124 lines
3.8 KiB
Python
124 lines
3.8 KiB
Python
"""Local Mem0 server HTTP client."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import requests
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LocalMem0Client:
|
|
"""HTTP client for self-hosted Mem0 server.
|
|
|
|
Expects Mem0 server at MEM0_BASE_URL with endpoints:
|
|
- POST /search
|
|
- GET /memories
|
|
- POST /add
|
|
"""
|
|
|
|
def __init__(self, base_url: str, timeout: float = 10.0):
|
|
self.base_url = base_url.rstrip("/")
|
|
self.timeout = timeout
|
|
self.session = requests.Session()
|
|
self.session.headers.update(
|
|
{
|
|
"Content-Type": "application/json",
|
|
"User-Agent": "hermes-agent-mem0-local-plugin/1.0.0",
|
|
}
|
|
)
|
|
|
|
def _request(
|
|
self,
|
|
method: str,
|
|
endpoint: str,
|
|
json: Optional[Dict] = None,
|
|
params: Optional[Dict] = None,
|
|
) -> Dict:
|
|
"""Make HTTP request with error handling."""
|
|
url = f"{self.base_url}{endpoint}"
|
|
try:
|
|
resp = self.session.request(
|
|
method, url, json=json, params=params, timeout=self.timeout
|
|
)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
except requests.exceptions.Timeout:
|
|
logger.error("Mem0 request timed out after %ss", self.timeout)
|
|
raise
|
|
except requests.exceptions.ConnectionError as e:
|
|
logger.error("Failed to connect to Mem0 server at %s: %s", self.base_url, e)
|
|
raise
|
|
except requests.exceptions.HTTPError as e:
|
|
logger.error(
|
|
"Mem0 API error: %s - %s", e.response.status_code, e.response.text
|
|
)
|
|
raise
|
|
|
|
def search(
|
|
self,
|
|
query: str,
|
|
filters: Dict[str, Any],
|
|
rerank: bool = False,
|
|
top_k: int = 10,
|
|
) -> List[Dict]:
|
|
"""Search memories by semantic similarity."""
|
|
payload = {
|
|
"query": query,
|
|
"user_id": filters.get("user_id"),
|
|
"agent_id": filters.get("agent_id"),
|
|
"top_k": top_k,
|
|
}
|
|
if rerank is not None:
|
|
payload["rerank"] = rerank
|
|
result = self._request("POST", "/search", json=payload)
|
|
return self._unwrap_results(result)
|
|
|
|
def get_all(self, filters: Dict[str, Any]) -> List[Dict]:
|
|
"""Get all memories matching filters."""
|
|
params = filters
|
|
result = self._request("GET", "/memories", params=params)
|
|
return self._unwrap_results(result)
|
|
|
|
def add(
|
|
self,
|
|
messages: List[Dict[str, str]],
|
|
filters: Dict[str, Any],
|
|
infer: bool = True,
|
|
) -> Dict:
|
|
"""Add conversation messages for fact extraction."""
|
|
payload = {
|
|
"messages": messages,
|
|
"user_id": filters.get("user_id"),
|
|
"agent_id": filters.get("agent_id"),
|
|
}
|
|
if not infer:
|
|
payload["messages"] = [
|
|
{
|
|
"role": "user",
|
|
"content": messages[0].get("content", "")
|
|
if isinstance(messages[0], dict)
|
|
else messages[0],
|
|
}
|
|
]
|
|
return self._request("POST", "/add", json=payload)
|
|
|
|
@staticmethod
|
|
def _unwrap_results(response: Any) -> List[Dict]:
|
|
"""Normalize Mem0 API response."""
|
|
if isinstance(response, dict):
|
|
return response.get("memories", response.get("results", []))
|
|
if isinstance(response, list):
|
|
return response
|
|
return []
|
|
|
|
def health(self) -> bool:
|
|
"""Check if server is reachable."""
|
|
try:
|
|
resp = self.session.get(f"{self.base_url}/health", timeout=5.0)
|
|
return resp.status_code == 200
|
|
except requests.exceptions.RequestException:
|
|
return False
|