Files
mem0-local-hermes-plugin/__init__.py
T

514 lines
18 KiB
Python

"""Mem0 local server memory plugin — MemoryProvider interface.
Self-hosted Mem0 server with semantic search and automatic fact extraction.
Config via environment variables:
MEM0_BASE_URL — Local Mem0 server URL (required, e.g., http://localhost:8000)
MEM0_USER_ID — User identifier for memory scoping (default: hermes-user)
MEM0_AGENT_ID — Agent identifier (default: hermes)
Or via $HERMES_HOME/mem0-local.json.
"""
from __future__ import annotations
import json
import logging
import os
import threading
import time
from typing import Any, Dict, List, Optional
from agent.memory_provider import MemoryProvider
from tools.registry import tool_error
from .client import LocalMem0Client
logger = logging.getLogger(__name__)
# Circuit breaker: after this many consecutive failures, pause API calls
_BREAKER_THRESHOLD = 5
_BREAKER_COOLDOWN_SECS = 120
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
def _load_config() -> dict:
"""Load config from env vars, with $HERMES_HOME/mem0-local.json overrides."""
from hermes_constants import get_hermes_home
config = {
"base_url": os.environ.get("MEM0_BASE_URL", "http://localhost:8000"),
"user_id": os.environ.get("MEM0_USER_ID", "hermes-user"),
"agent_id": os.environ.get("MEM0_AGENT_ID", "hermes"),
"rerank": True,
"timeout": 10.0,
}
config_path = get_hermes_home() / "mem0-local.json"
if config_path.exists():
try:
file_cfg = json.loads(config_path.read_text(encoding="utf-8"))
config.update(
{k: v for k, v in file_cfg.items() if v is not None and v != ""}
)
except Exception as e:
logger.warning("Failed to load mem0-local.json: %s", e)
return config
# ---------------------------------------------------------------------------
# Tool schemas
# ---------------------------------------------------------------------------
PROFILE_SCHEMA = {
"name": "mem0_profile",
"description": (
"Retrieve all stored memories about the user — preferences, facts, "
"project context. Fast, no reranking. Use at conversation start."
),
"parameters": {"type": "object", "properties": {}, "required": []},
}
SEARCH_SCHEMA = {
"name": "mem0_search",
"description": (
"Search memories by meaning. Returns relevant facts ranked by similarity. "
"Set rerank=true for higher accuracy on important queries."
),
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "What to search for."},
"rerank": {
"type": "boolean",
"description": "Enable reranking for precision (default: false).",
},
"top_k": {
"type": "integer",
"description": "Max results (default: 10, max: 50).",
},
},
"required": ["query"],
},
}
CONCLUDE_SCHEMA = {
"name": "mem0_conclude",
"description": (
"Store a durable fact about the user. Stored verbatim (no LLM extraction). "
"Use for explicit preferences, corrections, or decisions."
),
"parameters": {
"type": "object",
"properties": {
"conclusion": {"type": "string", "description": "The fact to store."},
},
"required": ["conclusion"],
},
}
DELETE_SCHEMA = {
"name": "mem0_delete",
"description": (
"Delete a specific memory by ID. Use when user explicitly requests "
"to remove or forget a stored fact."
),
"parameters": {
"type": "object",
"properties": {
"memory_id": {
"type": "string",
"description": "The ID of the memory to delete.",
},
},
"required": ["memory_id"],
},
}
# ---------------------------------------------------------------------------
# MemoryProvider implementation
# ---------------------------------------------------------------------------
class Mem0LocalMemoryProvider(MemoryProvider):
"""Self-hosted Mem0 memory with semantic search and fact extraction."""
def __init__(self):
self._config = None
self._client: Optional[LocalMem0Client] = None
self._client_lock = threading.Lock()
self._user_id = "hermes-user"
self._agent_id = "hermes"
self._rerank = True
self._prefetch_result = ""
self._prefetch_lock = threading.Lock()
self._prefetch_thread = None
self._sync_thread = None
# Circuit breaker state
self._breaker_lock = threading.Lock()
self._consecutive_failures = 0
self._breaker_open_until = 0.0
@property
def name(self) -> str:
return "mem0-local"
def is_available(self) -> bool:
try:
return self._get_client().health()
except Exception as e:
logger.debug("Mem0 availability check failed: %s", e)
return False
def save_config(self, values: dict, hermes_home):
"""Write config to $HERMES_HOME/mem0-local.json."""
from pathlib import Path
config_path = Path(hermes_home) / "mem0-local.json"
existing = {}
if config_path.exists():
try:
existing = json.loads(config_path.read_text())
except Exception:
pass
existing.update(values)
config_path.write_text(json.dumps(existing, indent=2))
def get_config_schema(self):
return [
{
"key": "base_url",
"description": "Local Mem0 server URL",
"required": True,
"env_var": "MEM0_BASE_URL",
"url": "https://github.com/mem0ai/mem0",
},
{
"key": "user_id",
"description": "User identifier",
"default": "hermes-user",
},
{"key": "agent_id", "description": "Agent identifier", "default": "hermes"},
{
"key": "rerank",
"description": "Enable reranking for recall",
"default": "true",
"choices": ["true", "false"],
},
{
"key": "timeout",
"description": "Request timeout in seconds",
"default": "10.0",
},
]
def _get_client(self) -> LocalMem0Client:
"""Thread-safe client accessor with lazy initialization."""
with self._client_lock:
if self._client is not None:
return self._client
# Lazy config loading if initialize() wasn't called
if self._config is None:
self._config = _load_config()
self._user_id = self._config.get("user_id", "hermes-user")
self._agent_id = self._config.get("agent_id", "hermes")
self._rerank = self._config.get("rerank", True)
base_url = self._config.get("base_url", "http://localhost:8000")
timeout = float(self._config.get("timeout", 10.0))
self._client = LocalMem0Client(base_url, timeout=timeout)
return self._client
def _is_breaker_open(self) -> bool:
"""Return True if the circuit breaker is tripped (too many failures)."""
with self._breaker_lock:
if self._consecutive_failures < _BREAKER_THRESHOLD:
return False
if time.monotonic() >= self._breaker_open_until:
self._consecutive_failures = 0
return False
return True
def _record_success(self):
with self._breaker_lock:
self._consecutive_failures = 0
def _record_failure(self):
with self._breaker_lock:
self._consecutive_failures += 1
if self._consecutive_failures >= _BREAKER_THRESHOLD:
self._breaker_open_until = time.monotonic() + _BREAKER_COOLDOWN_SECS
logger.warning(
"Mem0 circuit breaker tripped after %d consecutive failures. "
"Pausing API calls for %ds.",
self._consecutive_failures,
_BREAKER_COOLDOWN_SECS,
)
def _format_search_results(self, results: List[Dict]) -> str:
"""Format search results into a bullet list string."""
lines = [
r.get("text") or r.get("memory", "")
for r in results
if r.get("text") or r.get("memory")
]
return "\n".join(f"- {line}" for line in lines) if lines else ""
def initialize(self, session_id: str, **kwargs) -> None:
self._config = _load_config()
# Prefer gateway-provided user_id for per-user memory scoping
self._user_id = kwargs.get("user_id") or self._config.get(
"user_id", "hermes-user"
)
self._agent_id = self._config.get("agent_id", "hermes")
self._rerank = self._config.get("rerank", True)
def system_prompt_block(self) -> str:
return (
"# Mem0 Memory (Local)\n"
f"Active. User: {self._user_id}.\n"
"Use mem0_search to find memories, mem0_conclude to store facts, "
"mem0_profile for a full overview."
)
def prefetch(self, query: str = "", *, session_id: str = "") -> str:
"""Return cached prefetch result from previous turn.
Args:
query: Deprecated, kept for API compatibility.
session_id: Session identifier.
"""
if self._prefetch_thread and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=3.0)
with self._prefetch_lock:
result = self._prefetch_result
self._prefetch_result = ""
if not result:
return ""
return f"## Mem0 Memory\n{result}"
def queue_prefetch_and_get(self, query: str) -> str:
"""Sync prefetch for pre_llm_call hook - returns memory context immediately."""
if self._is_breaker_open():
return ""
try:
client = self._get_client()
results = client.search(
query=query,
user_id=self._user_id,
limit=5,
)
if results:
formatted = self._format_search_results(results)
if formatted:
self._record_success()
return formatted
except Exception as e:
self._record_failure()
logger.debug("Mem0 prefetch failed: %s", e)
return ""
def queue_prefetch(self, query: str, *, session_id: str = "") -> None:
"""Queue async prefetch for next turn (called before LLM request).
Args:
query: Search query for prefetching memories.
session_id: Unused. Kept for API compatibility.
"""
if self._is_breaker_open():
return
def _run():
try:
client = self._get_client()
results = client.search(
query=query,
user_id=self._user_id,
limit=5,
)
if results:
formatted = self._format_search_results(results)
with self._prefetch_lock:
self._prefetch_result = formatted
self._record_success()
except Exception as e:
self._record_failure()
logger.debug("Mem0 prefetch failed: %s", e)
self._prefetch_thread = threading.Thread(
target=_run, daemon=True, name="mem0-local-prefetch"
)
self._prefetch_thread.start()
def sync_turn(
self, user_content: str, assistant_content: str, *, session_id: str = ""
) -> None:
"""Send the turn to Mem0 for server-side fact extraction (non-blocking).
Args:
user_content: User message content.
assistant_content: Assistant response content.
session_id: Unused. Kept for API compatibility.
"""
if self._is_breaker_open():
return
def _sync():
try:
client = self._get_client()
combined = f"User: {user_content}\nAssistant: {assistant_content}"
client.add(
message=combined,
user_id=self._user_id,
agent_id=self._agent_id,
)
self._record_success()
except Exception as e:
self._record_failure()
logger.warning("Mem0 sync failed: %s", e)
if self._sync_thread and self._sync_thread.is_alive():
self._sync_thread.join(timeout=5.0)
self._sync_thread = threading.Thread(
target=_sync, daemon=True, name="mem0-local-sync"
)
self._sync_thread.start()
def get_tool_schemas(self) -> List[Dict[str, Any]]:
return [PROFILE_SCHEMA, SEARCH_SCHEMA, CONCLUDE_SCHEMA, DELETE_SCHEMA]
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
if self._is_breaker_open():
return json.dumps(
{
"error": "Mem0 server temporarily unavailable (multiple consecutive failures). Will retry automatically."
}
)
try:
client = self._get_client()
except Exception as e:
return tool_error(str(e))
if tool_name == "mem0_profile":
try:
memories = client.get_all(user_id=self._user_id)
self._record_success()
if not memories:
return json.dumps({"result": "No memories stored yet."})
lines = [m.get("text", "") for m in memories if m.get("text")]
return json.dumps({"result": "\n".join(lines), "count": len(lines)})
except Exception as e:
self._record_failure()
return tool_error(f"Failed to fetch profile: {e}")
elif tool_name == "mem0_search":
query = args.get("query", "")
if not query:
return tool_error("Missing required parameter: query")
top_k = min(int(args.get("top_k", 10)), 50)
try:
results = client.search(
query=query,
user_id=self._user_id,
limit=top_k,
)
self._record_success()
if not results:
return json.dumps({"result": "No relevant memories found."})
items = [
{"memory": r.get("text", ""), "score": r.get("score", 0)}
for r in results
]
return json.dumps({"results": items, "count": len(items)})
except Exception as e:
self._record_failure()
return tool_error(f"Search failed: {e}")
elif tool_name == "mem0_conclude":
conclusion = args.get("conclusion", "")
if not conclusion:
return tool_error("Missing required parameter: conclusion")
try:
client.add(
message=conclusion,
user_id=self._user_id,
agent_id=self._agent_id,
)
self._record_success()
return json.dumps({"result": "Fact stored."})
except Exception as e:
self._record_failure()
return tool_error(f"Failed to store: {e}")
elif tool_name == "mem0_delete":
memory_id = args.get("memory_id", "")
if not memory_id:
return tool_error("Missing required parameter: memory_id")
try:
client.delete(memory_id=memory_id)
self._record_success()
return json.dumps({"result": "Memory deleted."})
except Exception as e:
self._record_failure()
return tool_error(f"Failed to delete: {e}")
return tool_error(f"Unknown tool: {tool_name}")
def shutdown(self) -> None:
for t in (self._prefetch_thread, self._sync_thread):
if t and t.is_alive():
t.join(timeout=5.0)
with self._client_lock:
self._client = None
def register(ctx) -> None:
"""Register Mem0 local as a memory provider plugin.
Works in both contexts:
- Memory provider context (plugins/memory/) — uses register_memory_provider()
- General plugin context (~/.hermes/plugins/) — registers tools directly
"""
provider = Mem0LocalMemoryProvider()
# Memory provider context (plugins/memory/ directory)
if hasattr(ctx, "register_memory_provider"):
ctx.register_memory_provider(provider)
return
# General plugin context (~/.hermes/plugins/ directory)
# Register tools manually since we can't register as memory provider
def make_handler(tool_name: str):
"""Create a handler closure for a specific tool."""
return lambda args, **kwargs: provider.handle_tool_call(tool_name, args)
for schema in [PROFILE_SCHEMA, SEARCH_SCHEMA, CONCLUDE_SCHEMA, DELETE_SCHEMA]:
ctx.register_tool(
name=schema["name"],
toolset="mem0_local",
schema=schema,
handler=make_handler(schema["name"]),
)
logger.debug("Registered tool: %s", schema["name"])
# Register pre_llm_call hook to inject memory context
def pre_llm_call_hook(user_message: str, **kwargs) -> dict:
"""Inject memory context into user message before LLM call."""
try:
results = provider.queue_prefetch_and_get(user_message)
if results:
return {"context": results}
except Exception as e:
logger.debug("Mem0 pre_llm_call hook failed: %s", e)
return {}
ctx.register_hook("pre_llm_call", pre_llm_call_hook)
logger.debug("Registered pre_llm_call hook")