mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-05-10 13:32:14 +08:00
Compare commits
69 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a77e4bfb7a | |||
| 654c177333 | |||
| b92669ba33 | |||
| f2e4f6607d | |||
| 5ec909c565 | |||
| a84f31d54a | |||
| e0dd21406d | |||
| 72f5f7a0b8 | |||
| e3d20085c5 | |||
| 8bf1aef801 | |||
| 5f7ade20dc | |||
| 70d7e52df0 | |||
| 8e6afa5614 | |||
| a1ae3804e3 | |||
| 814ce7a43b | |||
| 628f75009e | |||
| 03fc8c1202 | |||
| 8c8e996c87 | |||
| 933bb0b1fb | |||
| 931fbc3eb5 | |||
| 3db5e70a3d | |||
| 7b19b70d90 | |||
| 99b8103d70 | |||
| 7167310ccd | |||
| 263667a2d4 | |||
| d5cef291f6 | |||
| c8d166e833 | |||
| 6e25782d8b | |||
| c3127f7e84 | |||
| 7b90fb018b | |||
| e8bc173cd7 | |||
| 4d1cdf5207 | |||
| 57a473364e | |||
| 40b62e9d38 | |||
| ead5f9926b | |||
| 814b6753c2 | |||
| ce505251f8 | |||
| 5d2a987aaa | |||
| 4d67e08723 | |||
| 2e71dd5fe2 | |||
| c3b9643227 | |||
| 0aad5dc2b7 | |||
| cec900168f | |||
| f9b1c403d5 | |||
| 9024b602f5 | |||
| c139fd9a57 | |||
| e299b68163 | |||
| 7777a53a82 | |||
| 3e185dbbfe | |||
| e8a32af369 | |||
| 7b0ec6687e | |||
| ec1c6c7b92 | |||
| 8dfaa86760 | |||
| 323aebd1be | |||
| 436c038a2f | |||
| ccd50ec6c0 | |||
| a7541c2c0f | |||
| c3a57d756c | |||
| aa300a4c98 | |||
| 83ea7352b9 | |||
| 9050712cd8 | |||
| 8d92fdbb6e | |||
| a2442ec1b9 | |||
| 71662c9cd9 | |||
| 54ff5dbcc2 | |||
| 4ab7bd3b51 | |||
| ef3c61a297 | |||
| abf79bf60c | |||
| 5d3cecd926 |
@@ -14,6 +14,9 @@ tmp
|
||||
plugins.json
|
||||
itchat.pkl
|
||||
*.log
|
||||
logs/
|
||||
workspace
|
||||
config.yaml
|
||||
user_datas.pkl
|
||||
chatgpt_tool_hub/
|
||||
plugins/**/
|
||||
@@ -30,4 +33,5 @@ plugins/banwords/lib/__pycache__
|
||||
!plugins/role
|
||||
!plugins/keyword
|
||||
!plugins/linkai
|
||||
!plugins/agent
|
||||
client_config.json
|
||||
|
||||
@@ -1,11 +1,19 @@
|
||||
# 简介
|
||||
<p align="center"><img src= "https://github.com/user-attachments/assets/31fb4eab-3be4-477d-aa76-82cf62bfd12c" alt="Chatgpt-on-Wechat" width="600" /></p>
|
||||
|
||||
> chatgpt-on-wechat(简称CoW)项目是基于大模型的智能对话机器人,支持微信公众号、企业微信应用、飞书、钉钉接入,可选择GPT3.5/GPT4.0/Claude/Gemini/LinkAI/ChatGLM/KIMI/文心一言/讯飞星火/通义千问/LinkAI,能处理文本、语音和图片,通过插件访问操作系统和互联网等外部资源,支持基于自有知识库定制企业AI应用。
|
||||
<p align="center">
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat/releases/latest"><img src="https://img.shields.io/github/v/release/zhayujie/chatgpt-on-wechat" alt="Latest release"></a>
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat/blob/master/LICENSE"><img src="https://img.shields.io/github/license/zhayujie/chatgpt-on-wechat" alt="License: MIT"></a>
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat"><img src="https://img.shields.io/github/stars/zhayujie/chatgpt-on-wechat?style=flat-square" alt="Stars"></a> <br/>
|
||||
</p>
|
||||
|
||||
chatgpt-on-wechat(简称CoW)项目是基于大模型的智能对话机器人,支持微信公众号、企业微信应用、飞书、钉钉接入,可选择GPT3.5/GPT4.0/Claude/Gemini/LinkAI/ChatGLM/KIMI/文心一言/讯飞星火/通义千问/LinkAI/ModelScope,能处理文本、语音和图片,通过插件访问操作系统和互联网等外部资源,支持基于自有知识库定制企业AI应用。
|
||||
|
||||
# 简介
|
||||
|
||||
最新版本支持的功能如下:
|
||||
|
||||
- ✅ **多端部署:** 有多种部署方式可选择且功能完备,目前已支持微信公众号、企业微信应用、飞书、钉钉等部署方式
|
||||
- ✅ **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3.5, GPT-4o-mini, GPT-4o, GPT-4, Claude-3.5, Gemini, 文心一言, 讯飞星火, 通义千问,ChatGLM-4,Kimi(月之暗面), MiniMax
|
||||
- ✅ **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-4o系列, GPT-4.1系列, Claude, Gemini, 文心一言, 讯飞星火, 通义千问,ChatGLM-4,Kimi, MiniMax, GiteeAI, ModelScope
|
||||
- ✅ **语音能力:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai(whisper/tts) 等多种语音模型
|
||||
- ✅ **图像能力:** 支持图片生成、图片识别、图生图(如照片修复),可选择 Dall-E-3, stable diffusion, replicate, midjourney, CogView-3, vision模型
|
||||
- ✅ **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结、文档总结和对话、联网搜索等插件
|
||||
@@ -45,6 +53,13 @@ DEMO视频:https://cdn.link-ai.tech/doc/cow_demo.mp4
|
||||
<br>
|
||||
|
||||
# 🏷 更新日志
|
||||
|
||||
>**2025.05.23:** [1.7.6版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.6) 优化web网页channel、新增[AgentMesh多智能体插件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/agent/README.md)、百度语音合成优化、企微应用`access_token`获取优化、支持`claude-4-sonnet`和`claude-4-opus`模型
|
||||
|
||||
>**2025.04.11:** [1.7.5版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.5) 新增支持 [wechatferry](https://github.com/zhayujie/chatgpt-on-wechat/pull/2562) 协议、新增 deepseek 模型、新增支持腾讯云语音能力、新增支持 ModelScope 和 Gitee-AI API接口
|
||||
|
||||
>**2024.12.13:** [1.7.4版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.4) 新增 Gemini 2.0 模型、新增web channel、解决内存泄漏问题、解决 `#reloadp` 命令重载不生效问题
|
||||
|
||||
>**2024.10.31:** [1.7.3版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.3) 程序稳定性提升、数据库功能、Claude模型优化、linkai插件优化、离线通知
|
||||
|
||||
>**2024.09.26:** [1.7.2版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.2) 和 [1.7.1版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.1) 文心,讯飞等模型优化、o1 模型、快速安装和管理脚本
|
||||
@@ -142,7 +157,7 @@ pip3 install -r requirements-optional.txt
|
||||
```bash
|
||||
# config.json文件内容示例
|
||||
{
|
||||
"model": "gpt-3.5-turbo", # 模型名称, 支持 gpt-3.5-turbo, gpt-4, gpt-4-turbo, wenxin, xunfei, glm-4, claude-3-haiku, moonshot
|
||||
"model": "gpt-4o-mini", # 模型名称, 支持 gpt-4o-mini, gpt-4.1, gpt-4o, wenxin, xunfei, glm-4, claude-3-7-sonnet-latest, moonshot等
|
||||
"open_ai_api_key": "YOUR API KEY", # 如果使用openAI模型则填入上面创建的 OpenAI API KEY
|
||||
"open_ai_api_base": "https://api.openai.com/v1", # OpenAI接口代理地址
|
||||
"proxy": "", # 代理客户端的ip和端口,国内环境开启代理的需要填写该项,如 "127.0.0.1:7890"
|
||||
@@ -186,7 +201,7 @@ pip3 install -r requirements-optional.txt
|
||||
|
||||
**4.其他配置**
|
||||
|
||||
+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `gpt-4o-mini`, `gpt-4o`, `gpt-4`, `wenxin` , `claude` , `gemini`, `glm-4`, `xunfei`, `moonshot`等,全部模型名称参考[common/const.py](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/common/const.py)文件
|
||||
+ `model`: 模型名称,目前支持 `gpt-4o-mini`, `gpt-4.1`, `gpt-4o`, `gpt-3.5-turbo`, `wenxin` , `claude` , `gemini`, `glm-4`, `xunfei`, `moonshot`等,全部模型名称参考[common/const.py](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/common/const.py)文件
|
||||
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
|
||||
+ `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
|
||||
+ 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix `
|
||||
|
||||
@@ -68,5 +68,9 @@ def create_bot(bot_type):
|
||||
from bot.minimax.minimax_bot import MinimaxBot
|
||||
return MinimaxBot()
|
||||
|
||||
elif bot_type == const.MODELSCOPE:
|
||||
from bot.modelscope.modelscope_bot import ModelScopeBot
|
||||
return ModelScopeBot()
|
||||
|
||||
|
||||
raise RuntimeError
|
||||
|
||||
@@ -83,7 +83,7 @@ def num_tokens_from_messages(messages, model):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo.")
|
||||
logger.debug(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo.")
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
|
||||
@@ -147,9 +147,9 @@ class LinkAIBot(Bot):
|
||||
if response["choices"][0].get("img_urls"):
|
||||
thread = threading.Thread(target=self._send_image, args=(context.get("channel"), context, response["choices"][0].get("img_urls")))
|
||||
thread.start()
|
||||
if response["choices"][0].get("text_content"):
|
||||
reply_content = response["choices"][0].get("text_content")
|
||||
reply_content = self._process_url(reply_content)
|
||||
reply_content = response["choices"][0].get("text_content")
|
||||
if reply_content:
|
||||
reply_content = self._process_url(reply_content)
|
||||
return Reply(ReplyType.TEXT, reply_content)
|
||||
|
||||
else:
|
||||
|
||||
@@ -0,0 +1,277 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import time
|
||||
import json
|
||||
import openai
|
||||
import openai.error
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf, load_config
|
||||
from .modelscope_session import ModelScopeSession
|
||||
import requests
|
||||
|
||||
|
||||
# ModelScope对话模型API
|
||||
class ModelScopeBot(Bot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sessions = SessionManager(ModelScopeSession, model=conf().get("model") or "Qwen/Qwen2.5-7B-Instruct")
|
||||
model = conf().get("model") or "Qwen/Qwen2.5-7B-Instruct"
|
||||
if model == "modelscope":
|
||||
model = "Qwen/Qwen2.5-7B-Instruct"
|
||||
self.args = {
|
||||
"model": model, # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.3), # 如果设置,值域须为 [0, 1] 我们推荐 0.3,以达到较合适的效果。
|
||||
"top_p": conf().get("top_p", 1.0), # 使用默认值
|
||||
}
|
||||
self.api_key = conf().get("modelscope_api_key")
|
||||
self.base_url = conf().get("modelscope_base_url", "https://api-inference.modelscope.cn/v1/chat/completions")
|
||||
"""
|
||||
需要获取ModelScope支持API-inference的模型名称列表,请到魔搭社区官网模型中心查看 https://modelscope.cn/models?filter=inference_type&page=1。
|
||||
或者使用命令 curl https://api-inference.modelscope.cn/v1/models 对模型列表和ID进行获取。查看commend/const.py文件也可以获取模型列表。
|
||||
获取ModelScope的免费API Key,请到魔搭社区官网用户中心查看获取方式 https://modelscope.cn/docs/model-service/API-Inference/intro。
|
||||
"""
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[MODELSCOPE_AI] query={}".format(query))
|
||||
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
|
||||
if query in clear_memory_commands:
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
elif query == "#更新配置":
|
||||
load_config()
|
||||
reply = Reply(ReplyType.INFO, "配置已更新")
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[MODELSCOPE_AI] session query={}".format(session.messages))
|
||||
|
||||
model = context.get("modelscope_model")
|
||||
new_args = self.args.copy()
|
||||
if model:
|
||||
new_args["model"] = model
|
||||
|
||||
if new_args["model"] == "Qwen/QwQ-32B":
|
||||
reply_content = self.reply_text_stream(session, args=new_args)
|
||||
else:
|
||||
reply_content = self.reply_text(session, args=new_args)
|
||||
|
||||
logger.debug(
|
||||
"[MODELSCOPE_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||
session.messages,
|
||||
session_id,
|
||||
reply_content["content"],
|
||||
reply_content["completion_tokens"],
|
||||
)
|
||||
)
|
||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
||||
# 只有当 content 为空且 completion_tokens 为 0 时才标记为错误
|
||||
if len(reply_content["content"]) == 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
logger.debug("[MODELSCOPE_AI] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: ModelScopeSession, args=None, retry_count=0) -> dict:
|
||||
"""
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
:param session_id: session id
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
"""
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + self.api_key
|
||||
}
|
||||
|
||||
body = args
|
||||
body["messages"] = session.messages
|
||||
res = requests.post(
|
||||
self.base_url,
|
||||
headers=headers,
|
||||
data=json.dumps(body)
|
||||
)
|
||||
|
||||
if res.status_code == 200:
|
||||
response = res.json()
|
||||
return {
|
||||
"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
"content": response["choices"][0]["message"]["content"]
|
||||
}
|
||||
else:
|
||||
response = res.json()
|
||||
if "errors" in response:
|
||||
error = response.get("errors")
|
||||
elif "error" in response:
|
||||
error = response.get("error")
|
||||
else:
|
||||
error = "Unknown error"
|
||||
logger.error(f"[MODELSCOPE_AI] chat failed, status_code={res.status_code}, "
|
||||
f"msg={error.get('message')}, type={error.get('type')}")
|
||||
|
||||
result = {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
|
||||
need_retry = False
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
logger.warn(f"[MODELSCOPE_AI] do retry, times={retry_count}")
|
||||
need_retry = retry_count < 2
|
||||
elif res.status_code == 401:
|
||||
result["content"] = "授权失败,请检查API Key是否正确"
|
||||
elif res.status_code == 429:
|
||||
result["content"] = "请求过于频繁,请稍后再试"
|
||||
need_retry = retry_count < 2
|
||||
else:
|
||||
need_retry = False
|
||||
|
||||
if need_retry:
|
||||
time.sleep(3)
|
||||
return self.reply_text(session, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if need_retry:
|
||||
return self.reply_text(session, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
|
||||
def reply_text_stream(self, session: ModelScopeSession, args=None, retry_count=0) -> dict:
|
||||
"""
|
||||
call ModelScope's ChatCompletion to get the answer with stream response
|
||||
:param session: a conversation session
|
||||
:param session_id: session id
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
"""
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + self.api_key
|
||||
}
|
||||
|
||||
body = args
|
||||
body["messages"] = session.messages
|
||||
body["stream"] = True # 启用流式响应
|
||||
|
||||
res = requests.post(
|
||||
self.base_url,
|
||||
headers=headers,
|
||||
data=json.dumps(body),
|
||||
stream=True
|
||||
)
|
||||
if res.status_code == 200:
|
||||
content = ""
|
||||
for line in res.iter_lines():
|
||||
if line:
|
||||
decoded_line = line.decode('utf-8')
|
||||
if decoded_line.startswith("data: "):
|
||||
try:
|
||||
json_data = json.loads(decoded_line[6:])
|
||||
delta_content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
||||
if delta_content:
|
||||
content += delta_content
|
||||
except json.JSONDecodeError as e:
|
||||
pass
|
||||
return {
|
||||
"total_tokens": 1, # 流式响应通常不返回token使用情况
|
||||
"completion_tokens": 1,
|
||||
"content": content
|
||||
}
|
||||
else:
|
||||
response = res.json()
|
||||
if "errors" in response:
|
||||
error = response.get("errors")
|
||||
elif "error" in response:
|
||||
error = response.get("error")
|
||||
else:
|
||||
error = "Unknown error"
|
||||
logger.error(f"[MODELSCOPE_AI] chat failed, status_code={res.status_code}, "
|
||||
f"msg={error.get('message')}, type={error.get('type')}")
|
||||
|
||||
result = {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
|
||||
need_retry = False
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
logger.warn(f"[MODELSCOPE_AI] do retry, times={retry_count}")
|
||||
need_retry = retry_count < 2
|
||||
elif res.status_code == 401:
|
||||
result["content"] = "授权失败,请检查API Key是否正确"
|
||||
elif res.status_code == 429:
|
||||
result["content"] = "请求过于频繁,请稍后再试"
|
||||
need_retry = retry_count < 2
|
||||
else:
|
||||
need_retry = False
|
||||
|
||||
if need_retry:
|
||||
time.sleep(3)
|
||||
return self.reply_text_stream(session, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if need_retry:
|
||||
return self.reply_text_stream(session, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
def create_img(self, query, retry_count=0):
|
||||
try:
|
||||
logger.info("[ModelScopeImage] image_query={}".format(query))
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8", # 明确指定编码
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
payload = {
|
||||
"prompt": query, # required
|
||||
"n": 1,
|
||||
"model": conf().get("text_to_image"),
|
||||
}
|
||||
url = "https://api-inference.modelscope.cn/v1/images/generations"
|
||||
|
||||
# 手动序列化并保留中文(禁用 ASCII 转义)
|
||||
json_payload = json.dumps(payload, ensure_ascii=False).encode('utf-8')
|
||||
|
||||
# 使用 data 参数发送原始字符串(requests 会自动处理编码)
|
||||
res = requests.post(url, headers=headers, data=json_payload)
|
||||
|
||||
response_data = res.json()
|
||||
image_url = response_data['images'][0]['url']
|
||||
logger.info("[ModelScopeImage] image_url={}".format(image_url))
|
||||
return True, image_url
|
||||
|
||||
except Exception as e:
|
||||
logger.error(format(e))
|
||||
return False, "画图出现问题,请休息一下再问我吧"
|
||||
@@ -0,0 +1,51 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class ModelScopeSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="Qwen/Qwen2.5-7B-Instruct"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
self.reset()
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 2:
|
||||
self.messages.pop(1)
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
|
||||
self.messages.pop(1)
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
break
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
|
||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens,
|
||||
len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages, self.model)
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages, model):
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
tokens += len(msg["content"])
|
||||
return tokens
|
||||
@@ -41,7 +41,7 @@ class XunFeiBot(Bot):
|
||||
self.api_key = conf().get("xunfei_api_key")
|
||||
self.api_secret = conf().get("xunfei_api_secret")
|
||||
# 默认使用v2.0版本: "generalv2"
|
||||
# Spark Lite请求地址(spark_url): wss://spark-api.xf-yun.com/v1.1/chat, 对应的domain参数为: "general"
|
||||
# Spark Lite请求地址(spark_url): wss://spark-api.xf-yun.com/v1.1/chat, 对应的domain参数为: "lite"
|
||||
# Spark V2.0请求地址(spark_url): wss://spark-api.xf-yun.com/v2.1/chat, 对应的domain参数为: "generalv2"
|
||||
# Spark Pro 请求地址(spark_url): wss://spark-api.xf-yun.com/v3.1/chat, 对应的domain参数为: "generalv3"
|
||||
# Spark Pro-128K请求地址(spark_url): wss://spark-api.xf-yun.com/chat/pro-128k, 对应的domain参数为: "pro-128k"
|
||||
|
||||
+4
-1
@@ -40,7 +40,7 @@ class Bridge(object):
|
||||
self.btype["chat"] = const.GEMINI
|
||||
if model_type and model_type.startswith("glm"):
|
||||
self.btype["chat"] = const.ZHIPU_AI
|
||||
if model_type and model_type.startswith("claude-3"):
|
||||
if model_type and model_type.startswith("claude"):
|
||||
self.btype["chat"] = const.CLAUDEAPI
|
||||
|
||||
if model_type in ["claude"]:
|
||||
@@ -49,6 +49,9 @@ class Bridge(object):
|
||||
if model_type in [const.MOONSHOT, "moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]:
|
||||
self.btype["chat"] = const.MOONSHOT
|
||||
|
||||
if model_type in [const.MODELSCOPE]:
|
||||
self.btype["chat"] = const.MODELSCOPE
|
||||
|
||||
if model_type in ["abab6.5-chat"]:
|
||||
self.btype["chat"] = const.MiniMax
|
||||
|
||||
|
||||
@@ -18,6 +18,9 @@ def create_channel(channel_type) -> Channel:
|
||||
elif channel_type == "wxy":
|
||||
from channel.wechat.wechaty_channel import WechatyChannel
|
||||
ch = WechatyChannel()
|
||||
elif channel_type == "wcf":
|
||||
from channel.wechat.wcf_channel import WechatfChannel
|
||||
ch = WechatfChannel()
|
||||
elif channel_type == "terminal":
|
||||
from channel.terminal.terminal_channel import TerminalChannel
|
||||
ch = TerminalChannel()
|
||||
|
||||
@@ -146,6 +146,7 @@ class ChatChannel(Channel):
|
||||
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
|
||||
pass
|
||||
else:
|
||||
logger.info("[chat_channel]receive single chat msg, but checkprefix didn't match")
|
||||
return None
|
||||
content = content.strip()
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix",[""]))
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
# Web channel
|
||||
使用SSE(Server-Sent Events,服务器推送事件)实现,提供了一个默认的网页。也可以自己实现加入api
|
||||
# Web Channel
|
||||
|
||||
#使用方法
|
||||
- 在配置文件中channel_type填入web即可
|
||||
- 访问地址 http://localhost:9899
|
||||
- port可以在配置项 web_port中设置
|
||||
提供了一个默认的AI对话页面,可展示文本、图片等消息交互,支持markdown语法渲染,兼容插件执行。
|
||||
|
||||
# 使用说明
|
||||
|
||||
- 在 `config.json` 配置文件中的 `channel_type` 字段填入 `web`
|
||||
- 程序运行后将监听9899端口,浏览器访问 http://localhost:9899/chat 即可使用
|
||||
- 监听端口可以在配置文件 `web_port` 中自定义
|
||||
- 对于Docker运行方式,如果需要外部访问,需要在 `docker-compose.yml` 中通过 ports配置将端口监听映射到宿主机
|
||||
|
||||
+1454
-113
File diff suppressed because it is too large
Load Diff
Vendored
+2
File diff suppressed because one or more lines are too long
Binary file not shown.
|
After Width: | Height: | Size: 4.2 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 3.4 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 21 KiB |
+194
-107
@@ -2,7 +2,8 @@ import sys
|
||||
import time
|
||||
import web
|
||||
import json
|
||||
from queue import Queue
|
||||
import uuid
|
||||
from queue import Queue, Empty
|
||||
from bridge.context import *
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
@@ -11,7 +12,9 @@ from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
import os
|
||||
|
||||
import mimetypes # 添加这行来处理MIME类型
|
||||
import threading
|
||||
import logging
|
||||
|
||||
class WebMessage(ChatMessage):
|
||||
def __init__(
|
||||
@@ -43,131 +46,138 @@ class WebChannel(ChatChannel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.message_queues = {} # 为每个用户存储一个消息队列
|
||||
self.msg_id_counter = 0 # 添加消息ID计数器
|
||||
self.session_queues = {} # 存储session_id到队列的映射
|
||||
self.request_to_session = {} # 存储request_id到session_id的映射
|
||||
|
||||
def _generate_msg_id(self):
|
||||
"""生成唯一的消息ID"""
|
||||
self.msg_id_counter += 1
|
||||
return str(int(time.time())) + str(self.msg_id_counter)
|
||||
|
||||
def _generate_request_id(self):
|
||||
"""生成唯一的请求ID"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
try:
|
||||
if reply.type == ReplyType.IMAGE:
|
||||
from PIL import Image
|
||||
if reply.type in self.NOT_SUPPORT_REPLYTYPE:
|
||||
logger.warning(f"Web channel doesn't support {reply.type} yet")
|
||||
return
|
||||
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
img = Image.open(image_storage)
|
||||
print("<IMAGE>")
|
||||
img.show()
|
||||
elif reply.type == ReplyType.IMAGE_URL:
|
||||
import io
|
||||
if reply.type == ReplyType.IMAGE_URL:
|
||||
time.sleep(0.5)
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
img_url = reply.content
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
image_storage.seek(0)
|
||||
img = Image.open(image_storage)
|
||||
print(img_url)
|
||||
img.show()
|
||||
else:
|
||||
print(reply.content)
|
||||
|
||||
# 获取用户ID,如果没有则使用默认值
|
||||
# user_id = getattr(context.get("session", None), "session_id", "default_user")
|
||||
user_id = context["receiver"]
|
||||
# 确保用户有对应的消息队列
|
||||
if user_id not in self.message_queues:
|
||||
self.message_queues[user_id] = Queue()
|
||||
# 获取请求ID和会话ID
|
||||
request_id = context.get("request_id", None)
|
||||
|
||||
if not request_id:
|
||||
logger.error("No request_id found in context, cannot send message")
|
||||
return
|
||||
|
||||
# 将消息放入对应用户的队列
|
||||
message_data = {
|
||||
"type": str(reply.type),
|
||||
"content": reply.content,
|
||||
"timestamp": time.time()
|
||||
}
|
||||
self.message_queues[user_id].put(message_data)
|
||||
logger.debug(f"Message queued for user {user_id}")
|
||||
# 通过request_id获取session_id
|
||||
session_id = self.request_to_session.get(request_id)
|
||||
if not session_id:
|
||||
logger.error(f"No session_id found for request {request_id}")
|
||||
return
|
||||
|
||||
# 检查是否有会话队列
|
||||
if session_id in self.session_queues:
|
||||
# 创建响应数据,包含请求ID以区分不同请求的响应
|
||||
response_data = {
|
||||
"type": str(reply.type),
|
||||
"content": reply.content,
|
||||
"timestamp": time.time(),
|
||||
"request_id": request_id
|
||||
}
|
||||
self.session_queues[session_id].put(response_data)
|
||||
logger.debug(f"Response sent to queue for session {session_id}, request {request_id}")
|
||||
else:
|
||||
logger.warning(f"No response queue found for session {session_id}, response dropped")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in send method: {e}")
|
||||
raise
|
||||
|
||||
def sse_handler(self, user_id):
|
||||
"""
|
||||
Handle Server-Sent Events (SSE) for real-time communication.
|
||||
"""
|
||||
web.header('Content-Type', 'text/event-stream')
|
||||
web.header('Cache-Control', 'no-cache')
|
||||
web.header('Connection', 'keep-alive')
|
||||
|
||||
# 确保用户有消息队列
|
||||
if user_id not in self.message_queues:
|
||||
self.message_queues[user_id] = Queue()
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# 发送心跳
|
||||
yield f": heartbeat\n\n"
|
||||
|
||||
# 非阻塞方式获取消息
|
||||
if not self.message_queues[user_id].empty():
|
||||
message = self.message_queues[user_id].get_nowait()
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
time.sleep(0.5)
|
||||
except Exception as e:
|
||||
logger.error(f"SSE Error: {e}")
|
||||
break
|
||||
finally:
|
||||
# 清理资源
|
||||
if user_id in self.message_queues:
|
||||
# 只有当队列为空时才删除
|
||||
if self.message_queues[user_id].empty():
|
||||
del self.message_queues[user_id]
|
||||
|
||||
def post_message(self):
|
||||
"""
|
||||
Handle incoming messages from users via POST request.
|
||||
Returns a request_id for tracking this specific request.
|
||||
"""
|
||||
try:
|
||||
data = web.data() # 获取原始POST数据
|
||||
json_data = json.loads(data)
|
||||
user_id = json_data.get('user_id', 'default_user')
|
||||
session_id = json_data.get('session_id', f'session_{int(time.time())}')
|
||||
prompt = json_data.get('message', '')
|
||||
except json.JSONDecodeError:
|
||||
return json.dumps({"status": "error", "message": "Invalid JSON"})
|
||||
except Exception as e:
|
||||
return json.dumps({"status": "error", "message": str(e)})
|
||||
|
||||
if not prompt:
|
||||
return json.dumps({"status": "error", "message": "No message provided"})
|
||||
|
||||
try:
|
||||
msg_id = self._generate_msg_id()
|
||||
context = self._compose_context(ContextType.TEXT, prompt, msg=WebMessage(msg_id,
|
||||
prompt,
|
||||
from_user_id=user_id,
|
||||
other_user_id = user_id
|
||||
))
|
||||
context["isgroup"] = False
|
||||
# context["session"] = web.storage(session_id=user_id)
|
||||
# 生成请求ID
|
||||
request_id = self._generate_request_id()
|
||||
|
||||
if not context:
|
||||
return json.dumps({"status": "error", "message": "Failed to process message"})
|
||||
|
||||
self.produce(context)
|
||||
return json.dumps({"status": "success", "message": "Message received"})
|
||||
# 将请求ID与会话ID关联
|
||||
self.request_to_session[request_id] = session_id
|
||||
|
||||
# 确保会话队列存在
|
||||
if session_id not in self.session_queues:
|
||||
self.session_queues[session_id] = Queue()
|
||||
|
||||
# 创建消息对象
|
||||
msg = WebMessage(self._generate_msg_id(), prompt)
|
||||
msg.from_user_id = session_id # 使用会话ID作为用户ID
|
||||
|
||||
# 创建上下文
|
||||
context = self._compose_context(ContextType.TEXT, prompt, msg=msg)
|
||||
|
||||
# 添加必要的字段
|
||||
context["session_id"] = session_id
|
||||
context["request_id"] = request_id
|
||||
context["isgroup"] = False # 添加 isgroup 字段
|
||||
context["receiver"] = session_id # 添加 receiver 字段
|
||||
|
||||
# 异步处理消息 - 只传递上下文
|
||||
threading.Thread(target=self.produce, args=(context,)).start()
|
||||
|
||||
# 返回请求ID
|
||||
return json.dumps({"status": "success", "request_id": request_id})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
return json.dumps({"status": "error", "message": "Internal server error"})
|
||||
return json.dumps({"status": "error", "message": str(e)})
|
||||
|
||||
def poll_response(self):
|
||||
"""
|
||||
Poll for responses using the session_id.
|
||||
"""
|
||||
try:
|
||||
# 不记录轮询请求的日志
|
||||
web.ctx.log_request = False
|
||||
|
||||
data = web.data()
|
||||
json_data = json.loads(data)
|
||||
session_id = json_data.get('session_id')
|
||||
|
||||
if not session_id or session_id not in self.session_queues:
|
||||
return json.dumps({"status": "error", "message": "Invalid session ID"})
|
||||
|
||||
# 尝试从队列获取响应,不等待
|
||||
try:
|
||||
# 使用peek而不是get,这样如果前端没有成功处理,下次还能获取到
|
||||
response = self.session_queues[session_id].get(block=False)
|
||||
|
||||
# 返回响应,包含请求ID以区分不同请求
|
||||
return json.dumps({
|
||||
"status": "success",
|
||||
"has_content": True,
|
||||
"content": response["content"],
|
||||
"request_id": response["request_id"],
|
||||
"timestamp": response["timestamp"]
|
||||
})
|
||||
|
||||
except Empty:
|
||||
# 没有新响应
|
||||
return json.dumps({"status": "success", "has_content": False})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error polling response: {e}")
|
||||
return json.dumps({"status": "error", "message": str(e)})
|
||||
|
||||
def chat_page(self):
|
||||
"""Serve the chat HTML page."""
|
||||
@@ -176,22 +186,51 @@ class WebChannel(ChatChannel):
|
||||
return f.read()
|
||||
|
||||
def startup(self):
|
||||
logger.setLevel("WARN")
|
||||
print("\nWeb Channel is running. Send POST requests to /message to send messages.")
|
||||
logger.info("""[WebChannel] 当前channel为web,可修改 config.json 配置文件中的 channel_type 字段进行切换。全部可用类型为:
|
||||
1. web: 网页
|
||||
2. terminal: 终端
|
||||
3. wechatmp: 个人公众号
|
||||
4. wechatmp_service: 企业公众号
|
||||
5. wechatcom_app: 企微自建应用
|
||||
6. dingtalk: 钉钉
|
||||
7. feishu: 飞书""")
|
||||
logger.info("Web对话网页已运行, 请使用浏览器访问 http://localhost:9899/chat")
|
||||
|
||||
# 确保静态文件目录存在
|
||||
static_dir = os.path.join(os.path.dirname(__file__), 'static')
|
||||
if not os.path.exists(static_dir):
|
||||
os.makedirs(static_dir)
|
||||
logger.info(f"Created static directory: {static_dir}")
|
||||
|
||||
urls = (
|
||||
'/sse/(.+)', 'SSEHandler', # 修改路由以接收用户ID
|
||||
'/', 'RootHandler', # 添加根路径处理器
|
||||
'/message', 'MessageHandler',
|
||||
'/chat', 'ChatHandler',
|
||||
'/poll', 'PollHandler', # 添加轮询处理器
|
||||
'/chat', 'ChatHandler',
|
||||
'/assets/(.*)', 'AssetsHandler', # 匹配 /assets/任何路径
|
||||
)
|
||||
port = conf().get("web_port", 9899)
|
||||
app = web.application(urls, globals(), autoreload=False)
|
||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||
|
||||
# 禁用web.py的默认日志输出
|
||||
import io
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
# 配置web.py的日志级别为ERROR,只显示错误
|
||||
logging.getLogger("web").setLevel(logging.ERROR)
|
||||
|
||||
# 禁用web.httpserver的日志
|
||||
logging.getLogger("web.httpserver").setLevel(logging.ERROR)
|
||||
|
||||
# 临时重定向标准输出,捕获web.py的启动消息
|
||||
with redirect_stdout(io.StringIO()):
|
||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||
|
||||
|
||||
class SSEHandler:
|
||||
def GET(self, user_id):
|
||||
return WebChannel().sse_handler(user_id)
|
||||
class RootHandler:
|
||||
def GET(self):
|
||||
# 重定向到/chat
|
||||
raise web.seeother('/chat')
|
||||
|
||||
|
||||
class MessageHandler:
|
||||
@@ -199,6 +238,54 @@ class MessageHandler:
|
||||
return WebChannel().post_message()
|
||||
|
||||
|
||||
class PollHandler:
|
||||
def POST(self):
|
||||
return WebChannel().poll_response()
|
||||
|
||||
|
||||
class ChatHandler:
|
||||
def GET(self):
|
||||
return WebChannel().chat_page()
|
||||
# 正常返回聊天页面
|
||||
file_path = os.path.join(os.path.dirname(__file__), 'chat.html')
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
class AssetsHandler:
|
||||
def GET(self, file_path): # 修改默认参数
|
||||
try:
|
||||
# 如果请求是/static/,需要处理
|
||||
if file_path == '':
|
||||
# 返回目录列表...
|
||||
pass
|
||||
|
||||
# 获取当前文件的绝对路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
static_dir = os.path.join(current_dir, 'static')
|
||||
|
||||
full_path = os.path.normpath(os.path.join(static_dir, file_path))
|
||||
|
||||
# 安全检查:确保请求的文件在static目录内
|
||||
if not os.path.abspath(full_path).startswith(os.path.abspath(static_dir)):
|
||||
logger.error(f"Security check failed for path: {full_path}")
|
||||
raise web.notfound()
|
||||
|
||||
if not os.path.exists(full_path) or not os.path.isfile(full_path):
|
||||
logger.error(f"File not found: {full_path}")
|
||||
raise web.notfound()
|
||||
|
||||
# 设置正确的Content-Type
|
||||
content_type = mimetypes.guess_type(full_path)[0]
|
||||
if content_type:
|
||||
web.header('Content-Type', content_type)
|
||||
else:
|
||||
# 默认为二进制流
|
||||
web.header('Content-Type', 'application/octet-stream')
|
||||
|
||||
# 读取并返回文件内容
|
||||
with open(full_path, 'rb') as f:
|
||||
return f.read()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error serving static file: {e}", exc_info=True) # 添加更详细的错误信息
|
||||
raise web.notfound()
|
||||
|
||||
@@ -0,0 +1,179 @@
|
||||
# encoding:utf-8
|
||||
|
||||
"""
|
||||
wechat channel
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from queue import Empty
|
||||
from typing import Any
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.wechat.wcf_message import WechatfMessage
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.utils import *
|
||||
from config import conf, get_appdata_dir
|
||||
from wcferry import Wcf, WxMsg
|
||||
|
||||
|
||||
@singleton
|
||||
class WechatfChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.NOT_SUPPORT_REPLYTYPE = []
|
||||
# 使用字典存储最近消息,用于去重
|
||||
self.received_msgs = {}
|
||||
# 初始化wcferry客户端
|
||||
self.wcf = Wcf()
|
||||
self.wxid = None # 登录后会被设置为当前登录用户的wxid
|
||||
|
||||
def startup(self):
|
||||
"""
|
||||
启动通道
|
||||
"""
|
||||
try:
|
||||
# wcferry会自动唤起微信并登录
|
||||
self.wxid = self.wcf.get_self_wxid()
|
||||
self.name = self.wcf.get_user_info().get("name")
|
||||
logger.info(f"微信登录成功,当前用户ID: {self.wxid}, 用户名:{self.name}")
|
||||
self.contact_cache = ContactCache(self.wcf)
|
||||
self.contact_cache.update()
|
||||
# 启动消息接收
|
||||
self.wcf.enable_receiving_msg()
|
||||
# 创建消息处理线程
|
||||
t = threading.Thread(target=self._process_messages, name="WeChatThread", daemon=True)
|
||||
t.start()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"微信通道启动失败: {e}")
|
||||
raise e
|
||||
|
||||
def _process_messages(self):
|
||||
"""
|
||||
处理消息队列
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
msg = self.wcf.get_msg()
|
||||
if msg:
|
||||
self._handle_message(msg)
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息失败: {e}")
|
||||
continue
|
||||
|
||||
def _handle_message(self, msg: WxMsg):
|
||||
"""
|
||||
处理单条消息
|
||||
"""
|
||||
try:
|
||||
# 构造消息对象
|
||||
cmsg = WechatfMessage(self, msg)
|
||||
# 消息去重
|
||||
if cmsg.msg_id in self.received_msgs:
|
||||
return
|
||||
self.received_msgs[cmsg.msg_id] = time.time()
|
||||
# 清理过期消息ID
|
||||
self._clean_expired_msgs()
|
||||
|
||||
logger.debug(f"收到消息: {msg}")
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content,
|
||||
isgroup=cmsg.is_group,
|
||||
msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息失败: {e}")
|
||||
|
||||
def _clean_expired_msgs(self, expire_time: float = 60):
|
||||
"""
|
||||
清理过期的消息ID
|
||||
"""
|
||||
now = time.time()
|
||||
for msg_id in list(self.received_msgs.keys()):
|
||||
if now - self.received_msgs[msg_id] > expire_time:
|
||||
del self.received_msgs[msg_id]
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
"""
|
||||
发送消息
|
||||
"""
|
||||
receiver = context["receiver"]
|
||||
if not receiver:
|
||||
logger.error("receiver is empty")
|
||||
return
|
||||
|
||||
try:
|
||||
if reply.type == ReplyType.TEXT:
|
||||
# 处理@信息
|
||||
at_list = []
|
||||
if context.get("isgroup"):
|
||||
if context["msg"].actual_user_id:
|
||||
at_list = [context["msg"].actual_user_id]
|
||||
at_str = ",".join(at_list) if at_list else ""
|
||||
self.wcf.send_text(reply.content, receiver, at_str)
|
||||
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
self.wcf.send_text(reply.content, receiver)
|
||||
else:
|
||||
logger.error(f"暂不支持的消息类型: {reply.type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
关闭通道
|
||||
"""
|
||||
try:
|
||||
self.wcf.cleanup()
|
||||
except Exception as e:
|
||||
logger.error(f"关闭通道失败: {e}")
|
||||
|
||||
|
||||
class ContactCache:
|
||||
def __init__(self, wcf):
|
||||
"""
|
||||
wcf: 一个 wcfferry.client.Wcf 实例
|
||||
"""
|
||||
self.wcf = wcf
|
||||
self._contact_map = {} # 形如 {wxid: {完整联系人信息}}
|
||||
|
||||
def update(self):
|
||||
"""
|
||||
更新缓存:调用 get_contacts(),
|
||||
再把 wcf.contacts 构建成 {wxid: {完整信息}} 的字典
|
||||
"""
|
||||
self.wcf.get_contacts()
|
||||
self._contact_map.clear()
|
||||
for item in self.wcf.contacts:
|
||||
wxid = item.get('wxid')
|
||||
if wxid: # 确保有 wxid 字段
|
||||
self._contact_map[wxid] = item
|
||||
|
||||
def get_contact(self, wxid: str) -> dict:
|
||||
"""
|
||||
返回该 wxid 对应的完整联系人 dict,
|
||||
如果没找到就返回 None
|
||||
"""
|
||||
return self._contact_map.get(wxid)
|
||||
|
||||
def get_name_by_wxid(self, wxid: str) -> str:
|
||||
"""
|
||||
通过wxid,获取成员/群名称
|
||||
"""
|
||||
contact = self.get_contact(wxid)
|
||||
if contact:
|
||||
return contact.get('name', '')
|
||||
return ''
|
||||
@@ -0,0 +1,58 @@
|
||||
# encoding:utf-8
|
||||
|
||||
"""
|
||||
wechat channel message
|
||||
"""
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from wcferry import WxMsg
|
||||
|
||||
|
||||
class WechatfMessage(ChatMessage):
|
||||
"""
|
||||
微信消息封装类
|
||||
"""
|
||||
|
||||
def __init__(self, channel, wcf_msg: WxMsg, is_group=False):
|
||||
"""
|
||||
初始化消息对象
|
||||
:param wcf_msg: wcferry消息对象
|
||||
:param is_group: 是否是群消息
|
||||
"""
|
||||
super().__init__(wcf_msg)
|
||||
self.msg_id = wcf_msg.id
|
||||
self.create_time = wcf_msg.ts # 使用消息时间戳
|
||||
self.is_group = is_group or wcf_msg._is_group
|
||||
self.wxid = channel.wxid
|
||||
self.name = channel.name
|
||||
|
||||
# 解析消息类型
|
||||
if wcf_msg.is_text():
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = wcf_msg.content
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported message type: {wcf_msg.type}")
|
||||
|
||||
# 设置发送者和接收者信息
|
||||
self.from_user_id = self.wxid if wcf_msg.sender == self.wxid else wcf_msg.sender
|
||||
self.from_user_nickname = self.name if wcf_msg.sender == self.wxid else channel.contact_cache.get_name_by_wxid(wcf_msg.sender)
|
||||
self.to_user_id = self.wxid
|
||||
self.to_user_nickname = self.name
|
||||
self.other_user_id = wcf_msg.sender
|
||||
self.other_user_nickname = channel.contact_cache.get_name_by_wxid(wcf_msg.sender)
|
||||
|
||||
# 群消息特殊处理
|
||||
if self.is_group:
|
||||
self.other_user_id = wcf_msg.roomid
|
||||
self.other_user_nickname = channel.contact_cache.get_name_by_wxid(wcf_msg.roomid)
|
||||
self.actual_user_id = wcf_msg.sender
|
||||
self.actual_user_nickname = channel.wcf.get_alias_in_chatroom(wcf_msg.sender, wcf_msg.roomid)
|
||||
if not self.actual_user_nickname: # 群聊获取不到企微号成员昵称,这里尝试从联系人缓存去获取
|
||||
self.actual_user_nickname = channel.contact_cache.get_name_by_wxid(wcf_msg.sender)
|
||||
self.room_id = wcf_msg.roomid
|
||||
self.is_at = wcf_msg.is_at(self.wxid) # 是否被@当前登录用户
|
||||
|
||||
# 判断是否是自己发送的消息
|
||||
self.my_msg = wcf_msg.from_self()
|
||||
@@ -117,23 +117,35 @@ class WechatChannel(ChatChannel):
|
||||
|
||||
def startup(self):
|
||||
try:
|
||||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
|
||||
# login by scan QRCode
|
||||
hotReload = conf().get("hot_reload", False)
|
||||
status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
|
||||
itchat.auto_login(
|
||||
enableCmdQR=2,
|
||||
hotReload=hotReload,
|
||||
statusStorageDir=status_path,
|
||||
qrCallback=qrCallback,
|
||||
exitCallback=self.exitCallback,
|
||||
loginCallback=self.loginCallback
|
||||
)
|
||||
self.user_id = itchat.instance.storageClass.userName
|
||||
self.name = itchat.instance.storageClass.nickName
|
||||
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
|
||||
# start message listener
|
||||
itchat.run()
|
||||
time.sleep(3)
|
||||
logger.error("""[WechatChannel] 当前channel暂不可用,目前支持的channel有:
|
||||
1. terminal: 终端
|
||||
2. wechatmp: 个人公众号
|
||||
3. wechatmp_service: 企业公众号
|
||||
4. wechatcom_app: 企微自建应用
|
||||
5. dingtalk: 钉钉
|
||||
6. feishu: 飞书
|
||||
7. web: 网页
|
||||
8. wcf: wechat (需Windows环境,参考 https://github.com/zhayujie/chatgpt-on-wechat/pull/2562 )
|
||||
可修改 config.json 配置文件的 channel_type 字段进行切换""")
|
||||
|
||||
# itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
|
||||
# # login by scan QRCode
|
||||
# hotReload = conf().get("hot_reload", False)
|
||||
# status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
|
||||
# itchat.auto_login(
|
||||
# enableCmdQR=2,
|
||||
# hotReload=hotReload,
|
||||
# statusStorageDir=status_path,
|
||||
# qrCallback=qrCallback,
|
||||
# exitCallback=self.exitCallback,
|
||||
# loginCallback=self.loginCallback
|
||||
# )
|
||||
# self.user_id = itchat.instance.storageClass.userName
|
||||
# self.name = itchat.instance.storageClass.nickName
|
||||
# logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
|
||||
# # start message listener
|
||||
# itchat.run()
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
|
||||
@@ -1,21 +1,43 @@
|
||||
# wechatcomapp_client.py
|
||||
import threading
|
||||
import time
|
||||
|
||||
from wechatpy.enterprise import WeChatClient
|
||||
|
||||
|
||||
class WechatComAppClient(WeChatClient):
|
||||
def __init__(self, corp_id, secret, access_token=None, session=None, timeout=None, auto_retry=True):
|
||||
super(WechatComAppClient, self).__init__(corp_id, secret, access_token, session, timeout, auto_retry)
|
||||
self.fetch_access_token_lock = threading.Lock()
|
||||
self._active_refresh()
|
||||
|
||||
def _active_refresh(self):
|
||||
"""启动主动刷新的后台线程"""
|
||||
def refresh_loop():
|
||||
while True:
|
||||
now = time.time()
|
||||
expires_at = self.session.get(f"{self.corp_id}_expires_at", 0)
|
||||
|
||||
# 提前10分钟刷新(600秒)
|
||||
if expires_at - now < 600:
|
||||
with self.fetch_access_token_lock:
|
||||
# 双重检查避免重复刷新
|
||||
if self.session.get(f"{self.corp_id}_expires_at", 0) - time.time() < 600:
|
||||
super(WechatComAppClient, self).fetch_access_token()
|
||||
# 每次检查间隔60秒
|
||||
time.sleep(60)
|
||||
|
||||
# 启动守护线程
|
||||
refresh_thread = threading.Thread(
|
||||
target=refresh_loop,
|
||||
daemon=True,
|
||||
name="wechatcom_token_refresh_thread"
|
||||
)
|
||||
refresh_thread.start()
|
||||
|
||||
def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token
|
||||
def fetch_access_token(self):
|
||||
with self.fetch_access_token_lock:
|
||||
access_token = self.session.get(self.access_token_key)
|
||||
if access_token:
|
||||
if not self.expires_at:
|
||||
return access_token
|
||||
timestamp = time.time()
|
||||
if self.expires_at - timestamp > 60:
|
||||
return access_token
|
||||
return super().fetch_access_token()
|
||||
expires_at = self.session.get(f"{self.corp_id}_expires_at", 0)
|
||||
|
||||
if access_token and expires_at > time.time() + 60:
|
||||
return access_token
|
||||
return super().fetch_access_token()
|
||||
+23
-6
@@ -15,7 +15,7 @@ GEMINI = "gemini" # gemini-1.0-pro
|
||||
ZHIPU_AI = "glm-4"
|
||||
MOONSHOT = "moonshot"
|
||||
MiniMax = "minimax"
|
||||
|
||||
MODELSCOPE = "modelscope"
|
||||
|
||||
# model
|
||||
CLAUDE3 = "claude-3-opus-20240229"
|
||||
@@ -37,6 +37,9 @@ GPT_4o_MINI = "gpt-4o-mini"
|
||||
GPT4_32k = "gpt-4-32k"
|
||||
GPT4_06_13 = "gpt-4-0613"
|
||||
GPT4_32k_06_13 = "gpt-4-32k-0613"
|
||||
GPT_41 = "gpt-4.1"
|
||||
GPT_41_MINI = "gpt-4.1-mini"
|
||||
GPT_41_NANO = "gpt-4.1-nano"
|
||||
|
||||
O1 = "o1-preview"
|
||||
O1_MINI = "o1-mini"
|
||||
@@ -74,28 +77,42 @@ GLM_4_AIRX = "glm-4-airx"
|
||||
|
||||
CLAUDE_3_OPUS = "claude-3-opus-latest"
|
||||
CLAUDE_3_OPUS_0229 = "claude-3-opus-20240229"
|
||||
|
||||
CLAUDE_35_SONNET = "claude-3-5-sonnet-latest" # 带 latest 标签的模型名称,会不断更新指向最新发布的模型
|
||||
CLAUDE_35_SONNET_1022 = "claude-3-5-sonnet-20241022" # 带具体日期的模型名称,会固定为该日期发布的模型
|
||||
CLAUDE_35_SONNET_0620 = "claude-3-5-sonnet-20240620"
|
||||
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
|
||||
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
CLAUDE_4_SONNET = "claude-sonnet-4-0"
|
||||
CLAUDE_4_OPUS = "claude-opus-4-0"
|
||||
|
||||
DEEPSEEK_CHAT = "deepseek-chat" # DeepSeek-V3对话模型
|
||||
DEEPSEEK_REASONER = "deepseek-reasoner" # DeepSeek-R1模型
|
||||
|
||||
GITEE_AI_MODEL_LIST = ["Yi-34B-Chat", "InternVL2-8B", "deepseek-coder-33B-instruct", "InternVL2.5-26B", "Qwen2-VL-72B", "Qwen2.5-32B-Instruct", "glm-4-9b-chat", "codegeex4-all-9b", "Qwen2.5-Coder-32B-Instruct", "Qwen2.5-72B-Instruct", "Qwen2.5-7B-Instruct", "Qwen2-72B-Instruct", "Qwen2-7B-Instruct", "code-raccoon-v1", "Qwen2.5-14B-Instruct"]
|
||||
|
||||
MODELSCOPE_MODEL_LIST = ["LLM-Research/c4ai-command-r-plus-08-2024","mistralai/Mistral-Small-Instruct-2409","mistralai/Ministral-8B-Instruct-2410","mistralai/Mistral-Large-Instruct-2407",
|
||||
"Qwen/Qwen2.5-Coder-32B-Instruct","Qwen/Qwen2.5-Coder-14B-Instruct","Qwen/Qwen2.5-Coder-7B-Instruct","Qwen/Qwen2.5-72B-Instruct","Qwen/Qwen2.5-32B-Instruct","Qwen/Qwen2.5-14B-Instruct","Qwen/Qwen2.5-7B-Instruct","Qwen/QwQ-32B-Preview",
|
||||
"LLM-Research/Llama-3.3-70B-Instruct","opencompass/CompassJudger-1-32B-Instruct","Qwen/QVQ-72B-Preview","LLM-Research/Meta-Llama-3.1-405B-Instruct","LLM-Research/Meta-Llama-3.1-8B-Instruct","Qwen/Qwen2-VL-7B-Instruct","LLM-Research/Meta-Llama-3.1-70B-Instruct",
|
||||
"Qwen/Qwen2.5-14B-Instruct-1M","Qwen/Qwen2.5-7B-Instruct-1M","Qwen/Qwen2.5-VL-3B-Instruct","Qwen/Qwen2.5-VL-7B-Instruct","Qwen/Qwen2.5-VL-72B-Instruct","deepseek-ai/DeepSeek-R1-Distill-Llama-70B","deepseek-ai/DeepSeek-R1-Distill-Llama-8B","deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
|
||||
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B","deepseek-ai/DeepSeek-R1-Distill-Qwen-7B","deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B","deepseek-ai/DeepSeek-R1","deepseek-ai/DeepSeek-V3","Qwen/QwQ-32B"]
|
||||
|
||||
MODEL_LIST = [
|
||||
GPT35, GPT35_0125, GPT35_1106, "gpt-3.5-turbo-16k",
|
||||
O1, O1_MINI, GPT_4o, GPT_4O_0806, GPT_4o_MINI, GPT4_TURBO, GPT4_TURBO_PREVIEW, GPT4_TURBO_01_25, GPT4_TURBO_11_06, GPT4, GPT4_32k, GPT4_06_13, GPT4_32k_06_13,
|
||||
GPT_41, GPT_41_MINI, GPT_41_NANO, O1, O1_MINI, GPT_4o, GPT_4O_0806, GPT_4o_MINI, GPT4_TURBO, GPT4_TURBO_PREVIEW, GPT4_TURBO_01_25, GPT4_TURBO_11_06, GPT4, GPT4_32k, GPT4_06_13, GPT4_32k_06_13,
|
||||
WEN_XIN, WEN_XIN_4,
|
||||
XUNFEI,
|
||||
ZHIPU_AI, GLM_4, GLM_4_PLUS, GLM_4_flash, GLM_4_LONG, GLM_4_ALLTOOLS, GLM_4_0520, GLM_4_AIR, GLM_4_AIRX,
|
||||
MOONSHOT, MiniMax,
|
||||
GEMINI, GEMINI_PRO, GEMINI_15_flash, GEMINI_15_PRO,GEMINI_20_flash_exp,
|
||||
CLAUDE_3_OPUS, CLAUDE_3_OPUS_0229, CLAUDE_35_SONNET, CLAUDE_35_SONNET_1022, CLAUDE_35_SONNET_0620, CLAUDE_3_SONNET, CLAUDE_3_HAIKU, "claude", "claude-3-haiku", "claude-3-sonnet", "claude-3-opus", "claude-3.5-sonnet",
|
||||
CLAUDE_4_OPUS, CLAUDE_4_SONNET, CLAUDE_3_OPUS, CLAUDE_3_OPUS_0229, CLAUDE_35_SONNET, CLAUDE_35_SONNET_1022, CLAUDE_35_SONNET_0620, CLAUDE_3_SONNET, CLAUDE_3_HAIKU, "claude", "claude-3-haiku", "claude-3-sonnet", "claude-3-opus", "claude-3.5-sonnet",
|
||||
"moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k",
|
||||
QWEN, QWEN_TURBO, QWEN_PLUS, QWEN_MAX,
|
||||
LINKAI_35, LINKAI_4_TURBO, LINKAI_4o
|
||||
LINKAI_35, LINKAI_4_TURBO, LINKAI_4o,
|
||||
DEEPSEEK_CHAT, DEEPSEEK_REASONER,
|
||||
MODELSCOPE
|
||||
]
|
||||
|
||||
MODEL_LIST = MODEL_LIST + GITEE_AI_MODEL_LIST + MODELSCOPE_MODEL_LIST
|
||||
# channel
|
||||
FEISHU = "feishu"
|
||||
DINGTALK = "dingtalk"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"channel_type": "wx",
|
||||
"channel_type": "web",
|
||||
"model": "",
|
||||
"open_ai_api_key": "YOUR API KEY",
|
||||
"claude_api_key": "YOUR API KEY",
|
||||
|
||||
@@ -171,6 +171,9 @@ available_setting = {
|
||||
"zhipu_ai_api_base": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"moonshot_api_key": "",
|
||||
"moonshot_base_url": "https://api.moonshot.cn/v1/chat/completions",
|
||||
#魔搭社区 平台配置
|
||||
"modelscope_api_key": "",
|
||||
"modelscope_base_url": "https://api-inference.modelscope.cn/v1/chat/completions",
|
||||
# LinkAI平台配置
|
||||
"use_linkai": False,
|
||||
"linkai_api_key": "",
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
# Agent插件
|
||||
|
||||
## 插件说明
|
||||
|
||||
基于 [AgentMesh](https://github.com/MinimalFuture/AgentMesh) 多智能体框架实现的Agent插件,可以让机器人快速获得Agent能力,通过自然语言对话来访问 **终端、浏览器、文件系统、搜索引擎** 等各类工具。
|
||||
同时还支持通过 **多智能体协作** 来完成复杂任务,例如多智能体任务分发、多智能体问题讨论、协同处理等。
|
||||
|
||||
AgentMesh项目地址:https://github.com/MinimalFuture/AgentMesh
|
||||
|
||||
## 安装
|
||||
|
||||
1. 确保已安装依赖:
|
||||
|
||||
```bash
|
||||
pip install agentmesh-sdk>=0.1.2
|
||||
```
|
||||
|
||||
2. 如需使用浏览器工具,还需安装:
|
||||
|
||||
```bash
|
||||
pip install browser-use>=0.1.40
|
||||
playwright install
|
||||
```
|
||||
|
||||
## 配置
|
||||
|
||||
插件配置文件是 `plugins/agent`目录下的 `config.yaml`,包含智能体团队的配置以及工具的配置,可以从模板文件 `config-template.yaml`中复制:
|
||||
|
||||
```bash
|
||||
cp config-template.yaml config.yaml
|
||||
```
|
||||
|
||||
说明:
|
||||
|
||||
- `team`配置是默认选中的 agent team
|
||||
- `teams` 下是Agent团队配置,团队的model默认为`gpt-4.1-mini`,可根据需要进行修改,模型对应的 `api_key` 需要在项目根目录的 `config.json` 全局配置中进行配置。例如openai模型需要配置 `open_ai_api_key`
|
||||
- 支持为 `agents` 下面的每个agent添加model字段来设置不同的模型
|
||||
|
||||
|
||||
## 使用方法
|
||||
|
||||
在对机器人发送的消息中使用 `$agent` 前缀来触发插件,支持以下命令:
|
||||
|
||||
- `$agent [task]`: 使用默认团队执行任务 (默认团队可通 config.yaml 中的team配置修改)
|
||||
- `$agent teams`: 列出可用的团队
|
||||
- `$agent use [team_name] [task]`: 使用指定的团队执行任务
|
||||
|
||||
|
||||
### 示例
|
||||
|
||||
```bash
|
||||
$agent 帮我查看当前目录下有哪些文件夹
|
||||
$agent teams
|
||||
$agent use software_team 帮我写一个产品预约体验的表单页面
|
||||
```
|
||||
|
||||
## 工具支持
|
||||
|
||||
目前支持多种内置工具,包括但不限于:
|
||||
|
||||
- `calculator`: 数学计算工具
|
||||
- `current_time`: 获取当前时间
|
||||
- `browser`: 浏览器操作工具,注意需安装`browser-use`依赖
|
||||
- `google_search`: 搜索引擎,注意需在`config.yaml`中配置 `api_key`
|
||||
- `file_save`: 文件保存工具,开启后智能体输出的内容将保存在 `workspace` 目录下
|
||||
- `terminal`: 终端命令执行工具
|
||||
@@ -0,0 +1,3 @@
|
||||
from .agent import AgentPlugin
|
||||
|
||||
__all__ = ["AgentPlugin"]
|
||||
@@ -0,0 +1,282 @@
|
||||
import os
|
||||
import yaml
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from agentmesh import AgentTeam, Agent, LLMModel
|
||||
from agentmesh.models import ClaudeModel
|
||||
from agentmesh.tools import ToolManager
|
||||
from config import conf
|
||||
|
||||
import plugins
|
||||
from plugins import Plugin, Event, EventContext, EventAction
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
|
||||
|
||||
@plugins.register(
|
||||
name="agent",
|
||||
desc="Use AgentMesh framework to process tasks with multi-agent teams",
|
||||
version="0.1.0",
|
||||
author="Saboteur7",
|
||||
desire_priority=1,
|
||||
)
|
||||
class AgentPlugin(Plugin):
|
||||
"""Plugin for integrating AgentMesh framework."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
self.name = "agent"
|
||||
self.description = "Use AgentMesh framework to process tasks with multi-agent teams"
|
||||
self.config = self._load_config()
|
||||
self.tool_manager = ToolManager()
|
||||
self.tool_manager.load_tools(config_dict=self.config.get("tools"))
|
||||
logger.info("[agent] inited")
|
||||
|
||||
def _load_config(self) -> Dict:
|
||||
"""Load configuration from config.yaml file."""
|
||||
config_path = os.path.join(self.path, "config.yaml")
|
||||
if not os.path.exists(config_path):
|
||||
logger.warning(f"Config file not found at {config_path}")
|
||||
return {}
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
def get_help_text(self, verbose=False, **kwargs):
|
||||
"""Return help message for the agent plugin."""
|
||||
help_text = "通过AgentMesh实现对终端、浏览器、文件系统、搜索引擎等工具的执行,并支持多智能体协作。"
|
||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||
|
||||
if not verbose:
|
||||
return help_text
|
||||
|
||||
teams = self.get_available_teams()
|
||||
teams_str = ", ".join(teams) if teams else "未配置任何团队"
|
||||
|
||||
help_text += "\n\n使用说明:\n"
|
||||
help_text += f"{trigger_prefix}agent [task] - 使用默认团队执行任务\n"
|
||||
help_text += f"{trigger_prefix}agent teams - 列出可用的团队\n"
|
||||
help_text += f"{trigger_prefix}agent use [team_name] [task] - 使用特定团队执行任务\n\n"
|
||||
help_text += f"可用团队: \n{teams_str}\n\n"
|
||||
help_text += f"示例:\n"
|
||||
help_text += f"{trigger_prefix}agent 帮我查看当前文件夹路径\n"
|
||||
help_text += f"{trigger_prefix}agent use software_team 帮我写一个产品预约体验的表单页面"
|
||||
return help_text
|
||||
|
||||
def get_available_teams(self) -> List[str]:
|
||||
"""Get list of available teams from configuration."""
|
||||
teams_config = self.config.get("teams", {})
|
||||
return list(teams_config.keys())
|
||||
|
||||
|
||||
def create_team_from_config(self, team_name: str) -> Optional[AgentTeam]:
|
||||
"""Create a team from configuration."""
|
||||
# Get teams configuration
|
||||
teams_config = self.config.get("teams", {})
|
||||
|
||||
# Check if the specified team exists
|
||||
if team_name not in teams_config:
|
||||
logger.error(f"Team '{team_name}' not found in configuration.")
|
||||
available_teams = list(teams_config.keys())
|
||||
logger.info(f"Available teams: {', '.join(available_teams)}")
|
||||
return None
|
||||
|
||||
# Get team configuration
|
||||
team_config = teams_config[team_name]
|
||||
|
||||
# Get team's model
|
||||
team_model_name = team_config.get("model", "gpt-4.1-mini")
|
||||
team_model = self.create_llm_model(team_model_name)
|
||||
|
||||
# Get team's max_steps (default to 20 if not specified)
|
||||
team_max_steps = team_config.get("max_steps", 20)
|
||||
|
||||
# Create team with the model
|
||||
team = AgentTeam(
|
||||
name=team_name,
|
||||
description=team_config.get("description", ""),
|
||||
rule=team_config.get("rule", ""),
|
||||
model=team_model,
|
||||
max_steps=team_max_steps
|
||||
)
|
||||
|
||||
# Create and add agents to the team
|
||||
agents_config = team_config.get("agents", [])
|
||||
for agent_config in agents_config:
|
||||
# Check if agent has a specific model
|
||||
if agent_config.get("model"):
|
||||
agent_model = self.create_llm_model(agent_config.get("model"))
|
||||
else:
|
||||
agent_model = team_model
|
||||
|
||||
# Get agent's max_steps
|
||||
agent_max_steps = agent_config.get("max_steps")
|
||||
|
||||
agent = Agent(
|
||||
name=agent_config.get("name", ""),
|
||||
system_prompt=agent_config.get("system_prompt", ""),
|
||||
model=agent_model, # Use agent's model if specified, otherwise will use team's model
|
||||
description=agent_config.get("description", ""),
|
||||
max_steps=agent_max_steps
|
||||
)
|
||||
|
||||
# Add tools to the agent if specified
|
||||
tool_names = agent_config.get("tools", [])
|
||||
for tool_name in tool_names:
|
||||
tool = self.tool_manager.create_tool(tool_name)
|
||||
if tool:
|
||||
agent.add_tool(tool)
|
||||
else:
|
||||
if tool_name == "browser":
|
||||
logger.warning(
|
||||
"Tool 'Browser' loaded failed, "
|
||||
"please install the required dependency with: \n"
|
||||
"'pip install browser-use>=0.1.40' or 'pip install agentmesh-sdk[full]'\n"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Tool '{tool_name}' not found for agent '{agent.name}'\n")
|
||||
|
||||
# Add agent to team
|
||||
team.add(agent)
|
||||
|
||||
return team
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
"""Handle the message context."""
|
||||
if e_context['context'].type != ContextType.TEXT:
|
||||
return
|
||||
content = e_context['context'].content
|
||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||
|
||||
if not content.startswith(f"{trigger_prefix}agent "):
|
||||
e_context.action = EventAction.CONTINUE
|
||||
return
|
||||
|
||||
if not self.config:
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.ERROR
|
||||
reply.content = "未找到插件配置,请在 plugins/agent 目录下创建 config.yaml 配置文件,可根据 config-template.yml 模板文件复制"
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
|
||||
# Extract the actual task
|
||||
task = content[len(f"{trigger_prefix}agent "):].strip()
|
||||
|
||||
# If task is empty, return help message
|
||||
if not task:
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
reply.content = self.get_help_text(verbose=True)
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
|
||||
# Check if task is asking for available teams
|
||||
if task.lower() in ["teams", "list teams", "show teams"]:
|
||||
teams = self.get_available_teams()
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
|
||||
if not teams:
|
||||
reply.content = "未配置任何团队。请检查 config.yaml 文件。"
|
||||
else:
|
||||
reply.content = f"可用团队: {', '.join(teams)}"
|
||||
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
|
||||
# Check if task specifies a team
|
||||
team_name = None
|
||||
if task.startswith("use "):
|
||||
parts = task[4:].split(" ", 1)
|
||||
if len(parts) > 0:
|
||||
team_name = parts[0]
|
||||
if len(parts) > 1:
|
||||
task = parts[1].strip()
|
||||
else:
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
reply.content = f"已选择团队 '{team_name}'。请输入您想执行的任务。"
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
if not team_name:
|
||||
team_name = self.config.get("team")
|
||||
|
||||
# If no team specified, use default or first available
|
||||
if not team_name:
|
||||
teams = self.configself.get_available_teams()
|
||||
if not teams:
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
reply.content = "未配置任何团队。请检查 config.yaml 文件。"
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
team_name = teams[0]
|
||||
|
||||
# Create team
|
||||
team = self.create_team_from_config(team_name)
|
||||
if not team:
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
reply.content = f"创建团队 '{team_name}' 失败。请检查配置。"
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
|
||||
# Run the task
|
||||
try:
|
||||
logger.info(f"[agent] Running task '{task}' with team '{team_name}', team_model={team.model.model}")
|
||||
result = team.run_async(task=task)
|
||||
for agent_result in result:
|
||||
res_text = f"🤖 {agent_result.get('agent_name')}\n\n{agent_result.get('final_answer')}"
|
||||
_send_text(e_context, content=res_text)
|
||||
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
reply.content = ""
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error running task with team '{team_name}'")
|
||||
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.ERROR
|
||||
reply.content = f"执行任务时出错: {str(e)}"
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
|
||||
def create_llm_model(self, model_name) -> LLMModel:
|
||||
if conf().get("use_linkai"):
|
||||
api_base = "https://api.link-ai.tech/v1"
|
||||
api_key = conf().get("linkai_api_key")
|
||||
elif model_name.startswith(("gpt", "text-davinci", "o1", "o3")):
|
||||
api_base = conf().get("open_ai_api_base") or "https://api.openai.com/v1"
|
||||
api_key = conf().get("open_ai_api_key")
|
||||
elif model_name.startswith("claude"):
|
||||
return ClaudeModel(model=model_name, api_key=conf().get("claude_api_key"))
|
||||
elif model_name.startswith("moonshot"):
|
||||
api_base = "https://api.moonshot.cn/v1"
|
||||
api_key = conf().get("moonshot_api_key")
|
||||
elif model_name.startswith("qwen"):
|
||||
api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
api_key = conf().get("dashscope_api_key")
|
||||
else:
|
||||
api_base = conf().get("open_ai_api_base") or "https://api.openai.com/v1"
|
||||
api_key = conf().get("open_ai_api_key")
|
||||
|
||||
llm_model = LLMModel(model=model_name, api_key=api_key, api_base=api_base)
|
||||
return llm_model
|
||||
|
||||
|
||||
def _send_text(e_context: EventContext, content: str):
|
||||
reply = Reply(ReplyType.TEXT, content)
|
||||
channel = e_context["channel"]
|
||||
channel.send(reply, e_context["context"])
|
||||
@@ -0,0 +1,52 @@
|
||||
# 默认选中的Agent Team名称
|
||||
team: general_team
|
||||
|
||||
tools:
|
||||
google_search:
|
||||
# get your apikey from https://serper.dev/
|
||||
api_key: "YOUR API KEY"
|
||||
|
||||
# Agent Team 配置
|
||||
teams:
|
||||
# 通用智能体团队
|
||||
general_team:
|
||||
model: "gpt-4.1-mini" # 团队使用的模型
|
||||
description: "A versatile research and information agent team"
|
||||
max_steps: 5
|
||||
agents:
|
||||
- name: "通用智能助手"
|
||||
description: "Universal assistant specializing in research, information synthesis, and task execution"
|
||||
system_prompt: "You are a versatile assistant who answers questions and completes tasks using available tools. Reply in a clearly structured, attractive and easy to read format."
|
||||
# Agent 支持使用的工具
|
||||
tools:
|
||||
- time
|
||||
- calculator
|
||||
- google_search
|
||||
- browser
|
||||
- terminal
|
||||
|
||||
# 软件开发智能体团队
|
||||
software_team:
|
||||
model: "gpt-4.1-mini"
|
||||
description: "A software development team with product manager, developer and tester."
|
||||
rule: "A normal R&D process should be that Product Manager writes PRD, Developer writes code based on PRD, and Finally, Tester performs testing."
|
||||
max_steps: 10
|
||||
agents:
|
||||
- name: "Product-Manager"
|
||||
description: "Responsible for product requirements and documentation"
|
||||
system_prompt: "You are an experienced product manager who creates concise PRDs, focusing on user needs and feature specifications. You always format your responses in Markdown."
|
||||
tools:
|
||||
- time
|
||||
- file_save
|
||||
- name: "Developer"
|
||||
description: "Implements code based on PRD"
|
||||
system_prompt: "You are a skilled developer. When developing web application, you creates single-page website based on user needs, you deliver HTML files with embedded JavaScript and CSS that are visually appealing, responsive, and user-friendly, featuring a grand layout and beautiful background. The HTML, CSS, and JavaScript code should be well-structured and effectively organized."
|
||||
tools:
|
||||
- file_save
|
||||
- name: "Tester"
|
||||
description: "Tests code and verifies functionality"
|
||||
system_prompt: "You are a tester who validates code against requirements. For HTML applications, use browser tools to test functionality. For Python or other client-side applications, use the terminal tool to run and test. You only need to test a few core cases."
|
||||
tools:
|
||||
- file_save
|
||||
- browser
|
||||
- terminal
|
||||
@@ -155,7 +155,7 @@ def get_help_text(isadmin, isgroup):
|
||||
for plugin in plugins:
|
||||
if plugins[plugin].enabled and not plugins[plugin].hidden:
|
||||
namecn = plugins[plugin].namecn
|
||||
help_text += "\n%s:" % namecn
|
||||
help_text += "\n%s: " % namecn
|
||||
help_text += PluginManager().instances[plugin].get_help_text(verbose=False).strip()
|
||||
|
||||
if ADMIN_COMMANDS and isadmin:
|
||||
@@ -339,7 +339,8 @@ class Godcmd(Plugin):
|
||||
ok, result = True, "配置已重载"
|
||||
elif cmd == "resetall":
|
||||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI,
|
||||
const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI, const.ZHIPU_AI, const.MOONSHOT]:
|
||||
const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI, const.ZHIPU_AI, const.MOONSHOT,
|
||||
const.MODELSCOPE]:
|
||||
channel.cancel_all_session()
|
||||
bot.sessions.clear_all_session()
|
||||
ok, result = True, "重置所有会话成功"
|
||||
|
||||
@@ -99,7 +99,7 @@ class Role(Plugin):
|
||||
if e_context["context"].type != ContextType.TEXT:
|
||||
return
|
||||
btype = Bridge().get_bot_type("chat")
|
||||
if btype not in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.QWEN_DASHSCOPE, const.XUNFEI, const.BAIDU, const.ZHIPU_AI, const.MOONSHOT, const.MiniMax, const.LINKAI]:
|
||||
if btype not in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.QWEN_DASHSCOPE, const.XUNFEI, const.BAIDU, const.ZHIPU_AI, const.MOONSHOT, const.MiniMax, const.LINKAI,const.MODELSCOPE]:
|
||||
logger.debug(f'不支持的bot: {btype}')
|
||||
return
|
||||
bot = Bridge().get_bot("chat")
|
||||
|
||||
@@ -44,3 +44,6 @@ zhipuai>=2.0.1
|
||||
|
||||
# tongyi qwen new sdk
|
||||
dashscope
|
||||
|
||||
# tencentcloud sdk
|
||||
tencentcloud-sdk-python>=3.0.0
|
||||
|
||||
@@ -8,3 +8,4 @@ Pillow
|
||||
pre-commit
|
||||
web.py
|
||||
linkai>=0.0.6.0
|
||||
agentmesh-sdk>=0.1.3
|
||||
|
||||
@@ -0,0 +1,278 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>登录</title>
|
||||
<style>
|
||||
/* Reset and base */
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
}
|
||||
body, html {
|
||||
margin: 0; padding: 0; height: 100%;
|
||||
font-family: "Segoe UI", Tahoma, Geneva, Verdana, sans-serif;
|
||||
background: linear-gradient(135deg, #667eea, #764ba2);
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
.login-container {
|
||||
background: rgba(255, 255, 255, 0.95);
|
||||
padding: 2.5rem 3rem;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 8px 24px rgba(0,0,0,0.15);
|
||||
width: 100%;
|
||||
max-width: 400px;
|
||||
}
|
||||
h2 {
|
||||
margin-bottom: 1.5rem;
|
||||
color: #333;
|
||||
text-align: center;
|
||||
}
|
||||
form {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
label {
|
||||
font-weight: 600;
|
||||
margin-bottom: 0.4rem;
|
||||
color: #444;
|
||||
}
|
||||
input[type="text"],
|
||||
input[type="email"],
|
||||
input[type="password"] {
|
||||
padding: 0.6rem 0.8rem;
|
||||
font-size: 1rem;
|
||||
border: 1.8px solid #ccc;
|
||||
border-radius: 6px;
|
||||
transition: border-color 0.3s ease;
|
||||
outline-offset: 2px;
|
||||
}
|
||||
input[type="text"]:focus,
|
||||
input[type="email"]:focus,
|
||||
input[type="password"]:focus {
|
||||
border-color: #667eea;
|
||||
}
|
||||
.password-wrapper {
|
||||
position: relative;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
.toggle-password {
|
||||
position: absolute;
|
||||
right: 0.8rem;
|
||||
background: none;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
font-size: 1rem;
|
||||
color: #667eea;
|
||||
user-select: none;
|
||||
}
|
||||
.login-button {
|
||||
margin-top: 1.5rem;
|
||||
padding: 0.75rem;
|
||||
font-size: 1.1rem;
|
||||
font-weight: 700;
|
||||
background-color: #667eea;
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: 8px;
|
||||
cursor: pointer;
|
||||
transition: background-color 0.3s ease;
|
||||
}
|
||||
.login-button:disabled {
|
||||
background-color: #a3a9f7;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
.forgot-password {
|
||||
margin-top: 1rem;
|
||||
text-align: right;
|
||||
}
|
||||
.forgot-password a {
|
||||
color: #667eea;
|
||||
text-decoration: none;
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
.forgot-password a:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
.error-message {
|
||||
margin-top: 1rem;
|
||||
color: #d93025;
|
||||
font-weight: 600;
|
||||
text-align: center;
|
||||
}
|
||||
.loading-spinner {
|
||||
border: 3px solid #f3f3f3;
|
||||
border-top: 3px solid #667eea;
|
||||
border-radius: 50%;
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
animation: spin 1s linear infinite;
|
||||
display: inline-block;
|
||||
vertical-align: middle;
|
||||
margin-left: 8px;
|
||||
}
|
||||
@keyframes spin {
|
||||
0% { transform: rotate(0deg);}
|
||||
100% { transform: rotate(360deg);}
|
||||
}
|
||||
/* Responsive */
|
||||
@media (max-width: 480px) {
|
||||
.login-container {
|
||||
margin: 1rem;
|
||||
padding: 2rem 1.5rem;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="login-container" role="main" aria-label="登录表单">
|
||||
<h2>用户登录</h2>
|
||||
<form id="loginForm" novalidate>
|
||||
<label for="usernameEmail">用户名或邮箱</label>
|
||||
<input type="text" id="usernameEmail" name="usernameEmail" autocomplete="username" placeholder="请输入用户名或邮箱" required aria-describedby="usernameEmailError" />
|
||||
<div id="usernameEmailError" class="error-message" aria-live="polite"></div>
|
||||
|
||||
<label for="password" style="margin-top:1rem;">密码</label>
|
||||
<div class="password-wrapper">
|
||||
<input type="password" id="password" name="password" autocomplete="current-password" placeholder="请输入密码" required minlength="6" aria-describedby="passwordError" />
|
||||
<button type="button" class="toggle-password" aria-label="切换密码可见性" title="切换密码可见性">👁️</button>
|
||||
</div>
|
||||
<div id="passwordError" class="error-message" aria-live="polite"></div>
|
||||
|
||||
<button type="submit" id="loginButton" class="login-button" disabled>登录</button>
|
||||
<div class="forgot-password">
|
||||
<a href="/forgot-password.html" target="_blank" rel="noopener noreferrer">忘记密码?</a>
|
||||
</div>
|
||||
<div id="submitError" class="error-message" aria-live="polite"></div>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
(function(){
|
||||
const usernameEmailInput = document.getElementById('usernameEmail');
|
||||
const passwordInput = document.getElementById('password');
|
||||
const loginButton = document.getElementById('loginButton');
|
||||
const usernameEmailError = document.getElementById('usernameEmailError');
|
||||
const passwordError = document.getElementById('passwordError');
|
||||
const submitError = document.getElementById('submitError');
|
||||
const togglePasswordBtn = document.querySelector('.toggle-password');
|
||||
const form = document.getElementById('loginForm');
|
||||
|
||||
// 校验用户名或邮箱格式
|
||||
function validateUsernameEmail(value) {
|
||||
if (!value.trim()) {
|
||||
return "用户名或邮箱不能为空";
|
||||
}
|
||||
// 简单邮箱正则
|
||||
const emailRegex = /^[^\s@]+@[^\s@]+\.[^\s@]+$/;
|
||||
// 用户名规则:允许字母数字下划线,长度3-20
|
||||
const usernameRegex = /^[a-zA-Z0-9_]{3,20}$/;
|
||||
if (emailRegex.test(value)) {
|
||||
return "";
|
||||
} else if (usernameRegex.test(value)) {
|
||||
return "";
|
||||
} else {
|
||||
return "请输入有效的用户名或邮箱格式";
|
||||
}
|
||||
}
|
||||
|
||||
// 校验密码格式
|
||||
function validatePassword(value) {
|
||||
if (!value) {
|
||||
return "密码不能为空";
|
||||
}
|
||||
if (value.length < 6) {
|
||||
return "密码长度不能少于6位";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
// 实时校验并更新错误提示和按钮状态
|
||||
function validateForm() {
|
||||
const usernameEmailVal = usernameEmailInput.value;
|
||||
const passwordVal = passwordInput.value;
|
||||
|
||||
const usernameEmailErrMsg = validateUsernameEmail(usernameEmailVal);
|
||||
const passwordErrMsg = validatePassword(passwordVal);
|
||||
|
||||
usernameEmailError.textContent = usernameEmailErrMsg;
|
||||
passwordError.textContent = passwordErrMsg;
|
||||
submitError.textContent = "";
|
||||
|
||||
const isValid = !usernameEmailErrMsg && !passwordErrMsg;
|
||||
loginButton.disabled = !isValid;
|
||||
return isValid;
|
||||
}
|
||||
|
||||
// 密码可见切换
|
||||
togglePasswordBtn.addEventListener('click', () => {
|
||||
if (passwordInput.type === 'password') {
|
||||
passwordInput.type = 'text';
|
||||
togglePasswordBtn.textContent = '🙈';
|
||||
togglePasswordBtn.setAttribute('aria-label', '隐藏密码');
|
||||
togglePasswordBtn.setAttribute('title', '隐藏密码');
|
||||
} else {
|
||||
passwordInput.type = 'password';
|
||||
togglePasswordBtn.textContent = '👁️';
|
||||
togglePasswordBtn.setAttribute('aria-label', '显示密码');
|
||||
togglePasswordBtn.setAttribute('title', '显示密码');
|
||||
}
|
||||
});
|
||||
|
||||
// 监听输入事件实时校验
|
||||
usernameEmailInput.addEventListener('input', validateForm);
|
||||
passwordInput.addEventListener('input', validateForm);
|
||||
|
||||
// 模拟登录请求
|
||||
function fakeLoginRequest(data) {
|
||||
return new Promise((resolve, reject) => {
|
||||
setTimeout(() => {
|
||||
// 模拟用户名/邮箱为 "user" 或 "user@example.com" 且密码为 "password123" 才成功
|
||||
const validUsers = ["user", "user@example.com"];
|
||||
if (validUsers.includes(data.usernameEmail.toLowerCase()) && data.password === "password123") {
|
||||
resolve();
|
||||
} else {
|
||||
reject(new Error("用户名或密码错误"));
|
||||
}
|
||||
}, 1500);
|
||||
});
|
||||
}
|
||||
|
||||
// 表单提交处理
|
||||
form.addEventListener('submit', async (e) => {
|
||||
e.preventDefault();
|
||||
if (!validateForm()) return;
|
||||
|
||||
loginButton.disabled = true;
|
||||
const originalText = loginButton.textContent;
|
||||
loginButton.textContent = "登录中";
|
||||
const spinner = document.createElement('span');
|
||||
spinner.className = 'loading-spinner';
|
||||
loginButton.appendChild(spinner);
|
||||
submitError.textContent = "";
|
||||
|
||||
try {
|
||||
await fakeLoginRequest({
|
||||
usernameEmail: usernameEmailInput.value.trim(),
|
||||
password: passwordInput.value
|
||||
});
|
||||
// 登录成功跳转(此处用alert模拟)
|
||||
alert("登录成功,跳转到用户主页");
|
||||
// window.location.href = "/user-home.html"; // 实际跳转
|
||||
} catch (err) {
|
||||
submitError.textContent = err.message;
|
||||
} finally {
|
||||
loginButton.disabled = false;
|
||||
loginButton.textContent = originalText;
|
||||
}
|
||||
});
|
||||
|
||||
// 页面加载时校验一次,防止缓存值导致按钮状态异常
|
||||
validateForm();
|
||||
})();
|
||||
<\/script>
|
||||
<\/body>
|
||||
<\/html>
|
||||
+130
-52
@@ -1,9 +1,11 @@
|
||||
"""
|
||||
baidu voice service
|
||||
baidu voice service with thread-safe token caching
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
import requests
|
||||
|
||||
from aip import AipSpeech
|
||||
|
||||
@@ -14,28 +16,13 @@ from config import conf
|
||||
from voice.audio_convert import get_pcm_from_wav
|
||||
from voice.voice import Voice
|
||||
|
||||
"""
|
||||
百度的语音识别API.
|
||||
dev_pid:
|
||||
- 1936: 普通话远场
|
||||
- 1536:普通话(支持简单的英文识别)
|
||||
- 1537:普通话(纯中文识别)
|
||||
- 1737:英语
|
||||
- 1637:粤语
|
||||
- 1837:四川话
|
||||
要使用本模块, 首先到 yuyin.baidu.com 注册一个开发者账号,
|
||||
之后创建一个新应用, 然后在应用管理的"查看key"中获得 API Key 和 Secret Key
|
||||
然后在 config.json 中填入这两个值, 以及 app_id, dev_pid
|
||||
"""
|
||||
|
||||
|
||||
class BaiduVoice(Voice):
|
||||
def __init__(self):
|
||||
try:
|
||||
# 读取本地 TTS 参数配置
|
||||
curdir = os.path.dirname(__file__)
|
||||
config_path = os.path.join(curdir, "config.json")
|
||||
bconf = None
|
||||
if not os.path.exists(config_path): # 如果没有配置文件,创建本地配置文件
|
||||
if not os.path.exists(config_path):
|
||||
bconf = {"lang": "zh", "ctp": 1, "spd": 5, "pit": 5, "vol": 5, "per": 0}
|
||||
with open(config_path, "w") as fw:
|
||||
json.dump(bconf, fw, indent=4)
|
||||
@@ -47,48 +34,139 @@ class BaiduVoice(Voice):
|
||||
self.api_key = str(conf().get("baidu_api_key"))
|
||||
self.secret_key = str(conf().get("baidu_secret_key"))
|
||||
self.dev_id = conf().get("baidu_dev_pid")
|
||||
self.lang = bconf["lang"]
|
||||
self.ctp = bconf["ctp"]
|
||||
self.spd = bconf["spd"]
|
||||
self.pit = bconf["pit"]
|
||||
self.vol = bconf["vol"]
|
||||
self.per = bconf["per"]
|
||||
|
||||
self.lang = bconf["lang"]
|
||||
self.ctp = bconf["ctp"]
|
||||
self.spd = bconf["spd"]
|
||||
self.pit = bconf["pit"]
|
||||
self.vol = bconf["vol"]
|
||||
self.per = bconf["per"]
|
||||
|
||||
# 百度 SDK 客户端(短文本合成 & 语音识别)
|
||||
self.client = AipSpeech(self.app_id, self.api_key, self.secret_key)
|
||||
|
||||
# access_token 缓存与锁
|
||||
self._access_token = None
|
||||
self._token_expire_ts = 0
|
||||
self._token_lock = threading.Lock()
|
||||
except Exception as e:
|
||||
logger.warn("BaiduVoice init failed: %s, ignore " % e)
|
||||
logger.warn("BaiduVoice init failed: %s, ignore" % e)
|
||||
|
||||
def _get_access_token(self):
|
||||
# 多线程安全获取 token
|
||||
with self._token_lock:
|
||||
now = time.time()
|
||||
if self._access_token and now < self._token_expire_ts:
|
||||
return self._access_token
|
||||
url = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
params = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": self.api_key,
|
||||
"client_secret": self.secret_key,
|
||||
}
|
||||
resp = requests.post(url, params=params).json()
|
||||
token = resp.get("access_token")
|
||||
expires_in = resp.get("expires_in", 2592000)
|
||||
if token:
|
||||
self._access_token = token
|
||||
self._token_expire_ts = now + expires_in - 60 # 提前 1 分钟过期
|
||||
return token
|
||||
else:
|
||||
logger.error("BaiduVoice _get_access_token failed: %s", resp)
|
||||
return None
|
||||
|
||||
def voiceToText(self, voice_file):
|
||||
# 识别本地文件
|
||||
logger.debug("[Baidu] voice file name={}".format(voice_file))
|
||||
logger.debug("[Baidu] recognize voice file=%s", voice_file)
|
||||
pcm = get_pcm_from_wav(voice_file)
|
||||
res = self.client.asr(pcm, "pcm", 16000, {"dev_pid": self.dev_id})
|
||||
if res["err_no"] == 0:
|
||||
logger.info("百度语音识别到了:{}".format(res["result"]))
|
||||
if res.get("err_no") == 0:
|
||||
text = "".join(res["result"])
|
||||
reply = Reply(ReplyType.TEXT, text)
|
||||
logger.info("[Baidu] ASR result: %s", text)
|
||||
return Reply(ReplyType.TEXT, text)
|
||||
else:
|
||||
logger.info("百度语音识别出错了: {}".format(res["err_msg"]))
|
||||
if res["err_msg"] == "request pv too much":
|
||||
logger.info(" 出现这个原因很可能是你的百度语音服务调用量超出限制,或未开通付费")
|
||||
reply = Reply(ReplyType.ERROR, "百度语音识别出错了;{0}".format(res["err_msg"]))
|
||||
return reply
|
||||
err = res.get("err_msg", "")
|
||||
logger.error("[Baidu] ASR error: %s", err)
|
||||
return Reply(ReplyType.ERROR, f"语音识别失败:{err}")
|
||||
|
||||
def _long_text_synthesis(self, text):
|
||||
token = self._get_access_token()
|
||||
if not token:
|
||||
return Reply(ReplyType.ERROR, "获取百度 access_token 失败")
|
||||
|
||||
# 创建合成任务
|
||||
create_url = f"https://aip.baidubce.com/rpc/2.0/tts/v1/create?access_token={token}"
|
||||
payload = {
|
||||
"text": text,
|
||||
"format": "mp3-16k",
|
||||
"voice": 0,
|
||||
"lang": self.lang,
|
||||
"speed": self.spd,
|
||||
"pitch": self.pit,
|
||||
"volume": self.vol,
|
||||
"enable_subtitle": 0,
|
||||
}
|
||||
headers = {"Content-Type": "application/json"}
|
||||
create_resp = requests.post(create_url, headers=headers, json=payload).json()
|
||||
task_id = create_resp.get("task_id")
|
||||
if not task_id:
|
||||
logger.error("[Baidu] 长文本合成创建任务失败: %s", create_resp)
|
||||
return Reply(ReplyType.ERROR, "长文本合成任务提交失败")
|
||||
logger.info("[Baidu] 长文本合成任务已提交 task_id=%s", task_id)
|
||||
|
||||
# 轮询查询任务状态
|
||||
query_url = f"https://aip.baidubce.com/rpc/2.0/tts/v1/query?access_token={token}"
|
||||
for _ in range(100):
|
||||
time.sleep(3)
|
||||
resp = requests.post(query_url, headers=headers, json={"task_ids":[task_id]})
|
||||
result = resp.json()
|
||||
infos = result.get("tasks_info") or result.get("tasks") or []
|
||||
if not infos:
|
||||
continue
|
||||
info = infos[0]
|
||||
status = info.get("task_status")
|
||||
if status == "Success":
|
||||
task_res = info.get("task_result", {})
|
||||
audio_url = task_res.get("audio_address") or task_res.get("speech_url")
|
||||
break
|
||||
elif status == "Running":
|
||||
continue
|
||||
else:
|
||||
logger.error("[Baidu] 长文本合成失败: %s", info)
|
||||
return Reply(ReplyType.ERROR, "长文本合成执行失败")
|
||||
else:
|
||||
return Reply(ReplyType.ERROR, "长文本合成超时,请稍后重试")
|
||||
|
||||
# 下载并保存音频
|
||||
audio_data = requests.get(audio_url).content
|
||||
fn = TmpDir().path() + f"reply-long-{int(time.time())}-{hash(text)&0x7FFFFFFF}.mp3"
|
||||
with open(fn, "wb") as f:
|
||||
f.write(audio_data)
|
||||
logger.info("[Baidu] 长文本合成 success: %s", fn)
|
||||
return Reply(ReplyType.VOICE, fn)
|
||||
|
||||
def textToVoice(self, text):
|
||||
result = self.client.synthesis(
|
||||
text,
|
||||
self.lang,
|
||||
self.ctp,
|
||||
{"spd": self.spd, "pit": self.pit, "vol": self.vol, "per": self.per},
|
||||
)
|
||||
if not isinstance(result, dict):
|
||||
# Avoid the same filename under multithreading
|
||||
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3"
|
||||
with open(fileName, "wb") as f:
|
||||
f.write(result)
|
||||
logger.info("[Baidu] textToVoice text={} voice file name={}".format(text, fileName))
|
||||
reply = Reply(ReplyType.VOICE, fileName)
|
||||
else:
|
||||
logger.error("[Baidu] textToVoice error={}".format(result))
|
||||
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
|
||||
return reply
|
||||
try:
|
||||
# GBK 编码字节长度
|
||||
gbk_len = len(text.encode("gbk", errors="ignore"))
|
||||
if gbk_len <= 1024:
|
||||
# 短文本走 SDK 合成
|
||||
result = self.client.synthesis(
|
||||
text, self.lang, self.ctp,
|
||||
{"spd":self.spd, "pit":self.pit, "vol":self.vol, "per":self.per}
|
||||
)
|
||||
if not isinstance(result, dict):
|
||||
fn = TmpDir().path() + f"reply-{int(time.time())}-{hash(text)&0x7FFFFFFF}.mp3"
|
||||
with open(fn, "wb") as f:
|
||||
f.write(result)
|
||||
logger.info("[Baidu] 短文本合成 success: %s", fn)
|
||||
return Reply(ReplyType.VOICE, fn)
|
||||
else:
|
||||
logger.error("[Baidu] 短文本合成 error: %s", result)
|
||||
return Reply(ReplyType.ERROR, "短文本语音合成失败")
|
||||
else:
|
||||
# 长文本
|
||||
return self._long_text_synthesis(text)
|
||||
except Exception as e:
|
||||
logger.error("BaiduVoice textToVoice exception: %s", e)
|
||||
return Reply(ReplyType.ERROR, f"合成异常:{e}")
|
||||
|
||||
|
||||
@@ -50,4 +50,8 @@ def create_voice(voice_type):
|
||||
from voice.xunfei.xunfei_voice import XunfeiVoice
|
||||
|
||||
return XunfeiVoice()
|
||||
elif voice_type == "tencent":
|
||||
from voice.tencent.tencent_voice import TencentVoice
|
||||
|
||||
return TencentVoice()
|
||||
raise RuntimeError
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"voice_type": 1003,
|
||||
"secret_id": "YOUR_SECRET_ID",
|
||||
"secret_key": "YOUR_SECRET_KEY"
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
import json
|
||||
import base64
|
||||
import os
|
||||
import time
|
||||
from voice.voice import Voice
|
||||
from common.log import logger
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.asr.v20190614 import asr_client, models as asr_models
|
||||
from tencentcloud.tts.v20190823 import tts_client, models as tts_models
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.tmp_dir import TmpDir
|
||||
|
||||
class TencentVoice(Voice):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.secret_id = None
|
||||
self.secret_key = None
|
||||
self.voice_type = 1003
|
||||
self._load_config()
|
||||
|
||||
def _load_config(self):
|
||||
"""
|
||||
从本地配置文件加载配置
|
||||
"""
|
||||
try:
|
||||
config_path = os.path.join(os.path.dirname(__file__), 'config.json')
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
self.secret_id = config.get('secret_id')
|
||||
self.secret_key = config.get('secret_key')
|
||||
self.voice_type = config.get('voice_type', self.voice_type)
|
||||
if not self.secret_id or not self.secret_key:
|
||||
logger.error("[Tencent] Missing credentials in config.json")
|
||||
except Exception as e:
|
||||
logger.error(f"[Tencent] Failed to load config: {e}")
|
||||
|
||||
def setup(self, config):
|
||||
"""
|
||||
设置配置信息(保留此方法用于向后兼容)
|
||||
"""
|
||||
pass
|
||||
|
||||
def voiceToText(self, voice_file):
|
||||
"""
|
||||
将语音文件转换为文本
|
||||
"""
|
||||
try:
|
||||
# 实例化认证对象
|
||||
cred = credential.Credential(self.secret_id, self.secret_key)
|
||||
|
||||
# 实例化客户端
|
||||
client = asr_client.AsrClient(cred, "ap-guangzhou")
|
||||
|
||||
# 读取音频文件
|
||||
with open(voice_file, 'rb') as f:
|
||||
audio_data = f.read()
|
||||
|
||||
# 进行base64编码
|
||||
base64_audio = base64.b64encode(audio_data).decode('utf-8')
|
||||
|
||||
# 构造请求对象
|
||||
req = asr_models.SentenceRecognitionRequest()
|
||||
req.ProjectId = 0
|
||||
req.SubServiceType = 2
|
||||
req.EngSerViceType = "16k_zh"
|
||||
req.SourceType = 1
|
||||
req.VoiceFormat = "wav"
|
||||
req.UsrAudioKey = "voice_recognition"
|
||||
req.Data = base64_audio
|
||||
|
||||
# 发起请求
|
||||
resp = client.SentenceRecognition(req)
|
||||
|
||||
# 解析结果
|
||||
if resp.Result:
|
||||
logger.info("[Tencent] Voice to text success: {}".format(resp.Result))
|
||||
return Reply(ReplyType.TEXT, resp.Result)
|
||||
else:
|
||||
logger.warning("[Tencent] Voice to text failed")
|
||||
return Reply(ReplyType.ERROR, "腾讯语音识别失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Tencent] Voice to text error: {}".format(e))
|
||||
return Reply(ReplyType.ERROR, "腾讯语音识别出错:{}".format(str(e)))
|
||||
|
||||
def textToVoice(self, text):
|
||||
"""
|
||||
将文本转换为语音
|
||||
"""
|
||||
try:
|
||||
cred = credential.Credential(self.secret_id, self.secret_key)
|
||||
client = tts_client.TtsClient(cred, "ap-guangzhou")
|
||||
|
||||
req = tts_models.TextToVoiceRequest()
|
||||
req.Text = text
|
||||
req.SessionId = str(int(time.time()))
|
||||
req.Volume = 5
|
||||
req.Speed = 0
|
||||
req.ProjectId = 0
|
||||
req.ModelType = 1
|
||||
req.PrimaryLanguage = 1
|
||||
req.SampleRate = 16000
|
||||
req.VoiceType = self.voice_type # 客服女声
|
||||
|
||||
response = client.TextToVoice(req)
|
||||
|
||||
if response.Audio:
|
||||
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3"
|
||||
with open(fileName, "wb") as f:
|
||||
f.write(base64.b64decode(response.Audio))
|
||||
logger.info("[Tencent] textToVoice text={} voice file name={}".format(text, fileName))
|
||||
return Reply(ReplyType.VOICE, fileName)
|
||||
else:
|
||||
logger.error("[Tencent] textToVoice failed")
|
||||
return Reply(ReplyType.ERROR, "腾讯语音合成失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Tencent] Text to voice error: {}".format(e))
|
||||
return Reply(ReplyType.ERROR, "腾讯语音合成出错:{}".format(str(e)))
|
||||
Reference in New Issue
Block a user