diff --git a/__init__.py b/__init__.py index 0a3a627..19572f6 100644 --- a/__init__.py +++ b/__init__.py @@ -112,6 +112,24 @@ CONCLUDE_SCHEMA = { }, } +DELETE_SCHEMA = { + "name": "mem0_delete", + "description": ( + "Delete a specific memory by ID. Use when user explicitly requests " + "to remove or forget a stored fact." + ), + "parameters": { + "type": "object", + "properties": { + "memory_id": { + "type": "string", + "description": "The ID of the memory to delete.", + }, + }, + "required": ["memory_id"], + }, +} + # --------------------------------------------------------------------------- # MemoryProvider implementation @@ -279,9 +297,8 @@ class Mem0LocalMemoryProvider(MemoryProvider): client = self._get_client() results = client.search( query=query, - filters=self._read_filters(), - rerank=self._rerank, - top_k=5, + user_id=self._user_id, + limit=5, ) if results: lines = [ @@ -311,11 +328,12 @@ class Mem0LocalMemoryProvider(MemoryProvider): def _sync(): try: client = self._get_client() - messages = [ - {"role": "user", "content": user_content}, - {"role": "assistant", "content": assistant_content}, - ] - client.add(messages, filters=self._write_filters(), infer=True) + # Combine user and assistant content for context + combined = f"User: {user_content}\nAssistant: {assistant_content}" + client.add( + message=combined, + user_id=self._user_id, + ) self._record_success() except Exception as e: self._record_failure() @@ -330,7 +348,7 @@ class Mem0LocalMemoryProvider(MemoryProvider): self._sync_thread.start() def get_tool_schemas(self) -> List[Dict[str, Any]]: - return [PROFILE_SCHEMA, SEARCH_SCHEMA, CONCLUDE_SCHEMA] + return [PROFILE_SCHEMA, SEARCH_SCHEMA, CONCLUDE_SCHEMA, DELETE_SCHEMA] def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str: if self._is_breaker_open(): @@ -347,15 +365,11 @@ class Mem0LocalMemoryProvider(MemoryProvider): if tool_name == "mem0_profile": try: - memories = client.get_all(filters=self._read_filters()) + memories = client.get_all(user_id=self._user_id) self._record_success() if not memories: return json.dumps({"result": "No memories stored yet."}) - lines = [ - m.get("text") or m.get("memory", "") - for m in memories - if m.get("text") or m.get("memory") - ] + lines = [m.get("text", "") for m in memories if m.get("text")] return json.dumps({"result": "\n".join(lines), "count": len(lines)}) except Exception as e: self._record_failure() @@ -365,23 +379,18 @@ class Mem0LocalMemoryProvider(MemoryProvider): query = args.get("query", "") if not query: return tool_error("Missing required parameter: query") - rerank = args.get("rerank", False) top_k = min(int(args.get("top_k", 10)), 50) try: results = client.search( query=query, - filters=self._read_filters(), - rerank=rerank, - top_k=top_k, + user_id=self._user_id, + limit=top_k, ) self._record_success() if not results: return json.dumps({"result": "No relevant memories found."}) items = [ - { - "memory": r.get("text") or r.get("memory", ""), - "score": r.get("score", 0), - } + {"memory": r.get("text", ""), "score": r.get("score", 0)} for r in results ] return json.dumps({"results": items, "count": len(items)}) @@ -395,9 +404,8 @@ class Mem0LocalMemoryProvider(MemoryProvider): return tool_error("Missing required parameter: conclusion") try: client.add( - [{"role": "user", "content": conclusion}], - filters=self._write_filters(), - infer=False, # Store verbatim + message=conclusion, + user_id=self._user_id, ) self._record_success() return json.dumps({"result": "Fact stored."}) @@ -405,6 +413,18 @@ class Mem0LocalMemoryProvider(MemoryProvider): self._record_failure() return tool_error(f"Failed to store: {e}") + elif tool_name == "mem0_delete": + memory_id = args.get("memory_id", "") + if not memory_id: + return tool_error("Missing required parameter: memory_id") + try: + client.delete(memory_id=memory_id) + self._record_success() + return json.dumps({"result": "Memory deleted."}) + except Exception as e: + self._record_failure() + return tool_error(f"Failed to delete: {e}") + return tool_error(f"Unknown tool: {tool_name}") def shutdown(self) -> None: @@ -435,7 +455,7 @@ def register(ctx) -> None: """Create a handler closure for a specific tool.""" return lambda args, **kwargs: provider.handle_tool_call(tool_name, args) - for schema in [PROFILE_SCHEMA, SEARCH_SCHEMA, CONCLUDE_SCHEMA]: + for schema in [PROFILE_SCHEMA, SEARCH_SCHEMA, CONCLUDE_SCHEMA, DELETE_SCHEMA]: ctx.register_tool( name=schema["name"], toolset="mem0_local", diff --git a/client.py b/client.py index 604e60a..2dbb234 100644 --- a/client.py +++ b/client.py @@ -1,4 +1,13 @@ -"""Local Mem0 server HTTP client.""" +"""Local Mem0 server HTTP client. + +Based on OpenAPI spec from http://10.0.0.150:8889/openapi.json + +Endpoints: +- POST /add - Add memory +- POST /search - Search memories +- GET /memories - Get all memories +- DELETE /delete/{memory_id} - Delete memory +""" from __future__ import annotations @@ -11,13 +20,7 @@ logger = logging.getLogger(__name__) class LocalMem0Client: - """HTTP client for self-hosted Mem0 server. - - Expects Mem0 server at MEM0_BASE_URL with endpoints: - - POST /search - - GET /memories - - POST /add - """ + """HTTP client for self-hosted Mem0 server.""" def __init__(self, base_url: str, timeout: float = 10.0): self.base_url = base_url.rstrip("/") @@ -60,60 +63,59 @@ class LocalMem0Client: def search( self, query: str, - filters: Dict[str, Any], - rerank: bool = False, - top_k: int = 10, + user_id: Optional[str] = None, + limit: int = 5, ) -> List[Dict]: - """Search memories by semantic similarity.""" - payload = { - "query": query, - "user_id": filters.get("user_id"), - "agent_id": filters.get("agent_id"), - "top_k": top_k, - } - if rerank is not None: - payload["rerank"] = rerank + """Search memories by semantic similarity. + + API: POST /search + Request: {query, user_id, limit} + Response: {results: [{id, text, user_id, score, metadata}]} + """ + payload = {"query": query, "limit": limit} + if user_id: + payload["user_id"] = user_id result = self._request("POST", "/search", json=payload) - return self._unwrap_results(result) + return result.get("results", []) - def get_all(self, filters: Dict[str, Any]) -> List[Dict]: - """Get all memories matching filters.""" - params = filters + def get_all(self, user_id: Optional[str] = None) -> List[Dict]: + """Get all memories for a user. + + API: GET /memories?user_id=... + Response: {memories: [{id, text, user_id, metadata}]} + """ + params = {} + if user_id: + params["user_id"] = user_id result = self._request("GET", "/memories", params=params) - return self._unwrap_results(result) + return result.get("memories", []) - def add( + def add( self, - messages: List[Dict[str, str]], - filters: Dict[str, Any], - infer: bool = True, + message: str, + user_id: Optional[str] = None, + metadata: Optional[Dict] = None, ) -> Dict: - """Add conversation messages for fact extraction.""" - # Extract message content from messages array - if messages and isinstance(messages[0], dict): - message_content = messages[0].get("content", "") - elif messages: - message_content = str(messages[0]) - else: - message_content = "" + """Add a new memory. - payload = { - "message": message_content, - "user_id": filters.get("user_id"), - "agent_id": filters.get("agent_id"), - } - if not infer: - payload["only_store_messages"] = True + API: POST /add + Request: {message, user_id, metadata} + Response: {success, memory_id, message} + """ + payload = {"message": message} + if user_id: + payload["user_id"] = user_id + if metadata: + payload["metadata"] = metadata return self._request("POST", "/add", json=payload) - @staticmethod - def _unwrap_results(response: Any) -> List[Dict]: - """Normalize Mem0 API response.""" - if isinstance(response, dict): - return response.get("memories", response.get("results", [])) - if isinstance(response, list): - return response - return [] + def delete(self, memory_id: str) -> Dict: + """Delete a memory by ID. + + API: DELETE /delete/{memory_id} + Response: {success, memory_id, message} + """ + return self._request("DELETE", f"/delete/{memory_id}") def health(self) -> bool: """Check if server is reachable.""" diff --git a/plugin.yaml b/plugin.yaml index 03b2b58..c8ab4db 100644 --- a/plugin.yaml +++ b/plugin.yaml @@ -15,6 +15,7 @@ provides_tools: - mem0_profile - mem0_search - mem0_conclude + - mem0_delete pip_dependencies: - requests