Improve streaming response and add system prompt support
- Add configurable initial streaming message - Support system prompt in API requests - Fix config key typo (open-webui -> open_webui) - Add validation for required config values - Improve error handling for network and API errors - Set proper timeout for API requests (900s) - Better logging for rate limit errors
This commit is contained in:
@@ -9,7 +9,7 @@ allow_dms: false # set to true if you want the bot to answer private DM
|
|||||||
|
|
||||||
# ─────────── Open‑WebUI Settings ───────────
|
# ─────────── Open‑WebUI Settings ───────────
|
||||||
open_webui_url: "http://your_open-webui_ip_or_domain:port"
|
open_webui_url: "http://your_open-webui_ip_or_domain:port"
|
||||||
open-webui_api_key: "user_api_key_from_open_webui"
|
open_webui_api_key: "user_api_key_from_open_webui"
|
||||||
model_name: "model_id_from_open-webui"
|
model_name: "model_id_from_open-webui"
|
||||||
knowledge_base: "knowledge_base_id_from_open-webui"
|
knowledge_base: "knowledge_base_id_from_open-webui"
|
||||||
|
|
||||||
@@ -20,5 +20,7 @@ tools:
|
|||||||
|
|
||||||
use_streaming: true # Allows to stream the answer to feel more interactive.
|
use_streaming: true # Allows to stream the answer to feel more interactive.
|
||||||
|
|
||||||
|
streaming_initial_message: "Bitte warte kurz, die Informationen werden gesammelt..."
|
||||||
|
|
||||||
# optional system prompt (you can leave it empty to use the default one or the systemprompt given in open-webui for the specific model)
|
# optional system prompt (you can leave it empty to use the default one or the systemprompt given in open-webui for the specific model)
|
||||||
system_prompt: ""
|
system_prompt: ""
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
import asyncio
|
|
||||||
import discord
|
import discord
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import logging
|
import logging
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
# ────────────────────────────────────────────────
|
# ────────────────────────────────────────────────
|
||||||
# Setup logging
|
# Setup logging
|
||||||
@@ -24,7 +25,7 @@ DISCORD_TOKEN = config["discord_token"] # Discord bot token
|
|||||||
WHITELIST_CHANNELS = set(map(int, config.get("whitelist_channels", []))) # Set of whitelisted channel IDs (int)
|
WHITELIST_CHANNELS = set(map(int, config.get("whitelist_channels", []))) # Set of whitelisted channel IDs (int)
|
||||||
|
|
||||||
OPENWEBUI_URL = config["open_webui_url"].rstrip('/') # Ensure no trailing slash
|
OPENWEBUI_URL = config["open_webui_url"].rstrip('/') # Ensure no trailing slash
|
||||||
OPENWEBUI_API_KEY = config["open-webui_api_key"] # API key for Open-WebUI
|
OPENWEBUI_API_KEY = config["open_webui_api_key"] # API key for Open-WebUI
|
||||||
MODEL_NAME = config["model_name"] # Model name to use, e.g., "gpt-3.5-turbo"
|
MODEL_NAME = config["model_name"] # Model name to use, e.g., "gpt-3.5-turbo"
|
||||||
KNOW_BASE = config["knowledge_base"] # Knowledge base to use, e.g., "knowledge_base_v1"
|
KNOW_BASE = config["knowledge_base"] # Knowledge base to use, e.g., "knowledge_base_v1"
|
||||||
|
|
||||||
@@ -33,29 +34,44 @@ USE_STREAMING = config.get("use_streaming", False) # Enable/disable streamin
|
|||||||
|
|
||||||
SYSTEM_PROMPT = config.get("system_prompt", None) # Optional system prompt to prepend to user messages
|
SYSTEM_PROMPT = config.get("system_prompt", None) # Optional system prompt to prepend to user messages
|
||||||
ALLOW_DMS = config.get("allow_dms", False) # Allow DMs to the bot (default: False)
|
ALLOW_DMS = config.get("allow_dms", False) # Allow DMs to the bot (default: False)
|
||||||
|
STREAMING_INITIAL_MESSAGE = config.get("streaming_initial_message", "Bitte warte kurz, die Informationen werden gesammelt...")
|
||||||
|
|
||||||
async def _query_openwebui(user_text: str, channel_id: int, tools_list: list):
|
if not DISCORD_TOKEN:
|
||||||
|
raise ValueError("discord_token is required in config.yml")
|
||||||
|
if not OPENWEBUI_API_KEY:
|
||||||
|
raise ValueError("open_webui_api_key is required in config.yml")
|
||||||
|
if not OPENWEBUI_URL:
|
||||||
|
raise ValueError("open_webui_url is required in config.yml")
|
||||||
|
if not MODEL_NAME:
|
||||||
|
raise ValueError("model_name is required in config.yml")
|
||||||
|
|
||||||
|
async def _query_openwebui(user_text: str, tools_list: list):
|
||||||
"""
|
"""
|
||||||
Payload structure for the OpenAI-compatible endpoint.
|
Payload structure for the OpenAI-compatible endpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_text (str): The user's message to send to the Open-WebUI.
|
user_text (str): The user's message to send to the Open-WebUI.
|
||||||
channel_id (int): The Discord channel ID where the message was sent.
|
|
||||||
tools_list (list): List of tool IDs to use, if any.
|
tools_list (list): List of tool IDs to use, if any.
|
||||||
"""
|
"""
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=900)) as session:
|
||||||
# This payload structure is for the OpenAI-compatible endpoint from open-webui
|
messages = []
|
||||||
|
if SYSTEM_PROMPT:
|
||||||
|
messages.append({
|
||||||
|
"role": "system",
|
||||||
|
"content": SYSTEM_PROMPT
|
||||||
|
})
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": user_text
|
||||||
|
})
|
||||||
payload = {
|
payload = {
|
||||||
"model": MODEL_NAME,
|
"model": MODEL_NAME,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"messages": [
|
"messages": messages
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": user_text
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
# Attach tools if provided in the config file
|
if KNOW_BASE:
|
||||||
|
payload["knowledge_base"] = KNOW_BASE
|
||||||
|
logging.debug(f"📚 Using knowledge base: {payload['knowledge_base']}")
|
||||||
if tools_list:
|
if tools_list:
|
||||||
payload["tool_ids"] = tools_list
|
payload["tool_ids"] = tools_list
|
||||||
logging.debug(f"🔧 Using tools: {payload['tool_ids']}")
|
logging.debug(f"🔧 Using tools: {payload['tool_ids']}")
|
||||||
@@ -81,28 +97,36 @@ async def _query_openwebui(user_text: str, channel_id: int, tools_list: list):
|
|||||||
return "No response content received from Open-WebUI"
|
return "No response content received from Open-WebUI"
|
||||||
return content
|
return content
|
||||||
|
|
||||||
async def _query_openwebui_streaming(user_text: str, channel_id: int, tools_list: list, message_to_edit):
|
async def _query_openwebui_streaming(user_text: str, tools_list: list, message_to_edit):
|
||||||
"""
|
"""
|
||||||
Stream response from Open-WebUI and edit Discord message progressively.
|
Stream response from Open-WebUI and edit Discord message progressively.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_text (str): The user's message to send to the Open-WebUI.
|
user_text (str): The user's message to send to the Open-WebUI.
|
||||||
channel_id (int): The Discord channel ID where the message was sent.
|
|
||||||
tools_list (list): List of tool IDs to use, if any.
|
tools_list (list): List of tool IDs to use, if any.
|
||||||
message_to_edit: The Discord message object to edit with streaming content.
|
message_to_edit: The Discord message object to edit with streaming content.
|
||||||
"""
|
"""
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=900)) as session:
|
||||||
|
messages = []
|
||||||
|
if SYSTEM_PROMPT:
|
||||||
|
messages.append({
|
||||||
|
"role": "system",
|
||||||
|
"content": SYSTEM_PROMPT
|
||||||
|
})
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": user_text
|
||||||
|
})
|
||||||
payload = {
|
payload = {
|
||||||
"model": MODEL_NAME,
|
"model": MODEL_NAME,
|
||||||
"stream": True, # Enable streaming
|
"stream": True,
|
||||||
"messages": [
|
"messages": messages
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": user_text
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if KNOW_BASE:
|
||||||
|
payload["knowledge_base"] = KNOW_BASE
|
||||||
|
logging.debug(f"📚 Using knowledge base: {payload['knowledge_base']}")
|
||||||
|
|
||||||
if tools_list:
|
if tools_list:
|
||||||
payload["tool_ids"] = tools_list
|
payload["tool_ids"] = tools_list
|
||||||
logging.debug(f"🔧 Using tools: {payload['tool_ids']}")
|
logging.debug(f"🔧 Using tools: {payload['tool_ids']}")
|
||||||
@@ -131,7 +155,6 @@ async def _query_openwebui_streaming(user_text: str, channel_id: int, tools_list
|
|||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import json
|
|
||||||
chunk_data = json.loads(data_str)
|
chunk_data = json.loads(data_str)
|
||||||
|
|
||||||
if 'choices' in chunk_data and len(chunk_data['choices']) > 0:
|
if 'choices' in chunk_data and len(chunk_data['choices']) > 0:
|
||||||
@@ -141,19 +164,18 @@ async def _query_openwebui_streaming(user_text: str, channel_id: int, tools_list
|
|||||||
accumulated_content += content
|
accumulated_content += content
|
||||||
|
|
||||||
# Edit message periodically to avoid rate limits
|
# Edit message periodically to avoid rate limits
|
||||||
current_time = asyncio.get_event_loop().time()
|
current_time = time.time()
|
||||||
if current_time - last_edit_time >= edit_interval:
|
if current_time - last_edit_time >= edit_interval:
|
||||||
try:
|
try:
|
||||||
# Limit message length to Discord's 2000 character limit
|
# Limit message length to 2000. We don't want to spam the channel with too many edits, and Discord has a 2000 character limit per message.
|
||||||
content_to_show = accumulated_content[:1900]
|
content_to_show = accumulated_content[:2000]
|
||||||
if len(accumulated_content) > 1900:
|
if len(accumulated_content) > 2000:
|
||||||
content_to_show += "..."
|
content_to_show += "..."
|
||||||
|
|
||||||
await message_to_edit.edit(content=content_to_show)
|
await message_to_edit.edit(content=content_to_show)
|
||||||
last_edit_time = current_time
|
last_edit_time = current_time
|
||||||
except discord.HTTPException:
|
except discord.HTTPException as e:
|
||||||
# Handle rate limits gracefully
|
logging.warning(f"Discord rate limit or error while editing message: {e}")
|
||||||
pass
|
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
continue
|
continue
|
||||||
@@ -162,8 +184,8 @@ async def _query_openwebui_streaming(user_text: str, channel_id: int, tools_list
|
|||||||
try:
|
try:
|
||||||
final_content = accumulated_content[:2000] # Respect Discord's limit
|
final_content = accumulated_content[:2000] # Respect Discord's limit
|
||||||
await message_to_edit.edit(content=final_content)
|
await message_to_edit.edit(content=final_content)
|
||||||
except discord.HTTPException:
|
except discord.HTTPException as e:
|
||||||
pass
|
logging.warning(f"Discord rate limit or error on final edit: {e}")
|
||||||
|
|
||||||
return accumulated_content
|
return accumulated_content
|
||||||
|
|
||||||
@@ -224,39 +246,35 @@ async def on_message(message):
|
|||||||
if not is_dm and WHITELIST_CHANNELS and message.channel.id not in WHITELIST_CHANNELS:
|
if not is_dm and WHITELIST_CHANNELS and message.channel.id not in WHITELIST_CHANNELS:
|
||||||
return
|
return
|
||||||
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
# A. Prepare payload
|
|
||||||
# ----------------------------------------------------------------------- #
|
|
||||||
# The OpenAI endpoint works better without the extra context in the prompt
|
|
||||||
prompt = message.content
|
prompt = message.content
|
||||||
if SYSTEM_PROMPT:
|
|
||||||
# The system prompt is handled differently in the OpenAI-compatible API
|
|
||||||
# For simplicity, we'll prepend it here. A more robust solution
|
|
||||||
# would add it as a separate message with the 'system' role.
|
|
||||||
prompt = f"{SYSTEM_PROMPT}\n\nUser Question: {message.content}"
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------- #
|
# ----------------------------------------------------------------------- #
|
||||||
# B. Query Open-WebUI and show typing indicator
|
# B. Query Open-WebUI and show typing indicator
|
||||||
# ----------------------------------------------------------------------- #
|
# ----------------------------------------------------------------------- #
|
||||||
|
initial_message = None
|
||||||
try:
|
try:
|
||||||
if USE_STREAMING:
|
if USE_STREAMING:
|
||||||
# Send initial "collecting information" message
|
initial_message = await message.reply(STREAMING_INITIAL_MESSAGE)
|
||||||
initial_message = await message.reply("Bitte warte kurz, die Informationen werden gesammelt...")
|
await _query_openwebui_streaming(prompt, TOOLS, initial_message)
|
||||||
|
|
||||||
# Start streaming response and edit the message
|
|
||||||
await _query_openwebui_streaming(prompt, message.channel.id, TOOLS, initial_message)
|
|
||||||
else:
|
else:
|
||||||
# Use the original non-streaming approach
|
|
||||||
async with message.channel.typing():
|
async with message.channel.typing():
|
||||||
reply = await _query_openwebui(prompt, message.channel.id, TOOLS)
|
reply = await _query_openwebui(prompt, TOOLS)
|
||||||
await message.reply(reply)
|
await message.reply(reply)
|
||||||
|
except RuntimeError as e:
|
||||||
|
if initial_message:
|
||||||
|
await initial_message.edit(content=f"⚠ Open-WebUI API error: {e}")
|
||||||
|
else:
|
||||||
|
await message.reply(f"⚠ Open-WebUI API error: {e}")
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
if initial_message:
|
||||||
|
await initial_message.edit(content=f"⚠ Network error contacting Open-WebUI API: {e}")
|
||||||
|
else:
|
||||||
|
await message.reply(f"⚠ Network error contacting Open-WebUI API: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# If we're in streaming mode and have an initial message, edit it with error
|
if initial_message:
|
||||||
if USE_STREAMING and 'initial_message' in locals():
|
|
||||||
await initial_message.edit(content=f"⚠ Error contacting the Open-WebUI API: {e}")
|
await initial_message.edit(content=f"⚠ Error contacting the Open-WebUI API: {e}")
|
||||||
else:
|
else:
|
||||||
await message.reply(f"⚠ Error contacting the Open-WebUI API: {e}")
|
await message.reply(f"⚠ Error contacting the Open-WebUI API: {e}")
|
||||||
# No need to return here as the function ends after this block.
|
|
||||||
|
|
||||||
# --------------------------------------------------------------------------- #
|
# --------------------------------------------------------------------------- #
|
||||||
# Start bot
|
# Start bot
|
||||||
|
|||||||
Reference in New Issue
Block a user