Fix code review issues and improve code quality #1
+47
-31
@@ -151,6 +151,7 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
|||||||
self._prefetch_thread = None
|
self._prefetch_thread = None
|
||||||
self._sync_thread = None
|
self._sync_thread = None
|
||||||
# Circuit breaker state
|
# Circuit breaker state
|
||||||
|
self._breaker_lock = threading.Lock()
|
||||||
self._consecutive_failures = 0
|
self._consecutive_failures = 0
|
||||||
self._breaker_open_until = 0.0
|
self._breaker_open_until = 0.0
|
||||||
|
|
||||||
@@ -230,26 +231,38 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
|||||||
|
|
||||||
def _is_breaker_open(self) -> bool:
|
def _is_breaker_open(self) -> bool:
|
||||||
"""Return True if the circuit breaker is tripped (too many failures)."""
|
"""Return True if the circuit breaker is tripped (too many failures)."""
|
||||||
if self._consecutive_failures < _BREAKER_THRESHOLD:
|
with self._breaker_lock:
|
||||||
return False
|
if self._consecutive_failures < _BREAKER_THRESHOLD:
|
||||||
if time.monotonic() >= self._breaker_open_until:
|
return False
|
||||||
self._consecutive_failures = 0
|
if time.monotonic() >= self._breaker_open_until:
|
||||||
return False
|
self._consecutive_failures = 0
|
||||||
return True
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def _record_success(self):
|
def _record_success(self):
|
||||||
self._consecutive_failures = 0
|
with self._breaker_lock:
|
||||||
|
self._consecutive_failures = 0
|
||||||
|
|
||||||
def _record_failure(self):
|
def _record_failure(self):
|
||||||
self._consecutive_failures += 1
|
with self._breaker_lock:
|
||||||
if self._consecutive_failures >= _BREAKER_THRESHOLD:
|
self._consecutive_failures += 1
|
||||||
self._breaker_open_until = time.monotonic() + _BREAKER_COOLDOWN_SECS
|
if self._consecutive_failures >= _BREAKER_THRESHOLD:
|
||||||
logger.warning(
|
self._breaker_open_until = time.monotonic() + _BREAKER_COOLDOWN_SECS
|
||||||
"Mem0 circuit breaker tripped after %d consecutive failures. "
|
logger.warning(
|
||||||
"Pausing API calls for %ds.",
|
"Mem0 circuit breaker tripped after %d consecutive failures. "
|
||||||
self._consecutive_failures,
|
"Pausing API calls for %ds.",
|
||||||
_BREAKER_COOLDOWN_SECS,
|
self._consecutive_failures,
|
||||||
)
|
_BREAKER_COOLDOWN_SECS,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _format_search_results(self, results: List[Dict]) -> str:
|
||||||
|
"""Format search results into a bullet list string."""
|
||||||
|
lines = [
|
||||||
|
r.get("text") or r.get("memory", "")
|
||||||
|
for r in results
|
||||||
|
if r.get("text") or r.get("memory")
|
||||||
|
]
|
||||||
|
return "\n".join(f"- {line}" for line in lines) if lines else ""
|
||||||
|
|
||||||
def initialize(self, session_id: str, **kwargs) -> None:
|
def initialize(self, session_id: str, **kwargs) -> None:
|
||||||
self._config = _load_config()
|
self._config = _load_config()
|
||||||
@@ -296,21 +309,22 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
|||||||
limit=5,
|
limit=5,
|
||||||
)
|
)
|
||||||
if results:
|
if results:
|
||||||
lines = [
|
formatted = self._format_search_results(results)
|
||||||
r.get("text") or r.get("memory", "")
|
if formatted:
|
||||||
for r in results
|
|
||||||
if r.get("text") or r.get("memory")
|
|
||||||
]
|
|
||||||
if lines:
|
|
||||||
self._record_success()
|
self._record_success()
|
||||||
return "\n".join(f"- {line}" for line in lines)
|
return formatted
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._record_failure()
|
self._record_failure()
|
||||||
logger.debug("Mem0 prefetch failed: %s", e)
|
logger.debug("Mem0 prefetch failed: %s", e)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def queue_prefetch(self, query: str, *, session_id: str = "") -> None:
|
def queue_prefetch(self, query: str, *, session_id: str = "") -> None:
|
||||||
"""Queue async prefetch for next turn (called before LLM request)."""
|
"""Queue async prefetch for next turn (called before LLM request).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query for prefetching memories.
|
||||||
|
session_id: Unused. Kept for API compatibility.
|
||||||
|
"""
|
||||||
if self._is_breaker_open():
|
if self._is_breaker_open():
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -323,13 +337,9 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
|||||||
limit=5,
|
limit=5,
|
||||||
)
|
)
|
||||||
if results:
|
if results:
|
||||||
lines = [
|
formatted = self._format_search_results(results)
|
||||||
r.get("text") or r.get("memory", "")
|
|
||||||
for r in results
|
|
||||||
if r.get("text") or r.get("memory")
|
|
||||||
]
|
|
||||||
with self._prefetch_lock:
|
with self._prefetch_lock:
|
||||||
self._prefetch_result = "\n".join(f"- {line}" for line in lines)
|
self._prefetch_result = formatted
|
||||||
self._record_success()
|
self._record_success()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._record_failure()
|
self._record_failure()
|
||||||
@@ -343,7 +353,13 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
|||||||
def sync_turn(
|
def sync_turn(
|
||||||
self, user_content: str, assistant_content: str, *, session_id: str = ""
|
self, user_content: str, assistant_content: str, *, session_id: str = ""
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Send the turn to Mem0 for server-side fact extraction (non-blocking)."""
|
"""Send the turn to Mem0 for server-side fact extraction (non-blocking).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_content: User message content.
|
||||||
|
assistant_content: Assistant response content.
|
||||||
|
session_id: Unused. Kept for API compatibility.
|
||||||
|
"""
|
||||||
if self._is_breaker_open():
|
if self._is_breaker_open():
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user