Files
mem0-docker-qdrant/main.py
T
2026-04-13 16:38:35 +02:00

268 lines
7.8 KiB
Python

"""
Mem0 API Server - Simple wrapper around mem0ai with llama.cpp embedding support
"""
import os
import requests
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, List, Any, Dict
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
app = FastAPI(title="Mem0 API", version="1.0.0")
# Configuration from environment
QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant")
QDRANT_PORT = int(os.getenv("QDRANT_PORT", 6333))
EMBEDDING_URL = os.getenv("EMBEDDING_URL", "http://host.docker.internal:4700/embedding")
EMBEDDING_DIMS = int(os.getenv("EMBEDDING_DIMS", 1024))
COLLECTION_NAME = "memories"
class LlamaCppEmbedder:
"""Custom embedder for llama.cpp embedding endpoint"""
def __init__(self, base_url: str, dims: int):
self.base_url = base_url
self.dims = dims
def get_embedding(self, text: str) -> List[float]:
"""Get embedding from llama.cpp endpoint"""
response = requests.post(
self.base_url,
json={"content": text},
headers={"Content-Type": "application/json"}
)
response.raise_for_status()
result = response.json()
embedding = result[0]["embedding"][0]
return embedding
# Initialize embedder and Qdrant client
embedder = LlamaCppEmbedder(EMBEDDING_URL, EMBEDDING_DIMS)
qdrant_client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
def init_collection():
"""Initialize Qdrant collection if it doesn't exist"""
collections = qdrant_client.get_collections().collections
collection_names = [c.name for c in collections]
if COLLECTION_NAME not in collection_names:
qdrant_client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=EMBEDDING_DIMS, distance=Distance.COSINE)
)
# Initialize collection on startup
init_collection()
class AddMemoryRequest(BaseModel):
message: str
user_id: Optional[str] = "default"
metadata: Optional[dict] = None
class AddMemoryResponse(BaseModel):
success: bool
memory_id: Optional[str]
message: str
class SearchMemoryRequest(BaseModel):
query: str
user_id: Optional[str] = "default"
limit: Optional[int] = 5
class SearchResult(BaseModel):
id: str
text: str
user_id: str
score: float
metadata: Optional[dict]
class SearchMemoryResponse(BaseModel):
results: List[SearchResult]
class MemoryItem(BaseModel):
id: str
text: str
user_id: str
metadata: Optional[dict]
class GetMemoriesResponse(BaseModel):
memories: List[MemoryItem]
class DeleteMemoryResponse(BaseModel):
success: bool
memory_id: str
message: str
@app.get("/health")
async def health_check():
"""Health check endpoint"""
try:
# Test embedding endpoint
test_response = requests.get(EMBEDDING_URL.replace("/embedding", "/"), timeout=5)
embedding_healthy = test_response.status_code == 200 or "gzip" in test_response.text.lower()
# Test Qdrant
qdrant_url = f"http://{QDRANT_HOST}:{QDRANT_PORT}/"
qdrant_response = requests.get(qdrant_url, timeout=5)
qdrant_healthy = qdrant_response.status_code == 200
return {
"status": "healthy" if (embedding_healthy and qdrant_healthy) else "degraded",
"service": "mem0-api",
"embedding_endpoint": embedding_healthy,
"qdrant": qdrant_healthy
}
except Exception as e:
return {"status": "unhealthy", "service": "mem0-api", "error": str(e)}
@app.post("/add", response_model=AddMemoryResponse)
async def add_memory(request: AddMemoryRequest):
"""Add a new memory"""
try:
import uuid
memory_id = str(uuid.uuid4())
# Get embedding
embedding = embedder.get_embedding(request.message)
# Create point
point = PointStruct(
id=memory_id,
vector=embedding,
payload={
"text": request.message,
"user_id": request.user_id,
"metadata": request.metadata or {}
}
)
# Upsert to Qdrant
qdrant_client.upsert(collection_name=COLLECTION_NAME, points=[point])
return AddMemoryResponse(
success=True,
memory_id=memory_id,
message="Memory added successfully"
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/search", response_model=SearchMemoryResponse)
async def search_memory(request: SearchMemoryRequest):
"""Search for memories"""
try:
# Get query embedding
query_embedding = embedder.get_embedding(request.query)
# Build filter for user_id
query_filter = None
if request.user_id:
query_filter = Filter(
must=[FieldCondition(key="user_id", match=MatchValue(value=request.user_id))]
)
# Search in Qdrant using query_points (new API) - pass vector directly
results = qdrant_client.query_points(
collection_name=COLLECTION_NAME,
query=query_embedding,
limit=request.limit,
query_filter=query_filter,
with_payload=True,
with_vectors=False
)
# Format results
formatted_results = [
SearchResult(
id=str(hit.id),
text=hit.payload.get("text", ""),
user_id=hit.payload.get("user_id", ""),
score=hit.score,
metadata=hit.payload.get("metadata")
)
for hit in results.points
]
return SearchMemoryResponse(results=formatted_results)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/memories", response_model=GetMemoriesResponse)
async def get_memories(user_id: Optional[str] = "default"):
"""Get all memories for a user"""
try:
# Build filter for user_id
scroll_filter = None
if user_id:
scroll_filter = Filter(
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
)
# Scroll through collection
memories = []
offset = None
while True:
result, next_offset = qdrant_client.scroll(
collection_name=COLLECTION_NAME,
limit=100,
offset=offset,
scroll_filter=scroll_filter,
with_payload=True,
with_vectors=False
)
for point in result:
memories.append(MemoryItem(
id=str(point.id),
text=point.payload.get("text", ""),
user_id=point.payload.get("user_id", ""),
metadata=point.payload.get("metadata")
))
if not next_offset:
break
offset = next_offset
return GetMemoriesResponse(memories=memories)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/delete/{memory_id}", response_model=DeleteMemoryResponse)
async def delete_memory(memory_id: str):
"""Delete a memory by ID"""
try:
qdrant_client.delete(
collection_name=COLLECTION_NAME,
points_selector=[memory_id]
)
return DeleteMemoryResponse(
success=True,
memory_id=memory_id,
message="Memory deleted successfully"
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)