mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-05-02 08:55:39 +08:00
feat: support plugins
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -7,3 +7,4 @@ config.json
|
|||||||
QR.png
|
QR.png
|
||||||
nohup.out
|
nohup.out
|
||||||
tmp
|
tmp
|
||||||
|
plugins.json
|
||||||
7
app.py
7
app.py
@@ -4,14 +4,17 @@ import config
|
|||||||
from channel import channel_factory
|
from channel import channel_factory
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
|
|
||||||
|
from plugins import *
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
try:
|
try:
|
||||||
# load config
|
# load config
|
||||||
config.load_config()
|
config.load_config()
|
||||||
|
|
||||||
# create channel
|
# create channel
|
||||||
channel = channel_factory.create_channel("wx")
|
channel_name='wx'
|
||||||
|
channel = channel_factory.create_channel(channel_name)
|
||||||
|
if channel_name=='wx':
|
||||||
|
PluginManager().load_plugins()
|
||||||
|
|
||||||
# startup channel
|
# startup channel
|
||||||
channel.startup()
|
channel.startup()
|
||||||
|
|||||||
@@ -60,12 +60,13 @@ class ChatGPTBot(Bot):
|
|||||||
ok, retstring = self.create_img(query, 0)
|
ok, retstring = self.create_img(query, 0)
|
||||||
reply = None
|
reply = None
|
||||||
if ok:
|
if ok:
|
||||||
reply = {'type': 'IMAGE', 'content': retstring}
|
reply = {'type': 'IMAGE_URL', 'content': retstring}
|
||||||
else:
|
else:
|
||||||
reply = {'type': 'ERROR', 'content': retstring}
|
reply = {'type': 'ERROR', 'content': retstring}
|
||||||
return reply
|
return reply
|
||||||
else:
|
else:
|
||||||
reply= {'type':'ERROR', 'content':'Bot不支持处理{}类型的消息'.format(context['type'])}
|
reply= {'type':'ERROR', 'content':'Bot不支持处理{}类型的消息'.format(context['type'])}
|
||||||
|
return reply
|
||||||
|
|
||||||
def reply_text(self, session, session_id, retry_count=0) -> dict:
|
def reply_text(self, session, session_id, retry_count=0) -> dict:
|
||||||
'''
|
'''
|
||||||
@@ -139,7 +140,11 @@ class ChatGPTBot(Bot):
|
|||||||
|
|
||||||
class SessionManager(object):
|
class SessionManager(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.sessions = {}
|
if conf().get('expires_in_seconds'):
|
||||||
|
sessions = ExpiredDict(conf().get('expires_in_seconds'))
|
||||||
|
else:
|
||||||
|
sessions = dict()
|
||||||
|
self.sessions = sessions
|
||||||
|
|
||||||
def build_session_query(self, query, session_id):
|
def build_session_query(self, query, session_id):
|
||||||
'''
|
'''
|
||||||
|
|||||||
@@ -12,9 +12,12 @@ from concurrent.futures import ThreadPoolExecutor
|
|||||||
from common.log import logger
|
from common.log import logger
|
||||||
from common.tmp_dir import TmpDir
|
from common.tmp_dir import TmpDir
|
||||||
from config import conf
|
from config import conf
|
||||||
|
from plugins import *
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import io
|
import io
|
||||||
|
|
||||||
|
|
||||||
thread_pool = ThreadPoolExecutor(max_workers=8)
|
thread_pool = ThreadPoolExecutor(max_workers=8)
|
||||||
|
|
||||||
|
|
||||||
@@ -49,8 +52,8 @@ class WechatChannel(Channel):
|
|||||||
|
|
||||||
# handle_* 系列函数处理收到的消息后构造context,然后调用handle函数处理context
|
# handle_* 系列函数处理收到的消息后构造context,然后调用handle函数处理context
|
||||||
# context是一个字典,包含了消息的所有信息,包括以下key
|
# context是一个字典,包含了消息的所有信息,包括以下key
|
||||||
# type: 消息类型,包括TEXT、VOICE、CMD_IMAGE_CREATE
|
# type: 消息类型,包括TEXT、VOICE、IMAGE_CREATE
|
||||||
# content: 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是CMD_IMAGE_CREATE类型,content就是图片生成命令
|
# content: 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令
|
||||||
# session_id: 会话id
|
# session_id: 会话id
|
||||||
# isgroup: 是否是群聊
|
# isgroup: 是否是群聊
|
||||||
# msg: 原始消息对象
|
# msg: 原始消息对象
|
||||||
@@ -88,7 +91,7 @@ class WechatChannel(Channel):
|
|||||||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
|
img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
|
||||||
if img_match_prefix:
|
if img_match_prefix:
|
||||||
content = content.replace(img_match_prefix, '', 1).strip()
|
content = content.replace(img_match_prefix, '', 1).strip()
|
||||||
context['type'] = 'CMD_IMAGE_CREATE'
|
context['type'] = 'IMAGE_CREATE'
|
||||||
else:
|
else:
|
||||||
context['type'] = 'TEXT'
|
context['type'] = 'TEXT'
|
||||||
|
|
||||||
@@ -121,7 +124,7 @@ class WechatChannel(Channel):
|
|||||||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
|
img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
|
||||||
if img_match_prefix:
|
if img_match_prefix:
|
||||||
content = content.replace(img_match_prefix, '', 1).strip()
|
content = content.replace(img_match_prefix, '', 1).strip()
|
||||||
context['type'] = 'CMD_IMAGE_CREATE'
|
context['type'] = 'IMAGE_CREATE'
|
||||||
else:
|
else:
|
||||||
context['type'] = 'TEXT'
|
context['type'] = 'TEXT'
|
||||||
context['content'] = content
|
context['content'] = content
|
||||||
@@ -136,8 +139,7 @@ class WechatChannel(Channel):
|
|||||||
|
|
||||||
thread_pool.submit(self.handle, context)
|
thread_pool.submit(self.handle, context)
|
||||||
|
|
||||||
# 统一的发送函数,根据reply的type字段发送不同类型的消息
|
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
||||||
|
|
||||||
def send(self, reply, receiver):
|
def send(self, reply, receiver):
|
||||||
if reply['type'] == 'TEXT':
|
if reply['type'] == 'TEXT':
|
||||||
itchat.send(reply['content'], toUserName=receiver)
|
itchat.send(reply['content'], toUserName=receiver)
|
||||||
@@ -163,54 +165,63 @@ class WechatChannel(Channel):
|
|||||||
itchat.send_image(image_storage, toUserName=receiver)
|
itchat.send_image(image_storage, toUserName=receiver)
|
||||||
logger.info('[WX] sendImage, receiver={}'.format(receiver))
|
logger.info('[WX] sendImage, receiver={}'.format(receiver))
|
||||||
|
|
||||||
# 处理消息
|
# 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类
|
||||||
def handle(self, context):
|
def handle(self, context):
|
||||||
content = context['content']
|
reply = {}
|
||||||
reply = None
|
|
||||||
|
|
||||||
logger.debug('[WX] ready to handle context: {}'.format(context))
|
logger.debug('[WX] ready to handle context: {}'.format(context))
|
||||||
|
|
||||||
# reply的构建步骤
|
# reply的构建步骤
|
||||||
if context['type'] == 'TEXT' or context['type'] == 'CMD_IMAGE_CREATE':
|
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply}))
|
||||||
reply = super().build_reply_content(content, context)
|
reply=e_context['reply']
|
||||||
elif context['type'] == 'VOICE':
|
if not e_context.is_pass():
|
||||||
msg = context['msg']
|
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context['type'], context['content']))
|
||||||
file_name = TmpDir().path() + msg['FileName']
|
if context['type'] == 'TEXT' or context['type'] == 'IMAGE_CREATE':
|
||||||
msg.download(file_name)
|
reply = super().build_reply_content(context['content'], context)
|
||||||
reply = super().build_voice_to_text(file_name)
|
elif context['type'] == 'VOICE':
|
||||||
if reply['type'] != 'ERROR' and reply['type'] != 'INFO':
|
msg = context['msg']
|
||||||
reply = super().build_reply_content(reply['content'], context)
|
file_name = TmpDir().path() + msg['FileName']
|
||||||
if reply['type'] == 'TEXT':
|
msg.download(file_name)
|
||||||
if conf().get('voice_reply_voice'):
|
reply = super().build_voice_to_text(file_name)
|
||||||
reply = super().build_text_to_voice(reply['content'])
|
if reply['type'] != 'ERROR' and reply['type'] != 'INFO':
|
||||||
else:
|
reply = super().build_reply_content(reply['content'], context)
|
||||||
logger.error('[WX] unknown context type: {}'.format(context['type']))
|
if reply['type'] == 'TEXT':
|
||||||
return
|
if conf().get('voice_reply_voice'):
|
||||||
|
reply = super().build_text_to_voice(reply['content'])
|
||||||
|
else:
|
||||||
|
logger.error('[WX] unknown context type: {}'.format(context['type']))
|
||||||
|
return
|
||||||
|
|
||||||
logger.debug('[WX] ready to decorate reply: {}'.format(reply))
|
logger.debug('[WX] ready to decorate reply: {}'.format(reply))
|
||||||
|
|
||||||
# reply的包装步骤
|
# reply的包装步骤
|
||||||
if reply:
|
if reply and reply['type']:
|
||||||
if reply['type'] == 'TEXT':
|
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
|
||||||
reply_text = reply['content']
|
reply=e_context['reply']
|
||||||
if context['isgroup']:
|
if not e_context.is_pass() and reply and reply['type']:
|
||||||
reply_text = '@' + \
|
if reply['type'] == 'TEXT':
|
||||||
context['msg']['ActualNickName'] + \
|
reply_text = reply['content']
|
||||||
' ' + reply_text.strip()
|
if context['isgroup']:
|
||||||
reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
|
reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip()
|
||||||
|
reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
|
||||||
|
else:
|
||||||
|
reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
|
||||||
|
reply['content'] = reply_text
|
||||||
|
elif reply['type'] == 'ERROR' or reply['type'] == 'INFO':
|
||||||
|
reply['content'] = reply['type']+": " + reply['content']
|
||||||
|
elif reply['type'] == 'IMAGE_URL' or reply['type'] == 'VOICE' or reply['type'] == 'IMAGE':
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
|
logger.error('[WX] unknown reply type: {}'.format(reply['type']))
|
||||||
reply['content'] = reply_text
|
return
|
||||||
elif reply['type'] == 'ERROR' or reply['type'] == 'INFO':
|
|
||||||
reply['content'] = reply['type']+": " + reply['content']
|
# reply的发送步骤
|
||||||
elif reply['type'] == 'IMAGE_URL' or reply['type'] == 'VOICE':
|
if reply and reply['type']:
|
||||||
pass
|
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
|
||||||
else:
|
reply=e_context['reply']
|
||||||
logger.error(
|
if not e_context.is_pass() and reply and reply['type']:
|
||||||
'[WX] unknown reply type: {}'.format(reply['type']))
|
logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver']))
|
||||||
return
|
self.send(reply, context['receiver'])
|
||||||
if reply:
|
|
||||||
logger.debug('[WX] ready to send reply: {} to {}'.format(
|
|
||||||
reply, context['receiver']))
|
|
||||||
self.send(reply, context['receiver'])
|
|
||||||
|
|
||||||
|
|
||||||
def check_prefix(content, prefix_list):
|
def check_prefix(content, prefix_list):
|
||||||
|
|||||||
9
plugins/__init__.py
Normal file
9
plugins/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from .plugin_manager import PluginManager
|
||||||
|
from .event import *
|
||||||
|
from .plugin import *
|
||||||
|
|
||||||
|
instance = PluginManager()
|
||||||
|
|
||||||
|
register = instance.register
|
||||||
|
# load_plugins = instance.load_plugins
|
||||||
|
# emit_event = instance.emit_event
|
||||||
49
plugins/event.py
Normal file
49
plugins/event.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# encoding:utf-8
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class Event(Enum):
|
||||||
|
# ON_RECEIVE_MESSAGE = 1 # 收到消息
|
||||||
|
|
||||||
|
ON_HANDLE_CONTEXT = 2 # 处理消息前
|
||||||
|
"""
|
||||||
|
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空 }
|
||||||
|
"""
|
||||||
|
|
||||||
|
ON_DECORATE_REPLY = 3 # 得到回复后准备装饰
|
||||||
|
"""
|
||||||
|
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 }
|
||||||
|
"""
|
||||||
|
|
||||||
|
ON_SEND_REPLY = 4 # 发送回复前
|
||||||
|
"""
|
||||||
|
e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 }
|
||||||
|
"""
|
||||||
|
|
||||||
|
# AFTER_SEND_REPLY = 5 # 发送回复后
|
||||||
|
|
||||||
|
|
||||||
|
class EventAction(Enum):
|
||||||
|
CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑
|
||||||
|
BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑
|
||||||
|
BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑
|
||||||
|
|
||||||
|
|
||||||
|
class EventContext:
|
||||||
|
def __init__(self, event, econtext=dict()):
|
||||||
|
self.event = event
|
||||||
|
self.econtext = econtext
|
||||||
|
self.action = EventAction.CONTINUE
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self.econtext[key]
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
self.econtext[key] = value
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
del self.econtext[key]
|
||||||
|
|
||||||
|
def is_pass(self):
|
||||||
|
return self.action == EventAction.BREAK_PASS
|
||||||
3
plugins/plugin.py
Normal file
3
plugins/plugin.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
class Plugin:
|
||||||
|
def __init__(self):
|
||||||
|
self.handlers = {}
|
||||||
89
plugins/plugin_manager.py
Normal file
89
plugins/plugin_manager.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
# encoding:utf-8
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from common.singleton import singleton
|
||||||
|
from .event import *
|
||||||
|
from .plugin import *
|
||||||
|
from common.log import logger
|
||||||
|
|
||||||
|
|
||||||
|
@singleton
|
||||||
|
class PluginManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.plugins = {}
|
||||||
|
self.listening_plugins = {}
|
||||||
|
self.instances = {}
|
||||||
|
|
||||||
|
def register(self, name: str, desc: str, version: str, author: str):
|
||||||
|
def wrapper(plugincls):
|
||||||
|
self.plugins[name] = plugincls
|
||||||
|
plugincls.name = name
|
||||||
|
plugincls.desc = desc
|
||||||
|
plugincls.version = version
|
||||||
|
plugincls.author = author
|
||||||
|
plugincls.enabled = True
|
||||||
|
logger.info("Plugin %s registered" % name)
|
||||||
|
return plugincls
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def save_config(self, pconf):
|
||||||
|
with open("plugins/plugins.json", "w", encoding="utf-8") as f:
|
||||||
|
json.dump(pconf, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
def load_config(self):
|
||||||
|
logger.info("Loading plugins config...")
|
||||||
|
plugins_dir = "plugins"
|
||||||
|
for plugin_name in os.listdir(plugins_dir):
|
||||||
|
plugin_path = os.path.join(plugins_dir, plugin_name)
|
||||||
|
if os.path.isdir(plugin_path):
|
||||||
|
# 判断插件是否包含main.py文件
|
||||||
|
main_module_path = os.path.join(plugin_path, "main.py")
|
||||||
|
if os.path.isfile(main_module_path):
|
||||||
|
# 导入插件的main
|
||||||
|
import_path = "{}.{}.main".format(plugins_dir, plugin_name)
|
||||||
|
main_module = importlib.import_module(import_path)
|
||||||
|
|
||||||
|
modified = False
|
||||||
|
if os.path.exists("plugins/plugins.json"):
|
||||||
|
with open("plugins/plugins.json", "r", encoding="utf-8") as f:
|
||||||
|
pconf = json.load(f)
|
||||||
|
else:
|
||||||
|
modified = True
|
||||||
|
pconf = {"plugins": []}
|
||||||
|
for name, plugincls in self.plugins.items():
|
||||||
|
if name not in [plugin["name"] for plugin in pconf["plugins"]]:
|
||||||
|
modified = True
|
||||||
|
logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name)
|
||||||
|
pconf["plugins"].append({"name": name, "enabled": True})
|
||||||
|
if modified:
|
||||||
|
self.save_config(pconf)
|
||||||
|
return pconf
|
||||||
|
|
||||||
|
def load_plugins(self):
|
||||||
|
pconf = self.load_config()
|
||||||
|
|
||||||
|
for plugin in pconf["plugins"]:
|
||||||
|
name = plugin["name"]
|
||||||
|
enabled = plugin["enabled"]
|
||||||
|
self.plugins[name].enabled = enabled
|
||||||
|
|
||||||
|
for name, plugincls in self.plugins.items():
|
||||||
|
if plugincls.enabled:
|
||||||
|
if name not in self.instances:
|
||||||
|
instance = plugincls()
|
||||||
|
self.instances[name] = instance
|
||||||
|
for event in instance.handlers:
|
||||||
|
if event not in self.listening_plugins:
|
||||||
|
self.listening_plugins[event] = []
|
||||||
|
self.listening_plugins[event].append(name)
|
||||||
|
|
||||||
|
def emit_event(self, e_context: EventContext, *args, **kwargs):
|
||||||
|
if e_context.event in self.listening_plugins:
|
||||||
|
for name in self.listening_plugins[e_context.event]:
|
||||||
|
if e_context.action == EventAction.CONTINUE:
|
||||||
|
logger.debug("Plugin %s triggered by event %s" % (name,e_context.event))
|
||||||
|
instance = self.instances[name]
|
||||||
|
instance.handlers[e_context.event](e_context, *args, **kwargs)
|
||||||
|
return e_context
|
||||||
Reference in New Issue
Block a user