diff --git a/__init__.py b/__init__.py index f99ab0e..c5257cd 100644 --- a/__init__.py +++ b/__init__.py @@ -151,6 +151,7 @@ class Mem0LocalMemoryProvider(MemoryProvider): self._prefetch_thread = None self._sync_thread = None # Circuit breaker state + self._breaker_lock = threading.Lock() self._consecutive_failures = 0 self._breaker_open_until = 0.0 @@ -230,26 +231,38 @@ class Mem0LocalMemoryProvider(MemoryProvider): def _is_breaker_open(self) -> bool: """Return True if the circuit breaker is tripped (too many failures).""" - if self._consecutive_failures < _BREAKER_THRESHOLD: - return False - if time.monotonic() >= self._breaker_open_until: - self._consecutive_failures = 0 - return False - return True + with self._breaker_lock: + if self._consecutive_failures < _BREAKER_THRESHOLD: + return False + if time.monotonic() >= self._breaker_open_until: + self._consecutive_failures = 0 + return False + return True def _record_success(self): - self._consecutive_failures = 0 + with self._breaker_lock: + self._consecutive_failures = 0 def _record_failure(self): - self._consecutive_failures += 1 - if self._consecutive_failures >= _BREAKER_THRESHOLD: - self._breaker_open_until = time.monotonic() + _BREAKER_COOLDOWN_SECS - logger.warning( - "Mem0 circuit breaker tripped after %d consecutive failures. " - "Pausing API calls for %ds.", - self._consecutive_failures, - _BREAKER_COOLDOWN_SECS, - ) + with self._breaker_lock: + self._consecutive_failures += 1 + if self._consecutive_failures >= _BREAKER_THRESHOLD: + self._breaker_open_until = time.monotonic() + _BREAKER_COOLDOWN_SECS + logger.warning( + "Mem0 circuit breaker tripped after %d consecutive failures. " + "Pausing API calls for %ds.", + self._consecutive_failures, + _BREAKER_COOLDOWN_SECS, + ) + + def _format_search_results(self, results: List[Dict]) -> str: + """Format search results into a bullet list string.""" + lines = [ + r.get("text") or r.get("memory", "") + for r in results + if r.get("text") or r.get("memory") + ] + return "\n".join(f"- {line}" for line in lines) if lines else "" def initialize(self, session_id: str, **kwargs) -> None: self._config = _load_config() @@ -260,14 +273,6 @@ class Mem0LocalMemoryProvider(MemoryProvider): self._agent_id = self._config.get("agent_id", "hermes") self._rerank = self._config.get("rerank", True) - def _read_filters(self) -> Dict[str, Any]: - """Filters for search/get_all — scoped to user only.""" - return {"user_id": self._user_id} - - def _write_filters(self) -> Dict[str, Any]: - """Filters for add — scoped to user + agent.""" - return {"user_id": self._user_id, "agent_id": self._agent_id} - def system_prompt_block(self) -> str: return ( "# Mem0 Memory (Local)\n" @@ -276,8 +281,13 @@ class Mem0LocalMemoryProvider(MemoryProvider): "mem0_profile for a full overview." ) - def prefetch(self, query: str, *, session_id: str = "") -> str: - """Return cached prefetch result from previous turn.""" + def prefetch(self, query: str = "", *, session_id: str = "") -> str: + """Return cached prefetch result from previous turn. + + Args: + query: Deprecated, kept for API compatibility. + session_id: Session identifier. + """ if self._prefetch_thread and self._prefetch_thread.is_alive(): self._prefetch_thread.join(timeout=3.0) with self._prefetch_lock: @@ -299,17 +309,22 @@ class Mem0LocalMemoryProvider(MemoryProvider): limit=5, ) if results: - lines = [r.get("text", "") for r in results if r.get("text")] - if lines: + formatted = self._format_search_results(results) + if formatted: self._record_success() - return "\n".join(f"- {line}" for line in lines) + return formatted except Exception as e: self._record_failure() logger.debug("Mem0 prefetch failed: %s", e) return "" def queue_prefetch(self, query: str, *, session_id: str = "") -> None: - """Queue async prefetch for next turn (called before LLM request).""" + """Queue async prefetch for next turn (called before LLM request). + + Args: + query: Search query for prefetching memories. + session_id: Unused. Kept for API compatibility. + """ if self._is_breaker_open(): return @@ -322,13 +337,9 @@ class Mem0LocalMemoryProvider(MemoryProvider): limit=5, ) if results: - lines = [ - r.get("text") or r.get("memory", "") - for r in results - if r.get("text") or r.get("memory") - ] + formatted = self._format_search_results(results) with self._prefetch_lock: - self._prefetch_result = "\n".join(f"- {line}" for line in lines) + self._prefetch_result = formatted self._record_success() except Exception as e: self._record_failure() @@ -342,18 +353,24 @@ class Mem0LocalMemoryProvider(MemoryProvider): def sync_turn( self, user_content: str, assistant_content: str, *, session_id: str = "" ) -> None: - """Send the turn to Mem0 for server-side fact extraction (non-blocking).""" + """Send the turn to Mem0 for server-side fact extraction (non-blocking). + + Args: + user_content: User message content. + assistant_content: Assistant response content. + session_id: Unused. Kept for API compatibility. + """ if self._is_breaker_open(): return def _sync(): try: client = self._get_client() - # Combine user and assistant content for context combined = f"User: {user_content}\nAssistant: {assistant_content}" client.add( message=combined, user_id=self._user_id, + agent_id=self._agent_id, ) self._record_success() except Exception as e: @@ -427,6 +444,7 @@ class Mem0LocalMemoryProvider(MemoryProvider): client.add( message=conclusion, user_id=self._user_id, + agent_id=self._agent_id, ) self._record_success() return json.dumps({"result": "Fact stored."}) diff --git a/client.py b/client.py index 2dbb234..31fe6f1 100644 --- a/client.py +++ b/client.py @@ -94,17 +94,20 @@ class LocalMem0Client: self, message: str, user_id: Optional[str] = None, + agent_id: Optional[str] = None, metadata: Optional[Dict] = None, ) -> Dict: """Add a new memory. API: POST /add - Request: {message, user_id, metadata} + Request: {message, user_id, agent_id, metadata} Response: {success, memory_id, message} """ payload = {"message": message} if user_id: payload["user_id"] = user_id + if agent_id: + payload["agent_id"] = agent_id if metadata: payload["metadata"] = metadata return self._request("POST", "/add", json=payload)