mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-04-29 13:42:45 +08:00
feat: web channel stream chat
This commit is contained in:
@@ -3,7 +3,6 @@ import time
|
||||
import web
|
||||
import json
|
||||
import uuid
|
||||
import io
|
||||
from queue import Queue, Empty
|
||||
from bridge.context import *
|
||||
from bridge.reply import Reply, ReplyType
|
||||
@@ -13,7 +12,7 @@ from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
import os
|
||||
import mimetypes # 添加这行来处理MIME类型
|
||||
import mimetypes
|
||||
import threading
|
||||
import logging
|
||||
|
||||
@@ -47,9 +46,10 @@ class WebChannel(ChatChannel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.msg_id_counter = 0 # 添加消息ID计数器
|
||||
self.session_queues = {} # 存储session_id到队列的映射
|
||||
self.request_to_session = {} # 存储request_id到session_id的映射
|
||||
self.msg_id_counter = 0
|
||||
self.session_queues = {} # session_id -> Queue (fallback polling)
|
||||
self.request_to_session = {} # request_id -> session_id
|
||||
self.sse_queues = {} # request_id -> Queue (SSE streaming)
|
||||
self._http_server = None
|
||||
|
||||
|
||||
@@ -71,22 +71,30 @@ class WebChannel(ChatChannel):
|
||||
if reply.type == ReplyType.IMAGE_URL:
|
||||
time.sleep(0.5)
|
||||
|
||||
# 获取请求ID和会话ID
|
||||
request_id = context.get("request_id", None)
|
||||
|
||||
if not request_id:
|
||||
logger.error("No request_id found in context, cannot send message")
|
||||
return
|
||||
|
||||
# 通过request_id获取session_id
|
||||
|
||||
session_id = self.request_to_session.get(request_id)
|
||||
if not session_id:
|
||||
logger.error(f"No session_id found for request {request_id}")
|
||||
return
|
||||
|
||||
# 检查是否有会话队列
|
||||
|
||||
# SSE mode: push done event to SSE queue
|
||||
if request_id in self.sse_queues:
|
||||
content = reply.content if reply.content is not None else ""
|
||||
self.sse_queues[request_id].put({
|
||||
"type": "done",
|
||||
"content": content,
|
||||
"request_id": request_id,
|
||||
"timestamp": time.time()
|
||||
})
|
||||
logger.debug(f"SSE done sent for request {request_id}")
|
||||
return
|
||||
|
||||
# Fallback: polling mode
|
||||
if session_id in self.session_queues:
|
||||
# 创建响应数据,包含请求ID以区分不同请求的响应
|
||||
response_data = {
|
||||
"type": str(reply.type),
|
||||
"content": reply.content,
|
||||
@@ -94,69 +102,133 @@ class WebChannel(ChatChannel):
|
||||
"request_id": request_id
|
||||
}
|
||||
self.session_queues[session_id].put(response_data)
|
||||
logger.debug(f"Response sent to queue for session {session_id}, request {request_id}")
|
||||
logger.debug(f"Response sent to poll queue for session {session_id}, request {request_id}")
|
||||
else:
|
||||
logger.warning(f"No response queue found for session {session_id}, response dropped")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in send method: {e}")
|
||||
|
||||
def _make_sse_callback(self, request_id: str):
|
||||
"""Build an on_event callback that pushes agent stream events into the SSE queue."""
|
||||
def on_event(event: dict):
|
||||
if request_id not in self.sse_queues:
|
||||
return
|
||||
q = self.sse_queues[request_id]
|
||||
event_type = event.get("type")
|
||||
data = event.get("data", {})
|
||||
|
||||
if event_type == "message_update":
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
q.put({"type": "delta", "content": delta})
|
||||
|
||||
elif event_type == "tool_execution_start":
|
||||
tool_name = data.get("tool_name", "tool")
|
||||
arguments = data.get("arguments", {})
|
||||
q.put({"type": "tool_start", "tool": tool_name, "arguments": arguments})
|
||||
|
||||
elif event_type == "tool_execution_end":
|
||||
tool_name = data.get("tool_name", "tool")
|
||||
status = data.get("status", "success")
|
||||
result = data.get("result", "")
|
||||
exec_time = data.get("execution_time", 0)
|
||||
# Truncate long results to avoid huge SSE payloads
|
||||
result_str = str(result)
|
||||
if len(result_str) > 2000:
|
||||
result_str = result_str[:2000] + "…"
|
||||
q.put({
|
||||
"type": "tool_end",
|
||||
"tool": tool_name,
|
||||
"status": status,
|
||||
"result": result_str,
|
||||
"execution_time": round(exec_time, 2)
|
||||
})
|
||||
|
||||
return on_event
|
||||
|
||||
def post_message(self):
|
||||
"""
|
||||
Handle incoming messages from users via POST request.
|
||||
Returns a request_id for tracking this specific request.
|
||||
"""
|
||||
try:
|
||||
data = web.data() # 获取原始POST数据
|
||||
data = web.data()
|
||||
json_data = json.loads(data)
|
||||
session_id = json_data.get('session_id', f'session_{int(time.time())}')
|
||||
prompt = json_data.get('message', '')
|
||||
|
||||
# 生成请求ID
|
||||
use_sse = json_data.get('stream', True)
|
||||
|
||||
request_id = self._generate_request_id()
|
||||
|
||||
# 将请求ID与会话ID关联
|
||||
self.request_to_session[request_id] = session_id
|
||||
|
||||
# 确保会话队列存在
|
||||
|
||||
if session_id not in self.session_queues:
|
||||
self.session_queues[session_id] = Queue()
|
||||
|
||||
# Web channel 不需要前缀,确保消息能通过前缀检查
|
||||
|
||||
if use_sse:
|
||||
self.sse_queues[request_id] = Queue()
|
||||
|
||||
trigger_prefixs = conf().get("single_chat_prefix", [""])
|
||||
if check_prefix(prompt, trigger_prefixs) is None:
|
||||
# 如果没有匹配到前缀,给消息加上第一个前缀
|
||||
if trigger_prefixs:
|
||||
prompt = trigger_prefixs[0] + prompt
|
||||
logger.debug(f"[WebChannel] Added prefix to message: {prompt}")
|
||||
|
||||
# 创建消息对象
|
||||
|
||||
msg = WebMessage(self._generate_msg_id(), prompt)
|
||||
msg.from_user_id = session_id # 使用会话ID作为用户ID
|
||||
|
||||
# 创建上下文,明确指定 isgroup=False
|
||||
msg.from_user_id = session_id
|
||||
|
||||
context = self._compose_context(ContextType.TEXT, prompt, msg=msg, isgroup=False)
|
||||
|
||||
# 检查 context 是否为 None(可能被插件过滤等)
|
||||
|
||||
if context is None:
|
||||
logger.warning(f"[WebChannel] Context is None for session {session_id}, message may be filtered")
|
||||
if request_id in self.sse_queues:
|
||||
del self.sse_queues[request_id]
|
||||
return json.dumps({"status": "error", "message": "Message was filtered"})
|
||||
|
||||
# 覆盖必要的字段(_compose_context 会设置默认值,但我们需要使用实际的 session_id)
|
||||
context["session_id"] = session_id
|
||||
context["receiver"] = session_id
|
||||
context["request_id"] = request_id
|
||||
|
||||
# 异步处理消息 - 只传递上下文
|
||||
|
||||
if use_sse:
|
||||
context["on_event"] = self._make_sse_callback(request_id)
|
||||
|
||||
threading.Thread(target=self.produce, args=(context,)).start()
|
||||
|
||||
# 返回请求ID
|
||||
return json.dumps({"status": "success", "request_id": request_id})
|
||||
|
||||
|
||||
return json.dumps({"status": "success", "request_id": request_id, "stream": use_sse})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
return json.dumps({"status": "error", "message": str(e)})
|
||||
|
||||
def stream_response(self, request_id: str):
|
||||
"""
|
||||
SSE generator for a given request_id.
|
||||
Yields UTF-8 encoded bytes to avoid WSGI Latin-1 mangling.
|
||||
"""
|
||||
if request_id not in self.sse_queues:
|
||||
yield b"data: {\"type\": \"error\", \"message\": \"invalid request_id\"}\n\n"
|
||||
return
|
||||
|
||||
q = self.sse_queues[request_id]
|
||||
timeout = 300 # 5 minutes max
|
||||
deadline = time.time() + timeout
|
||||
|
||||
try:
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
item = q.get(timeout=1)
|
||||
except Empty:
|
||||
yield b": keepalive\n\n"
|
||||
continue
|
||||
|
||||
payload = json.dumps(item, ensure_ascii=False)
|
||||
yield f"data: {payload}\n\n".encode("utf-8")
|
||||
|
||||
if item.get("type") == "done":
|
||||
break
|
||||
finally:
|
||||
self.sse_queues.pop(request_id, None)
|
||||
|
||||
def poll_response(self):
|
||||
"""
|
||||
Poll for responses using the session_id.
|
||||
@@ -223,6 +295,7 @@ class WebChannel(ChatChannel):
|
||||
'/', 'RootHandler',
|
||||
'/message', 'MessageHandler',
|
||||
'/poll', 'PollHandler',
|
||||
'/stream', 'StreamHandler',
|
||||
'/chat', 'ChatHandler',
|
||||
'/config', 'ConfigHandler',
|
||||
'/assets/(.*)', 'AssetsHandler',
|
||||
@@ -272,6 +345,21 @@ class PollHandler:
|
||||
return WebChannel().poll_response()
|
||||
|
||||
|
||||
class StreamHandler:
|
||||
def GET(self):
|
||||
params = web.input(request_id='')
|
||||
request_id = params.request_id
|
||||
if not request_id:
|
||||
raise web.badrequest()
|
||||
|
||||
web.header('Content-Type', 'text/event-stream; charset=utf-8')
|
||||
web.header('Cache-Control', 'no-cache')
|
||||
web.header('X-Accel-Buffering', 'no')
|
||||
web.header('Access-Control-Allow-Origin', '*')
|
||||
|
||||
return WebChannel().stream_response(request_id)
|
||||
|
||||
|
||||
class ChatHandler:
|
||||
def GET(self):
|
||||
# 正常返回聊天页面
|
||||
|
||||
Reference in New Issue
Block a user