Compare commits

...

58 Commits

Author SHA1 Message Date
zhayujie a64d7c42b1 fix: xunfei ws error log 2023-08-26 11:46:01 +08:00
zhayujie 36b6cc58bf fix: on_close params 2023-08-26 11:37:27 +08:00
zhayujie 5ac8a257e7 fix: add gpt-3.5-turbo in model_list 2023-08-26 10:50:31 +08:00
zhayujie 74119d0372 fix: websocket version 2023-08-25 23:57:59 +08:00
zhayujie 4e162c73e5 fix: update websocket version 2023-08-25 23:10:47 +08:00
zhayujie 5ff753a492 feat: add global model check 2023-08-25 17:26:40 +08:00
zhayujie 89400630c0 fix: xunfei client bug 2023-08-25 16:55:32 +08:00
zhayujie 3899c0cfe3 Merge pull request #1371 from uezhenxiang2023/Peter
add ElevenLabs TTS to voice factory
2023-08-25 16:15:18 +08:00
zhayujie a086f1989f feat: add xunfei spark bot 2023-08-25 16:06:55 +08:00
zhayujie 1171b04e93 fix: wenxin token discard bug 2023-08-25 12:24:16 +08:00
uezhenxiang2023 c55d81825a Merge branch 'zhayujie:master' into Peter 2023-08-25 12:12:06 +08:00
zhayujie 2dcd026e9f logs: add baidu reply log 2023-08-25 11:19:00 +08:00
zhayujie cdf8609d24 Merge pull request #1360 from zyqfork/master
dockerfile fallback debian11,fix azure cognitiveservices speech error
2023-08-25 01:24:34 +08:00
zhayujie 36580c5f7f Merge pull request #1363 from iRedScarf/master
把温度值设置默认放进config.json
2023-08-25 01:24:02 +08:00
zhayujie 1cff2521f4 fix: add web.py and linkai base url 2023-08-22 11:09:01 +08:00
uezhenxiang2023 db4998a56b replace requests with elevenlabs for audio generation 2023-08-20 10:58:26 +08:00
uezhenxiang2023 acbd506568 add ElevenLabs TTS to voice factory 2023-08-19 11:20:47 +08:00
eks 0cf8e3be73 Merge branch 'zhayujie:master' into master 2023-08-16 16:54:34 +08:00
zhayujie 2473334dfc fix: channel send compatibility and add log 2023-08-14 23:09:51 +08:00
eks 1ff72d1d37 Merge branch 'zhayujie:master' into master 2023-08-11 13:50:11 +08:00
eks 241fad5524 Update config-template.json
把温度值默认放进config.json
2023-08-11 13:49:47 +08:00
zouyq 1b48cea50a dockerfile fallback debian11,fix azure cognitiveservices speech error
Python 3.10-slim based Debian 12, using Azure TextToVoice may result in an error. the Speech SDK does not currently support OpenSSL 3.0, which is the default version in Ubuntu 22.04 and Debian 12
2023-08-10 17:39:25 +08:00
zhayujie 88bf345b91 docs: update plugin README 2023-08-08 17:03:18 +08:00
zhayujie ab4ff3d1a3 config: reduce the config of baidu-wenxin 2023-08-08 16:04:25 +08:00
zhayujie 3502e0d643 Merge pull request #1336 from kevin808/master
添加百度文心一言接口
2023-08-08 15:46:47 +08:00
zhayujie 995894d3aa Merge branch 'master' into master 2023-08-08 15:46:07 +08:00
zhayujie 4da8714124 Merge pull request #1358 from zhayujie/feat-1.3.5
feat: add midjourney variation and reset
2023-08-08 11:21:35 +08:00
zhayujie 6b247ae880 feat: add midjourney variation and reset 2023-08-07 19:14:09 +08:00
zhayujie 176941ea3b Merge pull request #1357 from zhayujie/feat-1.3.5
feat: add plugin instructions and fix some issues
2023-08-07 14:44:03 +08:00
zhayujie 5176b56d3b fix: global plugin read encoding 2023-08-07 14:42:24 +08:00
zhayujie 8abf18ab25 feat: add knowledge base and midjourney switch instruction 2023-08-06 17:57:07 +08:00
zhayujie 395edbd9f4 fix: only filter messages sent by the bot itself in private chat 2023-08-06 16:02:02 +08:00
zhayujie 2386eb8fc2 fix: unable to use plugin when group nickname is set 2023-08-06 15:44:48 +08:00
zhayujie 68208f82a0 docs: update README.md 2023-08-01 00:08:39 +08:00
zhayujie ca916b7ce5 fix: default to fast mode 2023-07-31 21:40:50 +08:00
zhayujie 01e02934da Merge pull request #1334 from zyqfork/master
azure api add api-version https://learn.microsoft.com/zh-cn/azure/ai-serv…
2023-07-31 18:40:06 +08:00
zhayujie c81a79f7b9 Merge pull request #1104 from mari1995/feat_my_msg
feat: 手机上回复消息,不触发机器人
2023-07-31 18:02:41 +08:00
zhayujie 1133648bf6 Merge branch 'master' of github.com:zhayujie/chatgpt-on-wechat 2023-07-31 17:58:06 +08:00
zhayujie e05bc541d7 Merge pull request #1346 from befantasy/patch-1
Update keyword.py 增加返回图片的功能
2023-07-31 17:53:46 +08:00
zhayujie d689d20482 docs: update README.md 2023-07-31 17:52:05 +08:00
zhayujie 39dd99b272 Merge pull request #1343 from zhayujie/feat-1.3.4
feat: add midjourney and app manager plugin
2023-07-31 17:15:22 +08:00
zhayujie cda21acb43 feat: use new linkai completion api 2023-07-31 16:11:33 +08:00
zhayujie 9bd7d09f20 fix: remove relax mode temporarily 2023-07-31 14:42:50 +08:00
zhayujie b22994c2d2 fix: some image bug 2023-07-30 19:55:56 +08:00
zhayujie e027286b6d fix: midjourney check task thread 2023-07-30 15:16:19 +08:00
befantasy d6e16995e0 Update keyword.py 增加返回图片的功能
增加返回图片的功能。以http/https开头,且以.jpg/.jpeg/.png/.gif结尾的内容,识别为URL,自动以图片发送。
2023-07-30 14:40:07 +08:00
zhayujie 782bff3a51 fix: add debug log 2023-07-29 12:22:45 +08:00
zhayujie de26dc0597 fix: fast mode and relax mode checkout 2023-07-28 18:50:21 +08:00
zhayujie 233b24ab0f feat: add global admin config 2023-07-28 16:33:41 +08:00
zhayujie 2f9e5b1219 feat: check app_code dynamically 2023-07-28 12:40:06 +08:00
zhayujie dd36b8b150 config: add config template 2023-07-27 21:29:50 +08:00
zhayujie f81ac31fe1 feat: add linkai plugin to support midjourney and distinguish app between groups 2023-07-27 21:21:36 +08:00
Kevin Li 24b63bc5bd Add Baidu access token validation 2023-07-25 11:11:02 +08:00
Kevin Li 1817a972c6 Add Baidu Wenxin Bot 2023-07-25 09:52:47 +08:00
zyqcn@live.com 74a253f521 azure api add api-version:https://learn.microsoft.com/zh-cn/azure/ai-services/openai/reference 2023-07-24 16:28:05 +08:00
SSMario 4dbc54fa15 Revert "feat: 增加eleventLabs"
This reverts commit 1d4ff796d7.
2023-05-16 12:00:05 +08:00
SSMario 1d4ff796d7 feat: 增加eleventLabs 2023-05-16 11:50:54 +08:00
SSMario 44cb54a9ea feat: 手机上回复消息,不触发机器人 2023-05-16 09:38:38 +08:00
32 changed files with 1358 additions and 72 deletions
+6 -1
View File
@@ -1,6 +1,8 @@
.DS_Store
.idea
.vscode
.venv
.vs
.wechaty/
__pycache__/
venv*
@@ -22,6 +24,9 @@ plugins/**/
!plugins/tool
!plugins/banwords
!plugins/banwords/**/
plugins/banwords/__pycache__
plugins/banwords/lib/__pycache__
!plugins/hello
!plugins/role
!plugins/keyword
!plugins/keyword
!plugins/linkai
+20 -7
View File
@@ -5,11 +5,12 @@
最新版本支持的功能如下:
- [x] **多端部署:** 有多种部署方式可选择且功能完备,目前已支持个人微信,微信公众号和企业微信应用等部署方式
- [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3GPT-3.5GPT-4模型
- [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3, GPT-3.5, GPT-4, 文心一言, 讯飞星火
- [x] **语音识别:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai等多种语音模型
- [x] **图片生成:** 支持图片生成 和 图生图(如照片修复),可选择 Dell-E, stable diffusion, replicate模型
- [x] **图片生成:** 支持图片生成 和 图生图(如照片修复),可选择 Dell-E, stable diffusion, replicate, midjourney模型
- [x] **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结等插件
- [X] **Tool工具:** 与操作系统和互联网交互,支持最新信息搜索、数学计算、天气和资讯查询、网页总结,基于 [chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub) 实现
- [x] **知识库:** 通过上传知识库文件自定义专属机器人,可作为数字分身、领域知识库、智能客服使用,基于 [LinkAI](https://chat.link-ai.tech/console) 实现
> 欢迎接入更多应用,参考 [Terminal代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/terminal/terminal_channel.py)实现接收和发送消息逻辑即可接入。 同时欢迎增加新的插件,参考 [插件说明文档](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins)。
@@ -27,7 +28,9 @@ Demo made by [Visionn](https://www.wangpc.cc/)
# 更新日志
>**2023.06.12** 接入 [LinkAI](https://chat.link-ai.tech/console) 平台,可在线创建 个人知识库,并接入微信、公众号及企业微信中。使用参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
>**2023.08.08** 接入百度文心一言模型,通过 [插件](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/linkai) 支持 Midjourney 绘图
>**2023.06.12** 接入 [LinkAI](https://chat.link-ai.tech/console) 平台,可在线创建个人知识库,并接入微信、公众号及企业微信中,打造专属客服机器人。使用参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
>**2023.04.26** 支持企业微信应用号部署,兼容插件,并支持语音图片交互,私人助理理想选择,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatcom/README.md)。(contributed by [@lanvent](https://github.com/lanvent) in [#944](https://github.com/zhayujie/chatgpt-on-wechat/pull/944))
@@ -110,8 +113,8 @@ pip3 install azure-cognitiveservices-speech
# config.json文件内容示例
{
"open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY
"model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
"proxy": "127.0.0.1:7890", # 代理客户端的ip和端口
"model": "gpt-3.5-turbo", # 模型名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
"proxy": "", # 代理客户端的ip和端口,国内环境开启代理的需要填写该项,如 "127.0.0.1:7890"
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
@@ -123,9 +126,13 @@ pip3 install azure-cognitiveservices-speech
"group_speech_recognition": false, # 是否开启群组语音识别
"use_azure_chatgpt": false, # 是否使用Azure ChatGPT service代替openai ChatGPT service. 当设置为true时需要设置 open_ai_api_base,如 https://xxx.openai.azure.com/
"azure_deployment_id": "", # 采用Azure ChatGPT时,模型部署名称
"azure_api_version": "", # 采用Azure ChatGPT时,API版本
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
# 订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复,可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
"subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。"
"subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。",
"use_linkai": false, # 是否使用LinkAI接口,默认关闭,开启后可国内访问,使用知识库和MJ
"linkai_api_key": "", # LinkAI Api Key
"linkai_app_code": "" # LinkAI 应用code
}
```
**配置说明:**
@@ -150,7 +157,7 @@ pip3 install azure-cognitiveservices-speech
**4.其他配置**
+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k` (其中gpt-4 api暂未完全开放,申请通过后可使用)
+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k`, `wenxin` (其中gpt-4 api暂未完全开放,申请通过后可使用)
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
+ `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
+ 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix `
@@ -162,6 +169,12 @@ pip3 install azure-cognitiveservices-speech
+ `character_desc` 配置中保存着你对机器人说的一段话,他会记住这段话并作为他的设定,你可以为他定制任何人格 (关于会话上下文的更多内容参考该 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/43))
+ `subscribe_msg`:订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复, 可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
**5.LinkAI配置 (可选)**
+ `use_linkai`: 是否使用LinkAI接口,开启后可国内访问,使用知识库和 `Midjourney` 绘画, 参考 [文档](https://link-ai.tech/platform/link-app/wechat)
+ `linkai_api_key`: LinkAI Api Key,可在 [控制台](https://chat.link-ai.tech/console/interface) 创建
+ `linkai_app_code`: LinkAI 应用code,选填
**本说明文档可能会未及时更新,当前所有可选的配置项均在该[`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中列出。**
## 运行
+104
View File
@@ -0,0 +1,104 @@
# encoding:utf-8
import requests, json
from bot.bot import Bot
from bot.session_manager import SessionManager
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from common.log import logger
from config import conf
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
BAIDU_API_KEY = conf().get("baidu_wenxin_api_key")
BAIDU_SECRET_KEY = conf().get("baidu_wenxin_secret_key")
class BaiduWenxinBot(Bot):
def __init__(self):
super().__init__()
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("baidu_wenxin_model") or "eb-instant")
def reply(self, query, context=None):
# acquire reply content
if context and context.type:
if context.type == ContextType.TEXT:
logger.info("[BAIDU] query={}".format(query))
session_id = context["session_id"]
reply = None
if query == "#清除记忆":
self.sessions.clear_session(session_id)
reply = Reply(ReplyType.INFO, "记忆已清除")
elif query == "#清除所有":
self.sessions.clear_all_session()
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
else:
session = self.sessions.session_query(query, session_id)
result = self.reply_text(session)
total_tokens, completion_tokens, reply_content = (
result["total_tokens"],
result["completion_tokens"],
result["content"],
)
logger.debug(
"[BAIDU] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content, completion_tokens)
)
if total_tokens == 0:
reply = Reply(ReplyType.ERROR, reply_content)
else:
self.sessions.session_reply(reply_content, session_id, total_tokens)
reply = Reply(ReplyType.TEXT, reply_content)
return reply
elif context.type == ContextType.IMAGE_CREATE:
ok, retstring = self.create_img(query, 0)
reply = None
if ok:
reply = Reply(ReplyType.IMAGE_URL, retstring)
else:
reply = Reply(ReplyType.ERROR, retstring)
return reply
def reply_text(self, session: BaiduWenxinSession, retry_count=0):
try:
logger.info("[BAIDU] model={}".format(session.model))
access_token = self.get_access_token()
if access_token == 'None':
logger.warn("[BAIDU] access token 获取失败")
return {
"total_tokens": 0,
"completion_tokens": 0,
"content": 0,
}
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + session.model + "?access_token=" + access_token
headers = {
'Content-Type': 'application/json'
}
payload = {'messages': session.messages}
response = requests.request("POST", url, headers=headers, data=json.dumps(payload))
response_text = json.loads(response.text)
logger.info(f"[BAIDU] response text={response_text}")
res_content = response_text["result"]
total_tokens = response_text["usage"]["total_tokens"]
completion_tokens = response_text["usage"]["completion_tokens"]
logger.info("[BAIDU] reply={}".format(res_content))
return {
"total_tokens": total_tokens,
"completion_tokens": completion_tokens,
"content": res_content,
}
except Exception as e:
need_retry = retry_count < 2
logger.warn("[BAIDU] Exception: {}".format(e))
need_retry = False
self.sessions.clear_session(session.session_id)
result = {"completion_tokens": 0, "content": "出错了: {}".format(e)}
return result
def get_access_token(self):
"""
使用 AKSK 生成鉴权签名(Access Token
:return: access_token,或是None(如果错误)
"""
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {"grant_type": "client_credentials", "client_id": BAIDU_API_KEY, "client_secret": BAIDU_SECRET_KEY}
return str(requests.post(url, params=params).json().get("access_token"))
+53
View File
@@ -0,0 +1,53 @@
from bot.session_manager import Session
from common.log import logger
"""
e.g. [
{"role": "user", "content": "Who won the world series in 2020?"},
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
{"role": "user", "content": "Where was it played?"}
]
"""
class BaiduWenxinSession(Session):
def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"):
super().__init__(session_id, system_prompt)
self.model = model
# 百度文心不支持system prompt
# self.reset()
def discard_exceeding(self, max_tokens, cur_tokens=None):
precise = True
try:
cur_tokens = self.calc_tokens()
except Exception as e:
precise = False
if cur_tokens is None:
raise e
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
while cur_tokens > max_tokens:
if len(self.messages) >= 2:
self.messages.pop(0)
self.messages.pop(0)
else:
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
break
if precise:
cur_tokens = self.calc_tokens()
else:
cur_tokens = cur_tokens - max_tokens
return cur_tokens
def calc_tokens(self):
return num_tokens_from_messages(self.messages, self.model)
def num_tokens_from_messages(messages, model):
"""Returns the number of tokens used by a list of messages."""
tokens = 0
for msg in messages:
# 官方token计算规则暂不明确: "大约为 token数为 "中文字 + 其他语种单词数 x 1.3"
# 这里先直接根据字数粗略估算吧,暂不影响正常使用,仅在判断是否丢弃历史会话的时候会有偏差
tokens += len(msg["content"])
return tokens
+9 -8
View File
@@ -11,31 +11,32 @@ def create_bot(bot_type):
:return: bot instance
"""
if bot_type == const.BAIDU:
# Baidu Unit对话接口
from bot.baidu.baidu_unit_bot import BaiduUnitBot
return BaiduUnitBot()
# 替换Baidu Unit为Baidu文心千帆对话接口
# from bot.baidu.baidu_unit_bot import BaiduUnitBot
# return BaiduUnitBot()
from bot.baidu.baidu_wenxin import BaiduWenxinBot
return BaiduWenxinBot()
elif bot_type == const.CHATGPT:
# ChatGPT 网页端web接口
from bot.chatgpt.chat_gpt_bot import ChatGPTBot
return ChatGPTBot()
elif bot_type == const.OPEN_AI:
# OpenAI 官方对话模型API
from bot.openai.open_ai_bot import OpenAIBot
return OpenAIBot()
elif bot_type == const.CHATGPTONAZURE:
# Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/
from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot
return AzureChatGPTBot()
elif bot_type == const.XUNFEI:
from bot.xunfei.xunfei_spark_bot import XunFeiBot
return XunFeiBot()
elif bot_type == const.LINKAI:
from bot.linkai.link_ai_bot import LinkAIBot
return LinkAIBot()
raise RuntimeError
+1 -1
View File
@@ -166,7 +166,7 @@ class AzureChatGPTBot(ChatGPTBot):
def __init__(self):
super().__init__()
openai.api_type = "azure"
openai.api_version = "2023-03-15-preview"
openai.api_version = conf().get("azure_api_version", "2023-06-01-preview")
self.args["deployment_id"] = conf().get("azure_deployment_id")
def create_img(self, query, retry_count=0, api_key=None):
+14 -1
View File
@@ -55,11 +55,16 @@ class ChatGPTSession(Session):
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def num_tokens_from_messages(messages, model):
"""Returns the number of tokens used by a list of messages."""
if model in ["wenxin", "xunfei"]:
return num_tokens_by_character(messages)
import tiktoken
if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo"]:
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k"]:
elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k"]:
return num_tokens_from_messages(messages, model="gpt-4")
try:
@@ -85,3 +90,11 @@ def num_tokens_from_messages(messages, model):
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens
def num_tokens_by_character(messages):
"""Returns the number of tokens used by a list of messages."""
tokens = 0
for msg in messages:
tokens += len(msg["content"])
return tokens
+38 -30
View File
@@ -22,25 +22,30 @@ class LinkAIBot(Bot, OpenAIImage):
def __init__(self):
super().__init__()
self.base_url = "https://api.link-ai.chat/v1"
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
def reply(self, query, context: Context = None) -> Reply:
if context.type == ContextType.TEXT:
return self._chat(query, context)
elif context.type == ContextType.IMAGE_CREATE:
ok, retstring = self.create_img(query, 0)
reply = None
ok, res = self.create_img(query, 0)
if ok:
reply = Reply(ReplyType.IMAGE_URL, retstring)
reply = Reply(ReplyType.IMAGE_URL, res)
else:
reply = Reply(ReplyType.ERROR, retstring)
reply = Reply(ReplyType.ERROR, res)
return reply
else:
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
return reply
def _chat(self, query, context, retry_count=0):
def _chat(self, query, context, retry_count=0) -> Reply:
"""
发起对话请求
:param query: 请求提示词
:param context: 对话上下文
:param retry_count: 当前递归重试次数
:return: 回复
"""
if retry_count >= 2:
# exit from retry 2 times
logger.warn("[LINKAI] failed after maximum number of retry times")
@@ -52,53 +57,56 @@ class LinkAIBot(Bot, OpenAIImage):
logger.info(f"[LINKAI] won't set appcode because a plugin ({context['generate_breaked_by']}) affected the context")
app_code = None
else:
app_code = conf().get("linkai_app_code")
app_code = context.kwargs.get("app_code") or conf().get("linkai_app_code")
linkai_api_key = conf().get("linkai_api_key")
session_id = context["session_id"]
session = self.sessions.session_query(query, session_id)
model = conf().get("model") or "gpt-3.5-turbo"
# remove system message
if app_code and session.messages[0].get("role") == "system":
session.messages.pop(0)
logger.info(f"[LINKAI] query={query}, app_code={app_code}")
if session.messages[0].get("role") == "system":
if app_code or model == "wenxin":
session.messages.pop(0)
body = {
"appCode": app_code,
"app_code": app_code,
"messages": session.messages,
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
"model": model, # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin
"temperature": conf().get("temperature"),
"top_p": conf().get("top_p", 1),
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
}
logger.info(f"[LINKAI] query={query}, app_code={app_code}, mode={body.get('model')}")
headers = {"Authorization": "Bearer " + linkai_api_key}
# do http request
res = requests.post(url=self.base_url + "/chat/completion", json=body, headers=headers).json()
base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers,
timeout=conf().get("request_timeout", 180))
if res.status_code == 200:
# execute success
response = res.json()
reply_content = response["choices"][0]["message"]["content"]
total_tokens = response["usage"]["total_tokens"]
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
self.sessions.session_reply(reply_content, session_id, total_tokens)
return Reply(ReplyType.TEXT, reply_content)
if not res or not res["success"]:
if res.get("code") == self.AUTH_FAILED_CODE:
logger.exception(f"[LINKAI] please check your linkai_api_key, res={res}")
return Reply(ReplyType.ERROR, "请再问我一次吧")
else:
response = res.json()
error = response.get("error")
logger.error(f"[LINKAI] chat failed, status_code={res.status_code}, "
f"msg={error.get('message')}, type={error.get('type')}")
elif res.get("code") == self.NO_QUOTA_CODE:
logger.exception(f"[LINKAI] please check your account quota, https://chat.link-ai.tech/console/account")
return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")
else:
# retry
if res.status_code >= 500:
# server error, need retry
time.sleep(2)
logger.warn(f"[LINKAI] do retry, times={retry_count}")
return self._chat(query, context, retry_count + 1)
# execute success
reply_content = res["data"]["content"]
logger.info(f"[LINKAI] reply={reply_content}")
self.sessions.session_reply(reply_content, session_id)
return Reply(ReplyType.TEXT, reply_content)
return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")
except Exception as e:
logger.exception(e)
+250
View File
@@ -0,0 +1,250 @@
# encoding:utf-8
import requests, json
from bot.bot import Bot
from bot.session_manager import SessionManager
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
from bridge.context import ContextType, Context
from bridge.reply import Reply, ReplyType
from common.log import logger
from config import conf
from common import const
import time
import _thread as thread
import datetime
from datetime import datetime
from wsgiref.handlers import format_date_time
from urllib.parse import urlencode
import base64
import ssl
import hashlib
import hmac
import json
from time import mktime
from urllib.parse import urlparse
import websocket
import queue
import threading
import random
# 消息队列 map
queue_map = dict()
# 响应队列 map
reply_map = dict()
class XunFeiBot(Bot):
def __init__(self):
super().__init__()
self.app_id = conf().get("xunfei_app_id")
self.api_key = conf().get("xunfei_api_key")
self.api_secret = conf().get("xunfei_api_secret")
# 默认使用v2.0版本,1.5版本可设置为 general
self.domain = "generalv2"
# 默认使用v2.0版本,1.5版本可设置为 "ws://spark-api.xf-yun.com/v1.1/chat"
self.spark_url = "ws://spark-api.xf-yun.com/v2.1/chat"
self.host = urlparse(self.spark_url).netloc
self.path = urlparse(self.spark_url).path
# 和wenxin使用相同的session机制
self.sessions = SessionManager(BaiduWenxinSession, model=const.XUNFEI)
def reply(self, query, context: Context = None) -> Reply:
if context.type == ContextType.TEXT:
logger.info("[XunFei] query={}".format(query))
session_id = context["session_id"]
request_id = self.gen_request_id(session_id)
reply_map[request_id] = ""
session = self.sessions.session_query(query, session_id)
threading.Thread(target=self.create_web_socket, args=(session.messages, request_id)).start()
depth = 0
time.sleep(0.1)
t1 = time.time()
usage = {}
while depth <= 300:
try:
data_queue = queue_map.get(request_id)
if not data_queue:
depth += 1
time.sleep(0.1)
continue
data_item = data_queue.get(block=True, timeout=0.1)
if data_item.is_end:
# 请求结束
del queue_map[request_id]
if data_item.reply:
reply_map[request_id] += data_item.reply
usage = data_item.usage
break
reply_map[request_id] += data_item.reply
depth += 1
except Exception as e:
depth += 1
continue
t2 = time.time()
logger.info(f"[XunFei-API] response={reply_map[request_id]}, time={t2 - t1}s, usage={usage}")
self.sessions.session_reply(reply_map[request_id], session_id, usage.get("total_tokens"))
reply = Reply(ReplyType.TEXT, reply_map[request_id])
del reply_map[request_id]
return reply
else:
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
return reply
def create_web_socket(self, prompt, session_id, temperature=0.5):
logger.info(f"[XunFei] start connect, prompt={prompt}")
websocket.enableTrace(False)
wsUrl = self.create_url()
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close,
on_open=on_open)
data_queue = queue.Queue(1000)
queue_map[session_id] = data_queue
ws.appid = self.app_id
ws.question = prompt
ws.domain = self.domain
ws.session_id = session_id
ws.temperature = temperature
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def gen_request_id(self, session_id: str):
return session_id + "_" + str(int(time.time())) + "" + str(random.randint(0, 100))
# 生成url
def create_url(self):
# 生成RFC1123格式的时间戳
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
# 拼接字符串
signature_origin = "host: " + self.host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(self.api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", ' \
f'signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
# 将请求的鉴权参数组合为字典
v = {
"authorization": authorization,
"date": date,
"host": self.host
}
# 拼接鉴权参数,生成url
url = self.spark_url + '?' + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
return url
def gen_params(self, appid, domain, question):
"""
通过appid和用户的提问来生成请参数
"""
data = {
"header": {
"app_id": appid,
"uid": "1234"
},
"parameter": {
"chat": {
"domain": domain,
"random_threshold": 0.5,
"max_tokens": 2048,
"auditing": "default"
}
},
"payload": {
"message": {
"text": question
}
}
}
return data
class ReplyItem:
def __init__(self, reply, usage=None, is_end=False):
self.is_end = is_end
self.reply = reply
self.usage = usage
# 收到websocket错误的处理
def on_error(ws, error):
logger.error(f"[XunFei] error: {str(error)}")
# 收到websocket关闭的处理
def on_close(ws, one, two):
data_queue = queue_map.get(ws.session_id)
data_queue.put("END")
# 收到websocket连接建立的处理
def on_open(ws):
logger.info(f"[XunFei] Start websocket, session_id={ws.session_id}")
thread.start_new_thread(run, (ws,))
def run(ws, *args):
data = json.dumps(gen_params(appid=ws.appid, domain=ws.domain, question=ws.question, temperature=ws.temperature))
ws.send(data)
# Websocket 操作
# 收到websocket消息的处理
def on_message(ws, message):
data = json.loads(message)
code = data['header']['code']
if code != 0:
logger.error(f'请求错误: {code}, {data}')
ws.close()
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
data_queue = queue_map.get(ws.session_id)
if not data_queue:
logger.error(f"[XunFei] can't find data queue, session_id={ws.session_id}")
return
reply_item = ReplyItem(content)
if status == 2:
usage = data["payload"].get("usage")
reply_item = ReplyItem(content, usage)
reply_item.is_end = True
ws.close()
data_queue.put(reply_item)
def gen_params(appid, domain, question, temperature=0.5):
"""
通过appid和用户的提问来生成请参数
"""
data = {
"header": {
"app_id": appid,
"uid": "1234"
},
"parameter": {
"chat": {
"domain": domain,
"temperature": temperature,
"random_threshold": 0.5,
"max_tokens": 2048,
"auditing": "default"
}
},
"payload": {
"message": {
"text": question
}
}
}
return data
+10
View File
@@ -23,6 +23,10 @@ class Bridge(object):
self.btype["chat"] = const.OPEN_AI
if conf().get("use_azure_chatgpt", False):
self.btype["chat"] = const.CHATGPTONAZURE
if model_type in ["wenxin"]:
self.btype["chat"] = const.BAIDU
if model_type in ["xunfei"]:
self.btype["chat"] = const.XUNFEI
if conf().get("use_linkai") and conf().get("linkai_api_key"):
self.btype["chat"] = const.LINKAI
self.bots = {}
@@ -54,3 +58,9 @@ class Bridge(object):
def fetch_translate(self, text, from_lang="", to_lang="en") -> Reply:
return self.get_bot("translate").translate(text, from_lang, to_lang)
def reset_bot(self):
"""
重置bot路由
"""
self.__init__()
+6 -2
View File
@@ -108,8 +108,12 @@ class ChatChannel(Channel):
if not conf().get("group_at_off", False):
flag = True
pattern = f"@{re.escape(self.name)}(\u2005|\u0020)"
content = re.sub(pattern, r"", content)
subtract_res = re.sub(pattern, r"", content)
if subtract_res == content and context["msg"].self_display_name:
# 前缀移除后没有变化,使用群昵称再次移除
pattern = f"@{re.escape(context['msg'].self_display_name)}(\u2005|\u0020)"
subtract_res = re.sub(pattern, r"", content)
content = subtract_res
if not flag:
if context["origin_ctype"] == ContextType.VOICE:
logger.info("[WX]receive group voice, but checkprefix didn't match")
+3 -3
View File
@@ -24,9 +24,7 @@ is_at: 是否被at
- (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在)
actual_user_id: 实际发送者id (群聊必填)
actual_user_nickname:实际发送者昵称
self_display_name: 自身的展示名,设置群昵称时,该字段表示群昵称
_prepare_fn: 准备函数,用于准备消息的内容,比如下载图片等,
_prepared: 是否已经调用过准备函数
@@ -48,6 +46,8 @@ class ChatMessage(object):
to_user_nickname = None
other_user_id = None
other_user_nickname = None
my_msg = False
self_display_name = None
is_group = False
is_at = False
+7
View File
@@ -58,6 +58,9 @@ def _check(func):
if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
logger.debug("[WX]history message {} skipped".format(msgId))
return
if cmsg.my_msg and not cmsg.is_group:
logger.debug("[WX]my message {} skipped".format(msgId))
return
return func(self, cmsg)
return wrapper
@@ -189,10 +192,14 @@ class WechatChannel(ChatChannel):
logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
img_url = reply.content
logger.debug(f"[WX] start download image, img_url={img_url}")
pic_res = requests.get(img_url, stream=True)
image_storage = io.BytesIO()
size = 0
for block in pic_res.iter_content(1024):
size += len(block)
image_storage.write(block)
logger.info(f"[WX] download image success, size={size}, img_url={img_url}")
image_storage.seek(0)
itchat.send_image(image_storage, toUserName=receiver)
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
+7 -1
View File
@@ -57,13 +57,19 @@ class WechatMessage(ChatMessage):
self.from_user_nickname = nickname
if self.to_user_id == user_id:
self.to_user_nickname = nickname
try: # 陌生人时候, 'User'字段可能不存在
try: # 陌生人时候, User字段可能不存在
# my_msg 为True是表示是自己发送的消息
self.my_msg = itchat_msg["ToUserName"] == itchat_msg["User"]["UserName"] and \
itchat_msg["ToUserName"] != itchat_msg["FromUserName"]
self.other_user_id = itchat_msg["User"]["UserName"]
self.other_user_nickname = itchat_msg["User"]["NickName"]
if self.other_user_id == self.from_user_id:
self.from_user_nickname = self.other_user_nickname
if self.other_user_id == self.to_user_id:
self.to_user_nickname = self.other_user_nickname
if itchat_msg["User"].get("Self"):
# 自身的展示名,当设置了群昵称时,该字段表示群昵称
self.self_display_name = itchat_msg["User"].get("Self").get("DisplayName")
except KeyError as e: # 处理偶尔没有对方信息的情况
logger.warn("[WX]get other_user_id failed: " + str(e))
if self.from_user_id == user_id:
+3
View File
@@ -2,7 +2,10 @@
OPEN_AI = "openAI"
CHATGPT = "chatGPT"
BAIDU = "baidu"
XUNFEI = "xunfei"
CHATGPTONAZURE = "chatGPTOnAzure"
LINKAI = "linkai"
VERSION = "1.3.0"
MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "xunfei"]
+3
View File
@@ -1,6 +1,7 @@
{
"open_ai_api_key": "YOUR API KEY",
"model": "gpt-3.5-turbo",
"channel_type": "wx",
"proxy": "",
"hot_reload": false,
"single_chat_prefix": [
@@ -29,6 +30,8 @@
"conversation_max_tokens": 1000,
"expires_in_seconds": 3600,
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",
"temperature": 0.7,
"top_p": 1,
"subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。",
"use_linkai": false,
"linkai_api_key": "",
+24 -3
View File
@@ -16,9 +16,10 @@ available_setting = {
"open_ai_api_base": "https://api.openai.com/v1",
"proxy": "", # openai使用的代理
# chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
"model": "gpt-3.5-turbo",
"model": "gpt-3.5-turbo", # 还支持 gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
"use_azure_chatgpt": False, # 是否使用azure的chatgpt
"azure_deployment_id": "", # azure 模型部署名称
"azure_api_version": "", # azure api版本
# Bot触发配置
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
@@ -50,13 +51,21 @@ available_setting = {
"presence_penalty": 0,
"request_timeout": 60, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
"timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试
# Baidu 文心一言参数
"baidu_wenxin_model": "eb-instant", # 默认使用ERNIE-Bot-turbo模型
"baidu_wenxin_api_key": "", # Baidu api key
"baidu_wenxin_secret_key": "", # Baidu secret key
# 讯飞星火API
"xunfei_app_id": "", # 讯飞应用ID
"xunfei_api_key": "", # 讯飞 API key
"xunfei_api_secret": "", # 讯飞 API secret
# 语音设置
"speech_recognition": False, # 是否开启语音识别
"group_speech_recognition": False, # 是否开启群组语音识别
"voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key
"always_reply_voice": False, # 是否一直使用语音回复
"voice_to_text": "openai", # 语音识别引擎,支持openai,baidu,google,azure
"text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline),azure
"text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline),azure,elevenlabs
# baidu 语音api配置, 使用百度语音识别和语音合成时需要
"baidu_app_id": "",
"baidu_api_key": "",
@@ -66,6 +75,9 @@ available_setting = {
# azure 语音api配置, 使用azure语音识别和语音合成时需要
"azure_voice_api_key": "",
"azure_voice_region": "japaneast",
# elevenlabs 语音api配置
"xi_api_key": "", #获取ap的方法可以参考https://docs.elevenlabs.io/api-reference/quick-start/authentication
"xi_voice_id": "", #ElevenLabs提供了9种英式、美式等英语发音id,分别是“Adam/Antoni/Arnold/Bella/Domi/Elli/Josh/Rachel/Sam”
# 服务时间限制,目前支持itchat
"chat_time_module": False, # 是否开启服务时间限制
"chat_start_time": "00:00", # 服务开始时间
@@ -102,10 +114,13 @@ available_setting = {
"appdata_dir": "", # 数据目录
# 插件配置
"plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突
# 是否使用全局插件配置
"use_global_plugin_config": False,
# 知识库平台配置
"use_linkai": False,
"linkai_api_key": "",
"linkai_app_code": ""
"linkai_app_code": "",
"linkai_api_base": "https://api.link-ai.chat" # linkAI服务地址,若国内无法访问或延迟较高可改为 https://api.link-ai.tech
}
@@ -252,3 +267,9 @@ def pconf(plugin_name: str) -> dict:
:return: 该插件的配置项
"""
return plugin_config.get(plugin_name.lower())
# 全局配置,用于存放全局生效的状态
global_config = {
"admin_users": []
}
+2 -2
View File
@@ -1,4 +1,4 @@
FROM python:3.10-slim
FROM python:3.10-slim-bullseye
LABEL maintainer="foo@bar.com"
ARG TZ='Asia/Shanghai'
@@ -32,4 +32,4 @@ RUN chmod +x /entrypoint.sh \
USER noroot
ENTRYPOINT ["/entrypoint.sh"]
ENTRYPOINT ["/entrypoint.sh"]
+1
View File
@@ -18,6 +18,7 @@ services:
SPEECH_RECOGNITION: 'False'
CHARACTER_DESC: '你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。'
EXPIRES_IN_SECONDS: 3600
USE_GLOBAL_PLUGIN_CONFIG: 'True'
USE_LINKAI: 'False'
LINKAI_API_KEY: ''
LINKAI_APP_CODE: ''
+14
View File
@@ -20,5 +20,19 @@
"no_default": false,
"model_name": "gpt-3.5-turbo"
}
},
"linkai": {
"group_app_map": {
"测试群1": "default",
"测试群2": "Kv2fXJcH"
},
"midjourney": {
"enabled": true,
"auto_translate": true,
"img_proxy": true,
"max_tasks": 3,
"max_tasks_per_user": 1,
"use_image_create_prefix": true
}
}
}
+28 -5
View File
@@ -4,7 +4,6 @@ import json
import os
import random
import string
import traceback
from typing import Tuple
import plugins
@@ -12,8 +11,7 @@ from bridge.bridge import Bridge
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from common import const
from common.log import logger
from config import conf, load_config
from config import conf, load_config, global_config
from plugins import *
# 定义指令集
@@ -32,6 +30,10 @@ COMMANDS = {
"args": ["口令"],
"desc": "管理员认证",
},
"model": {
"alias": ["model", "模型"],
"desc": "查看和设置全局模型",
},
"set_openai_api_key": {
"alias": ["set_openai_api_key"],
"args": ["api_key"],
@@ -257,6 +259,18 @@ class Godcmd(Plugin):
break
if not ok:
result = "插件不存在或未启用"
elif cmd == "model":
if not isadmin and not self.is_admin_in_group(e_context["context"]):
ok, result = False, "需要管理员权限执行"
elif len(args) == 0:
ok, result = True, "当前模型为: " + str(conf().get("model"))
elif len(args) == 1:
if args[0] not in const.MODEL_LIST:
ok, result = False, "模型名称不存在"
else:
conf()["model"] = args[0]
Bridge().reset_bot()
ok, result = True, "模型设置为: " + str(conf().get("model"))
elif cmd == "id":
ok, result = True, user
elif cmd == "set_openai_api_key":
@@ -294,7 +308,7 @@ class Godcmd(Plugin):
except Exception as e:
ok, result = False, "你没有设置私有GPT模型"
elif cmd == "reset":
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI]:
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI]:
bot.sessions.clear_session(session_id)
channel.cancel_session(session_id)
ok, result = True, "会话已重置"
@@ -317,7 +331,8 @@ class Godcmd(Plugin):
load_config()
ok, result = True, "配置已重载"
elif cmd == "resetall":
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI]:
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI,
const.BAIDU, const.XUNFEI]:
channel.cancel_all_session()
bot.sessions.clear_all_session()
ok, result = True, "重置所有会话成功"
@@ -426,12 +441,20 @@ class Godcmd(Plugin):
password = args[0]
if password == self.password:
self.admin_users.append(userid)
global_config["admin_users"].append(userid)
return True, "认证成功"
elif password == self.temp_password:
self.admin_users.append(userid)
global_config["admin_users"].append(userid)
return True, "认证成功,请尽快设置口令"
else:
return False, "认证失败"
def get_help_text(self, isadmin=False, isgroup=False, **kwargs):
return get_help_text(isadmin, isgroup)
def is_admin_in_group(self, context):
if context["isgroup"]:
return context.kwargs.get("msg").actual_user_id in global_config["admin_users"]
return False
+12 -3
View File
@@ -54,9 +54,18 @@ class Keyword(Plugin):
logger.debug(f"[keyword] 匹配到关键字【{content}")
reply_text = self.keyword[content]
reply = Reply()
reply.type = ReplyType.TEXT
reply.content = reply_text
# 判断匹配内容的类型
if (reply_text.startswith("http://") or reply_text.startswith("https://")) and any(reply_text.endswith(ext) for ext in [".jpg", ".jpeg", ".png", ".gif", ".webp"]):
# 如果是以 http:// 或 https:// 开头,且.jpg/.jpeg/.png/.gif结尾,则认为是图片 URL
reply = Reply()
reply.type = ReplyType.IMAGE_URL
reply.content = reply_text
else:
# 否则认为是普通文本
reply = Reply()
reply.type = ReplyType.TEXT
reply.content = reply_text
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
+75
View File
@@ -0,0 +1,75 @@
## 插件说明
基于 LinkAI 提供的知识库、Midjourney绘画等能力对机器人的功能进行增强。平台地址: https://chat.link-ai.tech/console
## 插件配置
`plugins/linkai` 目录下的 `config.json.template` 配置模板复制为最终生效的 `config.json`:
以下是配置项说明:
```bash
{
"group_app_map": { # 群聊 和 应用编码 的映射关系
"测试群1": "default", # 表示在名称为 "测试群1" 的群聊中将使用app_code 为 default 的应用
"测试群2": "Kv2fXJcH"
},
"midjourney": {
"enabled": true, # midjourney 绘画开关
"auto_translate": true, # 是否自动将提示词翻译为英文
"img_proxy": true, # 是否对生成的图片使用代理,如果你是国外服务器,将这一项设置为false会获得更快的生成速度
"max_tasks": 3, # 支持同时提交的总任务个数
"max_tasks_per_user": 1, # 支持单个用户同时提交的任务个数
"use_image_create_prefix": true # 是否使用全局的绘画触发词,如果开启将同时支持由`config.json`中的 image_create_prefix 配置触发
}
}
```
注意:
- 配置项中 `group_app_map` 部分是用于映射群聊与LinkAI平台上的应用, `midjourney` 部分是 mj 画图的配置,可根据需要进行填写,未填写配置时默认不开启相应功能
- 实际 `config.json` 配置中应保证json格式,不应携带 '#' 及后面的注释
- 如果是`docker`部署,可通过映射 `plugins/config.json` 到容器中来完成插件配置,参考[文档](https://github.com/zhayujie/chatgpt-on-wechat#3-%E6%8F%92%E4%BB%B6%E4%BD%BF%E7%94%A8)
## 插件使用
> 使用插件中的知识库管理功能需要首先开启`linkai`对话,依赖全局 `config.json` 中的 `use_linkai` 和 `linkai_api_key` 配置;而midjourney绘画功能则只需填写 `linkai_api_key` 配置,`use_linkai` 无论是否关闭均可使用。具体可参考 [详细文档](https://link-ai.tech/platform/link-app/wechat)。
完成配置后运行项目,会自动运行插件,输入 `#help linkai` 可查看插件功能。
### 1.知识库管理功能
提供在不同群聊使用不同应用的功能。可以在上述 `group_app_map` 配置中固定映射关系,也可以通过指令在群中快速完成切换。
应用切换指令需要首先完成管理员 (`godcmd`) 插件的认证,然后按以下格式输入:
`$linkai app {app_code}`
例如输入 `$linkai app Kv2fXJcH`,即将当前群聊与 app_code为 Kv2fXJcH 的应用绑定。
另外,还可以通过 `$linkai close` 来一键关闭linkai对话,此时就会使用默认的openai接口;同理,发送 `$linkai open` 可以再次开启。
### 2.Midjourney绘画功能
指令格式:
```
- 图片生成: $mj 描述词1, 描述词2..
- 图片放大: $mju 图片ID 图片序号
- 图片变换: $mjv 图片ID 图片序号
- 重置: $mjr 图片ID
```
例如:
```
"$mj a little cat, white --ar 9:16"
"$mju 1105592717188272288 2"
"$mjv 11055927171882 2"
"$mjr 11055927171882"
```
注:
1. 开启 `use_image_create_prefix` 配置后可直接复用全局画图触发词,以"画"开头便可以生成图片。
2. 提示词内容中包含敏感词或者参数格式错误可能导致绘画失败,生成失败不消耗积分
3. 使用 `$mj open``$mj close` 指令可以快速打开和关闭绘图功能
+1
View File
@@ -0,0 +1 @@
from .linkai import *
+14
View File
@@ -0,0 +1,14 @@
{
"group_app_map": {
"测试群1": "default",
"测试群2": "Kv2fXJcH"
},
"midjourney": {
"enabled": true,
"auto_translate": true,
"img_proxy": true,
"max_tasks": 3,
"max_tasks_per_user": 1,
"use_image_create_prefix": true
}
}
+164
View File
@@ -0,0 +1,164 @@
import plugins
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from config import global_config
from plugins import *
from .midjourney import MJBot
from bridge import bridge
@plugins.register(
name="linkai",
desc="A plugin that supports knowledge base and midjourney drawing.",
version="0.1.0",
author="https://link-ai.tech",
)
class LinkAI(Plugin):
def __init__(self):
super().__init__()
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
self.config = super().load_config()
if self.config:
self.mj_bot = MJBot(self.config.get("midjourney"))
logger.info("[LinkAI] inited")
def on_handle_context(self, e_context: EventContext):
"""
消息处理逻辑
:param e_context: 消息上下文
"""
if not self.config:
return
context = e_context['context']
if context.type not in [ContextType.TEXT, ContextType.IMAGE, ContextType.IMAGE_CREATE]:
# filter content no need solve
return
mj_type = self.mj_bot.judge_mj_task_type(e_context)
if mj_type:
# MJ作图任务处理
self.mj_bot.process_mj_task(mj_type, e_context)
return
if context.content.startswith(f"{_get_trigger_prefix()}linkai"):
# 应用管理功能
self._process_admin_cmd(e_context)
return
if self._is_chat_task(e_context):
# 文本对话任务处理
self._process_chat_task(e_context)
# 插件管理功能
def _process_admin_cmd(self, e_context: EventContext):
context = e_context['context']
cmd = context.content.split()
if len(cmd) == 1 or (len(cmd) == 2 and cmd[1] == "help"):
_set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO)
return
if len(cmd) == 2 and (cmd[1] == "open" or cmd[1] == "close"):
# 知识库开关指令
if not _is_admin(e_context):
_set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
return
is_open = True
tips_text = "开启"
if cmd[1] == "close":
tips_text = "关闭"
is_open = False
conf()["use_linkai"] = is_open
bridge.Bridge().reset_bot()
_set_reply_text(f"知识库功能已{tips_text}", e_context, level=ReplyType.INFO)
return
if len(cmd) == 3 and cmd[1] == "app":
# 知识库应用切换指令
if not context.kwargs.get("isgroup"):
_set_reply_text("该指令需在群聊中使用", e_context, level=ReplyType.ERROR)
return
if not _is_admin(e_context):
_set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
return
app_code = cmd[2]
group_name = context.kwargs.get("msg").from_user_nickname
group_mapping = self.config.get("group_app_map")
if group_mapping:
group_mapping[group_name] = app_code
else:
self.config["group_app_map"] = {group_name: app_code}
# 保存插件配置
super().save_config(self.config)
_set_reply_text(f"应用设置成功: {app_code}", e_context, level=ReplyType.INFO)
else:
_set_reply_text(f"指令错误,请输入{_get_trigger_prefix()}linkai help 获取帮助", e_context,
level=ReplyType.INFO)
return
# LinkAI 对话任务处理
def _is_chat_task(self, e_context: EventContext):
context = e_context['context']
# 群聊应用管理
return self.config.get("group_app_map") and context.kwargs.get("isgroup")
def _process_chat_task(self, e_context: EventContext):
"""
处理LinkAI对话任务
:param e_context: 对话上下文
"""
context = e_context['context']
# 群聊应用管理
group_name = context.kwargs.get("msg").from_user_nickname
app_code = self._fetch_group_app_code(group_name)
if app_code:
context.kwargs['app_code'] = app_code
def _fetch_group_app_code(self, group_name: str) -> str:
"""
根据群聊名称获取对应的应用code
:param group_name: 群聊名称
:return: 应用code
"""
group_mapping = self.config.get("group_app_map")
if group_mapping:
app_code = group_mapping.get(group_name) or group_mapping.get("ALL_GROUP")
return app_code
def get_help_text(self, verbose=False, **kwargs):
trigger_prefix = _get_trigger_prefix()
help_text = "用于集成 LinkAI 提供的知识库、Midjourney绘画等能力。\n\n"
if not verbose:
return help_text
help_text += f'📖 知识库\n - 群聊中指定应用: {trigger_prefix}linkai app 应用编码\n'
help_text += f' - {trigger_prefix}linkai open: 开启对话\n'
help_text += f' - {trigger_prefix}linkai close: 关闭对话\n'
help_text += f'\n例如: \n"{trigger_prefix}linkai app Kv2fXJcH"\n\n'
help_text += f"🎨 绘画\n - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n - 变换: {trigger_prefix}mjv 图片ID 图片序号\n - 重置: {trigger_prefix}mjr 图片ID"
help_text += f"\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 11055927171882 2\""
help_text += f"\n\"{trigger_prefix}mjv 11055927171882 2\"\n\"{trigger_prefix}mjr 11055927171882\""
return help_text
# 静态方法
def _is_admin(e_context: EventContext) -> bool:
"""
判断消息是否由管理员用户发送
:param e_context: 消息上下文
:return: True: 是, False: 否
"""
context = e_context["context"]
if context["isgroup"]:
return context.kwargs.get("msg").actual_user_id in global_config["admin_users"]
else:
return context["receiver"] in global_config["admin_users"]
def _set_reply_text(content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR):
reply = Reply(level, content)
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
def _get_trigger_prefix():
return conf().get("plugin_trigger_prefix", "$")
+427
View File
@@ -0,0 +1,427 @@
from enum import Enum
from config import conf
from common.log import logger
import requests
import threading
import time
from bridge.reply import Reply, ReplyType
import aiohttp
import asyncio
from bridge.context import ContextType
from plugins import EventContext, EventAction
INVALID_REQUEST = 410
NOT_FOUND_ORIGIN_IMAGE = 461
NOT_FOUND_TASK = 462
class TaskType(Enum):
GENERATE = "generate"
UPSCALE = "upscale"
VARIATION = "variation"
RESET = "reset"
def __str__(self):
return self.name
class Status(Enum):
PENDING = "pending"
FINISHED = "finished"
EXPIRED = "expired"
ABORTED = "aborted"
def __str__(self):
return self.name
class TaskMode(Enum):
FAST = "fast"
RELAX = "relax"
task_name_mapping = {
TaskType.GENERATE.name: "生成",
TaskType.UPSCALE.name: "放大",
TaskType.VARIATION.name: "变换",
TaskType.RESET.name: "重新生成",
}
class MJTask:
def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int = 60 * 30,
status=Status.PENDING):
self.id = id
self.user_id = user_id
self.task_type = task_type
self.raw_prompt = raw_prompt
self.send_func = None # send_func(img_url)
self.expiry_time = time.time() + expires
self.status = status
self.img_url = None # url
self.img_id = None
def __str__(self):
return f"id={self.id}, user_id={self.user_id}, task_type={self.task_type}, status={self.status}, img_id={self.img_id}"
# midjourney bot
class MJBot:
def __init__(self, config):
self.base_url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/img/midjourney"
self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
self.config = config
self.tasks = {}
self.temp_dict = {}
self.tasks_lock = threading.Lock()
self.event_loop = asyncio.new_event_loop()
def judge_mj_task_type(self, e_context: EventContext):
"""
判断MJ任务的类型
:param e_context: 上下文
:return: 任务类型枚举
"""
if not self.config:
return None
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
context = e_context['context']
if context.type == ContextType.TEXT:
cmd_list = context.content.split(maxsplit=1)
if cmd_list[0].lower() == f"{trigger_prefix}mj":
return TaskType.GENERATE
elif cmd_list[0].lower() == f"{trigger_prefix}mju":
return TaskType.UPSCALE
elif cmd_list[0].lower() == f"{trigger_prefix}mjv":
return TaskType.VARIATION
elif cmd_list[0].lower() == f"{trigger_prefix}mjr":
return TaskType.RESET
elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"):
return TaskType.GENERATE
def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
"""
处理mj任务
:param mj_type: mj任务类型
:param e_context: 对话上下文
"""
context = e_context['context']
session_id = context["session_id"]
cmd = context.content.split(maxsplit=1)
if len(cmd) == 1 and context.type == ContextType.TEXT:
# midjourney 帮助指令
self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO)
return
if len(cmd) == 2 and (cmd[1] == "open" or cmd[1] == "close"):
# midjourney 开关指令
is_open = True
tips_text = "开启"
if cmd[1] == "close":
tips_text = "关闭"
is_open = False
self.config["enabled"] = is_open
self._set_reply_text(f"Midjourney绘画已{tips_text}", e_context, level=ReplyType.INFO)
return
if not self.config.get("enabled"):
logger.warn("Midjourney绘画未开启,请查看 plugins/linkai/config.json 中的配置")
self._set_reply_text(f"Midjourney绘画未开启", e_context, level=ReplyType.INFO)
return
if not self._check_rate_limit(session_id, e_context):
logger.warn("[MJ] midjourney task exceed rate limit")
return
if mj_type == TaskType.GENERATE:
if context.type == ContextType.IMAGE_CREATE:
raw_prompt = context.content
else:
# 图片生成
raw_prompt = cmd[1]
reply = self.generate(raw_prompt, session_id, e_context)
e_context['reply'] = reply
e_context.action = EventAction.BREAK_PASS
return
elif mj_type == TaskType.UPSCALE or mj_type == TaskType.VARIATION:
# 图片放大/变换
clist = cmd[1].split()
if len(clist) < 2:
self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context)
return
img_id = clist[0]
index = int(clist[1])
if index < 1 or index > 4:
self._set_reply_text(f"图片序号 {index} 错误,应在 1 至 4 之间", e_context)
return
key = f"{str(mj_type)}_{img_id}_{index}"
if self.temp_dict.get(key):
self._set_reply_text(f"{index} 张图片已经{task_name_mapping.get(str(mj_type))}过了", e_context)
return
# 执行图片放大/变换操作
reply = self.do_operate(mj_type, session_id, img_id, e_context, index)
e_context['reply'] = reply
e_context.action = EventAction.BREAK_PASS
return
elif mj_type == TaskType.RESET:
# 图片重新生成
clist = cmd[1].split()
if len(clist) < 1:
self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context)
return
img_id = clist[0]
# 图片重新生成
reply = self.do_operate(mj_type, session_id, img_id, e_context)
e_context['reply'] = reply
e_context.action = EventAction.BREAK_PASS
else:
self._set_reply_text(f"暂不支持该命令", e_context)
def generate(self, prompt: str, user_id: str, e_context: EventContext) -> Reply:
"""
图片生成
:param prompt: 提示词
:param user_id: 用户id
:param e_context: 对话上下文
:return: 任务ID
"""
logger.info(f"[MJ] image generate, prompt={prompt}")
mode = self._fetch_mode(prompt)
body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")}
if not self.config.get("img_proxy"):
body["img_proxy"] = False
res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5, 40))
if res.status_code == 200:
res = res.json()
logger.debug(f"[MJ] image generate, res={res}")
if res.get("code") == 200:
task_id = res.get("data").get("task_id")
real_prompt = res.get("data").get("real_prompt")
if mode == TaskMode.RELAX.value:
time_str = "1~10分钟"
else:
time_str = "1分钟"
content = f"🚀您的作品将在{time_str}左右完成,请耐心等待\n- - - - - - - - -\n"
if real_prompt:
content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}"
else:
content += f"prompt: {prompt}"
reply = Reply(ReplyType.INFO, content)
task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id,
task_type=TaskType.GENERATE)
# put to memory dict
self.tasks[task.id] = task
# asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
self._do_check_task(task, e_context)
return reply
else:
res_json = res.json()
logger.error(f"[MJ] generate error, msg={res_json.get('message')}, status_code={res.status_code}")
if res.status_code == INVALID_REQUEST:
reply = Reply(ReplyType.ERROR, "图片生成失败,请检查提示词参数或内容")
else:
reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试")
return reply
def do_operate(self, task_type: TaskType, user_id: str, img_id: str, e_context: EventContext,
index: int = None) -> Reply:
logger.info(f"[MJ] image operate, task_type={task_type}, img_id={img_id}, index={index}")
body = {"type": task_type.name, "img_id": img_id}
if index:
body["index"] = index
if not self.config.get("img_proxy"):
body["img_proxy"] = False
res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5, 40))
logger.debug(res)
if res.status_code == 200:
res = res.json()
if res.get("code") == 200:
task_id = res.get("data").get("task_id")
logger.info(f"[MJ] image operate processing, task_id={task_id}")
icon_map = {TaskType.UPSCALE: "🔎", TaskType.VARIATION: "🪄", TaskType.RESET: "🔄"}
content = f"{icon_map.get(task_type)}图片正在{task_name_mapping.get(task_type.name)}中,请耐心等待"
reply = Reply(ReplyType.INFO, content)
task = MJTask(id=task_id, status=Status.PENDING, user_id=user_id, task_type=task_type)
# put to memory dict
self.tasks[task.id] = task
key = f"{task_type.name}_{img_id}_{index}"
self.temp_dict[key] = True
# asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
self._do_check_task(task, e_context)
return reply
else:
error_msg = ""
if res.status_code == NOT_FOUND_ORIGIN_IMAGE:
error_msg = "请输入正确的图片ID"
res_json = res.json()
logger.error(f"[MJ] operate error, msg={res_json.get('message')}, status_code={res.status_code}")
reply = Reply(ReplyType.ERROR, error_msg or "图片生成失败,请稍后再试")
return reply
def check_task_sync(self, task: MJTask, e_context: EventContext):
logger.debug(f"[MJ] start check task status, {task}")
max_retry_times = 90
while max_retry_times > 0:
time.sleep(10)
url = f"{self.base_url}/tasks/{task.id}"
try:
res = requests.get(url, headers=self.headers, timeout=8)
if res.status_code == 200:
res_json = res.json()
logger.debug(f"[MJ] task check res sync, task_id={task.id}, status={res.status_code}, "
f"data={res_json.get('data')}, thread={threading.current_thread().name}")
if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name:
# process success res
if self.tasks.get(task.id):
self.tasks[task.id].status = Status.FINISHED
self._process_success_task(task, res_json.get("data"), e_context)
return
max_retry_times -= 1
else:
res_json = res.json()
logger.warn(f"[MJ] image check error, status_code={res.status_code}, res={res_json}")
max_retry_times -= 20
except Exception as e:
max_retry_times -= 20
logger.warn(e)
logger.warn("[MJ] end from poll")
if self.tasks.get(task.id):
self.tasks[task.id].status = Status.EXPIRED
def _do_check_task(self, task: MJTask, e_context: EventContext):
threading.Thread(target=self.check_task_sync, args=(task, e_context)).start()
def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext):
"""
处理任务成功的结果
:param task: MJ任务
:param res: 请求结果
:param e_context: 对话上下文
"""
# channel send img
task.status = Status.FINISHED
task.img_id = res.get("img_id")
task.img_url = res.get("img_url")
logger.info(f"[MJ] task success, task_id={task.id}, img_id={task.img_id}, img_url={task.img_url}")
# send img
reply = Reply(ReplyType.IMAGE_URL, task.img_url)
channel = e_context["channel"]
_send(channel, reply, e_context["context"])
# send info
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
text = ""
if task.task_type == TaskType.GENERATE or task.task_type == TaskType.VARIATION or task.task_type == TaskType.RESET:
text = f"🎨绘画完成!\n"
if task.raw_prompt:
text += f"prompt: {task.raw_prompt}\n"
text += f"- - - - - - - - -\n图片ID: {task.img_id}"
text += f"\n\n🔎使用 {trigger_prefix}mju 命令放大图片\n"
text += f"例如:\n{trigger_prefix}mju {task.img_id} 1"
text += f"\n\n🪄使用 {trigger_prefix}mjv 命令变换图片\n"
text += f"例如:\n{trigger_prefix}mjv {task.img_id} 1"
text += f"\n\n🔄使用 {trigger_prefix}mjr 命令重新生成图片\n"
text += f"例如:\n{trigger_prefix}mjr {task.img_id}"
reply = Reply(ReplyType.INFO, text)
_send(channel, reply, e_context["context"])
self._print_tasks()
return
def _check_rate_limit(self, user_id: str, e_context: EventContext) -> bool:
"""
midjourney任务限流控制
:param user_id: 用户id
:param e_context: 对话上下文
:return: 任务是否能够生成, True:可以生成, False: 被限流
"""
tasks = self.find_tasks_by_user_id(user_id)
task_count = len([t for t in tasks if t.status == Status.PENDING])
if task_count >= self.config.get("max_tasks_per_user"):
reply = Reply(ReplyType.INFO, "您的Midjourney作图任务数已达上限,请稍后再试")
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
return False
task_count = len([t for t in self.tasks.values() if t.status == Status.PENDING])
if task_count >= self.config.get("max_tasks"):
reply = Reply(ReplyType.INFO, "Midjourney作图任务数已达上限,请稍后再试")
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
return False
return True
def _fetch_mode(self, prompt) -> str:
mode = self.config.get("mode")
if "--relax" in prompt or mode == TaskMode.RELAX.value:
return TaskMode.RELAX.value
return mode or TaskMode.FAST.value
def _run_loop(self, loop: asyncio.BaseEventLoop):
"""
运行事件循环,用于轮询任务的线程
:param loop: 事件循环
"""
loop.run_forever()
loop.stop()
def _print_tasks(self):
for id in self.tasks:
logger.debug(f"[MJ] current task: {self.tasks[id]}")
def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR):
"""
设置回复文本
:param content: 回复内容
:param e_context: 对话上下文
:param level: 回复等级
"""
reply = Reply(level, content)
e_context["reply"] = reply
e_context.action = EventAction.BREAK_PASS
def get_help_text(self, verbose=False, **kwargs):
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
help_text = "🎨利用Midjourney进行画图\n\n"
if not verbose:
return help_text
help_text += f" - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n - 变换: mjv 图片ID 图片序号\n - 重置: mjr 图片ID"
help_text += f"\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 11055927171882 2\""
help_text += f"\n\"{trigger_prefix}mjv 11055927171882 2\"\n\"{trigger_prefix}mjr 11055927171882\""
return help_text
def find_tasks_by_user_id(self, user_id) -> list:
result = []
with self.tasks_lock:
now = time.time()
for task in self.tasks.values():
if task.status == Status.PENDING and now > task.expiry_time:
task.status = Status.EXPIRED
logger.info(f"[MJ] {task} expired")
if task.user_id == user_id:
result.append(task)
return result
def _send(channel, reply: Reply, context, retry_cnt=0):
try:
channel.send(reply, context)
except Exception as e:
logger.error("[WX] sendMsg error: {}".format(str(e)))
if isinstance(e, NotImplementedError):
return
logger.exception(e)
if retry_cnt < 2:
time.sleep(3 + 3 * retry_cnt)
channel.send(reply, context, retry_cnt + 1)
def check_prefix(content, prefix_list):
if not prefix_list:
return None
for prefix in prefix_list:
if content.startswith(prefix):
return prefix
return None
+21 -4
View File
@@ -1,6 +1,6 @@
import os
import json
from config import pconf
from config import pconf, plugin_config, conf
from common.log import logger
@@ -15,14 +15,31 @@ class Plugin:
"""
# 优先获取 plugins/config.json 中的全局配置
plugin_conf = pconf(self.name)
if not plugin_conf:
# 全局配置不存在,则获取插件目录下的配置
if not plugin_conf or not conf().get("use_global_plugin_config"):
# 全局配置不存在 或者 未开启全局配置开关,则获取插件目录下的配置
plugin_config_path = os.path.join(self.path, "config.json")
if os.path.exists(plugin_config_path):
with open(plugin_config_path, "r") as f:
with open(plugin_config_path, "r", encoding="utf-8") as f:
plugin_conf = json.load(f)
logger.debug(f"loading plugin config, plugin_name={self.name}, conf={plugin_conf}")
return plugin_conf
def save_config(self, config: dict):
try:
plugin_config[self.name] = config
# 写入全局配置
global_config_path = "./plugins/config.json"
if os.path.exists(global_config_path):
with open(global_config_path, "w", encoding='utf-8') as f:
json.dump(plugin_config, f, indent=4, ensure_ascii=False)
# 写入插件配置
plugin_config_path = os.path.join(self.path, "config.json")
if os.path.exists(plugin_config_path):
with open(plugin_config_path, "w", encoding='utf-8') as f:
json.dump(config, f, indent=4, ensure_ascii=False)
except Exception as e:
logger.warn("save plugin config failed: {}".format(e))
def get_help_text(self, **kwargs):
return "暂无帮助信息"
+3 -1
View File
@@ -23,6 +23,8 @@ web.py
wechatpy
# chatgpt-tool-hub plugin
--extra-index-url https://pypi.python.org/simple
chatgpt_tool_hub==0.4.6
# xunfei spark
websocket-client==1.2.0
+1
View File
@@ -6,3 +6,4 @@ requests>=2.28.2
chardet>=5.1.0
Pillow
pre-commit
web.py
+33
View File
@@ -0,0 +1,33 @@
import time
from elevenlabs import set_api_key,generate
from bridge.reply import Reply, ReplyType
from common.log import logger
from common.tmp_dir import TmpDir
from voice.voice import Voice
from config import conf
XI_API_KEY = conf().get("xi_api_key")
set_api_key(XI_API_KEY)
name = conf().get("xi_voice_id")
class ElevenLabsVoice(Voice):
def __init__(self):
pass
def voiceToText(self, voice_file):
pass
def textToVoice(self, text):
audio = generate(
text=text,
voice=name,
model='eleven_multilingual_v1'
)
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3"
with open(fileName, "wb") as f:
f.write(audio)
logger.info("[ElevenLabs] textToVoice text={} voice file name={}".format(text, fileName))
return Reply(ReplyType.VOICE, fileName)
+4
View File
@@ -29,4 +29,8 @@ def create_voice(voice_type):
from voice.azure.azure_voice import AzureVoice
return AzureVoice()
elif voice_type == "elevenlabs":
from voice.elevent.elevent_voice import ElevenLabsVoice
return ElevenLabsVoice()
raise RuntimeError