Refactor client API based on OpenAPI spec and add mem0_delete tool
- Update client.py to use correct API endpoints from OpenAPI spec - Add delete() method for memory deletion - Add mem0_delete tool with schema and handler - Simplify API calls to match actual Mem0 OSS server format
This commit is contained in:
+47
-27
@@ -112,6 +112,24 @@ CONCLUDE_SCHEMA = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DELETE_SCHEMA = {
|
||||||
|
"name": "mem0_delete",
|
||||||
|
"description": (
|
||||||
|
"Delete a specific memory by ID. Use when user explicitly requests "
|
||||||
|
"to remove or forget a stored fact."
|
||||||
|
),
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"memory_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The ID of the memory to delete.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["memory_id"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# MemoryProvider implementation
|
# MemoryProvider implementation
|
||||||
@@ -279,9 +297,8 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
|||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
results = client.search(
|
results = client.search(
|
||||||
query=query,
|
query=query,
|
||||||
filters=self._read_filters(),
|
user_id=self._user_id,
|
||||||
rerank=self._rerank,
|
limit=5,
|
||||||
top_k=5,
|
|
||||||
)
|
)
|
||||||
if results:
|
if results:
|
||||||
lines = [
|
lines = [
|
||||||
@@ -311,11 +328,12 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
|||||||
def _sync():
|
def _sync():
|
||||||
try:
|
try:
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
messages = [
|
# Combine user and assistant content for context
|
||||||
{"role": "user", "content": user_content},
|
combined = f"User: {user_content}\nAssistant: {assistant_content}"
|
||||||
{"role": "assistant", "content": assistant_content},
|
client.add(
|
||||||
]
|
message=combined,
|
||||||
client.add(messages, filters=self._write_filters(), infer=True)
|
user_id=self._user_id,
|
||||||
|
)
|
||||||
self._record_success()
|
self._record_success()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._record_failure()
|
self._record_failure()
|
||||||
@@ -330,7 +348,7 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
|||||||
self._sync_thread.start()
|
self._sync_thread.start()
|
||||||
|
|
||||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||||
return [PROFILE_SCHEMA, SEARCH_SCHEMA, CONCLUDE_SCHEMA]
|
return [PROFILE_SCHEMA, SEARCH_SCHEMA, CONCLUDE_SCHEMA, DELETE_SCHEMA]
|
||||||
|
|
||||||
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
|
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
|
||||||
if self._is_breaker_open():
|
if self._is_breaker_open():
|
||||||
@@ -347,15 +365,11 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
|||||||
|
|
||||||
if tool_name == "mem0_profile":
|
if tool_name == "mem0_profile":
|
||||||
try:
|
try:
|
||||||
memories = client.get_all(filters=self._read_filters())
|
memories = client.get_all(user_id=self._user_id)
|
||||||
self._record_success()
|
self._record_success()
|
||||||
if not memories:
|
if not memories:
|
||||||
return json.dumps({"result": "No memories stored yet."})
|
return json.dumps({"result": "No memories stored yet."})
|
||||||
lines = [
|
lines = [m.get("text", "") for m in memories if m.get("text")]
|
||||||
m.get("text") or m.get("memory", "")
|
|
||||||
for m in memories
|
|
||||||
if m.get("text") or m.get("memory")
|
|
||||||
]
|
|
||||||
return json.dumps({"result": "\n".join(lines), "count": len(lines)})
|
return json.dumps({"result": "\n".join(lines), "count": len(lines)})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._record_failure()
|
self._record_failure()
|
||||||
@@ -365,23 +379,18 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
|||||||
query = args.get("query", "")
|
query = args.get("query", "")
|
||||||
if not query:
|
if not query:
|
||||||
return tool_error("Missing required parameter: query")
|
return tool_error("Missing required parameter: query")
|
||||||
rerank = args.get("rerank", False)
|
|
||||||
top_k = min(int(args.get("top_k", 10)), 50)
|
top_k = min(int(args.get("top_k", 10)), 50)
|
||||||
try:
|
try:
|
||||||
results = client.search(
|
results = client.search(
|
||||||
query=query,
|
query=query,
|
||||||
filters=self._read_filters(),
|
user_id=self._user_id,
|
||||||
rerank=rerank,
|
limit=top_k,
|
||||||
top_k=top_k,
|
|
||||||
)
|
)
|
||||||
self._record_success()
|
self._record_success()
|
||||||
if not results:
|
if not results:
|
||||||
return json.dumps({"result": "No relevant memories found."})
|
return json.dumps({"result": "No relevant memories found."})
|
||||||
items = [
|
items = [
|
||||||
{
|
{"memory": r.get("text", ""), "score": r.get("score", 0)}
|
||||||
"memory": r.get("text") or r.get("memory", ""),
|
|
||||||
"score": r.get("score", 0),
|
|
||||||
}
|
|
||||||
for r in results
|
for r in results
|
||||||
]
|
]
|
||||||
return json.dumps({"results": items, "count": len(items)})
|
return json.dumps({"results": items, "count": len(items)})
|
||||||
@@ -395,9 +404,8 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
|||||||
return tool_error("Missing required parameter: conclusion")
|
return tool_error("Missing required parameter: conclusion")
|
||||||
try:
|
try:
|
||||||
client.add(
|
client.add(
|
||||||
[{"role": "user", "content": conclusion}],
|
message=conclusion,
|
||||||
filters=self._write_filters(),
|
user_id=self._user_id,
|
||||||
infer=False, # Store verbatim
|
|
||||||
)
|
)
|
||||||
self._record_success()
|
self._record_success()
|
||||||
return json.dumps({"result": "Fact stored."})
|
return json.dumps({"result": "Fact stored."})
|
||||||
@@ -405,6 +413,18 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
|||||||
self._record_failure()
|
self._record_failure()
|
||||||
return tool_error(f"Failed to store: {e}")
|
return tool_error(f"Failed to store: {e}")
|
||||||
|
|
||||||
|
elif tool_name == "mem0_delete":
|
||||||
|
memory_id = args.get("memory_id", "")
|
||||||
|
if not memory_id:
|
||||||
|
return tool_error("Missing required parameter: memory_id")
|
||||||
|
try:
|
||||||
|
client.delete(memory_id=memory_id)
|
||||||
|
self._record_success()
|
||||||
|
return json.dumps({"result": "Memory deleted."})
|
||||||
|
except Exception as e:
|
||||||
|
self._record_failure()
|
||||||
|
return tool_error(f"Failed to delete: {e}")
|
||||||
|
|
||||||
return tool_error(f"Unknown tool: {tool_name}")
|
return tool_error(f"Unknown tool: {tool_name}")
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
def shutdown(self) -> None:
|
||||||
@@ -435,7 +455,7 @@ def register(ctx) -> None:
|
|||||||
"""Create a handler closure for a specific tool."""
|
"""Create a handler closure for a specific tool."""
|
||||||
return lambda args, **kwargs: provider.handle_tool_call(tool_name, args)
|
return lambda args, **kwargs: provider.handle_tool_call(tool_name, args)
|
||||||
|
|
||||||
for schema in [PROFILE_SCHEMA, SEARCH_SCHEMA, CONCLUDE_SCHEMA]:
|
for schema in [PROFILE_SCHEMA, SEARCH_SCHEMA, CONCLUDE_SCHEMA, DELETE_SCHEMA]:
|
||||||
ctx.register_tool(
|
ctx.register_tool(
|
||||||
name=schema["name"],
|
name=schema["name"],
|
||||||
toolset="mem0_local",
|
toolset="mem0_local",
|
||||||
|
|||||||
@@ -1,4 +1,13 @@
|
|||||||
"""Local Mem0 server HTTP client."""
|
"""Local Mem0 server HTTP client.
|
||||||
|
|
||||||
|
Based on OpenAPI spec from http://10.0.0.150:8889/openapi.json
|
||||||
|
|
||||||
|
Endpoints:
|
||||||
|
- POST /add - Add memory
|
||||||
|
- POST /search - Search memories
|
||||||
|
- GET /memories - Get all memories
|
||||||
|
- DELETE /delete/{memory_id} - Delete memory
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -11,13 +20,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class LocalMem0Client:
|
class LocalMem0Client:
|
||||||
"""HTTP client for self-hosted Mem0 server.
|
"""HTTP client for self-hosted Mem0 server."""
|
||||||
|
|
||||||
Expects Mem0 server at MEM0_BASE_URL with endpoints:
|
|
||||||
- POST /search
|
|
||||||
- GET /memories
|
|
||||||
- POST /add
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, base_url: str, timeout: float = 10.0):
|
def __init__(self, base_url: str, timeout: float = 10.0):
|
||||||
self.base_url = base_url.rstrip("/")
|
self.base_url = base_url.rstrip("/")
|
||||||
@@ -60,60 +63,59 @@ class LocalMem0Client:
|
|||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
filters: Dict[str, Any],
|
user_id: Optional[str] = None,
|
||||||
rerank: bool = False,
|
limit: int = 5,
|
||||||
top_k: int = 10,
|
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""Search memories by semantic similarity."""
|
"""Search memories by semantic similarity.
|
||||||
payload = {
|
|
||||||
"query": query,
|
|
||||||
"user_id": filters.get("user_id"),
|
|
||||||
"agent_id": filters.get("agent_id"),
|
|
||||||
"top_k": top_k,
|
|
||||||
}
|
|
||||||
if rerank is not None:
|
|
||||||
payload["rerank"] = rerank
|
|
||||||
result = self._request("POST", "/search", json=payload)
|
|
||||||
return self._unwrap_results(result)
|
|
||||||
|
|
||||||
def get_all(self, filters: Dict[str, Any]) -> List[Dict]:
|
API: POST /search
|
||||||
"""Get all memories matching filters."""
|
Request: {query, user_id, limit}
|
||||||
params = filters
|
Response: {results: [{id, text, user_id, score, metadata}]}
|
||||||
|
"""
|
||||||
|
payload = {"query": query, "limit": limit}
|
||||||
|
if user_id:
|
||||||
|
payload["user_id"] = user_id
|
||||||
|
result = self._request("POST", "/search", json=payload)
|
||||||
|
return result.get("results", [])
|
||||||
|
|
||||||
|
def get_all(self, user_id: Optional[str] = None) -> List[Dict]:
|
||||||
|
"""Get all memories for a user.
|
||||||
|
|
||||||
|
API: GET /memories?user_id=...
|
||||||
|
Response: {memories: [{id, text, user_id, metadata}]}
|
||||||
|
"""
|
||||||
|
params = {}
|
||||||
|
if user_id:
|
||||||
|
params["user_id"] = user_id
|
||||||
result = self._request("GET", "/memories", params=params)
|
result = self._request("GET", "/memories", params=params)
|
||||||
return self._unwrap_results(result)
|
return result.get("memories", [])
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
messages: List[Dict[str, str]],
|
message: str,
|
||||||
filters: Dict[str, Any],
|
user_id: Optional[str] = None,
|
||||||
infer: bool = True,
|
metadata: Optional[Dict] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Add conversation messages for fact extraction."""
|
"""Add a new memory.
|
||||||
# Extract message content from messages array
|
|
||||||
if messages and isinstance(messages[0], dict):
|
|
||||||
message_content = messages[0].get("content", "")
|
|
||||||
elif messages:
|
|
||||||
message_content = str(messages[0])
|
|
||||||
else:
|
|
||||||
message_content = ""
|
|
||||||
|
|
||||||
payload = {
|
API: POST /add
|
||||||
"message": message_content,
|
Request: {message, user_id, metadata}
|
||||||
"user_id": filters.get("user_id"),
|
Response: {success, memory_id, message}
|
||||||
"agent_id": filters.get("agent_id"),
|
"""
|
||||||
}
|
payload = {"message": message}
|
||||||
if not infer:
|
if user_id:
|
||||||
payload["only_store_messages"] = True
|
payload["user_id"] = user_id
|
||||||
|
if metadata:
|
||||||
|
payload["metadata"] = metadata
|
||||||
return self._request("POST", "/add", json=payload)
|
return self._request("POST", "/add", json=payload)
|
||||||
|
|
||||||
@staticmethod
|
def delete(self, memory_id: str) -> Dict:
|
||||||
def _unwrap_results(response: Any) -> List[Dict]:
|
"""Delete a memory by ID.
|
||||||
"""Normalize Mem0 API response."""
|
|
||||||
if isinstance(response, dict):
|
API: DELETE /delete/{memory_id}
|
||||||
return response.get("memories", response.get("results", []))
|
Response: {success, memory_id, message}
|
||||||
if isinstance(response, list):
|
"""
|
||||||
return response
|
return self._request("DELETE", f"/delete/{memory_id}")
|
||||||
return []
|
|
||||||
|
|
||||||
def health(self) -> bool:
|
def health(self) -> bool:
|
||||||
"""Check if server is reachable."""
|
"""Check if server is reachable."""
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ provides_tools:
|
|||||||
- mem0_profile
|
- mem0_profile
|
||||||
- mem0_search
|
- mem0_search
|
||||||
- mem0_conclude
|
- mem0_conclude
|
||||||
|
- mem0_delete
|
||||||
|
|
||||||
pip_dependencies:
|
pip_dependencies:
|
||||||
- requests
|
- requests
|
||||||
|
|||||||
Reference in New Issue
Block a user