Merge pull request 'Fix code review issues and improve code quality' (#1) from fixes/code-review-issues into main

Reviewed-on: #1
This commit was merged in pull request #1.
This commit is contained in:
2026-04-10 11:53:10 +00:00
2 changed files with 60 additions and 39 deletions
+56 -38
View File
@@ -151,6 +151,7 @@ class Mem0LocalMemoryProvider(MemoryProvider):
self._prefetch_thread = None self._prefetch_thread = None
self._sync_thread = None self._sync_thread = None
# Circuit breaker state # Circuit breaker state
self._breaker_lock = threading.Lock()
self._consecutive_failures = 0 self._consecutive_failures = 0
self._breaker_open_until = 0.0 self._breaker_open_until = 0.0
@@ -230,26 +231,38 @@ class Mem0LocalMemoryProvider(MemoryProvider):
def _is_breaker_open(self) -> bool: def _is_breaker_open(self) -> bool:
"""Return True if the circuit breaker is tripped (too many failures).""" """Return True if the circuit breaker is tripped (too many failures)."""
if self._consecutive_failures < _BREAKER_THRESHOLD: with self._breaker_lock:
return False if self._consecutive_failures < _BREAKER_THRESHOLD:
if time.monotonic() >= self._breaker_open_until: return False
self._consecutive_failures = 0 if time.monotonic() >= self._breaker_open_until:
return False self._consecutive_failures = 0
return True return False
return True
def _record_success(self): def _record_success(self):
self._consecutive_failures = 0 with self._breaker_lock:
self._consecutive_failures = 0
def _record_failure(self): def _record_failure(self):
self._consecutive_failures += 1 with self._breaker_lock:
if self._consecutive_failures >= _BREAKER_THRESHOLD: self._consecutive_failures += 1
self._breaker_open_until = time.monotonic() + _BREAKER_COOLDOWN_SECS if self._consecutive_failures >= _BREAKER_THRESHOLD:
logger.warning( self._breaker_open_until = time.monotonic() + _BREAKER_COOLDOWN_SECS
"Mem0 circuit breaker tripped after %d consecutive failures. " logger.warning(
"Pausing API calls for %ds.", "Mem0 circuit breaker tripped after %d consecutive failures. "
self._consecutive_failures, "Pausing API calls for %ds.",
_BREAKER_COOLDOWN_SECS, self._consecutive_failures,
) _BREAKER_COOLDOWN_SECS,
)
def _format_search_results(self, results: List[Dict]) -> str:
"""Format search results into a bullet list string."""
lines = [
r.get("text") or r.get("memory", "")
for r in results
if r.get("text") or r.get("memory")
]
return "\n".join(f"- {line}" for line in lines) if lines else ""
def initialize(self, session_id: str, **kwargs) -> None: def initialize(self, session_id: str, **kwargs) -> None:
self._config = _load_config() self._config = _load_config()
@@ -260,14 +273,6 @@ class Mem0LocalMemoryProvider(MemoryProvider):
self._agent_id = self._config.get("agent_id", "hermes") self._agent_id = self._config.get("agent_id", "hermes")
self._rerank = self._config.get("rerank", True) self._rerank = self._config.get("rerank", True)
def _read_filters(self) -> Dict[str, Any]:
"""Filters for search/get_all — scoped to user only."""
return {"user_id": self._user_id}
def _write_filters(self) -> Dict[str, Any]:
"""Filters for add — scoped to user + agent."""
return {"user_id": self._user_id, "agent_id": self._agent_id}
def system_prompt_block(self) -> str: def system_prompt_block(self) -> str:
return ( return (
"# Mem0 Memory (Local)\n" "# Mem0 Memory (Local)\n"
@@ -276,8 +281,13 @@ class Mem0LocalMemoryProvider(MemoryProvider):
"mem0_profile for a full overview." "mem0_profile for a full overview."
) )
def prefetch(self, query: str, *, session_id: str = "") -> str: def prefetch(self, query: str = "", *, session_id: str = "") -> str:
"""Return cached prefetch result from previous turn.""" """Return cached prefetch result from previous turn.
Args:
query: Deprecated, kept for API compatibility.
session_id: Session identifier.
"""
if self._prefetch_thread and self._prefetch_thread.is_alive(): if self._prefetch_thread and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=3.0) self._prefetch_thread.join(timeout=3.0)
with self._prefetch_lock: with self._prefetch_lock:
@@ -299,17 +309,22 @@ class Mem0LocalMemoryProvider(MemoryProvider):
limit=5, limit=5,
) )
if results: if results:
lines = [r.get("text", "") for r in results if r.get("text")] formatted = self._format_search_results(results)
if lines: if formatted:
self._record_success() self._record_success()
return "\n".join(f"- {line}" for line in lines) return formatted
except Exception as e: except Exception as e:
self._record_failure() self._record_failure()
logger.debug("Mem0 prefetch failed: %s", e) logger.debug("Mem0 prefetch failed: %s", e)
return "" return ""
def queue_prefetch(self, query: str, *, session_id: str = "") -> None: def queue_prefetch(self, query: str, *, session_id: str = "") -> None:
"""Queue async prefetch for next turn (called before LLM request).""" """Queue async prefetch for next turn (called before LLM request).
Args:
query: Search query for prefetching memories.
session_id: Unused. Kept for API compatibility.
"""
if self._is_breaker_open(): if self._is_breaker_open():
return return
@@ -322,13 +337,9 @@ class Mem0LocalMemoryProvider(MemoryProvider):
limit=5, limit=5,
) )
if results: if results:
lines = [ formatted = self._format_search_results(results)
r.get("text") or r.get("memory", "")
for r in results
if r.get("text") or r.get("memory")
]
with self._prefetch_lock: with self._prefetch_lock:
self._prefetch_result = "\n".join(f"- {line}" for line in lines) self._prefetch_result = formatted
self._record_success() self._record_success()
except Exception as e: except Exception as e:
self._record_failure() self._record_failure()
@@ -342,18 +353,24 @@ class Mem0LocalMemoryProvider(MemoryProvider):
def sync_turn( def sync_turn(
self, user_content: str, assistant_content: str, *, session_id: str = "" self, user_content: str, assistant_content: str, *, session_id: str = ""
) -> None: ) -> None:
"""Send the turn to Mem0 for server-side fact extraction (non-blocking).""" """Send the turn to Mem0 for server-side fact extraction (non-blocking).
Args:
user_content: User message content.
assistant_content: Assistant response content.
session_id: Unused. Kept for API compatibility.
"""
if self._is_breaker_open(): if self._is_breaker_open():
return return
def _sync(): def _sync():
try: try:
client = self._get_client() client = self._get_client()
# Combine user and assistant content for context
combined = f"User: {user_content}\nAssistant: {assistant_content}" combined = f"User: {user_content}\nAssistant: {assistant_content}"
client.add( client.add(
message=combined, message=combined,
user_id=self._user_id, user_id=self._user_id,
agent_id=self._agent_id,
) )
self._record_success() self._record_success()
except Exception as e: except Exception as e:
@@ -427,6 +444,7 @@ class Mem0LocalMemoryProvider(MemoryProvider):
client.add( client.add(
message=conclusion, message=conclusion,
user_id=self._user_id, user_id=self._user_id,
agent_id=self._agent_id,
) )
self._record_success() self._record_success()
return json.dumps({"result": "Fact stored."}) return json.dumps({"result": "Fact stored."})
+4 -1
View File
@@ -94,17 +94,20 @@ class LocalMem0Client:
self, self,
message: str, message: str,
user_id: Optional[str] = None, user_id: Optional[str] = None,
agent_id: Optional[str] = None,
metadata: Optional[Dict] = None, metadata: Optional[Dict] = None,
) -> Dict: ) -> Dict:
"""Add a new memory. """Add a new memory.
API: POST /add API: POST /add
Request: {message, user_id, metadata} Request: {message, user_id, agent_id, metadata}
Response: {success, memory_id, message} Response: {success, memory_id, message}
""" """
payload = {"message": message} payload = {"message": message}
if user_id: if user_id:
payload["user_id"] = user_id payload["user_id"] = user_id
if agent_id:
payload["agent_id"] = agent_id
if metadata: if metadata:
payload["metadata"] = metadata payload["metadata"] = metadata
return self._request("POST", "/add", json=payload) return self._request("POST", "/add", json=payload)