mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-05-15 08:48:51 +08:00
Compare commits
38 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ee91c86a29 | |||
| 48c08f4aad | |||
| fceabb8e67 | |||
| fcfafb05f1 | |||
| f1e8344beb | |||
| 89e8f385b4 | |||
| bf4ae9a051 | |||
| 6bd1242d43 | |||
| 8779eab36b | |||
| 3174b1158c | |||
| 18740093d1 | |||
| 8c7d1d4010 | |||
| 8c48a27e1a | |||
| 4278d2b8ef | |||
| 3a3affd3ec | |||
| 45d72b8b9b | |||
| 03b908c079 | |||
| d35d01f980 | |||
| 9c208ffa2c | |||
| bea4416f12 | |||
| 2ea8b4ef73 | |||
| e6946ef989 | |||
| 9aeb60f66d | |||
| d687f9329e | |||
| 3207258fd9 | |||
| d8b75206fe | |||
| 88e8dd5162 | |||
| c9306633b2 | |||
| c50d1cc99d | |||
| 9a20c1cb02 | |||
| 176f77ba5b | |||
| 484de6237b | |||
| 898aa30b1d | |||
| 8b73a74609 | |||
| 3c6d42b22e | |||
| 40563c1e96 | |||
| cb0c86ec1c | |||
| 614f3b1ea4 |
+11
-1
@@ -1,5 +1,6 @@
|
||||
.DS_Store
|
||||
.idea
|
||||
.vscode
|
||||
.wechaty/
|
||||
__pycache__/
|
||||
venv*
|
||||
@@ -11,4 +12,13 @@ tmp
|
||||
plugins.json
|
||||
itchat.pkl
|
||||
*.log
|
||||
user_datas.pkl
|
||||
user_datas.pkl
|
||||
plugins/**/
|
||||
!plugins/bdunit
|
||||
!plugins/dungeon
|
||||
!plugins/finish
|
||||
!plugins/godcmd
|
||||
!plugins/tool
|
||||
!plugins/banwords
|
||||
!plugins/hello
|
||||
!plugins/role
|
||||
@@ -13,10 +13,14 @@
|
||||
- [x] **语音识别:** 支持接收和处理语音消息,通过文字或语音回复
|
||||
- [x] **插件化:** 支持个性化功能插件,提供角色扮演、文字冒险游戏等预设插件
|
||||
|
||||
> 快速部署:
|
||||
> 目前支持微信和微信个人号部署,欢迎接入更多应用,参考[`Terminal`代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/terminal/terminal_channel.py)实现接收和发送消息逻辑即可接入。
|
||||
|
||||
|
||||
快速部署:
|
||||
>
|
||||
>[](https://railway.app/template/qApznZ?referralCode=RC3znh)
|
||||
|
||||
|
||||
# 更新日志
|
||||
|
||||
>**2023.04.05:** 支持微信个人号部署,兼容角色扮演等预设插件,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatmp/README.md)。(contributed by [@JS00000](https://github.com/JS00000) in [#686](https://github.com/zhayujie/chatgpt-on-wechat/pull/686))
|
||||
@@ -62,10 +66,6 @@
|
||||
|
||||
> 项目中使用的对话模型是 davinci,计费方式是约每 750 字 (包含请求和回复) 消耗 $0.02,图片生成是每张消耗 $0.016,账号创建有免费的 $18 额度 (更新3.25: 最新注册的已经无免费额度了),使用完可以更换邮箱重新注册。
|
||||
|
||||
#### 1.1 ChapGPT service On Azure
|
||||
一种替换以上的方法是使用Azure推出的[ChatGPT service](https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/)。它host在公有云Azure上,因此不需要VPN就可以直接访问。不过目前仍然处于preview阶段。新用户可以通过Try Azure for free来薅一段时间的羊毛
|
||||
|
||||
|
||||
### 2.运行环境
|
||||
|
||||
支持 Linux、MacOS、Windows 系统(可在Linux服务器上长期运行),同时需安装 `Python`。
|
||||
|
||||
@@ -13,7 +13,8 @@ def sigterm_handler_wrap(_signo):
|
||||
def func(_signo, _stack_frame):
|
||||
logger.info("signal {} received, exiting...".format(_signo))
|
||||
conf().save_user_datas()
|
||||
return old_handler(_signo, _stack_frame)
|
||||
if callable(old_handler): # check old_handler
|
||||
return old_handler(_signo, _stack_frame)
|
||||
signal.signal(_signo, func)
|
||||
|
||||
def run():
|
||||
@@ -27,12 +28,16 @@ def run():
|
||||
|
||||
# create channel
|
||||
channel_name=conf().get('channel_type', 'wx')
|
||||
|
||||
if "--cmd" in sys.argv:
|
||||
channel_name = 'terminal'
|
||||
|
||||
if channel_name == 'wxy':
|
||||
os.environ['WECHATY_LOG']="warn"
|
||||
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
|
||||
|
||||
channel = channel_factory.create_channel(channel_name)
|
||||
if channel_name in ['wx','wxy','wechatmp']:
|
||||
if channel_name in ['wx','wxy','terminal','wechatmp','wechatmp_service']:
|
||||
PluginManager().load_plugins()
|
||||
|
||||
# startup channel
|
||||
|
||||
@@ -3,13 +3,12 @@
|
||||
from bot.bot import Bot
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.session_manager import Session, SessionManager
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from config import conf, load_config
|
||||
from common.log import logger
|
||||
from common.token_bucket import TokenBucket
|
||||
from common.expired_dict import ExpiredDict
|
||||
import openai
|
||||
import openai.error
|
||||
import time
|
||||
@@ -91,8 +90,8 @@ class ChatGPTBot(Bot,OpenAIImage):
|
||||
"top_p":1,
|
||||
"frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"request_timeout": conf().get('request_timeout', 60), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"timeout": conf().get('request_timeout', 120), #重试超时时间,在这个时间内,将会自动重试
|
||||
"request_timeout": conf().get('request_timeout', None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试
|
||||
}
|
||||
|
||||
def reply_text(self, session:ChatGPTSession, session_id, api_key, retry_count=0) -> dict:
|
||||
@@ -151,6 +150,7 @@ class AzureChatGPTBot(ChatGPTBot):
|
||||
|
||||
def compose_args(self):
|
||||
args = super().compose_args()
|
||||
args["engine"] = args["model"]
|
||||
del(args["model"])
|
||||
return args
|
||||
args["deployment_id"] = conf().get("azure_deployment_id")
|
||||
#args["engine"] = args["model"]
|
||||
#del(args["model"])
|
||||
return args
|
||||
|
||||
@@ -55,7 +55,7 @@ def num_tokens_from_messages(messages, model):
|
||||
except KeyError:
|
||||
logger.debug("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
if model == "gpt-3.5-turbo":
|
||||
if model == "gpt-3.5-turbo" or model == "gpt-35-turbo":
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
|
||||
elif model == "gpt-4":
|
||||
return num_tokens_from_messages(messages, model="gpt-4-0314")
|
||||
@@ -76,4 +76,4 @@ def num_tokens_from_messages(messages, model):
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
return num_tokens
|
||||
return num_tokens
|
||||
|
||||
+1
-1
@@ -19,7 +19,7 @@ class Bridge(object):
|
||||
model_type = conf().get("model")
|
||||
if model_type in ["text-davinci-003"]:
|
||||
self.btype['chat'] = const.OPEN_AI
|
||||
if conf().get("use_azure_chatgpt"):
|
||||
if conf().get("use_azure_chatgpt", False):
|
||||
self.btype['chat'] = const.CHATGPTONAZURE
|
||||
self.bots={}
|
||||
|
||||
|
||||
+2
-1
@@ -5,7 +5,8 @@ from enum import Enum
|
||||
class ContextType (Enum):
|
||||
TEXT = 1 # 文本消息
|
||||
VOICE = 2 # 音频消息
|
||||
IMAGE_CREATE = 3 # 创建图片命令
|
||||
IMAGE = 3 # 图片消息
|
||||
IMAGE_CREATE = 10 # 创建图片命令
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
@@ -19,5 +19,8 @@ def create_channel(channel_type):
|
||||
return TerminalChannel()
|
||||
elif channel_type == 'wechatmp':
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
return WechatMPChannel()
|
||||
return WechatMPChannel(passive_reply = True)
|
||||
elif channel_type == 'wechatmp_service':
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
return WechatMPChannel(passive_reply = False)
|
||||
raise RuntimeError
|
||||
|
||||
+20
-8
@@ -51,7 +51,7 @@ class ChatChannel(Channel):
|
||||
if cmsg.from_user_id == self.user_id and not config.get('trigger_by_self', True):
|
||||
logger.debug("[WX]self message skipped")
|
||||
return None
|
||||
if context["isgroup"]:
|
||||
if context.get("isgroup", False):
|
||||
group_name = cmsg.other_user_nickname
|
||||
group_id = cmsg.other_user_id
|
||||
|
||||
@@ -76,7 +76,7 @@ class ChatChannel(Channel):
|
||||
logger.debug("[WX]reference query skipped")
|
||||
return None
|
||||
|
||||
if context["isgroup"]: # 群聊
|
||||
if context.get("isgroup", False): # 群聊
|
||||
# 校验关键字
|
||||
match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
|
||||
match_contain = check_contain(content, conf().get('group_chat_keyword'))
|
||||
@@ -97,7 +97,7 @@ class ChatChannel(Channel):
|
||||
logger.info("[WX]receive group voice, but checkprefix didn't match")
|
||||
return None
|
||||
else: # 单聊
|
||||
match_prefix = check_prefix(content, conf().get('single_chat_prefix'))
|
||||
match_prefix = check_prefix(content, conf().get('single_chat_prefix',['']))
|
||||
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
|
||||
content = content.replace(match_prefix, '', 1).strip()
|
||||
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
|
||||
@@ -170,6 +170,8 @@ class ChatChannel(Channel):
|
||||
reply = self._generate_reply(new_context)
|
||||
else:
|
||||
return
|
||||
elif context.type == ContextType.IMAGE: # 图片消息,当前无默认逻辑
|
||||
pass
|
||||
else:
|
||||
logger.error('[WX] unknown context type: {}'.format(context.type))
|
||||
return
|
||||
@@ -193,7 +195,7 @@ class ChatChannel(Channel):
|
||||
if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
reply = super().build_text_to_voice(reply.content)
|
||||
return self._decorate_reply(context, reply)
|
||||
if context['isgroup']:
|
||||
if context.get("isgroup", False):
|
||||
reply_text = '@' + context['msg'].actual_user_nickname + ' ' + reply_text.strip()
|
||||
reply_text = conf().get("group_chat_reply_prefix", "") + reply_text
|
||||
else:
|
||||
@@ -231,12 +233,20 @@ class ChatChannel(Channel):
|
||||
time.sleep(3+3*retry_cnt)
|
||||
self._send(reply, context, retry_cnt+1)
|
||||
|
||||
def thread_pool_callback(self, session_id):
|
||||
def _success_callback(self, session_id, **kwargs):# 线程正常结束时的回调函数
|
||||
logger.debug("Worker return success, session_id = {}".format(session_id))
|
||||
|
||||
def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数
|
||||
logger.exception("Worker return exception: {}".format(exception))
|
||||
|
||||
def _thread_pool_callback(self, session_id, **kwargs):
|
||||
def func(worker:Future):
|
||||
try:
|
||||
worker_exception = worker.exception()
|
||||
if worker_exception:
|
||||
logger.exception("Worker return exception: {}".format(worker_exception))
|
||||
self._fail_callback(session_id, exception = worker_exception, **kwargs)
|
||||
else:
|
||||
self._success_callback(session_id, **kwargs)
|
||||
except CancelledError as e:
|
||||
logger.info("Worker cancelled, session_id = {}".format(session_id))
|
||||
except Exception as e:
|
||||
@@ -249,7 +259,7 @@ class ChatChannel(Channel):
|
||||
session_id = context['session_id']
|
||||
with self.lock:
|
||||
if session_id not in self.sessions:
|
||||
self.sessions[session_id] = [Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 1))]
|
||||
self.sessions[session_id] = [Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 4))]
|
||||
if context.type == ContextType.TEXT and context.content.startswith("#"):
|
||||
self.sessions[session_id][0].putleft(context) # 优先处理管理命令
|
||||
else:
|
||||
@@ -267,7 +277,7 @@ class ChatChannel(Channel):
|
||||
context = context_queue.get()
|
||||
logger.debug("[WX] consume context: {}".format(context))
|
||||
future:Future = self.handler_pool.submit(self._handle, context)
|
||||
future.add_done_callback(self.thread_pool_callback(session_id))
|
||||
future.add_done_callback(self._thread_pool_callback(session_id, context = context))
|
||||
if session_id not in self.futures:
|
||||
self.futures[session_id] = []
|
||||
self.futures[session_id].append(future)
|
||||
@@ -302,6 +312,8 @@ class ChatChannel(Channel):
|
||||
|
||||
|
||||
def check_prefix(content, prefix_list):
|
||||
if not prefix_list:
|
||||
return None
|
||||
for prefix in prefix_list:
|
||||
if content.startswith(prefix):
|
||||
return prefix
|
||||
|
||||
+12
-10
@@ -1,27 +1,29 @@
|
||||
|
||||
"""
|
||||
本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装
|
||||
本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装。
|
||||
|
||||
填好必填项(群聊6个,非群聊8个),即可接入ChatChannel,并支持插件,参考TerminalChannel
|
||||
|
||||
ChatMessage
|
||||
msg_id: 消息id
|
||||
msg_id: 消息id (必填)
|
||||
create_time: 消息创建时间
|
||||
|
||||
ctype: 消息类型 : ContextType
|
||||
content: 消息内容, 如果是声音/图片,这里是文件路径
|
||||
ctype: 消息类型 : ContextType (必填)
|
||||
content: 消息内容, 如果是声音/图片,这里是文件路径 (必填)
|
||||
|
||||
from_user_id: 发送者id
|
||||
from_user_id: 发送者id (必填)
|
||||
from_user_nickname: 发送者昵称
|
||||
to_user_id: 接收者id
|
||||
to_user_id: 接收者id (必填)
|
||||
to_user_nickname: 接收者昵称
|
||||
|
||||
other_user_id: 对方的id,如果你是发送者,那这个就是接收者id,如果你是接收者,那这个就是发送者id,如果是群消息,那这一直是群id
|
||||
other_user_id: 对方的id,如果你是发送者,那这个就是接收者id,如果你是接收者,那这个就是发送者id,如果是群消息,那这一直是群id (必填)
|
||||
other_user_nickname: 同上
|
||||
|
||||
is_group: 是否是群消息
|
||||
is_at: 是否被at
|
||||
is_group: 是否是群消息 (群聊必填)
|
||||
is_at: 是否被at
|
||||
|
||||
- (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在)
|
||||
actual_user_id: 实际发送者id
|
||||
actual_user_id: 实际发送者id (群聊必填)
|
||||
actual_user_nickname:实际发送者昵称
|
||||
|
||||
|
||||
|
||||
@@ -1,31 +1,78 @@
|
||||
from bridge.context import *
|
||||
from channel.channel import Channel
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.chat_message import ChatMessage
|
||||
import sys
|
||||
|
||||
class TerminalChannel(Channel):
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
|
||||
class TerminalMessage(ChatMessage):
|
||||
def __init__(self, msg_id, content, ctype = ContextType.TEXT, from_user_id = "User", to_user_id = "Chatgpt", other_user_id = "Chatgpt"):
|
||||
self.msg_id = msg_id
|
||||
self.ctype = ctype
|
||||
self.content = content
|
||||
self.from_user_id = from_user_id
|
||||
self.to_user_id = to_user_id
|
||||
self.other_user_id = other_user_id
|
||||
|
||||
class TerminalChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE]
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
print("\nBot:")
|
||||
if reply.type == ReplyType.IMAGE:
|
||||
from PIL import Image
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
img = Image.open(image_storage)
|
||||
print("<IMAGE>")
|
||||
img.show()
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
from PIL import Image
|
||||
import requests,io
|
||||
img_url = reply.content
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
image_storage.seek(0)
|
||||
img = Image.open(image_storage)
|
||||
print(img_url)
|
||||
img.show()
|
||||
else:
|
||||
print(reply.content)
|
||||
print("\nUser:", end="")
|
||||
sys.stdout.flush()
|
||||
return
|
||||
|
||||
def startup(self):
|
||||
context = Context()
|
||||
print("\nPlease input your question")
|
||||
logger.setLevel("WARN")
|
||||
print("\nPlease input your question:\nUser:", end="")
|
||||
sys.stdout.flush()
|
||||
msg_id = 0
|
||||
while True:
|
||||
try:
|
||||
prompt = self.get_input("User:\n")
|
||||
prompt = self.get_input()
|
||||
except KeyboardInterrupt:
|
||||
print("\nExiting...")
|
||||
sys.exit()
|
||||
msg_id += 1
|
||||
trigger_prefixs = conf().get("single_chat_prefix",[""])
|
||||
if check_prefix(prompt, trigger_prefixs) is None:
|
||||
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
|
||||
|
||||
context = self._compose_context(ContextType.TEXT, prompt, msg = TerminalMessage(msg_id, prompt))
|
||||
if context:
|
||||
self.produce(context)
|
||||
else:
|
||||
raise Exception("context is None")
|
||||
|
||||
context.type = ContextType.TEXT
|
||||
context['session_id'] = "User"
|
||||
context.content = prompt
|
||||
print("Bot:")
|
||||
sys.stdout.flush()
|
||||
res = super().build_reply_content(prompt, context).content
|
||||
print(res)
|
||||
|
||||
|
||||
def get_input(self, prompt):
|
||||
def get_input(self):
|
||||
"""
|
||||
Multi-line input function
|
||||
"""
|
||||
print(prompt, end="")
|
||||
sys.stdout.flush()
|
||||
line = input()
|
||||
return line
|
||||
|
||||
@@ -23,26 +23,21 @@ from common.time_check import time_checker
|
||||
from common.expired_dict import ExpiredDict
|
||||
from plugins import *
|
||||
|
||||
@itchat.msg_register(TEXT)
|
||||
@itchat.msg_register([TEXT,VOICE,PICTURE])
|
||||
def handler_single_msg(msg):
|
||||
WechatChannel().handle_text(WeChatMessage(msg))
|
||||
# logger.debug("handler_single_msg: {}".format(msg))
|
||||
if msg['Type'] == PICTURE and msg['MsgType'] == 47:
|
||||
return None
|
||||
WechatChannel().handle_single(WeChatMessage(msg))
|
||||
return None
|
||||
|
||||
@itchat.msg_register(TEXT, isGroupChat=True)
|
||||
@itchat.msg_register([TEXT,VOICE,PICTURE], isGroupChat=True)
|
||||
def handler_group_msg(msg):
|
||||
if msg['Type'] == PICTURE and msg['MsgType'] == 47:
|
||||
return None
|
||||
WechatChannel().handle_group(WeChatMessage(msg,True))
|
||||
return None
|
||||
|
||||
@itchat.msg_register(VOICE)
|
||||
def handler_single_voice(msg):
|
||||
WechatChannel().handle_voice(WeChatMessage(msg))
|
||||
return None
|
||||
|
||||
@itchat.msg_register(VOICE, isGroupChat=True)
|
||||
def handler_group_voice(msg):
|
||||
WechatChannel().handle_group_voice(WeChatMessage(msg,True))
|
||||
return None
|
||||
|
||||
def _check(func):
|
||||
def wrapper(self, cmsg: ChatMessage):
|
||||
msgId = cmsg.msg_id
|
||||
@@ -118,7 +113,7 @@ class WechatChannel(ChatChannel):
|
||||
# start message listener
|
||||
itchat.run()
|
||||
|
||||
# handle_* 系列函数处理收到的消息后构造Context,然后传入_handle函数中处理Context和发送回复
|
||||
# handle_* 系列函数处理收到的消息后构造Context,然后传入produce函数中处理Context和发送回复
|
||||
# Context包含了消息的所有信息,包括以下属性
|
||||
# type 消息类型, 包括TEXT、VOICE、IMAGE_CREATE
|
||||
# content 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令
|
||||
@@ -132,37 +127,32 @@ class WechatChannel(ChatChannel):
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_voice(self, cmsg : ChatMessage):
|
||||
if conf().get('speech_recognition') != True:
|
||||
return
|
||||
logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_text(self, cmsg : ChatMessage):
|
||||
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=False, msg=cmsg)
|
||||
def handle_single(self, cmsg : ChatMessage):
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
if conf().get('speech_recognition') != True:
|
||||
return
|
||||
logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[WX]receive image msg: {}".format(cmsg.content))
|
||||
else:
|
||||
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_group(self, cmsg : ChatMessage):
|
||||
logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
context = self._compose_context(ContextType.TEXT, cmsg.content, isgroup=True, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_group_voice(self, cmsg : ChatMessage):
|
||||
if conf().get('group_speech_recognition', False) != True:
|
||||
return
|
||||
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(ContextType.VOICE, cmsg.content, isgroup=True, msg=cmsg)
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
if conf().get('speech_recognition') != True:
|
||||
return
|
||||
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[WX]receive image for group msg: {}".format(cmsg.content))
|
||||
else:
|
||||
# logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
pass
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
|
||||
@@ -22,6 +22,10 @@ class WeChatMessage(ChatMessage):
|
||||
self.ctype = ContextType.VOICE
|
||||
self.content = TmpDir().path() + itchat_msg['FileName'] # content直接存临时目录路径
|
||||
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
||||
elif itchat_msg['Type'] == PICTURE and itchat_msg['MsgType'] == 3:
|
||||
self.ctype = ContextType.IMAGE
|
||||
self.content = TmpDir().path() + itchat_msg['FileName'] # content直接存临时目录路径
|
||||
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: {}".format(itchat_msg['Type']))
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# 个人微信公众号channel
|
||||
# 微信公众号channel
|
||||
|
||||
鉴于个人微信号在服务器上通过itchat登录有封号风险,这里新增了个人微信公众号channel,提供无风险的服务。
|
||||
但是由于个人微信公众号的众多接口限制,目前支持的功能有限,实现简陋,提供了一个最基本的文本对话服务,支持加载插件,优化了命令格式,支持私有api_key。暂未实现图片输入输出、语音输入输出等交互形式。
|
||||
如有公众号是企业主体且可以通过微信认证,即可获得更多接口,解除大多数限制。欢迎大家提供更多的支持。
|
||||
鉴于个人微信号在服务器上通过itchat登录有封号风险,这里新增了微信公众号channel,提供无风险的服务。
|
||||
目前支持订阅号(个人)和服务号(企业)两种类型的公众号,它们的主要区别就是被动回复和主动回复。
|
||||
个人微信订阅号有许多接口限制,目前仅支持最基本的文本对话和语音输入,支持加载插件,支持私有api_key。
|
||||
暂未实现图片输入输出、语音输出等交互形式。
|
||||
|
||||
## 使用方法
|
||||
## 使用方法(订阅号,服务号类似)
|
||||
|
||||
在开始部署前,你需要一个拥有公网IP的服务器,以提供微信服务器和我们自己服务器的连接。或者你需要进行内网穿透,否则微信服务器无法将消息发送给我们的服务器。
|
||||
|
||||
@@ -21,8 +22,10 @@ pip3 install web.py
|
||||
相关的服务器验证代码已经写好,你不需要再添加任何代码。你只需要在本项目根目录的`config.json`中添加
|
||||
```
|
||||
"channel_type": "wechatmp",
|
||||
"wechatmp_token": "your Token",
|
||||
"wechatmp_port": 8080,
|
||||
"wechatmp_token": "Token", # 微信公众平台的Token
|
||||
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
|
||||
"wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要
|
||||
"wechatmp_app_secret": "", # 微信公众平台的appsecret,仅服务号需要
|
||||
```
|
||||
然后运行`python3 app.py`启动web服务器。这里会默认监听8080端口,但是微信公众号的服务器配置只支持80/443端口,有两种方法来解决这个问题。第一个是推荐的方法,使用端口转发命令将80端口转发到8080端口(443同理,注意需要支持SSL,也就是https的访问,在`wechatmp_channel.py`需要修改相应的证书路径):
|
||||
```
|
||||
@@ -35,7 +38,7 @@ sudo iptables-save > /etc/iptables/rules.v4
|
||||
随后在[微信公众平台](https://mp.weixin.qq.com)启用服务器,关闭手动填写规则的自动回复,即可实现ChatGPT的自动回复。
|
||||
|
||||
## 个人微信公众号的限制
|
||||
由于目前测试的公众号不是企业主体,所以没有客服接口,因此公众号无法主动发出消息,只能被动回复。而微信官方对被动回复有5秒的时间限制,最多重试2次,因此最多只有15秒的自动回复时间窗口。因此如果问题比较复杂或者我们的服务器比较忙,ChatGPT的回答就没办法及时回复给用户。为了解决这个问题,这里做了回答缓存,它需要你在回复超时后,再次主动发送任意文字(例如1)来尝试拿到回答缓存。为了优化使用体验,目前设置了两分钟(120秒)的timeout,用户在至多两分钟后即可得到查询到回复或者错误原因。
|
||||
由于人微信公众号不能通过微信认证,所以没有客服接口,因此公众号无法主动发出消息,只能被动回复。而微信官方对被动回复有5秒的时间限制,最多重试2次,因此最多只有15秒的自动回复时间窗口。因此如果问题比较复杂或者我们的服务器比较忙,ChatGPT的回答就没办法及时回复给用户。为了解决这个问题,这里做了回答缓存,它需要你在回复超时后,再次主动发送任意文字(例如1)来尝试拿到回答缓存。为了优化使用体验,目前设置了两分钟(120秒)的timeout,用户在至多两分钟后即可得到查询到回复或者错误原因。
|
||||
|
||||
另外,由于微信官方的限制,自动回复有长度限制。因此这里将ChatGPT的回答拆分,分成每段600字回复(限制大约在700字)。
|
||||
|
||||
@@ -43,4 +46,9 @@ sudo iptables-save > /etc/iptables/rules.v4
|
||||
公共api有访问频率限制(免费账号每分钟最多20次ChatGPT的API调用),这在服务多人的时候会遇到问题。因此这里多加了一个设置私有api_key的功能。目前通过godcmd插件的命令来设置私有api_key。
|
||||
|
||||
## 测试范围
|
||||
目前在`RoboStyle`这个公众号上进行了测试,感兴趣的可以关注并体验。开启了godcmd, Banwords, role, dungeon, finish这五个插件,其他的插件还没有测试。百度的接口暂未测试。语音对话没有测试。图片直接以链接形式回复(没有临时素材上传接口的权限)。
|
||||
目前在`RoboStyle`这个公众号上进行了测试(基于[wechatmp-stable分支](https://github.com/JS00000/chatgpt-on-wechat/tree/wechatmp-stable),而[master分支](https://github.com/zhayujie/chatgpt-on-wechat)含有最新功能,但是稳定性有待测试),感兴趣的可以关注并体验。开启了godcmd, Banwords, role, dungeon, finish这五个插件,其他的插件还没有测试。百度的接口暂未测试。语音对话没有测试。图片直接以链接形式回复(没有临时素材上传接口的权限)。
|
||||
|
||||
## TODO
|
||||
* 服务号交互完善
|
||||
* 服务号使用临时素材接口,提供图片回复能力
|
||||
* 插件测试
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
import web
|
||||
import time
|
||||
import channel.wechatmp.reply as reply
|
||||
import channel.wechatmp.receive as receive
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
from bridge.context import *
|
||||
from channel.wechatmp.common import *
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
|
||||
# This class is instantiated once per query
|
||||
class Query():
|
||||
|
||||
def GET(self):
|
||||
return verify_server(web.input())
|
||||
|
||||
def POST(self):
|
||||
# Make sure to return the instance that first created, @singleton will do that.
|
||||
channel = WechatMPChannel()
|
||||
try:
|
||||
webData = web.data()
|
||||
# logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
|
||||
wechatmp_msg = receive.parse_xml(webData)
|
||||
if wechatmp_msg.msg_type == 'text':
|
||||
from_user = wechatmp_msg.from_user_id
|
||||
message = wechatmp_msg.content.decode("utf-8")
|
||||
message_id = wechatmp_msg.msg_id
|
||||
|
||||
logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message))
|
||||
context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg)
|
||||
if context:
|
||||
# set private openai_api_key
|
||||
# if from_user is not changed in itchat, this can be placed at chat_channel
|
||||
user_data = conf().get_user_data(from_user)
|
||||
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key
|
||||
channel.produce(context)
|
||||
# The reply will be sent by channel.send() in another thread
|
||||
return "success"
|
||||
|
||||
elif wechatmp_msg.msg_type == 'event':
|
||||
logger.info("[wechatmp] Event {} from {}".format(wechatmp_msg.Event, wechatmp_msg.from_user_id))
|
||||
content = subscribe_msg()
|
||||
replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content)
|
||||
return replyMsg.send()
|
||||
else:
|
||||
logger.info("暂且不处理")
|
||||
return "success"
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
return exc
|
||||
|
||||
@@ -0,0 +1,172 @@
|
||||
import web
|
||||
import time
|
||||
import channel.wechatmp.reply as reply
|
||||
import channel.wechatmp.receive as receive
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
from bridge.context import *
|
||||
from channel.wechatmp.common import *
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
|
||||
# This class is instantiated once per query
|
||||
class Query():
|
||||
|
||||
def GET(self):
|
||||
return verify_server(web.input())
|
||||
|
||||
def POST(self):
|
||||
# Make sure to return the instance that first created, @singleton will do that.
|
||||
channel = WechatMPChannel()
|
||||
try:
|
||||
query_time = time.time()
|
||||
webData = web.data()
|
||||
logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
|
||||
wechatmp_msg = receive.parse_xml(webData)
|
||||
if wechatmp_msg.msg_type == 'text':
|
||||
from_user = wechatmp_msg.from_user_id
|
||||
to_user = wechatmp_msg.to_user_id
|
||||
message = wechatmp_msg.content.decode("utf-8")
|
||||
message_id = wechatmp_msg.msg_id
|
||||
|
||||
logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message))
|
||||
supported = True
|
||||
if "【收到不支持的消息类型,暂无法显示】" in message:
|
||||
supported = False # not supported, used to refresh
|
||||
cache_key = from_user
|
||||
|
||||
reply_text = ""
|
||||
# New request
|
||||
if cache_key not in channel.cache_dict and cache_key not in channel.running:
|
||||
# The first query begin, reset the cache
|
||||
context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg)
|
||||
logger.debug("[wechatmp] context: {} {}".format(context, wechatmp_msg))
|
||||
if message_id in channel.received_msgs: # received and finished
|
||||
# no return because of bandwords or other reasons
|
||||
return "success"
|
||||
if supported and context:
|
||||
# set private openai_api_key
|
||||
# if from_user is not changed in itchat, this can be placed at chat_channel
|
||||
user_data = conf().get_user_data(from_user)
|
||||
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key
|
||||
channel.received_msgs[message_id] = wechatmp_msg
|
||||
channel.running.add(cache_key)
|
||||
channel.produce(context)
|
||||
else:
|
||||
trigger_prefix = conf().get('single_chat_prefix',[''])[0]
|
||||
if trigger_prefix or not supported:
|
||||
if trigger_prefix:
|
||||
content = textwrap.dedent(f"""\
|
||||
请输入'{trigger_prefix}'接你想说的话跟我说话。
|
||||
例如:
|
||||
{trigger_prefix}你好,很高兴见到你。""")
|
||||
else:
|
||||
content = textwrap.dedent("""\
|
||||
你好,很高兴见到你。
|
||||
请跟我说话吧。""")
|
||||
else:
|
||||
logger.error(f"[wechatmp] unknown error")
|
||||
content = textwrap.dedent("""\
|
||||
未知错误,请稍后再试""")
|
||||
replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content)
|
||||
return replyMsg.send()
|
||||
channel.query1[cache_key] = False
|
||||
channel.query2[cache_key] = False
|
||||
channel.query3[cache_key] = False
|
||||
# User request again, and the answer is not ready
|
||||
elif cache_key in channel.running and channel.query1.get(cache_key) == True and channel.query2.get(cache_key) == True and channel.query3.get(cache_key) == True:
|
||||
channel.query1[cache_key] = False #To improve waiting experience, this can be set to True.
|
||||
channel.query2[cache_key] = False #To improve waiting experience, this can be set to True.
|
||||
channel.query3[cache_key] = False
|
||||
# User request again, and the answer is ready
|
||||
elif cache_key in channel.cache_dict:
|
||||
# Skip the waiting phase
|
||||
channel.query1[cache_key] = True
|
||||
channel.query2[cache_key] = True
|
||||
channel.query3[cache_key] = True
|
||||
|
||||
assert not (cache_key in channel.cache_dict and cache_key in channel.running)
|
||||
|
||||
if channel.query1.get(cache_key) == False:
|
||||
# The first query from wechat official server
|
||||
logger.debug("[wechatmp] query1 {}".format(cache_key))
|
||||
channel.query1[cache_key] = True
|
||||
cnt = 0
|
||||
while cache_key in channel.running and cnt < 45:
|
||||
cnt = cnt + 1
|
||||
time.sleep(0.1)
|
||||
if cnt == 45:
|
||||
# waiting for timeout (the POST query will be closed by wechat official server)
|
||||
time.sleep(1)
|
||||
# and do nothing
|
||||
return
|
||||
else:
|
||||
pass
|
||||
elif channel.query2.get(cache_key) == False:
|
||||
# The second query from wechat official server
|
||||
logger.debug("[wechatmp] query2 {}".format(cache_key))
|
||||
channel.query2[cache_key] = True
|
||||
cnt = 0
|
||||
while cache_key in channel.running and cnt < 45:
|
||||
cnt = cnt + 1
|
||||
time.sleep(0.1)
|
||||
if cnt == 45:
|
||||
# waiting for timeout (the POST query will be closed by wechat official server)
|
||||
time.sleep(1)
|
||||
# and do nothing
|
||||
return
|
||||
else:
|
||||
pass
|
||||
elif channel.query3.get(cache_key) == False:
|
||||
# The third query from wechat official server
|
||||
logger.debug("[wechatmp] query3 {}".format(cache_key))
|
||||
channel.query3[cache_key] = True
|
||||
cnt = 0
|
||||
while cache_key in channel.running and cnt < 40:
|
||||
cnt = cnt + 1
|
||||
time.sleep(0.1)
|
||||
if cnt == 40:
|
||||
# Have waiting for 3x5 seconds
|
||||
# return timeout message
|
||||
reply_text = "【正在思考中,回复任意文字尝试获取回复】"
|
||||
logger.info("[wechatmp] Three queries has finished For {}: {}".format(from_user, message_id))
|
||||
replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
|
||||
return replyPost
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
if cache_key not in channel.cache_dict and cache_key not in channel.running:
|
||||
# no return because of bandwords or other reasons
|
||||
return "success"
|
||||
|
||||
# if float(time.time()) - float(query_time) > 4.8:
|
||||
# reply_text = "【正在思考中,回复任意文字尝试获取回复】"
|
||||
# logger.info("[wechatmp] Timeout for {} {}, return".format(from_user, message_id))
|
||||
# replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
|
||||
# return replyPost
|
||||
|
||||
if cache_key in channel.cache_dict:
|
||||
content = channel.cache_dict[cache_key]
|
||||
if len(content.encode('utf8'))<=MAX_UTF8_LEN:
|
||||
reply_text = channel.cache_dict[cache_key]
|
||||
channel.cache_dict.pop(cache_key)
|
||||
else:
|
||||
continue_text = "\n【未完待续,回复任意文字以继续】"
|
||||
splits = split_string_by_utf8_length(content, MAX_UTF8_LEN - len(continue_text.encode('utf-8')), max_split= 1)
|
||||
reply_text = splits[0] + continue_text
|
||||
channel.cache_dict[cache_key] = splits[1]
|
||||
logger.info("[wechatmp] {}:{} Do send {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), reply_text))
|
||||
replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
|
||||
return replyPost
|
||||
|
||||
elif wechatmp_msg.msg_type == 'event':
|
||||
logger.info("[wechatmp] Event {} from {}".format(wechatmp_msg.content, wechatmp_msg.from_user_id))
|
||||
content = subscribe_msg()
|
||||
replyMsg = reply.TextMsg(wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content)
|
||||
return replyMsg.send()
|
||||
else:
|
||||
logger.info("暂且不处理")
|
||||
return "success"
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
return exc
|
||||
@@ -0,0 +1,63 @@
|
||||
from config import conf
|
||||
import hashlib
|
||||
import textwrap
|
||||
|
||||
MAX_UTF8_LEN = 2048
|
||||
|
||||
class WeChatAPIException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def verify_server(data):
|
||||
try:
|
||||
if len(data) == 0:
|
||||
return "None"
|
||||
signature = data.signature
|
||||
timestamp = data.timestamp
|
||||
nonce = data.nonce
|
||||
echostr = data.echostr
|
||||
token = conf().get('wechatmp_token') #请按照公众平台官网\基本配置中信息填写
|
||||
|
||||
data_list = [token, timestamp, nonce]
|
||||
data_list.sort()
|
||||
sha1 = hashlib.sha1()
|
||||
# map(sha1.update, data_list) #python2
|
||||
sha1.update("".join(data_list).encode('utf-8'))
|
||||
hashcode = sha1.hexdigest()
|
||||
print("handle/GET func: hashcode, signature: ", hashcode, signature)
|
||||
if hashcode == signature:
|
||||
return echostr
|
||||
else:
|
||||
return ""
|
||||
except Exception as Argument:
|
||||
return Argument
|
||||
|
||||
def subscribe_msg():
|
||||
trigger_prefix = conf().get('single_chat_prefix',[''])[0]
|
||||
msg = textwrap.dedent(f"""\
|
||||
感谢您的关注!
|
||||
这里是ChatGPT,可以自由对话。
|
||||
资源有限,回复较慢,请勿着急。
|
||||
支持通用表情输入。
|
||||
暂时不支持图片输入。
|
||||
支持图片输出,画字开头的问题将回复图片链接。
|
||||
支持角色扮演和文字冒险两种定制模式对话。
|
||||
输入'{trigger_prefix}#帮助' 查看详细指令。""")
|
||||
return msg
|
||||
|
||||
|
||||
def split_string_by_utf8_length(string, max_length, max_split=0):
|
||||
encoded = string.encode('utf-8')
|
||||
start, end = 0, 0
|
||||
result = []
|
||||
while end < len(encoded):
|
||||
if max_split > 0 and len(result) >= max_split:
|
||||
result.append(encoded[start:].decode('utf-8'))
|
||||
break
|
||||
end = start + max_length
|
||||
# 如果当前字节不是 UTF-8 编码的开始字节,则向前查找直到找到开始字节为止
|
||||
while end < len(encoded) and (encoded[end] & 0b11000000) == 0b10000000:
|
||||
end -= 1
|
||||
result.append(encoded[start:end].decode('utf-8'))
|
||||
start = end
|
||||
return result
|
||||
@@ -19,7 +19,10 @@ class WeChatMPMessage(ChatMessage):
|
||||
self.from_user_id = xmlData.find('FromUserName').text
|
||||
self.create_time = xmlData.find('CreateTime').text
|
||||
self.msg_type = xmlData.find('MsgType').text
|
||||
self.msg_id = xmlData.find('MsgId').text
|
||||
try:
|
||||
self.msg_id = xmlData.find('MsgId').text
|
||||
except:
|
||||
self.msg_id = self.from_user_id+self.create_time
|
||||
self.is_group = False
|
||||
|
||||
# reply to other_user_id
|
||||
@@ -36,7 +39,7 @@ class WeChatMPMessage(ChatMessage):
|
||||
self.pic_url = xmlData.find('PicUrl').text
|
||||
self.media_id = xmlData.find('MediaId').text
|
||||
elif self.msg_type == 'event':
|
||||
self.event = xmlData.find('Event').text
|
||||
self.content = xmlData.find('Event').text
|
||||
else: # video, shortvideo, location, link
|
||||
# not implemented
|
||||
pass
|
||||
@@ -1,19 +1,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import web
|
||||
import time
|
||||
import math
|
||||
import hashlib
|
||||
import textwrap
|
||||
from channel.chat_channel import ChatChannel
|
||||
import channel.wechatmp.reply as reply
|
||||
import channel.wechatmp.receive as receive
|
||||
import json
|
||||
import requests
|
||||
import threading
|
||||
from common.singleton import singleton
|
||||
from common.log import logger
|
||||
from common.expired_dict import ExpiredDict
|
||||
from config import conf
|
||||
from bridge.reply import *
|
||||
from bridge.context import *
|
||||
from plugins import *
|
||||
import traceback
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.wechatmp.common import *
|
||||
|
||||
# If using SSL, uncomment the following lines, and modify the certificate path.
|
||||
# from cheroot.server import HTTPServer
|
||||
@@ -22,213 +20,110 @@ import traceback
|
||||
# certificate='/ssl/cert.pem',
|
||||
# private_key='/ssl/cert.key')
|
||||
|
||||
|
||||
# from concurrent.futures import ThreadPoolExecutor
|
||||
# thread_pool = ThreadPoolExecutor(max_workers=8)
|
||||
|
||||
@singleton
|
||||
class WechatMPChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE]
|
||||
def __init__(self):
|
||||
def __init__(self, passive_reply = True):
|
||||
super().__init__()
|
||||
self.cache_dict = dict()
|
||||
self.query1 = dict()
|
||||
self.query2 = dict()
|
||||
self.query3 = dict()
|
||||
|
||||
self.passive_reply = passive_reply
|
||||
self.running = set()
|
||||
self.received_msgs = ExpiredDict(60*60*24)
|
||||
if self.passive_reply:
|
||||
self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE]
|
||||
self.cache_dict = dict()
|
||||
self.query1 = dict()
|
||||
self.query2 = dict()
|
||||
self.query3 = dict()
|
||||
else:
|
||||
# TODO support image
|
||||
self.NOT_SUPPORT_REPLYTYPE = [ReplyType.IMAGE, ReplyType.VOICE]
|
||||
self.app_id = conf().get('wechatmp_app_id')
|
||||
self.app_secret = conf().get('wechatmp_app_secret')
|
||||
self.access_token = None
|
||||
self.access_token_expires_time = 0
|
||||
self.access_token_lock = threading.Lock()
|
||||
self.get_access_token()
|
||||
|
||||
def startup(self):
|
||||
urls = (
|
||||
'/wx', 'SubsribeAccountQuery',
|
||||
)
|
||||
app = web.application(urls, globals())
|
||||
if self.passive_reply:
|
||||
urls = ('/wx', 'channel.wechatmp.SubscribeAccount.Query')
|
||||
else:
|
||||
urls = ('/wx', 'channel.wechatmp.ServiceAccount.Query')
|
||||
app = web.application(urls, globals(), autoreload=False)
|
||||
port = conf().get('wechatmp_port', 8080)
|
||||
web.httpserver.runsimple(app.wsgifunc(), ('0.0.0.0', port))
|
||||
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
reply_cnt = math.ceil(len(reply.content) / 600)
|
||||
receiver = context["receiver"]
|
||||
self.cache_dict[receiver] = (reply_cnt, reply.content)
|
||||
logger.debug("[send] reply to {} saved to cache: {}".format(receiver, reply))
|
||||
def wechatmp_request(self, method, url, **kwargs):
|
||||
r = requests.request(method=method, url=url, **kwargs)
|
||||
r.raise_for_status()
|
||||
r.encoding = "utf-8"
|
||||
ret = r.json()
|
||||
if "errcode" in ret and ret["errcode"] != 0:
|
||||
raise WeChatAPIException("{}".format(ret))
|
||||
return ret
|
||||
|
||||
def get_access_token(self):
|
||||
|
||||
def verify_server():
|
||||
try:
|
||||
data = web.input()
|
||||
if len(data) == 0:
|
||||
return "None"
|
||||
signature = data.signature
|
||||
timestamp = data.timestamp
|
||||
nonce = data.nonce
|
||||
echostr = data.echostr
|
||||
token = conf().get('wechatmp_token') #请按照公众平台官网\基本配置中信息填写
|
||||
# return the access_token
|
||||
if self.access_token:
|
||||
if self.access_token_expires_time - time.time() > 60:
|
||||
return self.access_token
|
||||
|
||||
data_list = [token, timestamp, nonce]
|
||||
data_list.sort()
|
||||
sha1 = hashlib.sha1()
|
||||
# map(sha1.update, data_list) #python2
|
||||
sha1.update("".join(data_list).encode('utf-8'))
|
||||
hashcode = sha1.hexdigest()
|
||||
print("handle/GET func: hashcode, signature: ", hashcode, signature)
|
||||
if hashcode == signature:
|
||||
return echostr
|
||||
# Get new access_token
|
||||
# Do not request access_token in parallel! Only the last obtained is valid.
|
||||
if self.access_token_lock.acquire(blocking=False):
|
||||
# Wait for other threads that have previously obtained access_token to complete the request
|
||||
# This happens every 2 hours, so it doesn't affect the experience very much
|
||||
time.sleep(1)
|
||||
self.access_token = None
|
||||
url="https://api.weixin.qq.com/cgi-bin/token"
|
||||
params={
|
||||
"grant_type": "client_credential",
|
||||
"appid": self.app_id,
|
||||
"secret": self.app_secret
|
||||
}
|
||||
data = self.wechatmp_request(method='get', url=url, params=params)
|
||||
self.access_token = data['access_token']
|
||||
self.access_token_expires_time = int(time.time()) + data['expires_in']
|
||||
logger.info("[wechatmp] access_token: {}".format(self.access_token))
|
||||
self.access_token_lock.release()
|
||||
else:
|
||||
return ""
|
||||
except Exception as Argument:
|
||||
return Argument
|
||||
# Wait for token update
|
||||
while self.access_token_lock.locked():
|
||||
time.sleep(0.1)
|
||||
return self.access_token
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
if self.passive_reply:
|
||||
receiver = context["receiver"]
|
||||
self.cache_dict[receiver] = reply.content
|
||||
logger.info("[send] reply to {} saved to cache: {}".format(receiver, reply))
|
||||
else:
|
||||
receiver = context["receiver"]
|
||||
reply_text = reply.content
|
||||
url="https://api.weixin.qq.com/cgi-bin/message/custom/send"
|
||||
params = {
|
||||
"access_token": self.get_access_token()
|
||||
}
|
||||
json_data = {
|
||||
"touser": receiver,
|
||||
"msgtype": "text",
|
||||
"text": {"content": reply_text}
|
||||
}
|
||||
self.wechatmp_request(method='post', url=url, params=params, data=json.dumps(json_data, ensure_ascii=False).encode('utf8'))
|
||||
logger.info("[send] Do send to {}: {}".format(receiver, reply_text))
|
||||
return
|
||||
|
||||
|
||||
# This class is instantiated once per query
|
||||
class SubsribeAccountQuery():
|
||||
|
||||
def GET(self):
|
||||
return verify_server()
|
||||
|
||||
def POST(self):
|
||||
channel_instance = WechatMPChannel()
|
||||
try:
|
||||
query_time = time.time()
|
||||
webData = web.data()
|
||||
# logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
|
||||
wechat_msg = receive.parse_xml(webData)
|
||||
if wechat_msg.msg_type == 'text':
|
||||
from_user = wechat_msg.from_user_id
|
||||
to_user = wechat_msg.to_user_id
|
||||
message = wechat_msg.content.decode("utf-8")
|
||||
message_id = wechat_msg.msg_id
|
||||
|
||||
logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message))
|
||||
|
||||
cache_key = from_user
|
||||
cache = channel_instance.cache_dict.get(cache_key)
|
||||
|
||||
reply_text = ""
|
||||
# New request
|
||||
if cache == None:
|
||||
# The first query begin, reset the cache
|
||||
context = channel_instance._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechat_msg)
|
||||
logger.debug("[wechatmp] context: {} {}".format(context, wechat_msg))
|
||||
if context:
|
||||
# set private openai_api_key
|
||||
# if from_user is not changed in itchat, this can be placed at chat_channel
|
||||
user_data = conf().get_user_data(from_user)
|
||||
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key
|
||||
channel_instance.cache_dict[cache_key] = (0, "")
|
||||
channel_instance.produce(context)
|
||||
else:
|
||||
trigger_prefix = conf().get('single_chat_prefix',[''])[0]
|
||||
if trigger_prefix:
|
||||
content = textwrap.dedent(f"""\
|
||||
请输入'{trigger_prefix}'接你想说的话跟我说话。
|
||||
例如:
|
||||
{trigger_prefix}你好,很高兴见到你。""")
|
||||
else:
|
||||
logger.error(f"[wechatmp] unknown error")
|
||||
content = textwrap.dedent("""\
|
||||
未知错误,请稍后再试""")
|
||||
replyMsg = reply.TextMsg(wechat_msg.from_user_id, wechat_msg.to_user_id, content)
|
||||
return replyMsg.send()
|
||||
channel_instance.query1[cache_key] = False
|
||||
channel_instance.query2[cache_key] = False
|
||||
channel_instance.query3[cache_key] = False
|
||||
# Request again
|
||||
elif cache[0] == 0 and channel_instance.query1.get(cache_key) == True and channel_instance.query2.get(cache_key) == True and channel_instance.query3.get(cache_key) == True:
|
||||
channel_instance.query1[cache_key] = False #To improve waiting experience, this can be set to True.
|
||||
channel_instance.query2[cache_key] = False #To improve waiting experience, this can be set to True.
|
||||
channel_instance.query3[cache_key] = False
|
||||
elif cache[0] >= 1:
|
||||
# Skip the waiting phase
|
||||
channel_instance.query1[cache_key] = True
|
||||
channel_instance.query2[cache_key] = True
|
||||
channel_instance.query3[cache_key] = True
|
||||
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数
|
||||
logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context['msg'].msg_id))
|
||||
if self.passive_reply:
|
||||
self.running.remove(session_id)
|
||||
|
||||
|
||||
cache = channel_instance.cache_dict.get(cache_key)
|
||||
if channel_instance.query1.get(cache_key) == False:
|
||||
# The first query from wechat official server
|
||||
logger.debug("[wechatmp] query1 {}".format(cache_key))
|
||||
channel_instance.query1[cache_key] = True
|
||||
cnt = 0
|
||||
while cache[0] == 0 and cnt < 45:
|
||||
cnt = cnt + 1
|
||||
time.sleep(0.1)
|
||||
cache = channel_instance.cache_dict.get(cache_key)
|
||||
if cnt == 45:
|
||||
# waiting for timeout (the POST query will be closed by wechat official server)
|
||||
time.sleep(5)
|
||||
# and do nothing
|
||||
return
|
||||
else:
|
||||
pass
|
||||
elif channel_instance.query2.get(cache_key) == False:
|
||||
# The second query from wechat official server
|
||||
logger.debug("[wechatmp] query2 {}".format(cache_key))
|
||||
channel_instance.query2[cache_key] = True
|
||||
cnt = 0
|
||||
while cache[0] == 0 and cnt < 45:
|
||||
cnt = cnt + 1
|
||||
time.sleep(0.1)
|
||||
cache = channel_instance.cache_dict.get(cache_key)
|
||||
if cnt == 45:
|
||||
# waiting for timeout (the POST query will be closed by wechat official server)
|
||||
time.sleep(5)
|
||||
# and do nothing
|
||||
return
|
||||
else:
|
||||
pass
|
||||
elif channel_instance.query3.get(cache_key) == False:
|
||||
# The third query from wechat official server
|
||||
logger.debug("[wechatmp] query3 {}".format(cache_key))
|
||||
channel_instance.query3[cache_key] = True
|
||||
cnt = 0
|
||||
while cache[0] == 0 and cnt < 40:
|
||||
cnt = cnt + 1
|
||||
time.sleep(0.1)
|
||||
cache = channel_instance.cache_dict.get(cache_key)
|
||||
if cnt == 40:
|
||||
# Have waiting for 3x5 seconds
|
||||
# return timeout message
|
||||
reply_text = "【正在思考中,回复任意文字尝试获取回复】"
|
||||
logger.info("[wechatmp] Three queries has finished For {}: {}".format(from_user, message_id))
|
||||
replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
|
||||
return replyPost
|
||||
else:
|
||||
pass
|
||||
|
||||
if float(time.time()) - float(query_time) > 4.8:
|
||||
logger.info("[wechatmp] Timeout for {} {}".format(from_user, message_id))
|
||||
return
|
||||
|
||||
|
||||
if cache[0] > 1:
|
||||
reply_text = cache[1][:600] + "\n【未完待续,回复任意文字以继续】" #wechatmp auto_reply length limit
|
||||
channel_instance.cache_dict[cache_key] = (cache[0] - 1, cache[1][600:])
|
||||
elif cache[0] == 1:
|
||||
reply_text = cache[1]
|
||||
channel_instance.cache_dict.pop(cache_key)
|
||||
logger.info("[wechatmp] {}:{} Do send {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), reply_text))
|
||||
replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
|
||||
return replyPost
|
||||
|
||||
elif wechat_msg.msg_type == 'event':
|
||||
logger.info("[wechatmp] Event {} from {}".format(wechat_msg.Event, wechat_msg.from_user_id))
|
||||
trigger_prefix = conf().get('single_chat_prefix',[''])[0]
|
||||
content = textwrap.dedent(f"""\
|
||||
感谢您的关注!
|
||||
这里是ChatGPT,可以自由对话。
|
||||
资源有限,回复较慢,请勿着急。
|
||||
支持通用表情输入。
|
||||
暂时不支持图片输入。
|
||||
支持图片输出,画字开头的问题将回复图片链接。
|
||||
支持角色扮演和文字冒险两种定制模式对话。
|
||||
输入'{trigger_prefix}#帮助' 查看详细指令。""")
|
||||
replyMsg = reply.TextMsg(wechat_msg.from_user_id, wechat_msg.to_user_id, content)
|
||||
return replyMsg.send()
|
||||
else:
|
||||
logger.info("暂且不处理")
|
||||
return "success"
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
return exc
|
||||
def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数
|
||||
logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context['msg'].msg_id, exception))
|
||||
if self.passive_reply:
|
||||
assert session_id not in self.cache_dict
|
||||
self.running.remove(session_id)
|
||||
|
||||
|
||||
+12
-3
@@ -2,9 +2,13 @@ import logging
|
||||
import sys
|
||||
|
||||
|
||||
def _get_logger():
|
||||
log = logging.getLogger('log')
|
||||
log.setLevel(logging.INFO)
|
||||
def _reset_logger(log):
|
||||
for handler in log.handlers:
|
||||
handler.close()
|
||||
log.removeHandler(handler)
|
||||
del handler
|
||||
log.handlers.clear()
|
||||
log.propagate = False
|
||||
console_handle = logging.StreamHandler(sys.stdout)
|
||||
console_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'))
|
||||
@@ -13,6 +17,11 @@ def _get_logger():
|
||||
datefmt='%Y-%m-%d %H:%M:%S'))
|
||||
log.addHandler(file_handle)
|
||||
log.addHandler(console_handle)
|
||||
|
||||
def _get_logger():
|
||||
log = logging.getLogger('log')
|
||||
_reset_logger(log)
|
||||
log.setLevel(logging.INFO)
|
||||
return log
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
import time
|
||||
import pip
|
||||
from pip._internal import main as pipmain
|
||||
from common.log import logger,_reset_logger
|
||||
|
||||
def install(package):
|
||||
pipmain(['install', package])
|
||||
|
||||
def install_requirements(file):
|
||||
pipmain(['install', '-r', file, "--upgrade"])
|
||||
_reset_logger(logger)
|
||||
|
||||
def check_dulwich():
|
||||
needwait = False
|
||||
for i in range(2):
|
||||
if needwait:
|
||||
time.sleep(3)
|
||||
needwait = False
|
||||
try:
|
||||
import dulwich
|
||||
return
|
||||
except ImportError:
|
||||
try:
|
||||
install('dulwich')
|
||||
except:
|
||||
needwait = True
|
||||
try:
|
||||
import dulwich
|
||||
except ImportError:
|
||||
raise ImportError("Unable to import dulwich")
|
||||
+1
-1
@@ -12,7 +12,7 @@ class TmpDir(object):
|
||||
|
||||
def __init__(self):
|
||||
pathExists = os.path.exists(self.tmpFilePath)
|
||||
if not pathExists and conf().get('speech_recognition') == True:
|
||||
if not pathExists:
|
||||
os.makedirs(self.tmpFilePath)
|
||||
|
||||
def path(self):
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
"open_ai_api_key": "YOUR API KEY",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"proxy": "",
|
||||
"use_azure_chatgpt": false,
|
||||
"single_chat_prefix": ["bot", "@bot"],
|
||||
"single_chat_reply_prefix": "[bot] ",
|
||||
"group_chat_prefix": ["@bot"],
|
||||
|
||||
@@ -16,6 +16,7 @@ available_setting = {
|
||||
# chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
|
||||
"model": "gpt-3.5-turbo",
|
||||
"use_azure_chatgpt": False, # 是否使用azure的chatgpt
|
||||
"azure_deployment_id": "", #azure 模型部署名称
|
||||
|
||||
# Bot触发配置
|
||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
|
||||
@@ -79,14 +80,16 @@ available_setting = {
|
||||
"wechaty_puppet_service_token": "", # wechaty的token
|
||||
|
||||
# wechatmp的配置
|
||||
"wechatmp_token": "", # 微信公众平台的Token
|
||||
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
|
||||
"wechatmp_token": "", # 微信公众平台的Token
|
||||
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
|
||||
"wechatmp_app_id": "", # 微信公众平台的appID,仅服务号需要
|
||||
"wechatmp_app_secret": "", # 微信公众平台的appsecret,仅服务号需要
|
||||
|
||||
# chatgpt指令自定义触发词
|
||||
"clear_memory_commands": ['#清除记忆'], # 重置会话指令,必须以#开头
|
||||
|
||||
# channel配置
|
||||
"channel_type": "wx", # 通道类型,支持:{wx,wxy,terminal,wechatmp}
|
||||
"channel_type": "wx", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service}
|
||||
|
||||
"debug": False, # 是否开启debug模式,开启后会打印更多日志
|
||||
|
||||
|
||||
+41
-6
@@ -1,3 +1,11 @@
|
||||
**Table of Content**
|
||||
|
||||
- [插件化初衷](#插件化初衷)
|
||||
- [插件安装方法](#插件安装方法)
|
||||
- [插件化实现](#插件化实现)
|
||||
- [插件编写示例](#插件编写示例)
|
||||
- [插件设计建议](#插件设计建议)
|
||||
|
||||
## 插件化初衷
|
||||
|
||||
之前未插件化的代码耦合程度高,如果要定制一些个性化功能(如流量控制、接入`NovelAI`画图平台等),需要了解代码主体,避免影响到其他的功能。多个功能同时存在时,无法调整功能的优先级顺序,功能配置项也非常混乱。
|
||||
@@ -11,7 +19,23 @@
|
||||
- [x] 插件化能够自由开关和调整优先级。
|
||||
- [x] 每个插件可在插件文件夹内维护独立的配置文件,方便代码的测试和调试,可以在独立的仓库开发插件。
|
||||
|
||||
PS: 插件目前支持`itchat`和`wechaty`
|
||||
## 插件安装方法
|
||||
|
||||
在本仓库中预置了一些插件,如果要安装其他仓库的插件,有两种方法。
|
||||
|
||||
- 第一种方法是在将下载的插件文件都解压到"plugins"文件夹的一个单独的文件夹,最终插件的代码都位于"plugins/PLUGIN_NAME/*"中。启动程序后,如果插件的目录结构正确,插件会自动被扫描加载。除此以外,注意你还需要安装文件夹中`requirements.txt`中的依赖。
|
||||
|
||||
- 第二种方法是`Godcmd`插件,它是预置的管理员插件,能够让程序在运行时就能安装插件,它能够自动安装依赖。
|
||||
|
||||
安装插件的命令是"#installp [仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件名/仓库地址"。这是管理员命令,认证方法在[这里](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/godcmd)。
|
||||
|
||||
- 安装[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件:#installp sdwebui
|
||||
|
||||
- 安装指定仓库的插件:#installp https://github.com/lanvent/plugin_sdwebui.git
|
||||
|
||||
在安装之后,需要执行"#scanp"命令来扫描加载新安装的插件(或者重新启动程序)。
|
||||
|
||||
安装插件后需要注意有些插件有自己的配置模板,一般要去掉".template"新建一个配置文件。
|
||||
|
||||
## 插件化实现
|
||||
|
||||
@@ -26,7 +50,9 @@ PS: 插件目前支持`itchat`和`wechaty`
|
||||
1.收到消息 ---> 2.产生回复 ---> 3.包装回复 ---> 4.发送回复
|
||||
```
|
||||
|
||||
以下是它们的默认处理逻辑(太长不看,可跳过):
|
||||
以下是它们的默认处理逻辑(太长不看,可跳到[插件编写示例](#插件编写示例)):
|
||||
|
||||
**注意以下包含的代码是`v1.1.0`中的片段,已过时,只可用于理解事件,最新的默认代码逻辑请参考[chat_channel](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/chat_channel.py)**
|
||||
|
||||
#### 1. 收到消息
|
||||
|
||||
@@ -67,9 +93,9 @@ PS: 插件目前支持`itchat`和`wechaty`
|
||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE:
|
||||
reply = super().build_reply_content(context.content, context) #文字跟画图交付给chatgpt
|
||||
elif context.type == ContextType.VOICE: # 声音先进行语音转文字后,修改Context类型为文字后,再交付给chatgpt
|
||||
msg = context['msg']
|
||||
file_name = TmpDir().path() + context.content
|
||||
msg.download(file_name)
|
||||
cmsg = context['msg']
|
||||
cmsg.prepare()
|
||||
file_name = context.content
|
||||
reply = super().build_voice_to_text(file_name)
|
||||
if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO:
|
||||
context.content = reply.content # 语音转文字后,将文字内容作为新的context
|
||||
@@ -154,7 +180,8 @@ PS: 插件目前支持`itchat`和`wechaty`
|
||||
|
||||
### 1. 创建插件
|
||||
|
||||
在`plugins`目录下创建一个插件文件夹`hello`。然后,在该文件夹中创建一个与文件夹同名的`.py`文件`hello.py`。
|
||||
在`plugins`目录下创建一个插件文件夹`hello`。然后,在该文件夹中创建``__init__.py``文件,在``__init__.py``中将其他编写的模块文件导入。在程序启动时,插件管理器会读取``__init__.py``的所有内容。
|
||||
|
||||
```
|
||||
plugins/
|
||||
└── hello
|
||||
@@ -162,6 +189,11 @@ plugins/
|
||||
└── hello.py
|
||||
```
|
||||
|
||||
``__init__.py``的内容:
|
||||
```
|
||||
from .hello import *
|
||||
```
|
||||
|
||||
### 2. 编写插件类
|
||||
|
||||
在`hello.py`文件中,创建插件类,它继承自`Plugin`。
|
||||
@@ -234,5 +266,8 @@ class Hello(Plugin):
|
||||
|
||||
- 尽情将你想要的个性化功能设计为插件。
|
||||
- 一个插件目录建议只注册一个插件类。建议使用单独的仓库维护插件,便于更新。
|
||||
|
||||
在测试调试好后提交`PR`,把自己的仓库加入到[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)中。
|
||||
|
||||
- 插件的config文件、使用说明`README.md`、`requirement.txt`等放置在插件目录中。
|
||||
- 默认优先级不要超过管理员插件`Godcmd`的优先级(999),`Godcmd`插件提供了配置管理、插件管理等功能。
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .banwords import *
|
||||
@@ -0,0 +1 @@
|
||||
from .bdunit import *
|
||||
@@ -0,0 +1 @@
|
||||
from .dungeon import *
|
||||
@@ -0,0 +1 @@
|
||||
from .finish import *
|
||||
@@ -6,7 +6,13 @@
|
||||
|
||||
将`config.json.template`复制为`config.json`,并修改其中`password`的值为口令。
|
||||
|
||||
在私聊中可使用`#auth`指令,输入口令进行管理员认证,详细指令请输入`#help`查看帮助文档:
|
||||
如果没有设置命令,在命令行日志中会打印出本次的临时口令,请注意观察,打印格式如下。
|
||||
|
||||
`#auth <口令>` - 管理员认证。
|
||||
`#help` - 输出帮助文档,是否是管理员和是否是在群聊中会影响帮助文档的输出内容。
|
||||
```
|
||||
[INFO][2023-04-06 23:53:47][godcmd.py:165] - [Godcmd] 因未设置口令,本次的临时口令为0971。
|
||||
```
|
||||
|
||||
在私聊中可使用`#auth`指令,输入口令进行管理员认证。更多详细指令请输入`#help`查看帮助文档:
|
||||
|
||||
`#auth <口令>` - 管理员认证,仅可在私聊时认证。
|
||||
`#help` - 输出帮助文档,**是否是管理员**和是否是在群聊中会影响帮助文档的输出内容。
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .godcmd import *
|
||||
+53
-18
@@ -2,6 +2,8 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import traceback
|
||||
from typing import Tuple
|
||||
from bridge.bridge import Bridge
|
||||
@@ -37,10 +39,10 @@ COMMANDS = {
|
||||
"alias": ["reset_openai_api_key"],
|
||||
"desc": "重置为默认的api_key",
|
||||
},
|
||||
# "id": {
|
||||
# "alias": ["id", "用户"],
|
||||
# "desc": "获取用户id", #目前无实际意义
|
||||
# },
|
||||
"id": {
|
||||
"alias": ["id", "用户"],
|
||||
"desc": "获取用户id", # wechaty和wechatmp的用户id不会变化,可用于绑定管理员
|
||||
},
|
||||
"reset": {
|
||||
"alias": ["reset", "重置会话"],
|
||||
"desc": "重置会话",
|
||||
@@ -92,6 +94,21 @@ ADMIN_COMMANDS = {
|
||||
"args": ["插件名"],
|
||||
"desc": "禁用指定插件",
|
||||
},
|
||||
"installp": {
|
||||
"alias": ["installp", "安装插件"],
|
||||
"args": ["仓库地址或插件名"],
|
||||
"desc": "安装指定插件",
|
||||
},
|
||||
"uninstallp": {
|
||||
"alias": ["uninstallp", "卸载插件"],
|
||||
"args": ["插件名"],
|
||||
"desc": "卸载指定插件",
|
||||
},
|
||||
"updatep": {
|
||||
"alias": ["updatep", "更新插件"],
|
||||
"args": ["插件名"],
|
||||
"desc": "更新指定插件",
|
||||
},
|
||||
"debug": {
|
||||
"alias": ["debug", "调试模式", "DEBUG"],
|
||||
"desc": "开启机器调试日志",
|
||||
@@ -103,7 +120,9 @@ def get_help_text(isadmin, isgroup):
|
||||
for cmd, info in COMMANDS.items():
|
||||
if cmd=="auth": #不提示认证指令
|
||||
continue
|
||||
alias=["#"+a for a in info['alias']]
|
||||
if cmd=="id" and conf().get("channel_type","wx") not in ["wxy","wechatmp"]:
|
||||
continue
|
||||
alias=["#"+a for a in info['alias'][:1]]
|
||||
help_text += f"{','.join(alias)} "
|
||||
if 'args' in info:
|
||||
args=[a for a in info['args']]
|
||||
@@ -122,7 +141,7 @@ def get_help_text(isadmin, isgroup):
|
||||
if ADMIN_COMMANDS and isadmin:
|
||||
help_text += "\n\n管理员指令:\n"
|
||||
for cmd, info in ADMIN_COMMANDS.items():
|
||||
alias=["#"+a for a in info['alias']]
|
||||
alias=["#"+a for a in info['alias'][:1]]
|
||||
help_text += f"{','.join(alias)} "
|
||||
if 'args' in info:
|
||||
args=[a for a in info['args']]
|
||||
@@ -146,7 +165,11 @@ class Godcmd(Plugin):
|
||||
else:
|
||||
with open(config_path,"r") as f:
|
||||
gconf=json.load(f)
|
||||
|
||||
if gconf["password"] == "":
|
||||
self.temp_password = "".join(random.sample(string.digits, 4))
|
||||
logger.info("[Godcmd] 因未设置口令,本次的临时口令为%s。"%self.temp_password)
|
||||
else:
|
||||
self.temp_password = None
|
||||
custom_commands = conf().get("clear_memory_commands", [])
|
||||
for custom_command in custom_commands:
|
||||
if custom_command and custom_command.startswith("#"):
|
||||
@@ -155,7 +178,7 @@ class Godcmd(Plugin):
|
||||
COMMANDS["reset"]["alias"].append(custom_command)
|
||||
|
||||
self.password = gconf["password"]
|
||||
self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证 TODO: 用户名每次都会变,目前不可用
|
||||
self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用
|
||||
self.isrunning = True # 机器人是否运行中
|
||||
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
@@ -176,7 +199,7 @@ class Godcmd(Plugin):
|
||||
channel = e_context['channel']
|
||||
user = e_context['context']['receiver']
|
||||
session_id = e_context['context']['session_id']
|
||||
isgroup = e_context['context']['isgroup']
|
||||
isgroup = e_context['context'].get("isgroup", False)
|
||||
bottype = Bridge().get_bot_type("chat")
|
||||
bot = Bridge().get_bot("chat")
|
||||
# 将命令和参数分割
|
||||
@@ -208,6 +231,8 @@ class Godcmd(Plugin):
|
||||
break
|
||||
if not ok:
|
||||
result = "插件不存在或未启用"
|
||||
elif cmd == "id":
|
||||
ok, result = True, user
|
||||
elif cmd == "set_openai_api_key":
|
||||
if len(args) == 1:
|
||||
user_data = conf().get_user_data(user)
|
||||
@@ -296,11 +321,7 @@ class Godcmd(Plugin):
|
||||
if len(args) != 1:
|
||||
ok, result = False, "请提供插件名"
|
||||
else:
|
||||
ok = PluginManager().enable_plugin(args[0])
|
||||
if ok:
|
||||
result = "插件已启用"
|
||||
else:
|
||||
result = "插件不存在"
|
||||
ok, result = PluginManager().enable_plugin(args[0])
|
||||
elif cmd == "disablep":
|
||||
if len(args) != 1:
|
||||
ok, result = False, "请提供插件名"
|
||||
@@ -310,7 +331,21 @@ class Godcmd(Plugin):
|
||||
result = "插件已禁用"
|
||||
else:
|
||||
result = "插件不存在"
|
||||
|
||||
elif cmd == "installp":
|
||||
if len(args) != 1:
|
||||
ok, result = False, "请提供插件名或.git结尾的仓库地址"
|
||||
else:
|
||||
ok, result = PluginManager().install_plugin(args[0])
|
||||
elif cmd == "uninstallp":
|
||||
if len(args) != 1:
|
||||
ok, result = False, "请提供插件名"
|
||||
else:
|
||||
ok, result = PluginManager().uninstall_plugin(args[0])
|
||||
elif cmd == "updatep":
|
||||
if len(args) != 1:
|
||||
ok, result = False, "请提供插件名"
|
||||
else:
|
||||
ok, result = PluginManager().update_plugin(args[0])
|
||||
logger.debug("[Godcmd] admin command: %s by %s" % (cmd, user))
|
||||
else:
|
||||
ok, result = False, "需要管理员权限才能执行该指令"
|
||||
@@ -339,9 +374,6 @@ class Godcmd(Plugin):
|
||||
if isadmin:
|
||||
return False,"管理员账号无需认证"
|
||||
|
||||
if len(self.password) == 0:
|
||||
return False,"未设置口令,无法认证"
|
||||
|
||||
if len(args) != 1:
|
||||
return False,"请提供口令"
|
||||
|
||||
@@ -349,6 +381,9 @@ class Godcmd(Plugin):
|
||||
if password == self.password:
|
||||
self.admin_users.append(userid)
|
||||
return True,"认证成功"
|
||||
elif password == self.temp_password:
|
||||
self.admin_users.append(userid)
|
||||
return True,"认证成功,请尽快设置口令"
|
||||
else:
|
||||
return False,"认证失败"
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .hello import *
|
||||
+124
-15
@@ -1,8 +1,10 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import importlib
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from common.singleton import singleton
|
||||
from common.sorted_dict import SortedDict
|
||||
from .event import *
|
||||
@@ -17,6 +19,8 @@ class PluginManager:
|
||||
self.listening_plugins = {}
|
||||
self.instances = {}
|
||||
self.pconf = {}
|
||||
self.current_plugin_path = None
|
||||
self.loaded = {}
|
||||
|
||||
def register(self, name: str, desire_priority: int = 0, **kwargs):
|
||||
def wrapper(plugincls):
|
||||
@@ -24,13 +28,15 @@ class PluginManager:
|
||||
plugincls.priority = desire_priority
|
||||
plugincls.desc = kwargs.get('desc')
|
||||
plugincls.author = kwargs.get('author')
|
||||
plugincls.path = self.current_plugin_path
|
||||
plugincls.version = kwargs.get('version') if kwargs.get('version') != None else "1.0"
|
||||
plugincls.namecn = kwargs.get('namecn') if kwargs.get('namecn') != None else name
|
||||
plugincls.hidden = kwargs.get('hidden') if kwargs.get('hidden') != None else False
|
||||
plugincls.enabled = True
|
||||
if self.current_plugin_path == None:
|
||||
raise Exception("Plugin path not set")
|
||||
self.plugins[name.upper()] = plugincls
|
||||
logger.info("Plugin %s_v%s registered" % (name, plugincls.version))
|
||||
return plugincls
|
||||
logger.info("Plugin %s_v%s registered, path=%s" % (name, plugincls.version, plugincls.path))
|
||||
return wrapper
|
||||
|
||||
def save_config(self):
|
||||
@@ -56,26 +62,38 @@ class PluginManager:
|
||||
def scan_plugins(self):
|
||||
logger.info("Scaning plugins ...")
|
||||
plugins_dir = "./plugins"
|
||||
raws = [self.plugins[name] for name in self.plugins]
|
||||
for plugin_name in os.listdir(plugins_dir):
|
||||
plugin_path = os.path.join(plugins_dir, plugin_name)
|
||||
if os.path.isdir(plugin_path):
|
||||
# 判断插件是否包含同名.py文件
|
||||
main_module_path = os.path.join(plugin_path, plugin_name+".py")
|
||||
# 判断插件是否包含同名__init__.py文件
|
||||
main_module_path = os.path.join(plugin_path,"__init__.py")
|
||||
if os.path.isfile(main_module_path):
|
||||
# 导入插件
|
||||
import_path = "plugins.{}.{}".format(plugin_name, plugin_name)
|
||||
import_path = "plugins.{}".format(plugin_name)
|
||||
try:
|
||||
main_module = importlib.import_module(import_path)
|
||||
self.current_plugin_path = plugin_path
|
||||
if plugin_path in self.loaded:
|
||||
if self.loaded[plugin_path] == None:
|
||||
logger.info("reload module %s" % plugin_name)
|
||||
self.loaded[plugin_path] = importlib.reload(sys.modules[import_path])
|
||||
dependent_module_names = [name for name in sys.modules.keys() if name.startswith( import_path+ '.')]
|
||||
for name in dependent_module_names:
|
||||
logger.info("reload module %s" % name)
|
||||
importlib.reload(sys.modules[name])
|
||||
else:
|
||||
self.loaded[plugin_path] = importlib.import_module(import_path)
|
||||
self.current_plugin_path = None
|
||||
except Exception as e:
|
||||
logger.warn("Failed to import plugin %s: %s" % (plugin_name, e))
|
||||
logger.exception("Failed to import plugin %s: %s" % (plugin_name, e))
|
||||
continue
|
||||
pconf = self.pconf
|
||||
new_plugins = []
|
||||
news = [self.plugins[name] for name in self.plugins]
|
||||
new_plugins = list(set(news) - set(raws))
|
||||
modified = False
|
||||
for name, plugincls in self.plugins.items():
|
||||
rawname = plugincls.name
|
||||
if rawname not in pconf["plugins"]:
|
||||
new_plugins.append(plugincls)
|
||||
modified = True
|
||||
logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name)
|
||||
pconf["plugins"][rawname] = {"enabled": plugincls.enabled, "priority": plugincls.priority}
|
||||
@@ -92,14 +110,16 @@ class PluginManager:
|
||||
self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True)
|
||||
|
||||
def activate_plugins(self): # 生成新开启的插件实例
|
||||
failed_plugins = []
|
||||
for name, plugincls in self.plugins.items():
|
||||
if plugincls.enabled:
|
||||
if name not in self.instances:
|
||||
try:
|
||||
instance = plugincls()
|
||||
except Exception as e:
|
||||
logger.warn("Failed to create init %s, diabled. %s" % (name, e))
|
||||
logger.warn("Failed to init %s, diabled. %s" % (name, e))
|
||||
self.disable_plugin(name)
|
||||
failed_plugins.append(name)
|
||||
continue
|
||||
self.instances[name] = instance
|
||||
for event in instance.handlers:
|
||||
@@ -107,6 +127,7 @@ class PluginManager:
|
||||
self.listening_plugins[event] = []
|
||||
self.listening_plugins[event].append(name)
|
||||
self.refresh_order()
|
||||
return failed_plugins
|
||||
|
||||
def reload_plugin(self, name:str):
|
||||
name = name.upper()
|
||||
@@ -156,15 +177,17 @@ class PluginManager:
|
||||
def enable_plugin(self, name:str):
|
||||
name = name.upper()
|
||||
if name not in self.plugins:
|
||||
return False
|
||||
return False, "插件不存在"
|
||||
if not self.plugins[name].enabled :
|
||||
self.plugins[name].enabled = True
|
||||
rawname = self.plugins[name].name
|
||||
self.pconf["plugins"][rawname]["enabled"] = True
|
||||
self.save_config()
|
||||
self.activate_plugins()
|
||||
return True
|
||||
return True
|
||||
failed_plugins = self.activate_plugins()
|
||||
if name in failed_plugins:
|
||||
return False, "插件开启失败"
|
||||
return True, "插件已开启"
|
||||
return True, "插件已开启"
|
||||
|
||||
def disable_plugin(self, name:str):
|
||||
name = name.upper()
|
||||
@@ -179,4 +202,90 @@ class PluginManager:
|
||||
return True
|
||||
|
||||
def list_plugins(self):
|
||||
return self.plugins
|
||||
return self.plugins
|
||||
|
||||
def install_plugin(self, repo:str):
|
||||
try:
|
||||
import common.package_manager as pkgmgr
|
||||
pkgmgr.check_dulwich()
|
||||
except Exception as e:
|
||||
logger.error("Failed to install plugin, {}".format(e))
|
||||
return False, "无法导入dulwich,安装插件失败"
|
||||
import re
|
||||
from dulwich import porcelain
|
||||
|
||||
logger.info("clone git repo: {}".format(repo))
|
||||
|
||||
match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo)
|
||||
|
||||
if not match:
|
||||
try:
|
||||
with open("./plugins/source.json","r", encoding="utf-8") as f:
|
||||
source = json.load(f)
|
||||
if repo in source["repo"]:
|
||||
repo = source["repo"][repo]["url"]
|
||||
match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo)
|
||||
if not match:
|
||||
return False, "安装插件失败,source中的仓库地址不合法"
|
||||
else:
|
||||
return False, "安装插件失败,仓库地址不合法"
|
||||
except Exception as e:
|
||||
logger.error("Failed to install plugin, {}".format(e))
|
||||
return False, "安装插件失败,请检查仓库地址是否正确"
|
||||
dirname = os.path.join("./plugins",match.group(4))
|
||||
try:
|
||||
repo = porcelain.clone(repo, dirname, checkout=True)
|
||||
if os.path.exists(os.path.join(dirname,"requirements.txt")):
|
||||
logger.info("detect requirements.txt,installing...")
|
||||
pkgmgr.install_requirements(os.path.join(dirname,"requirements.txt"))
|
||||
return True, "安装插件成功,请使用 #scanp 命令扫描插件或重启程序,开启前请检查插件是否需要配置"
|
||||
except Exception as e:
|
||||
logger.error("Failed to install plugin, {}".format(e))
|
||||
return False, "安装插件失败,"+str(e)
|
||||
|
||||
def update_plugin(self, name:str):
|
||||
try:
|
||||
import common.package_manager as pkgmgr
|
||||
pkgmgr.check_dulwich()
|
||||
except Exception as e:
|
||||
logger.error("Failed to install plugin, {}".format(e))
|
||||
return False, "无法导入dulwich,更新插件失败"
|
||||
from dulwich import porcelain
|
||||
name = name.upper()
|
||||
if name not in self.plugins:
|
||||
return False, "插件不存在"
|
||||
if name in ["HELLO","GODCMD","ROLE","TOOL","BDUNIT","BANWORDS","FINISH","DUNGEON"]:
|
||||
return False, "预置插件无法更新,请更新主程序仓库"
|
||||
dirname = self.plugins[name].path
|
||||
try:
|
||||
porcelain.pull(dirname, "origin")
|
||||
if os.path.exists(os.path.join(dirname,"requirements.txt")):
|
||||
logger.info("detect requirements.txt,installing...")
|
||||
pkgmgr.install_requirements(os.path.join(dirname,"requirements.txt"))
|
||||
return True, "更新插件成功,请重新运行程序"
|
||||
except Exception as e:
|
||||
logger.error("Failed to update plugin, {}".format(e))
|
||||
return False, "更新插件失败,"+str(e)
|
||||
|
||||
def uninstall_plugin(self, name:str):
|
||||
name = name.upper()
|
||||
if name not in self.plugins:
|
||||
return False, "插件不存在"
|
||||
if name in self.instances:
|
||||
self.disable_plugin(name)
|
||||
dirname = self.plugins[name].path
|
||||
try:
|
||||
import shutil
|
||||
shutil.rmtree(dirname)
|
||||
rawname = self.plugins[name].name
|
||||
for event in self.listening_plugins:
|
||||
if name in self.listening_plugins[event]:
|
||||
self.listening_plugins[event].remove(name)
|
||||
del self.plugins[name]
|
||||
del self.pconf["plugins"][rawname]
|
||||
self.loaded[dirname] = None
|
||||
self.save_config()
|
||||
return True, "卸载插件成功"
|
||||
except Exception as e:
|
||||
logger.error("Failed to uninstall plugin, {}".format(e))
|
||||
return False, "卸载插件失败,请手动删除文件夹完成卸载,"+str(e)
|
||||
@@ -0,0 +1 @@
|
||||
from .role import *
|
||||
@@ -1,71 +0,0 @@
|
||||
{
|
||||
"start":{
|
||||
"host" : "127.0.0.1",
|
||||
"port" : 7860,
|
||||
"use_https" : false
|
||||
},
|
||||
"defaults": {
|
||||
"params": {
|
||||
"sampler_name": "DPM++ 2M Karras",
|
||||
"steps": 20,
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"cfg_scale": 7,
|
||||
"prompt":"masterpiece, best quality",
|
||||
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
|
||||
"enable_hr": false,
|
||||
"hr_scale": 2,
|
||||
"hr_upscaler": "Latent",
|
||||
"hr_second_pass_steps": 15,
|
||||
"denoising_strength": 0.7
|
||||
},
|
||||
"options": {
|
||||
"sd_model_checkpoint": "perfectWorld_v2Baked"
|
||||
}
|
||||
},
|
||||
"rules": [
|
||||
{
|
||||
"keywords": [
|
||||
"横版",
|
||||
"壁纸"
|
||||
],
|
||||
"params": {
|
||||
"width": 640,
|
||||
"height": 384
|
||||
},
|
||||
"desc": "分辨率会变成640x384"
|
||||
},
|
||||
{
|
||||
"keywords": [
|
||||
"竖版"
|
||||
],
|
||||
"params": {
|
||||
"width": 384,
|
||||
"height": 640
|
||||
}
|
||||
},
|
||||
{
|
||||
"keywords": [
|
||||
"高清"
|
||||
],
|
||||
"params": {
|
||||
"enable_hr": true,
|
||||
"hr_scale": 1.6
|
||||
},
|
||||
"desc": "出图分辨率长宽都会提高1.6倍"
|
||||
},
|
||||
{
|
||||
"keywords": [
|
||||
"二次元"
|
||||
],
|
||||
"params": {
|
||||
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
|
||||
"prompt": "masterpiece, best quality"
|
||||
},
|
||||
"options": {
|
||||
"sd_model_checkpoint": "meinamix_meinaV8"
|
||||
},
|
||||
"desc": "使用二次元风格模型出图"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,91 +0,0 @@
|
||||
## 插件描述
|
||||
|
||||
本插件用于将画图请求转发给stable diffusion webui。
|
||||
|
||||
## 环境要求
|
||||
|
||||
使用前先安装stable diffusion webui,并在它的启动参数中添加 "--api"。
|
||||
|
||||
具体信息,请参考[文章](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API)。
|
||||
|
||||
部署运行后,保证主机能够成功访问http://127.0.0.1:7860/docs
|
||||
|
||||
请**安装**本插件的依赖包```webuiapi```
|
||||
|
||||
```
|
||||
pip install webuiapi
|
||||
```
|
||||
|
||||
## 使用说明
|
||||
|
||||
请将`config.json.template`复制为`config.json`,并修改其中的参数和规则。
|
||||
|
||||
PS: 如果修改了webui的`host`和`port`,也需要在配置文件中更改启动参数, 更多启动参数参考:https://github.com/mix1009/sdwebuiapi/blob/a1cb4c6d2f39389d6e962f0e6436f4aa74cd752c/webuiapi/webuiapi.py#L114
|
||||
### 画图请求格式
|
||||
|
||||
用户的画图请求格式为:
|
||||
|
||||
```
|
||||
<画图触发词><关键词1> <关键词2> ... <关键词n>:<prompt>
|
||||
```
|
||||
|
||||
- 本插件会对画图触发词后的关键词进行逐个匹配,如果触发了规则中的关键词,则会在画图请求中重载对应的参数。
|
||||
- 规则的匹配顺序参考`config.json`中的顺序,每个关键词最多被匹配到1次,如果多个关键词触发了重复的参数,重复参数以最后一个关键词为准。
|
||||
- 关键词中包含`help`或`帮助`,会打印出帮助文档。
|
||||
|
||||
第一个"**:**"号之后的内容会作为附加的**prompt**,接在最终的prompt后
|
||||
|
||||
例如: 画横版 高清 二次元:cat
|
||||
|
||||
会触发三个关键词 "横版", "高清", "二次元",prompt为"cat"
|
||||
|
||||
若默认参数是:
|
||||
```json
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
"enable_hr": false,
|
||||
"prompt": "8k"
|
||||
"negative_prompt": "nsfw",
|
||||
"sd_model_checkpoint": "perfectWorld_v2Baked"
|
||||
```
|
||||
|
||||
"横版"触发的规则参数为:
|
||||
```json
|
||||
"width": 640,
|
||||
"height": 384,
|
||||
```
|
||||
|
||||
"高清"触发的规则参数为:
|
||||
```json
|
||||
"enable_hr": true,
|
||||
"hr_scale": 1.6,
|
||||
```
|
||||
|
||||
"二次元"触发的规则参数为:
|
||||
```json
|
||||
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
|
||||
"steps": 20,
|
||||
"prompt": "masterpiece, best quality",
|
||||
|
||||
"sd_model_checkpoint": "meinamix_meinaV8"
|
||||
```
|
||||
|
||||
以上这些规则的参数会和默认参数合并。第一个":"后的内容cat会连接在prompt后。
|
||||
|
||||
得到最终参数为:
|
||||
```json
|
||||
"width": 640,
|
||||
"height": 384,
|
||||
"enable_hr": true,
|
||||
"hr_scale": 1.6,
|
||||
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
|
||||
"steps": 20,
|
||||
"prompt": "masterpiece, best quality, cat",
|
||||
|
||||
"sd_model_checkpoint": "meinamix_meinaV8"
|
||||
```
|
||||
|
||||
PS: 实际参数分为两部分:
|
||||
|
||||
- 一部分是`params`,为画画的参数;参数名**必须**与webuiapi包中[txt2img api](https://github.com/mix1009/sdwebuiapi/blob/fb2054e149c0a4e25125c0cd7e7dca06bda839d4/webuiapi/webuiapi.py#L163)的参数名一致
|
||||
- 另一部分是`options`,指sdwebui的设置,使用的模型和vae需写在里面。它和(http://127.0.0.1:7860/sdapi/v1/options )所返回的键一致。
|
||||
@@ -1,123 +0,0 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import json
|
||||
import os
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from config import conf
|
||||
import plugins
|
||||
from plugins import *
|
||||
from common.log import logger
|
||||
import webuiapi
|
||||
import io
|
||||
|
||||
|
||||
@plugins.register(name="sdwebui", desc="利用stable-diffusion webui来画图", version="2.0", author="lanvent")
|
||||
class SDWebUI(Plugin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
curdir = os.path.dirname(__file__)
|
||||
config_path = os.path.join(curdir, "config.json")
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
self.rules = config["rules"]
|
||||
defaults = config["defaults"]
|
||||
self.default_params = defaults["params"]
|
||||
self.default_options = defaults["options"]
|
||||
self.start_args = config["start"]
|
||||
self.api = webuiapi.WebUIApi(**self.start_args)
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
logger.info("[SD] inited")
|
||||
except Exception as e:
|
||||
if isinstance(e, FileNotFoundError):
|
||||
logger.warn(f"[SD] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/sdwebui .")
|
||||
else:
|
||||
logger.warn("[SD] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/sdwebui .")
|
||||
raise e
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
|
||||
if e_context['context'].type != ContextType.IMAGE_CREATE:
|
||||
return
|
||||
channel = e_context['context'].channel
|
||||
if ReplyType.IMAGE in channel.NOT_SUPPORT_REPLYTYPE:
|
||||
return
|
||||
|
||||
logger.debug("[SD] on_handle_context. content: %s" %e_context['context'].content)
|
||||
|
||||
logger.info("[SD] image_query={}".format(e_context['context'].content))
|
||||
reply = Reply()
|
||||
try:
|
||||
content = e_context['context'].content[:]
|
||||
# 解析用户输入 如"横版 高清 二次元:cat"
|
||||
if ":" in content:
|
||||
keywords, prompt = content.split(":", 1)
|
||||
else:
|
||||
keywords = content
|
||||
prompt = ""
|
||||
|
||||
keywords = keywords.split()
|
||||
|
||||
if "help" in keywords or "帮助" in keywords:
|
||||
reply.type = ReplyType.INFO
|
||||
reply.content = self.get_help_text(verbose = True)
|
||||
else:
|
||||
rule_params = {}
|
||||
rule_options = {}
|
||||
for keyword in keywords:
|
||||
matched = False
|
||||
for rule in self.rules:
|
||||
if keyword in rule["keywords"]:
|
||||
for key in rule["params"]:
|
||||
rule_params[key] = rule["params"][key]
|
||||
if "options" in rule:
|
||||
for key in rule["options"]:
|
||||
rule_options[key] = rule["options"][key]
|
||||
matched = True
|
||||
break # 一个关键词只匹配一个规则
|
||||
if not matched:
|
||||
logger.warning("[SD] keyword not matched: %s" % keyword)
|
||||
|
||||
params = {**self.default_params, **rule_params}
|
||||
options = {**self.default_options, **rule_options}
|
||||
params["prompt"] = params.get("prompt", "")+f", {prompt}"
|
||||
if len(options) > 0:
|
||||
logger.info("[SD] cover options={}".format(options))
|
||||
self.api.set_options(options)
|
||||
logger.info("[SD] params={}".format(params))
|
||||
result = self.api.txt2img(
|
||||
**params
|
||||
)
|
||||
reply.type = ReplyType.IMAGE
|
||||
b_img = io.BytesIO()
|
||||
result.image.save(b_img, format="PNG")
|
||||
reply.content = b_img
|
||||
e_context.action = EventAction.BREAK_PASS # 事件结束后,跳过处理context的默认逻辑
|
||||
except Exception as e:
|
||||
reply.type = ReplyType.ERROR
|
||||
reply.content = "[SD] "+str(e)
|
||||
logger.error("[SD] exception: %s" % e)
|
||||
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
|
||||
finally:
|
||||
e_context['reply'] = reply
|
||||
|
||||
def get_help_text(self, verbose = False, **kwargs):
|
||||
if not conf().get('image_create_prefix'):
|
||||
return "画图功能未启用"
|
||||
else:
|
||||
trigger = conf()['image_create_prefix'][0]
|
||||
help_text = "利用stable-diffusion来画图。\n"
|
||||
if not verbose:
|
||||
return help_text
|
||||
|
||||
help_text += f"使用方法:\n使用\"{trigger}[关键词1] [关键词2]...:提示语\"的格式作画,如\"{trigger}横版 高清:cat\"\n"
|
||||
help_text += "目前可用关键词:\n"
|
||||
for rule in self.rules:
|
||||
keywords = [f"[{keyword}]" for keyword in rule['keywords']]
|
||||
help_text += f"{','.join(keywords)}"
|
||||
if "desc" in rule:
|
||||
help_text += f"-{rule['desc']}\n"
|
||||
else:
|
||||
help_text += "\n"
|
||||
return help_text
|
||||
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"repo": {
|
||||
"sdwebui": {
|
||||
"url": "https://github.com/lanvent/plugin_sdwebui.git",
|
||||
"desc": "利用stable-diffusion画图的插件"
|
||||
},
|
||||
"replicate": {
|
||||
"url": "https://github.com/lanvent/plugin_replicate.git",
|
||||
"desc": "利用replicate api画图的插件"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
from .tool import *
|
||||
@@ -8,6 +8,9 @@ pyttsx3>=2.90 # pytsx text to speech
|
||||
baidu_aip>=4.16.10 # baidu voice
|
||||
# azure-cognitiveservices-speech # azure voice
|
||||
|
||||
#install plugin
|
||||
dulwich
|
||||
|
||||
# wechaty
|
||||
wechaty>=0.10.7
|
||||
wechaty_puppet>=0.4.23
|
||||
@@ -16,9 +19,6 @@ pysilk_mod>=1.6.0 # needed by send voice
|
||||
# wechatmp
|
||||
web.py
|
||||
|
||||
# sdwebui plugin
|
||||
webuiapi>=0.6.2
|
||||
|
||||
# chatgpt-tool-hub plugin
|
||||
--extra-index-url https://pypi.python.org/simple
|
||||
chatgpt_tool_hub>=0.3.5
|
||||
chatgpt_tool_hub>=0.3.7
|
||||
+1
-1
@@ -1,4 +1,4 @@
|
||||
openai>=0.27.2
|
||||
openai==0.27.2
|
||||
HTMLParser>=0.0.2
|
||||
PyQRCode>=1.2.1
|
||||
qrcode>=7.4.2
|
||||
|
||||
Reference in New Issue
Block a user