From 93876a2653fe76f90f6c2090fd6d647c88cb680f Mon Sep 17 00:00:00 2001 From: ARIA Date: Fri, 10 Apr 2026 13:50:19 +0200 Subject: [PATCH] Refactor: improve thread safety and code organization - Add dedicated _breaker_lock for thread-safe circuit breaker state access - Extract _format_search_results() helper to eliminate DRY violation - Document unused session_id parameters for API compatibility --- __init__.py | 78 ++++++++++++++++++++++++++++++++--------------------- 1 file changed, 47 insertions(+), 31 deletions(-) diff --git a/__init__.py b/__init__.py index 104f676..c5257cd 100644 --- a/__init__.py +++ b/__init__.py @@ -151,6 +151,7 @@ class Mem0LocalMemoryProvider(MemoryProvider): self._prefetch_thread = None self._sync_thread = None # Circuit breaker state + self._breaker_lock = threading.Lock() self._consecutive_failures = 0 self._breaker_open_until = 0.0 @@ -230,26 +231,38 @@ class Mem0LocalMemoryProvider(MemoryProvider): def _is_breaker_open(self) -> bool: """Return True if the circuit breaker is tripped (too many failures).""" - if self._consecutive_failures < _BREAKER_THRESHOLD: - return False - if time.monotonic() >= self._breaker_open_until: - self._consecutive_failures = 0 - return False - return True + with self._breaker_lock: + if self._consecutive_failures < _BREAKER_THRESHOLD: + return False + if time.monotonic() >= self._breaker_open_until: + self._consecutive_failures = 0 + return False + return True def _record_success(self): - self._consecutive_failures = 0 + with self._breaker_lock: + self._consecutive_failures = 0 def _record_failure(self): - self._consecutive_failures += 1 - if self._consecutive_failures >= _BREAKER_THRESHOLD: - self._breaker_open_until = time.monotonic() + _BREAKER_COOLDOWN_SECS - logger.warning( - "Mem0 circuit breaker tripped after %d consecutive failures. " - "Pausing API calls for %ds.", - self._consecutive_failures, - _BREAKER_COOLDOWN_SECS, - ) + with self._breaker_lock: + self._consecutive_failures += 1 + if self._consecutive_failures >= _BREAKER_THRESHOLD: + self._breaker_open_until = time.monotonic() + _BREAKER_COOLDOWN_SECS + logger.warning( + "Mem0 circuit breaker tripped after %d consecutive failures. " + "Pausing API calls for %ds.", + self._consecutive_failures, + _BREAKER_COOLDOWN_SECS, + ) + + def _format_search_results(self, results: List[Dict]) -> str: + """Format search results into a bullet list string.""" + lines = [ + r.get("text") or r.get("memory", "") + for r in results + if r.get("text") or r.get("memory") + ] + return "\n".join(f"- {line}" for line in lines) if lines else "" def initialize(self, session_id: str, **kwargs) -> None: self._config = _load_config() @@ -296,21 +309,22 @@ class Mem0LocalMemoryProvider(MemoryProvider): limit=5, ) if results: - lines = [ - r.get("text") or r.get("memory", "") - for r in results - if r.get("text") or r.get("memory") - ] - if lines: + formatted = self._format_search_results(results) + if formatted: self._record_success() - return "\n".join(f"- {line}" for line in lines) + return formatted except Exception as e: self._record_failure() logger.debug("Mem0 prefetch failed: %s", e) return "" def queue_prefetch(self, query: str, *, session_id: str = "") -> None: - """Queue async prefetch for next turn (called before LLM request).""" + """Queue async prefetch for next turn (called before LLM request). + + Args: + query: Search query for prefetching memories. + session_id: Unused. Kept for API compatibility. + """ if self._is_breaker_open(): return @@ -323,13 +337,9 @@ class Mem0LocalMemoryProvider(MemoryProvider): limit=5, ) if results: - lines = [ - r.get("text") or r.get("memory", "") - for r in results - if r.get("text") or r.get("memory") - ] + formatted = self._format_search_results(results) with self._prefetch_lock: - self._prefetch_result = "\n".join(f"- {line}" for line in lines) + self._prefetch_result = formatted self._record_success() except Exception as e: self._record_failure() @@ -343,7 +353,13 @@ class Mem0LocalMemoryProvider(MemoryProvider): def sync_turn( self, user_content: str, assistant_content: str, *, session_id: str = "" ) -> None: - """Send the turn to Mem0 for server-side fact extraction (non-blocking).""" + """Send the turn to Mem0 for server-side fact extraction (non-blocking). + + Args: + user_content: User message content. + assistant_content: Assistant response content. + session_id: Unused. Kept for API compatibility. + """ if self._is_breaker_open(): return