mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-05-07 03:32:18 +08:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 371e38cfa6 | |||
| 5a221848e9 |
@@ -86,6 +86,7 @@ class ChatGPTBot(Bot,OpenAIImage):
|
||||
"top_p":1,
|
||||
"frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"request_timeout": conf().get('request_timeout', 30), # 请求超时时间
|
||||
}
|
||||
|
||||
def reply_text(self, session:ChatGPTSession, session_id, retry_count=0) -> dict:
|
||||
|
||||
+65
-2
@@ -1,9 +1,13 @@
|
||||
|
||||
|
||||
|
||||
from asyncio import CancelledError
|
||||
import queue
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.expired_dict import ExpiredDict
|
||||
from channel.channel import Channel
|
||||
from bridge.reply import *
|
||||
@@ -20,8 +24,16 @@ except Exception as e:
|
||||
class ChatChannel(Channel):
|
||||
name = None # 登录的用户名
|
||||
user_id = None # 登录的用户id
|
||||
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
|
||||
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
|
||||
lock = threading.Lock() # 用于控制对sessions的访问
|
||||
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
_thread = threading.Thread(target=self.consume)
|
||||
_thread.setDaemon(True)
|
||||
_thread.start()
|
||||
|
||||
|
||||
# 根据消息构造context,消息内容相关的触发项写在这里
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
@@ -215,6 +227,57 @@ class ChatChannel(Channel):
|
||||
time.sleep(3+3*retry_cnt)
|
||||
self._send(reply, context, retry_cnt+1)
|
||||
|
||||
def thread_pool_callback(self, session_id):
|
||||
def func(worker:Future):
|
||||
try:
|
||||
worker_exception = worker.exception()
|
||||
if worker_exception:
|
||||
logger.exception("Worker return exception: {}".format(worker_exception))
|
||||
except CancelledError as e:
|
||||
logger.info("Worker cancelled, session_id = {}".format(session_id))
|
||||
except Exception as e:
|
||||
logger.exception("Worker raise exception: {}".format(e))
|
||||
with self.lock:
|
||||
self.sessions[session_id][1].release()
|
||||
return func
|
||||
|
||||
def produce(self, context: Context):
|
||||
session_id = context['session_id']
|
||||
with self.lock:
|
||||
if session_id not in self.sessions:
|
||||
self.sessions[session_id] = (queue.Queue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 1)))
|
||||
self.sessions[session_id][0].put(context)
|
||||
|
||||
# 消费者函数,单独线程,用于从消息队列中取出消息并处理
|
||||
def consume(self):
|
||||
while True:
|
||||
with self.lock:
|
||||
session_ids = list(self.sessions.keys())
|
||||
for session_id in session_ids:
|
||||
context_queue, semaphore = self.sessions[session_id]
|
||||
if semaphore.acquire(blocking = False): # 等线程处理完毕才能删除
|
||||
if not context_queue.empty():
|
||||
context = context_queue.get()
|
||||
logger.debug("[WX] consume context: {}".format(context))
|
||||
future:Future = self.handler_pool.submit(self._handle, context)
|
||||
future.add_done_callback(self.thread_pool_callback(session_id))
|
||||
if session_id not in self.futures:
|
||||
self.futures[session_id] = []
|
||||
self.futures[session_id].append(future)
|
||||
elif semaphore._initial_value == semaphore._value+1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
|
||||
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
|
||||
assert len(self.futures[session_id]) == 0, "thread pool error"
|
||||
del self.sessions[session_id]
|
||||
else:
|
||||
semaphore.release()
|
||||
time.sleep(0.1)
|
||||
|
||||
def cancel(self, session_id):
|
||||
with self.lock:
|
||||
if session_id in self.sessions:
|
||||
for future in self.futures[session_id]:
|
||||
future.cancel()
|
||||
self.sessions[session_id][0]=queue.Queue()
|
||||
|
||||
|
||||
def check_prefix(content, prefix_list):
|
||||
|
||||
@@ -5,6 +5,7 @@ wechat channel
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
import requests
|
||||
import io
|
||||
import time
|
||||
@@ -17,18 +18,10 @@ from lib import itchat
|
||||
from lib.itchat.content import *
|
||||
from bridge.reply import *
|
||||
from bridge.context import *
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from config import conf
|
||||
from common.time_check import time_checker
|
||||
from common.expired_dict import ExpiredDict
|
||||
from plugins import *
|
||||
thread_pool = ThreadPoolExecutor(max_workers=8)
|
||||
|
||||
def thread_pool_callback(worker):
|
||||
worker_exception = worker.exception()
|
||||
if worker_exception:
|
||||
logger.exception("Worker return exception: {}".format(worker_exception))
|
||||
|
||||
|
||||
@itchat.msg_register(TEXT)
|
||||
def handler_single_msg(msg):
|
||||
@@ -73,7 +66,9 @@ def qrCallback(uuid,status,qrcode):
|
||||
try:
|
||||
from PIL import Image
|
||||
img = Image.open(io.BytesIO(qrcode))
|
||||
thread_pool.submit(img.show,"QRCode")
|
||||
_thread = threading.Thread(target=img.show, args=("QRCode",))
|
||||
_thread.setDaemon(True)
|
||||
_thread.start()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
@@ -142,7 +137,7 @@ class WechatChannel(ChatChannel):
|
||||
logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback)
|
||||
self.produce(context)
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
@@ -150,7 +145,7 @@ class WechatChannel(ChatChannel):
|
||||
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback)
|
||||
self.produce(context)
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
@@ -158,7 +153,7 @@ class WechatChannel(ChatChannel):
|
||||
logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=True, msg=cmsg)
|
||||
if context:
|
||||
thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback)
|
||||
self.produce(context)
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
@@ -168,7 +163,7 @@ class WechatChannel(ChatChannel):
|
||||
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=True, msg=cmsg)
|
||||
if context:
|
||||
thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback)
|
||||
self.produce(context)
|
||||
|
||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
||||
def send(self, reply: Reply, context: Context):
|
||||
|
||||
@@ -5,7 +5,6 @@ wechaty channel
|
||||
Python Wechaty - https://github.com/wechaty/python-wechaty
|
||||
"""
|
||||
import base64
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
@@ -18,21 +17,18 @@ from bridge.context import *
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.wechat.wechaty_message import WechatyMessage
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
try:
|
||||
from voice.audio_convert import any_to_sil
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=8)
|
||||
def thread_pool_callback(worker):
|
||||
worker_exception = worker.exception()
|
||||
if worker_exception:
|
||||
logger.exception("Worker return exception: {}".format(worker_exception))
|
||||
@singleton
|
||||
class WechatyChannel(ChatChannel):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
super().__init__()
|
||||
|
||||
def startup(self):
|
||||
config = conf()
|
||||
@@ -41,6 +37,10 @@ class WechatyChannel(ChatChannel):
|
||||
asyncio.run(self.main())
|
||||
|
||||
async def main(self):
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
#将asyncio的loop传入处理线程
|
||||
self.handler_pool._initializer= lambda: asyncio.set_event_loop(loop)
|
||||
self.bot = Wechaty()
|
||||
self.bot.on('login', self.on_login)
|
||||
self.bot.on('message', self.on_message)
|
||||
@@ -122,8 +122,4 @@ class WechatyChannel(ChatChannel):
|
||||
context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg)
|
||||
if context:
|
||||
logger.info('[WX] receiveMsg={}, context={}'.format(cmsg, context))
|
||||
thread_pool.submit(self._handle_loop, context, asyncio.get_event_loop()).add_done_callback(thread_pool_callback)
|
||||
|
||||
def _handle_loop(self,context,loop):
|
||||
asyncio.set_event_loop(loop)
|
||||
self._handle(context)
|
||||
self.produce(context)
|
||||
@@ -27,6 +27,7 @@ available_setting = {
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
||||
"trigger_by_self": False, # 是否允许机器人触发
|
||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
||||
"concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序
|
||||
|
||||
# chatgpt会话参数
|
||||
"expires_in_seconds": 3600, # 无操作会话的过期时间
|
||||
@@ -43,6 +44,7 @@ available_setting = {
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"request_timeout": 30, # chatgpt请求超时时间
|
||||
|
||||
# 语音设置
|
||||
"speech_recognition": False, # 是否开启语音识别
|
||||
|
||||
Reference in New Issue
Block a user