mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-05-10 13:32:14 +08:00
Compare commits
58 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a64d7c42b1 | |||
| 36b6cc58bf | |||
| 5ac8a257e7 | |||
| 74119d0372 | |||
| 4e162c73e5 | |||
| 5ff753a492 | |||
| 89400630c0 | |||
| 3899c0cfe3 | |||
| a086f1989f | |||
| 1171b04e93 | |||
| c55d81825a | |||
| 2dcd026e9f | |||
| cdf8609d24 | |||
| 36580c5f7f | |||
| 1cff2521f4 | |||
| db4998a56b | |||
| acbd506568 | |||
| 0cf8e3be73 | |||
| 2473334dfc | |||
| 1ff72d1d37 | |||
| 241fad5524 | |||
| 1b48cea50a | |||
| 88bf345b91 | |||
| ab4ff3d1a3 | |||
| 3502e0d643 | |||
| 995894d3aa | |||
| 4da8714124 | |||
| 6b247ae880 | |||
| 176941ea3b | |||
| 5176b56d3b | |||
| 8abf18ab25 | |||
| 395edbd9f4 | |||
| 2386eb8fc2 | |||
| 68208f82a0 | |||
| ca916b7ce5 | |||
| 01e02934da | |||
| c81a79f7b9 | |||
| 1133648bf6 | |||
| e05bc541d7 | |||
| d689d20482 | |||
| 39dd99b272 | |||
| cda21acb43 | |||
| 9bd7d09f20 | |||
| b22994c2d2 | |||
| e027286b6d | |||
| d6e16995e0 | |||
| 782bff3a51 | |||
| de26dc0597 | |||
| 233b24ab0f | |||
| 2f9e5b1219 | |||
| dd36b8b150 | |||
| f81ac31fe1 | |||
| 24b63bc5bd | |||
| 1817a972c6 | |||
| 74a253f521 | |||
| 4dbc54fa15 | |||
| 1d4ff796d7 | |||
| 44cb54a9ea |
+6
-1
@@ -1,6 +1,8 @@
|
||||
.DS_Store
|
||||
.idea
|
||||
.vscode
|
||||
.venv
|
||||
.vs
|
||||
.wechaty/
|
||||
__pycache__/
|
||||
venv*
|
||||
@@ -22,6 +24,9 @@ plugins/**/
|
||||
!plugins/tool
|
||||
!plugins/banwords
|
||||
!plugins/banwords/**/
|
||||
plugins/banwords/__pycache__
|
||||
plugins/banwords/lib/__pycache__
|
||||
!plugins/hello
|
||||
!plugins/role
|
||||
!plugins/keyword
|
||||
!plugins/keyword
|
||||
!plugins/linkai
|
||||
@@ -5,11 +5,12 @@
|
||||
最新版本支持的功能如下:
|
||||
|
||||
- [x] **多端部署:** 有多种部署方式可选择且功能完备,目前已支持个人微信,微信公众号和企业微信应用等部署方式
|
||||
- [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3,GPT-3.5,GPT-4模型
|
||||
- [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3, GPT-3.5, GPT-4, 文心一言, 讯飞星火
|
||||
- [x] **语音识别:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai等多种语音模型
|
||||
- [x] **图片生成:** 支持图片生成 和 图生图(如照片修复),可选择 Dell-E, stable diffusion, replicate模型
|
||||
- [x] **图片生成:** 支持图片生成 和 图生图(如照片修复),可选择 Dell-E, stable diffusion, replicate, midjourney模型
|
||||
- [x] **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结等插件
|
||||
- [X] **Tool工具:** 与操作系统和互联网交互,支持最新信息搜索、数学计算、天气和资讯查询、网页总结,基于 [chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub) 实现
|
||||
- [x] **知识库:** 通过上传知识库文件自定义专属机器人,可作为数字分身、领域知识库、智能客服使用,基于 [LinkAI](https://chat.link-ai.tech/console) 实现
|
||||
|
||||
> 欢迎接入更多应用,参考 [Terminal代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/terminal/terminal_channel.py)实现接收和发送消息逻辑即可接入。 同时欢迎增加新的插件,参考 [插件说明文档](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins)。
|
||||
|
||||
@@ -27,7 +28,9 @@ Demo made by [Visionn](https://www.wangpc.cc/)
|
||||
|
||||
# 更新日志
|
||||
|
||||
>**2023.06.12:** 接入 [LinkAI](https://chat.link-ai.tech/console) 平台,可在线创建 个人知识库,并接入微信、公众号及企业微信中。使用参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
>**2023.08.08:** 接入百度文心一言模型,通过 [插件](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/linkai) 支持 Midjourney 绘图
|
||||
|
||||
>**2023.06.12:** 接入 [LinkAI](https://chat.link-ai.tech/console) 平台,可在线创建个人知识库,并接入微信、公众号及企业微信中,打造专属客服机器人。使用参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
|
||||
>**2023.04.26:** 支持企业微信应用号部署,兼容插件,并支持语音图片交互,私人助理理想选择,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatcom/README.md)。(contributed by [@lanvent](https://github.com/lanvent) in [#944](https://github.com/zhayujie/chatgpt-on-wechat/pull/944))
|
||||
|
||||
@@ -110,8 +113,8 @@ pip3 install azure-cognitiveservices-speech
|
||||
# config.json文件内容示例
|
||||
{
|
||||
"open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY
|
||||
"model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
|
||||
"proxy": "127.0.0.1:7890", # 代理客户端的ip和端口
|
||||
"model": "gpt-3.5-turbo", # 模型名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
|
||||
"proxy": "", # 代理客户端的ip和端口,国内环境开启代理的需要填写该项,如 "127.0.0.1:7890"
|
||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
|
||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
||||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
|
||||
@@ -123,9 +126,13 @@ pip3 install azure-cognitiveservices-speech
|
||||
"group_speech_recognition": false, # 是否开启群组语音识别
|
||||
"use_azure_chatgpt": false, # 是否使用Azure ChatGPT service代替openai ChatGPT service. 当设置为true时需要设置 open_ai_api_base,如 https://xxx.openai.azure.com/
|
||||
"azure_deployment_id": "", # 采用Azure ChatGPT时,模型部署名称
|
||||
"azure_api_version": "", # 采用Azure ChatGPT时,API版本
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
|
||||
# 订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复,可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
|
||||
"subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。"
|
||||
"subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
||||
"use_linkai": false, # 是否使用LinkAI接口,默认关闭,开启后可国内访问,使用知识库和MJ
|
||||
"linkai_api_key": "", # LinkAI Api Key
|
||||
"linkai_app_code": "" # LinkAI 应用code
|
||||
}
|
||||
```
|
||||
**配置说明:**
|
||||
@@ -150,7 +157,7 @@ pip3 install azure-cognitiveservices-speech
|
||||
|
||||
**4.其他配置**
|
||||
|
||||
+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k` (其中gpt-4 api暂未完全开放,申请通过后可使用)
|
||||
+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k`, `wenxin` (其中gpt-4 api暂未完全开放,申请通过后可使用)
|
||||
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
|
||||
+ `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
|
||||
+ 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix `
|
||||
@@ -162,6 +169,12 @@ pip3 install azure-cognitiveservices-speech
|
||||
+ `character_desc` 配置中保存着你对机器人说的一段话,他会记住这段话并作为他的设定,你可以为他定制任何人格 (关于会话上下文的更多内容参考该 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/43))
|
||||
+ `subscribe_msg`:订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复, 可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
|
||||
|
||||
**5.LinkAI配置 (可选)**
|
||||
|
||||
+ `use_linkai`: 是否使用LinkAI接口,开启后可国内访问,使用知识库和 `Midjourney` 绘画, 参考 [文档](https://link-ai.tech/platform/link-app/wechat)
|
||||
+ `linkai_api_key`: LinkAI Api Key,可在 [控制台](https://chat.link-ai.tech/console/interface) 创建
|
||||
+ `linkai_app_code`: LinkAI 应用code,选填
|
||||
|
||||
**本说明文档可能会未及时更新,当前所有可选的配置项均在该[`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中列出。**
|
||||
|
||||
## 运行
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import requests, json
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
|
||||
BAIDU_API_KEY = conf().get("baidu_wenxin_api_key")
|
||||
BAIDU_SECRET_KEY = conf().get("baidu_wenxin_secret_key")
|
||||
|
||||
class BaiduWenxinBot(Bot):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("baidu_wenxin_model") or "eb-instant")
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context and context.type:
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[BAIDU] query={}".format(query))
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
if query == "#清除记忆":
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
else:
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
result = self.reply_text(session)
|
||||
total_tokens, completion_tokens, reply_content = (
|
||||
result["total_tokens"],
|
||||
result["completion_tokens"],
|
||||
result["content"],
|
||||
)
|
||||
logger.debug(
|
||||
"[BAIDU] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content, completion_tokens)
|
||||
)
|
||||
|
||||
if total_tokens == 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content)
|
||||
else:
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
||||
reply = Reply(ReplyType.TEXT, reply_content)
|
||||
return reply
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: BaiduWenxinSession, retry_count=0):
|
||||
try:
|
||||
logger.info("[BAIDU] model={}".format(session.model))
|
||||
access_token = self.get_access_token()
|
||||
if access_token == 'None':
|
||||
logger.warn("[BAIDU] access token 获取失败")
|
||||
return {
|
||||
"total_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"content": 0,
|
||||
}
|
||||
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + session.model + "?access_token=" + access_token
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
payload = {'messages': session.messages}
|
||||
response = requests.request("POST", url, headers=headers, data=json.dumps(payload))
|
||||
response_text = json.loads(response.text)
|
||||
logger.info(f"[BAIDU] response text={response_text}")
|
||||
res_content = response_text["result"]
|
||||
total_tokens = response_text["usage"]["total_tokens"]
|
||||
completion_tokens = response_text["usage"]["completion_tokens"]
|
||||
logger.info("[BAIDU] reply={}".format(res_content))
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"content": res_content,
|
||||
}
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
logger.warn("[BAIDU] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session.session_id)
|
||||
result = {"completion_tokens": 0, "content": "出错了: {}".format(e)}
|
||||
return result
|
||||
|
||||
def get_access_token(self):
|
||||
"""
|
||||
使用 AK,SK 生成鉴权签名(Access Token)
|
||||
:return: access_token,或是None(如果错误)
|
||||
"""
|
||||
url = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
params = {"grant_type": "client_credentials", "client_id": BAIDU_API_KEY, "client_secret": BAIDU_SECRET_KEY}
|
||||
return str(requests.post(url, params=params).json().get("access_token"))
|
||||
@@ -0,0 +1,53 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
"""
|
||||
e.g. [
|
||||
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
||||
{"role": "user", "content": "Where was it played?"}
|
||||
]
|
||||
"""
|
||||
|
||||
|
||||
class BaiduWenxinSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
# 百度文心不支持system prompt
|
||||
# self.reset()
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) >= 2:
|
||||
self.messages.pop(0)
|
||||
self.messages.pop(0)
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages, self.model)
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages, model):
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
# 官方token计算规则暂不明确: "大约为 token数为 "中文字 + 其他语种单词数 x 1.3"
|
||||
# 这里先直接根据字数粗略估算吧,暂不影响正常使用,仅在判断是否丢弃历史会话的时候会有偏差
|
||||
tokens += len(msg["content"])
|
||||
return tokens
|
||||
+9
-8
@@ -11,31 +11,32 @@ def create_bot(bot_type):
|
||||
:return: bot instance
|
||||
"""
|
||||
if bot_type == const.BAIDU:
|
||||
# Baidu Unit对话接口
|
||||
from bot.baidu.baidu_unit_bot import BaiduUnitBot
|
||||
|
||||
return BaiduUnitBot()
|
||||
# 替换Baidu Unit为Baidu文心千帆对话接口
|
||||
# from bot.baidu.baidu_unit_bot import BaiduUnitBot
|
||||
# return BaiduUnitBot()
|
||||
from bot.baidu.baidu_wenxin import BaiduWenxinBot
|
||||
return BaiduWenxinBot()
|
||||
|
||||
elif bot_type == const.CHATGPT:
|
||||
# ChatGPT 网页端web接口
|
||||
from bot.chatgpt.chat_gpt_bot import ChatGPTBot
|
||||
|
||||
return ChatGPTBot()
|
||||
|
||||
elif bot_type == const.OPEN_AI:
|
||||
# OpenAI 官方对话模型API
|
||||
from bot.openai.open_ai_bot import OpenAIBot
|
||||
|
||||
return OpenAIBot()
|
||||
|
||||
elif bot_type == const.CHATGPTONAZURE:
|
||||
# Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/
|
||||
from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot
|
||||
|
||||
return AzureChatGPTBot()
|
||||
|
||||
elif bot_type == const.XUNFEI:
|
||||
from bot.xunfei.xunfei_spark_bot import XunFeiBot
|
||||
return XunFeiBot()
|
||||
|
||||
elif bot_type == const.LINKAI:
|
||||
from bot.linkai.link_ai_bot import LinkAIBot
|
||||
return LinkAIBot()
|
||||
|
||||
raise RuntimeError
|
||||
|
||||
@@ -166,7 +166,7 @@ class AzureChatGPTBot(ChatGPTBot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
openai.api_type = "azure"
|
||||
openai.api_version = "2023-03-15-preview"
|
||||
openai.api_version = conf().get("azure_api_version", "2023-06-01-preview")
|
||||
self.args["deployment_id"] = conf().get("azure_deployment_id")
|
||||
|
||||
def create_img(self, query, retry_count=0, api_key=None):
|
||||
|
||||
@@ -55,11 +55,16 @@ class ChatGPTSession(Session):
|
||||
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
def num_tokens_from_messages(messages, model):
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
|
||||
if model in ["wenxin", "xunfei"]:
|
||||
return num_tokens_by_character(messages)
|
||||
|
||||
import tiktoken
|
||||
|
||||
if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo"]:
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
|
||||
elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k"]:
|
||||
elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k"]:
|
||||
return num_tokens_from_messages(messages, model="gpt-4")
|
||||
|
||||
try:
|
||||
@@ -85,3 +90,11 @@ def num_tokens_from_messages(messages, model):
|
||||
num_tokens += tokens_per_name
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
return num_tokens
|
||||
|
||||
|
||||
def num_tokens_by_character(messages):
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
tokens += len(msg["content"])
|
||||
return tokens
|
||||
|
||||
+38
-30
@@ -22,25 +22,30 @@ class LinkAIBot(Bot, OpenAIImage):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_url = "https://api.link-ai.chat/v1"
|
||||
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
if context.type == ContextType.TEXT:
|
||||
return self._chat(query, context)
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
ok, res = self.create_img(query, 0)
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
reply = Reply(ReplyType.IMAGE_URL, res)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
reply = Reply(ReplyType.ERROR, res)
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def _chat(self, query, context, retry_count=0):
|
||||
def _chat(self, query, context, retry_count=0) -> Reply:
|
||||
"""
|
||||
发起对话请求
|
||||
:param query: 请求提示词
|
||||
:param context: 对话上下文
|
||||
:param retry_count: 当前递归重试次数
|
||||
:return: 回复
|
||||
"""
|
||||
if retry_count >= 2:
|
||||
# exit from retry 2 times
|
||||
logger.warn("[LINKAI] failed after maximum number of retry times")
|
||||
@@ -52,53 +57,56 @@ class LinkAIBot(Bot, OpenAIImage):
|
||||
logger.info(f"[LINKAI] won't set appcode because a plugin ({context['generate_breaked_by']}) affected the context")
|
||||
app_code = None
|
||||
else:
|
||||
app_code = conf().get("linkai_app_code")
|
||||
app_code = context.kwargs.get("app_code") or conf().get("linkai_app_code")
|
||||
linkai_api_key = conf().get("linkai_api_key")
|
||||
|
||||
session_id = context["session_id"]
|
||||
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
|
||||
model = conf().get("model") or "gpt-3.5-turbo"
|
||||
# remove system message
|
||||
if app_code and session.messages[0].get("role") == "system":
|
||||
session.messages.pop(0)
|
||||
|
||||
logger.info(f"[LINKAI] query={query}, app_code={app_code}")
|
||||
if session.messages[0].get("role") == "system":
|
||||
if app_code or model == "wenxin":
|
||||
session.messages.pop(0)
|
||||
|
||||
body = {
|
||||
"appCode": app_code,
|
||||
"app_code": app_code,
|
||||
"messages": session.messages,
|
||||
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
|
||||
"model": model, # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin
|
||||
"temperature": conf().get("temperature"),
|
||||
"top_p": conf().get("top_p", 1),
|
||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
}
|
||||
logger.info(f"[LINKAI] query={query}, app_code={app_code}, mode={body.get('model')}")
|
||||
headers = {"Authorization": "Bearer " + linkai_api_key}
|
||||
|
||||
# do http request
|
||||
res = requests.post(url=self.base_url + "/chat/completion", json=body, headers=headers).json()
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
|
||||
res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers,
|
||||
timeout=conf().get("request_timeout", 180))
|
||||
if res.status_code == 200:
|
||||
# execute success
|
||||
response = res.json()
|
||||
reply_content = response["choices"][0]["message"]["content"]
|
||||
total_tokens = response["usage"]["total_tokens"]
|
||||
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
||||
return Reply(ReplyType.TEXT, reply_content)
|
||||
|
||||
if not res or not res["success"]:
|
||||
if res.get("code") == self.AUTH_FAILED_CODE:
|
||||
logger.exception(f"[LINKAI] please check your linkai_api_key, res={res}")
|
||||
return Reply(ReplyType.ERROR, "请再问我一次吧")
|
||||
else:
|
||||
response = res.json()
|
||||
error = response.get("error")
|
||||
logger.error(f"[LINKAI] chat failed, status_code={res.status_code}, "
|
||||
f"msg={error.get('message')}, type={error.get('type')}")
|
||||
|
||||
elif res.get("code") == self.NO_QUOTA_CODE:
|
||||
logger.exception(f"[LINKAI] please check your account quota, https://chat.link-ai.tech/console/account")
|
||||
return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")
|
||||
|
||||
else:
|
||||
# retry
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
|
||||
# execute success
|
||||
reply_content = res["data"]["content"]
|
||||
logger.info(f"[LINKAI] reply={reply_content}")
|
||||
self.sessions.session_reply(reply_content, session_id)
|
||||
return Reply(ReplyType.TEXT, reply_content)
|
||||
return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
@@ -0,0 +1,250 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import requests, json
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
from bridge.context import ContextType, Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
from common import const
|
||||
import time
|
||||
import _thread as thread
|
||||
import datetime
|
||||
from datetime import datetime
|
||||
from wsgiref.handlers import format_date_time
|
||||
from urllib.parse import urlencode
|
||||
import base64
|
||||
import ssl
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from time import mktime
|
||||
from urllib.parse import urlparse
|
||||
import websocket
|
||||
import queue
|
||||
import threading
|
||||
import random
|
||||
|
||||
# 消息队列 map
|
||||
queue_map = dict()
|
||||
|
||||
# 响应队列 map
|
||||
reply_map = dict()
|
||||
|
||||
|
||||
class XunFeiBot(Bot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.app_id = conf().get("xunfei_app_id")
|
||||
self.api_key = conf().get("xunfei_api_key")
|
||||
self.api_secret = conf().get("xunfei_api_secret")
|
||||
# 默认使用v2.0版本,1.5版本可设置为 general
|
||||
self.domain = "generalv2"
|
||||
# 默认使用v2.0版本,1.5版本可设置为 "ws://spark-api.xf-yun.com/v1.1/chat"
|
||||
self.spark_url = "ws://spark-api.xf-yun.com/v2.1/chat"
|
||||
self.host = urlparse(self.spark_url).netloc
|
||||
self.path = urlparse(self.spark_url).path
|
||||
# 和wenxin使用相同的session机制
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=const.XUNFEI)
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[XunFei] query={}".format(query))
|
||||
session_id = context["session_id"]
|
||||
request_id = self.gen_request_id(session_id)
|
||||
reply_map[request_id] = ""
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
threading.Thread(target=self.create_web_socket, args=(session.messages, request_id)).start()
|
||||
depth = 0
|
||||
time.sleep(0.1)
|
||||
t1 = time.time()
|
||||
usage = {}
|
||||
while depth <= 300:
|
||||
try:
|
||||
data_queue = queue_map.get(request_id)
|
||||
if not data_queue:
|
||||
depth += 1
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
data_item = data_queue.get(block=True, timeout=0.1)
|
||||
if data_item.is_end:
|
||||
# 请求结束
|
||||
del queue_map[request_id]
|
||||
if data_item.reply:
|
||||
reply_map[request_id] += data_item.reply
|
||||
usage = data_item.usage
|
||||
break
|
||||
|
||||
reply_map[request_id] += data_item.reply
|
||||
depth += 1
|
||||
except Exception as e:
|
||||
depth += 1
|
||||
continue
|
||||
t2 = time.time()
|
||||
logger.info(f"[XunFei-API] response={reply_map[request_id]}, time={t2 - t1}s, usage={usage}")
|
||||
self.sessions.session_reply(reply_map[request_id], session_id, usage.get("total_tokens"))
|
||||
reply = Reply(ReplyType.TEXT, reply_map[request_id])
|
||||
del reply_map[request_id]
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def create_web_socket(self, prompt, session_id, temperature=0.5):
|
||||
logger.info(f"[XunFei] start connect, prompt={prompt}")
|
||||
websocket.enableTrace(False)
|
||||
wsUrl = self.create_url()
|
||||
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close,
|
||||
on_open=on_open)
|
||||
data_queue = queue.Queue(1000)
|
||||
queue_map[session_id] = data_queue
|
||||
ws.appid = self.app_id
|
||||
ws.question = prompt
|
||||
ws.domain = self.domain
|
||||
ws.session_id = session_id
|
||||
ws.temperature = temperature
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
|
||||
def gen_request_id(self, session_id: str):
|
||||
return session_id + "_" + str(int(time.time())) + "" + str(random.randint(0, 100))
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
# 生成RFC1123格式的时间戳
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
# 拼接字符串
|
||||
signature_origin = "host: " + self.host + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + self.path + " HTTP/1.1"
|
||||
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(self.api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
|
||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
|
||||
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", ' \
|
||||
f'signature="{signature_sha_base64}"'
|
||||
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
|
||||
# 将请求的鉴权参数组合为字典
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": self.host
|
||||
}
|
||||
# 拼接鉴权参数,生成url
|
||||
url = self.spark_url + '?' + urlencode(v)
|
||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
return url
|
||||
|
||||
def gen_params(self, appid, domain, question):
|
||||
"""
|
||||
通过appid和用户的提问来生成请参数
|
||||
"""
|
||||
data = {
|
||||
"header": {
|
||||
"app_id": appid,
|
||||
"uid": "1234"
|
||||
},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": domain,
|
||||
"random_threshold": 0.5,
|
||||
"max_tokens": 2048,
|
||||
"auditing": "default"
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"message": {
|
||||
"text": question
|
||||
}
|
||||
}
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
class ReplyItem:
|
||||
def __init__(self, reply, usage=None, is_end=False):
|
||||
self.is_end = is_end
|
||||
self.reply = reply
|
||||
self.usage = usage
|
||||
|
||||
|
||||
# 收到websocket错误的处理
|
||||
def on_error(ws, error):
|
||||
logger.error(f"[XunFei] error: {str(error)}")
|
||||
|
||||
|
||||
# 收到websocket关闭的处理
|
||||
def on_close(ws, one, two):
|
||||
data_queue = queue_map.get(ws.session_id)
|
||||
data_queue.put("END")
|
||||
|
||||
|
||||
# 收到websocket连接建立的处理
|
||||
def on_open(ws):
|
||||
logger.info(f"[XunFei] Start websocket, session_id={ws.session_id}")
|
||||
thread.start_new_thread(run, (ws,))
|
||||
|
||||
|
||||
def run(ws, *args):
|
||||
data = json.dumps(gen_params(appid=ws.appid, domain=ws.domain, question=ws.question, temperature=ws.temperature))
|
||||
ws.send(data)
|
||||
|
||||
|
||||
# Websocket 操作
|
||||
# 收到websocket消息的处理
|
||||
def on_message(ws, message):
|
||||
data = json.loads(message)
|
||||
code = data['header']['code']
|
||||
if code != 0:
|
||||
logger.error(f'请求错误: {code}, {data}')
|
||||
ws.close()
|
||||
else:
|
||||
choices = data["payload"]["choices"]
|
||||
status = choices["status"]
|
||||
content = choices["text"][0]["content"]
|
||||
data_queue = queue_map.get(ws.session_id)
|
||||
if not data_queue:
|
||||
logger.error(f"[XunFei] can't find data queue, session_id={ws.session_id}")
|
||||
return
|
||||
reply_item = ReplyItem(content)
|
||||
if status == 2:
|
||||
usage = data["payload"].get("usage")
|
||||
reply_item = ReplyItem(content, usage)
|
||||
reply_item.is_end = True
|
||||
ws.close()
|
||||
data_queue.put(reply_item)
|
||||
|
||||
|
||||
def gen_params(appid, domain, question, temperature=0.5):
|
||||
"""
|
||||
通过appid和用户的提问来生成请参数
|
||||
"""
|
||||
data = {
|
||||
"header": {
|
||||
"app_id": appid,
|
||||
"uid": "1234"
|
||||
},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": domain,
|
||||
"temperature": temperature,
|
||||
"random_threshold": 0.5,
|
||||
"max_tokens": 2048,
|
||||
"auditing": "default"
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"message": {
|
||||
"text": question
|
||||
}
|
||||
}
|
||||
}
|
||||
return data
|
||||
@@ -23,6 +23,10 @@ class Bridge(object):
|
||||
self.btype["chat"] = const.OPEN_AI
|
||||
if conf().get("use_azure_chatgpt", False):
|
||||
self.btype["chat"] = const.CHATGPTONAZURE
|
||||
if model_type in ["wenxin"]:
|
||||
self.btype["chat"] = const.BAIDU
|
||||
if model_type in ["xunfei"]:
|
||||
self.btype["chat"] = const.XUNFEI
|
||||
if conf().get("use_linkai") and conf().get("linkai_api_key"):
|
||||
self.btype["chat"] = const.LINKAI
|
||||
self.bots = {}
|
||||
@@ -54,3 +58,9 @@ class Bridge(object):
|
||||
|
||||
def fetch_translate(self, text, from_lang="", to_lang="en") -> Reply:
|
||||
return self.get_bot("translate").translate(text, from_lang, to_lang)
|
||||
|
||||
def reset_bot(self):
|
||||
"""
|
||||
重置bot路由
|
||||
"""
|
||||
self.__init__()
|
||||
|
||||
@@ -108,8 +108,12 @@ class ChatChannel(Channel):
|
||||
if not conf().get("group_at_off", False):
|
||||
flag = True
|
||||
pattern = f"@{re.escape(self.name)}(\u2005|\u0020)"
|
||||
content = re.sub(pattern, r"", content)
|
||||
|
||||
subtract_res = re.sub(pattern, r"", content)
|
||||
if subtract_res == content and context["msg"].self_display_name:
|
||||
# 前缀移除后没有变化,使用群昵称再次移除
|
||||
pattern = f"@{re.escape(context['msg'].self_display_name)}(\u2005|\u0020)"
|
||||
subtract_res = re.sub(pattern, r"", content)
|
||||
content = subtract_res
|
||||
if not flag:
|
||||
if context["origin_ctype"] == ContextType.VOICE:
|
||||
logger.info("[WX]receive group voice, but checkprefix didn't match")
|
||||
|
||||
@@ -24,9 +24,7 @@ is_at: 是否被at
|
||||
- (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在)
|
||||
actual_user_id: 实际发送者id (群聊必填)
|
||||
actual_user_nickname:实际发送者昵称
|
||||
|
||||
|
||||
|
||||
self_display_name: 自身的展示名,设置群昵称时,该字段表示群昵称
|
||||
|
||||
_prepare_fn: 准备函数,用于准备消息的内容,比如下载图片等,
|
||||
_prepared: 是否已经调用过准备函数
|
||||
@@ -48,6 +46,8 @@ class ChatMessage(object):
|
||||
to_user_nickname = None
|
||||
other_user_id = None
|
||||
other_user_nickname = None
|
||||
my_msg = False
|
||||
self_display_name = None
|
||||
|
||||
is_group = False
|
||||
is_at = False
|
||||
|
||||
@@ -58,6 +58,9 @@ def _check(func):
|
||||
if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
|
||||
logger.debug("[WX]history message {} skipped".format(msgId))
|
||||
return
|
||||
if cmsg.my_msg and not cmsg.is_group:
|
||||
logger.debug("[WX]my message {} skipped".format(msgId))
|
||||
return
|
||||
return func(self, cmsg)
|
||||
|
||||
return wrapper
|
||||
@@ -189,10 +192,14 @@ class WechatChannel(ChatChannel):
|
||||
logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
logger.debug(f"[WX] start download image, img_url={img_url}")
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
size = 0
|
||||
for block in pic_res.iter_content(1024):
|
||||
size += len(block)
|
||||
image_storage.write(block)
|
||||
logger.info(f"[WX] download image success, size={size}, img_url={img_url}")
|
||||
image_storage.seek(0)
|
||||
itchat.send_image(image_storage, toUserName=receiver)
|
||||
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
|
||||
|
||||
@@ -57,13 +57,19 @@ class WechatMessage(ChatMessage):
|
||||
self.from_user_nickname = nickname
|
||||
if self.to_user_id == user_id:
|
||||
self.to_user_nickname = nickname
|
||||
try: # 陌生人时候, 'User'字段可能不存在
|
||||
try: # 陌生人时候, User字段可能不存在
|
||||
# my_msg 为True是表示是自己发送的消息
|
||||
self.my_msg = itchat_msg["ToUserName"] == itchat_msg["User"]["UserName"] and \
|
||||
itchat_msg["ToUserName"] != itchat_msg["FromUserName"]
|
||||
self.other_user_id = itchat_msg["User"]["UserName"]
|
||||
self.other_user_nickname = itchat_msg["User"]["NickName"]
|
||||
if self.other_user_id == self.from_user_id:
|
||||
self.from_user_nickname = self.other_user_nickname
|
||||
if self.other_user_id == self.to_user_id:
|
||||
self.to_user_nickname = self.other_user_nickname
|
||||
if itchat_msg["User"].get("Self"):
|
||||
# 自身的展示名,当设置了群昵称时,该字段表示群昵称
|
||||
self.self_display_name = itchat_msg["User"].get("Self").get("DisplayName")
|
||||
except KeyError as e: # 处理偶尔没有对方信息的情况
|
||||
logger.warn("[WX]get other_user_id failed: " + str(e))
|
||||
if self.from_user_id == user_id:
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
OPEN_AI = "openAI"
|
||||
CHATGPT = "chatGPT"
|
||||
BAIDU = "baidu"
|
||||
XUNFEI = "xunfei"
|
||||
CHATGPTONAZURE = "chatGPTOnAzure"
|
||||
LINKAI = "linkai"
|
||||
|
||||
VERSION = "1.3.0"
|
||||
|
||||
MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "xunfei"]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{
|
||||
"open_ai_api_key": "YOUR API KEY",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"channel_type": "wx",
|
||||
"proxy": "",
|
||||
"hot_reload": false,
|
||||
"single_chat_prefix": [
|
||||
@@ -29,6 +30,8 @@
|
||||
"conversation_max_tokens": 1000,
|
||||
"expires_in_seconds": 3600,
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",
|
||||
"temperature": 0.7,
|
||||
"top_p": 1,
|
||||
"subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
||||
"use_linkai": false,
|
||||
"linkai_api_key": "",
|
||||
|
||||
@@ -16,9 +16,10 @@ available_setting = {
|
||||
"open_ai_api_base": "https://api.openai.com/v1",
|
||||
"proxy": "", # openai使用的代理
|
||||
# chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
|
||||
"model": "gpt-3.5-turbo",
|
||||
"model": "gpt-3.5-turbo", # 还支持 gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
|
||||
"use_azure_chatgpt": False, # 是否使用azure的chatgpt
|
||||
"azure_deployment_id": "", # azure 模型部署名称
|
||||
"azure_api_version": "", # azure api版本
|
||||
# Bot触发配置
|
||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
|
||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
||||
@@ -50,13 +51,21 @@ available_setting = {
|
||||
"presence_penalty": 0,
|
||||
"request_timeout": 60, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试
|
||||
# Baidu 文心一言参数
|
||||
"baidu_wenxin_model": "eb-instant", # 默认使用ERNIE-Bot-turbo模型
|
||||
"baidu_wenxin_api_key": "", # Baidu api key
|
||||
"baidu_wenxin_secret_key": "", # Baidu secret key
|
||||
# 讯飞星火API
|
||||
"xunfei_app_id": "", # 讯飞应用ID
|
||||
"xunfei_api_key": "", # 讯飞 API key
|
||||
"xunfei_api_secret": "", # 讯飞 API secret
|
||||
# 语音设置
|
||||
"speech_recognition": False, # 是否开启语音识别
|
||||
"group_speech_recognition": False, # 是否开启群组语音识别
|
||||
"voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key
|
||||
"always_reply_voice": False, # 是否一直使用语音回复
|
||||
"voice_to_text": "openai", # 语音识别引擎,支持openai,baidu,google,azure
|
||||
"text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline),azure
|
||||
"text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline),azure,elevenlabs
|
||||
# baidu 语音api配置, 使用百度语音识别和语音合成时需要
|
||||
"baidu_app_id": "",
|
||||
"baidu_api_key": "",
|
||||
@@ -66,6 +75,9 @@ available_setting = {
|
||||
# azure 语音api配置, 使用azure语音识别和语音合成时需要
|
||||
"azure_voice_api_key": "",
|
||||
"azure_voice_region": "japaneast",
|
||||
# elevenlabs 语音api配置
|
||||
"xi_api_key": "", #获取ap的方法可以参考https://docs.elevenlabs.io/api-reference/quick-start/authentication
|
||||
"xi_voice_id": "", #ElevenLabs提供了9种英式、美式等英语发音id,分别是“Adam/Antoni/Arnold/Bella/Domi/Elli/Josh/Rachel/Sam”
|
||||
# 服务时间限制,目前支持itchat
|
||||
"chat_time_module": False, # 是否开启服务时间限制
|
||||
"chat_start_time": "00:00", # 服务开始时间
|
||||
@@ -102,10 +114,13 @@ available_setting = {
|
||||
"appdata_dir": "", # 数据目录
|
||||
# 插件配置
|
||||
"plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突
|
||||
# 是否使用全局插件配置
|
||||
"use_global_plugin_config": False,
|
||||
# 知识库平台配置
|
||||
"use_linkai": False,
|
||||
"linkai_api_key": "",
|
||||
"linkai_app_code": ""
|
||||
"linkai_app_code": "",
|
||||
"linkai_api_base": "https://api.link-ai.chat" # linkAI服务地址,若国内无法访问或延迟较高可改为 https://api.link-ai.tech
|
||||
}
|
||||
|
||||
|
||||
@@ -252,3 +267,9 @@ def pconf(plugin_name: str) -> dict:
|
||||
:return: 该插件的配置项
|
||||
"""
|
||||
return plugin_config.get(plugin_name.lower())
|
||||
|
||||
|
||||
# 全局配置,用于存放全局生效的状态
|
||||
global_config = {
|
||||
"admin_users": []
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM python:3.10-slim
|
||||
FROM python:3.10-slim-bullseye
|
||||
|
||||
LABEL maintainer="foo@bar.com"
|
||||
ARG TZ='Asia/Shanghai'
|
||||
@@ -32,4 +32,4 @@ RUN chmod +x /entrypoint.sh \
|
||||
|
||||
USER noroot
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
|
||||
@@ -18,6 +18,7 @@ services:
|
||||
SPEECH_RECOGNITION: 'False'
|
||||
CHARACTER_DESC: '你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。'
|
||||
EXPIRES_IN_SECONDS: 3600
|
||||
USE_GLOBAL_PLUGIN_CONFIG: 'True'
|
||||
USE_LINKAI: 'False'
|
||||
LINKAI_API_KEY: ''
|
||||
LINKAI_APP_CODE: ''
|
||||
|
||||
@@ -20,5 +20,19 @@
|
||||
"no_default": false,
|
||||
"model_name": "gpt-3.5-turbo"
|
||||
}
|
||||
},
|
||||
"linkai": {
|
||||
"group_app_map": {
|
||||
"测试群1": "default",
|
||||
"测试群2": "Kv2fXJcH"
|
||||
},
|
||||
"midjourney": {
|
||||
"enabled": true,
|
||||
"auto_translate": true,
|
||||
"img_proxy": true,
|
||||
"max_tasks": 3,
|
||||
"max_tasks_per_user": 1,
|
||||
"use_image_create_prefix": true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import json
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import traceback
|
||||
from typing import Tuple
|
||||
|
||||
import plugins
|
||||
@@ -12,8 +11,7 @@ from bridge.bridge import Bridge
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common import const
|
||||
from common.log import logger
|
||||
from config import conf, load_config
|
||||
from config import conf, load_config, global_config
|
||||
from plugins import *
|
||||
|
||||
# 定义指令集
|
||||
@@ -32,6 +30,10 @@ COMMANDS = {
|
||||
"args": ["口令"],
|
||||
"desc": "管理员认证",
|
||||
},
|
||||
"model": {
|
||||
"alias": ["model", "模型"],
|
||||
"desc": "查看和设置全局模型",
|
||||
},
|
||||
"set_openai_api_key": {
|
||||
"alias": ["set_openai_api_key"],
|
||||
"args": ["api_key"],
|
||||
@@ -257,6 +259,18 @@ class Godcmd(Plugin):
|
||||
break
|
||||
if not ok:
|
||||
result = "插件不存在或未启用"
|
||||
elif cmd == "model":
|
||||
if not isadmin and not self.is_admin_in_group(e_context["context"]):
|
||||
ok, result = False, "需要管理员权限执行"
|
||||
elif len(args) == 0:
|
||||
ok, result = True, "当前模型为: " + str(conf().get("model"))
|
||||
elif len(args) == 1:
|
||||
if args[0] not in const.MODEL_LIST:
|
||||
ok, result = False, "模型名称不存在"
|
||||
else:
|
||||
conf()["model"] = args[0]
|
||||
Bridge().reset_bot()
|
||||
ok, result = True, "模型设置为: " + str(conf().get("model"))
|
||||
elif cmd == "id":
|
||||
ok, result = True, user
|
||||
elif cmd == "set_openai_api_key":
|
||||
@@ -294,7 +308,7 @@ class Godcmd(Plugin):
|
||||
except Exception as e:
|
||||
ok, result = False, "你没有设置私有GPT模型"
|
||||
elif cmd == "reset":
|
||||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI]:
|
||||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI]:
|
||||
bot.sessions.clear_session(session_id)
|
||||
channel.cancel_session(session_id)
|
||||
ok, result = True, "会话已重置"
|
||||
@@ -317,7 +331,8 @@ class Godcmd(Plugin):
|
||||
load_config()
|
||||
ok, result = True, "配置已重载"
|
||||
elif cmd == "resetall":
|
||||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI]:
|
||||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI,
|
||||
const.BAIDU, const.XUNFEI]:
|
||||
channel.cancel_all_session()
|
||||
bot.sessions.clear_all_session()
|
||||
ok, result = True, "重置所有会话成功"
|
||||
@@ -426,12 +441,20 @@ class Godcmd(Plugin):
|
||||
password = args[0]
|
||||
if password == self.password:
|
||||
self.admin_users.append(userid)
|
||||
global_config["admin_users"].append(userid)
|
||||
return True, "认证成功"
|
||||
elif password == self.temp_password:
|
||||
self.admin_users.append(userid)
|
||||
global_config["admin_users"].append(userid)
|
||||
return True, "认证成功,请尽快设置口令"
|
||||
else:
|
||||
return False, "认证失败"
|
||||
|
||||
def get_help_text(self, isadmin=False, isgroup=False, **kwargs):
|
||||
return get_help_text(isadmin, isgroup)
|
||||
|
||||
|
||||
def is_admin_in_group(self, context):
|
||||
if context["isgroup"]:
|
||||
return context.kwargs.get("msg").actual_user_id in global_config["admin_users"]
|
||||
return False
|
||||
|
||||
@@ -54,9 +54,18 @@ class Keyword(Plugin):
|
||||
logger.debug(f"[keyword] 匹配到关键字【{content}】")
|
||||
reply_text = self.keyword[content]
|
||||
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
reply.content = reply_text
|
||||
# 判断匹配内容的类型
|
||||
if (reply_text.startswith("http://") or reply_text.startswith("https://")) and any(reply_text.endswith(ext) for ext in [".jpg", ".jpeg", ".png", ".gif", ".webp"]):
|
||||
# 如果是以 http:// 或 https:// 开头,且.jpg/.jpeg/.png/.gif结尾,则认为是图片 URL
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.IMAGE_URL
|
||||
reply.content = reply_text
|
||||
else:
|
||||
# 否则认为是普通文本
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
reply.content = reply_text
|
||||
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
||||
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
## 插件说明
|
||||
|
||||
基于 LinkAI 提供的知识库、Midjourney绘画等能力对机器人的功能进行增强。平台地址: https://chat.link-ai.tech/console
|
||||
|
||||
## 插件配置
|
||||
|
||||
将 `plugins/linkai` 目录下的 `config.json.template` 配置模板复制为最终生效的 `config.json`:
|
||||
|
||||
以下是配置项说明:
|
||||
|
||||
```bash
|
||||
{
|
||||
"group_app_map": { # 群聊 和 应用编码 的映射关系
|
||||
"测试群1": "default", # 表示在名称为 "测试群1" 的群聊中将使用app_code 为 default 的应用
|
||||
"测试群2": "Kv2fXJcH"
|
||||
},
|
||||
"midjourney": {
|
||||
"enabled": true, # midjourney 绘画开关
|
||||
"auto_translate": true, # 是否自动将提示词翻译为英文
|
||||
"img_proxy": true, # 是否对生成的图片使用代理,如果你是国外服务器,将这一项设置为false会获得更快的生成速度
|
||||
"max_tasks": 3, # 支持同时提交的总任务个数
|
||||
"max_tasks_per_user": 1, # 支持单个用户同时提交的任务个数
|
||||
"use_image_create_prefix": true # 是否使用全局的绘画触发词,如果开启将同时支持由`config.json`中的 image_create_prefix 配置触发
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
注意:
|
||||
|
||||
- 配置项中 `group_app_map` 部分是用于映射群聊与LinkAI平台上的应用, `midjourney` 部分是 mj 画图的配置,可根据需要进行填写,未填写配置时默认不开启相应功能
|
||||
- 实际 `config.json` 配置中应保证json格式,不应携带 '#' 及后面的注释
|
||||
- 如果是`docker`部署,可通过映射 `plugins/config.json` 到容器中来完成插件配置,参考[文档](https://github.com/zhayujie/chatgpt-on-wechat#3-%E6%8F%92%E4%BB%B6%E4%BD%BF%E7%94%A8)
|
||||
|
||||
## 插件使用
|
||||
|
||||
> 使用插件中的知识库管理功能需要首先开启`linkai`对话,依赖全局 `config.json` 中的 `use_linkai` 和 `linkai_api_key` 配置;而midjourney绘画功能则只需填写 `linkai_api_key` 配置,`use_linkai` 无论是否关闭均可使用。具体可参考 [详细文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
|
||||
完成配置后运行项目,会自动运行插件,输入 `#help linkai` 可查看插件功能。
|
||||
|
||||
### 1.知识库管理功能
|
||||
|
||||
提供在不同群聊使用不同应用的功能。可以在上述 `group_app_map` 配置中固定映射关系,也可以通过指令在群中快速完成切换。
|
||||
|
||||
应用切换指令需要首先完成管理员 (`godcmd`) 插件的认证,然后按以下格式输入:
|
||||
|
||||
`$linkai app {app_code}`
|
||||
|
||||
例如输入 `$linkai app Kv2fXJcH`,即将当前群聊与 app_code为 Kv2fXJcH 的应用绑定。
|
||||
|
||||
另外,还可以通过 `$linkai close` 来一键关闭linkai对话,此时就会使用默认的openai接口;同理,发送 `$linkai open` 可以再次开启。
|
||||
|
||||
### 2.Midjourney绘画功能
|
||||
|
||||
指令格式:
|
||||
|
||||
```
|
||||
- 图片生成: $mj 描述词1, 描述词2..
|
||||
- 图片放大: $mju 图片ID 图片序号
|
||||
- 图片变换: $mjv 图片ID 图片序号
|
||||
- 重置: $mjr 图片ID
|
||||
```
|
||||
|
||||
例如:
|
||||
|
||||
```
|
||||
"$mj a little cat, white --ar 9:16"
|
||||
"$mju 1105592717188272288 2"
|
||||
"$mjv 11055927171882 2"
|
||||
"$mjr 11055927171882"
|
||||
```
|
||||
|
||||
注:
|
||||
1. 开启 `use_image_create_prefix` 配置后可直接复用全局画图触发词,以"画"开头便可以生成图片。
|
||||
2. 提示词内容中包含敏感词或者参数格式错误可能导致绘画失败,生成失败不消耗积分
|
||||
3. 使用 `$mj open` 和 `$mj close` 指令可以快速打开和关闭绘图功能
|
||||
@@ -0,0 +1 @@
|
||||
from .linkai import *
|
||||
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"group_app_map": {
|
||||
"测试群1": "default",
|
||||
"测试群2": "Kv2fXJcH"
|
||||
},
|
||||
"midjourney": {
|
||||
"enabled": true,
|
||||
"auto_translate": true,
|
||||
"img_proxy": true,
|
||||
"max_tasks": 3,
|
||||
"max_tasks_per_user": 1,
|
||||
"use_image_create_prefix": true
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
import plugins
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from config import global_config
|
||||
from plugins import *
|
||||
from .midjourney import MJBot
|
||||
from bridge import bridge
|
||||
|
||||
|
||||
@plugins.register(
|
||||
name="linkai",
|
||||
desc="A plugin that supports knowledge base and midjourney drawing.",
|
||||
version="0.1.0",
|
||||
author="https://link-ai.tech",
|
||||
)
|
||||
class LinkAI(Plugin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
self.config = super().load_config()
|
||||
if self.config:
|
||||
self.mj_bot = MJBot(self.config.get("midjourney"))
|
||||
logger.info("[LinkAI] inited")
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
"""
|
||||
消息处理逻辑
|
||||
:param e_context: 消息上下文
|
||||
"""
|
||||
if not self.config:
|
||||
return
|
||||
|
||||
context = e_context['context']
|
||||
if context.type not in [ContextType.TEXT, ContextType.IMAGE, ContextType.IMAGE_CREATE]:
|
||||
# filter content no need solve
|
||||
return
|
||||
|
||||
mj_type = self.mj_bot.judge_mj_task_type(e_context)
|
||||
if mj_type:
|
||||
# MJ作图任务处理
|
||||
self.mj_bot.process_mj_task(mj_type, e_context)
|
||||
return
|
||||
|
||||
if context.content.startswith(f"{_get_trigger_prefix()}linkai"):
|
||||
# 应用管理功能
|
||||
self._process_admin_cmd(e_context)
|
||||
return
|
||||
|
||||
if self._is_chat_task(e_context):
|
||||
# 文本对话任务处理
|
||||
self._process_chat_task(e_context)
|
||||
|
||||
# 插件管理功能
|
||||
def _process_admin_cmd(self, e_context: EventContext):
|
||||
context = e_context['context']
|
||||
cmd = context.content.split()
|
||||
if len(cmd) == 1 or (len(cmd) == 2 and cmd[1] == "help"):
|
||||
_set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO)
|
||||
return
|
||||
|
||||
if len(cmd) == 2 and (cmd[1] == "open" or cmd[1] == "close"):
|
||||
# 知识库开关指令
|
||||
if not _is_admin(e_context):
|
||||
_set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
|
||||
return
|
||||
is_open = True
|
||||
tips_text = "开启"
|
||||
if cmd[1] == "close":
|
||||
tips_text = "关闭"
|
||||
is_open = False
|
||||
conf()["use_linkai"] = is_open
|
||||
bridge.Bridge().reset_bot()
|
||||
_set_reply_text(f"知识库功能已{tips_text}", e_context, level=ReplyType.INFO)
|
||||
return
|
||||
|
||||
if len(cmd) == 3 and cmd[1] == "app":
|
||||
# 知识库应用切换指令
|
||||
if not context.kwargs.get("isgroup"):
|
||||
_set_reply_text("该指令需在群聊中使用", e_context, level=ReplyType.ERROR)
|
||||
return
|
||||
if not _is_admin(e_context):
|
||||
_set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
|
||||
return
|
||||
app_code = cmd[2]
|
||||
group_name = context.kwargs.get("msg").from_user_nickname
|
||||
group_mapping = self.config.get("group_app_map")
|
||||
if group_mapping:
|
||||
group_mapping[group_name] = app_code
|
||||
else:
|
||||
self.config["group_app_map"] = {group_name: app_code}
|
||||
# 保存插件配置
|
||||
super().save_config(self.config)
|
||||
_set_reply_text(f"应用设置成功: {app_code}", e_context, level=ReplyType.INFO)
|
||||
else:
|
||||
_set_reply_text(f"指令错误,请输入{_get_trigger_prefix()}linkai help 获取帮助", e_context,
|
||||
level=ReplyType.INFO)
|
||||
return
|
||||
|
||||
# LinkAI 对话任务处理
|
||||
def _is_chat_task(self, e_context: EventContext):
|
||||
context = e_context['context']
|
||||
# 群聊应用管理
|
||||
return self.config.get("group_app_map") and context.kwargs.get("isgroup")
|
||||
|
||||
def _process_chat_task(self, e_context: EventContext):
|
||||
"""
|
||||
处理LinkAI对话任务
|
||||
:param e_context: 对话上下文
|
||||
"""
|
||||
context = e_context['context']
|
||||
# 群聊应用管理
|
||||
group_name = context.kwargs.get("msg").from_user_nickname
|
||||
app_code = self._fetch_group_app_code(group_name)
|
||||
if app_code:
|
||||
context.kwargs['app_code'] = app_code
|
||||
|
||||
def _fetch_group_app_code(self, group_name: str) -> str:
|
||||
"""
|
||||
根据群聊名称获取对应的应用code
|
||||
:param group_name: 群聊名称
|
||||
:return: 应用code
|
||||
"""
|
||||
group_mapping = self.config.get("group_app_map")
|
||||
if group_mapping:
|
||||
app_code = group_mapping.get(group_name) or group_mapping.get("ALL_GROUP")
|
||||
return app_code
|
||||
|
||||
def get_help_text(self, verbose=False, **kwargs):
|
||||
trigger_prefix = _get_trigger_prefix()
|
||||
help_text = "用于集成 LinkAI 提供的知识库、Midjourney绘画等能力。\n\n"
|
||||
if not verbose:
|
||||
return help_text
|
||||
help_text += f'📖 知识库\n - 群聊中指定应用: {trigger_prefix}linkai app 应用编码\n'
|
||||
help_text += f' - {trigger_prefix}linkai open: 开启对话\n'
|
||||
help_text += f' - {trigger_prefix}linkai close: 关闭对话\n'
|
||||
help_text += f'\n例如: \n"{trigger_prefix}linkai app Kv2fXJcH"\n\n'
|
||||
help_text += f"🎨 绘画\n - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n - 变换: {trigger_prefix}mjv 图片ID 图片序号\n - 重置: {trigger_prefix}mjr 图片ID"
|
||||
help_text += f"\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 11055927171882 2\""
|
||||
help_text += f"\n\"{trigger_prefix}mjv 11055927171882 2\"\n\"{trigger_prefix}mjr 11055927171882\""
|
||||
return help_text
|
||||
|
||||
|
||||
# 静态方法
|
||||
def _is_admin(e_context: EventContext) -> bool:
|
||||
"""
|
||||
判断消息是否由管理员用户发送
|
||||
:param e_context: 消息上下文
|
||||
:return: True: 是, False: 否
|
||||
"""
|
||||
context = e_context["context"]
|
||||
if context["isgroup"]:
|
||||
return context.kwargs.get("msg").actual_user_id in global_config["admin_users"]
|
||||
else:
|
||||
return context["receiver"] in global_config["admin_users"]
|
||||
|
||||
|
||||
def _set_reply_text(content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR):
|
||||
reply = Reply(level, content)
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
|
||||
|
||||
def _get_trigger_prefix():
|
||||
return conf().get("plugin_trigger_prefix", "$")
|
||||
@@ -0,0 +1,427 @@
|
||||
from enum import Enum
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
import requests
|
||||
import threading
|
||||
import time
|
||||
from bridge.reply import Reply, ReplyType
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from bridge.context import ContextType
|
||||
from plugins import EventContext, EventAction
|
||||
|
||||
INVALID_REQUEST = 410
|
||||
NOT_FOUND_ORIGIN_IMAGE = 461
|
||||
NOT_FOUND_TASK = 462
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
GENERATE = "generate"
|
||||
UPSCALE = "upscale"
|
||||
VARIATION = "variation"
|
||||
RESET = "reset"
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
PENDING = "pending"
|
||||
FINISHED = "finished"
|
||||
EXPIRED = "expired"
|
||||
ABORTED = "aborted"
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class TaskMode(Enum):
|
||||
FAST = "fast"
|
||||
RELAX = "relax"
|
||||
|
||||
|
||||
task_name_mapping = {
|
||||
TaskType.GENERATE.name: "生成",
|
||||
TaskType.UPSCALE.name: "放大",
|
||||
TaskType.VARIATION.name: "变换",
|
||||
TaskType.RESET.name: "重新生成",
|
||||
}
|
||||
|
||||
|
||||
class MJTask:
|
||||
def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int = 60 * 30,
|
||||
status=Status.PENDING):
|
||||
self.id = id
|
||||
self.user_id = user_id
|
||||
self.task_type = task_type
|
||||
self.raw_prompt = raw_prompt
|
||||
self.send_func = None # send_func(img_url)
|
||||
self.expiry_time = time.time() + expires
|
||||
self.status = status
|
||||
self.img_url = None # url
|
||||
self.img_id = None
|
||||
|
||||
def __str__(self):
|
||||
return f"id={self.id}, user_id={self.user_id}, task_type={self.task_type}, status={self.status}, img_id={self.img_id}"
|
||||
|
||||
|
||||
# midjourney bot
|
||||
class MJBot:
|
||||
def __init__(self, config):
|
||||
self.base_url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/img/midjourney"
|
||||
self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
self.config = config
|
||||
self.tasks = {}
|
||||
self.temp_dict = {}
|
||||
self.tasks_lock = threading.Lock()
|
||||
self.event_loop = asyncio.new_event_loop()
|
||||
|
||||
def judge_mj_task_type(self, e_context: EventContext):
|
||||
"""
|
||||
判断MJ任务的类型
|
||||
:param e_context: 上下文
|
||||
:return: 任务类型枚举
|
||||
"""
|
||||
if not self.config:
|
||||
return None
|
||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||
context = e_context['context']
|
||||
if context.type == ContextType.TEXT:
|
||||
cmd_list = context.content.split(maxsplit=1)
|
||||
if cmd_list[0].lower() == f"{trigger_prefix}mj":
|
||||
return TaskType.GENERATE
|
||||
elif cmd_list[0].lower() == f"{trigger_prefix}mju":
|
||||
return TaskType.UPSCALE
|
||||
elif cmd_list[0].lower() == f"{trigger_prefix}mjv":
|
||||
return TaskType.VARIATION
|
||||
elif cmd_list[0].lower() == f"{trigger_prefix}mjr":
|
||||
return TaskType.RESET
|
||||
elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix"):
|
||||
return TaskType.GENERATE
|
||||
|
||||
def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
|
||||
"""
|
||||
处理mj任务
|
||||
:param mj_type: mj任务类型
|
||||
:param e_context: 对话上下文
|
||||
"""
|
||||
context = e_context['context']
|
||||
session_id = context["session_id"]
|
||||
cmd = context.content.split(maxsplit=1)
|
||||
if len(cmd) == 1 and context.type == ContextType.TEXT:
|
||||
# midjourney 帮助指令
|
||||
self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO)
|
||||
return
|
||||
|
||||
if len(cmd) == 2 and (cmd[1] == "open" or cmd[1] == "close"):
|
||||
# midjourney 开关指令
|
||||
is_open = True
|
||||
tips_text = "开启"
|
||||
if cmd[1] == "close":
|
||||
tips_text = "关闭"
|
||||
is_open = False
|
||||
self.config["enabled"] = is_open
|
||||
self._set_reply_text(f"Midjourney绘画已{tips_text}", e_context, level=ReplyType.INFO)
|
||||
return
|
||||
|
||||
if not self.config.get("enabled"):
|
||||
logger.warn("Midjourney绘画未开启,请查看 plugins/linkai/config.json 中的配置")
|
||||
self._set_reply_text(f"Midjourney绘画未开启", e_context, level=ReplyType.INFO)
|
||||
return
|
||||
|
||||
if not self._check_rate_limit(session_id, e_context):
|
||||
logger.warn("[MJ] midjourney task exceed rate limit")
|
||||
return
|
||||
|
||||
if mj_type == TaskType.GENERATE:
|
||||
if context.type == ContextType.IMAGE_CREATE:
|
||||
raw_prompt = context.content
|
||||
else:
|
||||
# 图片生成
|
||||
raw_prompt = cmd[1]
|
||||
reply = self.generate(raw_prompt, session_id, e_context)
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
|
||||
elif mj_type == TaskType.UPSCALE or mj_type == TaskType.VARIATION:
|
||||
# 图片放大/变换
|
||||
clist = cmd[1].split()
|
||||
if len(clist) < 2:
|
||||
self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context)
|
||||
return
|
||||
img_id = clist[0]
|
||||
index = int(clist[1])
|
||||
if index < 1 or index > 4:
|
||||
self._set_reply_text(f"图片序号 {index} 错误,应在 1 至 4 之间", e_context)
|
||||
return
|
||||
key = f"{str(mj_type)}_{img_id}_{index}"
|
||||
if self.temp_dict.get(key):
|
||||
self._set_reply_text(f"第 {index} 张图片已经{task_name_mapping.get(str(mj_type))}过了", e_context)
|
||||
return
|
||||
# 执行图片放大/变换操作
|
||||
reply = self.do_operate(mj_type, session_id, img_id, e_context, index)
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
|
||||
elif mj_type == TaskType.RESET:
|
||||
# 图片重新生成
|
||||
clist = cmd[1].split()
|
||||
if len(clist) < 1:
|
||||
self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context)
|
||||
return
|
||||
img_id = clist[0]
|
||||
# 图片重新生成
|
||||
reply = self.do_operate(mj_type, session_id, img_id, e_context)
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
else:
|
||||
self._set_reply_text(f"暂不支持该命令", e_context)
|
||||
|
||||
def generate(self, prompt: str, user_id: str, e_context: EventContext) -> Reply:
|
||||
"""
|
||||
图片生成
|
||||
:param prompt: 提示词
|
||||
:param user_id: 用户id
|
||||
:param e_context: 对话上下文
|
||||
:return: 任务ID
|
||||
"""
|
||||
logger.info(f"[MJ] image generate, prompt={prompt}")
|
||||
mode = self._fetch_mode(prompt)
|
||||
body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")}
|
||||
if not self.config.get("img_proxy"):
|
||||
body["img_proxy"] = False
|
||||
res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5, 40))
|
||||
if res.status_code == 200:
|
||||
res = res.json()
|
||||
logger.debug(f"[MJ] image generate, res={res}")
|
||||
if res.get("code") == 200:
|
||||
task_id = res.get("data").get("task_id")
|
||||
real_prompt = res.get("data").get("real_prompt")
|
||||
if mode == TaskMode.RELAX.value:
|
||||
time_str = "1~10分钟"
|
||||
else:
|
||||
time_str = "1分钟"
|
||||
content = f"🚀您的作品将在{time_str}左右完成,请耐心等待\n- - - - - - - - -\n"
|
||||
if real_prompt:
|
||||
content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}"
|
||||
else:
|
||||
content += f"prompt: {prompt}"
|
||||
reply = Reply(ReplyType.INFO, content)
|
||||
task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id,
|
||||
task_type=TaskType.GENERATE)
|
||||
# put to memory dict
|
||||
self.tasks[task.id] = task
|
||||
# asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
|
||||
self._do_check_task(task, e_context)
|
||||
return reply
|
||||
else:
|
||||
res_json = res.json()
|
||||
logger.error(f"[MJ] generate error, msg={res_json.get('message')}, status_code={res.status_code}")
|
||||
if res.status_code == INVALID_REQUEST:
|
||||
reply = Reply(ReplyType.ERROR, "图片生成失败,请检查提示词参数或内容")
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试")
|
||||
return reply
|
||||
|
||||
def do_operate(self, task_type: TaskType, user_id: str, img_id: str, e_context: EventContext,
|
||||
index: int = None) -> Reply:
|
||||
logger.info(f"[MJ] image operate, task_type={task_type}, img_id={img_id}, index={index}")
|
||||
body = {"type": task_type.name, "img_id": img_id}
|
||||
if index:
|
||||
body["index"] = index
|
||||
if not self.config.get("img_proxy"):
|
||||
body["img_proxy"] = False
|
||||
res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5, 40))
|
||||
logger.debug(res)
|
||||
if res.status_code == 200:
|
||||
res = res.json()
|
||||
if res.get("code") == 200:
|
||||
task_id = res.get("data").get("task_id")
|
||||
logger.info(f"[MJ] image operate processing, task_id={task_id}")
|
||||
icon_map = {TaskType.UPSCALE: "🔎", TaskType.VARIATION: "🪄", TaskType.RESET: "🔄"}
|
||||
content = f"{icon_map.get(task_type)}图片正在{task_name_mapping.get(task_type.name)}中,请耐心等待"
|
||||
reply = Reply(ReplyType.INFO, content)
|
||||
task = MJTask(id=task_id, status=Status.PENDING, user_id=user_id, task_type=task_type)
|
||||
# put to memory dict
|
||||
self.tasks[task.id] = task
|
||||
key = f"{task_type.name}_{img_id}_{index}"
|
||||
self.temp_dict[key] = True
|
||||
# asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
|
||||
self._do_check_task(task, e_context)
|
||||
return reply
|
||||
else:
|
||||
error_msg = ""
|
||||
if res.status_code == NOT_FOUND_ORIGIN_IMAGE:
|
||||
error_msg = "请输入正确的图片ID"
|
||||
res_json = res.json()
|
||||
logger.error(f"[MJ] operate error, msg={res_json.get('message')}, status_code={res.status_code}")
|
||||
reply = Reply(ReplyType.ERROR, error_msg or "图片生成失败,请稍后再试")
|
||||
return reply
|
||||
|
||||
def check_task_sync(self, task: MJTask, e_context: EventContext):
|
||||
logger.debug(f"[MJ] start check task status, {task}")
|
||||
max_retry_times = 90
|
||||
while max_retry_times > 0:
|
||||
time.sleep(10)
|
||||
url = f"{self.base_url}/tasks/{task.id}"
|
||||
try:
|
||||
res = requests.get(url, headers=self.headers, timeout=8)
|
||||
if res.status_code == 200:
|
||||
res_json = res.json()
|
||||
logger.debug(f"[MJ] task check res sync, task_id={task.id}, status={res.status_code}, "
|
||||
f"data={res_json.get('data')}, thread={threading.current_thread().name}")
|
||||
if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name:
|
||||
# process success res
|
||||
if self.tasks.get(task.id):
|
||||
self.tasks[task.id].status = Status.FINISHED
|
||||
self._process_success_task(task, res_json.get("data"), e_context)
|
||||
return
|
||||
max_retry_times -= 1
|
||||
else:
|
||||
res_json = res.json()
|
||||
logger.warn(f"[MJ] image check error, status_code={res.status_code}, res={res_json}")
|
||||
max_retry_times -= 20
|
||||
except Exception as e:
|
||||
max_retry_times -= 20
|
||||
logger.warn(e)
|
||||
logger.warn("[MJ] end from poll")
|
||||
if self.tasks.get(task.id):
|
||||
self.tasks[task.id].status = Status.EXPIRED
|
||||
|
||||
def _do_check_task(self, task: MJTask, e_context: EventContext):
|
||||
threading.Thread(target=self.check_task_sync, args=(task, e_context)).start()
|
||||
|
||||
def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext):
|
||||
"""
|
||||
处理任务成功的结果
|
||||
:param task: MJ任务
|
||||
:param res: 请求结果
|
||||
:param e_context: 对话上下文
|
||||
"""
|
||||
# channel send img
|
||||
task.status = Status.FINISHED
|
||||
task.img_id = res.get("img_id")
|
||||
task.img_url = res.get("img_url")
|
||||
logger.info(f"[MJ] task success, task_id={task.id}, img_id={task.img_id}, img_url={task.img_url}")
|
||||
|
||||
# send img
|
||||
reply = Reply(ReplyType.IMAGE_URL, task.img_url)
|
||||
channel = e_context["channel"]
|
||||
_send(channel, reply, e_context["context"])
|
||||
|
||||
# send info
|
||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||
text = ""
|
||||
if task.task_type == TaskType.GENERATE or task.task_type == TaskType.VARIATION or task.task_type == TaskType.RESET:
|
||||
text = f"🎨绘画完成!\n"
|
||||
if task.raw_prompt:
|
||||
text += f"prompt: {task.raw_prompt}\n"
|
||||
text += f"- - - - - - - - -\n图片ID: {task.img_id}"
|
||||
text += f"\n\n🔎使用 {trigger_prefix}mju 命令放大图片\n"
|
||||
text += f"例如:\n{trigger_prefix}mju {task.img_id} 1"
|
||||
text += f"\n\n🪄使用 {trigger_prefix}mjv 命令变换图片\n"
|
||||
text += f"例如:\n{trigger_prefix}mjv {task.img_id} 1"
|
||||
text += f"\n\n🔄使用 {trigger_prefix}mjr 命令重新生成图片\n"
|
||||
text += f"例如:\n{trigger_prefix}mjr {task.img_id}"
|
||||
reply = Reply(ReplyType.INFO, text)
|
||||
_send(channel, reply, e_context["context"])
|
||||
|
||||
self._print_tasks()
|
||||
return
|
||||
|
||||
def _check_rate_limit(self, user_id: str, e_context: EventContext) -> bool:
|
||||
"""
|
||||
midjourney任务限流控制
|
||||
:param user_id: 用户id
|
||||
:param e_context: 对话上下文
|
||||
:return: 任务是否能够生成, True:可以生成, False: 被限流
|
||||
"""
|
||||
tasks = self.find_tasks_by_user_id(user_id)
|
||||
task_count = len([t for t in tasks if t.status == Status.PENDING])
|
||||
if task_count >= self.config.get("max_tasks_per_user"):
|
||||
reply = Reply(ReplyType.INFO, "您的Midjourney作图任务数已达上限,请稍后再试")
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return False
|
||||
task_count = len([t for t in self.tasks.values() if t.status == Status.PENDING])
|
||||
if task_count >= self.config.get("max_tasks"):
|
||||
reply = Reply(ReplyType.INFO, "Midjourney作图任务数已达上限,请稍后再试")
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return False
|
||||
return True
|
||||
|
||||
def _fetch_mode(self, prompt) -> str:
|
||||
mode = self.config.get("mode")
|
||||
if "--relax" in prompt or mode == TaskMode.RELAX.value:
|
||||
return TaskMode.RELAX.value
|
||||
return mode or TaskMode.FAST.value
|
||||
|
||||
def _run_loop(self, loop: asyncio.BaseEventLoop):
|
||||
"""
|
||||
运行事件循环,用于轮询任务的线程
|
||||
:param loop: 事件循环
|
||||
"""
|
||||
loop.run_forever()
|
||||
loop.stop()
|
||||
|
||||
def _print_tasks(self):
|
||||
for id in self.tasks:
|
||||
logger.debug(f"[MJ] current task: {self.tasks[id]}")
|
||||
|
||||
def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR):
|
||||
"""
|
||||
设置回复文本
|
||||
:param content: 回复内容
|
||||
:param e_context: 对话上下文
|
||||
:param level: 回复等级
|
||||
"""
|
||||
reply = Reply(level, content)
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
|
||||
def get_help_text(self, verbose=False, **kwargs):
|
||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||
help_text = "🎨利用Midjourney进行画图\n\n"
|
||||
if not verbose:
|
||||
return help_text
|
||||
help_text += f" - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n - 变换: mjv 图片ID 图片序号\n - 重置: mjr 图片ID"
|
||||
help_text += f"\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 11055927171882 2\""
|
||||
help_text += f"\n\"{trigger_prefix}mjv 11055927171882 2\"\n\"{trigger_prefix}mjr 11055927171882\""
|
||||
return help_text
|
||||
|
||||
def find_tasks_by_user_id(self, user_id) -> list:
|
||||
result = []
|
||||
with self.tasks_lock:
|
||||
now = time.time()
|
||||
for task in self.tasks.values():
|
||||
if task.status == Status.PENDING and now > task.expiry_time:
|
||||
task.status = Status.EXPIRED
|
||||
logger.info(f"[MJ] {task} expired")
|
||||
if task.user_id == user_id:
|
||||
result.append(task)
|
||||
return result
|
||||
|
||||
|
||||
def _send(channel, reply: Reply, context, retry_cnt=0):
|
||||
try:
|
||||
channel.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error("[WX] sendMsg error: {}".format(str(e)))
|
||||
if isinstance(e, NotImplementedError):
|
||||
return
|
||||
logger.exception(e)
|
||||
if retry_cnt < 2:
|
||||
time.sleep(3 + 3 * retry_cnt)
|
||||
channel.send(reply, context, retry_cnt + 1)
|
||||
|
||||
|
||||
def check_prefix(content, prefix_list):
|
||||
if not prefix_list:
|
||||
return None
|
||||
for prefix in prefix_list:
|
||||
if content.startswith(prefix):
|
||||
return prefix
|
||||
return None
|
||||
+21
-4
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import json
|
||||
from config import pconf
|
||||
from config import pconf, plugin_config, conf
|
||||
from common.log import logger
|
||||
|
||||
|
||||
@@ -15,14 +15,31 @@ class Plugin:
|
||||
"""
|
||||
# 优先获取 plugins/config.json 中的全局配置
|
||||
plugin_conf = pconf(self.name)
|
||||
if not plugin_conf:
|
||||
# 全局配置不存在,则获取插件目录下的配置
|
||||
if not plugin_conf or not conf().get("use_global_plugin_config"):
|
||||
# 全局配置不存在 或者 未开启全局配置开关,则获取插件目录下的配置
|
||||
plugin_config_path = os.path.join(self.path, "config.json")
|
||||
if os.path.exists(plugin_config_path):
|
||||
with open(plugin_config_path, "r") as f:
|
||||
with open(plugin_config_path, "r", encoding="utf-8") as f:
|
||||
plugin_conf = json.load(f)
|
||||
logger.debug(f"loading plugin config, plugin_name={self.name}, conf={plugin_conf}")
|
||||
return plugin_conf
|
||||
|
||||
def save_config(self, config: dict):
|
||||
try:
|
||||
plugin_config[self.name] = config
|
||||
# 写入全局配置
|
||||
global_config_path = "./plugins/config.json"
|
||||
if os.path.exists(global_config_path):
|
||||
with open(global_config_path, "w", encoding='utf-8') as f:
|
||||
json.dump(plugin_config, f, indent=4, ensure_ascii=False)
|
||||
# 写入插件配置
|
||||
plugin_config_path = os.path.join(self.path, "config.json")
|
||||
if os.path.exists(plugin_config_path):
|
||||
with open(plugin_config_path, "w", encoding='utf-8') as f:
|
||||
json.dump(config, f, indent=4, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
logger.warn("save plugin config failed: {}".format(e))
|
||||
|
||||
def get_help_text(self, **kwargs):
|
||||
return "暂无帮助信息"
|
||||
|
||||
@@ -23,6 +23,8 @@ web.py
|
||||
wechatpy
|
||||
|
||||
# chatgpt-tool-hub plugin
|
||||
|
||||
--extra-index-url https://pypi.python.org/simple
|
||||
chatgpt_tool_hub==0.4.6
|
||||
|
||||
# xunfei spark
|
||||
websocket-client==1.2.0
|
||||
|
||||
@@ -6,3 +6,4 @@ requests>=2.28.2
|
||||
chardet>=5.1.0
|
||||
Pillow
|
||||
pre-commit
|
||||
web.py
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
import time
|
||||
|
||||
from elevenlabs import set_api_key,generate
|
||||
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from voice.voice import Voice
|
||||
from config import conf
|
||||
|
||||
XI_API_KEY = conf().get("xi_api_key")
|
||||
set_api_key(XI_API_KEY)
|
||||
name = conf().get("xi_voice_id")
|
||||
|
||||
class ElevenLabsVoice(Voice):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def voiceToText(self, voice_file):
|
||||
pass
|
||||
|
||||
def textToVoice(self, text):
|
||||
audio = generate(
|
||||
text=text,
|
||||
voice=name,
|
||||
model='eleven_multilingual_v1'
|
||||
)
|
||||
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3"
|
||||
with open(fileName, "wb") as f:
|
||||
f.write(audio)
|
||||
logger.info("[ElevenLabs] textToVoice text={} voice file name={}".format(text, fileName))
|
||||
return Reply(ReplyType.VOICE, fileName)
|
||||
@@ -29,4 +29,8 @@ def create_voice(voice_type):
|
||||
from voice.azure.azure_voice import AzureVoice
|
||||
|
||||
return AzureVoice()
|
||||
elif voice_type == "elevenlabs":
|
||||
from voice.elevent.elevent_voice import ElevenLabsVoice
|
||||
|
||||
return ElevenLabsVoice()
|
||||
raise RuntimeError
|
||||
|
||||
Reference in New Issue
Block a user