mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-05-11 22:35:32 +08:00
Compare commits
25 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 186e18fe94 | |||
| 28eb67bc24 | |||
| 6c7e4aaf37 | |||
| 709a1317ef | |||
| 371e38cfa6 | |||
| 5a221848e9 | |||
| 7458a6298f | |||
| b0f54bb8b7 | |||
| acddadc406 | |||
| b74274b96b | |||
| 49ba278316 | |||
| 388058467c | |||
| cf25bd7869 | |||
| 02a95345aa | |||
| 6076e2ed0a | |||
| cec674cb47 | |||
| c5a90823fa | |||
| 18d82bc1f0 | |||
| a68af990ea | |||
| e71c600d10 | |||
| d7f1f7182c | |||
| dfb2e460b4 | |||
| 5badef8ba9 | |||
| 18aa5ce75c | |||
| 1545a9f262 |
@@ -10,3 +10,4 @@ nohup.out
|
||||
tmp
|
||||
plugins.json
|
||||
itchat.pkl
|
||||
*.log
|
||||
@@ -90,6 +90,13 @@ pip3 install -r requirements.txt
|
||||
|
||||
参考[#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)
|
||||
|
||||
使用`azure`语音功能需安装依赖:
|
||||
```bash
|
||||
pip3 install azure-cognitiveservices-speech
|
||||
```
|
||||
> 目前默认发布的镜像和`railway`部署,都基于`apline`,无法安装`azure`的依赖。若有需求请自行基于[`debian`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/docker/Dockerfile.debian.latest)打包。
|
||||
参考[文档](https://learn.microsoft.com/en-us/azure/cognitive-services/speech-service/quickstarts/setup-platform?pivots=programming-language-python&tabs=linux%2Cubuntu%2Cdotnet%2Cjre%2Cmaven%2Cnodejs%2Cmac%2Cpypi)
|
||||
|
||||
## 配置
|
||||
|
||||
配置文件的模板在根目录的`config-template.json`中,需复制该模板创建最终生效的 `config.json` 文件:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import os
|
||||
from config import conf, load_config
|
||||
from channel import channel_factory
|
||||
from common.log import logger
|
||||
@@ -13,6 +14,10 @@ def run():
|
||||
|
||||
# create channel
|
||||
channel_name=conf().get('channel_type', 'wx')
|
||||
if channel_name == 'wxy':
|
||||
os.environ['WECHATY_LOG']="warn"
|
||||
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
|
||||
|
||||
channel = channel_factory.create_channel(channel_name)
|
||||
if channel_name in ['wx','wxy']:
|
||||
PluginManager().load_plugins()
|
||||
|
||||
@@ -86,6 +86,7 @@ class ChatGPTBot(Bot,OpenAIImage):
|
||||
"top_p":1,
|
||||
"frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"request_timeout": conf().get('request_timeout', 30), # 请求超时时间
|
||||
}
|
||||
|
||||
def reply_text(self, session:ChatGPTSession, session_id, retry_count=0) -> dict:
|
||||
|
||||
+103
-16
@@ -1,10 +1,12 @@
|
||||
|
||||
|
||||
|
||||
from asyncio import CancelledError
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.dequeue import Dequeue
|
||||
from channel.channel import Channel
|
||||
from bridge.reply import *
|
||||
from bridge.context import *
|
||||
@@ -20,8 +22,16 @@ except Exception as e:
|
||||
class ChatChannel(Channel):
|
||||
name = None # 登录的用户名
|
||||
user_id = None # 登录的用户id
|
||||
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
|
||||
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
|
||||
lock = threading.Lock() # 用于控制对sessions的访问
|
||||
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
_thread = threading.Thread(target=self.consume)
|
||||
_thread.setDaemon(True)
|
||||
_thread.start()
|
||||
|
||||
|
||||
# 根据消息构造context,消息内容相关的触发项写在这里
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
@@ -38,7 +48,7 @@ class ChatChannel(Channel):
|
||||
if first_in: # context首次传入时,receiver是None,根据类型设置receiver
|
||||
config = conf()
|
||||
cmsg = context['msg']
|
||||
if cmsg.from_user_id == self.user_id:
|
||||
if cmsg.from_user_id == self.user_id and not config.get('trigger_by_self', False):
|
||||
logger.debug("[WX]self message skipped")
|
||||
return None
|
||||
if context["isgroup"]:
|
||||
@@ -70,17 +80,21 @@ class ChatChannel(Channel):
|
||||
# 校验关键字
|
||||
match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
|
||||
match_contain = check_contain(content, conf().get('group_chat_keyword'))
|
||||
flag = False
|
||||
if match_prefix is not None or match_contain is not None:
|
||||
flag = True
|
||||
if match_prefix:
|
||||
content = content.replace(match_prefix, '', 1).strip()
|
||||
elif context['msg'].is_at and not conf().get("group_at_off", False):
|
||||
logger.info("[WX]receive group at, continue")
|
||||
if context['msg'].is_at:
|
||||
logger.info("[WX]receive group at")
|
||||
if not conf().get("group_at_off", False):
|
||||
flag = True
|
||||
pattern = f'@{self.name}(\u2005|\u0020)'
|
||||
content = re.sub(pattern, r'', content)
|
||||
elif context["origin_ctype"] == ContextType.VOICE:
|
||||
logger.info("[WX]receive group voice, checkprefix didn't match")
|
||||
return None
|
||||
else:
|
||||
|
||||
if not flag:
|
||||
if context["origin_ctype"] == ContextType.VOICE:
|
||||
logger.info("[WX]receive group voice, but checkprefix didn't match")
|
||||
return None
|
||||
else: # 单聊
|
||||
match_prefix = check_prefix(content, conf().get('single_chat_prefix'))
|
||||
@@ -98,14 +112,14 @@ class ChatChannel(Channel):
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content
|
||||
if 'desire_rtype' not in context and conf().get('always_reply_voice'):
|
||||
context['desire_rtype'] = ReplyType.VOICE
|
||||
elif context.type == ContextType.VOICE:
|
||||
if 'desire_rtype' not in context and conf().get('voice_reply_voice'):
|
||||
context['desire_rtype'] = ReplyType.VOICE
|
||||
|
||||
|
||||
return context
|
||||
|
||||
# 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类
|
||||
def _handle(self, context: Context):
|
||||
if context is None or not context.content:
|
||||
return
|
||||
@@ -143,9 +157,11 @@ class ChatChannel(Channel):
|
||||
# 删除临时文件
|
||||
try:
|
||||
os.remove(file_path)
|
||||
os.remove(wav_path)
|
||||
if wav_path != file_path:
|
||||
os.remove(wav_path)
|
||||
except Exception as e:
|
||||
logger.warning("[WX]delete temp file error: " + str(e))
|
||||
pass
|
||||
# logger.warning("[WX]delete temp file error: " + str(e))
|
||||
|
||||
if reply.type == ReplyType.TEXT:
|
||||
new_context = self._compose_context(
|
||||
@@ -194,18 +210,89 @@ class ChatChannel(Channel):
|
||||
'channel': self, 'context': context, 'reply': reply}))
|
||||
reply = e_context['reply']
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context))
|
||||
logger.debug('[WX] ready to send reply: {}, context: {}'.format(reply, context))
|
||||
self._send(reply, context)
|
||||
|
||||
def _send(self, reply: Reply, context: Context, retry_cnt = 0):
|
||||
try:
|
||||
self.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error('[WX] sendMsg error: {}'.format(e))
|
||||
logger.error('[WX] sendMsg error: {}'.format(str(e)))
|
||||
if isinstance(e, NotImplementedError):
|
||||
return
|
||||
logger.exception(e)
|
||||
if retry_cnt < 2:
|
||||
time.sleep(3+3*retry_cnt)
|
||||
self._send(reply, context, retry_cnt+1)
|
||||
|
||||
def thread_pool_callback(self, session_id):
|
||||
def func(worker:Future):
|
||||
try:
|
||||
worker_exception = worker.exception()
|
||||
if worker_exception:
|
||||
logger.exception("Worker return exception: {}".format(worker_exception))
|
||||
except CancelledError as e:
|
||||
logger.info("Worker cancelled, session_id = {}".format(session_id))
|
||||
except Exception as e:
|
||||
logger.exception("Worker raise exception: {}".format(e))
|
||||
with self.lock:
|
||||
self.sessions[session_id][1].release()
|
||||
return func
|
||||
|
||||
def produce(self, context: Context):
|
||||
session_id = context['session_id']
|
||||
with self.lock:
|
||||
if session_id not in self.sessions:
|
||||
self.sessions[session_id] = [Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 1))]
|
||||
if context.type == ContextType.TEXT and context.content.startswith("#"):
|
||||
self.sessions[session_id][0].putleft(context) # 优先处理管理命令
|
||||
else:
|
||||
self.sessions[session_id][0].put(context)
|
||||
|
||||
# 消费者函数,单独线程,用于从消息队列中取出消息并处理
|
||||
def consume(self):
|
||||
while True:
|
||||
with self.lock:
|
||||
session_ids = list(self.sessions.keys())
|
||||
for session_id in session_ids:
|
||||
context_queue, semaphore = self.sessions[session_id]
|
||||
if semaphore.acquire(blocking = False): # 等线程处理完毕才能删除
|
||||
if not context_queue.empty():
|
||||
context = context_queue.get()
|
||||
logger.debug("[WX] consume context: {}".format(context))
|
||||
future:Future = self.handler_pool.submit(self._handle, context)
|
||||
future.add_done_callback(self.thread_pool_callback(session_id))
|
||||
if session_id not in self.futures:
|
||||
self.futures[session_id] = []
|
||||
self.futures[session_id].append(future)
|
||||
elif semaphore._initial_value == semaphore._value+1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
|
||||
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
|
||||
assert len(self.futures[session_id]) == 0, "thread pool error"
|
||||
del self.sessions[session_id]
|
||||
else:
|
||||
semaphore.release()
|
||||
time.sleep(0.1)
|
||||
|
||||
# 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务
|
||||
def cancel_session(self, session_id):
|
||||
with self.lock:
|
||||
if session_id in self.sessions:
|
||||
for future in self.futures[session_id]:
|
||||
future.cancel()
|
||||
cnt = self.sessions[session_id][0].qsize()
|
||||
if cnt>0:
|
||||
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
|
||||
self.sessions[session_id][0] = Dequeue()
|
||||
|
||||
def cancel_all_session(self):
|
||||
with self.lock:
|
||||
for session_id in self.sessions:
|
||||
for future in self.futures[session_id]:
|
||||
future.cancel()
|
||||
cnt = self.sessions[session_id][0].qsize()
|
||||
if cnt>0:
|
||||
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
|
||||
self.sessions[session_id][0] = Dequeue()
|
||||
|
||||
|
||||
def check_prefix(content, prefix_list):
|
||||
|
||||
@@ -5,6 +5,7 @@ wechat channel
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
import requests
|
||||
import io
|
||||
import time
|
||||
@@ -17,18 +18,10 @@ from lib import itchat
|
||||
from lib.itchat.content import *
|
||||
from bridge.reply import *
|
||||
from bridge.context import *
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from config import conf
|
||||
from common.time_check import time_checker
|
||||
from common.expired_dict import ExpiredDict
|
||||
from plugins import *
|
||||
thread_pool = ThreadPoolExecutor(max_workers=8)
|
||||
|
||||
def thread_pool_callback(worker):
|
||||
worker_exception = worker.exception()
|
||||
if worker_exception:
|
||||
logger.exception("Worker return exception: {}".format(worker_exception))
|
||||
|
||||
|
||||
@itchat.msg_register(TEXT)
|
||||
def handler_single_msg(msg):
|
||||
@@ -64,6 +57,39 @@ def _check(func):
|
||||
return func(self, cmsg)
|
||||
return wrapper
|
||||
|
||||
#可用的二维码生成接口
|
||||
#https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com
|
||||
#https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com
|
||||
def qrCallback(uuid,status,qrcode):
|
||||
# logger.debug("qrCallback: {} {}".format(uuid,status))
|
||||
if status == '0':
|
||||
try:
|
||||
from PIL import Image
|
||||
img = Image.open(io.BytesIO(qrcode))
|
||||
_thread = threading.Thread(target=img.show, args=("QRCode",))
|
||||
_thread.setDaemon(True)
|
||||
_thread.start()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
import qrcode
|
||||
url = f"https://login.weixin.qq.com/l/{uuid}"
|
||||
|
||||
qr_api1="https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
|
||||
qr_api2="https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url)
|
||||
qr_api3="https://api.pwmqr.com/qrcode/create/?url={}".format(url)
|
||||
qr_api4="https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url)
|
||||
print("You can also scan QRCode in any website below:")
|
||||
print(qr_api3)
|
||||
print(qr_api4)
|
||||
print(qr_api2)
|
||||
print(qr_api1)
|
||||
|
||||
qr = qrcode.QRCode(border=1)
|
||||
qr.add_data(url)
|
||||
qr.make(fit=True)
|
||||
qr.print_ascii(invert=True)
|
||||
|
||||
@singleton
|
||||
class WechatChannel(ChatChannel):
|
||||
def __init__(self):
|
||||
@@ -76,13 +102,13 @@ class WechatChannel(ChatChannel):
|
||||
# login by scan QRCode
|
||||
hotReload = conf().get('hot_reload', False)
|
||||
try:
|
||||
itchat.auto_login(enableCmdQR=2, hotReload=hotReload)
|
||||
itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback)
|
||||
except Exception as e:
|
||||
if hotReload:
|
||||
logger.error("Hot reload failed, try to login without hot reload")
|
||||
itchat.logout()
|
||||
os.remove("itchat.pkl")
|
||||
itchat.auto_login(enableCmdQR=2, hotReload=hotReload)
|
||||
itchat.auto_login(enableCmdQR=2, hotReload=hotReload, qrCallback=qrCallback)
|
||||
else:
|
||||
raise e
|
||||
self.user_id = itchat.instance.storageClass.userName
|
||||
@@ -111,7 +137,7 @@ class WechatChannel(ChatChannel):
|
||||
logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback)
|
||||
self.produce(context)
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
@@ -119,7 +145,7 @@ class WechatChannel(ChatChannel):
|
||||
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback)
|
||||
self.produce(context)
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
@@ -127,7 +153,7 @@ class WechatChannel(ChatChannel):
|
||||
logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=True, msg=cmsg)
|
||||
if context:
|
||||
thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback)
|
||||
self.produce(context)
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
@@ -137,7 +163,7 @@ class WechatChannel(ChatChannel):
|
||||
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=True, msg=cmsg)
|
||||
if context:
|
||||
thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback)
|
||||
self.produce(context)
|
||||
|
||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
||||
def send(self, reply: Reply, context: Context):
|
||||
|
||||
@@ -5,7 +5,6 @@ wechaty channel
|
||||
Python Wechaty - https://github.com/wechaty/python-wechaty
|
||||
"""
|
||||
import base64
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
@@ -18,31 +17,30 @@ from bridge.context import *
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.wechat.wechaty_message import WechatyMessage
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
try:
|
||||
from voice.audio_convert import mp3_to_sil
|
||||
from voice.audio_convert import any_to_sil
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=8)
|
||||
def thread_pool_callback(worker):
|
||||
worker_exception = worker.exception()
|
||||
if worker_exception:
|
||||
logger.exception("Worker return exception: {}".format(worker_exception))
|
||||
@singleton
|
||||
class WechatyChannel(ChatChannel):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
super().__init__()
|
||||
|
||||
def startup(self):
|
||||
asyncio.run(self.main())
|
||||
|
||||
async def main(self):
|
||||
config = conf()
|
||||
token = config.get('wechaty_puppet_service_token')
|
||||
os.environ['WECHATY_PUPPET_SERVICE_TOKEN'] = token
|
||||
os.environ['WECHATY_LOG']="warn"
|
||||
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
|
||||
asyncio.run(self.main())
|
||||
|
||||
async def main(self):
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
#将asyncio的loop传入处理线程
|
||||
self.handler_pool._initializer= lambda: asyncio.set_event_loop(loop)
|
||||
self.bot = Wechaty()
|
||||
self.bot.on('login', self.on_login)
|
||||
self.bot.on('message', self.on_message)
|
||||
@@ -72,18 +70,12 @@ class WechatyChannel(ChatChannel):
|
||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
voiceLength = None
|
||||
if reply.content.endswith('.mp3'):
|
||||
mp3_file = reply.content
|
||||
sil_file = os.path.splitext(mp3_file)[0] + '.sil'
|
||||
voiceLength = mp3_to_sil(mp3_file, sil_file)
|
||||
try:
|
||||
os.remove(mp3_file)
|
||||
except Exception as e:
|
||||
pass
|
||||
elif reply.content.endswith('.sil'):
|
||||
sil_file = reply.content
|
||||
else:
|
||||
raise Exception('voice file must be mp3 or sil format')
|
||||
file_path = reply.content
|
||||
sil_file = os.path.splitext(file_path)[0] + '.sil'
|
||||
voiceLength = int(any_to_sil(file_path, sil_file))
|
||||
if voiceLength >= 60000:
|
||||
voiceLength = 60000
|
||||
logger.info('[WX] voice too long, length={}, set to 60s'.format(voiceLength))
|
||||
# 发送语音
|
||||
t = int(time.time())
|
||||
msg = FileBox.from_file(sil_file, name=str(t) + '.sil')
|
||||
@@ -91,7 +83,9 @@ class WechatyChannel(ChatChannel):
|
||||
msg.metadata['voiceLength'] = voiceLength
|
||||
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
|
||||
try:
|
||||
os.remove(sil_file)
|
||||
os.remove(file_path)
|
||||
if sil_file != file_path:
|
||||
os.remove(sil_file)
|
||||
except Exception as e:
|
||||
pass
|
||||
logger.info('[WX] sendVoice={}, receiver={}'.format(reply.content, receiver))
|
||||
@@ -123,14 +117,9 @@ class WechatyChannel(ChatChannel):
|
||||
return
|
||||
logger.debug('[WX] message:{}'.format(cmsg))
|
||||
room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None
|
||||
|
||||
isgroup = room is not None
|
||||
ctype = cmsg.ctype
|
||||
context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg)
|
||||
if context:
|
||||
logger.info('[WX] receiveMsg={}, context={}'.format(cmsg, context))
|
||||
thread_pool.submit(self._handle_loop, context, asyncio.get_event_loop()).add_done_callback(thread_pool_callback)
|
||||
|
||||
def _handle_loop(self,context,loop):
|
||||
asyncio.set_event_loop(loop)
|
||||
self._handle(context)
|
||||
self.produce(context)
|
||||
@@ -0,0 +1,33 @@
|
||||
|
||||
from queue import Full, Queue
|
||||
from time import monotonic as time
|
||||
|
||||
# add implementation of putleft to Queue
|
||||
class Dequeue(Queue):
|
||||
def putleft(self, item, block=True, timeout=None):
|
||||
with self.not_full:
|
||||
if self.maxsize > 0:
|
||||
if not block:
|
||||
if self._qsize() >= self.maxsize:
|
||||
raise Full
|
||||
elif timeout is None:
|
||||
while self._qsize() >= self.maxsize:
|
||||
self.not_full.wait()
|
||||
elif timeout < 0:
|
||||
raise ValueError("'timeout' must be a non-negative number")
|
||||
else:
|
||||
endtime = time() + timeout
|
||||
while self._qsize() >= self.maxsize:
|
||||
remaining = endtime - time()
|
||||
if remaining <= 0.0:
|
||||
raise Full
|
||||
self.not_full.wait(remaining)
|
||||
self._putleft(item)
|
||||
self.unfinished_tasks += 1
|
||||
self.not_empty.notify()
|
||||
|
||||
def put_nowait(self, item):
|
||||
return self.put(item, block=False)
|
||||
|
||||
def _putleft(self, item):
|
||||
self.queue.appendleft(item)
|
||||
@@ -8,6 +8,10 @@ def _get_logger():
|
||||
console_handle = logging.StreamHandler(sys.stdout)
|
||||
console_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'))
|
||||
file_handle = logging.FileHandler('run.log', encoding='utf-8')
|
||||
file_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'))
|
||||
log.addHandler(file_handle)
|
||||
log.addHandler(console_handle)
|
||||
return log
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from common.log import logger
|
||||
|
||||
@@ -25,7 +26,9 @@ available_setting = {
|
||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
|
||||
"group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
||||
"trigger_by_self": False, # 是否允许机器人触发
|
||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
||||
"concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序
|
||||
|
||||
# chatgpt会话参数
|
||||
"expires_in_seconds": 3600, # 无操作会话的过期时间
|
||||
@@ -36,27 +39,32 @@ available_setting = {
|
||||
"rate_limit_chatgpt": 20, # chatgpt的调用频率限制
|
||||
"rate_limit_dalle": 50, # openai dalle的调用频率限制
|
||||
|
||||
|
||||
# chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create
|
||||
"temperature": 0.9,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"request_timeout": 30, # chatgpt请求超时时间
|
||||
|
||||
# 语音设置
|
||||
"speech_recognition": False, # 是否开启语音识别
|
||||
"group_speech_recognition": False, # 是否开启群组语音识别
|
||||
"voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key
|
||||
"voice_to_text": "openai", # 语音识别引擎,支持openai,google
|
||||
"text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline)
|
||||
"always_reply_voice": False, # 是否一直使用语音回复
|
||||
"voice_to_text": "openai", # 语音识别引擎,支持openai,google,azure
|
||||
"text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline),azure
|
||||
|
||||
# baidu api的配置, 使用百度语音识别和语音合成时需要
|
||||
# baidu 语音api配置, 使用百度语音识别和语音合成时需要
|
||||
"baidu_app_id": "",
|
||||
"baidu_api_key": "",
|
||||
"baidu_secret_key": "",
|
||||
# 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场
|
||||
"baidu_dev_pid": "1536",
|
||||
|
||||
# azure 语音api配置, 使用azure语音识别和语音合成时需要
|
||||
"azure_voice_api_key": "",
|
||||
"azure_voice_region": "japaneast",
|
||||
|
||||
# 服务时间限制,目前支持itchat
|
||||
"chat_time_module": False, # 是否开启服务时间限制
|
||||
"chat_start_time": "00:00", # 服务开始时间
|
||||
@@ -69,11 +77,12 @@ available_setting = {
|
||||
"wechaty_puppet_service_token": "", # wechaty的token
|
||||
|
||||
# chatgpt指令自定义触发词
|
||||
"clear_memory_commands": ['#清除记忆'], # 重置会话指令
|
||||
"clear_memory_commands": ['#清除记忆'], # 重置会话指令,必须以#开头
|
||||
|
||||
# channel配置
|
||||
"channel_type": "wx", # 通道类型,支持wx,wxy和terminal
|
||||
|
||||
"debug": False, # 是否开启debug模式,开启后会打印更多日志
|
||||
|
||||
}
|
||||
|
||||
@@ -131,6 +140,10 @@ def load_config():
|
||||
else:
|
||||
config[name] = value
|
||||
|
||||
if config["debug"]:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.debug("[INIT] set log level to DEBUG")
|
||||
|
||||
logger.info("[INIT] load config: {}".format(config))
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
FROM python:3.10-slim
|
||||
|
||||
LABEL maintainer="foo@bar.com"
|
||||
ARG TZ='Asia/Shanghai'
|
||||
|
||||
ARG CHATGPT_ON_WECHAT_VER
|
||||
|
||||
ENV BUILD_PREFIX=/app
|
||||
|
||||
ADD . ${BUILD_PREFIX}
|
||||
|
||||
RUN apt-get update \
|
||||
&&apt-get install -y --no-install-recommends bash \
|
||||
ffmpeg espeak \
|
||||
&& cd ${BUILD_PREFIX} \
|
||||
&& cp config-template.json config.json \
|
||||
&& /usr/local/bin/python -m pip install --no-cache --upgrade pip \
|
||||
&& pip install --no-cache -r requirements.txt \
|
||||
&& pip install azure-cognitiveservices-speech
|
||||
|
||||
WORKDIR ${BUILD_PREFIX}
|
||||
|
||||
ADD docker/entrypoint.sh /entrypoint.sh
|
||||
|
||||
RUN chmod +x /entrypoint.sh \
|
||||
&& groupadd -r noroot \
|
||||
&& useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \
|
||||
&& chown -R noroot:noroot ${BUILD_PREFIX}
|
||||
|
||||
USER noroot
|
||||
|
||||
ENTRYPOINT ["docker/entrypoint.sh"]
|
||||
@@ -1,4 +1,4 @@
|
||||
#!/bin/bash
|
||||
|
||||
cd .. && docker build -f Dockerfile \
|
||||
cd .. && docker build -f docker/Dockerfile.latest \
|
||||
-t zhayujie/chatgpt-on-wechat .
|
||||
+10
-7
@@ -1,6 +1,6 @@
|
||||
## 插件化初衷
|
||||
|
||||
之前未插件化的代码耦合程度高,如果要定制一些个性化功能(如流量控制、接入`NovelAI`画图平台等),需要了解代码主体,避免影响到其他的功能。在实现多个功能后,不但无法调整功能的优先级顺序,功能的配置项也会变得非常混乱。
|
||||
之前未插件化的代码耦合程度高,如果要定制一些个性化功能(如流量控制、接入`NovelAI`画图平台等),需要了解代码主体,避免影响到其他的功能。多个功能同时存在时,无法调整功能的优先级顺序,功能配置项也非常混乱。
|
||||
|
||||
此时插件化应声而出。
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
- [x] 插件化能够自由开关和调整优先级。
|
||||
- [x] 每个插件可在插件文件夹内维护独立的配置文件,方便代码的测试和调试,可以在独立的仓库开发插件。
|
||||
|
||||
PS: 插件目前仅支持`itchat`
|
||||
PS: 插件目前支持`itchat`和`wechaty`
|
||||
|
||||
## 插件化实现
|
||||
|
||||
@@ -101,7 +101,7 @@ PS: 插件目前仅支持`itchat`
|
||||
|
||||
根据`Context`和回复`Reply`的类型,对回复的内容进行装饰。目前的装饰有以下两种:
|
||||
|
||||
- `TEXT`文本回复,根据是否在群聊中来决定是艾特接收方还是添加回复的前缀。
|
||||
- `TEXT`文本回复:如果这次消息需要的回复是`VOICE`,进行文字转语音回复之后再次装饰。 否则根据是否在群聊中来决定是艾特接收方还是添加回复的前缀。
|
||||
|
||||
- `INFO`或`ERROR`类型,会在消息前添加对应的系统提示字样。
|
||||
|
||||
@@ -110,8 +110,11 @@ PS: 插件目前仅支持`itchat`
|
||||
```python
|
||||
if reply.type == ReplyType.TEXT:
|
||||
reply_text = reply.content
|
||||
if context.get('desire_rtype') == ReplyType.VOICE:
|
||||
reply = super().build_text_to_voice(reply.content)
|
||||
return self._decorate_reply(context, reply)
|
||||
if context['isgroup']:
|
||||
reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip()
|
||||
reply_text = '@' + context['msg'].actual_user_nickname + ' ' + reply_text.strip()
|
||||
reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
|
||||
else:
|
||||
reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
|
||||
@@ -213,11 +216,11 @@ class Hello(Plugin):
|
||||
if content == "Hello":
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
msg = e_context['context']['msg']
|
||||
msg:ChatMessage = e_context['context']['msg']
|
||||
if e_context['context']['isgroup']:
|
||||
reply.content = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group")
|
||||
reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
|
||||
else:
|
||||
reply.content = "Hello, " + msg['User'].get('NickName', "My friend")
|
||||
reply.content = f"Hello, {msg.from_user_nickname}"
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
||||
if content == "End":
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Tuple
|
||||
from bridge.bridge import Bridge
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from config import load_config
|
||||
from config import conf, load_config
|
||||
import plugins
|
||||
from plugins import *
|
||||
from common import const
|
||||
@@ -126,7 +126,14 @@ class Godcmd(Plugin):
|
||||
else:
|
||||
with open(config_path,"r") as f:
|
||||
gconf=json.load(f)
|
||||
|
||||
|
||||
custom_commands = conf().get("clear_memory_commands", [])
|
||||
for custom_command in custom_commands:
|
||||
if custom_command and custom_command.startswith("#"):
|
||||
custom_command = custom_command[1:]
|
||||
if custom_command and custom_command not in COMMANDS["reset"]["alias"]:
|
||||
COMMANDS["reset"]["alias"].append(custom_command)
|
||||
|
||||
self.password = gconf["password"]
|
||||
self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证 TODO: 用户名每次都会变,目前不可用
|
||||
self.isrunning = True # 机器人是否运行中
|
||||
@@ -146,6 +153,7 @@ class Godcmd(Plugin):
|
||||
logger.debug("[Godcmd] on_handle_context. content: %s" % content)
|
||||
if content.startswith("#"):
|
||||
# msg = e_context['context']['msg']
|
||||
channel = e_context['channel']
|
||||
user = e_context['context']['receiver']
|
||||
session_id = e_context['context']['session_id']
|
||||
isgroup = e_context['context']['isgroup']
|
||||
@@ -181,6 +189,7 @@ class Godcmd(Plugin):
|
||||
elif cmd == "reset":
|
||||
if bottype in (const.CHATGPT, const.OPEN_AI):
|
||||
bot.sessions.clear_session(session_id)
|
||||
channel.cancel_session(session_id)
|
||||
ok, result = True, "会话已重置"
|
||||
else:
|
||||
ok, result = False, "当前对话机器人不支持重置会话"
|
||||
@@ -202,6 +211,7 @@ class Godcmd(Plugin):
|
||||
ok, result = True, "配置已重载"
|
||||
elif cmd == "resetall":
|
||||
if bottype in (const.CHATGPT, const.OPEN_AI):
|
||||
channel.cancel_all_session()
|
||||
bot.sessions.clear_all_session()
|
||||
ok, result = True, "重置所有会话成功"
|
||||
else:
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_message import ChatMessage
|
||||
import plugins
|
||||
from plugins import *
|
||||
from common.log import logger
|
||||
@@ -24,11 +25,11 @@ class Hello(Plugin):
|
||||
if content == "Hello":
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
msg = e_context['context']['msg']
|
||||
msg:ChatMessage = e_context['context']['msg']
|
||||
if e_context['context']['isgroup']:
|
||||
reply.content = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group")
|
||||
reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
|
||||
else:
|
||||
reply.content = "Hello, " + msg['User'].get('NickName', "My friend")
|
||||
reply.content = f"Hello, {msg.from_user_nickname}"
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ PyQRCode>=1.2.1
|
||||
pysilk>=0.0.1
|
||||
pysilk_mod>=1.6.0
|
||||
pyttsx3>=2.90
|
||||
qrcode>=7.4.2
|
||||
requests>=2.28.2
|
||||
webuiapi>=0.6.2
|
||||
wechaty>=0.10.7
|
||||
|
||||
+53
-19
@@ -1,7 +1,23 @@
|
||||
import shutil
|
||||
import wave
|
||||
import pysilk
|
||||
from pydub import AudioSegment
|
||||
|
||||
sil_supports=[8000, 12000, 16000, 24000, 32000, 44100, 48000] # slk转wav时,支持的采样率
|
||||
def find_closest_sil_supports(sample_rate):
|
||||
"""
|
||||
找到最接近的支持的采样率
|
||||
"""
|
||||
if sample_rate in sil_supports:
|
||||
return sample_rate
|
||||
closest = 0
|
||||
mindiff = 9999999
|
||||
for rate in sil_supports:
|
||||
diff = abs(rate - sample_rate)
|
||||
if diff < mindiff:
|
||||
closest = rate
|
||||
mindiff = diff
|
||||
return closest
|
||||
|
||||
def get_pcm_from_wav(wav_path):
|
||||
"""
|
||||
@@ -13,6 +29,30 @@ def get_pcm_from_wav(wav_path):
|
||||
wav = wave.open(wav_path, "rb")
|
||||
return wav.readframes(wav.getnframes())
|
||||
|
||||
def any_to_wav(any_path, wav_path):
|
||||
"""
|
||||
把任意格式转成wav文件
|
||||
"""
|
||||
if any_path.endswith('.wav'):
|
||||
shutil.copy2(any_path, wav_path)
|
||||
return
|
||||
if any_path.endswith('.sil') or any_path.endswith('.silk') or any_path.endswith('.slk'):
|
||||
return sil_to_wav(any_path, wav_path)
|
||||
audio = AudioSegment.from_file(any_path)
|
||||
audio.export(wav_path, format="wav")
|
||||
|
||||
def any_to_sil(any_path, sil_path):
|
||||
"""
|
||||
把任意格式转成sil文件
|
||||
"""
|
||||
if any_path.endswith('.sil') or any_path.endswith('.silk') or any_path.endswith('.slk'):
|
||||
shutil.copy2(any_path, sil_path)
|
||||
return 10000
|
||||
if any_path.endswith('.wav'):
|
||||
return pcm_to_sil(any_path, sil_path)
|
||||
if any_path.endswith('.mp3'):
|
||||
return mp3_to_sil(any_path, sil_path)
|
||||
raise NotImplementedError("Not support file type: {}".format(any_path))
|
||||
|
||||
def mp3_to_wav(mp3_path, wav_path):
|
||||
"""
|
||||
@@ -21,46 +61,40 @@ def mp3_to_wav(mp3_path, wav_path):
|
||||
audio = AudioSegment.from_mp3(mp3_path)
|
||||
audio.export(wav_path, format="wav")
|
||||
|
||||
def any_to_wav(any_path, wav_path):
|
||||
"""
|
||||
把任意格式转成wav文件
|
||||
"""
|
||||
if any_path.endswith('.wav'):
|
||||
return
|
||||
if any_path.endswith('.sil') or any_path.endswith('.silk') or any_path.endswith('.slk'):
|
||||
return sil_to_wav(any_path, wav_path)
|
||||
audio = AudioSegment.from_file(any_path)
|
||||
audio.export(wav_path, format="wav")
|
||||
|
||||
def pcm_to_silk(pcm_path, silk_path):
|
||||
def pcm_to_sil(pcm_path, silk_path):
|
||||
"""
|
||||
wav 文件转成 silk
|
||||
return 声音长度,毫秒
|
||||
"""
|
||||
audio = AudioSegment.from_wav(pcm_path)
|
||||
wav_data = audio.raw_data
|
||||
rate = find_closest_sil_supports(audio.frame_rate)
|
||||
# Convert to PCM_s16
|
||||
pcm_s16 = audio.set_sample_width(2)
|
||||
pcm_s16 = pcm_s16.set_frame_rate(rate)
|
||||
wav_data = pcm_s16.raw_data
|
||||
silk_data = pysilk.encode(
|
||||
wav_data, data_rate=audio.frame_rate, sample_rate=audio.frame_rate)
|
||||
wav_data, data_rate=rate, sample_rate=rate)
|
||||
with open(silk_path, "wb") as f:
|
||||
f.write(silk_data)
|
||||
return audio.duration_seconds * 1000
|
||||
|
||||
|
||||
def mp3_to_sil(mp3_path, silk_path):
|
||||
"""
|
||||
mp3 文件转成 silk
|
||||
return 声音长度,毫秒
|
||||
"""
|
||||
audio = AudioSegment.from_mp3(mp3_path)
|
||||
wav_data = audio.raw_data
|
||||
silk_data = pysilk.encode(
|
||||
wav_data, data_rate=audio.frame_rate, sample_rate=audio.frame_rate)
|
||||
rate = find_closest_sil_supports(audio.frame_rate)
|
||||
# Convert to PCM_s16
|
||||
pcm_s16 = audio.set_sample_width(2)
|
||||
pcm_s16 = pcm_s16.set_frame_rate(rate)
|
||||
wav_data = pcm_s16.raw_data
|
||||
silk_data = pysilk.encode(wav_data, data_rate=rate, sample_rate=rate)
|
||||
# Save the silk file
|
||||
with open(silk_path, "wb") as f:
|
||||
f.write(silk_data)
|
||||
return audio.duration_seconds * 1000
|
||||
|
||||
|
||||
def sil_to_wav(silk_path, wav_path, rate: int = 24000):
|
||||
"""
|
||||
silk 文件转 wav
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
|
||||
"""
|
||||
azure voice service
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import azure.cognitiveservices.speech as speechsdk
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from voice.voice import Voice
|
||||
from config import conf
|
||||
"""
|
||||
Azure voice
|
||||
主目录设置文件中需填写azure_voice_api_key和azure_voice_region
|
||||
|
||||
查看可用的 voice: https://speech.microsoft.com/portal/voicegallery
|
||||
|
||||
"""
|
||||
|
||||
class AzureVoice(Voice):
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
curdir = os.path.dirname(__file__)
|
||||
config_path = os.path.join(curdir, "config.json")
|
||||
config = None
|
||||
if not os.path.exists(config_path): #如果没有配置文件,创建本地配置文件
|
||||
config = { "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", "speech_recognition_language": "zh-CN"}
|
||||
with open(config_path, "w") as fw:
|
||||
json.dump(config, fw, indent=4)
|
||||
else:
|
||||
with open(config_path, "r") as fr:
|
||||
config = json.load(fr)
|
||||
self.api_key = conf().get('azure_voice_api_key')
|
||||
self.api_region = conf().get('azure_voice_region')
|
||||
self.speech_config = speechsdk.SpeechConfig(subscription=self.api_key, region=self.api_region)
|
||||
self.speech_config.speech_synthesis_voice_name = config["speech_synthesis_voice_name"]
|
||||
self.speech_config.speech_recognition_language = config["speech_recognition_language"]
|
||||
except Exception as e:
|
||||
logger.warn("AzureVoice init failed: %s, ignore " % e)
|
||||
|
||||
def voiceToText(self, voice_file):
|
||||
audio_config = speechsdk.AudioConfig(filename=voice_file)
|
||||
speech_recognizer = speechsdk.SpeechRecognizer(speech_config=self.speech_config, audio_config=audio_config)
|
||||
result = speech_recognizer.recognize_once()
|
||||
if result.reason == speechsdk.ResultReason.RecognizedSpeech:
|
||||
logger.info('[Azure] voiceToText voice file name={} text={}'.format(voice_file, result.text))
|
||||
reply = Reply(ReplyType.TEXT, result.text)
|
||||
else:
|
||||
logger.error('[Azure] voiceToText error, result={}'.format(result))
|
||||
reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败")
|
||||
return reply
|
||||
|
||||
def textToVoice(self, text):
|
||||
fileName = TmpDir().path() + 'reply-' + str(int(time.time())) + '.wav'
|
||||
audio_config = speechsdk.AudioConfig(filename=fileName)
|
||||
speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=self.speech_config, audio_config=audio_config)
|
||||
result = speech_synthesizer.speak_text(text)
|
||||
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
|
||||
logger.info(
|
||||
'[Azure] textToVoice text={} voice file name={}'.format(text, fileName))
|
||||
reply = Reply(ReplyType.VOICE, fileName)
|
||||
else:
|
||||
logger.error('[Azure] textToVoice error, result={}'.format(result))
|
||||
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
|
||||
return reply
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural",
|
||||
"speech_recognition_language": "zh-CN"
|
||||
}
|
||||
@@ -80,7 +80,7 @@ class BaiduVoice(Voice):
|
||||
result = self.client.synthesis(text, self.lang, self.ctp, {
|
||||
'spd': self.spd, 'pit': self.pit, 'vol': self.vol, 'per': self.per})
|
||||
if not isinstance(result, dict):
|
||||
fileName = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
|
||||
fileName = TmpDir().path() + 'reply-' + str(int(time.time())) + '.mp3'
|
||||
with open(fileName, 'wb') as f:
|
||||
f.write(result)
|
||||
logger.info(
|
||||
|
||||
@@ -34,7 +34,7 @@ class GoogleVoice(Voice):
|
||||
return reply
|
||||
def textToVoice(self, text):
|
||||
try:
|
||||
mp3File = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
|
||||
mp3File = TmpDir().path() + 'reply-' + str(int(time.time())) + '.mp3'
|
||||
tts = gTTS(text=text, lang='zh')
|
||||
tts.save(mp3File)
|
||||
logger.info(
|
||||
|
||||
@@ -25,12 +25,12 @@ class PyttsVoice(Voice):
|
||||
|
||||
def textToVoice(self, text):
|
||||
try:
|
||||
mp3File = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
|
||||
self.engine.save_to_file(text, mp3File)
|
||||
wavFile = TmpDir().path() + 'reply-' + str(int(time.time())) + '.wav'
|
||||
self.engine.save_to_file(text, wavFile)
|
||||
self.engine.runAndWait()
|
||||
logger.info(
|
||||
'[Pytts] textToVoice text={} voice file name={}'.format(text, mp3File))
|
||||
reply = Reply(ReplyType.VOICE, mp3File)
|
||||
'[Pytts] textToVoice text={} voice file name={}'.format(text, wavFile))
|
||||
reply = Reply(ReplyType.VOICE, wavFile)
|
||||
except Exception as e:
|
||||
reply = Reply(ReplyType.ERROR, str(e))
|
||||
finally:
|
||||
|
||||
@@ -20,4 +20,7 @@ def create_voice(voice_type):
|
||||
elif voice_type == 'pytts':
|
||||
from voice.pytts.pytts_voice import PyttsVoice
|
||||
return PyttsVoice()
|
||||
elif voice_type == 'azure':
|
||||
from voice.azure.azure_voice import AzureVoice
|
||||
return AzureVoice()
|
||||
raise RuntimeError
|
||||
|
||||
Reference in New Issue
Block a user