""" Mem0 API Server - Simple wrapper around mem0ai with llama.cpp embedding support """ import os import requests from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import Optional, List, Any, Dict from qdrant_client import QdrantClient from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue app = FastAPI(title="Mem0 API", version="1.0.0") # Configuration from environment QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant") QDRANT_PORT = int(os.getenv("QDRANT_PORT", 6333)) EMBEDDING_URL = os.getenv("EMBEDDING_URL", "http://host.docker.internal:4700/embedding") EMBEDDING_DIMS = int(os.getenv("EMBEDDING_DIMS", 1024)) COLLECTION_NAME = "memories" class LlamaCppEmbedder: """Custom embedder for llama.cpp embedding endpoint""" def __init__(self, base_url: str, dims: int): self.base_url = base_url self.dims = dims def get_embedding(self, text: str) -> List[float]: """Get embedding from llama.cpp endpoint""" response = requests.post( self.base_url, json={"content": text}, headers={"Content-Type": "application/json"} ) response.raise_for_status() result = response.json() embedding = result[0]["embedding"][0] return embedding # Initialize embedder and Qdrant client embedder = LlamaCppEmbedder(EMBEDDING_URL, EMBEDDING_DIMS) qdrant_client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT) def init_collection(): """Initialize Qdrant collection if it doesn't exist""" collections = qdrant_client.get_collections().collections collection_names = [c.name for c in collections] if COLLECTION_NAME not in collection_names: qdrant_client.create_collection( collection_name=COLLECTION_NAME, vectors_config=VectorParams(size=EMBEDDING_DIMS, distance=Distance.COSINE) ) # Initialize collection on startup init_collection() class AddMemoryRequest(BaseModel): message: str user_id: Optional[str] = "default" metadata: Optional[dict] = None class AddMemoryResponse(BaseModel): success: bool memory_id: Optional[str] message: str class SearchMemoryRequest(BaseModel): query: str user_id: Optional[str] = "default" limit: Optional[int] = 5 class SearchResult(BaseModel): id: str text: str user_id: str score: float metadata: Optional[dict] class SearchMemoryResponse(BaseModel): results: List[SearchResult] class MemoryItem(BaseModel): id: str text: str user_id: str metadata: Optional[dict] class GetMemoriesResponse(BaseModel): memories: List[MemoryItem] class DeleteMemoryResponse(BaseModel): success: bool memory_id: str message: str @app.get("/health") async def health_check(): """Health check endpoint""" try: # Test embedding endpoint test_response = requests.get(EMBEDDING_URL.replace("/embedding", "/"), timeout=5) embedding_healthy = test_response.status_code == 200 or "gzip" in test_response.text.lower() # Test Qdrant qdrant_url = f"http://{QDRANT_HOST}:{QDRANT_PORT}/" qdrant_response = requests.get(qdrant_url, timeout=5) qdrant_healthy = qdrant_response.status_code == 200 return { "status": "healthy" if (embedding_healthy and qdrant_healthy) else "degraded", "service": "mem0-api", "embedding_endpoint": embedding_healthy, "qdrant": qdrant_healthy } except Exception as e: return {"status": "unhealthy", "service": "mem0-api", "error": str(e)} @app.post("/add", response_model=AddMemoryResponse) async def add_memory(request: AddMemoryRequest): """Add a new memory""" try: import uuid memory_id = str(uuid.uuid4()) # Get embedding embedding = embedder.get_embedding(request.message) # Create point point = PointStruct( id=memory_id, vector=embedding, payload={ "text": request.message, "user_id": request.user_id, "metadata": request.metadata or {} } ) # Upsert to Qdrant qdrant_client.upsert(collection_name=COLLECTION_NAME, points=[point]) return AddMemoryResponse( success=True, memory_id=memory_id, message="Memory added successfully" ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/search", response_model=SearchMemoryResponse) async def search_memory(request: SearchMemoryRequest): """Search for memories""" try: # Get query embedding query_embedding = embedder.get_embedding(request.query) # Build filter for user_id query_filter = None if request.user_id: query_filter = Filter( must=[FieldCondition(key="user_id", match=MatchValue(value=request.user_id))] ) # Search in Qdrant using query_points (new API) - pass vector directly results = qdrant_client.query_points( collection_name=COLLECTION_NAME, query=query_embedding, limit=request.limit, query_filter=query_filter, with_payload=True, with_vectors=False ) # Format results formatted_results = [ SearchResult( id=str(hit.id), text=hit.payload.get("text", ""), user_id=hit.payload.get("user_id", ""), score=hit.score, metadata=hit.payload.get("metadata") ) for hit in results.points ] return SearchMemoryResponse(results=formatted_results) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/memories", response_model=GetMemoriesResponse) async def get_memories(user_id: Optional[str] = "default"): """Get all memories for a user""" try: # Build filter for user_id scroll_filter = None if user_id: scroll_filter = Filter( must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))] ) # Scroll through collection memories = [] offset = None while True: result, next_offset = qdrant_client.scroll( collection_name=COLLECTION_NAME, limit=100, offset=offset, scroll_filter=scroll_filter, with_payload=True, with_vectors=False ) for point in result: memories.append(MemoryItem( id=str(point.id), text=point.payload.get("text", ""), user_id=point.payload.get("user_id", ""), metadata=point.payload.get("metadata") )) if not next_offset: break offset = next_offset return GetMemoriesResponse(memories=memories) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.delete("/delete/{memory_id}", response_model=DeleteMemoryResponse) async def delete_memory(memory_id: str): """Delete a memory by ID""" try: qdrant_client.delete( collection_name=COLLECTION_NAME, points_selector=[memory_id] ) return DeleteMemoryResponse( success=True, memory_id=memory_id, message="Memory deleted successfully" ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)