Merge pull request 'Add configurable prefetch limit and score threshold' (#2) from mem0-prefetch-improvements into main
Reviewed-on: #2
This commit was merged in pull request #2.
This commit is contained in:
+55
-8
@@ -46,6 +46,8 @@ def _load_config() -> dict:
|
|||||||
"agent_id": os.environ.get("MEM0_AGENT_ID", "hermes"),
|
"agent_id": os.environ.get("MEM0_AGENT_ID", "hermes"),
|
||||||
"rerank": True,
|
"rerank": True,
|
||||||
"timeout": 10.0,
|
"timeout": 10.0,
|
||||||
|
"prefetch_limit": 3,
|
||||||
|
"prefetch_score_threshold": 60,
|
||||||
}
|
}
|
||||||
|
|
||||||
config_path = get_hermes_home() / "mem0-local.json"
|
config_path = get_hermes_home() / "mem0-local.json"
|
||||||
@@ -146,6 +148,8 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
|||||||
self._user_id = "hermes-user"
|
self._user_id = "hermes-user"
|
||||||
self._agent_id = "hermes"
|
self._agent_id = "hermes"
|
||||||
self._rerank = True
|
self._rerank = True
|
||||||
|
self._prefetch_limit = 3
|
||||||
|
self._prefetch_score_threshold = 60
|
||||||
self._prefetch_result = ""
|
self._prefetch_result = ""
|
||||||
self._prefetch_lock = threading.Lock()
|
self._prefetch_lock = threading.Lock()
|
||||||
self._prefetch_thread = None
|
self._prefetch_thread = None
|
||||||
@@ -206,6 +210,16 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
|||||||
"description": "Request timeout in seconds",
|
"description": "Request timeout in seconds",
|
||||||
"default": "10.0",
|
"default": "10.0",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"key": "prefetch_limit",
|
||||||
|
"description": "Max memories to prefetch (pre-LLM hook)",
|
||||||
|
"default": "3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"key": "prefetch_score_threshold",
|
||||||
|
"description": "Min similarity score % to include memory (0-100)",
|
||||||
|
"default": "60",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
def _get_client(self) -> LocalMem0Client:
|
def _get_client(self) -> LocalMem0Client:
|
||||||
@@ -219,6 +233,10 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
|||||||
self._user_id = self._config.get("user_id", "hermes-user")
|
self._user_id = self._config.get("user_id", "hermes-user")
|
||||||
self._agent_id = self._config.get("agent_id", "hermes")
|
self._agent_id = self._config.get("agent_id", "hermes")
|
||||||
self._rerank = self._config.get("rerank", True)
|
self._rerank = self._config.get("rerank", True)
|
||||||
|
self._prefetch_limit = int(self._config.get("prefetch_limit", 3))
|
||||||
|
self._prefetch_score_threshold = int(
|
||||||
|
self._config.get("prefetch_score_threshold", 60)
|
||||||
|
)
|
||||||
base_url = self._config.get("base_url", "http://localhost:8000")
|
base_url = self._config.get("base_url", "http://localhost:8000")
|
||||||
timeout = float(self._config.get("timeout", 10.0))
|
timeout = float(self._config.get("timeout", 10.0))
|
||||||
self._client = LocalMem0Client(base_url, timeout=timeout)
|
self._client = LocalMem0Client(base_url, timeout=timeout)
|
||||||
@@ -267,6 +285,10 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
|||||||
)
|
)
|
||||||
self._agent_id = self._config.get("agent_id", "hermes")
|
self._agent_id = self._config.get("agent_id", "hermes")
|
||||||
self._rerank = self._config.get("rerank", True)
|
self._rerank = self._config.get("rerank", True)
|
||||||
|
self._prefetch_limit = int(self._config.get("prefetch_limit", 3))
|
||||||
|
self._prefetch_score_threshold = int(
|
||||||
|
self._config.get("prefetch_score_threshold", 60)
|
||||||
|
)
|
||||||
|
|
||||||
def system_prompt_block(self) -> str:
|
def system_prompt_block(self) -> str:
|
||||||
return (
|
return (
|
||||||
@@ -290,27 +312,39 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
|||||||
self._prefetch_result = ""
|
self._prefetch_result = ""
|
||||||
if not result:
|
if not result:
|
||||||
return ""
|
return ""
|
||||||
|
# Check if it's an error message
|
||||||
|
if result.startswith("ERROR:"):
|
||||||
|
return f"## Mem0 Error\n{result[6:]}"
|
||||||
return f"## Mem0 Memory\n{result}"
|
return f"## Mem0 Memory\n{result}"
|
||||||
|
|
||||||
def queue_prefetch_and_get(self, query: str) -> str:
|
def queue_prefetch_and_get(self, query: str) -> str:
|
||||||
"""Sync prefetch for pre_llm_call hook - returns memory context immediately."""
|
"""Sync prefetch for pre_llm_call hook - returns memory context immediately."""
|
||||||
if self._is_breaker_open():
|
if self._is_breaker_open():
|
||||||
return ""
|
return (
|
||||||
|
"ERROR:Memory service temporarily unavailable. Please try again later."
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
results = client.search(
|
results = client.search(
|
||||||
query=query,
|
query=query,
|
||||||
user_id=self._user_id,
|
user_id=self._user_id,
|
||||||
limit=5,
|
limit=self._prefetch_limit,
|
||||||
)
|
)
|
||||||
if results:
|
# Filter by score threshold
|
||||||
formatted = self._format_search_results(results)
|
threshold = self._prefetch_score_threshold / 100.0
|
||||||
|
filtered = [r for r in results if r.get("score", 0) >= threshold]
|
||||||
|
if filtered:
|
||||||
|
formatted = self._format_search_results(filtered)
|
||||||
if formatted:
|
if formatted:
|
||||||
self._record_success()
|
self._record_success()
|
||||||
return formatted
|
return formatted
|
||||||
|
self._record_success()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._record_failure()
|
self._record_failure()
|
||||||
logger.debug("Mem0 prefetch failed: %s", e)
|
logger.debug("Mem0 prefetch failed: %s", e)
|
||||||
|
return (
|
||||||
|
"ERROR:Memory service temporarily unavailable. Please try again later."
|
||||||
|
)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def queue_prefetch(self, query: str, *, session_id: str = "") -> None:
|
def queue_prefetch(self, query: str, *, session_id: str = "") -> None:
|
||||||
@@ -321,6 +355,8 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
|||||||
session_id: Unused. Kept for API compatibility.
|
session_id: Unused. Kept for API compatibility.
|
||||||
"""
|
"""
|
||||||
if self._is_breaker_open():
|
if self._is_breaker_open():
|
||||||
|
with self._prefetch_lock:
|
||||||
|
self._prefetch_result = "ERROR:Memory service temporarily unavailable. Please try again later."
|
||||||
return
|
return
|
||||||
|
|
||||||
def _run():
|
def _run():
|
||||||
@@ -329,16 +365,24 @@ class Mem0LocalMemoryProvider(MemoryProvider):
|
|||||||
results = client.search(
|
results = client.search(
|
||||||
query=query,
|
query=query,
|
||||||
user_id=self._user_id,
|
user_id=self._user_id,
|
||||||
limit=5,
|
limit=self._prefetch_limit,
|
||||||
)
|
)
|
||||||
if results:
|
# Filter by score threshold
|
||||||
formatted = self._format_search_results(results)
|
threshold = self._prefetch_score_threshold / 100.0
|
||||||
|
filtered = [r for r in results if r.get("score", 0) >= threshold]
|
||||||
|
if filtered:
|
||||||
|
formatted = self._format_search_results(filtered)
|
||||||
with self._prefetch_lock:
|
with self._prefetch_lock:
|
||||||
self._prefetch_result = formatted
|
self._prefetch_result = formatted
|
||||||
|
else:
|
||||||
|
with self._prefetch_lock:
|
||||||
|
self._prefetch_result = ""
|
||||||
self._record_success()
|
self._record_success()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._record_failure()
|
self._record_failure()
|
||||||
logger.debug("Mem0 prefetch failed: %s", e)
|
logger.debug("Mem0 prefetch failed: %s", e)
|
||||||
|
with self._prefetch_lock:
|
||||||
|
self._prefetch_result = "ERROR:Memory service temporarily unavailable. Please try again later."
|
||||||
|
|
||||||
self._prefetch_thread = threading.Thread(
|
self._prefetch_thread = threading.Thread(
|
||||||
target=_run, daemon=True, name="mem0-local-prefetch"
|
target=_run, daemon=True, name="mem0-local-prefetch"
|
||||||
@@ -504,7 +548,10 @@ def register(ctx) -> None:
|
|||||||
try:
|
try:
|
||||||
results = provider.queue_prefetch_and_get(user_message)
|
results = provider.queue_prefetch_and_get(user_message)
|
||||||
if results:
|
if results:
|
||||||
return {"context": results}
|
# Error messages get their own header, memories get standard header
|
||||||
|
if results.startswith("ERROR:"):
|
||||||
|
return {"context": f"## Mem0 Error\n{results[6:]}"}
|
||||||
|
return {"context": f"## Mem0 Memory\n{results}"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Mem0 pre_llm_call hook failed: %s", e)
|
logger.debug("Mem0 pre_llm_call hook failed: %s", e)
|
||||||
return {}
|
return {}
|
||||||
|
|||||||
Reference in New Issue
Block a user