Merge pull request 'Fix code review issues and improve code quality' (#1) from fixes/code-review-issues into main
Reviewed-on: #1
This commit was merged in pull request #1.
This commit is contained in:
+40
-22
@@ -151,6 +151,7 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
||||
self._prefetch_thread = None
|
||||
self._sync_thread = None
|
||||
# Circuit breaker state
|
||||
self._breaker_lock = threading.Lock()
|
||||
self._consecutive_failures = 0
|
||||
self._breaker_open_until = 0.0
|
||||
|
||||
@@ -230,6 +231,7 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
||||
|
||||
def _is_breaker_open(self) -> bool:
|
||||
"""Return True if the circuit breaker is tripped (too many failures)."""
|
||||
with self._breaker_lock:
|
||||
if self._consecutive_failures < _BREAKER_THRESHOLD:
|
||||
return False
|
||||
if time.monotonic() >= self._breaker_open_until:
|
||||
@@ -238,9 +240,11 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
||||
return True
|
||||
|
||||
def _record_success(self):
|
||||
with self._breaker_lock:
|
||||
self._consecutive_failures = 0
|
||||
|
||||
def _record_failure(self):
|
||||
with self._breaker_lock:
|
||||
self._consecutive_failures += 1
|
||||
if self._consecutive_failures >= _BREAKER_THRESHOLD:
|
||||
self._breaker_open_until = time.monotonic() + _BREAKER_COOLDOWN_SECS
|
||||
@@ -251,6 +255,15 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
||||
_BREAKER_COOLDOWN_SECS,
|
||||
)
|
||||
|
||||
def _format_search_results(self, results: List[Dict]) -> str:
|
||||
"""Format search results into a bullet list string."""
|
||||
lines = [
|
||||
r.get("text") or r.get("memory", "")
|
||||
for r in results
|
||||
if r.get("text") or r.get("memory")
|
||||
]
|
||||
return "\n".join(f"- {line}" for line in lines) if lines else ""
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
self._config = _load_config()
|
||||
# Prefer gateway-provided user_id for per-user memory scoping
|
||||
@@ -260,14 +273,6 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
||||
self._agent_id = self._config.get("agent_id", "hermes")
|
||||
self._rerank = self._config.get("rerank", True)
|
||||
|
||||
def _read_filters(self) -> Dict[str, Any]:
|
||||
"""Filters for search/get_all — scoped to user only."""
|
||||
return {"user_id": self._user_id}
|
||||
|
||||
def _write_filters(self) -> Dict[str, Any]:
|
||||
"""Filters for add — scoped to user + agent."""
|
||||
return {"user_id": self._user_id, "agent_id": self._agent_id}
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
return (
|
||||
"# Mem0 Memory (Local)\n"
|
||||
@@ -276,8 +281,13 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
||||
"mem0_profile for a full overview."
|
||||
)
|
||||
|
||||
def prefetch(self, query: str, *, session_id: str = "") -> str:
|
||||
"""Return cached prefetch result from previous turn."""
|
||||
def prefetch(self, query: str = "", *, session_id: str = "") -> str:
|
||||
"""Return cached prefetch result from previous turn.
|
||||
|
||||
Args:
|
||||
query: Deprecated, kept for API compatibility.
|
||||
session_id: Session identifier.
|
||||
"""
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
self._prefetch_thread.join(timeout=3.0)
|
||||
with self._prefetch_lock:
|
||||
@@ -299,17 +309,22 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
||||
limit=5,
|
||||
)
|
||||
if results:
|
||||
lines = [r.get("text", "") for r in results if r.get("text")]
|
||||
if lines:
|
||||
formatted = self._format_search_results(results)
|
||||
if formatted:
|
||||
self._record_success()
|
||||
return "\n".join(f"- {line}" for line in lines)
|
||||
return formatted
|
||||
except Exception as e:
|
||||
self._record_failure()
|
||||
logger.debug("Mem0 prefetch failed: %s", e)
|
||||
return ""
|
||||
|
||||
def queue_prefetch(self, query: str, *, session_id: str = "") -> None:
|
||||
"""Queue async prefetch for next turn (called before LLM request)."""
|
||||
"""Queue async prefetch for next turn (called before LLM request).
|
||||
|
||||
Args:
|
||||
query: Search query for prefetching memories.
|
||||
session_id: Unused. Kept for API compatibility.
|
||||
"""
|
||||
if self._is_breaker_open():
|
||||
return
|
||||
|
||||
@@ -322,13 +337,9 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
||||
limit=5,
|
||||
)
|
||||
if results:
|
||||
lines = [
|
||||
r.get("text") or r.get("memory", "")
|
||||
for r in results
|
||||
if r.get("text") or r.get("memory")
|
||||
]
|
||||
formatted = self._format_search_results(results)
|
||||
with self._prefetch_lock:
|
||||
self._prefetch_result = "\n".join(f"- {line}" for line in lines)
|
||||
self._prefetch_result = formatted
|
||||
self._record_success()
|
||||
except Exception as e:
|
||||
self._record_failure()
|
||||
@@ -342,18 +353,24 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
||||
def sync_turn(
|
||||
self, user_content: str, assistant_content: str, *, session_id: str = ""
|
||||
) -> None:
|
||||
"""Send the turn to Mem0 for server-side fact extraction (non-blocking)."""
|
||||
"""Send the turn to Mem0 for server-side fact extraction (non-blocking).
|
||||
|
||||
Args:
|
||||
user_content: User message content.
|
||||
assistant_content: Assistant response content.
|
||||
session_id: Unused. Kept for API compatibility.
|
||||
"""
|
||||
if self._is_breaker_open():
|
||||
return
|
||||
|
||||
def _sync():
|
||||
try:
|
||||
client = self._get_client()
|
||||
# Combine user and assistant content for context
|
||||
combined = f"User: {user_content}\nAssistant: {assistant_content}"
|
||||
client.add(
|
||||
message=combined,
|
||||
user_id=self._user_id,
|
||||
agent_id=self._agent_id,
|
||||
)
|
||||
self._record_success()
|
||||
except Exception as e:
|
||||
@@ -427,6 +444,7 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
||||
client.add(
|
||||
message=conclusion,
|
||||
user_id=self._user_id,
|
||||
agent_id=self._agent_id,
|
||||
)
|
||||
self._record_success()
|
||||
return json.dumps({"result": "Fact stored."})
|
||||
|
||||
@@ -94,17 +94,20 @@ class LocalMem0Client:
|
||||
self,
|
||||
message: str,
|
||||
user_id: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
) -> Dict:
|
||||
"""Add a new memory.
|
||||
|
||||
API: POST /add
|
||||
Request: {message, user_id, metadata}
|
||||
Request: {message, user_id, agent_id, metadata}
|
||||
Response: {success, memory_id, message}
|
||||
"""
|
||||
payload = {"message": message}
|
||||
if user_id:
|
||||
payload["user_id"] = user_id
|
||||
if agent_id:
|
||||
payload["agent_id"] = agent_id
|
||||
if metadata:
|
||||
payload["metadata"] = metadata
|
||||
return self._request("POST", "/add", json=payload)
|
||||
|
||||
Reference in New Issue
Block a user