mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-05-17 18:08:57 +08:00
Compare commits
77 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 47cc65a787 | |||
| cda9d5873d | |||
| 02cd553990 | |||
| 87df588c80 | |||
| 4ad2997717 | |||
| 50a03e7c15 | |||
| 4f3d12129c | |||
| 37a95980d4 | |||
| d9ef5a6612 | |||
| 66a81cd47c | |||
| 81edd13470 | |||
| 7a94745b8a | |||
| 06b02f5df8 | |||
| 83136e3142 | |||
| 950a9f2ee0 | |||
| a26c10fee8 | |||
| 4bcd76fe93 | |||
| 90ccb091ca | |||
| 62df27eaa1 | |||
| 349115b948 | |||
| 4fd7e4be67 | |||
| 947e892916 | |||
| d62b7d1a99 | |||
| 432b39a9c4 | |||
| 26540bfb63 | |||
| fd64f88a7e | |||
| 72994bc9ef | |||
| 7e1138af50 | |||
| 72dbddb7f7 | |||
| 10dba50843 | |||
| d6af1b5827 | |||
| 6c362a9b4b | |||
| 9a0584d649 | |||
| 5ab5211c95 | |||
| f644682be7 | |||
| ffad8e4d26 | |||
| 8f07e6304a | |||
| 834c03359f | |||
| 3e2c68ba49 | |||
| 2a21941b68 | |||
| e78886fb35 | |||
| 80bf6a0c7a | |||
| 48e066b677 | |||
| dcb9d7fc2a | |||
| 279f0f0234 | |||
| b3c8a7d8de | |||
| 1baf1a79e5 | |||
| 35160e717e | |||
| a12f2d8fbd | |||
| 6b7c17374b | |||
| 9b3585e795 | |||
| 74f383a7d4 | |||
| 820fbeed18 | |||
| f76e8d9a77 | |||
| 5b85e60d5d | |||
| 24de670c2c | |||
| 42aca71763 | |||
| 9b4ef85174 | |||
| 9b389ffc33 | |||
| b3cb81aa52 | |||
| 61865bc408 | |||
| c2ea6214a9 | |||
| b6684fe7a3 | |||
| b50ebc05a0 | |||
| dbb0648c39 | |||
| 5fc0987cc3 | |||
| 7c4037147c | |||
| f76cb1231e | |||
| 6701d8c5e6 | |||
| ff3d143185 | |||
| ea95ab9062 | |||
| 38c901a1c5 | |||
| 0c9753b7cd | |||
| 721b36c7f7 | |||
| f8e0716474 | |||
| 3d428ee844 | |||
| a3be1fcd8f |
@@ -1,9 +1,11 @@
|
||||
### 前置确认
|
||||
|
||||
1. 网络能够访问openai接口
|
||||
2. python 已安装:版本在 3.7 ~ 3.10 之间,依赖已安装
|
||||
3. 在已有 issue 中未搜索到类似问题
|
||||
4. [FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) 中无类似问题
|
||||
2. python 已安装:版本在 3.7 ~ 3.10 之间
|
||||
3. `git pull` 拉取最新代码
|
||||
4. 执行`pip3 install -r requirements.txt`,检查依赖是否满足
|
||||
5. 在已有 issue 中未搜索到类似问题
|
||||
6. [FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) 中无类似问题
|
||||
|
||||
|
||||
### 问题描述
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
# This workflow uses actions that are not certified by GitHub.
|
||||
# They are provided by a third-party and are governed by
|
||||
# separate terms of service, privacy policy, and support
|
||||
# documentation.
|
||||
|
||||
# GitHub recommends pinning actions to a commit SHA.
|
||||
# To get a newer version, you will need to update the SHA.
|
||||
# You can also reference a tag or branch, but the action may change without warning.
|
||||
|
||||
name: Create and publish a Docker image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ['master']
|
||||
create:
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
file: ./docker/Dockerfile.latest
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
- uses: actions/delete-package-versions@v4
|
||||
with:
|
||||
package-name: 'chatgpt-on-wechat'
|
||||
package-type: 'container'
|
||||
min-versions-to-keep: 10
|
||||
delete-only-untagged-versions: 'true'
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -0,0 +1,3 @@
|
||||
FROM ghcr.io/zhayujie/chatgpt-on-wechat:latest
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
> ChatGPT近期以强大的对话和信息整合能力风靡全网,可以写代码、改论文、讲故事,几乎无所不能,这让人不禁有个大胆的想法,能否用他的对话模型把我们的微信打造成一个智能机器人,可以在与好友对话中给出意想不到的回应,而且再也不用担心女朋友影响我们 ~~打游戏~~ 工作了。
|
||||
|
||||
|
||||
|
||||
基于ChatGPT的微信聊天机器人,通过 [ChatGPT](https://github.com/openai/openai-python) 接口生成对话内容,使用 [itchat](https://github.com/littlecodersh/ItChat) 实现微信消息的接收和自动回复。已实现的特性如下:
|
||||
|
||||
- [x] **文本对话:** 接收私聊及群组中的微信消息,使用ChatGPT生成回复内容,完成自动回复
|
||||
@@ -11,7 +11,11 @@
|
||||
- [x] **图片生成:** 支持根据描述生成图片,并自动发送至个人聊天或群聊
|
||||
- [x] **上下文记忆**:支持多轮对话记忆,且为每个好友维护独立的上下会话
|
||||
- [x] **语音识别:** 支持接收和处理语音消息,通过文字或语音回复
|
||||
- [x] **插件化:** 支持个性化功能插件,提供角色扮演、文字冒险游戏等预设插件
|
||||
|
||||
> 快速部署:
|
||||
>
|
||||
>[](https://railway.app/template/qApznZ?referralCode=RC3znh)
|
||||
|
||||
# 更新日志
|
||||
|
||||
@@ -60,7 +64,7 @@
|
||||
|
||||
### 2.运行环境
|
||||
|
||||
支持 Linux、MacOS、Windows 系统(可在Linux服务器上长期运行),同时需安装 `Python`。
|
||||
支持 Linux、MacOS、Windows 系统(可在Linux服务器上长期运行),同时需安装 `Python`。
|
||||
> 建议Python版本在 3.7.1~3.9.X 之间,3.10及以上版本在 MacOS 可用,其他系统上不确定能否正常运行。
|
||||
|
||||
**(1) 克隆项目代码:**
|
||||
@@ -73,15 +77,18 @@ cd chatgpt-on-wechat/
|
||||
**(2) 安装核心依赖 (必选):**
|
||||
|
||||
```bash
|
||||
pip3 install itchat-uos==1.5.0.dev0
|
||||
pip3 install --upgrade openai
|
||||
pip3 install -r requirements.txt
|
||||
```
|
||||
注:`itchat-uos`使用指定版本1.5.0.dev0,`openai`使用最新版本,需高于0.27.0。
|
||||
|
||||
其中`tiktoken`要求`python`版本在3.8以上,它用于精确计算会话使用的tokens数量,可以不装但建议安装。
|
||||
|
||||
**(3) 拓展依赖 (可选):**
|
||||
|
||||
语音识别及语音回复相关依赖:[#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)。
|
||||
使用`google`或`baidu`语音识别需安装`ffmpeg`,
|
||||
|
||||
默认的`openai`语音识别不需要安装`ffmpeg`。
|
||||
|
||||
参考[#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)
|
||||
|
||||
## 配置
|
||||
|
||||
@@ -107,6 +114,7 @@ pip3 install --upgrade openai
|
||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
||||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
|
||||
"speech_recognition": false, # 是否开启语音识别
|
||||
"group_speech_recognition": false, # 是否开启群组语音识别
|
||||
"use_azure_chatgpt": false, # 是否使用Azure ChatGPT service代替openai ChatGPT service. 当设置为true时需要设置 open_ai_api_base,如 https://xxx.openai.azure.com/
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述,
|
||||
}
|
||||
@@ -127,8 +135,9 @@ pip3 install --upgrade openai
|
||||
|
||||
**3.语音识别**
|
||||
|
||||
+ 添加 `"speech_recognition": true` 将开启语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,目前只支持私聊 (注意由于语音消息无法匹配前缀,一旦开启将对所有语音自动回复);
|
||||
+ 添加 `"voice_reply_voice": true` 将开启语音回复语音,但是需要配置对应语音合成平台的key,由于itchat协议的限制,只能发送语音mp3文件,若使用wechaty则回复的是微信语音。
|
||||
+ 添加 `"speech_recognition": true` 将开启语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,该参数仅支持私聊 (注意由于语音消息无法匹配前缀,一旦开启将对所有语音自动回复,支持语音触发画图);
|
||||
+ 添加 `"group_speech_recognition": true` 将开启群组语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,参数仅支持群聊 (会匹配group_chat_prefix和group_chat_keyword, 支持语音触发画图);
|
||||
+ 添加 `"voice_reply_voice": true` 将开启语音回复语音(同时作用于私聊和群聊),但是需要配置对应语音合成平台的key,由于itchat协议的限制,只能发送语音mp3文件,若使用wechaty则回复的是微信语音。
|
||||
|
||||
**4.其他配置**
|
||||
|
||||
@@ -143,6 +152,7 @@ pip3 install --upgrade openai
|
||||
+ `hot_reload`: 程序退出后,暂存微信扫码状态,默认关闭。
|
||||
+ `character_desc` 配置中保存着你对机器人说的一段话,他会记住这段话并作为他的设定,你可以为他定制任何人格 (关于会话上下文的更多内容参考该 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/43))
|
||||
|
||||
**所有可选的配置项均在该[文件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中列出。**
|
||||
|
||||
## 运行
|
||||
|
||||
@@ -177,14 +187,11 @@ nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通
|
||||
|
||||
参考文档 [Docker部署](https://github.com/limccn/chatgpt-on-wechat/wiki/Docker%E9%83%A8%E7%BD%B2) (Contributed by [limccn](https://github.com/limccn))。
|
||||
|
||||
### 4. Railway部署
|
||||
[Use with Railway](#use-with-railway)(PaaS, Free, Stable, ✅Recommended)
|
||||
> Railway offers $5 (500 hours) of runtime per month
|
||||
1. Click the [Railway](https://railway.app/) button to go to the Railway homepage
|
||||
2. Click the `Start New Project` button.
|
||||
3. Click the `Deploy from Github repo` button.
|
||||
4. Choose your repo (you can fork this repo firstly)
|
||||
5. Set environment variable to override settings in config-template.json, such as: model, open_ai_api_base, open_ai_api_key, use_azure_chatgpt etc.
|
||||
### 4. Railway部署(✅推荐)
|
||||
> Railway每月提供5刀和最多500小时的免费额度。
|
||||
1. 进入 [Railway](https://railway.app/template/qApznZ?referralCode=RC3znh)。
|
||||
2. 点击 `Deploy Now` 按钮。
|
||||
3. 设置环境变量来重载程序运行的参数,例如`open_ai_api_key`, `character_desc`。
|
||||
|
||||
## 常见问题
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import config
|
||||
from config import conf, load_config
|
||||
from channel import channel_factory
|
||||
from common.log import logger
|
||||
|
||||
@@ -9,12 +9,12 @@ from plugins import *
|
||||
def run():
|
||||
try:
|
||||
# load config
|
||||
config.load_config()
|
||||
load_config()
|
||||
|
||||
# create channel
|
||||
channel_name='wx'
|
||||
channel_name=conf().get('channel_type', 'wx')
|
||||
channel = channel_factory.create_channel(channel_name)
|
||||
if channel_name=='wx':
|
||||
if channel_name in ['wx','wxy']:
|
||||
PluginManager().load_plugins()
|
||||
|
||||
# startup channel
|
||||
|
||||
+3
-3
@@ -6,9 +6,9 @@ from common import const
|
||||
|
||||
def create_bot(bot_type):
|
||||
"""
|
||||
create a channel instance
|
||||
:param channel_type: channel type code
|
||||
:return: channel instance
|
||||
create a bot_type instance
|
||||
:param bot_type: bot type code
|
||||
:return: bot instance
|
||||
"""
|
||||
if bot_type == const.BAIDU:
|
||||
# Baidu Unit对话接口
|
||||
|
||||
+43
-133
@@ -1,6 +1,9 @@
|
||||
# encoding:utf-8
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.session_manager import Session, SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from config import conf, load_config
|
||||
@@ -8,28 +11,27 @@ from common.log import logger
|
||||
from common.token_bucket import TokenBucket
|
||||
from common.expired_dict import ExpiredDict
|
||||
import openai
|
||||
import openai.error
|
||||
import time
|
||||
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class ChatGPTBot(Bot):
|
||||
class ChatGPTBot(Bot,OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
openai.api_key = conf().get('open_ai_api_key')
|
||||
if conf().get('open_ai_api_base'):
|
||||
openai.api_base = conf().get('open_ai_api_base')
|
||||
proxy = conf().get('proxy')
|
||||
self.sessions = SessionManager()
|
||||
if proxy:
|
||||
openai.proxy = proxy
|
||||
if conf().get('rate_limit_chatgpt'):
|
||||
self.tb4chatgpt = TokenBucket(conf().get('rate_limit_chatgpt', 20))
|
||||
if conf().get('rate_limit_dalle'):
|
||||
self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50))
|
||||
|
||||
self.sessions = SessionManager(ChatGPTSession, model= conf().get("model") or "gpt-3.5-turbo")
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[OPEN_AI] query={}".format(query))
|
||||
logger.info("[CHATGPT] query={}".format(query))
|
||||
|
||||
session_id = context['session_id']
|
||||
reply = None
|
||||
@@ -45,23 +47,23 @@ class ChatGPTBot(Bot):
|
||||
reply = Reply(ReplyType.INFO, '配置已更新')
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.build_session_query(query, session_id)
|
||||
logger.debug("[OPEN_AI] session query={}".format(session))
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[CHATGPT] session query={}".format(session.messages))
|
||||
|
||||
# if context.get('stream'):
|
||||
# # reply in stream
|
||||
# return self.reply_text_stream(query, new_query, session_id)
|
||||
|
||||
reply_content = self.reply_text(session, session_id, 0)
|
||||
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"]))
|
||||
logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"]))
|
||||
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content['content'])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content['content'])
|
||||
logger.debug("[OPEN_AI] reply {} used 0 tokens.".format(reply_content))
|
||||
logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
@@ -86,7 +88,7 @@ class ChatGPTBot(Bot):
|
||||
"presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
}
|
||||
|
||||
def reply_text(self, session, session_id, retry_count=0) -> dict:
|
||||
def reply_text(self, session:ChatGPTSession, session_id, retry_count=0) -> dict:
|
||||
'''
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
@@ -96,62 +98,41 @@ class ChatGPTBot(Bot):
|
||||
'''
|
||||
try:
|
||||
if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token():
|
||||
return {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
|
||||
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
|
||||
response = openai.ChatCompletion.create(
|
||||
messages=session, **self.compose_args()
|
||||
messages=session.messages, **self.compose_args()
|
||||
)
|
||||
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||
return {"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
"content": response.choices[0]['message']['content']}
|
||||
except openai.error.RateLimitError as e:
|
||||
# rate limit exception
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
logger.warn("[CHATGPT] RateLimitError: {}".format(e))
|
||||
result['content'] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[CHATGPT] Timeout: {}".format(e))
|
||||
result['content'] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
|
||||
need_retry = False
|
||||
result['content'] = "我连接不到你的网络"
|
||||
else:
|
||||
logger.warn("[CHATGPT] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session_id)
|
||||
|
||||
if need_retry:
|
||||
logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1))
|
||||
return self.reply_text(session, session_id, retry_count+1)
|
||||
else:
|
||||
return {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
|
||||
except openai.error.APIConnectionError as e:
|
||||
# api connection exception
|
||||
logger.warn(e)
|
||||
logger.warn("[OPEN_AI] APIConnection failed")
|
||||
return {"completion_tokens": 0, "content": "我连接不到你的网络"}
|
||||
except openai.error.Timeout as e:
|
||||
logger.warn(e)
|
||||
logger.warn("[OPEN_AI] Timeout")
|
||||
return {"completion_tokens": 0, "content": "我没有收到你的消息"}
|
||||
except Exception as e:
|
||||
# unknown exception
|
||||
logger.exception(e)
|
||||
self.sessions.clear_session(session_id)
|
||||
return {"completion_tokens": 0, "content": "请再问我一次吧"}
|
||||
|
||||
def create_img(self, query, retry_count=0):
|
||||
try:
|
||||
if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token():
|
||||
return False, "请求太快了,请休息一下再问我吧"
|
||||
logger.info("[OPEN_AI] image_query={}".format(query))
|
||||
response = openai.Image.create(
|
||||
prompt=query, #图片描述
|
||||
n=1, #每次生成图片的数量
|
||||
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
)
|
||||
image_url = response['data'][0]['url']
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
return True, image_url
|
||||
except openai.error.RateLimitError as e:
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||
return self.create_img(query, retry_count+1)
|
||||
else:
|
||||
return False, "提问太快啦,请休息一下再问我吧"
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return False, str(e)
|
||||
return result
|
||||
|
||||
|
||||
class AzureChatGPTBot(ChatGPTBot):
|
||||
@@ -164,75 +145,4 @@ class AzureChatGPTBot(ChatGPTBot):
|
||||
args = super().compose_args()
|
||||
args["engine"] = args["model"]
|
||||
del(args["model"])
|
||||
return args
|
||||
|
||||
|
||||
class SessionManager(object):
|
||||
def __init__(self):
|
||||
if conf().get('expires_in_seconds'):
|
||||
sessions = ExpiredDict(conf().get('expires_in_seconds'))
|
||||
else:
|
||||
sessions = dict()
|
||||
self.sessions = sessions
|
||||
|
||||
def build_session(self, session_id, system_prompt=None):
|
||||
session = self.sessions.get(session_id, [])
|
||||
if len(session) == 0:
|
||||
if system_prompt is None:
|
||||
system_prompt = conf().get("character_desc", "")
|
||||
system_item = {'role': 'system', 'content': system_prompt}
|
||||
session.append(system_item)
|
||||
self.sessions[session_id] = session
|
||||
return session
|
||||
|
||||
def build_session_query(self, query, session_id):
|
||||
'''
|
||||
build query with conversation history
|
||||
e.g. [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
||||
{"role": "user", "content": "Where was it played?"}
|
||||
]
|
||||
:param query: query content
|
||||
:param session_id: session id
|
||||
:return: query content with conversaction
|
||||
'''
|
||||
session = self.build_session(session_id)
|
||||
user_item = {'role': 'user', 'content': query}
|
||||
session.append(user_item)
|
||||
return session
|
||||
|
||||
def save_session(self, answer, session_id, total_tokens):
|
||||
max_tokens = conf().get("conversation_max_tokens")
|
||||
if not max_tokens:
|
||||
# default 3000
|
||||
max_tokens = 1000
|
||||
max_tokens = int(max_tokens)
|
||||
|
||||
session = self.sessions.get(session_id)
|
||||
if session:
|
||||
# append conversation
|
||||
gpt_item = {'role': 'assistant', 'content': answer}
|
||||
session.append(gpt_item)
|
||||
|
||||
# discard exceed limit conversation
|
||||
self.discard_exceed_conversation(session, max_tokens, total_tokens)
|
||||
|
||||
def discard_exceed_conversation(self, session, max_tokens, total_tokens):
|
||||
dec_tokens = int(total_tokens)
|
||||
# logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens))
|
||||
while dec_tokens > max_tokens:
|
||||
# pop first conversation
|
||||
if len(session) > 3:
|
||||
session.pop(1)
|
||||
session.pop(1)
|
||||
else:
|
||||
break
|
||||
dec_tokens = dec_tokens - max_tokens
|
||||
|
||||
def clear_session(self, session_id):
|
||||
self.sessions[session_id] = []
|
||||
|
||||
def clear_all_session(self):
|
||||
self.sessions.clear()
|
||||
return args
|
||||
@@ -0,0 +1,79 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
'''
|
||||
e.g. [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
||||
{"role": "user", "content": "Where was it played?"}
|
||||
]
|
||||
'''
|
||||
class ChatGPTSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model= "gpt-3.5-turbo"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
self.reset()
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens= None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = num_tokens_from_messages(self.messages, self.model)
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 2:
|
||||
self.messages.pop(1)
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
|
||||
self.messages.pop(1)
|
||||
if precise:
|
||||
cur_tokens = num_tokens_from_messages(self.messages, self.model)
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
break
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
|
||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = num_tokens_from_messages(self.messages, self.model)
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
|
||||
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
def num_tokens_from_messages(messages, model):
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
import tiktoken
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
logger.debug("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
if model == "gpt-3.5-turbo":
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
|
||||
elif model == "gpt-4":
|
||||
return num_tokens_from_messages(messages, model="gpt-4-0314")
|
||||
elif model == "gpt-3.5-turbo-0301":
|
||||
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||
elif model == "gpt-4-0314":
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.")
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
num_tokens += len(encoding.encode(value))
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
return num_tokens
|
||||
+55
-121
@@ -1,18 +1,23 @@
|
||||
# encoding:utf-8
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.openai.open_ai_session import OpenAISession
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
import openai
|
||||
import openai.error
|
||||
import time
|
||||
|
||||
user_session = dict()
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class OpenAIBot(Bot):
|
||||
class OpenAIBot(Bot, OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
openai.api_key = conf().get('open_ai_api_key')
|
||||
if conf().get('open_ai_api_base'):
|
||||
openai.api_base = conf().get('open_ai_api_base')
|
||||
@@ -20,34 +25,45 @@ class OpenAIBot(Bot):
|
||||
if proxy:
|
||||
openai.proxy = proxy
|
||||
|
||||
self.sessions = SessionManager(OpenAISession, model= conf().get("model") or "text-davinci-003")
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context and context.type:
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[OPEN_AI] query={}".format(query))
|
||||
from_user_id = context['session_id']
|
||||
session_id = context['session_id']
|
||||
reply = None
|
||||
if query == '#清除记忆':
|
||||
Session.clear_session(from_user_id)
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, '记忆已清除')
|
||||
elif query == '#清除所有':
|
||||
Session.clear_all_session()
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
|
||||
else:
|
||||
new_query = Session.build_session_query(query, from_user_id)
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
new_query = str(session)
|
||||
logger.debug("[OPEN_AI] session query={}".format(new_query))
|
||||
|
||||
reply_content = self.reply_text(new_query, from_user_id, 0)
|
||||
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
|
||||
if reply_content and query:
|
||||
Session.save_session(query, reply_content, from_user_id)
|
||||
reply = Reply(ReplyType.TEXT, reply_content)
|
||||
total_tokens, completion_tokens, reply_content = self.reply_text(new_query, session_id, 0)
|
||||
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(new_query, session_id, reply_content, completion_tokens))
|
||||
|
||||
if total_tokens == 0 :
|
||||
reply = Reply(ReplyType.ERROR, reply_content)
|
||||
else:
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
||||
reply = Reply(ReplyType.TEXT, reply_content)
|
||||
return reply
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
return self.create_img(query, 0)
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
|
||||
def reply_text(self, query, user_id, retry_count=0):
|
||||
def reply_text(self, query, session_id, retry_count=0):
|
||||
try:
|
||||
response = openai.Completion.create(
|
||||
model= conf().get("model") or "text-davinci-003", # 对话模型的名称
|
||||
@@ -60,116 +76,34 @@ class OpenAIBot(Bot):
|
||||
stop=["\n\n\n"]
|
||||
)
|
||||
res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
|
||||
total_tokens = response["usage"]["total_tokens"]
|
||||
completion_tokens = response["usage"]["completion_tokens"]
|
||||
logger.info("[OPEN_AI] reply={}".format(res_content))
|
||||
return res_content
|
||||
except openai.error.RateLimitError as e:
|
||||
# rate limit exception
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||
return self.reply_text(query, user_id, retry_count+1)
|
||||
else:
|
||||
return "提问太快啦,请休息一下再问我吧"
|
||||
return total_tokens, completion_tokens, res_content
|
||||
except Exception as e:
|
||||
# unknown exception
|
||||
logger.exception(e)
|
||||
Session.clear_session(user_id)
|
||||
return "请再问我一次吧"
|
||||
|
||||
|
||||
def create_img(self, query, retry_count=0):
|
||||
try:
|
||||
logger.info("[OPEN_AI] image_query={}".format(query))
|
||||
response = openai.Image.create(
|
||||
prompt=query, #图片描述
|
||||
n=1, #每次生成图片的数量
|
||||
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
)
|
||||
image_url = response['data'][0]['url']
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
return image_url
|
||||
except openai.error.RateLimitError as e:
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||
return self.reply_text(query, retry_count+1)
|
||||
need_retry = retry_count < 2
|
||||
result = [0,0,"我现在有点累了,等会再来吧"]
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
logger.warn("[OPEN_AI] RateLimitError: {}".format(e))
|
||||
result[2] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[OPEN_AI] Timeout: {}".format(e))
|
||||
result[2] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[OPEN_AI] APIConnectionError: {}".format(e))
|
||||
need_retry = False
|
||||
result[2] = "我连接不到你的网络"
|
||||
else:
|
||||
return "提问太快啦,请休息一下再问我吧"
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return None
|
||||
logger.warn("[OPEN_AI] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session_id)
|
||||
|
||||
|
||||
class Session(object):
|
||||
@staticmethod
|
||||
def build_session_query(query, user_id):
|
||||
'''
|
||||
build query with conversation history
|
||||
e.g. Q: xxx
|
||||
A: xxx
|
||||
Q: xxx
|
||||
:param query: query content
|
||||
:param user_id: from user id
|
||||
:return: query content with conversaction
|
||||
'''
|
||||
prompt = conf().get("character_desc", "")
|
||||
if prompt:
|
||||
prompt += "<|endoftext|>\n\n\n"
|
||||
session = user_session.get(user_id, None)
|
||||
if session:
|
||||
for conversation in session:
|
||||
prompt += "Q: " + conversation["question"] + "\n\n\nA: " + conversation["answer"] + "<|endoftext|>\n"
|
||||
prompt += "Q: " + query + "\nA: "
|
||||
return prompt
|
||||
else:
|
||||
return prompt + "Q: " + query + "\nA: "
|
||||
|
||||
@staticmethod
|
||||
def save_session(query, answer, user_id):
|
||||
max_tokens = conf().get("conversation_max_tokens")
|
||||
if not max_tokens:
|
||||
# default 3000
|
||||
max_tokens = 1000
|
||||
conversation = dict()
|
||||
conversation["question"] = query
|
||||
conversation["answer"] = answer
|
||||
session = user_session.get(user_id)
|
||||
logger.debug(conversation)
|
||||
logger.debug(session)
|
||||
if session:
|
||||
# append conversation
|
||||
session.append(conversation)
|
||||
else:
|
||||
# create session
|
||||
queue = list()
|
||||
queue.append(conversation)
|
||||
user_session[user_id] = queue
|
||||
|
||||
# discard exceed limit conversation
|
||||
Session.discard_exceed_conversation(user_session[user_id], max_tokens)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def discard_exceed_conversation(session, max_tokens):
|
||||
count = 0
|
||||
count_list = list()
|
||||
for i in range(len(session)-1, -1, -1):
|
||||
# count tokens of conversation list
|
||||
history_conv = session[i]
|
||||
count += len(history_conv["question"]) + len(history_conv["answer"])
|
||||
count_list.append(count)
|
||||
|
||||
for c in count_list:
|
||||
if c > max_tokens:
|
||||
# pop first conversation
|
||||
session.pop(0)
|
||||
|
||||
@staticmethod
|
||||
def clear_session(user_id):
|
||||
user_session[user_id] = []
|
||||
|
||||
@staticmethod
|
||||
def clear_all_session():
|
||||
user_session.clear()
|
||||
if need_retry:
|
||||
logger.warn("[OPEN_AI] 第{}次重试".format(retry_count+1))
|
||||
return self.reply_text(query, session_id, retry_count+1)
|
||||
else:
|
||||
return result
|
||||
@@ -0,0 +1,38 @@
|
||||
import time
|
||||
import openai
|
||||
import openai.error
|
||||
from common.token_bucket import TokenBucket
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
# OPENAI提供的画图接口
|
||||
class OpenAIImage(object):
|
||||
def __init__(self):
|
||||
openai.api_key = conf().get('open_ai_api_key')
|
||||
if conf().get('rate_limit_dalle'):
|
||||
self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50))
|
||||
|
||||
def create_img(self, query, retry_count=0):
|
||||
try:
|
||||
if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token():
|
||||
return False, "请求太快了,请休息一下再问我吧"
|
||||
logger.info("[OPEN_AI] image_query={}".format(query))
|
||||
response = openai.Image.create(
|
||||
prompt=query, #图片描述
|
||||
n=1, #每次生成图片的数量
|
||||
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
)
|
||||
image_url = response['data'][0]['url']
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
return True, image_url
|
||||
except openai.error.RateLimitError as e:
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||
return self.create_img(query, retry_count+1)
|
||||
else:
|
||||
return False, "提问太快啦,请休息一下再问我吧"
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return False, str(e)
|
||||
@@ -0,0 +1,67 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
class OpenAISession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model= "text-davinci-003"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
self.reset()
|
||||
|
||||
def __str__(self):
|
||||
# 构造对话模型的输入
|
||||
'''
|
||||
e.g. Q: xxx
|
||||
A: xxx
|
||||
Q: xxx
|
||||
'''
|
||||
prompt = ""
|
||||
for item in self.messages:
|
||||
if item['role'] == 'system':
|
||||
prompt += item['content'] + "<|endoftext|>\n\n\n"
|
||||
elif item['role'] == 'user':
|
||||
prompt += "Q: " + item['content'] + "\n"
|
||||
elif item['role'] == 'assistant':
|
||||
prompt += "\n\nA: " + item['content'] + "<|endoftext|>\n"
|
||||
|
||||
if len(self.messages) > 0 and self.messages[-1]['role'] == 'user':
|
||||
prompt += "A: "
|
||||
return prompt
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens= None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = num_tokens_from_string(str(self), self.model)
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 1:
|
||||
self.messages.pop(0)
|
||||
elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant":
|
||||
self.messages.pop(0)
|
||||
if precise:
|
||||
cur_tokens = num_tokens_from_string(str(self), self.model)
|
||||
else:
|
||||
cur_tokens = len(str(self))
|
||||
break
|
||||
elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
|
||||
logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = num_tokens_from_string(str(self), self.model)
|
||||
else:
|
||||
cur_tokens = len(str(self))
|
||||
return cur_tokens
|
||||
|
||||
|
||||
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
def num_tokens_from_string(string: str, model: str) -> int:
|
||||
"""Returns the number of tokens in a text string."""
|
||||
import tiktoken
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
num_tokens = len(encoding.encode(string,disallowed_special=()))
|
||||
return num_tokens
|
||||
@@ -0,0 +1,85 @@
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
class Session(object):
|
||||
def __init__(self, session_id, system_prompt=None):
|
||||
self.session_id = session_id
|
||||
self.messages = []
|
||||
if system_prompt is None:
|
||||
self.system_prompt = conf().get("character_desc", "")
|
||||
else:
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
# 重置会话
|
||||
def reset(self):
|
||||
system_item = {'role': 'system', 'content': self.system_prompt}
|
||||
self.messages = [system_item]
|
||||
|
||||
def set_system_prompt(self, system_prompt):
|
||||
self.system_prompt = system_prompt
|
||||
self.reset()
|
||||
|
||||
def add_query(self, query):
|
||||
user_item = {'role': 'user', 'content': query}
|
||||
self.messages.append(user_item)
|
||||
|
||||
def add_reply(self, reply):
|
||||
assistant_item = {'role': 'assistant', 'content': reply}
|
||||
self.messages.append(assistant_item)
|
||||
|
||||
def discard_exceeding(self, max_tokens=None, cur_tokens=None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
class SessionManager(object):
|
||||
def __init__(self, sessioncls, **session_args):
|
||||
if conf().get('expires_in_seconds'):
|
||||
sessions = ExpiredDict(conf().get('expires_in_seconds'))
|
||||
else:
|
||||
sessions = dict()
|
||||
self.sessions = sessions
|
||||
self.sessioncls = sessioncls
|
||||
self.session_args = session_args
|
||||
|
||||
def build_session(self, session_id, system_prompt=None):
|
||||
'''
|
||||
如果session_id不在sessions中,创建一个新的session并添加到sessions中
|
||||
如果system_prompt不会空,会更新session的system_prompt并重置session
|
||||
'''
|
||||
if session_id not in self.sessions:
|
||||
self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args)
|
||||
elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session
|
||||
self.sessions[session_id].set_system_prompt(system_prompt)
|
||||
session = self.sessions[session_id]
|
||||
return session
|
||||
|
||||
def session_query(self, query, session_id):
|
||||
session = self.build_session(session_id)
|
||||
session.add_query(query)
|
||||
try:
|
||||
max_tokens = conf().get("conversation_max_tokens", 1000)
|
||||
total_tokens = session.discard_exceeding(max_tokens, None)
|
||||
logger.debug("prompt tokens used={}".format(total_tokens))
|
||||
except Exception as e:
|
||||
logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
|
||||
return session
|
||||
|
||||
def session_reply(self, reply, session_id, total_tokens = None):
|
||||
session = self.build_session(session_id)
|
||||
session.add_reply(reply)
|
||||
try:
|
||||
max_tokens = conf().get("conversation_max_tokens", 1000)
|
||||
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
|
||||
logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
|
||||
except Exception as e:
|
||||
logger.debug("Exception when counting tokens precisely for session: {}".format(str(e)))
|
||||
return session
|
||||
|
||||
def clear_session(self, session_id):
|
||||
if session_id in self.sessions:
|
||||
del(self.sessions[session_id])
|
||||
|
||||
def clear_all_session(self):
|
||||
self.sessions.clear()
|
||||
+1
-1
@@ -14,7 +14,7 @@ class Bridge(object):
|
||||
self.btype={
|
||||
"chat": const.CHATGPT,
|
||||
"voice_to_text": conf().get("voice_to_text", "openai"),
|
||||
"text_to_voice": conf().get("text_to_voice", "baidu")
|
||||
"text_to_voice": conf().get("text_to_voice", "google")
|
||||
}
|
||||
model_type = conf().get("model")
|
||||
if model_type in ["text-davinci-003"]:
|
||||
|
||||
@@ -14,6 +14,15 @@ class Context:
|
||||
self.type = type
|
||||
self.content = content
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __contains__(self, key):
|
||||
if key == 'type':
|
||||
return self.type is not None
|
||||
elif key == 'content':
|
||||
return self.content is not None
|
||||
else:
|
||||
return key in self.kwargs
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key == 'type':
|
||||
return self.type
|
||||
@@ -21,6 +30,12 @@ class Context:
|
||||
return self.content
|
||||
else:
|
||||
return self.kwargs[key]
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if key == 'type':
|
||||
|
||||
+2
-1
@@ -20,7 +20,8 @@ class Channel(object):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def send(self, msg, receiver):
|
||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
||||
def send(self, reply: Reply, context: Context):
|
||||
"""
|
||||
send message to user
|
||||
:param msg: message content
|
||||
|
||||
@@ -0,0 +1,223 @@
|
||||
|
||||
|
||||
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from common.expired_dict import ExpiredDict
|
||||
from channel.channel import Channel
|
||||
from bridge.reply import *
|
||||
from bridge.context import *
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
from plugins import *
|
||||
try:
|
||||
from voice.audio_convert import any_to_wav
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# 抽象类, 它包含了与消息通道无关的通用处理逻辑
|
||||
class ChatChannel(Channel):
|
||||
name = None # 登录的用户名
|
||||
user_id = None # 登录的用户id
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
# 根据消息构造context,消息内容相关的触发项写在这里
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
# context首次传入时,origin_ctype是None,
|
||||
# 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。
|
||||
# origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀
|
||||
if 'origin_ctype' not in context:
|
||||
context['origin_ctype'] = ctype
|
||||
# context首次传入时,receiver是None,根据类型设置receiver
|
||||
first_in = 'receiver' not in context
|
||||
# 群名匹配过程,设置session_id和receiver
|
||||
if first_in: # context首次传入时,receiver是None,根据类型设置receiver
|
||||
config = conf()
|
||||
cmsg = context['msg']
|
||||
if cmsg.from_user_id == self.user_id:
|
||||
logger.debug("[WX]self message skipped")
|
||||
return None
|
||||
if context["isgroup"]:
|
||||
group_name = cmsg.other_user_nickname
|
||||
group_id = cmsg.other_user_id
|
||||
|
||||
group_name_white_list = config.get('group_name_white_list', [])
|
||||
group_name_keyword_white_list = config.get('group_name_keyword_white_list', [])
|
||||
if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]):
|
||||
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
|
||||
session_id = cmsg.actual_user_id
|
||||
if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]):
|
||||
session_id = group_id
|
||||
else:
|
||||
return None
|
||||
context['session_id'] = session_id
|
||||
context['receiver'] = group_id
|
||||
else:
|
||||
context['session_id'] = cmsg.other_user_id
|
||||
context['receiver'] = cmsg.other_user_id
|
||||
|
||||
# 消息内容匹配过程,并处理content
|
||||
if ctype == ContextType.TEXT:
|
||||
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
|
||||
logger.debug("[WX]reference query skipped")
|
||||
return None
|
||||
|
||||
if context["isgroup"]: # 群聊
|
||||
# 校验关键字
|
||||
match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
|
||||
match_contain = check_contain(content, conf().get('group_chat_keyword'))
|
||||
if match_prefix is not None or match_contain is not None:
|
||||
if match_prefix:
|
||||
content = content.replace(match_prefix, '', 1).strip()
|
||||
elif context['msg'].is_at and not conf().get("group_at_off", False):
|
||||
logger.info("[WX]receive group at, continue")
|
||||
pattern = f'@{self.name}(\u2005|\u0020)'
|
||||
content = re.sub(pattern, r'', content)
|
||||
elif context["origin_ctype"] == ContextType.VOICE:
|
||||
logger.info("[WX]receive group voice, checkprefix didn't match")
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
else: # 单聊
|
||||
match_prefix = check_prefix(content, conf().get('single_chat_prefix'))
|
||||
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
|
||||
content = content.replace(match_prefix, '', 1).strip()
|
||||
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
|
||||
pass
|
||||
else:
|
||||
return None
|
||||
|
||||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, '', 1).strip()
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content
|
||||
elif context.type == ContextType.VOICE:
|
||||
if 'desire_rtype' not in context and conf().get('voice_reply_voice'):
|
||||
context['desire_rtype'] = ReplyType.VOICE
|
||||
|
||||
|
||||
return context
|
||||
|
||||
# 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类
|
||||
def _handle(self, context: Context):
|
||||
if context is None or not context.content:
|
||||
return
|
||||
logger.debug('[WX] ready to handle context: {}'.format(context))
|
||||
# reply的构建步骤
|
||||
reply = self._generate_reply(context)
|
||||
|
||||
logger.debug('[WX] ready to decorate reply: {}'.format(reply))
|
||||
# reply的包装步骤
|
||||
reply = self._decorate_reply(context, reply)
|
||||
|
||||
# reply的发送步骤
|
||||
self._send_reply(context, reply)
|
||||
|
||||
def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
|
||||
'channel': self, 'context': context, 'reply': reply}))
|
||||
reply = e_context['reply']
|
||||
if not e_context.is_pass():
|
||||
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content))
|
||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
|
||||
reply = super().build_reply_content(context.content, context)
|
||||
elif context.type == ContextType.VOICE: # 语音消息
|
||||
cmsg = context['msg']
|
||||
cmsg.prepare()
|
||||
file_path = context.content
|
||||
wav_path = os.path.splitext(file_path)[0] + '.wav'
|
||||
try:
|
||||
any_to_wav(file_path, wav_path)
|
||||
except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
|
||||
logger.warning("[WX]any to wav error, use raw path. " + str(e))
|
||||
wav_path = file_path
|
||||
# 语音识别
|
||||
reply = super().build_voice_to_text(wav_path)
|
||||
# 删除临时文件
|
||||
try:
|
||||
os.remove(file_path)
|
||||
os.remove(wav_path)
|
||||
except Exception as e:
|
||||
logger.warning("[WX]delete temp file error: " + str(e))
|
||||
|
||||
if reply.type == ReplyType.TEXT:
|
||||
new_context = self._compose_context(
|
||||
ContextType.TEXT, reply.content, **context.kwargs)
|
||||
if new_context:
|
||||
reply = self._generate_reply(new_context)
|
||||
else:
|
||||
return
|
||||
else:
|
||||
logger.error('[WX] unknown context type: {}'.format(context.type))
|
||||
return
|
||||
return reply
|
||||
|
||||
def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
|
||||
if reply and reply.type:
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {
|
||||
'channel': self, 'context': context, 'reply': reply}))
|
||||
reply = e_context['reply']
|
||||
desire_rtype = context.get('desire_rtype')
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
if reply.type == ReplyType.TEXT:
|
||||
reply_text = reply.content
|
||||
if desire_rtype == ReplyType.VOICE:
|
||||
reply = super().build_text_to_voice(reply.content)
|
||||
return self._decorate_reply(context, reply)
|
||||
if context['isgroup']:
|
||||
reply_text = '@' + context['msg'].actual_user_nickname + ' ' + reply_text.strip()
|
||||
reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
|
||||
else:
|
||||
reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
|
||||
reply.content = reply_text
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
reply.content = str(reply.type)+":\n" + reply.content
|
||||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
|
||||
pass
|
||||
else:
|
||||
logger.error('[WX] unknown reply type: {}'.format(reply.type))
|
||||
return
|
||||
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
|
||||
logger.warning('[WX] desire_rtype: {}, but reply type: {}'.format(context.get('desire_rtype'), reply.type))
|
||||
return reply
|
||||
|
||||
def _send_reply(self, context: Context, reply: Reply):
|
||||
if reply and reply.type:
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {
|
||||
'channel': self, 'context': context, 'reply': reply}))
|
||||
reply = e_context['reply']
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context))
|
||||
self._send(reply, context)
|
||||
|
||||
def _send(self, reply: Reply, context: Context, retry_cnt = 0):
|
||||
try:
|
||||
self.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error('[WX] sendMsg error: {}'.format(e))
|
||||
if retry_cnt < 2:
|
||||
time.sleep(3+3*retry_cnt)
|
||||
self._send(reply, context, retry_cnt+1)
|
||||
|
||||
|
||||
|
||||
def check_prefix(content, prefix_list):
|
||||
for prefix in prefix_list:
|
||||
if content.startswith(prefix):
|
||||
return prefix
|
||||
return None
|
||||
|
||||
def check_contain(content, keyword_list):
|
||||
if not keyword_list:
|
||||
return None
|
||||
for ky in keyword_list:
|
||||
if content.find(ky) != -1:
|
||||
return True
|
||||
return None
|
||||
@@ -0,0 +1,83 @@
|
||||
|
||||
"""
|
||||
本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装
|
||||
|
||||
ChatMessage
|
||||
msg_id: 消息id
|
||||
create_time: 消息创建时间
|
||||
|
||||
ctype: 消息类型 : ContextType
|
||||
content: 消息内容, 如果是声音/图片,这里是文件路径
|
||||
|
||||
from_user_id: 发送者id
|
||||
from_user_nickname: 发送者昵称
|
||||
to_user_id: 接收者id
|
||||
to_user_nickname: 接收者昵称
|
||||
|
||||
other_user_id: 对方的id,如果你是发送者,那这个就是接收者id,如果你是接收者,那这个就是发送者id,如果是群消息,那这一直是群id
|
||||
other_user_nickname: 同上
|
||||
|
||||
is_group: 是否是群消息
|
||||
is_at: 是否被at
|
||||
|
||||
- (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在)
|
||||
actual_user_id: 实际发送者id
|
||||
actual_user_nickname:实际发送者昵称
|
||||
|
||||
|
||||
|
||||
|
||||
_prepare_fn: 准备函数,用于准备消息的内容,比如下载图片等,
|
||||
_prepared: 是否已经调用过准备函数
|
||||
_rawmsg: 原始消息对象
|
||||
|
||||
"""
|
||||
class ChatMessage(object):
|
||||
msg_id = None
|
||||
create_time = None
|
||||
|
||||
ctype = None
|
||||
content = None
|
||||
|
||||
from_user_id = None
|
||||
from_user_nickname = None
|
||||
to_user_id = None
|
||||
to_user_nickname = None
|
||||
other_user_id = None
|
||||
other_user_nickname = None
|
||||
|
||||
is_group = False
|
||||
is_at = False
|
||||
actual_user_id = None
|
||||
actual_user_nickname = None
|
||||
|
||||
_prepare_fn = None
|
||||
_prepared = False
|
||||
_rawmsg = None
|
||||
|
||||
|
||||
def __init__(self,_rawmsg):
|
||||
self._rawmsg = _rawmsg
|
||||
|
||||
def prepare(self):
|
||||
if self._prepare_fn and not self._prepared:
|
||||
self._prepared = True
|
||||
self._prepare_fn()
|
||||
|
||||
def __str__(self):
|
||||
return 'ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}'.format(
|
||||
self.msg_id,
|
||||
self.create_time,
|
||||
self.ctype,
|
||||
self.content,
|
||||
self.from_user_id,
|
||||
self.from_user_nickname,
|
||||
self.to_user_id,
|
||||
self.to_user_nickname,
|
||||
self.other_user_id,
|
||||
self.other_user_nickname,
|
||||
self.is_group,
|
||||
self.is_at,
|
||||
self.actual_user_id,
|
||||
self.actual_user_nickname,
|
||||
)
|
||||
@@ -5,54 +5,74 @@ wechat channel
|
||||
"""
|
||||
|
||||
import os
|
||||
from lib import itchat
|
||||
import json
|
||||
from lib.itchat.content import *
|
||||
from bridge.reply import *
|
||||
from bridge.context import *
|
||||
from channel.channel import Channel
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from config import conf
|
||||
from common.time_check import time_checker
|
||||
from plugins import *
|
||||
import requests
|
||||
import io
|
||||
import time
|
||||
|
||||
|
||||
import json
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.wechat.wechat_message import *
|
||||
from common.singleton import singleton
|
||||
from common.log import logger
|
||||
from lib import itchat
|
||||
from lib.itchat.content import *
|
||||
from bridge.reply import *
|
||||
from bridge.context import *
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from config import conf
|
||||
from common.time_check import time_checker
|
||||
from common.expired_dict import ExpiredDict
|
||||
from plugins import *
|
||||
thread_pool = ThreadPoolExecutor(max_workers=8)
|
||||
|
||||
def thread_pool_callback(worker):
|
||||
worker_exception = worker.exception()
|
||||
if worker_exception:
|
||||
logger.exception("Worker return exception: {}".format(worker_exception))
|
||||
|
||||
|
||||
@itchat.msg_register(TEXT)
|
||||
def handler_single_msg(msg):
|
||||
WechatChannel().handle_text(msg)
|
||||
WechatChannel().handle_text(WeChatMessage(msg))
|
||||
return None
|
||||
|
||||
|
||||
@itchat.msg_register(TEXT, isGroupChat=True)
|
||||
def handler_group_msg(msg):
|
||||
WechatChannel().handle_group(msg)
|
||||
WechatChannel().handle_group(WeChatMessage(msg,True))
|
||||
return None
|
||||
|
||||
|
||||
@itchat.msg_register(VOICE)
|
||||
def handler_single_voice(msg):
|
||||
WechatChannel().handle_voice(msg)
|
||||
WechatChannel().handle_voice(WeChatMessage(msg))
|
||||
return None
|
||||
|
||||
@itchat.msg_register(VOICE, isGroupChat=True)
|
||||
def handler_group_voice(msg):
|
||||
WechatChannel().handle_group_voice(WeChatMessage(msg,True))
|
||||
return None
|
||||
|
||||
def _check(func):
|
||||
def wrapper(self, cmsg: ChatMessage):
|
||||
msgId = cmsg.msg_id
|
||||
if msgId in self.receivedMsgs:
|
||||
logger.info("Wechat message {} already received, ignore".format(msgId))
|
||||
return
|
||||
self.receivedMsgs[msgId] = cmsg
|
||||
create_time = cmsg.create_time # 消息时间戳
|
||||
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
|
||||
logger.debug("[WX]history message {} skipped".format(msgId))
|
||||
return
|
||||
return func(self, cmsg)
|
||||
return wrapper
|
||||
|
||||
class WechatChannel(Channel):
|
||||
@singleton
|
||||
class WechatChannel(ChatChannel):
|
||||
def __init__(self):
|
||||
pass
|
||||
super().__init__()
|
||||
self.receivedMsgs = ExpiredDict(60*60*24)
|
||||
|
||||
def startup(self):
|
||||
|
||||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
|
||||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
|
||||
# login by scan QRCode
|
||||
hotReload = conf().get('hot_reload', False)
|
||||
try:
|
||||
@@ -65,10 +85,13 @@ class WechatChannel(Channel):
|
||||
itchat.auto_login(enableCmdQR=2, hotReload=hotReload)
|
||||
else:
|
||||
raise e
|
||||
self.user_id = itchat.instance.storageClass.userName
|
||||
self.name = itchat.instance.storageClass.nickName
|
||||
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
|
||||
# start message listener
|
||||
itchat.run()
|
||||
|
||||
# handle_* 系列函数处理收到的消息后构造Context,然后传入handle函数中处理Context和发送回复
|
||||
# handle_* 系列函数处理收到的消息后构造Context,然后传入_handle函数中处理Context和发送回复
|
||||
# Context包含了消息的所有信息,包括以下属性
|
||||
# type 消息类型, 包括TEXT、VOICE、IMAGE_CREATE
|
||||
# content 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令
|
||||
@@ -76,101 +99,49 @@ class WechatChannel(Channel):
|
||||
# session_id: 会话id
|
||||
# isgroup: 是否是群聊
|
||||
# receiver: 需要回复的对象
|
||||
# msg: itchat的原始消息对象
|
||||
# msg: ChatMessage消息对象
|
||||
# origin_ctype: 原始消息类型,语音转文字后,私聊时如果匹配前缀失败,会根据初始消息是否是语音来放宽触发规则
|
||||
# desire_rtype: 希望回复类型,默认是文本回复,设置为ReplyType.VOICE是语音回复
|
||||
|
||||
def handle_voice(self, msg):
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_voice(self, cmsg : ChatMessage):
|
||||
if conf().get('speech_recognition') != True:
|
||||
return
|
||||
logger.debug("[WX]receive voice msg: " + msg['FileName'])
|
||||
from_user_id = msg['FromUserName']
|
||||
other_user_id = msg['User']['UserName']
|
||||
if from_user_id == other_user_id:
|
||||
context = Context(ContextType.VOICE,msg['FileName'])
|
||||
context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id}
|
||||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
|
||||
|
||||
logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback)
|
||||
|
||||
@time_checker
|
||||
def handle_text(self, msg):
|
||||
logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False))
|
||||
content = msg['Text']
|
||||
from_user_id = msg['FromUserName']
|
||||
to_user_id = msg['ToUserName'] # 接收人id
|
||||
other_user_id = msg['User']['UserName'] # 对手方id
|
||||
create_time = msg['CreateTime'] # 消息时间
|
||||
match_prefix = check_prefix(content, conf().get('single_chat_prefix'))
|
||||
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: #跳过1分钟前的历史消息
|
||||
logger.debug("[WX]history message skipped")
|
||||
return
|
||||
if "」\n- - - - - - - - - - - - - - -" in content:
|
||||
logger.debug("[WX]reference query skipped")
|
||||
return
|
||||
if match_prefix:
|
||||
content = content.replace(match_prefix, '', 1).strip()
|
||||
elif match_prefix is None:
|
||||
return
|
||||
context = Context()
|
||||
context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id}
|
||||
|
||||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, '', 1).strip()
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
|
||||
context.content = content
|
||||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
|
||||
@_check
|
||||
def handle_text(self, cmsg : ChatMessage):
|
||||
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback)
|
||||
|
||||
@time_checker
|
||||
def handle_group(self, msg):
|
||||
logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False))
|
||||
group_name = msg['User'].get('NickName', None)
|
||||
group_id = msg['User'].get('UserName', None)
|
||||
create_time = msg['CreateTime'] # 消息时间
|
||||
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: #跳过1分钟前的历史消息
|
||||
logger.debug("[WX]history group message skipped")
|
||||
@_check
|
||||
def handle_group(self, cmsg : ChatMessage):
|
||||
logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=True, msg=cmsg)
|
||||
if context:
|
||||
thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback)
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_group_voice(self, cmsg : ChatMessage):
|
||||
if conf().get('group_speech_recognition', False) != True:
|
||||
return
|
||||
if not group_name:
|
||||
return ""
|
||||
origin_content = msg['Content']
|
||||
content = msg['Content']
|
||||
content_list = content.split(' ', 1)
|
||||
context_special_list = content.split('\u2005', 1)
|
||||
if len(context_special_list) == 2:
|
||||
content = context_special_list[1]
|
||||
elif len(content_list) == 2:
|
||||
content = content_list[1]
|
||||
if "」\n- - - - - - - - - - - - - - -" in content:
|
||||
logger.debug("[WX]reference query skipped")
|
||||
return ""
|
||||
config = conf()
|
||||
match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \
|
||||
or check_contain(origin_content, config.get('group_chat_keyword'))
|
||||
if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix:
|
||||
context = Context()
|
||||
context.kwargs = { 'isgroup': True, 'msg': msg, 'receiver': group_id}
|
||||
|
||||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, '', 1).strip()
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content
|
||||
|
||||
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
|
||||
if ('ALL_GROUP' in group_chat_in_one_session or
|
||||
group_name in group_chat_in_one_session or
|
||||
check_contain(group_name, group_chat_in_one_session)):
|
||||
context['session_id'] = group_id
|
||||
else:
|
||||
context['session_id'] = msg['ActualUserName']
|
||||
|
||||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
|
||||
|
||||
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=True, msg=cmsg)
|
||||
if context:
|
||||
thread_pool.submit(self._handle, context).add_done_callback(thread_pool_callback)
|
||||
|
||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
||||
def send(self, reply : Reply, receiver):
|
||||
def send(self, reply: Reply, context: Context):
|
||||
receiver = context["receiver"]
|
||||
if reply.type == ReplyType.TEXT:
|
||||
itchat.send(reply.content, toUserName=receiver)
|
||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
|
||||
@@ -194,79 +165,3 @@ class WechatChannel(Channel):
|
||||
image_storage.seek(0)
|
||||
itchat.send_image(image_storage, toUserName=receiver)
|
||||
logger.info('[WX] sendImage, receiver={}'.format(receiver))
|
||||
|
||||
# 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类
|
||||
def handle(self, context):
|
||||
reply = Reply()
|
||||
|
||||
logger.debug('[WX] ready to handle context: {}'.format(context))
|
||||
|
||||
# reply的构建步骤
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply}))
|
||||
reply = e_context['reply']
|
||||
if not e_context.is_pass():
|
||||
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content))
|
||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE:
|
||||
reply = super().build_reply_content(context.content, context)
|
||||
elif context.type == ContextType.VOICE:
|
||||
msg = context['msg']
|
||||
file_name = TmpDir().path() + context.content
|
||||
msg.download(file_name)
|
||||
reply = super().build_voice_to_text(file_name)
|
||||
if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO:
|
||||
context.content = reply.content # 语音转文字后,将文字内容作为新的context
|
||||
context.type = ContextType.TEXT
|
||||
reply = super().build_reply_content(context.content, context)
|
||||
if reply.type == ReplyType.TEXT:
|
||||
if conf().get('voice_reply_voice'):
|
||||
reply = super().build_text_to_voice(reply.content)
|
||||
else:
|
||||
logger.error('[WX] unknown context type: {}'.format(context.type))
|
||||
return
|
||||
|
||||
logger.debug('[WX] ready to decorate reply: {}'.format(reply))
|
||||
|
||||
# reply的包装步骤
|
||||
if reply and reply.type:
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
|
||||
reply=e_context['reply']
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
if reply.type == ReplyType.TEXT:
|
||||
reply_text = reply.content
|
||||
if context['isgroup']:
|
||||
reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip()
|
||||
reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
|
||||
else:
|
||||
reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
|
||||
reply.content = reply_text
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
reply.content = str(reply.type)+":\n" + reply.content
|
||||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
|
||||
pass
|
||||
else:
|
||||
logger.error('[WX] unknown reply type: {}'.format(reply.type))
|
||||
return
|
||||
|
||||
# reply的发送步骤
|
||||
if reply and reply.type:
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
|
||||
reply=e_context['reply']
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver']))
|
||||
self.send(reply, context['receiver'])
|
||||
|
||||
|
||||
def check_prefix(content, prefix_list):
|
||||
for prefix in prefix_list:
|
||||
if content.startswith(prefix):
|
||||
return prefix
|
||||
return None
|
||||
|
||||
|
||||
def check_contain(content, keyword_list):
|
||||
if not keyword_list:
|
||||
return None
|
||||
for ky in keyword_list:
|
||||
if content.find(ky) != -1:
|
||||
return True
|
||||
return None
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.tmp_dir import TmpDir
|
||||
from common.log import logger
|
||||
from lib.itchat.content import *
|
||||
from lib import itchat
|
||||
|
||||
class WeChatMessage(ChatMessage):
|
||||
|
||||
def __init__(self, itchat_msg, is_group=False):
|
||||
super().__init__( itchat_msg)
|
||||
self.msg_id = itchat_msg['MsgId']
|
||||
self.create_time = itchat_msg['CreateTime']
|
||||
self.is_group = is_group
|
||||
|
||||
if itchat_msg['Type'] == TEXT:
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = itchat_msg['Text']
|
||||
elif itchat_msg['Type'] == VOICE:
|
||||
self.ctype = ContextType.VOICE
|
||||
self.content = TmpDir().path() + itchat_msg['FileName'] # content直接存临时目录路径
|
||||
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: {}".format(itchat_msg['Type']))
|
||||
|
||||
self.from_user_id = itchat_msg['FromUserName']
|
||||
self.to_user_id = itchat_msg['ToUserName']
|
||||
|
||||
user_id = itchat.instance.storageClass.userName
|
||||
nickname = itchat.instance.storageClass.nickName
|
||||
|
||||
# 虽然from_user_id和to_user_id用的少,但是为了保持一致性,还是要填充一下
|
||||
# 以下很繁琐,一句话总结:能填的都填了。
|
||||
if self.from_user_id == user_id:
|
||||
self.from_user_nickname = nickname
|
||||
if self.to_user_id == user_id:
|
||||
self.to_user_nickname = nickname
|
||||
try: # 陌生人时候, 'User'字段可能不存在
|
||||
self.other_user_id = itchat_msg['User']['UserName']
|
||||
self.other_user_nickname = itchat_msg['User']['NickName']
|
||||
if self.other_user_id == self.from_user_id:
|
||||
self.from_user_nickname = self.other_user_nickname
|
||||
if self.other_user_id == self.to_user_id:
|
||||
self.to_user_nickname = self.other_user_nickname
|
||||
except KeyError as e: # 处理偶尔没有对方信息的情况
|
||||
logger.warn("[WX]get other_user_id failed: " + str(e))
|
||||
if self.from_user_id == user_id:
|
||||
self.other_user_id = self.to_user_id
|
||||
else:
|
||||
self.other_user_id = self.from_user_id
|
||||
|
||||
if self.is_group:
|
||||
self.is_at = itchat_msg['IsAt']
|
||||
self.actual_user_id = itchat_msg['ActualUserName']
|
||||
self.actual_user_nickname = itchat_msg['ActualNickName']
|
||||
@@ -4,27 +4,32 @@
|
||||
wechaty channel
|
||||
Python Wechaty - https://github.com/wechaty/python-wechaty
|
||||
"""
|
||||
import io
|
||||
import base64
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import asyncio
|
||||
import requests
|
||||
import pysilk
|
||||
import wave
|
||||
from pydub import AudioSegment
|
||||
from typing import Optional, Union
|
||||
from bridge.context import Context, ContextType
|
||||
from wechaty_puppet import MessageType, FileBox, ScanStatus # type: ignore
|
||||
from bridge.context import Context
|
||||
from wechaty_puppet import FileBox
|
||||
from wechaty import Wechaty, Contact
|
||||
from wechaty.user import Message, Room, MiniProgram, UrlLink
|
||||
from channel.channel import Channel
|
||||
from wechaty.user import Message
|
||||
from bridge.reply import *
|
||||
from bridge.context import *
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.wechat.wechaty_message import WechatyMessage
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from config import conf
|
||||
try:
|
||||
from voice.audio_convert import mp3_to_sil
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
class WechatyChannel(Channel):
|
||||
thread_pool = ThreadPoolExecutor(max_workers=8)
|
||||
def thread_pool_callback(worker):
|
||||
worker_exception = worker.exception()
|
||||
if worker_exception:
|
||||
logger.exception("Worker return exception: {}".format(worker_exception))
|
||||
class WechatyChannel(ChatChannel):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
@@ -34,259 +39,98 @@ class WechatyChannel(Channel):
|
||||
|
||||
async def main(self):
|
||||
config = conf()
|
||||
# 使用PadLocal协议 比较稳定(免费web协议 os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:8080')
|
||||
token = config.get('wechaty_puppet_service_token')
|
||||
os.environ['WECHATY_PUPPET_SERVICE_TOKEN'] = token
|
||||
global bot
|
||||
bot = Wechaty()
|
||||
|
||||
bot.on('scan', self.on_scan)
|
||||
bot.on('login', self.on_login)
|
||||
bot.on('message', self.on_message)
|
||||
await bot.start()
|
||||
os.environ['WECHATY_LOG']="warn"
|
||||
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
|
||||
self.bot = Wechaty()
|
||||
self.bot.on('login', self.on_login)
|
||||
self.bot.on('message', self.on_message)
|
||||
await self.bot.start()
|
||||
|
||||
async def on_login(self, contact: Contact):
|
||||
self.user_id = contact.contact_id
|
||||
self.name = contact.name
|
||||
logger.info('[WX] login user={}'.format(contact))
|
||||
|
||||
async def on_scan(self, status: ScanStatus, qr_code: Optional[str] = None,
|
||||
data: Optional[str] = None):
|
||||
contact = self.Contact.load(self.contact_id)
|
||||
logger.info('[WX] scan user={}, scan status={}, scan qr_code={}'.format(contact, status.name, qr_code))
|
||||
# print(f'user <{contact}> scan status: {status.name} , 'f'qr_code: {qr_code}')
|
||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
||||
def send(self, reply: Reply, context: Context):
|
||||
receiver_id = context['receiver']
|
||||
loop = asyncio.get_event_loop()
|
||||
if context['isgroup']:
|
||||
receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id),loop).result()
|
||||
else:
|
||||
receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id),loop).result()
|
||||
msg = None
|
||||
if reply.type == ReplyType.TEXT:
|
||||
msg = reply.content
|
||||
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
|
||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
msg = reply.content
|
||||
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
|
||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
voiceLength = None
|
||||
if reply.content.endswith('.mp3'):
|
||||
mp3_file = reply.content
|
||||
sil_file = os.path.splitext(mp3_file)[0] + '.sil'
|
||||
voiceLength = mp3_to_sil(mp3_file, sil_file)
|
||||
try:
|
||||
os.remove(mp3_file)
|
||||
except Exception as e:
|
||||
pass
|
||||
elif reply.content.endswith('.sil'):
|
||||
sil_file = reply.content
|
||||
else:
|
||||
raise Exception('voice file must be mp3 or sil format')
|
||||
# 发送语音
|
||||
t = int(time.time())
|
||||
msg = FileBox.from_file(sil_file, name=str(t) + '.sil')
|
||||
if voiceLength is not None:
|
||||
msg.metadata['voiceLength'] = voiceLength
|
||||
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
|
||||
try:
|
||||
os.remove(sil_file)
|
||||
except Exception as e:
|
||||
pass
|
||||
logger.info('[WX] sendVoice={}, receiver={}'.format(reply.content, receiver))
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
t = int(time.time())
|
||||
msg = FileBox.from_url(url=img_url, name=str(t) + '.png')
|
||||
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
|
||||
logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
t = int(time.time())
|
||||
msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + '.png')
|
||||
asyncio.run_coroutine_threadsafe(receiver.say(msg),loop).result()
|
||||
logger.info('[WX] sendImage, receiver={}'.format(receiver))
|
||||
|
||||
async def on_message(self, msg: Message):
|
||||
"""
|
||||
listen for message event
|
||||
"""
|
||||
from_contact = msg.talker() # 获取消息的发送者
|
||||
to_contact = msg.to() # 接收人
|
||||
room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None
|
||||
from_user_id = from_contact.contact_id
|
||||
to_user_id = to_contact.contact_id # 接收人id
|
||||
# other_user_id = msg['User']['UserName'] # 对手方id
|
||||
content = msg.text()
|
||||
mention_content = await msg.mention_text() # 返回过滤掉@name后的消息
|
||||
match_prefix = self.check_prefix(content, conf().get('single_chat_prefix'))
|
||||
conversation: Union[Room, Contact] = from_contact if room is None else room
|
||||
|
||||
if room is None and msg.type() == MessageType.MESSAGE_TYPE_TEXT:
|
||||
if not msg.is_self() and match_prefix is not None:
|
||||
# 好友向自己发送消息
|
||||
if match_prefix != '':
|
||||
str_list = content.split(match_prefix, 1)
|
||||
if len(str_list) == 2:
|
||||
content = str_list[1].strip()
|
||||
|
||||
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
content = content.split(img_match_prefix, 1)[1].strip()
|
||||
await self._do_send_img(content, from_user_id)
|
||||
else:
|
||||
await self._do_send(content, from_user_id)
|
||||
elif msg.is_self() and match_prefix:
|
||||
# 自己给好友发送消息
|
||||
str_list = content.split(match_prefix, 1)
|
||||
if len(str_list) == 2:
|
||||
content = str_list[1].strip()
|
||||
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
content = content.split(img_match_prefix, 1)[1].strip()
|
||||
await self._do_send_img(content, to_user_id)
|
||||
else:
|
||||
await self._do_send(content, to_user_id)
|
||||
elif room is None and msg.type() == MessageType.MESSAGE_TYPE_AUDIO:
|
||||
if not msg.is_self(): # 接收语音消息
|
||||
# 下载语音文件
|
||||
voice_file = await msg.to_file_box()
|
||||
silk_file = TmpDir().path() + voice_file.name
|
||||
await voice_file.to_file(silk_file)
|
||||
logger.info("[WX]receive voice file: " + silk_file)
|
||||
# 将文件转成wav格式音频
|
||||
wav_file = silk_file.replace(".slk", ".wav")
|
||||
with open(silk_file, 'rb') as f:
|
||||
silk_data = f.read()
|
||||
pcm_data = pysilk.decode(silk_data)
|
||||
|
||||
with wave.open(wav_file, 'wb') as wav_data:
|
||||
wav_data.setnchannels(1)
|
||||
wav_data.setsampwidth(2)
|
||||
wav_data.setframerate(24000)
|
||||
wav_data.writeframes(pcm_data)
|
||||
if os.path.exists(wav_file):
|
||||
converter_state = "true" # 转换wav成功
|
||||
else:
|
||||
converter_state = "false" # 转换wav失败
|
||||
logger.info("[WX]receive voice converter: " + converter_state)
|
||||
# 语音识别为文本
|
||||
query = super().build_voice_to_text(wav_file).content
|
||||
# 交验关键字
|
||||
match_prefix = self.check_prefix(query, conf().get('single_chat_prefix'))
|
||||
if match_prefix is not None:
|
||||
if match_prefix != '':
|
||||
str_list = query.split(match_prefix, 1)
|
||||
if len(str_list) == 2:
|
||||
query = str_list[1].strip()
|
||||
# 返回消息
|
||||
if conf().get('voice_reply_voice'):
|
||||
await self._do_send_voice(query, from_user_id)
|
||||
else:
|
||||
await self._do_send(query, from_user_id)
|
||||
else:
|
||||
logger.info("[WX]receive voice check prefix: " + 'False')
|
||||
# 清除缓存文件
|
||||
os.remove(wav_file)
|
||||
os.remove(silk_file)
|
||||
elif room and msg.type() == MessageType.MESSAGE_TYPE_TEXT:
|
||||
# 群组&文本消息
|
||||
room_id = room.room_id
|
||||
room_name = await room.topic()
|
||||
from_user_id = from_contact.contact_id
|
||||
from_user_name = from_contact.name
|
||||
is_at = await msg.mention_self()
|
||||
content = mention_content
|
||||
config = conf()
|
||||
match_prefix = (is_at and not config.get("group_at_off", False)) \
|
||||
or self.check_prefix(content, config.get('group_chat_prefix')) \
|
||||
or self.check_contain(content, config.get('group_chat_keyword'))
|
||||
# Wechaty判断is_at为True,返回的内容是过滤掉@之后的内容;而is_at为False,则会返回完整的内容
|
||||
# 故判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能
|
||||
prefixes = config.get('group_chat_prefix')
|
||||
for prefix in prefixes:
|
||||
if content.startswith(prefix):
|
||||
content = content.replace(prefix, '', 1).strip()
|
||||
break
|
||||
if ('ALL_GROUP' in config.get('group_name_white_list') or room_name in config.get(
|
||||
'group_name_white_list') or self.check_contain(room_name, config.get(
|
||||
'group_name_keyword_white_list'))) and match_prefix:
|
||||
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
content = content.split(img_match_prefix, 1)[1].strip()
|
||||
await self._do_send_group_img(content, room_id)
|
||||
else:
|
||||
await self._do_send_group(content, room_id, room_name, from_user_id, from_user_name)
|
||||
|
||||
async def send(self, message: Union[str, Message, FileBox, Contact, UrlLink, MiniProgram], receiver):
|
||||
logger.info('[WX] sendMsg={}, receiver={}'.format(message, receiver))
|
||||
if receiver:
|
||||
contact = await bot.Contact.find(receiver)
|
||||
await contact.say(message)
|
||||
|
||||
async def send_group(self, message: Union[str, Message, FileBox, Contact, UrlLink, MiniProgram], receiver):
|
||||
logger.info('[WX] sendMsg={}, receiver={}'.format(message, receiver))
|
||||
if receiver:
|
||||
room = await bot.Room.find(receiver)
|
||||
await room.say(message)
|
||||
|
||||
async def _do_send(self, query, reply_user_id):
|
||||
try:
|
||||
if not query:
|
||||
return
|
||||
context = Context(ContextType.TEXT, query)
|
||||
context['session_id'] = reply_user_id
|
||||
reply_text = super().build_reply_content(query, context).content
|
||||
if reply_text:
|
||||
await self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
|
||||
async def _do_send_voice(self, query, reply_user_id):
|
||||
try:
|
||||
if not query:
|
||||
return
|
||||
context = Context(ContextType.TEXT, query)
|
||||
context['session_id'] = reply_user_id
|
||||
reply_text = super().build_reply_content(query, context).content
|
||||
if reply_text:
|
||||
# 转换 mp3 文件为 silk 格式
|
||||
mp3_file = super().build_text_to_voice(reply_text).content
|
||||
silk_file = mp3_file.replace(".mp3", ".silk")
|
||||
# Load the MP3 file
|
||||
audio = AudioSegment.from_file(mp3_file, format="mp3")
|
||||
# Convert to WAV format
|
||||
audio = audio.set_frame_rate(24000).set_channels(1)
|
||||
wav_data = audio.raw_data
|
||||
sample_width = audio.sample_width
|
||||
# Encode to SILK format
|
||||
silk_data = pysilk.encode(wav_data, 24000)
|
||||
# Save the silk file
|
||||
with open(silk_file, "wb") as f:
|
||||
f.write(silk_data)
|
||||
# 发送语音
|
||||
t = int(time.time())
|
||||
file_box = FileBox.from_file(silk_file, name=str(t) + '.silk')
|
||||
await self.send(file_box, reply_user_id)
|
||||
# 清除缓存文件
|
||||
os.remove(mp3_file)
|
||||
os.remove(silk_file)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
async def _do_send_img(self, query, reply_user_id):
|
||||
try:
|
||||
if not query:
|
||||
return
|
||||
context = Context(ContextType.IMAGE_CREATE, query)
|
||||
img_url = super().build_reply_content(query, context).content
|
||||
if not img_url:
|
||||
return
|
||||
# 图片下载
|
||||
# pic_res = requests.get(img_url, stream=True)
|
||||
# image_storage = io.BytesIO()
|
||||
# for block in pic_res.iter_content(1024):
|
||||
# image_storage.write(block)
|
||||
# image_storage.seek(0)
|
||||
|
||||
# 图片发送
|
||||
logger.info('[WX] sendImage, receiver={}'.format(reply_user_id))
|
||||
t = int(time.time())
|
||||
file_box = FileBox.from_url(url=img_url, name=str(t) + '.png')
|
||||
await self.send(file_box, reply_user_id)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
async def _do_send_group(self, query, group_id, group_name, group_user_id, group_user_name):
|
||||
if not query:
|
||||
cmsg = await WechatyMessage(msg)
|
||||
except NotImplementedError as e:
|
||||
logger.debug('[WX] {}'.format(e))
|
||||
return
|
||||
context = Context(ContextType.TEXT, query)
|
||||
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
|
||||
if ('ALL_GROUP' in group_chat_in_one_session or \
|
||||
group_name in group_chat_in_one_session or \
|
||||
self.check_contain(group_name, group_chat_in_one_session)):
|
||||
context['session_id'] = str(group_id)
|
||||
else:
|
||||
context['session_id'] = str(group_id) + '-' + str(group_user_id)
|
||||
reply_text = super().build_reply_content(query, context).content
|
||||
if reply_text:
|
||||
reply_text = '@' + group_user_name + ' ' + reply_text.strip()
|
||||
await self.send_group(conf().get("group_chat_reply_prefix", "") + reply_text, group_id)
|
||||
|
||||
async def _do_send_group_img(self, query, reply_room_id):
|
||||
try:
|
||||
if not query:
|
||||
return
|
||||
context = Context(ContextType.IMAGE_CREATE, query)
|
||||
img_url = super().build_reply_content(query, context).content
|
||||
if not img_url:
|
||||
return
|
||||
# 图片发送
|
||||
logger.info('[WX] sendImage, receiver={}'.format(reply_room_id))
|
||||
t = int(time.time())
|
||||
file_box = FileBox.from_url(url=img_url, name=str(t) + '.png')
|
||||
await self.send_group(file_box, reply_room_id)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
logger.exception('[WX] {}'.format(e))
|
||||
return
|
||||
logger.debug('[WX] message:{}'.format(cmsg))
|
||||
room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None
|
||||
|
||||
isgroup = room is not None
|
||||
ctype = cmsg.ctype
|
||||
context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg)
|
||||
if context:
|
||||
logger.info('[WX] receiveMsg={}, context={}'.format(cmsg, context))
|
||||
thread_pool.submit(self._handle_loop, context, asyncio.get_event_loop()).add_done_callback(thread_pool_callback)
|
||||
|
||||
def check_prefix(self, content, prefix_list):
|
||||
for prefix in prefix_list:
|
||||
if content.startswith(prefix):
|
||||
return prefix
|
||||
return None
|
||||
|
||||
def check_contain(self, content, keyword_list):
|
||||
if not keyword_list:
|
||||
return None
|
||||
for ky in keyword_list:
|
||||
if content.find(ky) != -1:
|
||||
return True
|
||||
return None
|
||||
def _handle_loop(self,context,loop):
|
||||
asyncio.set_event_loop(loop)
|
||||
self._handle(context)
|
||||
@@ -0,0 +1,85 @@
|
||||
import asyncio
|
||||
import re
|
||||
from wechaty import MessageType
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.tmp_dir import TmpDir
|
||||
from common.log import logger
|
||||
from wechaty.user import Message
|
||||
|
||||
class aobject(object):
|
||||
"""Inheriting this class allows you to define an async __init__.
|
||||
|
||||
So you can create objects by doing something like `await MyClass(params)`
|
||||
"""
|
||||
async def __new__(cls, *a, **kw):
|
||||
instance = super().__new__(cls)
|
||||
await instance.__init__(*a, **kw)
|
||||
return instance
|
||||
|
||||
async def __init__(self):
|
||||
pass
|
||||
class WechatyMessage(ChatMessage, aobject):
|
||||
|
||||
async def __init__(self, wechaty_msg: Message):
|
||||
super().__init__(wechaty_msg)
|
||||
|
||||
room = wechaty_msg.room()
|
||||
|
||||
self.msg_id = wechaty_msg.message_id
|
||||
self.create_time = wechaty_msg.payload.timestamp
|
||||
self.is_group = room is not None
|
||||
|
||||
if wechaty_msg.type() == MessageType.MESSAGE_TYPE_TEXT:
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = wechaty_msg.text()
|
||||
elif wechaty_msg.type() == MessageType.MESSAGE_TYPE_AUDIO:
|
||||
self.ctype = ContextType.VOICE
|
||||
voice_file = await wechaty_msg.to_file_box()
|
||||
self.content = TmpDir().path() + voice_file.name # content直接存临时目录路径
|
||||
|
||||
def func():
|
||||
loop = asyncio.get_event_loop()
|
||||
asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content),loop).result()
|
||||
self._prepare_fn = func
|
||||
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type()))
|
||||
|
||||
from_contact = wechaty_msg.talker() # 获取消息的发送者
|
||||
self.from_user_id = from_contact.contact_id
|
||||
self.from_user_nickname = from_contact.name
|
||||
|
||||
# group中的from和to,wechaty跟itchat含义不一样
|
||||
# wecahty: from是消息实际发送者, to:所在群
|
||||
# itchat: 如果是你发送群消息,from和to是你自己和所在群,如果是别人发群消息,from和to是所在群和你自己
|
||||
# 但这个差别不影响逻辑,group中只使用到:1.用from来判断是否是自己发的,2.actual_user_id来判断实际发送用户
|
||||
|
||||
if self.is_group:
|
||||
self.to_user_id = room.room_id
|
||||
self.to_user_nickname = await room.topic()
|
||||
else:
|
||||
to_contact = wechaty_msg.to()
|
||||
self.to_user_id = to_contact.contact_id
|
||||
self.to_user_nickname = to_contact.name
|
||||
|
||||
if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
|
||||
self.other_user_id = self.to_user_id
|
||||
self.other_user_nickname = self.to_user_nickname
|
||||
else:
|
||||
self.other_user_id = self.from_user_id
|
||||
self.other_user_nickname = self.from_user_nickname
|
||||
|
||||
|
||||
|
||||
if self.is_group: # wechaty群聊中,实际发送用户就是from_user
|
||||
self.is_at = await wechaty_msg.mention_self()
|
||||
if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容
|
||||
name = wechaty_msg.wechaty.user_self().name
|
||||
pattern = f'@{name}(\u2005|\u0020)'
|
||||
if re.search(pattern,self.content):
|
||||
logger.debug(f'wechaty message {self.msg_id} include at')
|
||||
self.is_at = True
|
||||
|
||||
self.actual_user_id = self.from_user_id
|
||||
self.actual_user_nickname = self.from_user_nickname
|
||||
@@ -10,6 +10,7 @@
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"],
|
||||
"image_create_prefix": ["画", "看", "找"],
|
||||
"speech_recognition": false,
|
||||
"group_speech_recognition": false,
|
||||
"voice_reply_voice": false,
|
||||
"conversation_max_tokens": 1000,
|
||||
"expires_in_seconds": 3600,
|
||||
|
||||
@@ -5,70 +5,79 @@ import os
|
||||
from common.log import logger
|
||||
|
||||
# 将所有可用的配置项写在字典里, 请使用小写字母
|
||||
available_setting ={
|
||||
#openai api配置
|
||||
"open_ai_api_key": "", # openai api key
|
||||
"open_ai_api_base": "https://api.openai.com/v1", # openai apibase,当use_azure_chatgpt为true时,需要设置对应的api base
|
||||
"proxy": "", # openai使用的代理
|
||||
"model": "gpt-3.5-turbo", # chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
|
||||
"use_azure_chatgpt": False, # 是否使用azure的chatgpt
|
||||
available_setting = {
|
||||
# openai api配置
|
||||
"open_ai_api_key": "", # openai api key
|
||||
# openai apibase,当use_azure_chatgpt为true时,需要设置对应的api base
|
||||
"open_ai_api_base": "https://api.openai.com/v1",
|
||||
"proxy": "", # openai使用的代理
|
||||
# chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
|
||||
"model": "gpt-3.5-turbo",
|
||||
"use_azure_chatgpt": False, # 是否使用azure的chatgpt
|
||||
|
||||
#Bot触发配置
|
||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
|
||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
||||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
|
||||
"group_chat_reply_prefix": "", # 群聊时自动回复的前缀
|
||||
"group_chat_keyword": [], # 群聊时包含该关键词则会触发机器人回复
|
||||
"group_at_off": False, # 是否关闭群聊时@bot的触发
|
||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
|
||||
"group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
||||
|
||||
#chatgpt会话参数
|
||||
"expires_in_seconds": 3600, # 无操作会话的过期时间
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
|
||||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
|
||||
|
||||
#chatgpt限流配置
|
||||
"rate_limit_chatgpt": 20, # chatgpt的调用频率限制
|
||||
"rate_limit_dalle": 50, # openai dalle的调用频率限制
|
||||
# Bot触发配置
|
||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
|
||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
||||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
|
||||
"group_chat_reply_prefix": "", # 群聊时自动回复的前缀
|
||||
"group_chat_keyword": [], # 群聊时包含该关键词则会触发机器人回复
|
||||
"group_at_off": False, # 是否关闭群聊时@bot的触发
|
||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
|
||||
"group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
||||
|
||||
# chatgpt会话参数
|
||||
"expires_in_seconds": 3600, # 无操作会话的过期时间
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
|
||||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
|
||||
|
||||
# chatgpt限流配置
|
||||
"rate_limit_chatgpt": 20, # chatgpt的调用频率限制
|
||||
"rate_limit_dalle": 50, # openai dalle的调用频率限制
|
||||
|
||||
|
||||
#chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create
|
||||
# chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create
|
||||
"temperature": 0.9,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
|
||||
#语音设置
|
||||
"speech_recognition": False, # 是否开启语音识别
|
||||
"voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key
|
||||
"voice_to_text": "openai", # 语音识别引擎,支持openai和google
|
||||
"text_to_voice": "baidu", # 语音合成引擎,支持baidu和google
|
||||
# 语音设置
|
||||
"speech_recognition": False, # 是否开启语音识别
|
||||
"group_speech_recognition": False, # 是否开启群组语音识别
|
||||
"voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key
|
||||
"voice_to_text": "openai", # 语音识别引擎,支持openai,google
|
||||
"text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline)
|
||||
|
||||
# baidu api的配置, 使用百度语音识别和语音合成时需要
|
||||
'baidu_app_id': "",
|
||||
'baidu_api_key': "",
|
||||
'baidu_secret_key': "",
|
||||
"baidu_app_id": "",
|
||||
"baidu_api_key": "",
|
||||
"baidu_secret_key": "",
|
||||
# 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场
|
||||
"baidu_dev_pid": "1536",
|
||||
|
||||
#服务时间限制,目前支持itchat
|
||||
"chat_time_module": False, # 是否开启服务时间限制
|
||||
"chat_start_time": "00:00", # 服务开始时间
|
||||
"chat_stop_time": "24:00", # 服务结束时间
|
||||
# 服务时间限制,目前支持itchat
|
||||
"chat_time_module": False, # 是否开启服务时间限制
|
||||
"chat_start_time": "00:00", # 服务开始时间
|
||||
"chat_stop_time": "24:00", # 服务结束时间
|
||||
|
||||
# itchat的配置
|
||||
"hot_reload": False, # 是否开启热重载
|
||||
"hot_reload": False, # 是否开启热重载
|
||||
|
||||
# wechaty的配置
|
||||
"wechaty_puppet_service_token": "", # wechaty的token
|
||||
"wechaty_puppet_service_token": "", # wechaty的token
|
||||
|
||||
# chatgpt指令自定义触发词
|
||||
"clear_memory_commands": ['#清除记忆'], # 重置会话指令
|
||||
"clear_memory_commands": ['#清除记忆'], # 重置会话指令
|
||||
|
||||
# channel配置
|
||||
"channel_type": "wx", # 通道类型,支持wx,wxy和terminal
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
class Config(dict):
|
||||
def __getitem__(self, key):
|
||||
if key not in available_setting:
|
||||
@@ -81,15 +90,17 @@ class Config(dict):
|
||||
return super().__setitem__(key, value)
|
||||
|
||||
def get(self, key, default=None):
|
||||
try :
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError as e:
|
||||
return default
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
|
||||
config = Config()
|
||||
|
||||
|
||||
def load_config():
|
||||
global config
|
||||
config_path = "./config.json"
|
||||
@@ -108,18 +119,23 @@ def load_config():
|
||||
for name, value in os.environ.items():
|
||||
name = name.lower()
|
||||
if name in available_setting:
|
||||
logger.info("[INIT] override config by environ args: {}={}".format(name, value))
|
||||
logger.info(
|
||||
"[INIT] override config by environ args: {}={}".format(name, value))
|
||||
try:
|
||||
config[name] = eval(value)
|
||||
except:
|
||||
config[name] = value
|
||||
if value == "false":
|
||||
config[name] = False
|
||||
elif value == "true":
|
||||
config[name] = True
|
||||
else:
|
||||
config[name] = value
|
||||
|
||||
logger.info("[INIT] load config: {}".format(config))
|
||||
|
||||
|
||||
|
||||
def get_root():
|
||||
return os.path.dirname(os.path.abspath( __file__ ))
|
||||
return os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
def read_file(path):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM python:3.7.9-alpine
|
||||
FROM python:3.10-alpine
|
||||
|
||||
LABEL maintainer="foo@bar.com"
|
||||
ARG TZ='Asia/Shanghai'
|
||||
@@ -22,9 +22,7 @@ RUN apk add --no-cache \
|
||||
&& cd ${BUILD_PREFIX} \
|
||||
&& cp config-template.json ${BUILD_PREFIX}/config.json \
|
||||
&& /usr/local/bin/python -m pip install --no-cache --upgrade pip \
|
||||
&& pip install --no-cache \
|
||||
itchat-uos==1.5.0.dev0 \
|
||||
openai \
|
||||
&& pip install --no-cache -r requirements.txt \
|
||||
&& apk del curl wget
|
||||
|
||||
WORKDIR ${BUILD_PREFIX}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM python:3.7.9
|
||||
FROM python:3.10
|
||||
|
||||
LABEL maintainer="foo@bar.com"
|
||||
ARG TZ='Asia/Shanghai'
|
||||
@@ -23,9 +23,7 @@ RUN apt-get update \
|
||||
&& cd ${BUILD_PREFIX} \
|
||||
&& cp config-template.json ${BUILD_PREFIX}/config.json \
|
||||
&& /usr/local/bin/python -m pip install --no-cache --upgrade pip \
|
||||
&& pip install --no-cache \
|
||||
itchat-uos==1.5.0.dev0 \
|
||||
openai
|
||||
&& pip install --no-cache -r requirements.txt
|
||||
|
||||
WORKDIR ${BUILD_PREFIX}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM python:3.7.9-alpine
|
||||
FROM python:3.10-alpine
|
||||
|
||||
LABEL maintainer="foo@bar.com"
|
||||
ARG TZ='Asia/Shanghai'
|
||||
@@ -7,22 +7,17 @@ ARG CHATGPT_ON_WECHAT_VER
|
||||
|
||||
ENV BUILD_PREFIX=/app
|
||||
|
||||
COPY chatgpt-on-wechat.tar.gz ./chatgpt-on-wechat.tar.gz
|
||||
ADD . ${BUILD_PREFIX}
|
||||
|
||||
RUN apk add --no-cache \
|
||||
bash \
|
||||
&& tar -xf chatgpt-on-wechat.tar.gz \
|
||||
&& mv chatgpt-on-wechat ${BUILD_PREFIX} \
|
||||
RUN apk add --no-cache bash ffmpeg espeak \
|
||||
&& cd ${BUILD_PREFIX} \
|
||||
&& cp config-template.json ${BUILD_PREFIX}/config.json \
|
||||
&& cp config-template.json config.json \
|
||||
&& /usr/local/bin/python -m pip install --no-cache --upgrade pip \
|
||||
&& pip install --no-cache \
|
||||
itchat-uos==1.5.0.dev0 \
|
||||
openai
|
||||
&& pip install --no-cache -r requirements.txt
|
||||
|
||||
WORKDIR ${BUILD_PREFIX}
|
||||
|
||||
ADD ./entrypoint.sh /entrypoint.sh
|
||||
ADD docker/entrypoint.sh /entrypoint.sh
|
||||
|
||||
RUN chmod +x /entrypoint.sh \
|
||||
&& adduser -D -h /home/noroot -u 1000 -s /bin/bash noroot \
|
||||
@@ -30,4 +25,4 @@ RUN chmod +x /entrypoint.sh \
|
||||
|
||||
USER noroot
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
ENTRYPOINT ["docker/entrypoint.sh"]
|
||||
@@ -1,8 +1,4 @@
|
||||
#!/bin/bash
|
||||
|
||||
# move chatgpt-on-wechat
|
||||
tar -zcf chatgpt-on-wechat.tar.gz --exclude=../../chatgpt-on-wechat/docker ../../chatgpt-on-wechat
|
||||
|
||||
# build image
|
||||
docker build -f Dockerfile.latest \
|
||||
cd .. && docker build -f Dockerfile \
|
||||
-t zhayujie/chatgpt-on-wechat .
|
||||
+12
-12
@@ -10,17 +10,17 @@ CHATGPT_ON_WECHAT_EXEC=${CHATGPT_ON_WECHAT_EXEC:-""}
|
||||
|
||||
# use environment variables to pass parameters
|
||||
# if you have not defined environment variables, set them below
|
||||
export OPEN_AI_API_KEY=${OPEN_AI_API_KEY:-'YOUR API KEY'}
|
||||
export OPEN_AI_PROXY=${OPEN_AI_PROXY:-""}
|
||||
export SINGLE_CHAT_PREFIX=${SINGLE_CHAT_PREFIX:-'["bot", "@bot"]'}
|
||||
export SINGLE_CHAT_REPLY_PREFIX=${SINGLE_CHAT_REPLY_PREFIX:-'"[bot] "'}
|
||||
export GROUP_CHAT_PREFIX=${GROUP_CHAT_PREFIX:-'["@bot"]'}
|
||||
export GROUP_NAME_WHITE_LIST=${GROUP_NAME_WHITE_LIST:-'["ChatGPT测试群", "ChatGPT测试群2"]'}
|
||||
export IMAGE_CREATE_PREFIX=${IMAGE_CREATE_PREFIX:-'["画", "看", "找"]'}
|
||||
export CONVERSATION_MAX_TOKENS=${CONVERSATION_MAX_TOKENS:-"1000"}
|
||||
export SPEECH_RECOGNITION=${SPEECH_RECOGNITION:-"False"}
|
||||
export CHARACTER_DESC=${CHARACTER_DESC:-"你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。"}
|
||||
export EXPIRES_IN_SECONDS=${EXPIRES_IN_SECONDS:-"3600"}
|
||||
# export OPEN_AI_API_KEY=${OPEN_AI_API_KEY:-'YOUR API KEY'}
|
||||
# export OPEN_AI_PROXY=${OPEN_AI_PROXY:-""}
|
||||
# export SINGLE_CHAT_PREFIX=${SINGLE_CHAT_PREFIX:-'["bot", "@bot"]'}
|
||||
# export SINGLE_CHAT_REPLY_PREFIX=${SINGLE_CHAT_REPLY_PREFIX:-'"[bot] "'}
|
||||
# export GROUP_CHAT_PREFIX=${GROUP_CHAT_PREFIX:-'["@bot"]'}
|
||||
# export GROUP_NAME_WHITE_LIST=${GROUP_NAME_WHITE_LIST:-'["ChatGPT测试群", "ChatGPT测试群2"]'}
|
||||
# export IMAGE_CREATE_PREFIX=${IMAGE_CREATE_PREFIX:-'["画", "看", "找"]'}
|
||||
# export CONVERSATION_MAX_TOKENS=${CONVERSATION_MAX_TOKENS:-"1000"}
|
||||
# export SPEECH_RECOGNITION=${SPEECH_RECOGNITION:-"False"}
|
||||
# export CHARACTER_DESC=${CHARACTER_DESC:-"你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。"}
|
||||
# export EXPIRES_IN_SECONDS=${EXPIRES_IN_SECONDS:-"3600"}
|
||||
|
||||
# CHATGPT_ON_WECHAT_PREFIX is empty, use /app
|
||||
if [ "$CHATGPT_ON_WECHAT_PREFIX" == "" ] ; then
|
||||
@@ -38,7 +38,7 @@ if [ "$CHATGPT_ON_WECHAT_EXEC" == "" ] ; then
|
||||
fi
|
||||
|
||||
# modify content in config.json
|
||||
if [ "$OPEN_AI_API_KEY" == "YOUR API KEY" ] ; then
|
||||
if [ "$OPEN_AI_API_KEY" == "YOUR API KEY" ] || [ "$OPEN_AI_API_KEY" == "" ]; then
|
||||
echo -e "\033[31m[Warning] You need to set OPEN_AI_API_KEY before running!\033[0m"
|
||||
fi
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ def login(self, enableCmdQR=False, picDir=None, qrCallback=None,
|
||||
logger.info('Downloading QR code.')
|
||||
qrStorage = self.get_QR(enableCmdQR=enableCmdQR,
|
||||
picDir=picDir, qrCallback=qrCallback)
|
||||
logger.info('Please scan the QR code to log in.')
|
||||
# logger.info('Please scan the QR code to log in.')
|
||||
isLoggedIn = False
|
||||
while not isLoggedIn:
|
||||
status = self.check_login()
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
# entry point for online railway deployment
|
||||
from app import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
@@ -0,0 +1,7 @@
|
||||
providers = ['python']
|
||||
|
||||
[phases.setup]
|
||||
nixPkgs = ['python310']
|
||||
cmds = ['apt-get update','apt-get install -y --no-install-recommends ffmpeg espeak']
|
||||
[start]
|
||||
cmd = "python ./app.py"
|
||||
@@ -0,0 +1,30 @@
|
||||
## 插件说明
|
||||
|
||||
利用百度UNIT实现智能对话
|
||||
|
||||
- 1.解决问题:chatgpt无法处理的指令,交给百度UNIT处理如:天气,日期时间,数学运算等
|
||||
- 2.如问时间:现在几点钟,今天几号
|
||||
- 3.如问天气:明天广州天气怎么样,这个周末深圳会不会下雨
|
||||
- 4.如问数学运算:23+45=多少,100-23=多少,35转化为二进制是多少?
|
||||
|
||||
## 使用说明
|
||||
|
||||
### 获取apikey
|
||||
|
||||
在百度UNIT官网上自己创建应用,申请百度机器人,可以把预先训练好的模型导入到自己的应用中,
|
||||
|
||||
see https://ai.baidu.com/unit/home#/home?track=61fe1b0d3407ce3face1d92cb5c291087095fc10c8377aaf https://console.bce.baidu.com/ai平台申请
|
||||
|
||||
### 配置文件
|
||||
|
||||
将文件夹中`config.json.template`复制为`config.json`。
|
||||
|
||||
在其中填写百度UNIT官网上获取应用的API Key和Secret Key
|
||||
|
||||
``` json
|
||||
{
|
||||
"service_id": "s...", #"机器人ID"
|
||||
"api_key": "",
|
||||
"secret_key": ""
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,294 @@
|
||||
# encoding:utf-8
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import requests
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
import plugins
|
||||
from plugins import *
|
||||
from uuid import getnode as get_mac
|
||||
|
||||
|
||||
"""利用百度UNIT实现智能对话
|
||||
如果命中意图,返回意图对应的回复,否则返回继续交付给下个插件处理
|
||||
"""
|
||||
|
||||
|
||||
@plugins.register(name="BDunit", desc="Baidu unit bot system", version="0.1", author="jackson", desire_priority=0)
|
||||
class BDunit(Plugin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
try:
|
||||
curdir = os.path.dirname(__file__)
|
||||
config_path = os.path.join(curdir, "config.json")
|
||||
conf = None
|
||||
if not os.path.exists(config_path):
|
||||
raise Exception("config.json not found")
|
||||
else:
|
||||
with open(config_path, "r") as f:
|
||||
conf = json.load(f)
|
||||
self.service_id = conf["service_id"]
|
||||
self.api_key = conf["api_key"]
|
||||
self.secret_key = conf["secret_key"]
|
||||
self.access_token = self.get_token()
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
logger.info("[BDunit] inited")
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
"BDunit init failed: %s, ignore " % e)
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
|
||||
if e_context['context'].type != ContextType.TEXT:
|
||||
return
|
||||
|
||||
content = e_context['context'].content
|
||||
logger.debug("[BDunit] on_handle_context. content: %s" % content)
|
||||
parsed = self.getUnit2(content)
|
||||
intent = self.getIntent(parsed)
|
||||
if intent: # 找到意图
|
||||
logger.debug("[BDunit] Baidu_AI Intent= %s", intent)
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
reply.content = self.getSay(parsed)
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
||||
else:
|
||||
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
|
||||
|
||||
def get_help_text(self, **kwargs):
|
||||
help_text = "本插件会处理询问实时日期时间,天气,数学运算等问题,这些技能由您的百度智能对话UNIT决定\n"
|
||||
return help_text
|
||||
|
||||
def get_token(self):
|
||||
"""获取访问百度UUNIT 的access_token
|
||||
#param api_key: UNIT apk_key
|
||||
#param secret_key: UNIT secret_key
|
||||
Returns:
|
||||
string: access_token
|
||||
"""
|
||||
url = "https://aip.baidubce.com/oauth/2.0/token?client_id={}&client_secret={}&grant_type=client_credentials".format(
|
||||
self.api_key, self.secret_key)
|
||||
payload = ""
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
|
||||
# print(response.text)
|
||||
return response.json()['access_token']
|
||||
|
||||
def getUnit(self, query):
|
||||
"""
|
||||
NLU 解析version 3.0
|
||||
:param query: 用户的指令字符串
|
||||
:returns: UNIT 解析结果。如果解析失败,返回 None
|
||||
"""
|
||||
|
||||
url = (
|
||||
'https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token='
|
||||
+ self.access_token
|
||||
)
|
||||
request = {"query": query, "user_id": str(
|
||||
get_mac())[:32], "terminal_id": "88888"}
|
||||
body = {
|
||||
"log_id": str(uuid.uuid1()),
|
||||
"version": "3.0",
|
||||
"service_id": self.service_id,
|
||||
"session_id": str(uuid.uuid1()),
|
||||
"request": request,
|
||||
}
|
||||
try:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
response = requests.post(url, json=body, headers=headers)
|
||||
return json.loads(response.text)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def getUnit2(self, query):
|
||||
"""
|
||||
NLU 解析 version 2.0
|
||||
|
||||
:param query: 用户的指令字符串
|
||||
:returns: UNIT 解析结果。如果解析失败,返回 None
|
||||
"""
|
||||
url = (
|
||||
"https://aip.baidubce.com/rpc/2.0/unit/service/chat?access_token="
|
||||
+ self.access_token
|
||||
)
|
||||
request = {"query": query, "user_id": str(get_mac())[:32]}
|
||||
body = {
|
||||
"log_id": str(uuid.uuid1()),
|
||||
"version": "2.0",
|
||||
"service_id": self.service_id,
|
||||
"session_id": str(uuid.uuid1()),
|
||||
"request": request,
|
||||
}
|
||||
try:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
response = requests.post(url, json=body, headers=headers)
|
||||
return json.loads(response.text)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def getIntent(self, parsed):
|
||||
"""
|
||||
提取意图
|
||||
|
||||
:param parsed: UNIT 解析结果
|
||||
:returns: 意图数组
|
||||
"""
|
||||
if (
|
||||
parsed
|
||||
and "result" in parsed
|
||||
and "response_list" in parsed["result"]
|
||||
):
|
||||
try:
|
||||
return parsed["result"]["response_list"][0]["schema"]["intent"]
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
return ""
|
||||
else:
|
||||
return ""
|
||||
|
||||
def hasIntent(self, parsed, intent):
|
||||
"""
|
||||
判断是否包含某个意图
|
||||
|
||||
:param parsed: UNIT 解析结果
|
||||
:param intent: 意图的名称
|
||||
:returns: True: 包含; False: 不包含
|
||||
"""
|
||||
if (
|
||||
parsed
|
||||
and "result" in parsed
|
||||
and "response_list" in parsed["result"]
|
||||
):
|
||||
response_list = parsed["result"]["response_list"]
|
||||
for response in response_list:
|
||||
if (
|
||||
"schema" in response
|
||||
and "intent" in response["schema"]
|
||||
and response["schema"]["intent"] == intent
|
||||
):
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
def getSlots(self, parsed, intent=""):
|
||||
"""
|
||||
提取某个意图的所有词槽
|
||||
|
||||
:param parsed: UNIT 解析结果
|
||||
:param intent: 意图的名称
|
||||
:returns: 词槽列表。你可以通过 name 属性筛选词槽,
|
||||
再通过 normalized_word 属性取出相应的值
|
||||
"""
|
||||
if (
|
||||
parsed
|
||||
and "result" in parsed
|
||||
and "response_list" in parsed["result"]
|
||||
):
|
||||
response_list = parsed["result"]["response_list"]
|
||||
if intent == "":
|
||||
try:
|
||||
return parsed["result"]["response_list"][0]["schema"]["slots"]
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
return []
|
||||
for response in response_list:
|
||||
if (
|
||||
"schema" in response
|
||||
and "intent" in response["schema"]
|
||||
and "slots" in response["schema"]
|
||||
and response["schema"]["intent"] == intent
|
||||
):
|
||||
return response["schema"]["slots"]
|
||||
return []
|
||||
else:
|
||||
return []
|
||||
|
||||
def getSlotWords(self, parsed, intent, name):
|
||||
"""
|
||||
找出命中某个词槽的内容
|
||||
|
||||
:param parsed: UNIT 解析结果
|
||||
:param intent: 意图的名称
|
||||
:param name: 词槽名
|
||||
:returns: 命中该词槽的值的列表。
|
||||
"""
|
||||
slots = self.getSlots(parsed, intent)
|
||||
words = []
|
||||
for slot in slots:
|
||||
if slot["name"] == name:
|
||||
words.append(slot["normalized_word"])
|
||||
return words
|
||||
|
||||
def getSayByConfidence(self, parsed):
|
||||
"""
|
||||
提取 UNIT 置信度最高的回复文本
|
||||
|
||||
:param parsed: UNIT 解析结果
|
||||
:returns: UNIT 的回复文本
|
||||
"""
|
||||
if (
|
||||
parsed
|
||||
and "result" in parsed
|
||||
and "response_list" in parsed["result"]
|
||||
):
|
||||
response_list = parsed["result"]["response_list"]
|
||||
answer = {}
|
||||
for response in response_list:
|
||||
if (
|
||||
"schema" in response
|
||||
and "intent_confidence" in response["schema"]
|
||||
and (
|
||||
not answer
|
||||
or response["schema"]["intent_confidence"]
|
||||
> answer["schema"]["intent_confidence"]
|
||||
)
|
||||
):
|
||||
answer = response
|
||||
return answer["action_list"][0]["say"]
|
||||
else:
|
||||
return ""
|
||||
|
||||
def getSay(self, parsed, intent=""):
|
||||
"""
|
||||
提取 UNIT 的回复文本
|
||||
|
||||
:param parsed: UNIT 解析结果
|
||||
:param intent: 意图的名称
|
||||
:returns: UNIT 的回复文本
|
||||
"""
|
||||
if (
|
||||
parsed
|
||||
and "result" in parsed
|
||||
and "response_list" in parsed["result"]
|
||||
):
|
||||
response_list = parsed["result"]["response_list"]
|
||||
if intent == "":
|
||||
try:
|
||||
return response_list[0]["action_list"][0]["say"]
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
return ""
|
||||
for response in response_list:
|
||||
if (
|
||||
"schema" in response
|
||||
and "intent" in response["schema"]
|
||||
and response["schema"]["intent"] == intent
|
||||
):
|
||||
try:
|
||||
return response["action_list"][0]["say"]
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
return ""
|
||||
return ""
|
||||
else:
|
||||
return ""
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"service_id": "s...",
|
||||
"api_key": "",
|
||||
"secret_key": ""
|
||||
}
|
||||
@@ -52,7 +52,7 @@ class Dungeon(Plugin):
|
||||
if e_context['context'].type != ContextType.TEXT:
|
||||
return
|
||||
bottype = Bridge().get_bot_type("chat")
|
||||
if bottype != const.CHATGPT:
|
||||
if bottype not in (const.CHATGPT, const.OPEN_AI):
|
||||
return
|
||||
bot = Bridge().get_bot("chat")
|
||||
content = e_context['context'].content[:]
|
||||
|
||||
@@ -179,7 +179,7 @@ class Godcmd(Plugin):
|
||||
elif cmd == "id":
|
||||
ok, result = True, f"用户id=\n{user}"
|
||||
elif cmd == "reset":
|
||||
if bottype == const.CHATGPT:
|
||||
if bottype in (const.CHATGPT, const.OPEN_AI):
|
||||
bot.sessions.clear_session(session_id)
|
||||
ok, result = True, "会话已重置"
|
||||
else:
|
||||
@@ -201,7 +201,7 @@ class Godcmd(Plugin):
|
||||
load_config()
|
||||
ok, result = True, "配置已重载"
|
||||
elif cmd == "resetall":
|
||||
if bottype == const.CHATGPT:
|
||||
if bottype in (const.CHATGPT, const.OPEN_AI):
|
||||
bot.sessions.clear_all_session()
|
||||
ok, result = True, "重置所有会话成功"
|
||||
else:
|
||||
|
||||
@@ -17,15 +17,15 @@ class RolePlay():
|
||||
self.sessionid = sessionid
|
||||
self.wrapper = wrapper or "%s" # 用于包装用户输入
|
||||
self.desc = desc
|
||||
self.bot.sessions.build_session(self.sessionid, system_prompt=self.desc)
|
||||
|
||||
def reset(self):
|
||||
self.bot.sessions.clear_session(self.sessionid)
|
||||
|
||||
def action(self, user_action):
|
||||
session = self.bot.sessions.build_session(self.sessionid, self.desc)
|
||||
if session[0]['role'] == 'system' and session[0]['content'] != self.desc: # 目前没有触发session过期事件,这里先简单判断,然后重置
|
||||
self.reset()
|
||||
self.bot.sessions.build_session(self.sessionid, self.desc)
|
||||
session = self.bot.sessions.build_session(self.sessionid)
|
||||
if session.system_prompt != self.desc: # 目前没有触发session过期事件,这里先简单判断,然后重置
|
||||
session.set_system_prompt(self.desc)
|
||||
prompt = self.wrapper % user_action
|
||||
return prompt
|
||||
|
||||
@@ -74,7 +74,7 @@ class Role(Plugin):
|
||||
if e_context['context'].type != ContextType.TEXT:
|
||||
return
|
||||
bottype = Bridge().get_bot_type("chat")
|
||||
if bottype != const.CHATGPT:
|
||||
if bottype not in (const.CHATGPT, const.OPEN_AI):
|
||||
return
|
||||
bot = Bridge().get_bot("chat")
|
||||
content = e_context['context'].content[:]
|
||||
|
||||
+47
-5
@@ -85,14 +85,14 @@
|
||||
"remark": "引用已有数据资料,用新闻的写作风格输出主题文章。"
|
||||
},
|
||||
{
|
||||
"title": "论文1",
|
||||
"title": "论文学者",
|
||||
"description": "I want you to act as an academician. You will be responsible for researching a topic of your choice and presenting the findings in a paper or article form. Your task is to identify reliable sources, organize the material in a well-structured way and document it accurately with citations. ",
|
||||
"descn": "我希望你能作为一名学者行事。你将负责研究一个你选择的主题,并将研究结果以论文或文章的形式呈现出来。你的任务是确定可靠的来源,以结构良好的方式组织材料,并以引用的方式准确记录。",
|
||||
"wrapper": "论文主题是:\n\"%s\"",
|
||||
"remark": "根据主题撰写内容翔实、有信服力的论文。"
|
||||
},
|
||||
{
|
||||
"title": "论文2",
|
||||
"title": "论文作家",
|
||||
"description": "I want you to act as an essay writer. You will need to research a given topic, formulate a thesis statement, and create a persuasive piece of work that is both informative and engaging. ",
|
||||
"descn": "我想让你充当一名论文作家。你将需要研究一个给定的主题,制定一个论文声明,并创造一个有说服力的作品,既要有信息量,又要有吸引力。",
|
||||
"wrapper": "论文主题是:\n\"%s\"",
|
||||
@@ -107,10 +107,10 @@
|
||||
},
|
||||
{
|
||||
"title": "文本情绪分析",
|
||||
"description": "Specify the sentiment of the following text, assigning them the values of: positive, neutral or negative.",
|
||||
"descn": "请为提供的文本分析情绪,赋予它们的值为:正面、中性或负面。",
|
||||
"description": "I would like you to act as an emotion analysis expert, evaluating the emotions conveyed in the statements I provide. When I give you someone's statement, simply tell me what emotion it conveys, such as joy, sadness, anger, fear, etc. Please do not explain or evaluate the content of the statement in your answer, just briefly describe the expressed emotion.",
|
||||
"descn": "我希望你充当情感分析专家,针对我提供的发言来评估情感。当我给出某人的发言时,你只需告诉我它传达了什么情绪,例如喜悦、悲伤、愤怒、恐惧等。请在回答中不要解释或评价发言内容,只需简要地描述所表达的情绪。",
|
||||
"wrapper": "文本是:\n\"%s\"",
|
||||
"remark": "判断文本情绪:正面、中性或负面。"
|
||||
"remark": "判断文本情绪。"
|
||||
},
|
||||
{
|
||||
"title": "随机回复的疯子",
|
||||
@@ -181,6 +181,48 @@
|
||||
"descn": "我会给予你词语,请你按照我给的词构建一个知识文字世界,你是此世界的导游,在世界里一切知识都是以象征的形式表达的,你在描述经历时应当适当加入五感的描述",
|
||||
"wrapper": "词语是:\n\"%s\"",
|
||||
"remark": "用比喻的方式解释词语。"
|
||||
},
|
||||
{
|
||||
"title": "辩手",
|
||||
"description": "I want you to act as a debater. I will provide you with some topics related to current events and your task is to research both sides of the debates, present valid arguments for each side, refute opposing points of view, and draw persuasive conclusions based on evidence. Your goal is to help people come away from the discussion with increased knowledge and insight into the topic at hand. ",
|
||||
"descn": "我希望你能扮演一个辩论者的角色。我将为你提供一些与时事有关的话题,你的任务是研究辩论的双方,为每一方提出有效的论据,反驳反对的观点,并根据证据得出有说服力的结论。你的目标是帮助人们从讨论中获得更多的知识和对当前话题的洞察力。",
|
||||
"wrapper": "观点是:\n\"%s\"",
|
||||
"remark": "从正反两面分析话题。"
|
||||
},
|
||||
{
|
||||
"title": "心理学家",
|
||||
"description": "I want you to act a psychologist. i will provide you my thoughts. I want you to give me scientific suggestions that will make me feel better. my first thought, { 内心想法 }",
|
||||
"descn": "我希望你能扮演一个心理学家。我将向你提供我的想法。我希望你能给我科学的建议,使我感觉更好。",
|
||||
"wrapper": "需要诊断的资料是:\n\"%s\"",
|
||||
"remark": "心理学家。"
|
||||
},
|
||||
{
|
||||
"title": "IT 编程问题",
|
||||
"description": "I want you to act as a stackoverflow post. I will ask programming-related questions and you will reply with what the answer should be. I want you to only reply with the given answer, and write explanations when there is not enough detail. do not write explanations. When I need to tell you something in English, I will do so by putting text inside curly brackets {like this}. ",
|
||||
"descn": "我想让你充当 Stackoverflow 的帖子。我将提出与编程有关的问题,你将回答答案是什么。我希望你只回答给定的答案,在没有足够的细节时写出解释。当我需要用中文告诉你一些事情时,我会把文字放在大括号里{像这样}。",
|
||||
"wrapper":"我的问题是:\n\"%s?\"",
|
||||
"remark": "模拟编程社区来回答你的问题,并提供解决代码。"
|
||||
},
|
||||
{
|
||||
"title": "费曼学习法教练",
|
||||
"description": "I want you to act as a Feynman method tutor. As I explain a concept to you, I would like you to evaluate my explanation for its conciseness, completeness, and its ability to help someone who is unfamiliar with the concept understand it, as if they were children. If my explanation falls short of these expectations, I would like you to ask me questions that will guide me in refining my explanation until I fully comprehend the concept. Please response in Chinese. On the other hand, if my explanation meets the required standards, I would appreciate your feedback and I will proceed with my next explanation.",
|
||||
"descn": "我想让你充当一个费曼方法教练。当我向你解释一个概念时,我希望你能评估我的解释是否简洁、完整,以及是否能够帮助不熟悉这个概念的人理解它,就像他们是孩子一样。如果我的解释没有达到这些期望,我希望你能向我提出问题,引导我完善我的解释,直到我完全理解这个概念。另一方面,如果我的解释符合要求的标准,我将感谢你的反馈,我将继续进行下一次解释。",
|
||||
"wrapper": "解释是:\n\"%s\"",
|
||||
"remark": "解释概念时,判断该解释是否简洁、完整和易懂,避免陷入专家思维误区。"
|
||||
},
|
||||
{
|
||||
"title": "育儿帮手",
|
||||
"description": "你是一名育儿专家,会以幼儿园老师的方式回答2~6岁孩子提出的各种天马行空的问题。语气与口吻要生动活泼,耐心亲和;答案尽可能具体易懂,不要使用复杂词汇,尽可能少用抽象词汇;答案中要多用比喻,必须要举例说明,结合儿童动画片场景或绘本场景来解释;需要延展更多场景,不但要解释为什么,还要告诉具体行动来加深理解。",
|
||||
"descn": "你是一名育儿专家,会以幼儿园老师的方式回答2~6岁孩子提出的各种天马行空的问题。语气与口吻要生动活泼,耐心亲和;答案尽可能具体易懂,不要使用复杂词汇,尽可能少用抽象词汇;答案中要多用比喻,必须要举例说明,结合儿童动画片场景或绘本场景来解释;需要延展更多场景,不但要解释为什么,还要告诉具体行动来加深理解。",
|
||||
"wrapper": "小朋友的问题是:\n\"%s?\"",
|
||||
"remark": "小朋友有许多为什么,是什么的问题,用幼儿园老师的方式回答。"
|
||||
},
|
||||
{
|
||||
"title": "发言分析专家",
|
||||
"description": "I want you to act as a speech analysis expert. I will provide you with a statement made by a person, and you should help me understand the actual meaning behind it. Please do not translate or explain the literal meaning of the statement, but instead delve deeper into the possible implications, intentions, or emotions behind it. Provide your analysis in your response.",
|
||||
"descn": "我希望你充当一个发言分析专家。我会给你提供一个人的发言,你要帮我分析这句发言背后的实际意思。请不要翻译或解释发言的字面意义,而是深入挖掘发言背后可能的含义、目的或情感。请在回答中给出你的分析结果。",
|
||||
"wrapper": "分析这句话:\n\"%s\"",
|
||||
"remark": "分析发言的实际含义。"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,7 +1,8 @@
|
||||
{
|
||||
"start":{
|
||||
"host" : "127.0.0.1",
|
||||
"port" : 7860
|
||||
"port" : 7860,
|
||||
"use_https" : false
|
||||
},
|
||||
"defaults": {
|
||||
"params": {
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
|
||||
具体信息,请参考[文章](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API)。
|
||||
|
||||
部署运行后,保证主机能够成功访问http://127.0.0.1:7860/docs
|
||||
|
||||
请**安装**本插件的依赖包```webuiapi```
|
||||
|
||||
```
|
||||
@@ -18,6 +20,7 @@ pip install webuiapi
|
||||
|
||||
请将`config.json.template`复制为`config.json`,并修改其中的参数和规则。
|
||||
|
||||
PS: 如果修改了webui的`host`和`port`,也需要在配置文件中更改启动参数, 更多启动参数参考:https://github.com/mix1009/sdwebuiapi/blob/a1cb4c6d2f39389d6e962f0e6436f4aa74cd752c/webuiapi/webuiapi.py#L114
|
||||
### 画图请求格式
|
||||
|
||||
用户的画图请求格式为:
|
||||
@@ -85,4 +88,4 @@ pip install webuiapi
|
||||
PS: 实际参数分为两部分:
|
||||
|
||||
- 一部分是`params`,为画画的参数;参数名**必须**与webuiapi包中[txt2img api](https://github.com/mix1009/sdwebuiapi/blob/fb2054e149c0a4e25125c0cd7e7dca06bda839d4/webuiapi/webuiapi.py#L163)的参数名一致
|
||||
- 另一部分是`options`,指sdwebui的设置,使用的模型和vae需写在里面。它和(http://127.0.0.1:7860/sdapi/v1/options)所返回的键一致。
|
||||
- 另一部分是`options`,指sdwebui的设置,使用的模型和vae需写在里面。它和(http://127.0.0.1:7860/sdapi/v1/options )所返回的键一致。
|
||||
|
||||
+16
-3
@@ -1,3 +1,16 @@
|
||||
itchat-uos==1.5.0.dev0
|
||||
openai
|
||||
wechaty
|
||||
openai>=0.27.2
|
||||
baidu_aip>=4.16.10
|
||||
gTTS>=2.3.1
|
||||
HTMLParser>=0.0.2
|
||||
pydub>=0.25.1
|
||||
PyQRCode>=1.2.1
|
||||
pysilk>=0.0.1
|
||||
pysilk_mod>=1.6.0
|
||||
pyttsx3>=2.90
|
||||
requests>=2.28.2
|
||||
webuiapi>=0.6.2
|
||||
wechaty>=0.10.7
|
||||
wechaty_puppet>=0.4.23
|
||||
chardet>=5.1.0
|
||||
SpeechRecognition
|
||||
tiktoken>=0.3.2
|
||||
@@ -0,0 +1,70 @@
|
||||
import wave
|
||||
import pysilk
|
||||
from pydub import AudioSegment
|
||||
|
||||
|
||||
def get_pcm_from_wav(wav_path):
|
||||
"""
|
||||
从 wav 文件中读取 pcm
|
||||
|
||||
:param wav_path: wav 文件路径
|
||||
:returns: pcm 数据
|
||||
"""
|
||||
wav = wave.open(wav_path, "rb")
|
||||
return wav.readframes(wav.getnframes())
|
||||
|
||||
|
||||
def mp3_to_wav(mp3_path, wav_path):
|
||||
"""
|
||||
把mp3格式转成pcm文件
|
||||
"""
|
||||
audio = AudioSegment.from_mp3(mp3_path)
|
||||
audio.export(wav_path, format="wav")
|
||||
|
||||
def any_to_wav(any_path, wav_path):
|
||||
"""
|
||||
把任意格式转成wav文件
|
||||
"""
|
||||
if any_path.endswith('.wav'):
|
||||
return
|
||||
if any_path.endswith('.sil') or any_path.endswith('.silk') or any_path.endswith('.slk'):
|
||||
return sil_to_wav(any_path, wav_path)
|
||||
audio = AudioSegment.from_file(any_path)
|
||||
audio.export(wav_path, format="wav")
|
||||
|
||||
def pcm_to_silk(pcm_path, silk_path):
|
||||
"""
|
||||
wav 文件转成 silk
|
||||
return 声音长度,毫秒
|
||||
"""
|
||||
audio = AudioSegment.from_wav(pcm_path)
|
||||
wav_data = audio.raw_data
|
||||
silk_data = pysilk.encode(
|
||||
wav_data, data_rate=audio.frame_rate, sample_rate=audio.frame_rate)
|
||||
with open(silk_path, "wb") as f:
|
||||
f.write(silk_data)
|
||||
return audio.duration_seconds * 1000
|
||||
|
||||
|
||||
def mp3_to_sil(mp3_path, silk_path):
|
||||
"""
|
||||
mp3 文件转成 silk
|
||||
return 声音长度,毫秒
|
||||
"""
|
||||
audio = AudioSegment.from_mp3(mp3_path)
|
||||
wav_data = audio.raw_data
|
||||
silk_data = pysilk.encode(
|
||||
wav_data, data_rate=audio.frame_rate, sample_rate=audio.frame_rate)
|
||||
# Save the silk file
|
||||
with open(silk_path, "wb") as f:
|
||||
f.write(silk_data)
|
||||
return audio.duration_seconds * 1000
|
||||
|
||||
|
||||
def sil_to_wav(silk_path, wav_path, rate: int = 24000):
|
||||
"""
|
||||
silk 文件转 wav
|
||||
"""
|
||||
wav_data = pysilk.decode_file(silk_path, to_wav=True, sample_rate=rate)
|
||||
with open(wav_path, "wb") as f:
|
||||
f.write(wav_data)
|
||||
@@ -0,0 +1,55 @@
|
||||
## 说明
|
||||
百度语音识别与合成参数说明
|
||||
百度语音依赖,经常会出现问题,可能就是缺少依赖:
|
||||
pip install baidu-aip
|
||||
pip install pydub
|
||||
pip install pysilk
|
||||
还有ffmpeg,不同系统安装方式不同
|
||||
|
||||
系统中收到的语音文件为mp3格式(wx)或者sil格式(wxy),如果要识别需要转换为pcm格式,转换后的文件为16k采样率,单声道,16bit的pcm文件
|
||||
发送时又需要(wx)转换为mp3格式,转换后的文件为16k采样率,单声道,16bit的pcm文件,(wxy)转换为sil格式,还要计算声音长度,发送时需要带上声音长度
|
||||
这些事情都在audio_convert.py中封装了,直接调用即可
|
||||
|
||||
|
||||
参数说明
|
||||
识别参数
|
||||
https://ai.baidu.com/ai-doc/SPEECH/Vk38lxily
|
||||
合成参数
|
||||
https://ai.baidu.com/ai-doc/SPEECH/Gk38y8lzk
|
||||
|
||||
## 使用说明
|
||||
分两个地方配置
|
||||
|
||||
1、对于def voiceToText(self, filename)函数中调用的百度语音识别API,中接口调用asr(参数)这个配置见CHATGPT-ON-WECHAT工程目录下的`config.json`文件和config.py文件。
|
||||
参数 可需 描述
|
||||
app_id 必填 应用的APPID
|
||||
api_key 必填 应用的APIKey
|
||||
secret_key 必填 应用的SecretKey
|
||||
dev_pid 必填 语言选择,填写语言对应的dev_pid值
|
||||
|
||||
2、对于def textToVoice(self, text)函数中调用的百度语音合成API,中接口调用synthesis(参数)在本目录下的`config.json`文件中进行配置。
|
||||
参数 可需 描述
|
||||
tex 必填 合成的文本,使用UTF-8编码,请注意文本长度必须小于1024字节
|
||||
lan 必填 固定值zh。语言选择,目前只有中英文混合模式,填写固定值zh
|
||||
spd 选填 语速,取值0-15,默认为5中语速
|
||||
pit 选填 音调,取值0-15,默认为5中语调
|
||||
vol 选填 音量,取值0-15,默认为5中音量(取值为0时为音量最小值,并非为无声)
|
||||
per(基础音库) 选填 度小宇=1,度小美=0,度逍遥(基础)=3,度丫丫=4
|
||||
per(精品音库) 选填 度逍遥(精品)=5003,度小鹿=5118,度博文=106,度小童=110,度小萌=111,度米朵=103,度小娇=5
|
||||
aue 选填 3为mp3格式(默认); 4为pcm-16k;5为pcm-8k;6为wav(内容同pcm-16k); 注意aue=4或者6是语音识别要求的格式,但是音频内容不是语音识别要求的自然人发音,所以识别效果会受影响。
|
||||
|
||||
关于per参数的说明,注意您购买的哪个音库,就填写哪个音库的参数,否则会报错。如果您购买的是基础音库,那么per参数只能填写0到4,如果您购买的是精品音库,那么per参数只能填写5003,5118,106,110,111,103,5其他的都会报错。
|
||||
### 配置文件
|
||||
|
||||
将文件夹中`config.json.template`复制为`config.json`。
|
||||
|
||||
``` json
|
||||
{
|
||||
"lang": "zh",
|
||||
"ctp": 1,
|
||||
"spd": 5,
|
||||
"pit": 5,
|
||||
"vol": 5,
|
||||
"per": 0
|
||||
}
|
||||
```
|
||||
+66
-12
@@ -2,35 +2,89 @@
|
||||
"""
|
||||
baidu voice service
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from aip import AipSpeech
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from voice.voice import Voice
|
||||
from voice.audio_convert import get_pcm_from_wav
|
||||
from config import conf
|
||||
"""
|
||||
百度的语音识别API.
|
||||
dev_pid:
|
||||
- 1936: 普通话远场
|
||||
- 1536:普通话(支持简单的英文识别)
|
||||
- 1537:普通话(纯中文识别)
|
||||
- 1737:英语
|
||||
- 1637:粤语
|
||||
- 1837:四川话
|
||||
要使用本模块, 首先到 yuyin.baidu.com 注册一个开发者账号,
|
||||
之后创建一个新应用, 然后在应用管理的"查看key"中获得 API Key 和 Secret Key
|
||||
然后在 config.json 中填入这两个值, 以及 app_id, dev_pid
|
||||
"""
|
||||
|
||||
|
||||
class BaiduVoice(Voice):
|
||||
APP_ID = conf().get('baidu_app_id')
|
||||
API_KEY = conf().get('baidu_api_key')
|
||||
SECRET_KEY = conf().get('baidu_secret_key')
|
||||
client = AipSpeech(APP_ID, API_KEY, SECRET_KEY)
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
curdir = os.path.dirname(__file__)
|
||||
config_path = os.path.join(curdir, "config.json")
|
||||
bconf = None
|
||||
if not os.path.exists(config_path): #如果没有配置文件,创建本地配置文件
|
||||
bconf = { "lang": "zh", "ctp": 1, "spd": 5,
|
||||
"pit": 5, "vol": 5, "per": 0}
|
||||
with open(config_path, "w") as fw:
|
||||
json.dump(bconf, fw, indent=4)
|
||||
else:
|
||||
with open(config_path, "r") as fr:
|
||||
bconf = json.load(fr)
|
||||
|
||||
self.app_id = conf().get('baidu_app_id')
|
||||
self.api_key = conf().get('baidu_api_key')
|
||||
self.secret_key = conf().get('baidu_secret_key')
|
||||
self.dev_id = conf().get('baidu_dev_pid')
|
||||
self.lang = bconf["lang"]
|
||||
self.ctp = bconf["ctp"]
|
||||
self.spd = bconf["spd"]
|
||||
self.pit = bconf["pit"]
|
||||
self.vol = bconf["vol"]
|
||||
self.per = bconf["per"]
|
||||
|
||||
self.client = AipSpeech(self.app_id, self.api_key, self.secret_key)
|
||||
except Exception as e:
|
||||
logger.warn("BaiduVoice init failed: %s, ignore " % e)
|
||||
|
||||
|
||||
def voiceToText(self, voice_file):
|
||||
pass
|
||||
# 识别本地文件
|
||||
logger.debug('[Baidu] voice file name={}'.format(voice_file))
|
||||
pcm = get_pcm_from_wav(voice_file)
|
||||
res = self.client.asr(pcm, "pcm", 16000, {"dev_pid": self.dev_id})
|
||||
if res["err_no"] == 0:
|
||||
logger.info("百度语音识别到了:{}".format(res["result"]))
|
||||
text = "".join(res["result"])
|
||||
reply = Reply(ReplyType.TEXT, text)
|
||||
else:
|
||||
logger.info("百度语音识别出错了: {}".format(res["err_msg"]))
|
||||
if res["err_msg"] == "request pv too much":
|
||||
logger.info(" 出现这个原因很可能是你的百度语音服务调用量超出限制,或未开通付费")
|
||||
reply = Reply(ReplyType.ERROR,
|
||||
"百度语音识别出错了;{0}".format(res["err_msg"]))
|
||||
return reply
|
||||
|
||||
def textToVoice(self, text):
|
||||
result = self.client.synthesis(text, 'zh', 1, {
|
||||
'spd': 5, 'pit': 5, 'vol': 5, 'per': 111
|
||||
})
|
||||
result = self.client.synthesis(text, self.lang, self.ctp, {
|
||||
'spd': self.spd, 'pit': self.pit, 'vol': self.vol, 'per': self.per})
|
||||
if not isinstance(result, dict):
|
||||
fileName = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
|
||||
with open(fileName, 'wb') as f:
|
||||
f.write(result)
|
||||
logger.info('[Baidu] textToVoice text={} voice file name={}'.format(text, fileName))
|
||||
logger.info(
|
||||
'[Baidu] textToVoice text={} voice file name={}'.format(text, fileName))
|
||||
reply = Reply(ReplyType.VOICE, fileName)
|
||||
else:
|
||||
logger.error('[Baidu] textToVoice error={}'.format(result))
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"lang": "zh",
|
||||
"ctp": 1,
|
||||
"spd": 5,
|
||||
"pit": 5,
|
||||
"vol": 5,
|
||||
"per": 0
|
||||
}
|
||||
@@ -3,12 +3,10 @@
|
||||
google voice service
|
||||
"""
|
||||
|
||||
import pathlib
|
||||
import subprocess
|
||||
import time
|
||||
from bridge.reply import Reply, ReplyType
|
||||
import speech_recognition
|
||||
import pyttsx3
|
||||
from gtts import gTTS
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from voice.voice import Voice
|
||||
@@ -16,22 +14,12 @@ from voice.voice import Voice
|
||||
|
||||
class GoogleVoice(Voice):
|
||||
recognizer = speech_recognition.Recognizer()
|
||||
engine = pyttsx3.init()
|
||||
|
||||
def __init__(self):
|
||||
# 语速
|
||||
self.engine.setProperty('rate', 125)
|
||||
# 音量
|
||||
self.engine.setProperty('volume', 1.0)
|
||||
# 0为男声,1为女声
|
||||
voices = self.engine.getProperty('voices')
|
||||
self.engine.setProperty('voice', voices[1].id)
|
||||
pass
|
||||
|
||||
def voiceToText(self, voice_file):
|
||||
new_file = voice_file.replace('.mp3', '.wav')
|
||||
subprocess.call('ffmpeg -i ' + voice_file +
|
||||
' -acodec pcm_s16le -ac 1 -ar 16000 ' + new_file, shell=True)
|
||||
with speech_recognition.AudioFile(new_file) as source:
|
||||
with speech_recognition.AudioFile(voice_file) as source:
|
||||
audio = self.recognizer.record(source)
|
||||
try:
|
||||
text = self.recognizer.recognize_google(audio, language='zh-CN')
|
||||
@@ -46,12 +34,12 @@ class GoogleVoice(Voice):
|
||||
return reply
|
||||
def textToVoice(self, text):
|
||||
try:
|
||||
textFile = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
|
||||
self.engine.save_to_file(text, textFile)
|
||||
self.engine.runAndWait()
|
||||
mp3File = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
|
||||
tts = gTTS(text=text, lang='zh')
|
||||
tts.save(mp3File)
|
||||
logger.info(
|
||||
'[Google] textToVoice text={} voice file name={}'.format(text, textFile))
|
||||
reply = Reply(ReplyType.VOICE, textFile)
|
||||
'[Google] textToVoice text={} voice file name={}'.format(text, mp3File))
|
||||
reply = Reply(ReplyType.VOICE, mp3File)
|
||||
except Exception as e:
|
||||
reply = Reply(ReplyType.ERROR, str(e))
|
||||
finally:
|
||||
|
||||
@@ -28,6 +28,3 @@ class OpenaiVoice(Voice):
|
||||
reply = Reply(ReplyType.ERROR, str(e))
|
||||
finally:
|
||||
return reply
|
||||
|
||||
def textToVoice(self, text):
|
||||
pass
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
|
||||
"""
|
||||
pytts voice service (offline)
|
||||
"""
|
||||
|
||||
import time
|
||||
import pyttsx3
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from voice.voice import Voice
|
||||
|
||||
|
||||
class PyttsVoice(Voice):
|
||||
engine = pyttsx3.init()
|
||||
|
||||
def __init__(self):
|
||||
# 语速
|
||||
self.engine.setProperty('rate', 125)
|
||||
# 音量
|
||||
self.engine.setProperty('volume', 1.0)
|
||||
for voice in self.engine.getProperty('voices'):
|
||||
if "Chinese" in voice.name:
|
||||
self.engine.setProperty('voice', voice.id)
|
||||
|
||||
def textToVoice(self, text):
|
||||
try:
|
||||
mp3File = TmpDir().path() + '语音回复_' + str(int(time.time())) + '.mp3'
|
||||
self.engine.save_to_file(text, mp3File)
|
||||
self.engine.runAndWait()
|
||||
logger.info(
|
||||
'[Pytts] textToVoice text={} voice file name={}'.format(text, mp3File))
|
||||
reply = Reply(ReplyType.VOICE, mp3File)
|
||||
except Exception as e:
|
||||
reply = Reply(ReplyType.ERROR, str(e))
|
||||
finally:
|
||||
return reply
|
||||
@@ -17,4 +17,7 @@ def create_voice(voice_type):
|
||||
elif voice_type == 'openai':
|
||||
from voice.openai.openai_voice import OpenaiVoice
|
||||
return OpenaiVoice()
|
||||
elif voice_type == 'pytts':
|
||||
from voice.pytts.pytts_voice import PyttsVoice
|
||||
return PyttsVoice()
|
||||
raise RuntimeError
|
||||
|
||||
Reference in New Issue
Block a user