514 lines
18 KiB
Python
514 lines
18 KiB
Python
"""Mem0 local server memory plugin — MemoryProvider interface.
|
|
|
|
Self-hosted Mem0 server with semantic search and automatic fact extraction.
|
|
|
|
Config via environment variables:
|
|
MEM0_BASE_URL — Local Mem0 server URL (required, e.g., http://localhost:8000)
|
|
MEM0_USER_ID — User identifier for memory scoping (default: hermes-user)
|
|
MEM0_AGENT_ID — Agent identifier (default: hermes)
|
|
|
|
Or via $HERMES_HOME/mem0-local.json.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import threading
|
|
import time
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from agent.memory_provider import MemoryProvider
|
|
from tools.registry import tool_error
|
|
|
|
from .client import LocalMem0Client
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Circuit breaker: after this many consecutive failures, pause API calls
|
|
_BREAKER_THRESHOLD = 5
|
|
_BREAKER_COOLDOWN_SECS = 120
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _load_config() -> dict:
|
|
"""Load config from env vars, with $HERMES_HOME/mem0-local.json overrides."""
|
|
from hermes_constants import get_hermes_home
|
|
|
|
config = {
|
|
"base_url": os.environ.get("MEM0_BASE_URL", "http://localhost:8000"),
|
|
"user_id": os.environ.get("MEM0_USER_ID", "hermes-user"),
|
|
"agent_id": os.environ.get("MEM0_AGENT_ID", "hermes"),
|
|
"rerank": True,
|
|
"timeout": 10.0,
|
|
}
|
|
|
|
config_path = get_hermes_home() / "mem0-local.json"
|
|
if config_path.exists():
|
|
try:
|
|
file_cfg = json.loads(config_path.read_text(encoding="utf-8"))
|
|
config.update(
|
|
{k: v for k, v in file_cfg.items() if v is not None and v != ""}
|
|
)
|
|
except Exception as e:
|
|
logger.warning("Failed to load mem0-local.json: %s", e)
|
|
|
|
return config
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tool schemas
|
|
# ---------------------------------------------------------------------------
|
|
|
|
PROFILE_SCHEMA = {
|
|
"name": "mem0_profile",
|
|
"description": (
|
|
"Retrieve all stored memories about the user — preferences, facts, "
|
|
"project context. Fast, no reranking. Use at conversation start."
|
|
),
|
|
"parameters": {"type": "object", "properties": {}, "required": []},
|
|
}
|
|
|
|
SEARCH_SCHEMA = {
|
|
"name": "mem0_search",
|
|
"description": (
|
|
"Search memories by meaning. Returns relevant facts ranked by similarity. "
|
|
"Set rerank=true for higher accuracy on important queries."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {"type": "string", "description": "What to search for."},
|
|
"rerank": {
|
|
"type": "boolean",
|
|
"description": "Enable reranking for precision (default: false).",
|
|
},
|
|
"top_k": {
|
|
"type": "integer",
|
|
"description": "Max results (default: 10, max: 50).",
|
|
},
|
|
},
|
|
"required": ["query"],
|
|
},
|
|
}
|
|
|
|
CONCLUDE_SCHEMA = {
|
|
"name": "mem0_conclude",
|
|
"description": (
|
|
"Store a durable fact about the user. Stored verbatim (no LLM extraction). "
|
|
"Use for explicit preferences, corrections, or decisions."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"conclusion": {"type": "string", "description": "The fact to store."},
|
|
},
|
|
"required": ["conclusion"],
|
|
},
|
|
}
|
|
|
|
DELETE_SCHEMA = {
|
|
"name": "mem0_delete",
|
|
"description": (
|
|
"Delete a specific memory by ID. Use when user explicitly requests "
|
|
"to remove or forget a stored fact."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"memory_id": {
|
|
"type": "string",
|
|
"description": "The ID of the memory to delete.",
|
|
},
|
|
},
|
|
"required": ["memory_id"],
|
|
},
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# MemoryProvider implementation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class Mem0LocalMemoryProvider(MemoryProvider):
|
|
"""Self-hosted Mem0 memory with semantic search and fact extraction."""
|
|
|
|
def __init__(self):
|
|
self._config = None
|
|
self._client: Optional[LocalMem0Client] = None
|
|
self._client_lock = threading.Lock()
|
|
self._user_id = "hermes-user"
|
|
self._agent_id = "hermes"
|
|
self._rerank = True
|
|
self._prefetch_result = ""
|
|
self._prefetch_lock = threading.Lock()
|
|
self._prefetch_thread = None
|
|
self._sync_thread = None
|
|
# Circuit breaker state
|
|
self._breaker_lock = threading.Lock()
|
|
self._consecutive_failures = 0
|
|
self._breaker_open_until = 0.0
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return "mem0-local"
|
|
|
|
def is_available(self) -> bool:
|
|
try:
|
|
return self._get_client().health()
|
|
except Exception as e:
|
|
logger.debug("Mem0 availability check failed: %s", e)
|
|
return False
|
|
|
|
def save_config(self, values: dict, hermes_home):
|
|
"""Write config to $HERMES_HOME/mem0-local.json."""
|
|
from pathlib import Path
|
|
|
|
config_path = Path(hermes_home) / "mem0-local.json"
|
|
existing = {}
|
|
if config_path.exists():
|
|
try:
|
|
existing = json.loads(config_path.read_text())
|
|
except Exception:
|
|
pass
|
|
existing.update(values)
|
|
config_path.write_text(json.dumps(existing, indent=2))
|
|
|
|
def get_config_schema(self):
|
|
return [
|
|
{
|
|
"key": "base_url",
|
|
"description": "Local Mem0 server URL",
|
|
"required": True,
|
|
"env_var": "MEM0_BASE_URL",
|
|
"url": "https://github.com/mem0ai/mem0",
|
|
},
|
|
{
|
|
"key": "user_id",
|
|
"description": "User identifier",
|
|
"default": "hermes-user",
|
|
},
|
|
{"key": "agent_id", "description": "Agent identifier", "default": "hermes"},
|
|
{
|
|
"key": "rerank",
|
|
"description": "Enable reranking for recall",
|
|
"default": "true",
|
|
"choices": ["true", "false"],
|
|
},
|
|
{
|
|
"key": "timeout",
|
|
"description": "Request timeout in seconds",
|
|
"default": "10.0",
|
|
},
|
|
]
|
|
|
|
def _get_client(self) -> LocalMem0Client:
|
|
"""Thread-safe client accessor with lazy initialization."""
|
|
with self._client_lock:
|
|
if self._client is not None:
|
|
return self._client
|
|
# Lazy config loading if initialize() wasn't called
|
|
if self._config is None:
|
|
self._config = _load_config()
|
|
self._user_id = self._config.get("user_id", "hermes-user")
|
|
self._agent_id = self._config.get("agent_id", "hermes")
|
|
self._rerank = self._config.get("rerank", True)
|
|
base_url = self._config.get("base_url", "http://localhost:8000")
|
|
timeout = float(self._config.get("timeout", 10.0))
|
|
self._client = LocalMem0Client(base_url, timeout=timeout)
|
|
return self._client
|
|
|
|
def _is_breaker_open(self) -> bool:
|
|
"""Return True if the circuit breaker is tripped (too many failures)."""
|
|
with self._breaker_lock:
|
|
if self._consecutive_failures < _BREAKER_THRESHOLD:
|
|
return False
|
|
if time.monotonic() >= self._breaker_open_until:
|
|
self._consecutive_failures = 0
|
|
return False
|
|
return True
|
|
|
|
def _record_success(self):
|
|
with self._breaker_lock:
|
|
self._consecutive_failures = 0
|
|
|
|
def _record_failure(self):
|
|
with self._breaker_lock:
|
|
self._consecutive_failures += 1
|
|
if self._consecutive_failures >= _BREAKER_THRESHOLD:
|
|
self._breaker_open_until = time.monotonic() + _BREAKER_COOLDOWN_SECS
|
|
logger.warning(
|
|
"Mem0 circuit breaker tripped after %d consecutive failures. "
|
|
"Pausing API calls for %ds.",
|
|
self._consecutive_failures,
|
|
_BREAKER_COOLDOWN_SECS,
|
|
)
|
|
|
|
def _format_search_results(self, results: List[Dict]) -> str:
|
|
"""Format search results into a bullet list string."""
|
|
lines = [
|
|
r.get("text") or r.get("memory", "")
|
|
for r in results
|
|
if r.get("text") or r.get("memory")
|
|
]
|
|
return "\n".join(f"- {line}" for line in lines) if lines else ""
|
|
|
|
def initialize(self, session_id: str, **kwargs) -> None:
|
|
self._config = _load_config()
|
|
# Prefer gateway-provided user_id for per-user memory scoping
|
|
self._user_id = kwargs.get("user_id") or self._config.get(
|
|
"user_id", "hermes-user"
|
|
)
|
|
self._agent_id = self._config.get("agent_id", "hermes")
|
|
self._rerank = self._config.get("rerank", True)
|
|
|
|
def system_prompt_block(self) -> str:
|
|
return (
|
|
"# Mem0 Memory (Local)\n"
|
|
f"Active. User: {self._user_id}.\n"
|
|
"Use mem0_search to find memories, mem0_conclude to store facts, "
|
|
"mem0_profile for a full overview."
|
|
)
|
|
|
|
def prefetch(self, query: str = "", *, session_id: str = "") -> str:
|
|
"""Return cached prefetch result from previous turn.
|
|
|
|
Args:
|
|
query: Deprecated, kept for API compatibility.
|
|
session_id: Session identifier.
|
|
"""
|
|
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
|
self._prefetch_thread.join(timeout=3.0)
|
|
with self._prefetch_lock:
|
|
result = self._prefetch_result
|
|
self._prefetch_result = ""
|
|
if not result:
|
|
return ""
|
|
return f"## Mem0 Memory\n{result}"
|
|
|
|
def queue_prefetch_and_get(self, query: str) -> str:
|
|
"""Sync prefetch for pre_llm_call hook - returns memory context immediately."""
|
|
if self._is_breaker_open():
|
|
return ""
|
|
try:
|
|
client = self._get_client()
|
|
results = client.search(
|
|
query=query,
|
|
user_id=self._user_id,
|
|
limit=5,
|
|
)
|
|
if results:
|
|
formatted = self._format_search_results(results)
|
|
if formatted:
|
|
self._record_success()
|
|
return formatted
|
|
except Exception as e:
|
|
self._record_failure()
|
|
logger.debug("Mem0 prefetch failed: %s", e)
|
|
return ""
|
|
|
|
def queue_prefetch(self, query: str, *, session_id: str = "") -> None:
|
|
"""Queue async prefetch for next turn (called before LLM request).
|
|
|
|
Args:
|
|
query: Search query for prefetching memories.
|
|
session_id: Unused. Kept for API compatibility.
|
|
"""
|
|
if self._is_breaker_open():
|
|
return
|
|
|
|
def _run():
|
|
try:
|
|
client = self._get_client()
|
|
results = client.search(
|
|
query=query,
|
|
user_id=self._user_id,
|
|
limit=5,
|
|
)
|
|
if results:
|
|
formatted = self._format_search_results(results)
|
|
with self._prefetch_lock:
|
|
self._prefetch_result = formatted
|
|
self._record_success()
|
|
except Exception as e:
|
|
self._record_failure()
|
|
logger.debug("Mem0 prefetch failed: %s", e)
|
|
|
|
self._prefetch_thread = threading.Thread(
|
|
target=_run, daemon=True, name="mem0-local-prefetch"
|
|
)
|
|
self._prefetch_thread.start()
|
|
|
|
def sync_turn(
|
|
self, user_content: str, assistant_content: str, *, session_id: str = ""
|
|
) -> None:
|
|
"""Send the turn to Mem0 for server-side fact extraction (non-blocking).
|
|
|
|
Args:
|
|
user_content: User message content.
|
|
assistant_content: Assistant response content.
|
|
session_id: Unused. Kept for API compatibility.
|
|
"""
|
|
if self._is_breaker_open():
|
|
return
|
|
|
|
def _sync():
|
|
try:
|
|
client = self._get_client()
|
|
combined = f"User: {user_content}\nAssistant: {assistant_content}"
|
|
client.add(
|
|
message=combined,
|
|
user_id=self._user_id,
|
|
agent_id=self._agent_id,
|
|
)
|
|
self._record_success()
|
|
except Exception as e:
|
|
self._record_failure()
|
|
logger.warning("Mem0 sync failed: %s", e)
|
|
|
|
if self._sync_thread and self._sync_thread.is_alive():
|
|
self._sync_thread.join(timeout=5.0)
|
|
|
|
self._sync_thread = threading.Thread(
|
|
target=_sync, daemon=True, name="mem0-local-sync"
|
|
)
|
|
self._sync_thread.start()
|
|
|
|
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
|
return [PROFILE_SCHEMA, SEARCH_SCHEMA, CONCLUDE_SCHEMA, DELETE_SCHEMA]
|
|
|
|
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
|
|
if self._is_breaker_open():
|
|
return json.dumps(
|
|
{
|
|
"error": "Mem0 server temporarily unavailable (multiple consecutive failures). Will retry automatically."
|
|
}
|
|
)
|
|
|
|
try:
|
|
client = self._get_client()
|
|
except Exception as e:
|
|
return tool_error(str(e))
|
|
|
|
if tool_name == "mem0_profile":
|
|
try:
|
|
memories = client.get_all(user_id=self._user_id)
|
|
self._record_success()
|
|
if not memories:
|
|
return json.dumps({"result": "No memories stored yet."})
|
|
lines = [m.get("text", "") for m in memories if m.get("text")]
|
|
return json.dumps({"result": "\n".join(lines), "count": len(lines)})
|
|
except Exception as e:
|
|
self._record_failure()
|
|
return tool_error(f"Failed to fetch profile: {e}")
|
|
|
|
elif tool_name == "mem0_search":
|
|
query = args.get("query", "")
|
|
if not query:
|
|
return tool_error("Missing required parameter: query")
|
|
top_k = min(int(args.get("top_k", 10)), 50)
|
|
try:
|
|
results = client.search(
|
|
query=query,
|
|
user_id=self._user_id,
|
|
limit=top_k,
|
|
)
|
|
self._record_success()
|
|
if not results:
|
|
return json.dumps({"result": "No relevant memories found."})
|
|
items = [
|
|
{"memory": r.get("text", ""), "score": r.get("score", 0)}
|
|
for r in results
|
|
]
|
|
return json.dumps({"results": items, "count": len(items)})
|
|
except Exception as e:
|
|
self._record_failure()
|
|
return tool_error(f"Search failed: {e}")
|
|
|
|
elif tool_name == "mem0_conclude":
|
|
conclusion = args.get("conclusion", "")
|
|
if not conclusion:
|
|
return tool_error("Missing required parameter: conclusion")
|
|
try:
|
|
client.add(
|
|
message=conclusion,
|
|
user_id=self._user_id,
|
|
agent_id=self._agent_id,
|
|
)
|
|
self._record_success()
|
|
return json.dumps({"result": "Fact stored."})
|
|
except Exception as e:
|
|
self._record_failure()
|
|
return tool_error(f"Failed to store: {e}")
|
|
|
|
elif tool_name == "mem0_delete":
|
|
memory_id = args.get("memory_id", "")
|
|
if not memory_id:
|
|
return tool_error("Missing required parameter: memory_id")
|
|
try:
|
|
client.delete(memory_id=memory_id)
|
|
self._record_success()
|
|
return json.dumps({"result": "Memory deleted."})
|
|
except Exception as e:
|
|
self._record_failure()
|
|
return tool_error(f"Failed to delete: {e}")
|
|
|
|
return tool_error(f"Unknown tool: {tool_name}")
|
|
|
|
def shutdown(self) -> None:
|
|
for t in (self._prefetch_thread, self._sync_thread):
|
|
if t and t.is_alive():
|
|
t.join(timeout=5.0)
|
|
with self._client_lock:
|
|
self._client = None
|
|
|
|
|
|
def register(ctx) -> None:
|
|
"""Register Mem0 local as a memory provider plugin.
|
|
|
|
Works in both contexts:
|
|
- Memory provider context (plugins/memory/) — uses register_memory_provider()
|
|
- General plugin context (~/.hermes/plugins/) — registers tools directly
|
|
"""
|
|
provider = Mem0LocalMemoryProvider()
|
|
|
|
# Memory provider context (plugins/memory/ directory)
|
|
if hasattr(ctx, "register_memory_provider"):
|
|
ctx.register_memory_provider(provider)
|
|
return
|
|
|
|
# General plugin context (~/.hermes/plugins/ directory)
|
|
# Register tools manually since we can't register as memory provider
|
|
def make_handler(tool_name: str):
|
|
"""Create a handler closure for a specific tool."""
|
|
return lambda args, **kwargs: provider.handle_tool_call(tool_name, args)
|
|
|
|
for schema in [PROFILE_SCHEMA, SEARCH_SCHEMA, CONCLUDE_SCHEMA, DELETE_SCHEMA]:
|
|
ctx.register_tool(
|
|
name=schema["name"],
|
|
toolset="mem0_local",
|
|
schema=schema,
|
|
handler=make_handler(schema["name"]),
|
|
)
|
|
logger.debug("Registered tool: %s", schema["name"])
|
|
|
|
# Register pre_llm_call hook to inject memory context
|
|
def pre_llm_call_hook(user_message: str, **kwargs) -> dict:
|
|
"""Inject memory context into user message before LLM call."""
|
|
try:
|
|
results = provider.queue_prefetch_and_get(user_message)
|
|
if results:
|
|
return {"context": results}
|
|
except Exception as e:
|
|
logger.debug("Mem0 pre_llm_call hook failed: %s", e)
|
|
return {}
|
|
|
|
ctx.register_hook("pre_llm_call", pre_llm_call_hook)
|
|
logger.debug("Registered pre_llm_call hook")
|