Compare commits

..

10 Commits

Author SHA1 Message Date
lanvent 371e38cfa6 add concurrency_in_session,request_timeout options 2023-04-04 13:33:01 +08:00
lanvent 5a221848e9 feat: avoid disorder by producer-consumer model 2023-04-04 05:18:09 +08:00
lanvent 7458a6298f feat: add trigger_by_self option 2023-04-03 23:58:19 +08:00
lanvent b0f54bb8b7 fix: dirty message including at and prefix 2023-04-03 23:53:58 +08:00
lanvent acddadc406 feat: add convert pcm32 to pcm16 2023-04-03 22:55:39 +08:00
lanvent b74274b96b fix: old code in hello plugin 2023-04-03 02:00:33 +08:00
lanvent 49ba278316 fix: use english filename 2023-04-02 16:50:11 +08:00
lanvent 388058467c fix: delete same file twice 2023-04-02 14:55:45 +08:00
lanvent cf25bd7869 feat: itchat show qrcode using viewer 2023-04-02 14:45:38 +08:00
lanvent 02a95345aa fix: add more qrcode api 2023-04-02 14:13:38 +08:00
12 changed files with 142 additions and 55 deletions
+1
View File
@@ -86,6 +86,7 @@ class ChatGPTBot(Bot,OpenAIImage):
"top_p":1,
"frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get('request_timeout', 30), # 请求超时时间
}
def reply_text(self, session:ChatGPTSession, session_id, retry_count=0) -> dict:
+80 -12
View File
@@ -1,9 +1,13 @@
from asyncio import CancelledError
import queue
from concurrent.futures import Future, ThreadPoolExecutor
import os
import re
import threading
import time
from channel.chat_message import ChatMessage
from common.expired_dict import ExpiredDict
from channel.channel import Channel
from bridge.reply import *
@@ -20,8 +24,16 @@ except Exception as e:
class ChatChannel(Channel):
name = None # 登录的用户名
user_id = None # 登录的用户id
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
lock = threading.Lock() # 用于控制对sessions的访问
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
def __init__(self):
pass
_thread = threading.Thread(target=self.consume)
_thread.setDaemon(True)
_thread.start()
# 根据消息构造context,消息内容相关的触发项写在这里
def _compose_context(self, ctype: ContextType, content, **kwargs):
@@ -38,7 +50,7 @@ class ChatChannel(Channel):
if first_in: # context首次传入时,receiver是None,根据类型设置receiver
config = conf()
cmsg = context['msg']
if cmsg.from_user_id == self.user_id:
if cmsg.from_user_id == self.user_id and not config.get('trigger_by_self', False):
logger.debug("[WX]self message skipped")
return None
if context["isgroup"]:
@@ -70,17 +82,21 @@ class ChatChannel(Channel):
# 校验关键字
match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
match_contain = check_contain(content, conf().get('group_chat_keyword'))
flag = False
if match_prefix is not None or match_contain is not None:
flag = True
if match_prefix:
content = content.replace(match_prefix, '', 1).strip()
elif context['msg'].is_at and not conf().get("group_at_off", False):
logger.info("[WX]receive group at, continue")
if context['msg'].is_at:
logger.info("[WX]receive group at")
if not conf().get("group_at_off", False):
flag = True
pattern = f'@{self.name}(\u2005|\u0020)'
content = re.sub(pattern, r'', content)
elif context["origin_ctype"] == ContextType.VOICE:
logger.info("[WX]receive group voice, checkprefix didn't match")
return None
else:
if not flag:
if context["origin_ctype"] == ContextType.VOICE:
logger.info("[WX]receive group voice, but checkprefix didn't match")
return None
else: # 单聊
match_prefix = check_prefix(content, conf().get('single_chat_prefix'))
@@ -106,7 +122,6 @@ class ChatChannel(Channel):
return context
# 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类
def _handle(self, context: Context):
if context is None or not context.content:
return
@@ -144,9 +159,11 @@ class ChatChannel(Channel):
# 删除临时文件
try:
os.remove(file_path)
os.remove(wav_path)
if wav_path != file_path:
os.remove(wav_path)
except Exception as e:
logger.warning("[WX]delete temp file error: " + str(e))
pass
# logger.warning("[WX]delete temp file error: " + str(e))
if reply.type == ReplyType.TEXT:
new_context = self._compose_context(
@@ -210,6 +227,57 @@ class ChatChannel(Channel):
time.sleep(3+3*retry_cnt)
self._send(reply, context, retry_cnt+1)
def thread_pool_callback(self, session_id):
def func(worker:Future):
try:
worker_exception = worker.exception()
if worker_exception:
logger.exception("Worker return exception: {}".format(worker_exception))
except CancelledError as e:
logger.info("Worker cancelled, session_id = {}".format(session_id))
except Exception as e:
logger.exception("Worker raise exception: {}".format(e))
with self.lock:
self.sessions[session_id][1].release()
return func
def produce(self, context: Context):
session_id = context['session_id']
with self.lock:
if session_id not in self.sessions:
self.sessions[session_id] = (queue.Queue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 1)))
self.sessions[session_id][0].put(context)
# 消费者函数,单独线程,用于从消息队列中取出消息并处理
def consume(self):
while True:
with self.lock:
session_ids = list(self.sessions.keys())
for session_id in session_ids:
context_queue, semaphore = self.sessions[session_id]
if semaphore.acquire(blocking = False): # 等线程处理完毕才能删除
if not context_queue.empty():
context = context_queue.get()
logger.debug("[WX] consume context: {}".format(context))
future:Future = self.handler_pool.submit(self._handle, context)
future.add_done_callback(self.thread_pool_callback(session_id))
if session_id not in self.futures:
self.futures[session_id] = []
self.futures[session_id].append(future)
elif semaphore._initial_value == semaphore._value+1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
assert len(self.futures[session_id]) == 0, "thread pool error"
del self.sessions[session_id]
else:
semaphore.release()
time.sleep(0.1)
def cancel(self, session_id):
with self.lock:
if session_id in self.sessions:
for future in self.futures[session_id]:
future.cancel()
self.sessions[session_id][0]=queue.Queue()
def check_prefix(content, prefix_list):
+24 -15
View File
@@ -5,6 +5,7 @@ wechat channel
"""
import os
import threading
import requests
import io
import time
@@ -17,18 +18,10 @@ from lib import itchat
from lib.itchat.content import *
from bridge.reply import *
from bridge.context import *
from concurrent.futures import ThreadPoolExecutor
from config import conf
from common.time_check import time_checker
from common.expired_dict import ExpiredDict
from plugins import *
thread_pool = ThreadPoolExecutor(max_workers=8)
def thread_pool_callback(worker):
worker_exception = worker.exception()
if worker_exception:
logger.exception("Worker return exception: {}".format(worker_exception))
@itchat.msg_register(TEXT)
def handler_single_msg(msg):
@@ -70,11 +63,27 @@ def _check(func):
def qrCallback(uuid,status,qrcode):
# logger.debug("qrCallback: {} {}".format(uuid,status))
if status == '0':
try:
from PIL import Image
img = Image.open(io.BytesIO(qrcode))
_thread = threading.Thread(target=img.show, args=("QRCode",))
_thread.setDaemon(True)
_thread.start()
except Exception as e:
pass
import qrcode
url = f"https://login.weixin.qq.com/l/{uuid}"
qr_api="https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
print("You can also scan QRCode in the website below:\n{}".format(qr_api))
qr_api1="https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
qr_api2="https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url)
qr_api3="https://api.pwmqr.com/qrcode/create/?url={}".format(url)
qr_api4="https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url)
print("You can also scan QRCode in any website below:")
print(qr_api3)
print(qr_api4)
print(qr_api2)
print(qr_api1)
qr = qrcode.QRCode(border=1)
qr.add_data(url)
@@ -128,7 +137,7 @@ class WechatChannel(ChatChannel):
logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=False, msg=cmsg)
if context:
thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback)
self.produce(context)
@time_checker
@_check
@@ -136,7 +145,7 @@ class WechatChannel(ChatChannel):
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=False, msg=cmsg)
if context:
thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback)
self.produce(context)
@time_checker
@_check
@@ -144,7 +153,7 @@ class WechatChannel(ChatChannel):
logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=True, msg=cmsg)
if context:
thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback)
self.produce(context)
@time_checker
@_check
@@ -154,7 +163,7 @@ class WechatChannel(ChatChannel):
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=True, msg=cmsg)
if context:
thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback)
self.produce(context)
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
def send(self, reply: Reply, context: Context):
+10 -13
View File
@@ -5,7 +5,6 @@ wechaty channel
Python Wechaty - https://github.com/wechaty/python-wechaty
"""
import base64
from concurrent.futures import ThreadPoolExecutor
import os
import time
import asyncio
@@ -18,21 +17,18 @@ from bridge.context import *
from channel.chat_channel import ChatChannel
from channel.wechat.wechaty_message import WechatyMessage
from common.log import logger
from common.singleton import singleton
from config import conf
try:
from voice.audio_convert import any_to_sil
except Exception as e:
pass
thread_pool = ThreadPoolExecutor(max_workers=8)
def thread_pool_callback(worker):
worker_exception = worker.exception()
if worker_exception:
logger.exception("Worker return exception: {}".format(worker_exception))
@singleton
class WechatyChannel(ChatChannel):
def __init__(self):
pass
super().__init__()
def startup(self):
config = conf()
@@ -41,6 +37,10 @@ class WechatyChannel(ChatChannel):
asyncio.run(self.main())
async def main(self):
loop = asyncio.get_event_loop()
#将asyncio的loop传入处理线程
self.handler_pool._initializer= lambda: asyncio.set_event_loop(loop)
self.bot = Wechaty()
self.bot.on('login', self.on_login)
self.bot.on('message', self.on_message)
@@ -84,7 +84,8 @@ class WechatyChannel(ChatChannel):
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
try:
os.remove(file_path)
os.remove(sil_file)
if sil_file != file_path:
os.remove(sil_file)
except Exception as e:
pass
logger.info('[WX] sendVoice={}, receiver={}'.format(reply.content, receiver))
@@ -121,8 +122,4 @@ class WechatyChannel(ChatChannel):
context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg)
if context:
logger.info('[WX] receiveMsg={}, context={}'.format(cmsg, context))
thread_pool.submit(self._handle_loop, context, asyncio.get_event_loop()).add_done_callback(thread_pool_callback)
def _handle_loop(self,context,loop):
asyncio.set_event_loop(loop)
self._handle(context)
self.produce(context)
+3
View File
@@ -25,7 +25,9 @@ available_setting = {
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
"group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
"trigger_by_self": False, # 是否允许机器人触发
"image_create_prefix": ["", "", ""], # 开启图片回复的前缀
"concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序
# chatgpt会话参数
"expires_in_seconds": 3600, # 无操作会话的过期时间
@@ -42,6 +44,7 @@ available_setting = {
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"request_timeout": 30, # chatgpt请求超时时间
# 语音设置
"speech_recognition": False, # 是否开启语音识别
+8 -5
View File
@@ -101,7 +101,7 @@ PS: 插件目前支持`itchat`和`wechaty`
根据`Context`和回复`Reply`的类型,对回复的内容进行装饰。目前的装饰有以下两种:
- `TEXT`文本回复根据是否在群聊中来决定是艾特接收方还是添加回复的前缀。
- `TEXT`文本回复:如果这次消息需要的回复是`VOICE`,进行文字转语音回复之后再次装饰。 否则根据是否在群聊中来决定是艾特接收方还是添加回复的前缀。
- `INFO``ERROR`类型,会在消息前添加对应的系统提示字样。
@@ -110,8 +110,11 @@ PS: 插件目前支持`itchat`和`wechaty`
```python
if reply.type == ReplyType.TEXT:
reply_text = reply.content
if context.get('desire_rtype') == ReplyType.VOICE:
reply = super().build_text_to_voice(reply.content)
return self._decorate_reply(context, reply)
if context['isgroup']:
reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip()
reply_text = '@' + context['msg'].actual_user_nickname + ' ' + reply_text.strip()
reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
else:
reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
@@ -213,11 +216,11 @@ class Hello(Plugin):
if content == "Hello":
reply = Reply()
reply.type = ReplyType.TEXT
msg = e_context['context']['msg']
msg:ChatMessage = e_context['context']['msg']
if e_context['context']['isgroup']:
reply.content = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group")
reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
else:
reply.content = "Hello, " + msg['User'].get('NickName', "My friend")
reply.content = f"Hello, {msg.from_user_nickname}"
e_context['reply'] = reply
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
if content == "End":
+4 -3
View File
@@ -2,6 +2,7 @@
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from channel.chat_message import ChatMessage
import plugins
from plugins import *
from common.log import logger
@@ -24,11 +25,11 @@ class Hello(Plugin):
if content == "Hello":
reply = Reply()
reply.type = ReplyType.TEXT
msg = e_context['context']['msg']
msg:ChatMessage = e_context['context']['msg']
if e_context['context']['isgroup']:
reply.content = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group")
reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
else:
reply.content = "Hello, " + msg['User'].get('NickName', "My friend")
reply.content = f"Hello, {msg.from_user_nickname}"
e_context['reply'] = reply
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
+8 -3
View File
@@ -67,23 +67,28 @@ def pcm_to_sil(pcm_path, silk_path):
return 声音长度,毫秒
"""
audio = AudioSegment.from_wav(pcm_path)
wav_data = audio.raw_data
rate = find_closest_sil_supports(audio.frame_rate)
# Convert to PCM_s16
pcm_s16 = audio.set_sample_width(2)
pcm_s16 = pcm_s16.set_frame_rate(rate)
wav_data = pcm_s16.raw_data
silk_data = pysilk.encode(
wav_data, data_rate=rate, sample_rate=rate)
with open(silk_path, "wb") as f:
f.write(silk_data)
return audio.duration_seconds * 1000
def mp3_to_sil(mp3_path, silk_path):
"""
mp3 文件转成 silk
return 声音长度,毫秒
"""
audio = AudioSegment.from_mp3(mp3_path)
wav_data = audio.raw_data
rate = find_closest_sil_supports(audio.frame_rate)
# Convert to PCM_s16
pcm_s16 = audio.set_sample_width(2)
pcm_s16 = pcm_s16.set_frame_rate(rate)
wav_data = pcm_s16.raw_data
silk_data = pysilk.encode(wav_data, data_rate=rate, sample_rate=rate)
# Save the silk file
with open(silk_path, "wb") as f:
+1 -1
View File
@@ -54,7 +54,7 @@ class AzureVoice(Voice):
return reply
def textToVoice(self, text):
fileName = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.wav'
fileName = TmpDir().path() + 'reply-' + str(int(time.time())) + '.wav'
audio_config = speechsdk.AudioConfig(filename=fileName)
speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=self.speech_config, audio_config=audio_config)
result = speech_synthesizer.speak_text(text)
+1 -1
View File
@@ -80,7 +80,7 @@ class BaiduVoice(Voice):
result = self.client.synthesis(text, self.lang, self.ctp, {
'spd': self.spd, 'pit': self.pit, 'vol': self.vol, 'per': self.per})
if not isinstance(result, dict):
fileName = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
fileName = TmpDir().path() + 'reply-' + str(int(time.time())) + '.mp3'
with open(fileName, 'wb') as f:
f.write(result)
logger.info(
+1 -1
View File
@@ -34,7 +34,7 @@ class GoogleVoice(Voice):
return reply
def textToVoice(self, text):
try:
mp3File = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
mp3File = TmpDir().path() + 'reply-' + str(int(time.time())) + '.mp3'
tts = gTTS(text=text, lang='zh')
tts.save(mp3File)
logger.info(
+1 -1
View File
@@ -25,7 +25,7 @@ class PyttsVoice(Voice):
def textToVoice(self, text):
try:
wavFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.wav'
wavFile = TmpDir().path() + 'reply-' + str(int(time.time())) + '.wav'
self.engine.save_to_file(text, wavFile)
self.engine.runAndWait()
logger.info(