Refactor client API based on OpenAPI spec and add mem0_delete tool
- Update client.py to use correct API endpoints from OpenAPI spec - Add delete() method for memory deletion - Add mem0_delete tool with schema and handler - Simplify API calls to match actual Mem0 OSS server format
This commit is contained in:
+47
-27
@@ -112,6 +112,24 @@ CONCLUDE_SCHEMA = {
|
||||
},
|
||||
}
|
||||
|
||||
DELETE_SCHEMA = {
|
||||
"name": "mem0_delete",
|
||||
"description": (
|
||||
"Delete a specific memory by ID. Use when user explicitly requests "
|
||||
"to remove or forget a stored fact."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"memory_id": {
|
||||
"type": "string",
|
||||
"description": "The ID of the memory to delete.",
|
||||
},
|
||||
},
|
||||
"required": ["memory_id"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryProvider implementation
|
||||
@@ -279,9 +297,8 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
||||
client = self._get_client()
|
||||
results = client.search(
|
||||
query=query,
|
||||
filters=self._read_filters(),
|
||||
rerank=self._rerank,
|
||||
top_k=5,
|
||||
user_id=self._user_id,
|
||||
limit=5,
|
||||
)
|
||||
if results:
|
||||
lines = [
|
||||
@@ -311,11 +328,12 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
||||
def _sync():
|
||||
try:
|
||||
client = self._get_client()
|
||||
messages = [
|
||||
{"role": "user", "content": user_content},
|
||||
{"role": "assistant", "content": assistant_content},
|
||||
]
|
||||
client.add(messages, filters=self._write_filters(), infer=True)
|
||||
# Combine user and assistant content for context
|
||||
combined = f"User: {user_content}\nAssistant: {assistant_content}"
|
||||
client.add(
|
||||
message=combined,
|
||||
user_id=self._user_id,
|
||||
)
|
||||
self._record_success()
|
||||
except Exception as e:
|
||||
self._record_failure()
|
||||
@@ -330,7 +348,7 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
||||
self._sync_thread.start()
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
return [PROFILE_SCHEMA, SEARCH_SCHEMA, CONCLUDE_SCHEMA]
|
||||
return [PROFILE_SCHEMA, SEARCH_SCHEMA, CONCLUDE_SCHEMA, DELETE_SCHEMA]
|
||||
|
||||
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
|
||||
if self._is_breaker_open():
|
||||
@@ -347,15 +365,11 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
||||
|
||||
if tool_name == "mem0_profile":
|
||||
try:
|
||||
memories = client.get_all(filters=self._read_filters())
|
||||
memories = client.get_all(user_id=self._user_id)
|
||||
self._record_success()
|
||||
if not memories:
|
||||
return json.dumps({"result": "No memories stored yet."})
|
||||
lines = [
|
||||
m.get("text") or m.get("memory", "")
|
||||
for m in memories
|
||||
if m.get("text") or m.get("memory")
|
||||
]
|
||||
lines = [m.get("text", "") for m in memories if m.get("text")]
|
||||
return json.dumps({"result": "\n".join(lines), "count": len(lines)})
|
||||
except Exception as e:
|
||||
self._record_failure()
|
||||
@@ -365,23 +379,18 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
||||
query = args.get("query", "")
|
||||
if not query:
|
||||
return tool_error("Missing required parameter: query")
|
||||
rerank = args.get("rerank", False)
|
||||
top_k = min(int(args.get("top_k", 10)), 50)
|
||||
try:
|
||||
results = client.search(
|
||||
query=query,
|
||||
filters=self._read_filters(),
|
||||
rerank=rerank,
|
||||
top_k=top_k,
|
||||
user_id=self._user_id,
|
||||
limit=top_k,
|
||||
)
|
||||
self._record_success()
|
||||
if not results:
|
||||
return json.dumps({"result": "No relevant memories found."})
|
||||
items = [
|
||||
{
|
||||
"memory": r.get("text") or r.get("memory", ""),
|
||||
"score": r.get("score", 0),
|
||||
}
|
||||
{"memory": r.get("text", ""), "score": r.get("score", 0)}
|
||||
for r in results
|
||||
]
|
||||
return json.dumps({"results": items, "count": len(items)})
|
||||
@@ -395,9 +404,8 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
||||
return tool_error("Missing required parameter: conclusion")
|
||||
try:
|
||||
client.add(
|
||||
[{"role": "user", "content": conclusion}],
|
||||
filters=self._write_filters(),
|
||||
infer=False, # Store verbatim
|
||||
message=conclusion,
|
||||
user_id=self._user_id,
|
||||
)
|
||||
self._record_success()
|
||||
return json.dumps({"result": "Fact stored."})
|
||||
@@ -405,6 +413,18 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
||||
self._record_failure()
|
||||
return tool_error(f"Failed to store: {e}")
|
||||
|
||||
elif tool_name == "mem0_delete":
|
||||
memory_id = args.get("memory_id", "")
|
||||
if not memory_id:
|
||||
return tool_error("Missing required parameter: memory_id")
|
||||
try:
|
||||
client.delete(memory_id=memory_id)
|
||||
self._record_success()
|
||||
return json.dumps({"result": "Memory deleted."})
|
||||
except Exception as e:
|
||||
self._record_failure()
|
||||
return tool_error(f"Failed to delete: {e}")
|
||||
|
||||
return tool_error(f"Unknown tool: {tool_name}")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
@@ -435,7 +455,7 @@ def register(ctx) -> None:
|
||||
"""Create a handler closure for a specific tool."""
|
||||
return lambda args, **kwargs: provider.handle_tool_call(tool_name, args)
|
||||
|
||||
for schema in [PROFILE_SCHEMA, SEARCH_SCHEMA, CONCLUDE_SCHEMA]:
|
||||
for schema in [PROFILE_SCHEMA, SEARCH_SCHEMA, CONCLUDE_SCHEMA, DELETE_SCHEMA]:
|
||||
ctx.register_tool(
|
||||
name=schema["name"],
|
||||
toolset="mem0_local",
|
||||
|
||||
@@ -1,4 +1,13 @@
|
||||
"""Local Mem0 server HTTP client."""
|
||||
"""Local Mem0 server HTTP client.
|
||||
|
||||
Based on OpenAPI spec from http://10.0.0.150:8889/openapi.json
|
||||
|
||||
Endpoints:
|
||||
- POST /add - Add memory
|
||||
- POST /search - Search memories
|
||||
- GET /memories - Get all memories
|
||||
- DELETE /delete/{memory_id} - Delete memory
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -11,13 +20,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LocalMem0Client:
|
||||
"""HTTP client for self-hosted Mem0 server.
|
||||
|
||||
Expects Mem0 server at MEM0_BASE_URL with endpoints:
|
||||
- POST /search
|
||||
- GET /memories
|
||||
- POST /add
|
||||
"""
|
||||
"""HTTP client for self-hosted Mem0 server."""
|
||||
|
||||
def __init__(self, base_url: str, timeout: float = 10.0):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
@@ -60,60 +63,59 @@ class LocalMem0Client:
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
filters: Dict[str, Any],
|
||||
rerank: bool = False,
|
||||
top_k: int = 10,
|
||||
user_id: Optional[str] = None,
|
||||
limit: int = 5,
|
||||
) -> List[Dict]:
|
||||
"""Search memories by semantic similarity."""
|
||||
payload = {
|
||||
"query": query,
|
||||
"user_id": filters.get("user_id"),
|
||||
"agent_id": filters.get("agent_id"),
|
||||
"top_k": top_k,
|
||||
}
|
||||
if rerank is not None:
|
||||
payload["rerank"] = rerank
|
||||
result = self._request("POST", "/search", json=payload)
|
||||
return self._unwrap_results(result)
|
||||
"""Search memories by semantic similarity.
|
||||
|
||||
def get_all(self, filters: Dict[str, Any]) -> List[Dict]:
|
||||
"""Get all memories matching filters."""
|
||||
params = filters
|
||||
API: POST /search
|
||||
Request: {query, user_id, limit}
|
||||
Response: {results: [{id, text, user_id, score, metadata}]}
|
||||
"""
|
||||
payload = {"query": query, "limit": limit}
|
||||
if user_id:
|
||||
payload["user_id"] = user_id
|
||||
result = self._request("POST", "/search", json=payload)
|
||||
return result.get("results", [])
|
||||
|
||||
def get_all(self, user_id: Optional[str] = None) -> List[Dict]:
|
||||
"""Get all memories for a user.
|
||||
|
||||
API: GET /memories?user_id=...
|
||||
Response: {memories: [{id, text, user_id, metadata}]}
|
||||
"""
|
||||
params = {}
|
||||
if user_id:
|
||||
params["user_id"] = user_id
|
||||
result = self._request("GET", "/memories", params=params)
|
||||
return self._unwrap_results(result)
|
||||
return result.get("memories", [])
|
||||
|
||||
def add(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
filters: Dict[str, Any],
|
||||
infer: bool = True,
|
||||
message: str,
|
||||
user_id: Optional[str] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
) -> Dict:
|
||||
"""Add conversation messages for fact extraction."""
|
||||
# Extract message content from messages array
|
||||
if messages and isinstance(messages[0], dict):
|
||||
message_content = messages[0].get("content", "")
|
||||
elif messages:
|
||||
message_content = str(messages[0])
|
||||
else:
|
||||
message_content = ""
|
||||
"""Add a new memory.
|
||||
|
||||
payload = {
|
||||
"message": message_content,
|
||||
"user_id": filters.get("user_id"),
|
||||
"agent_id": filters.get("agent_id"),
|
||||
}
|
||||
if not infer:
|
||||
payload["only_store_messages"] = True
|
||||
API: POST /add
|
||||
Request: {message, user_id, metadata}
|
||||
Response: {success, memory_id, message}
|
||||
"""
|
||||
payload = {"message": message}
|
||||
if user_id:
|
||||
payload["user_id"] = user_id
|
||||
if metadata:
|
||||
payload["metadata"] = metadata
|
||||
return self._request("POST", "/add", json=payload)
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_results(response: Any) -> List[Dict]:
|
||||
"""Normalize Mem0 API response."""
|
||||
if isinstance(response, dict):
|
||||
return response.get("memories", response.get("results", []))
|
||||
if isinstance(response, list):
|
||||
return response
|
||||
return []
|
||||
def delete(self, memory_id: str) -> Dict:
|
||||
"""Delete a memory by ID.
|
||||
|
||||
API: DELETE /delete/{memory_id}
|
||||
Response: {success, memory_id, message}
|
||||
"""
|
||||
return self._request("DELETE", f"/delete/{memory_id}")
|
||||
|
||||
def health(self) -> bool:
|
||||
"""Check if server is reachable."""
|
||||
|
||||
@@ -15,6 +15,7 @@ provides_tools:
|
||||
- mem0_profile
|
||||
- mem0_search
|
||||
- mem0_conclude
|
||||
- mem0_delete
|
||||
|
||||
pip_dependencies:
|
||||
- requests
|
||||
|
||||
Reference in New Issue
Block a user