Compare commits

...

36 Commits

Author SHA1 Message Date
zhayujie 94004b095b fix: no debug config #744 2023-04-04 15:59:56 +08:00
lanvent f652d592bd fix: typo in dequeue 2023-04-04 15:10:35 +08:00
lanvent 186e18fe94 godcmd: load clear_memory_commands 2023-04-04 14:58:51 +08:00
lanvent 28eb67bc24 feat: reset will cancel unprocessed messages 2023-04-04 14:57:38 +08:00
lanvent 6c7e4aaf37 feat: prioritize handling commands 2023-04-04 14:29:03 +08:00
lanvent 709a1317ef feat: add debug option 2023-04-04 14:02:14 +08:00
lanvent 371e38cfa6 add concurrency_in_session,request_timeout options 2023-04-04 13:33:01 +08:00
lanvent 5a221848e9 feat: avoid disorder by producer-consumer model 2023-04-04 05:18:09 +08:00
lanvent 7458a6298f feat: add trigger_by_self option 2023-04-03 23:58:19 +08:00
lanvent b0f54bb8b7 fix: dirty message including at and prefix 2023-04-03 23:53:58 +08:00
lanvent acddadc406 feat: add convert pcm32 to pcm16 2023-04-03 22:55:39 +08:00
lanvent b74274b96b fix: old code in hello plugin 2023-04-03 02:00:33 +08:00
lanvent 49ba278316 fix: use english filename 2023-04-02 16:50:11 +08:00
lanvent 388058467c fix: delete same file twice 2023-04-02 14:55:45 +08:00
lanvent cf25bd7869 feat: itchat show qrcode using viewer 2023-04-02 14:45:38 +08:00
lanvent 02a95345aa fix: add more qrcode api 2023-04-02 14:13:38 +08:00
lanvent 6076e2ed0a fix: voice longer than 60s cannot be sent 2023-04-02 12:29:10 +08:00
lanvent cec674cb47 update qrcode 2023-04-02 04:44:08 +08:00
Jianglang c5a90823fa Update README.md 2023-04-02 04:30:40 +08:00
Jianglang 18d82bc1f0 Update README.md 2023-04-02 04:23:13 +08:00
lanvent a68af990ea update Readme.md 2023-04-02 04:19:50 +08:00
lanvent e71c600d10 feat: new itchat qrcode generator 2023-04-02 03:46:09 +08:00
lanvent d7f1f7182c feat: add always_reply_voice option 2023-04-01 22:27:11 +08:00
lanvent dfb2e460b4 fix: voice length bug in wechaty 2023-04-01 21:58:55 +08:00
lanvent 5badef8ba9 fix: correct sample rate when convert to silk 2023-04-01 20:59:52 +08:00
lanvent 18aa5ce75c fix: get correct audio format in pytts 2023-04-01 20:58:06 +08:00
lanvent 1545a9f262 feat: support azure voice 2023-04-01 16:36:27 +08:00
Jianglang 47cc65a787 Merge pull request #720 from lanvent/master
wechaty支持插件
2023-04-01 05:01:31 +08:00
lanvent cda9d5873d plugins: support wechaty channel 2023-04-01 04:26:34 +08:00
lanvent 02cd553990 refactor: using one processing logic in chat channel 2023-04-01 04:24:00 +08:00
lanvent 87df588c80 refactor: stripp processing unrelated to channel 2023-03-31 22:31:10 +08:00
lanvent 4ad2997717 compatibility: lower boolean values in env 2023-03-31 15:25:02 +08:00
lanvent 50a03e7c15 refactor: wrap wechaty message 2023-03-31 05:36:53 +08:00
lanvent 4f3d12129c refactor: wrap wechat msg to make logic general 2023-03-31 05:36:52 +08:00
lanvent 37a95980d4 feat: support at everywhere 2023-03-31 05:27:19 +08:00
lanvent d9ef5a6612 fix: 无前缀触发bug 2023-03-30 18:26:44 +08:00
28 changed files with 1008 additions and 651 deletions
+1 -2
View File
@@ -12,8 +12,7 @@ name: Create and publish a Docker image
on:
push:
branches: ['master']
release:
types: [published]
create:
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}
+1
View File
@@ -10,3 +10,4 @@ nohup.out
tmp
plugins.json
itchat.pkl
*.log
+7
View File
@@ -90,6 +90,13 @@ pip3 install -r requirements.txt
参考[#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)
使用`azure`语音功能需安装依赖:
```bash
pip3 install azure-cognitiveservices-speech
```
> 目前默认发布的镜像和`railway`部署,都基于`apline`,无法安装`azure`的依赖。若有需求请自行基于[`debian`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/docker/Dockerfile.debian.latest)打包。
参考[文档](https://learn.microsoft.com/en-us/azure/cognitive-services/speech-service/quickstarts/setup-platform?pivots=programming-language-python&tabs=linux%2Cubuntu%2Cdotnet%2Cjre%2Cmaven%2Cnodejs%2Cmac%2Cpypi)
## 配置
配置文件的模板在根目录的`config-template.json`中,需复制该模板创建最终生效的 `config.json` 文件:
+6 -1
View File
@@ -1,5 +1,6 @@
# encoding:utf-8
import os
from config import conf, load_config
from channel import channel_factory
from common.log import logger
@@ -13,8 +14,12 @@ def run():
# create channel
channel_name=conf().get('channel_type', 'wx')
if channel_name == 'wxy':
os.environ['WECHATY_LOG']="warn"
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
channel = channel_factory.create_channel(channel_name)
if channel_name=='wx':
if channel_name in ['wx','wxy']:
PluginManager().load_plugins()
# startup channel
+1
View File
@@ -86,6 +86,7 @@ class ChatGPTBot(Bot,OpenAIImage):
"top_p":1,
"frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"request_timeout": conf().get('request_timeout', 30), # 请求超时时间
}
def reply_text(self, session:ChatGPTSession, session_id, retry_count=0) -> dict:
+2 -1
View File
@@ -20,7 +20,8 @@ class Channel(object):
"""
raise NotImplementedError
def send(self, msg, receiver):
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
def send(self, reply: Reply, context: Context):
"""
send message to user
:param msg: message content
+310
View File
@@ -0,0 +1,310 @@
from asyncio import CancelledError
from concurrent.futures import Future, ThreadPoolExecutor
import os
import re
import threading
import time
from common.dequeue import Dequeue
from channel.channel import Channel
from bridge.reply import *
from bridge.context import *
from config import conf
from common.log import logger
from plugins import *
try:
from voice.audio_convert import any_to_wav
except Exception as e:
pass
# 抽象类, 它包含了与消息通道无关的通用处理逻辑
class ChatChannel(Channel):
name = None # 登录的用户名
user_id = None # 登录的用户id
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
lock = threading.Lock() # 用于控制对sessions的访问
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
def __init__(self):
_thread = threading.Thread(target=self.consume)
_thread.setDaemon(True)
_thread.start()
# 根据消息构造context,消息内容相关的触发项写在这里
def _compose_context(self, ctype: ContextType, content, **kwargs):
context = Context(ctype, content)
context.kwargs = kwargs
# context首次传入时,origin_ctype是None,
# 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。
# origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀
if 'origin_ctype' not in context:
context['origin_ctype'] = ctype
# context首次传入时,receiver是None,根据类型设置receiver
first_in = 'receiver' not in context
# 群名匹配过程,设置session_id和receiver
if first_in: # context首次传入时,receiver是None,根据类型设置receiver
config = conf()
cmsg = context['msg']
if cmsg.from_user_id == self.user_id and not config.get('trigger_by_self', False):
logger.debug("[WX]self message skipped")
return None
if context["isgroup"]:
group_name = cmsg.other_user_nickname
group_id = cmsg.other_user_id
group_name_white_list = config.get('group_name_white_list', [])
group_name_keyword_white_list = config.get('group_name_keyword_white_list', [])
if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]):
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
session_id = cmsg.actual_user_id
if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]):
session_id = group_id
else:
return None
context['session_id'] = session_id
context['receiver'] = group_id
else:
context['session_id'] = cmsg.other_user_id
context['receiver'] = cmsg.other_user_id
# 消息内容匹配过程,并处理content
if ctype == ContextType.TEXT:
if first_in and "\n- - - - - - -" in content: # 初次匹配 过滤引用消息
logger.debug("[WX]reference query skipped")
return None
if context["isgroup"]: # 群聊
# 校验关键字
match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
match_contain = check_contain(content, conf().get('group_chat_keyword'))
flag = False
if match_prefix is not None or match_contain is not None:
flag = True
if match_prefix:
content = content.replace(match_prefix, '', 1).strip()
if context['msg'].is_at:
logger.info("[WX]receive group at")
if not conf().get("group_at_off", False):
flag = True
pattern = f'@{self.name}(\u2005|\u0020)'
content = re.sub(pattern, r'', content)
if not flag:
if context["origin_ctype"] == ContextType.VOICE:
logger.info("[WX]receive group voice, but checkprefix didn't match")
return None
else: # 单聊
match_prefix = check_prefix(content, conf().get('single_chat_prefix'))
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
content = content.replace(match_prefix, '', 1).strip()
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
pass
else:
return None
img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
if img_match_prefix:
content = content.replace(img_match_prefix, '', 1).strip()
context.type = ContextType.IMAGE_CREATE
else:
context.type = ContextType.TEXT
context.content = content
if 'desire_rtype' not in context and conf().get('always_reply_voice'):
context['desire_rtype'] = ReplyType.VOICE
elif context.type == ContextType.VOICE:
if 'desire_rtype' not in context and conf().get('voice_reply_voice'):
context['desire_rtype'] = ReplyType.VOICE
return context
def _handle(self, context: Context):
if context is None or not context.content:
return
logger.debug('[WX] ready to handle context: {}'.format(context))
# reply的构建步骤
reply = self._generate_reply(context)
logger.debug('[WX] ready to decorate reply: {}'.format(reply))
# reply的包装步骤
reply = self._decorate_reply(context, reply)
# reply的发送步骤
self._send_reply(context, reply)
def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
'channel': self, 'context': context, 'reply': reply}))
reply = e_context['reply']
if not e_context.is_pass():
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content))
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
reply = super().build_reply_content(context.content, context)
elif context.type == ContextType.VOICE: # 语音消息
cmsg = context['msg']
cmsg.prepare()
file_path = context.content
wav_path = os.path.splitext(file_path)[0] + '.wav'
try:
any_to_wav(file_path, wav_path)
except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
logger.warning("[WX]any to wav error, use raw path. " + str(e))
wav_path = file_path
# 语音识别
reply = super().build_voice_to_text(wav_path)
# 删除临时文件
try:
os.remove(file_path)
if wav_path != file_path:
os.remove(wav_path)
except Exception as e:
pass
# logger.warning("[WX]delete temp file error: " + str(e))
if reply.type == ReplyType.TEXT:
new_context = self._compose_context(
ContextType.TEXT, reply.content, **context.kwargs)
if new_context:
reply = self._generate_reply(new_context)
else:
return
else:
logger.error('[WX] unknown context type: {}'.format(context.type))
return
return reply
def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
if reply and reply.type:
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {
'channel': self, 'context': context, 'reply': reply}))
reply = e_context['reply']
desire_rtype = context.get('desire_rtype')
if not e_context.is_pass() and reply and reply.type:
if reply.type == ReplyType.TEXT:
reply_text = reply.content
if desire_rtype == ReplyType.VOICE:
reply = super().build_text_to_voice(reply.content)
return self._decorate_reply(context, reply)
if context['isgroup']:
reply_text = '@' + context['msg'].actual_user_nickname + ' ' + reply_text.strip()
reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
else:
reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
reply.content = reply_text
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
reply.content = str(reply.type)+":\n" + reply.content
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
pass
else:
logger.error('[WX] unknown reply type: {}'.format(reply.type))
return
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
logger.warning('[WX] desire_rtype: {}, but reply type: {}'.format(context.get('desire_rtype'), reply.type))
return reply
def _send_reply(self, context: Context, reply: Reply):
if reply and reply.type:
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {
'channel': self, 'context': context, 'reply': reply}))
reply = e_context['reply']
if not e_context.is_pass() and reply and reply.type:
logger.debug('[WX] ready to send reply: {}, context: {}'.format(reply, context))
self._send(reply, context)
def _send(self, reply: Reply, context: Context, retry_cnt = 0):
try:
self.send(reply, context)
except Exception as e:
logger.error('[WX] sendMsg error: {}'.format(str(e)))
if isinstance(e, NotImplementedError):
return
logger.exception(e)
if retry_cnt < 2:
time.sleep(3+3*retry_cnt)
self._send(reply, context, retry_cnt+1)
def thread_pool_callback(self, session_id):
def func(worker:Future):
try:
worker_exception = worker.exception()
if worker_exception:
logger.exception("Worker return exception: {}".format(worker_exception))
except CancelledError as e:
logger.info("Worker cancelled, session_id = {}".format(session_id))
except Exception as e:
logger.exception("Worker raise exception: {}".format(e))
with self.lock:
self.sessions[session_id][1].release()
return func
def produce(self, context: Context):
session_id = context['session_id']
with self.lock:
if session_id not in self.sessions:
self.sessions[session_id] = [Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 1))]
if context.type == ContextType.TEXT and context.content.startswith("#"):
self.sessions[session_id][0].putleft(context) # 优先处理管理命令
else:
self.sessions[session_id][0].put(context)
# 消费者函数,单独线程,用于从消息队列中取出消息并处理
def consume(self):
while True:
with self.lock:
session_ids = list(self.sessions.keys())
for session_id in session_ids:
context_queue, semaphore = self.sessions[session_id]
if semaphore.acquire(blocking = False): # 等线程处理完毕才能删除
if not context_queue.empty():
context = context_queue.get()
logger.debug("[WX] consume context: {}".format(context))
future:Future = self.handler_pool.submit(self._handle, context)
future.add_done_callback(self.thread_pool_callback(session_id))
if session_id not in self.futures:
self.futures[session_id] = []
self.futures[session_id].append(future)
elif semaphore._initial_value == semaphore._value+1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
assert len(self.futures[session_id]) == 0, "thread pool error"
del self.sessions[session_id]
else:
semaphore.release()
time.sleep(0.1)
# 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务
def cancel_session(self, session_id):
with self.lock:
if session_id in self.sessions:
for future in self.futures[session_id]:
future.cancel()
cnt = self.sessions[session_id][0].qsize()
if cnt>0:
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
self.sessions[session_id][0] = Dequeue()
def cancel_all_session(self):
with self.lock:
for session_id in self.sessions:
for future in self.futures[session_id]:
future.cancel()
cnt = self.sessions[session_id][0].qsize()
if cnt>0:
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
self.sessions[session_id][0] = Dequeue()
def check_prefix(content, prefix_list):
for prefix in prefix_list:
if content.startswith(prefix):
return prefix
return None
def check_contain(content, keyword_list):
if not keyword_list:
return None
for ky in keyword_list:
if content.find(ky) != -1:
return True
return None
+83
View File
@@ -0,0 +1,83 @@
"""
本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装
ChatMessage
msg_id: 消息id
create_time: 消息创建时间
ctype: 消息类型 : ContextType
content: 消息内容, 如果是声音/图片,这里是文件路径
from_user_id: 发送者id
from_user_nickname: 发送者昵称
to_user_id: 接收者id
to_user_nickname: 接收者昵称
other_user_id: 对方的id,如果你是发送者,那这个就是接收者id,如果你是接收者,那这个就是发送者id,如果是群消息,那这一直是群id
other_user_nickname: 同上
is_group: 是否是群消息
is_at: 是否被at
- (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在)
actual_user_id: 实际发送者id
actual_user_nickname:实际发送者昵称
_prepare_fn: 准备函数,用于准备消息的内容,比如下载图片等,
_prepared: 是否已经调用过准备函数
_rawmsg: 原始消息对象
"""
class ChatMessage(object):
msg_id = None
create_time = None
ctype = None
content = None
from_user_id = None
from_user_nickname = None
to_user_id = None
to_user_nickname = None
other_user_id = None
other_user_nickname = None
is_group = False
is_at = False
actual_user_id = None
actual_user_nickname = None
_prepare_fn = None
_prepared = False
_rawmsg = None
def __init__(self,_rawmsg):
self._rawmsg = _rawmsg
def prepare(self):
if self._prepare_fn and not self._prepared:
self._prepared = True
self._prepare_fn()
def __str__(self):
return 'ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}'.format(
self.msg_id,
self.create_time,
self.ctype,
self.content,
self.from_user_id,
self.from_user_nickname,
self.to_user_id,
self.to_user_nickname,
self.other_user_id,
self.other_user_nickname,
self.is_group,
self.is_at,
self.actual_user_id,
self.actual_user_nickname,
)
+105 -304
View File
@@ -5,74 +5,95 @@ wechat channel
"""
import os
import threading
import requests
import io
import time
from lib import itchat
import json
from channel.chat_channel import ChatChannel
from channel.wechat.wechat_message import *
from common.singleton import singleton
from common.log import logger
from lib import itchat
from lib.itchat.content import *
from bridge.reply import *
from bridge.context import *
from channel.channel import Channel
from concurrent.futures import ThreadPoolExecutor
from common.log import logger
from common.tmp_dir import TmpDir
from config import conf
from common.time_check import time_checker
from common.expired_dict import ExpiredDict
from plugins import *
try:
from voice.audio_convert import mp3_to_wav
except Exception as e:
pass
thread_pool = ThreadPoolExecutor(max_workers=8)
def thread_pool_callback(worker):
worker_exception = worker.exception()
if worker_exception:
logger.exception("Worker return exception: {}".format(worker_exception))
@itchat.msg_register(TEXT)
def handler_single_msg(msg):
WechatChannel().handle_text(msg)
WechatChannel().handle_text(WeChatMessage(msg))
return None
@itchat.msg_register(TEXT, isGroupChat=True)
def handler_group_msg(msg):
WechatChannel().handle_group(msg)
WechatChannel().handle_group(WeChatMessage(msg,True))
return None
@itchat.msg_register(VOICE)
def handler_single_voice(msg):
WechatChannel().handle_voice(msg)
WechatChannel().handle_voice(WeChatMessage(msg))
return None
@itchat.msg_register(VOICE, isGroupChat=True)
def handler_group_voice(msg):
WechatChannel().handle_group_voice(msg)
WechatChannel().handle_group_voice(WeChatMessage(msg,True))
return None
def _check(func):
def wrapper(self, msg):
msgId = msg['MsgId']
def wrapper(self, cmsg: ChatMessage):
msgId = cmsg.msg_id
if msgId in self.receivedMsgs:
logger.info("Wechat message {} already received, ignore".format(msgId))
return
self.receivedMsgs[msgId] = msg
create_time = msg['CreateTime'] # 消息时间
self.receivedMsgs[msgId] = cmsg
create_time = cmsg.create_time # 消息时间
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
logger.debug("[WX]history message {} skipped".format(msgId))
return
return func(self, msg)
return func(self, cmsg)
return wrapper
#可用的二维码生成接口
#https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com
#https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com
def qrCallback(uuid,status,qrcode):
# logger.debug("qrCallback: {} {}".format(uuid,status))
if status == '0':
try:
from PIL import Image
img = Image.open(io.BytesIO(qrcode))
_thread = threading.Thread(target=img.show, args=("QRCode",))
_thread.setDaemon(True)
_thread.start()
except Exception as e:
pass
class WechatChannel(Channel):
import qrcode
url = f"https://login.weixin.qq.com/l/{uuid}"
qr_api1="https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
qr_api2="https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url)
qr_api3="https://api.pwmqr.com/qrcode/create/?url={}".format(url)
qr_api4="https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url)
print("You can also scan QRCode in any website below:")
print(qr_api3)
print(qr_api4)
print(qr_api2)
print(qr_api1)
qr = qrcode.QRCode(border=1)
qr.add_data(url)
qr.make(fit=True)
qr.print_ascii(invert=True)
@singleton
class WechatChannel(ChatChannel):
def __init__(self):
self.userName = None
self.nickName = None
super().__init__()
self.receivedMsgs = ExpiredDict(60*60*24)
def startup(self):
@@ -81,22 +102,22 @@ class WechatChannel(Channel):
# login by scan QRCode
hotReload = conf().get('hot_reload', False)
try:
itchat.auto_login(enableCmdQR=2, hotReload=hotReload)
itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback)
except Exception as e:
if hotReload:
logger.error("Hot reload failed, try to login without hot reload")
itchat.logout()
os.remove("itchat.pkl")
itchat.auto_login(enableCmdQR=2, hotReload=hotReload)
itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback)
else:
raise e
self.userName = itchat.instance.storageClass.userName
self.nickName = itchat.instance.storageClass.nickName
logger.info("Wechat login success, username: {}, nickname: {}".format(self.userName, self.nickName))
self.user_id = itchat.instance.storageClass.userName
self.name = itchat.instance.storageClass.nickName
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
# start message listener
itchat.run()
# handle_* 系列函数处理收到的消息后构造Context,然后传入handle函数中处理Context和发送回复
# handle_* 系列函数处理收到的消息后构造Context,然后传入_handle函数中处理Context和发送回复
# Context包含了消息的所有信息,包括以下属性
# type 消息类型, 包括TEXT、VOICE、IMAGE_CREATE
# content 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令
@@ -104,289 +125,69 @@ class WechatChannel(Channel):
# session_id: 会话id
# isgroup: 是否是群聊
# receiver: 需要回复的对象
# msg: itchat的原始消息对象
# origin_ctype: 原始消息类型,用于私聊语音消息时,避免匹配前缀
# desire_rtype: 希望回复类型,TEXT类型是文本回复,VOICE类型是语音回复
# msg: ChatMessage消息对象
# origin_ctype: 原始消息类型,语音转文字后,私聊时如果匹配前缀失败,会根据初始消息是否是语音来放宽触发规则
# desire_rtype: 希望回复类型,默认是文本回复,设置为ReplyType.VOICE是语音回复
@time_checker
@_check
def handle_voice(self, msg):
def handle_voice(self, cmsg : ChatMessage):
if conf().get('speech_recognition') != True:
return
logger.debug("[WX]receive voice msg: " + msg['FileName'])
to_user_id = msg['ToUserName']
from_user_id = msg['FromUserName']
try:
other_user_id = msg['User']['UserName'] # 对手方id
except Exception as e:
logger.warn("[WX]get other_user_id failed: " + str(e))
if from_user_id == self.userName:
other_user_id = to_user_id
else:
other_user_id = from_user_id
if from_user_id == other_user_id:
context = self._compose_context(ContextType.VOICE, msg['FileName'], isgroup=False, msg=msg, receiver=other_user_id, session_id=other_user_id)
if context:
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
@time_checker
@_check
def handle_text(self, msg):
logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False))
content = msg['Text']
from_user_id = msg['FromUserName']
to_user_id = msg['ToUserName'] # 接收人id
try:
other_user_id = msg['User']['UserName'] # 对手方id
except Exception as e:
logger.warn("[WX]get other_user_id failed: " + str(e))
if from_user_id == self.userName:
other_user_id = to_user_id
else:
other_user_id = from_user_id
if "\n- - - - - - - - - - - - - - -" in content:
logger.debug("[WX]reference query skipped")
return
context = self._compose_context(ContextType.TEXT, content, isgroup=False, msg=msg, receiver=other_user_id, session_id=other_user_id)
logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=False, msg=cmsg)
if context:
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
self.produce(context)
@time_checker
@_check
def handle_group(self, msg):
logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False))
group_name = msg['User'].get('NickName', None)
group_id = msg['User'].get('UserName', None)
if not group_name:
return ""
origin_content = msg['Content']
content = msg['Content']
content_list = content.split(' ', 1)
context_special_list = content.split('\u2005', 1)
if len(context_special_list) == 2:
content = context_special_list[1]
elif len(content_list) == 2:
content = content_list[1]
if "\n- - - - - - - - - - - - - - -" in content:
logger.debug("[WX]reference query skipped")
return ""
def handle_text(self, cmsg : ChatMessage):
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=False, msg=cmsg)
if context:
self.produce(context)
config = conf()
group_name_white_list = config.get('group_name_white_list', [])
group_name_keyword_white_list = config.get('group_name_keyword_white_list', [])
if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]):
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
session_id = msg['ActualUserName']
if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]):
session_id = group_id
context = self._compose_context(ContextType.TEXT, content, isgroup=True, msg=msg, receiver=group_id, session_id=session_id)
if context:
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
@time_checker
@_check
def handle_group(self, cmsg : ChatMessage):
logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=True, msg=cmsg)
if context:
self.produce(context)
@time_checker
@_check
def handle_group_voice(self, msg):
def handle_group_voice(self, cmsg : ChatMessage):
if conf().get('group_speech_recognition', False) != True:
return
logger.debug("[WX]receive voice for group msg: " + msg['FileName'])
group_name = msg['User'].get('NickName', None)
group_id = msg['User'].get('UserName', None)
# 验证群名
if not group_name:
return ""
config = conf()
group_name_white_list = config.get('group_name_white_list', [])
group_name_keyword_white_list = config.get('group_name_keyword_white_list', [])
if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]):
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
session_id =msg['ActualUserName']
if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]):
session_id = group_id
context = self._compose_context(ContextType.VOICE, msg['FileName'], isgroup=True, msg=msg, receiver=group_id, session_id=session_id)
if context:
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
def _compose_context(self, ctype: ContextType, content, **kwargs):
context = Context(ctype, content)
context.kwargs = kwargs
if 'origin_ctype' not in context:
context['origin_ctype'] = ctype
if ctype == ContextType.TEXT:
if context["isgroup"]: # 群聊
# 校验关键字
match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
match_contain = check_contain(content, conf().get('group_chat_keyword'))
if match_prefix is not None or match_contain is not None:
# 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能
if match_prefix:
content = content.replace(match_prefix, '', 1).strip()
elif context['msg']['IsAt'] and not conf().get("group_at_off", False):
logger.info("[WX]receive group at, continue")
elif context["origin_ctype"] == ContextType.VOICE:
logger.info("[WX]receive group voice, checkprefix didn't match")
return None
else:
return None
else: # 单聊
match_prefix = check_prefix(content, conf().get('single_chat_prefix'))
if match_prefix: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
content = content.replace(match_prefix, '', 1).strip()
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,不匹配前缀,直接返回
pass
else:
return None
img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
if img_match_prefix:
content = content.replace(img_match_prefix, '', 1).strip()
context.type = ContextType.IMAGE_CREATE
else:
context.type = ContextType.TEXT
context.content = content
elif context.type == ContextType.VOICE:
if 'desire_rtype' not in context and conf().get('voice_reply_voice'):
context['desire_rtype'] = ReplyType.VOICE
return context
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=True, msg=cmsg)
if context:
self.produce(context)
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
def send(self, reply: Reply, receiver, retry_cnt = 0):
try:
if reply.type == ReplyType.TEXT:
itchat.send(reply.content, toUserName=receiver)
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
itchat.send(reply.content, toUserName=receiver)
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
elif reply.type == ReplyType.VOICE:
itchat.send_file(reply.content, toUserName=receiver)
logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver))
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content
pic_res = requests.get(img_url, stream=True)
image_storage = io.BytesIO()
for block in pic_res.iter_content(1024):
image_storage.write(block)
image_storage.seek(0)
itchat.send_image(image_storage, toUserName=receiver)
logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver))
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
image_storage = reply.content
image_storage.seek(0)
itchat.send_image(image_storage, toUserName=receiver)
logger.info('[WX] sendImage, receiver={}'.format(receiver))
except Exception as e:
logger.error('[WX] sendMsg error: {}, receiver={}'.format(e, receiver))
if retry_cnt < 2:
time.sleep(3+3*retry_cnt)
self.send(reply, receiver, retry_cnt + 1)
# 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类
def handle(self, context: Context):
if context is None or not context.content:
return
logger.debug('[WX] ready to handle context: {}'.format(context))
# reply的构建步骤
reply = self._generate_reply(context)
logger.debug('[WX] ready to decorate reply: {}'.format(reply))
# reply的包装步骤
reply = self._decorate_reply(context, reply)
# reply的发送步骤
self._send_reply(context, reply)
def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
'channel': self, 'context': context, 'reply': reply}))
reply = e_context['reply']
if not e_context.is_pass():
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content))
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
reply = super().build_reply_content(context.content, context)
elif context.type == ContextType.VOICE: # 语音消息
msg = context['msg']
mp3_path = TmpDir().path() + context.content
msg.download(mp3_path)
# mp3转wav
wav_path = os.path.splitext(mp3_path)[0] + '.wav'
try:
mp3_to_wav(mp3_path=mp3_path, wav_path=wav_path)
except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
logger.warning("[WX]mp3 to wav error, use mp3 path. " + str(e))
wav_path = mp3_path
# 语音识别
reply = super().build_voice_to_text(wav_path)
# 删除临时文件
try:
os.remove(wav_path)
os.remove(mp3_path)
except Exception as e:
logger.warning("[WX]delete temp file error: " + str(e))
if reply.type == ReplyType.TEXT:
new_context = self._compose_context(
ContextType.TEXT, reply.content, **context.kwargs)
if new_context:
reply = self._generate_reply(new_context)
else:
return
else:
logger.error('[WX] unknown context type: {}'.format(context.type))
return
return reply
def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
if reply and reply.type:
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {
'channel': self, 'context': context, 'reply': reply}))
reply = e_context['reply']
desire_rtype = context.get('desire_rtype')
if not e_context.is_pass() and reply and reply.type:
if reply.type == ReplyType.TEXT:
reply_text = reply.content
if desire_rtype == ReplyType.VOICE:
reply = super().build_text_to_voice(reply.content)
return self._decorate_reply(context, reply)
if context['isgroup']:
reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip()
reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
else:
reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
reply.content = reply_text
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
reply.content = str(reply.type)+":\n" + reply.content
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
pass
else:
logger.error('[WX] unknown reply type: {}'.format(reply.type))
return
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
logger.warning('[WX] desire_rtype: {}, but reply type: {}'.format(context.get('desire_rtype'), reply.type))
return reply
def _send_reply(self, context: Context, reply: Reply):
if reply and reply.type:
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {
'channel': self, 'context': context, 'reply': reply}))
reply = e_context['reply']
if not e_context.is_pass() and reply and reply.type:
logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver']))
self.send(reply, context['receiver'])
def check_prefix(content, prefix_list):
for prefix in prefix_list:
if content.startswith(prefix):
return prefix
return None
def check_contain(content, keyword_list):
if not keyword_list:
return None
for ky in keyword_list:
if content.find(ky) != -1:
return True
return None
def send(self, reply: Reply, context: Context):
receiver = context["receiver"]
if reply.type == ReplyType.TEXT:
itchat.send(reply.content, toUserName=receiver)
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
itchat.send(reply.content, toUserName=receiver)
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
elif reply.type == ReplyType.VOICE:
itchat.send_file(reply.content, toUserName=receiver)
logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver))
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content
pic_res = requests.get(img_url, stream=True)
image_storage = io.BytesIO()
for block in pic_res.iter_content(1024):
image_storage.write(block)
image_storage.seek(0)
itchat.send_image(image_storage, toUserName=receiver)
logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver))
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
image_storage = reply.content
image_storage.seek(0)
itchat.send_image(image_storage, toUserName=receiver)
logger.info('[WX] sendImage, receiver={}'.format(receiver))
+57
View File
@@ -0,0 +1,57 @@
from bridge.context import ContextType
from channel.chat_message import ChatMessage
from common.tmp_dir import TmpDir
from common.log import logger
from lib.itchat.content import *
from lib import itchat
class WeChatMessage(ChatMessage):
def __init__(self, itchat_msg, is_group=False):
super().__init__( itchat_msg)
self.msg_id = itchat_msg['MsgId']
self.create_time = itchat_msg['CreateTime']
self.is_group = is_group
if itchat_msg['Type'] == TEXT:
self.ctype = ContextType.TEXT
self.content = itchat_msg['Text']
elif itchat_msg['Type'] == VOICE:
self.ctype = ContextType.VOICE
self.content = TmpDir().path() + itchat_msg['FileName'] # content直接存临时目录路径
self._prepare_fn = lambda: itchat_msg.download(self.content)
else:
raise NotImplementedError("Unsupported message type: {}".format(itchat_msg['Type']))
self.from_user_id = itchat_msg['FromUserName']
self.to_user_id = itchat_msg['ToUserName']
user_id = itchat.instance.storageClass.userName
nickname = itchat.instance.storageClass.nickName
# 虽然from_user_id和to_user_id用的少,但是为了保持一致性,还是要填充一下
# 以下很繁琐,一句话总结:能填的都填了。
if self.from_user_id == user_id:
self.from_user_nickname = nickname
if self.to_user_id == user_id:
self.to_user_nickname = nickname
try: # 陌生人时候, 'User'字段可能不存在
self.other_user_id = itchat_msg['User']['UserName']
self.other_user_nickname = itchat_msg['User']['NickName']
if self.other_user_id == self.from_user_id:
self.from_user_nickname = self.other_user_nickname
if self.other_user_id == self.to_user_id:
self.to_user_nickname = self.other_user_nickname
except KeyError as e: # 处理偶尔没有对方信息的情况
logger.warn("[WX]get other_user_id failed: " + str(e))
if self.from_user_id == user_id:
self.other_user_id = self.to_user_id
else:
self.other_user_id = self.from_user_id
if self.is_group:
self.is_at = itchat_msg['IsAt']
self.actual_user_id = itchat_msg['ActualUserName']
self.actual_user_nickname = itchat_msg['ActualNickName']
+95 -309
View File
@@ -4,336 +4,122 @@
wechaty channel
Python Wechaty - https://github.com/wechaty/python-wechaty
"""
import base64
import os
import time
import asyncio
from typing import Optional, Union
from bridge.context import Context, ContextType
from wechaty_puppet import MessageType, FileBox, ScanStatus # type: ignore
from bridge.context import Context
from wechaty_puppet import FileBox
from wechaty import Wechaty, Contact
from wechaty.user import Message, MiniProgram, UrlLink
from channel.channel import Channel
from wechaty.user import Message
from bridge.reply import *
from bridge.context import *
from channel.chat_channel import ChatChannel
from channel.wechat.wechaty_message import WechatyMessage
from common.log import logger
from common.tmp_dir import TmpDir
from common.singleton import singleton
from config import conf
from voice.audio_convert import sil_to_wav, mp3_to_sil
try:
from voice.audio_convert import any_to_sil
except Exception as e:
pass
class WechatyChannel(Channel):
@singleton
class WechatyChannel(ChatChannel):
def __init__(self):
pass
super().__init__()
def startup(self):
config = conf()
token = config.get('wechaty_puppet_service_token')
os.environ['WECHATY_PUPPET_SERVICE_TOKEN'] = token
asyncio.run(self.main())
async def main(self):
config = conf()
# 使用PadLocal协议 比较稳定(免费web协议 os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:8080')
token = config.get('wechaty_puppet_service_token')
os.environ['WECHATY_PUPPET_SERVICE_TOKEN'] = token
global bot
bot = Wechaty()
bot.on('scan', self.on_scan)
bot.on('login', self.on_login)
bot.on('message', self.on_message)
await bot.start()
loop = asyncio.get_event_loop()
#将asyncio的loop传入处理线程
self.handler_pool._initializer= lambda: asyncio.set_event_loop(loop)
self.bot = Wechaty()
self.bot.on('login', self.on_login)
self.bot.on('message', self.on_message)
await self.bot.start()
async def on_login(self, contact: Contact):
self.user_id = contact.contact_id
self.name = contact.name
logger.info('[WX] login user={}'.format(contact))
async def on_scan(self, status: ScanStatus, qr_code: Optional[str] = None,
data: Optional[str] = None):
pass
# contact = self.Contact.load(self.contact_id)
# logger.info('[WX] scan user={}, scan status={}, scan qr_code={}'.format(contact, status.name, qr_code))
# print(f'user <{contact}> scan status: {status.name} , 'f'qr_code: {qr_code}')
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
def send(self, reply: Reply, context: Context):
receiver_id = context['receiver']
loop = asyncio.get_event_loop()
if context['isgroup']:
receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id),loop).result()
else:
receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id),loop).result()
msg = None
if reply.type == ReplyType.TEXT:
msg = reply.content
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
msg = reply.content
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
elif reply.type == ReplyType.VOICE:
voiceLength = None
file_path = reply.content
sil_file = os.path.splitext(file_path)[0] + '.sil'
voiceLength = int(any_to_sil(file_path, sil_file))
if voiceLength >= 60000:
voiceLength = 60000
logger.info('[WX] voice too long, length={}, set to 60s'.format(voiceLength))
# 发送语音
t = int(time.time())
msg = FileBox.from_file(sil_file, name=str(t) + '.sil')
if voiceLength is not None:
msg.metadata['voiceLength'] = voiceLength
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
try:
os.remove(file_path)
if sil_file != file_path:
os.remove(sil_file)
except Exception as e:
pass
logger.info('[WX] sendVoice={}, receiver={}'.format(reply.content, receiver))
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content
t = int(time.time())
msg = FileBox.from_url(url=img_url, name=str(t) + '.png')
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver))
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
image_storage = reply.content
image_storage.seek(0)
t = int(time.time())
msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + '.png')
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
logger.info('[WX] sendImage, receiver={}'.format(receiver))
async def on_message(self, msg: Message):
"""
listen for message event
"""
from_contact = msg.talker() # 获取消息的发送者
to_contact = msg.to() # 接收人
try:
cmsg = await WechatyMessage(msg)
except NotImplementedError as e:
logger.debug('[WX] {}'.format(e))
return
except Exception as e:
logger.exception('[WX] {}'.format(e))
return
logger.debug('[WX] message:{}'.format(cmsg))
room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None
from_user_id = from_contact.contact_id
to_user_id = to_contact.contact_id # 接收人id
# other_user_id = msg['User']['UserName'] # 对手方id
content = msg.text()
mention_content = await msg.mention_text() # 返回过滤掉@name后的消息
match_prefix = self.check_prefix(content, conf().get('single_chat_prefix'))
# conversation: Union[Room, Contact] = from_contact if room is None else room
if room is None and msg.type() == MessageType.MESSAGE_TYPE_TEXT:
if not msg.is_self() and match_prefix is not None:
# 好友向自己发送消息
if match_prefix != '':
str_list = content.split(match_prefix, 1)
if len(str_list) == 2:
content = str_list[1].strip()
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
if img_match_prefix:
content = content.split(img_match_prefix, 1)[1].strip()
await self._do_send_img(content, from_user_id)
else:
await self._do_send(content, from_user_id)
elif msg.is_self() and match_prefix:
# 自己给好友发送消息
str_list = content.split(match_prefix, 1)
if len(str_list) == 2:
content = str_list[1].strip()
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
if img_match_prefix:
content = content.split(img_match_prefix, 1)[1].strip()
await self._do_send_img(content, to_user_id)
else:
await self._do_send(content, to_user_id)
elif room is None and msg.type() == MessageType.MESSAGE_TYPE_AUDIO:
if not msg.is_self(): # 接收语音消息
# 下载语音文件
voice_file = await msg.to_file_box()
silk_file = TmpDir().path() + voice_file.name
await voice_file.to_file(silk_file)
logger.info("[WX]receive voice file: " + silk_file)
# 将文件转成wav格式音频
wav_file = os.path.splitext(silk_file)[0] + '.wav'
sil_to_wav(silk_file, wav_file)
# 语音识别为文本
query = super().build_voice_to_text(wav_file).content
# 交验关键字
match_prefix = self.check_prefix(query, conf().get('single_chat_prefix'))
if match_prefix is not None:
if match_prefix != '':
str_list = query.split(match_prefix, 1)
if len(str_list) == 2:
query = str_list[1].strip()
# 返回消息
if conf().get('voice_reply_voice'):
await self._do_send_voice(query, from_user_id)
else:
await self._do_send(query, from_user_id)
else:
logger.info("[WX]receive voice check prefix: " + 'False')
# 清除缓存文件
os.remove(wav_file)
os.remove(silk_file)
elif room and msg.type() == MessageType.MESSAGE_TYPE_TEXT:
# 群组&文本消息
room_id = room.room_id
room_name = await room.topic()
from_user_id = from_contact.contact_id
from_user_name = from_contact.name
is_at = await msg.mention_self()
content = mention_content
config = conf()
match_prefix = (is_at and not config.get("group_at_off", False)) \
or self.check_prefix(content, config.get('group_chat_prefix')) \
or self.check_contain(content, config.get('group_chat_keyword'))
# Wechaty判断is_at为True,返回的内容是过滤掉@之后的内容;而is_at为False,则会返回完整的内容
# 故判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能
prefixes = config.get('group_chat_prefix')
for prefix in prefixes:
if content.startswith(prefix):
content = content.replace(prefix, '', 1).strip()
break
if ('ALL_GROUP' in config.get('group_name_white_list') or room_name in config.get(
'group_name_white_list') or self.check_contain(room_name, config.get(
'group_name_keyword_white_list'))) and match_prefix:
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
if img_match_prefix:
content = content.split(img_match_prefix, 1)[1].strip()
await self._do_send_group_img(content, room_id)
else:
await self._do_send_group(content, room_id, room_name, from_user_id, from_user_name)
elif room and msg.type() == MessageType.MESSAGE_TYPE_AUDIO:
# 群组&语音消息
room_id = room.room_id
room_name = await room.topic()
from_user_id = from_contact.contact_id
from_user_name = from_contact.name
is_at = await msg.mention_self()
config = conf()
# 是否开启语音识别、群消息响应功能、群名白名单符合等条件
if config.get('group_speech_recognition') and (
'ALL_GROUP' in config.get('group_name_white_list') or room_name in config.get(
'group_name_white_list') or self.check_contain(room_name, config.get(
'group_name_keyword_white_list'))):
# 下载语音文件
voice_file = await msg.to_file_box()
silk_file = TmpDir().path() + voice_file.name
await voice_file.to_file(silk_file)
logger.info("[WX]receive voice file: " + silk_file)
# 将文件转成wav格式音频
wav_file = os.path.splitext(silk_file)[0] + '.wav'
sil_to_wav(silk_file, wav_file)
# 语音识别为文本
query = super().build_voice_to_text(wav_file).content
# 校验关键字
match_prefix = self.check_prefix(query, config.get('group_chat_prefix')) \
or self.check_contain(query, config.get('group_chat_keyword'))
# Wechaty判断is_at为True,返回的内容是过滤掉@之后的内容;而is_at为False,则会返回完整的内容
if match_prefix is not None:
# 故判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能
prefixes = config.get('group_chat_prefix')
for prefix in prefixes:
if query.startswith(prefix):
query = query.replace(prefix, '', 1).strip()
break
# 返回消息
img_match_prefix = self.check_prefix(query, conf().get('image_create_prefix'))
if img_match_prefix:
query = query.split(img_match_prefix, 1)[1].strip()
await self._do_send_group_img(query, room_id)
elif config.get('voice_reply_voice'):
await self._do_send_group_voice(query, room_id, room_name, from_user_id, from_user_name)
else:
await self._do_send_group(query, room_id, room_name, from_user_id, from_user_name)
else:
logger.info("[WX]receive voice check prefix: " + 'False')
# 清除缓存文件
os.remove(wav_file)
os.remove(silk_file)
async def send(self, message: Union[str, Message, FileBox, Contact, UrlLink, MiniProgram], receiver):
logger.info('[WX] sendMsg={}, receiver={}'.format(message, receiver))
if receiver:
contact = await bot.Contact.find(receiver)
await contact.say(message)
async def send_group(self, message: Union[str, Message, FileBox, Contact, UrlLink, MiniProgram], receiver):
logger.info('[WX] sendMsg={}, receiver={}'.format(message, receiver))
if receiver:
room = await bot.Room.find(receiver)
await room.say(message)
async def _do_send(self, query, reply_user_id):
try:
if not query:
return
context = Context(ContextType.TEXT, query)
context['session_id'] = reply_user_id
reply_text = super().build_reply_content(query, context).content
if reply_text:
await self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id)
except Exception as e:
logger.exception(e)
async def _do_send_voice(self, query, reply_user_id):
try:
if not query:
return
context = Context(ContextType.TEXT, query)
context['session_id'] = reply_user_id
reply_text = super().build_reply_content(query, context).content
if reply_text:
# 转换 mp3 文件为 silk 格式
mp3_file = super().build_text_to_voice(reply_text).content
silk_file = os.path.splitext(mp3_file)[0] + '.sil'
voiceLength = mp3_to_sil(mp3_file, silk_file)
# 发送语音
t = int(time.time())
file_box = FileBox.from_file(silk_file, name=str(t) + '.sil')
file_box.metadata = {'voiceLength': voiceLength}
await self.send(file_box, reply_user_id)
# 清除缓存文件
os.remove(mp3_file)
os.remove(silk_file)
except Exception as e:
logger.exception(e)
async def _do_send_img(self, query, reply_user_id):
try:
if not query:
return
context = Context(ContextType.IMAGE_CREATE, query)
img_url = super().build_reply_content(query, context).content
if not img_url:
return
# 图片下载
# pic_res = requests.get(img_url, stream=True)
# image_storage = io.BytesIO()
# for block in pic_res.iter_content(1024):
# image_storage.write(block)
# image_storage.seek(0)
# 图片发送
logger.info('[WX] sendImage, receiver={}'.format(reply_user_id))
t = int(time.time())
file_box = FileBox.from_url(url=img_url, name=str(t) + '.png')
await self.send(file_box, reply_user_id)
except Exception as e:
logger.exception(e)
async def _do_send_group(self, query, group_id, group_name, group_user_id, group_user_name):
if not query:
return
context = Context(ContextType.TEXT, query)
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
if ('ALL_GROUP' in group_chat_in_one_session or \
group_name in group_chat_in_one_session or \
self.check_contain(group_name, group_chat_in_one_session)):
context['session_id'] = str(group_id)
else:
context['session_id'] = str(group_id) + '-' + str(group_user_id)
reply_text = super().build_reply_content(query, context).content
if reply_text:
reply_text = '@' + group_user_name + ' ' + reply_text.strip()
await self.send_group(conf().get("group_chat_reply_prefix", "") + reply_text, group_id)
async def _do_send_group_voice(self, query, group_id, group_name, group_user_id, group_user_name):
if not query:
return
context = Context(ContextType.TEXT, query)
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
if ('ALL_GROUP' in group_chat_in_one_session or \
group_name in group_chat_in_one_session or \
self.check_contain(group_name, group_chat_in_one_session)):
context['session_id'] = str(group_id)
else:
context['session_id'] = str(group_id) + '-' + str(group_user_id)
reply_text = super().build_reply_content(query, context).content
if reply_text:
reply_text = '@' + group_user_name + ' ' + reply_text.strip()
# 转换 mp3 文件为 silk 格式
mp3_file = super().build_text_to_voice(reply_text).content
silk_file = os.path.splitext(mp3_file)[0] + '.sil'
voiceLength = mp3_to_sil(mp3_file, silk_file)
# 发送语音
t = int(time.time())
file_box = FileBox.from_file(silk_file, name=str(t) + '.silk')
file_box.metadata = {'voiceLength': voiceLength}
await self.send_group(file_box, group_id)
# 清除缓存文件
os.remove(mp3_file)
os.remove(silk_file)
async def _do_send_group_img(self, query, reply_room_id):
try:
if not query:
return
context = Context(ContextType.IMAGE_CREATE, query)
img_url = super().build_reply_content(query, context).content
if not img_url:
return
# 图片发送
logger.info('[WX] sendImage, receiver={}'.format(reply_room_id))
t = int(time.time())
file_box = FileBox.from_url(url=img_url, name=str(t) + '.png')
await self.send_group(file_box, reply_room_id)
except Exception as e:
logger.exception(e)
def check_prefix(self, content, prefix_list):
for prefix in prefix_list:
if content.startswith(prefix):
return prefix
return None
def check_contain(self, content, keyword_list):
if not keyword_list:
return None
for ky in keyword_list:
if content.find(ky) != -1:
return True
return None
isgroup = room is not None
ctype = cmsg.ctype
context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg)
if context:
logger.info('[WX] receiveMsg={}, context={}'.format(cmsg, context))
self.produce(context)
+85
View File
@@ -0,0 +1,85 @@
import asyncio
import re
from wechaty import MessageType
from bridge.context import ContextType
from channel.chat_message import ChatMessage
from common.tmp_dir import TmpDir
from common.log import logger
from wechaty.user import Message
class aobject(object):
"""Inheriting this class allows you to define an async __init__.
So you can create objects by doing something like `await MyClass(params)`
"""
async def __new__(cls, *a, **kw):
instance = super().__new__(cls)
await instance.__init__(*a, **kw)
return instance
async def __init__(self):
pass
class WechatyMessage(ChatMessage, aobject):
async def __init__(self, wechaty_msg: Message):
super().__init__(wechaty_msg)
room = wechaty_msg.room()
self.msg_id = wechaty_msg.message_id
self.create_time = wechaty_msg.payload.timestamp
self.is_group = room is not None
if wechaty_msg.type() == MessageType.MESSAGE_TYPE_TEXT:
self.ctype = ContextType.TEXT
self.content = wechaty_msg.text()
elif wechaty_msg.type() == MessageType.MESSAGE_TYPE_AUDIO:
self.ctype = ContextType.VOICE
voice_file = await wechaty_msg.to_file_box()
self.content = TmpDir().path() + voice_file.name # content直接存临时目录路径
def func():
loop = asyncio.get_event_loop()
asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content),loop).result()
self._prepare_fn = func
else:
raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type()))
from_contact = wechaty_msg.talker() # 获取消息的发送者
self.from_user_id = from_contact.contact_id
self.from_user_nickname = from_contact.name
# group中的from和towechaty跟itchat含义不一样
# wecahty: from是消息实际发送者, to:所在群
# itchat: 如果是你发送群消息,from和to是你自己和所在群,如果是别人发群消息,from和to是所在群和你自己
# 但这个差别不影响逻辑,group中只使用到:1.用from来判断是否是自己发的,2.actual_user_id来判断实际发送用户
if self.is_group:
self.to_user_id = room.room_id
self.to_user_nickname = await room.topic()
else:
to_contact = wechaty_msg.to()
self.to_user_id = to_contact.contact_id
self.to_user_nickname = to_contact.name
if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
self.other_user_id = self.to_user_id
self.other_user_nickname = self.to_user_nickname
else:
self.other_user_id = self.from_user_id
self.other_user_nickname = self.from_user_nickname
if self.is_group: # wechaty群聊中,实际发送用户就是from_user
self.is_at = await wechaty_msg.mention_self()
if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容
name = wechaty_msg.wechaty.user_self().name
pattern = f'@{name}(\u2005|\u0020)'
if re.search(pattern,self.content):
logger.debug(f'wechaty message {self.msg_id} include at')
self.is_at = True
self.actual_user_id = self.from_user_id
self.actual_user_nickname = self.from_user_nickname
+33
View File
@@ -0,0 +1,33 @@
from queue import Full, Queue
from time import monotonic as time
# add implementation of putleft to Queue
class Dequeue(Queue):
def putleft(self, item, block=True, timeout=None):
with self.not_full:
if self.maxsize > 0:
if not block:
if self._qsize() >= self.maxsize:
raise Full
elif timeout is None:
while self._qsize() >= self.maxsize:
self.not_full.wait()
elif timeout < 0:
raise ValueError("'timeout' must be a non-negative number")
else:
endtime = time() + timeout
while self._qsize() >= self.maxsize:
remaining = endtime - time()
if remaining <= 0.0:
raise Full
self.not_full.wait(remaining)
self._putleft(item)
self.unfinished_tasks += 1
self.not_empty.notify()
def putleft_nowait(self, item):
return self.putleft(item, block=False)
def _putleft(self, item):
self.queue.appendleft(item)
+4
View File
@@ -8,6 +8,10 @@ def _get_logger():
console_handle = logging.StreamHandler(sys.stdout)
console_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'))
file_handle = logging.FileHandler('run.log', encoding='utf-8')
file_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'))
log.addHandler(file_handle)
log.addHandler(console_handle)
return log
+24 -6
View File
@@ -1,6 +1,7 @@
# encoding:utf-8
import json
import logging
import os
from common.log import logger
@@ -25,7 +26,9 @@ available_setting = {
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
"group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
"trigger_by_self": False, # 是否允许机器人触发
"image_create_prefix": ["", "", ""], # 开启图片回复的前缀
"concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序
# chatgpt会话参数
"expires_in_seconds": 3600, # 无操作会话的过期时间
@@ -36,27 +39,32 @@ available_setting = {
"rate_limit_chatgpt": 20, # chatgpt的调用频率限制
"rate_limit_dalle": 50, # openai dalle的调用频率限制
# chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create
"temperature": 0.9,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"request_timeout": 30, # chatgpt请求超时时间
# 语音设置
"speech_recognition": False, # 是否开启语音识别
"group_speech_recognition": False, # 是否开启群组语音识别
"voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key
"voice_to_text": "openai", # 语音识别引擎,支持openai,google
"text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline)
"always_reply_voice": False, # 是否一直使用语音回复
"voice_to_text": "openai", # 语音识别引擎,支持openai,google,azure
"text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline),azure
# baidu api配置, 使用百度语音识别和语音合成时需要
# baidu 语音api配置, 使用百度语音识别和语音合成时需要
"baidu_app_id": "",
"baidu_api_key": "",
"baidu_secret_key": "",
# 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场
"baidu_dev_pid": "1536",
# azure 语音api配置, 使用azure语音识别和语音合成时需要
"azure_voice_api_key": "",
"azure_voice_region": "japaneast",
# 服务时间限制,目前支持itchat
"chat_time_module": False, # 是否开启服务时间限制
"chat_start_time": "00:00", # 服务开始时间
@@ -69,11 +77,12 @@ available_setting = {
"wechaty_puppet_service_token": "", # wechaty的token
# chatgpt指令自定义触发词
"clear_memory_commands": ['#清除记忆'], # 重置会话指令
"clear_memory_commands": ['#清除记忆'], # 重置会话指令,必须以#开头
# channel配置
"channel_type": "wx", # 通道类型,支持wx,wxy和terminal
"debug": False, # 是否开启debug模式,开启后会打印更多日志
}
@@ -124,7 +133,16 @@ def load_config():
try:
config[name] = eval(value)
except:
config[name] = value
if value == "false":
config[name] = False
elif value == "true":
config[name] = True
else:
config[name] = value
if config.get("debug", False):
logger.setLevel(logging.DEBUG)
logger.debug("[INIT] set log level to DEBUG")
logger.info("[INIT] load config: {}".format(config))
+32
View File
@@ -0,0 +1,32 @@
FROM python:3.10-slim
LABEL maintainer="foo@bar.com"
ARG TZ='Asia/Shanghai'
ARG CHATGPT_ON_WECHAT_VER
ENV BUILD_PREFIX=/app
ADD . ${BUILD_PREFIX}
RUN apt-get update \
&&apt-get install -y --no-install-recommends bash \
ffmpeg espeak \
&& cd ${BUILD_PREFIX} \
&& cp config-template.json config.json \
&& /usr/local/bin/python -m pip install --no-cache --upgrade pip \
&& pip install --no-cache -r requirements.txt \
&& pip install azure-cognitiveservices-speech
WORKDIR ${BUILD_PREFIX}
ADD docker/entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh \
&& groupadd -r noroot \
&& useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \
&& chown -R noroot:noroot ${BUILD_PREFIX}
USER noroot
ENTRYPOINT ["docker/entrypoint.sh"]
+1 -1
View File
@@ -1,4 +1,4 @@
#!/bin/bash
cd .. && docker build -f Dockerfile \
cd .. && docker build -f docker/Dockerfile.latest \
-t zhayujie/chatgpt-on-wechat .
+10 -7
View File
@@ -1,6 +1,6 @@
## 插件化初衷
之前未插件化的代码耦合程度高,如果要定制一些个性化功能(如流量控制、接入`NovelAI`画图平台等),需要了解代码主体,避免影响到其他的功能。在实现多个功能后,不但无法调整功能的优先级顺序,功能配置项也会变得非常混乱。
之前未插件化的代码耦合程度高,如果要定制一些个性化功能(如流量控制、接入`NovelAI`画图平台等),需要了解代码主体,避免影响到其他的功能。多个功能同时存在时,无法调整功能的优先级顺序,功能配置项也非常混乱。
此时插件化应声而出。
@@ -11,7 +11,7 @@
- [x] 插件化能够自由开关和调整优先级。
- [x] 每个插件可在插件文件夹内维护独立的配置文件,方便代码的测试和调试,可以在独立的仓库开发插件。
PS: 插件目前支持`itchat`
PS: 插件目前支持`itchat``wechaty`
## 插件化实现
@@ -101,7 +101,7 @@ PS: 插件目前仅支持`itchat`
根据`Context`和回复`Reply`的类型,对回复的内容进行装饰。目前的装饰有以下两种:
- `TEXT`文本回复根据是否在群聊中来决定是艾特接收方还是添加回复的前缀。
- `TEXT`文本回复:如果这次消息需要的回复是`VOICE`,进行文字转语音回复之后再次装饰。 否则根据是否在群聊中来决定是艾特接收方还是添加回复的前缀。
- `INFO``ERROR`类型,会在消息前添加对应的系统提示字样。
@@ -110,8 +110,11 @@ PS: 插件目前仅支持`itchat`
```python
if reply.type == ReplyType.TEXT:
reply_text = reply.content
if context.get('desire_rtype') == ReplyType.VOICE:
reply = super().build_text_to_voice(reply.content)
return self._decorate_reply(context, reply)
if context['isgroup']:
reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip()
reply_text = '@' + context['msg'].actual_user_nickname + ' ' + reply_text.strip()
reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
else:
reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
@@ -213,11 +216,11 @@ class Hello(Plugin):
if content == "Hello":
reply = Reply()
reply.type = ReplyType.TEXT
msg = e_context['context']['msg']
msg:ChatMessage = e_context['context']['msg']
if e_context['context']['isgroup']:
reply.content = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group")
reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
else:
reply.content = "Hello, " + msg['User'].get('NickName', "My friend")
reply.content = f"Hello, {msg.from_user_nickname}"
e_context['reply'] = reply
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
if content == "End":
+12 -2
View File
@@ -7,7 +7,7 @@ from typing import Tuple
from bridge.bridge import Bridge
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from config import load_config
from config import conf, load_config
import plugins
from plugins import *
from common import const
@@ -126,7 +126,14 @@ class Godcmd(Plugin):
else:
with open(config_path,"r") as f:
gconf=json.load(f)
custom_commands = conf().get("clear_memory_commands", [])
for custom_command in custom_commands:
if custom_command and custom_command.startswith("#"):
custom_command = custom_command[1:]
if custom_command and custom_command not in COMMANDS["reset"]["alias"]:
COMMANDS["reset"]["alias"].append(custom_command)
self.password = gconf["password"]
self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证 TODO: 用户名每次都会变,目前不可用
self.isrunning = True # 机器人是否运行中
@@ -146,6 +153,7 @@ class Godcmd(Plugin):
logger.debug("[Godcmd] on_handle_context. content: %s" % content)
if content.startswith("#"):
# msg = e_context['context']['msg']
channel = e_context['channel']
user = e_context['context']['receiver']
session_id = e_context['context']['session_id']
isgroup = e_context['context']['isgroup']
@@ -181,6 +189,7 @@ class Godcmd(Plugin):
elif cmd == "reset":
if bottype in (const.CHATGPT, const.OPEN_AI):
bot.sessions.clear_session(session_id)
channel.cancel_session(session_id)
ok, result = True, "会话已重置"
else:
ok, result = False, "当前对话机器人不支持重置会话"
@@ -202,6 +211,7 @@ class Godcmd(Plugin):
ok, result = True, "配置已重载"
elif cmd == "resetall":
if bottype in (const.CHATGPT, const.OPEN_AI):
channel.cancel_all_session()
bot.sessions.clear_all_session()
ok, result = True, "重置所有会话成功"
else:
+4 -3
View File
@@ -2,6 +2,7 @@
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from channel.chat_message import ChatMessage
import plugins
from plugins import *
from common.log import logger
@@ -24,11 +25,11 @@ class Hello(Plugin):
if content == "Hello":
reply = Reply()
reply.type = ReplyType.TEXT
msg = e_context['context']['msg']
msg:ChatMessage = e_context['context']['msg']
if e_context['context']['isgroup']:
reply.content = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group")
reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
else:
reply.content = "Hello, " + msg['User'].get('NickName', "My friend")
reply.content = f"Hello, {msg.from_user_nickname}"
e_context['reply'] = reply
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
+1
View File
@@ -7,6 +7,7 @@ PyQRCode>=1.2.1
pysilk>=0.0.1
pysilk_mod>=1.6.0
pyttsx3>=2.90
qrcode>=7.4.2
requests>=2.28.2
webuiapi>=0.6.2
wechaty>=0.10.7
+53 -9
View File
@@ -1,7 +1,23 @@
import shutil
import wave
import pysilk
from pydub import AudioSegment
sil_supports=[8000, 12000, 16000, 24000, 32000, 44100, 48000] # slk转wav时,支持的采样率
def find_closest_sil_supports(sample_rate):
"""
找到最接近的支持的采样率
"""
if sample_rate in sil_supports:
return sample_rate
closest = 0
mindiff = 9999999
for rate in sil_supports:
diff = abs(rate - sample_rate)
if diff < mindiff:
closest = rate
mindiff = diff
return closest
def get_pcm_from_wav(wav_path):
"""
@@ -13,6 +29,30 @@ def get_pcm_from_wav(wav_path):
wav = wave.open(wav_path, "rb")
return wav.readframes(wav.getnframes())
def any_to_wav(any_path, wav_path):
"""
把任意格式转成wav文件
"""
if any_path.endswith('.wav'):
shutil.copy2(any_path, wav_path)
return
if any_path.endswith('.sil') or any_path.endswith('.silk') or any_path.endswith('.slk'):
return sil_to_wav(any_path, wav_path)
audio = AudioSegment.from_file(any_path)
audio.export(wav_path, format="wav")
def any_to_sil(any_path, sil_path):
"""
把任意格式转成sil文件
"""
if any_path.endswith('.sil') or any_path.endswith('.silk') or any_path.endswith('.slk'):
shutil.copy2(any_path, sil_path)
return 10000
if any_path.endswith('.wav'):
return pcm_to_sil(any_path, sil_path)
if any_path.endswith('.mp3'):
return mp3_to_sil(any_path, sil_path)
raise NotImplementedError("Not support file type: {}".format(any_path))
def mp3_to_wav(mp3_path, wav_path):
"""
@@ -21,36 +61,40 @@ def mp3_to_wav(mp3_path, wav_path):
audio = AudioSegment.from_mp3(mp3_path)
audio.export(wav_path, format="wav")
def pcm_to_silk(pcm_path, silk_path):
def pcm_to_sil(pcm_path, silk_path):
"""
wav 文件转成 silk
return 声音长度,毫秒
"""
audio = AudioSegment.from_wav(pcm_path)
wav_data = audio.raw_data
rate = find_closest_sil_supports(audio.frame_rate)
# Convert to PCM_s16
pcm_s16 = audio.set_sample_width(2)
pcm_s16 = pcm_s16.set_frame_rate(rate)
wav_data = pcm_s16.raw_data
silk_data = pysilk.encode(
wav_data, data_rate=audio.frame_rate, sample_rate=audio.frame_rate)
wav_data, data_rate=rate, sample_rate=rate)
with open(silk_path, "wb") as f:
f.write(silk_data)
return audio.duration_seconds * 1000
def mp3_to_sil(mp3_path, silk_path):
"""
mp3 文件转成 silk
return 声音长度,毫秒
"""
audio = AudioSegment.from_mp3(mp3_path)
wav_data = audio.raw_data
silk_data = pysilk.encode(
wav_data, data_rate=audio.frame_rate, sample_rate=audio.frame_rate)
rate = find_closest_sil_supports(audio.frame_rate)
# Convert to PCM_s16
pcm_s16 = audio.set_sample_width(2)
pcm_s16 = pcm_s16.set_frame_rate(rate)
wav_data = pcm_s16.raw_data
silk_data = pysilk.encode(wav_data, data_rate=rate, sample_rate=rate)
# Save the silk file
with open(silk_path, "wb") as f:
f.write(silk_data)
return audio.duration_seconds * 1000
def sil_to_wav(silk_path, wav_path, rate: int = 24000):
"""
silk 文件转 wav
+68
View File
@@ -0,0 +1,68 @@
"""
azure voice service
"""
import json
import os
import time
import azure.cognitiveservices.speech as speechsdk
from bridge.reply import Reply, ReplyType
from common.log import logger
from common.tmp_dir import TmpDir
from voice.voice import Voice
from config import conf
"""
Azure voice
主目录设置文件中需填写azure_voice_api_key和azure_voice_region
查看可用的 voice https://speech.microsoft.com/portal/voicegallery
"""
class AzureVoice(Voice):
def __init__(self):
try:
curdir = os.path.dirname(__file__)
config_path = os.path.join(curdir, "config.json")
config = None
if not os.path.exists(config_path): #如果没有配置文件,创建本地配置文件
config = { "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", "speech_recognition_language": "zh-CN"}
with open(config_path, "w") as fw:
json.dump(config, fw, indent=4)
else:
with open(config_path, "r") as fr:
config = json.load(fr)
self.api_key = conf().get('azure_voice_api_key')
self.api_region = conf().get('azure_voice_region')
self.speech_config = speechsdk.SpeechConfig(subscription=self.api_key, region=self.api_region)
self.speech_config.speech_synthesis_voice_name = config["speech_synthesis_voice_name"]
self.speech_config.speech_recognition_language = config["speech_recognition_language"]
except Exception as e:
logger.warn("AzureVoice init failed: %s, ignore " % e)
def voiceToText(self, voice_file):
audio_config = speechsdk.AudioConfig(filename=voice_file)
speech_recognizer = speechsdk.SpeechRecognizer(speech_config=self.speech_config, audio_config=audio_config)
result = speech_recognizer.recognize_once()
if result.reason == speechsdk.ResultReason.RecognizedSpeech:
logger.info('[Azure] voiceToText voice file name={} text={}'.format(voice_file, result.text))
reply = Reply(ReplyType.TEXT, result.text)
else:
logger.error('[Azure] voiceToText error, result={}'.format(result))
reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败")
return reply
def textToVoice(self, text):
fileName = TmpDir().path() + 'reply-' + str(int(time.time())) + '.wav'
audio_config = speechsdk.AudioConfig(filename=fileName)
speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=self.speech_config, audio_config=audio_config)
result = speech_synthesizer.speak_text(text)
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
logger.info(
'[Azure] textToVoice text={} voice file name={}'.format(text, fileName))
reply = Reply(ReplyType.VOICE, fileName)
else:
logger.error('[Azure] textToVoice error, result={}'.format(result))
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
return reply
+4
View File
@@ -0,0 +1,4 @@
{
"speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural",
"speech_recognition_language": "zh-CN"
}
+1 -1
View File
@@ -80,7 +80,7 @@ class BaiduVoice(Voice):
result = self.client.synthesis(text, self.lang, self.ctp, {
'spd': self.spd, 'pit': self.pit, 'vol': self.vol, 'per': self.per})
if not isinstance(result, dict):
fileName = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
fileName = TmpDir().path() + 'reply-' + str(int(time.time())) + '.mp3'
with open(fileName, 'wb') as f:
f.write(result)
logger.info(
+1 -1
View File
@@ -34,7 +34,7 @@ class GoogleVoice(Voice):
return reply
def textToVoice(self, text):
try:
mp3File = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
mp3File = TmpDir().path() + 'reply-' + str(int(time.time())) + '.mp3'
tts = gTTS(text=text, lang='zh')
tts.save(mp3File)
logger.info(
+4 -4
View File
@@ -25,12 +25,12 @@ class PyttsVoice(Voice):
def textToVoice(self, text):
try:
mp3File = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
self.engine.save_to_file(text, mp3File)
wavFile = TmpDir().path() + 'reply-' + str(int(time.time())) + '.wav'
self.engine.save_to_file(text, wavFile)
self.engine.runAndWait()
logger.info(
'[Pytts] textToVoice text={} voice file name={}'.format(text, mp3File))
reply = Reply(ReplyType.VOICE, mp3File)
'[Pytts] textToVoice text={} voice file name={}'.format(text, wavFile))
reply = Reply(ReplyType.VOICE, wavFile)
except Exception as e:
reply = Reply(ReplyType.ERROR, str(e))
finally:
+3
View File
@@ -20,4 +20,7 @@ def create_voice(voice_type):
elif voice_type == 'pytts':
from voice.pytts.pytts_voice import PyttsVoice
return PyttsVoice()
elif voice_type == 'azure':
from voice.azure.azure_voice import AzureVoice
return AzureVoice()
raise RuntimeError