From 2f98f63ecd20b896a26d28d8c4338ba8ca006fe5 Mon Sep 17 00:00:00 2001 From: ARIA Date: Fri, 10 Apr 2026 13:45:34 +0200 Subject: [PATCH 1/3] Fix code review issues - Add agent_id parameter to client.add() and use it in mem0_conclude - Fix inconsistent field access in queue_prefetch_and_get (check both text and memory) - Remove unused _read_filters() and _write_filters() methods - Mark prefetch() query parameter as deprecated (was unused) --- __init__.py | 24 +++++++++++++----------- client.py | 5 ++++- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/__init__.py b/__init__.py index f99ab0e..03193a2 100644 --- a/__init__.py +++ b/__init__.py @@ -260,14 +260,6 @@ class Mem0LocalMemoryProvider(MemoryProvider): self._agent_id = self._config.get("agent_id", "hermes") self._rerank = self._config.get("rerank", True) - def _read_filters(self) -> Dict[str, Any]: - """Filters for search/get_all — scoped to user only.""" - return {"user_id": self._user_id} - - def _write_filters(self) -> Dict[str, Any]: - """Filters for add — scoped to user + agent.""" - return {"user_id": self._user_id, "agent_id": self._agent_id} - def system_prompt_block(self) -> str: return ( "# Mem0 Memory (Local)\n" @@ -276,8 +268,13 @@ class Mem0LocalMemoryProvider(MemoryProvider): "mem0_profile for a full overview." ) - def prefetch(self, query: str, *, session_id: str = "") -> str: - """Return cached prefetch result from previous turn.""" + def prefetch(self, query: str = "", *, session_id: str = "") -> str: + """Return cached prefetch result from previous turn. + + Args: + query: Deprecated, kept for API compatibility. + session_id: Session identifier. + """ if self._prefetch_thread and self._prefetch_thread.is_alive(): self._prefetch_thread.join(timeout=3.0) with self._prefetch_lock: @@ -299,7 +296,11 @@ class Mem0LocalMemoryProvider(MemoryProvider): limit=5, ) if results: - lines = [r.get("text", "") for r in results if r.get("text")] + lines = [ + r.get("text") or r.get("memory", "") + for r in results + if r.get("text") or r.get("memory") + ] if lines: self._record_success() return "\n".join(f"- {line}" for line in lines) @@ -427,6 +428,7 @@ class Mem0LocalMemoryProvider(MemoryProvider): client.add( message=conclusion, user_id=self._user_id, + agent_id=self._agent_id, ) self._record_success() return json.dumps({"result": "Fact stored."}) diff --git a/client.py b/client.py index 2dbb234..31fe6f1 100644 --- a/client.py +++ b/client.py @@ -94,17 +94,20 @@ class LocalMem0Client: self, message: str, user_id: Optional[str] = None, + agent_id: Optional[str] = None, metadata: Optional[Dict] = None, ) -> Dict: """Add a new memory. API: POST /add - Request: {message, user_id, metadata} + Request: {message, user_id, agent_id, metadata} Response: {success, memory_id, message} """ payload = {"message": message} if user_id: payload["user_id"] = user_id + if agent_id: + payload["agent_id"] = agent_id if metadata: payload["metadata"] = metadata return self._request("POST", "/add", json=payload) From b70402ed30a1b413f231febad1b1fe8220d75092 Mon Sep 17 00:00:00 2001 From: ARIA Date: Fri, 10 Apr 2026 13:47:08 +0200 Subject: [PATCH 2/3] Fix missing agent_id in sync_turn() The sync_turn() method was not passing agent_id to client.add(), causing inconsistent memory scoping compared to mem0_conclude. --- __init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/__init__.py b/__init__.py index 03193a2..104f676 100644 --- a/__init__.py +++ b/__init__.py @@ -350,11 +350,11 @@ class Mem0LocalMemoryProvider(MemoryProvider): def _sync(): try: client = self._get_client() - # Combine user and assistant content for context combined = f"User: {user_content}\nAssistant: {assistant_content}" client.add( message=combined, user_id=self._user_id, + agent_id=self._agent_id, ) self._record_success() except Exception as e: From 93876a2653fe76f90f6c2090fd6d647c88cb680f Mon Sep 17 00:00:00 2001 From: ARIA Date: Fri, 10 Apr 2026 13:50:19 +0200 Subject: [PATCH 3/3] Refactor: improve thread safety and code organization - Add dedicated _breaker_lock for thread-safe circuit breaker state access - Extract _format_search_results() helper to eliminate DRY violation - Document unused session_id parameters for API compatibility --- __init__.py | 78 ++++++++++++++++++++++++++++++++--------------------- 1 file changed, 47 insertions(+), 31 deletions(-) diff --git a/__init__.py b/__init__.py index 104f676..c5257cd 100644 --- a/__init__.py +++ b/__init__.py @@ -151,6 +151,7 @@ class Mem0LocalMemoryProvider(MemoryProvider): self._prefetch_thread = None self._sync_thread = None # Circuit breaker state + self._breaker_lock = threading.Lock() self._consecutive_failures = 0 self._breaker_open_until = 0.0 @@ -230,26 +231,38 @@ class Mem0LocalMemoryProvider(MemoryProvider): def _is_breaker_open(self) -> bool: """Return True if the circuit breaker is tripped (too many failures).""" - if self._consecutive_failures < _BREAKER_THRESHOLD: - return False - if time.monotonic() >= self._breaker_open_until: - self._consecutive_failures = 0 - return False - return True + with self._breaker_lock: + if self._consecutive_failures < _BREAKER_THRESHOLD: + return False + if time.monotonic() >= self._breaker_open_until: + self._consecutive_failures = 0 + return False + return True def _record_success(self): - self._consecutive_failures = 0 + with self._breaker_lock: + self._consecutive_failures = 0 def _record_failure(self): - self._consecutive_failures += 1 - if self._consecutive_failures >= _BREAKER_THRESHOLD: - self._breaker_open_until = time.monotonic() + _BREAKER_COOLDOWN_SECS - logger.warning( - "Mem0 circuit breaker tripped after %d consecutive failures. " - "Pausing API calls for %ds.", - self._consecutive_failures, - _BREAKER_COOLDOWN_SECS, - ) + with self._breaker_lock: + self._consecutive_failures += 1 + if self._consecutive_failures >= _BREAKER_THRESHOLD: + self._breaker_open_until = time.monotonic() + _BREAKER_COOLDOWN_SECS + logger.warning( + "Mem0 circuit breaker tripped after %d consecutive failures. " + "Pausing API calls for %ds.", + self._consecutive_failures, + _BREAKER_COOLDOWN_SECS, + ) + + def _format_search_results(self, results: List[Dict]) -> str: + """Format search results into a bullet list string.""" + lines = [ + r.get("text") or r.get("memory", "") + for r in results + if r.get("text") or r.get("memory") + ] + return "\n".join(f"- {line}" for line in lines) if lines else "" def initialize(self, session_id: str, **kwargs) -> None: self._config = _load_config() @@ -296,21 +309,22 @@ class Mem0LocalMemoryProvider(MemoryProvider): limit=5, ) if results: - lines = [ - r.get("text") or r.get("memory", "") - for r in results - if r.get("text") or r.get("memory") - ] - if lines: + formatted = self._format_search_results(results) + if formatted: self._record_success() - return "\n".join(f"- {line}" for line in lines) + return formatted except Exception as e: self._record_failure() logger.debug("Mem0 prefetch failed: %s", e) return "" def queue_prefetch(self, query: str, *, session_id: str = "") -> None: - """Queue async prefetch for next turn (called before LLM request).""" + """Queue async prefetch for next turn (called before LLM request). + + Args: + query: Search query for prefetching memories. + session_id: Unused. Kept for API compatibility. + """ if self._is_breaker_open(): return @@ -323,13 +337,9 @@ class Mem0LocalMemoryProvider(MemoryProvider): limit=5, ) if results: - lines = [ - r.get("text") or r.get("memory", "") - for r in results - if r.get("text") or r.get("memory") - ] + formatted = self._format_search_results(results) with self._prefetch_lock: - self._prefetch_result = "\n".join(f"- {line}" for line in lines) + self._prefetch_result = formatted self._record_success() except Exception as e: self._record_failure() @@ -343,7 +353,13 @@ class Mem0LocalMemoryProvider(MemoryProvider): def sync_turn( self, user_content: str, assistant_content: str, *, session_id: str = "" ) -> None: - """Send the turn to Mem0 for server-side fact extraction (non-blocking).""" + """Send the turn to Mem0 for server-side fact extraction (non-blocking). + + Args: + user_content: User message content. + assistant_content: Assistant response content. + session_id: Unused. Kept for API compatibility. + """ if self._is_breaker_open(): return