Compare commits

...

9 Commits

Author SHA1 Message Date
Jianglang 9a20c1cb02 Update README.md 2023-04-07 00:43:47 +08:00
Jianglang 176f77ba5b Update README.md 2023-04-07 00:35:06 +08:00
lanvent 484de6237b feat: terminal support plugins 2023-04-06 23:55:25 +08:00
lanvent 898aa30b1d godcmd: add temp passwd 2023-04-06 21:57:02 +08:00
lanvent 8b73a74609 fix: bug when reinstall plugin 2023-04-06 21:54:38 +08:00
lanvent 3c6d42b22e feat: add installp/uninstallp command 2023-04-06 21:54:38 +08:00
lanvent 40563c1e96 plugins: remove sdwebui 2023-04-06 21:54:37 +08:00
lanvent cb0c86ec1c fix: a typo in sdwebui 2023-04-06 21:25:07 +08:00
Jianglang 614f3b1ea4 Update README.md 2023-04-06 14:15:49 +08:00
26 changed files with 330 additions and 355 deletions
+10 -1
View File
@@ -11,4 +11,13 @@ tmp
plugins.json plugins.json
itchat.pkl itchat.pkl
*.log *.log
user_datas.pkl user_datas.pkl
plugins/**/
!plugins/bdunit
!plugins/dungeon
!plugins/finish
!plugins/godcmd
!plugins/tool
!plugins/banwords
!plugins/hello
!plugins/role
-4
View File
@@ -62,10 +62,6 @@
> 项目中使用的对话模型是 davinci,计费方式是约每 750 字 (包含请求和回复) 消耗 $0.02,图片生成是每张消耗 $0.016,账号创建有免费的 $18 额度 (更新3.25: 最新注册的已经无免费额度了),使用完可以更换邮箱重新注册。 > 项目中使用的对话模型是 davinci,计费方式是约每 750 字 (包含请求和回复) 消耗 $0.02,图片生成是每张消耗 $0.016,账号创建有免费的 $18 额度 (更新3.25: 最新注册的已经无免费额度了),使用完可以更换邮箱重新注册。
#### 1.1 ChapGPT service On Azure
一种替换以上的方法是使用Azure推出的[ChatGPT service](https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/)。它host在公有云Azure上,因此不需要VPN就可以直接访问。不过目前仍然处于preview阶段。新用户可以通过Try Azure for free来薅一段时间的羊毛
### 2.运行环境 ### 2.运行环境
支持 Linux、MacOS、Windows 系统(可在Linux服务器上长期运行),同时需安装 `Python` 支持 Linux、MacOS、Windows 系统(可在Linux服务器上长期运行),同时需安装 `Python`
+5 -1
View File
@@ -27,12 +27,16 @@ def run():
# create channel # create channel
channel_name=conf().get('channel_type', 'wx') channel_name=conf().get('channel_type', 'wx')
if "--cmd" in sys.argv:
channel_name = 'terminal'
if channel_name == 'wxy': if channel_name == 'wxy':
os.environ['WECHATY_LOG']="warn" os.environ['WECHATY_LOG']="warn"
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001' # os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
channel = channel_factory.create_channel(channel_name) channel = channel_factory.create_channel(channel_name)
if channel_name in ['wx','wxy','wechatmp']: if channel_name in ['wx','wxy','wechatmp','terminal']:
PluginManager().load_plugins() PluginManager().load_plugins()
# startup channel # startup channel
+3 -3
View File
@@ -51,7 +51,7 @@ class ChatChannel(Channel):
if cmsg.from_user_id == self.user_id and not config.get('trigger_by_self', True): if cmsg.from_user_id == self.user_id and not config.get('trigger_by_self', True):
logger.debug("[WX]self message skipped") logger.debug("[WX]self message skipped")
return None return None
if context["isgroup"]: if context.get("isgroup", False):
group_name = cmsg.other_user_nickname group_name = cmsg.other_user_nickname
group_id = cmsg.other_user_id group_id = cmsg.other_user_id
@@ -76,7 +76,7 @@ class ChatChannel(Channel):
logger.debug("[WX]reference query skipped") logger.debug("[WX]reference query skipped")
return None return None
if context["isgroup"]: # 群聊 if context.get("isgroup", False): # 群聊
# 校验关键字 # 校验关键字
match_prefix = check_prefix(content, conf().get('group_chat_prefix')) match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
match_contain = check_contain(content, conf().get('group_chat_keyword')) match_contain = check_contain(content, conf().get('group_chat_keyword'))
@@ -193,7 +193,7 @@ class ChatChannel(Channel):
if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
reply = super().build_text_to_voice(reply.content) reply = super().build_text_to_voice(reply.content)
return self._decorate_reply(context, reply) return self._decorate_reply(context, reply)
if context['isgroup']: if context.get("isgroup", False):
reply_text = '@' + context['msg'].actual_user_nickname + ' ' + reply_text.strip() reply_text = '@' + context['msg'].actual_user_nickname + ' ' + reply_text.strip()
reply_text = conf().get("group_chat_reply_prefix", "") + reply_text reply_text = conf().get("group_chat_reply_prefix", "") + reply_text
else: else:
+62 -15
View File
@@ -1,31 +1,78 @@
from bridge.context import * from bridge.context import *
from channel.channel import Channel from bridge.reply import Reply, ReplyType
from channel.chat_channel import ChatChannel, check_prefix
from channel.chat_message import ChatMessage
import sys import sys
class TerminalChannel(Channel): from config import conf
from common.log import logger
class TerminalMessage(ChatMessage):
def __init__(self, msg_id, content, ctype = ContextType.TEXT, from_user_id = "User", to_user_id = "Chatgpt", other_user_id = "Chatgpt"):
self.msg_id = msg_id
self.ctype = ctype
self.content = content
self.from_user_id = from_user_id
self.to_user_id = to_user_id
self.other_user_id = other_user_id
class TerminalChannel(ChatChannel):
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE]
def send(self, reply: Reply, context: Context):
print("\nBot:")
if reply.type == ReplyType.IMAGE:
from PIL import Image
image_storage = reply.content
image_storage.seek(0)
img = Image.open(image_storage)
print("<IMAGE>")
img.show()
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
from PIL import Image
import requests,io
img_url = reply.content
pic_res = requests.get(img_url, stream=True)
image_storage = io.BytesIO()
for block in pic_res.iter_content(1024):
image_storage.write(block)
image_storage.seek(0)
img = Image.open(image_storage)
print(img_url)
img.show()
else:
print(reply.content)
print("\nUser:", end="")
sys.stdout.flush()
return
def startup(self): def startup(self):
context = Context() context = Context()
print("\nPlease input your question") logger.setLevel("WARN")
print("\nPlease input your question:\nUser:", end="")
sys.stdout.flush()
msg_id = 0
while True: while True:
try: try:
prompt = self.get_input("User:\n") prompt = self.get_input()
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nExiting...") print("\nExiting...")
sys.exit() sys.exit()
msg_id += 1
trigger_prefixs = conf().get("single_chat_prefix",[""])
if check_prefix(prompt, trigger_prefixs) is None:
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
context = self._compose_context(ContextType.TEXT, prompt, msg = TerminalMessage(msg_id, prompt))
if context:
self.produce(context)
else:
raise Exception("context is None")
context.type = ContextType.TEXT def get_input(self):
context['session_id'] = "User"
context.content = prompt
print("Bot:")
sys.stdout.flush()
res = super().build_reply_content(prompt, context).content
print(res)
def get_input(self, prompt):
""" """
Multi-line input function Multi-line input function
""" """
print(prompt, end="") sys.stdout.flush()
line = input() line = input()
return line return line
+1 -1
View File
@@ -41,7 +41,7 @@ class WechatMPChannel(ChatChannel):
urls = ( urls = (
'/wx', 'SubsribeAccountQuery', '/wx', 'SubsribeAccountQuery',
) )
app = web.application(urls, globals()) app = web.application(urls, globals(), autoreload=False)
port = conf().get('wechatmp_port', 8080) port = conf().get('wechatmp_port', 8080)
web.httpserver.runsimple(app.wsgifunc(), ('0.0.0.0', port)) web.httpserver.runsimple(app.wsgifunc(), ('0.0.0.0', port))
+12 -3
View File
@@ -2,9 +2,13 @@ import logging
import sys import sys
def _get_logger(): def _reset_logger(log):
log = logging.getLogger('log') for handler in log.handlers:
log.setLevel(logging.INFO) handler.close()
log.removeHandler(handler)
del handler
log.handlers.clear()
log.propagate = False
console_handle = logging.StreamHandler(sys.stdout) console_handle = logging.StreamHandler(sys.stdout)
console_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s', console_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')) datefmt='%Y-%m-%d %H:%M:%S'))
@@ -13,6 +17,11 @@ def _get_logger():
datefmt='%Y-%m-%d %H:%M:%S')) datefmt='%Y-%m-%d %H:%M:%S'))
log.addHandler(file_handle) log.addHandler(file_handle)
log.addHandler(console_handle) log.addHandler(console_handle)
def _get_logger():
log = logging.getLogger('log')
_reset_logger(log)
log.setLevel(logging.INFO)
return log return log
+30
View File
@@ -0,0 +1,30 @@
import time
import pip
from pip._internal import main as pipmain
from common.log import logger,_reset_logger
def install(package):
pipmain(['install', package])
def install_requirements(file):
pipmain(['install', '-r', file, "--upgrade"])
_reset_logger(logger)
def check_dulwich():
needwait = False
for i in range(2):
if needwait:
time.sleep(3)
needwait = False
try:
import dulwich
return
except ImportError:
try:
install('dulwich')
except:
needwait = True
try:
import dulwich
except ImportError:
raise ImportError("Unable to import dulwich")
+36 -3
View File
@@ -1,3 +1,11 @@
**Table of Content**
- [插件化初衷](#插件化初衷)
- [插件安装方法](#插件化安装方法)
- [插件化实现](#插件化实现)
- [插件编写示例](#插件编写示例)
- [插件设计建议](#插件设计建议)
## 插件化初衷 ## 插件化初衷
之前未插件化的代码耦合程度高,如果要定制一些个性化功能(如流量控制、接入`NovelAI`画图平台等),需要了解代码主体,避免影响到其他的功能。多个功能同时存在时,无法调整功能的优先级顺序,功能配置项也非常混乱。 之前未插件化的代码耦合程度高,如果要定制一些个性化功能(如流量控制、接入`NovelAI`画图平台等),需要了解代码主体,避免影响到其他的功能。多个功能同时存在时,无法调整功能的优先级顺序,功能配置项也非常混乱。
@@ -11,7 +19,23 @@
- [x] 插件化能够自由开关和调整优先级。 - [x] 插件化能够自由开关和调整优先级。
- [x] 每个插件可在插件文件夹内维护独立的配置文件,方便代码的测试和调试,可以在独立的仓库开发插件。 - [x] 每个插件可在插件文件夹内维护独立的配置文件,方便代码的测试和调试,可以在独立的仓库开发插件。
PS: 插件目前支持`itchat``wechaty` ## 插件安装方法
在本仓库中预置了一些插件,如果要安装其他仓库的插件,有两种方法。
- 第一种方法是在将下载的插件文件都解压到"plugins"文件夹的一个单独的文件夹,最终插件的代码都位于"plugins/PLUGIN_NAME/*"中。启动程序后,如果插件的目录结构正确,插件会自动被扫描加载。
- 第二种方法是`Godcmd`插件,它是预置的管理员插件,能够让程序在运行时就能安装插件。
安装插件的命令是"#installp [仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件名/仓库地址"。这是管理员命令,认证方法在[这里](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/godcmd)。
- 安装[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件:#installp sdwebui
- 安装指定仓库的插件:#installp https://github.com/lanvent/plugin_sdwebui.git
在安装之后,需要执行"#scanp"命令来扫描加载新安装的插件(或者重新启动程序)。
安装插件后需要注意有些插件有自己的配置模板,一般要去掉".template"新建一个配置文件。
## 插件化实现 ## 插件化实现
@@ -26,7 +50,7 @@ PS: 插件目前支持`itchat`和`wechaty`
1.收到消息 ---> 2.产生回复 ---> 3.包装回复 ---> 4.发送回复 1.收到消息 ---> 2.产生回复 ---> 3.包装回复 ---> 4.发送回复
``` ```
以下是它们的默认处理逻辑(太长不看,可跳) 以下是它们的默认处理逻辑(太长不看,可跳到[插件编写示例](#插件编写示例))
#### 1. 收到消息 #### 1. 收到消息
@@ -154,7 +178,8 @@ PS: 插件目前支持`itchat`和`wechaty`
### 1. 创建插件 ### 1. 创建插件
`plugins`目录下创建一个插件文件夹`hello`。然后,在该文件夹中创建一个与文件夹同名的`.py`文件`hello.py` `plugins`目录下创建一个插件文件夹`hello`。然后,在该文件夹中创建``__init__.py``文件,在``__init__.py``中将其他编写的模块文件导入。在程序启动时,插件管理器会读取``__init__.py``的所有内容
``` ```
plugins/ plugins/
└── hello └── hello
@@ -162,6 +187,11 @@ plugins/
└── hello.py └── hello.py
``` ```
``__init__.py``的内容:
```
from .hello import *
```
### 2. 编写插件类 ### 2. 编写插件类
在`hello.py`文件中,创建插件类,它继承自`Plugin`。 在`hello.py`文件中,创建插件类,它继承自`Plugin`。
@@ -234,5 +264,8 @@ class Hello(Plugin):
- 尽情将你想要的个性化功能设计为插件。 - 尽情将你想要的个性化功能设计为插件。
- 一个插件目录建议只注册一个插件类。建议使用单独的仓库维护插件,便于更新。 - 一个插件目录建议只注册一个插件类。建议使用单独的仓库维护插件,便于更新。
在测试调试好后提交`PR`,把自己的仓库加入到[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)中。
- 插件的config文件、使用说明`README.md`、`requirement.txt`等放置在插件目录中。 - 插件的config文件、使用说明`README.md`、`requirement.txt`等放置在插件目录中。
- 默认优先级不要超过管理员插件`Godcmd`的优先级(999)`Godcmd`插件提供了配置管理、插件管理等功能。 - 默认优先级不要超过管理员插件`Godcmd`的优先级(999)`Godcmd`插件提供了配置管理、插件管理等功能。
+1
View File
@@ -0,0 +1 @@
from .banwords import *
+1
View File
@@ -0,0 +1 @@
from .bdunit import *
+1
View File
@@ -0,0 +1 @@
from .dungeon import *
+1
View File
@@ -0,0 +1 @@
from .finish import *
+9 -3
View File
@@ -6,7 +6,13 @@
`config.json.template`复制为`config.json`,并修改其中`password`的值为口令。 `config.json.template`复制为`config.json`,并修改其中`password`的值为口令。
在私聊中可使用`#auth`指令,输入口令进行管理员认证,详细指令请输入`#help`查看帮助文档: 如果没有设置命令,在命令行日志中会打印出本次的临时口令,请注意观察,打印格式如下。
`#auth <口令>` - 管理员认证。 ```
`#help` - 输出帮助文档,是否是管理员和是否是在群聊中会影响帮助文档的输出内容 [INFO][2023-04-06 23:53:47][godcmd.py:165] - [Godcmd] 因未设置口令,本次的临时口令为0971
```
在私聊中可使用`#auth`指令,输入口令进行管理员认证。更多详细指令请输入`#help`查看帮助文档:
`#auth <口令>` - 管理员认证,仅可在私聊时认证。
`#help` - 输出帮助文档,**是否是管理员**和是否是在群聊中会影响帮助文档的输出内容。
+1
View File
@@ -0,0 +1 @@
from .godcmd import *
+43 -18
View File
@@ -2,6 +2,8 @@
import json import json
import os import os
import random
import string
import traceback import traceback
from typing import Tuple from typing import Tuple
from bridge.bridge import Bridge from bridge.bridge import Bridge
@@ -37,10 +39,10 @@ COMMANDS = {
"alias": ["reset_openai_api_key"], "alias": ["reset_openai_api_key"],
"desc": "重置为默认的api_key", "desc": "重置为默认的api_key",
}, },
# "id": { "id": {
# "alias": ["id", "用户"], "alias": ["id", "用户"],
# "desc": "获取用户id", #目前无实际意义 "desc": "获取用户id", # wechaty和wechatmp的用户id不会变化,可用于绑定管理员
# }, },
"reset": { "reset": {
"alias": ["reset", "重置会话"], "alias": ["reset", "重置会话"],
"desc": "重置会话", "desc": "重置会话",
@@ -92,6 +94,16 @@ ADMIN_COMMANDS = {
"args": ["插件名"], "args": ["插件名"],
"desc": "禁用指定插件", "desc": "禁用指定插件",
}, },
"installp": {
"alias": ["installp", "安装插件"],
"args": ["仓库地址或插件名"],
"desc": "安装指定插件",
},
"uninstallp": {
"alias": ["uninstallp", "卸载插件"],
"args": ["插件名"],
"desc": "卸载指定插件",
},
"debug": { "debug": {
"alias": ["debug", "调试模式", "DEBUG"], "alias": ["debug", "调试模式", "DEBUG"],
"desc": "开启机器调试日志", "desc": "开启机器调试日志",
@@ -103,7 +115,9 @@ def get_help_text(isadmin, isgroup):
for cmd, info in COMMANDS.items(): for cmd, info in COMMANDS.items():
if cmd=="auth": #不提示认证指令 if cmd=="auth": #不提示认证指令
continue continue
alias=["#"+a for a in info['alias']] if cmd=="id" and conf().get("channel_type","wx") not in ["wxy","wechatmp"]:
continue
alias=["#"+a for a in info['alias'][:1]]
help_text += f"{','.join(alias)} " help_text += f"{','.join(alias)} "
if 'args' in info: if 'args' in info:
args=[a for a in info['args']] args=[a for a in info['args']]
@@ -122,7 +136,7 @@ def get_help_text(isadmin, isgroup):
if ADMIN_COMMANDS and isadmin: if ADMIN_COMMANDS and isadmin:
help_text += "\n\n管理员指令:\n" help_text += "\n\n管理员指令:\n"
for cmd, info in ADMIN_COMMANDS.items(): for cmd, info in ADMIN_COMMANDS.items():
alias=["#"+a for a in info['alias']] alias=["#"+a for a in info['alias'][:1]]
help_text += f"{','.join(alias)} " help_text += f"{','.join(alias)} "
if 'args' in info: if 'args' in info:
args=[a for a in info['args']] args=[a for a in info['args']]
@@ -146,7 +160,11 @@ class Godcmd(Plugin):
else: else:
with open(config_path,"r") as f: with open(config_path,"r") as f:
gconf=json.load(f) gconf=json.load(f)
if gconf["password"] == "":
self.temp_password = "".join(random.sample(string.digits, 4))
logger.info("[Godcmd] 因未设置口令,本次的临时口令为%s"%self.temp_password)
else:
self.temp_password = None
custom_commands = conf().get("clear_memory_commands", []) custom_commands = conf().get("clear_memory_commands", [])
for custom_command in custom_commands: for custom_command in custom_commands:
if custom_command and custom_command.startswith("#"): if custom_command and custom_command.startswith("#"):
@@ -155,7 +173,7 @@ class Godcmd(Plugin):
COMMANDS["reset"]["alias"].append(custom_command) COMMANDS["reset"]["alias"].append(custom_command)
self.password = gconf["password"] self.password = gconf["password"]
self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证 TODO: 用户名每次都会变,目前不可用 self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用
self.isrunning = True # 机器人是否运行中 self.isrunning = True # 机器人是否运行中
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
@@ -176,7 +194,7 @@ class Godcmd(Plugin):
channel = e_context['channel'] channel = e_context['channel']
user = e_context['context']['receiver'] user = e_context['context']['receiver']
session_id = e_context['context']['session_id'] session_id = e_context['context']['session_id']
isgroup = e_context['context']['isgroup'] isgroup = e_context['context'].get("isgroup", False)
bottype = Bridge().get_bot_type("chat") bottype = Bridge().get_bot_type("chat")
bot = Bridge().get_bot("chat") bot = Bridge().get_bot("chat")
# 将命令和参数分割 # 将命令和参数分割
@@ -208,6 +226,8 @@ class Godcmd(Plugin):
break break
if not ok: if not ok:
result = "插件不存在或未启用" result = "插件不存在或未启用"
elif cmd == "id":
ok, result = True, user
elif cmd == "set_openai_api_key": elif cmd == "set_openai_api_key":
if len(args) == 1: if len(args) == 1:
user_data = conf().get_user_data(user) user_data = conf().get_user_data(user)
@@ -296,11 +316,7 @@ class Godcmd(Plugin):
if len(args) != 1: if len(args) != 1:
ok, result = False, "请提供插件名" ok, result = False, "请提供插件名"
else: else:
ok = PluginManager().enable_plugin(args[0]) ok, result = PluginManager().enable_plugin(args[0])
if ok:
result = "插件已启用"
else:
result = "插件不存在"
elif cmd == "disablep": elif cmd == "disablep":
if len(args) != 1: if len(args) != 1:
ok, result = False, "请提供插件名" ok, result = False, "请提供插件名"
@@ -310,7 +326,16 @@ class Godcmd(Plugin):
result = "插件已禁用" result = "插件已禁用"
else: else:
result = "插件不存在" result = "插件不存在"
elif cmd == "installp":
if len(args) != 1:
ok, result = False, "请提供插件名或.git结尾的仓库地址"
else:
ok, result = PluginManager().install_plugin(args[0])
elif cmd == "uninstallp":
if len(args) != 1:
ok, result = False, "请提供插件名"
else:
ok, result = PluginManager().uninstall_plugin(args[0])
logger.debug("[Godcmd] admin command: %s by %s" % (cmd, user)) logger.debug("[Godcmd] admin command: %s by %s" % (cmd, user))
else: else:
ok, result = False, "需要管理员权限才能执行该指令" ok, result = False, "需要管理员权限才能执行该指令"
@@ -339,9 +364,6 @@ class Godcmd(Plugin):
if isadmin: if isadmin:
return False,"管理员账号无需认证" return False,"管理员账号无需认证"
if len(self.password) == 0:
return False,"未设置口令,无法认证"
if len(args) != 1: if len(args) != 1:
return False,"请提供口令" return False,"请提供口令"
@@ -349,6 +371,9 @@ class Godcmd(Plugin):
if password == self.password: if password == self.password:
self.admin_users.append(userid) self.admin_users.append(userid)
return True,"认证成功" return True,"认证成功"
elif password == self.temp_password:
self.admin_users.append(userid)
return True,"认证成功,请尽快设置口令"
else: else:
return False,"认证失败" return False,"认证失败"
+1
View File
@@ -0,0 +1 @@
from .hello import *
+100 -15
View File
@@ -1,8 +1,10 @@
# encoding:utf-8 # encoding:utf-8
import importlib import importlib
import importlib.util
import json import json
import os import os
import sys
from common.singleton import singleton from common.singleton import singleton
from common.sorted_dict import SortedDict from common.sorted_dict import SortedDict
from .event import * from .event import *
@@ -17,6 +19,8 @@ class PluginManager:
self.listening_plugins = {} self.listening_plugins = {}
self.instances = {} self.instances = {}
self.pconf = {} self.pconf = {}
self.current_plugin_path = None
self.loaded = {}
def register(self, name: str, desire_priority: int = 0, **kwargs): def register(self, name: str, desire_priority: int = 0, **kwargs):
def wrapper(plugincls): def wrapper(plugincls):
@@ -24,13 +28,15 @@ class PluginManager:
plugincls.priority = desire_priority plugincls.priority = desire_priority
plugincls.desc = kwargs.get('desc') plugincls.desc = kwargs.get('desc')
plugincls.author = kwargs.get('author') plugincls.author = kwargs.get('author')
plugincls.path = self.current_plugin_path
plugincls.version = kwargs.get('version') if kwargs.get('version') != None else "1.0" plugincls.version = kwargs.get('version') if kwargs.get('version') != None else "1.0"
plugincls.namecn = kwargs.get('namecn') if kwargs.get('namecn') != None else name plugincls.namecn = kwargs.get('namecn') if kwargs.get('namecn') != None else name
plugincls.hidden = kwargs.get('hidden') if kwargs.get('hidden') != None else False plugincls.hidden = kwargs.get('hidden') if kwargs.get('hidden') != None else False
plugincls.enabled = True plugincls.enabled = True
if self.current_plugin_path == None:
raise Exception("Plugin path not set")
self.plugins[name.upper()] = plugincls self.plugins[name.upper()] = plugincls
logger.info("Plugin %s_v%s registered" % (name, plugincls.version)) logger.info("Plugin %s_v%s registered, path=%s" % (name, plugincls.version, plugincls.path))
return plugincls
return wrapper return wrapper
def save_config(self): def save_config(self):
@@ -56,26 +62,38 @@ class PluginManager:
def scan_plugins(self): def scan_plugins(self):
logger.info("Scaning plugins ...") logger.info("Scaning plugins ...")
plugins_dir = "./plugins" plugins_dir = "./plugins"
raws = [self.plugins[name] for name in self.plugins]
for plugin_name in os.listdir(plugins_dir): for plugin_name in os.listdir(plugins_dir):
plugin_path = os.path.join(plugins_dir, plugin_name) plugin_path = os.path.join(plugins_dir, plugin_name)
if os.path.isdir(plugin_path): if os.path.isdir(plugin_path):
# 判断插件是否包含同名.py文件 # 判断插件是否包含同名__init__.py文件
main_module_path = os.path.join(plugin_path, plugin_name+".py") main_module_path = os.path.join(plugin_path,"__init__.py")
if os.path.isfile(main_module_path): if os.path.isfile(main_module_path):
# 导入插件 # 导入插件
import_path = "plugins.{}.{}".format(plugin_name, plugin_name) import_path = "plugins.{}".format(plugin_name)
try: try:
main_module = importlib.import_module(import_path) self.current_plugin_path = plugin_path
if plugin_path in self.loaded:
if self.loaded[plugin_path] == None:
logger.info("reload module %s" % plugin_name)
self.loaded[plugin_path] = importlib.reload(sys.modules[import_path])
dependent_module_names = [name for name in sys.modules.keys() if name.startswith( import_path+ '.')]
for name in dependent_module_names:
logger.info("reload module %s" % name)
importlib.reload(sys.modules[name])
else:
self.loaded[plugin_path] = importlib.import_module(import_path)
self.current_plugin_path = None
except Exception as e: except Exception as e:
logger.warn("Failed to import plugin %s: %s" % (plugin_name, e)) logger.exception("Failed to import plugin %s: %s" % (plugin_name, e))
continue continue
pconf = self.pconf pconf = self.pconf
new_plugins = [] news = [self.plugins[name] for name in self.plugins]
new_plugins = list(set(news) - set(raws))
modified = False modified = False
for name, plugincls in self.plugins.items(): for name, plugincls in self.plugins.items():
rawname = plugincls.name rawname = plugincls.name
if rawname not in pconf["plugins"]: if rawname not in pconf["plugins"]:
new_plugins.append(plugincls)
modified = True modified = True
logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name) logger.info("Plugin %s not found in pconfig, adding to pconfig..." % name)
pconf["plugins"][rawname] = {"enabled": plugincls.enabled, "priority": plugincls.priority} pconf["plugins"][rawname] = {"enabled": plugincls.enabled, "priority": plugincls.priority}
@@ -92,14 +110,16 @@ class PluginManager:
self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True) self.listening_plugins[event].sort(key=lambda name: self.plugins[name].priority, reverse=True)
def activate_plugins(self): # 生成新开启的插件实例 def activate_plugins(self): # 生成新开启的插件实例
failed_plugins = []
for name, plugincls in self.plugins.items(): for name, plugincls in self.plugins.items():
if plugincls.enabled: if plugincls.enabled:
if name not in self.instances: if name not in self.instances:
try: try:
instance = plugincls() instance = plugincls()
except Exception as e: except Exception as e:
logger.warn("Failed to create init %s, diabled. %s" % (name, e)) logger.warn("Failed to init %s, diabled. %s" % (name, e))
self.disable_plugin(name) self.disable_plugin(name)
failed_plugins.append(name)
continue continue
self.instances[name] = instance self.instances[name] = instance
for event in instance.handlers: for event in instance.handlers:
@@ -107,6 +127,7 @@ class PluginManager:
self.listening_plugins[event] = [] self.listening_plugins[event] = []
self.listening_plugins[event].append(name) self.listening_plugins[event].append(name)
self.refresh_order() self.refresh_order()
return failed_plugins
def reload_plugin(self, name:str): def reload_plugin(self, name:str):
name = name.upper() name = name.upper()
@@ -156,15 +177,17 @@ class PluginManager:
def enable_plugin(self, name:str): def enable_plugin(self, name:str):
name = name.upper() name = name.upper()
if name not in self.plugins: if name not in self.plugins:
return False return False, "插件不存在"
if not self.plugins[name].enabled : if not self.plugins[name].enabled :
self.plugins[name].enabled = True self.plugins[name].enabled = True
rawname = self.plugins[name].name rawname = self.plugins[name].name
self.pconf["plugins"][rawname]["enabled"] = True self.pconf["plugins"][rawname]["enabled"] = True
self.save_config() self.save_config()
self.activate_plugins() failed_plugins = self.activate_plugins()
return True if name in failed_plugins:
return True return False, "插件开启失败"
return True, "插件已开启"
return True, "插件已开启"
def disable_plugin(self, name:str): def disable_plugin(self, name:str):
name = name.upper() name = name.upper()
@@ -179,4 +202,66 @@ class PluginManager:
return True return True
def list_plugins(self): def list_plugins(self):
return self.plugins return self.plugins
def install_plugin(self, repo:str):
try:
import common.package_manager as pkgmgr
pkgmgr.check_dulwich()
except Exception as e:
logger.error("Failed to install plugin, {}".format(e))
return False, "无法导入dulwich,安装插件失败"
import re
from dulwich import porcelain
logger.info("clone git repo: {}".format(repo))
match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo)
if not match:
try:
with open("./plugins/source.json","r") as f:
source = json.load(f)
if repo in source["repo"]:
repo = source["repo"][repo]["url"]
match = re.match(r"^(https?:\/\/|git@)([^\/:]+)[\/:]([^\/:]+)\/(.+).git$", repo)
if not match:
return False, "安装插件失败,source中的仓库地址不合法"
else:
return False, "安装插件失败,仓库地址不合法"
except Exception as e:
logger.error("Failed to install plugin, {}".format(e))
return False, "安装插件失败,请检查仓库地址是否正确"
dirname = os.path.join("./plugins",match.group(4))
try:
repo = porcelain.clone(repo, dirname, checkout=True)
if os.path.exists(os.path.join(dirname,"requirements.txt")):
logger.info("detect requirements.txtinstalling...")
pkgmgr.install_requirements(os.path.join(dirname,"requirements.txt"))
return True, "安装插件成功,请使用#scanp命令扫描插件或重启程序"
except Exception as e:
logger.error("Failed to install plugin, {}".format(e))
return False, "安装插件失败,"+str(e)
def uninstall_plugin(self, name:str):
name = name.upper()
if name not in self.plugins:
return False, "插件不存在"
if name in self.instances:
self.disable_plugin(name)
dirname = self.plugins[name].path
try:
import shutil
shutil.rmtree(dirname)
rawname = self.plugins[name].name
for event in self.listening_plugins:
if name in self.listening_plugins[event]:
self.listening_plugins[event].remove(name)
del self.plugins[name]
del self.pconf["plugins"][rawname]
self.loaded[dirname] = None
self.save_config()
return True, "卸载插件成功"
except Exception as e:
logger.error("Failed to uninstall plugin, {}".format(e))
return False, "卸载插件失败,请手动删除文件夹完成卸载,"+str(e)
+1
View File
@@ -0,0 +1 @@
from .role import *
View File
-71
View File
@@ -1,71 +0,0 @@
{
"start":{
"host" : "127.0.0.1",
"port" : 7860,
"use_https" : false
},
"defaults": {
"params": {
"sampler_name": "DPM++ 2M Karras",
"steps": 20,
"width": 512,
"height": 512,
"cfg_scale": 7,
"prompt":"masterpiece, best quality",
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
"enable_hr": false,
"hr_scale": 2,
"hr_upscaler": "Latent",
"hr_second_pass_steps": 15,
"denoising_strength": 0.7
},
"options": {
"sd_model_checkpoint": "perfectWorld_v2Baked"
}
},
"rules": [
{
"keywords": [
"横版",
"壁纸"
],
"params": {
"width": 640,
"height": 384
},
"desc": "分辨率会变成640x384"
},
{
"keywords": [
"竖版"
],
"params": {
"width": 384,
"height": 640
}
},
{
"keywords": [
"高清"
],
"params": {
"enable_hr": true,
"hr_scale": 1.6
},
"desc": "出图分辨率长宽都会提高1.6倍"
},
{
"keywords": [
"二次元"
],
"params": {
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
"prompt": "masterpiece, best quality"
},
"options": {
"sd_model_checkpoint": "meinamix_meinaV8"
},
"desc": "使用二次元风格模型出图"
}
]
}
-91
View File
@@ -1,91 +0,0 @@
## 插件描述
本插件用于将画图请求转发给stable diffusion webui。
## 环境要求
使用前先安装stable diffusion webui,并在它的启动参数中添加 "--api"。
具体信息,请参考[文章](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API)。
部署运行后,保证主机能够成功访问http://127.0.0.1:7860/docs
请**安装**本插件的依赖包```webuiapi```
```
pip install webuiapi
```
## 使用说明
请将`config.json.template`复制为`config.json`,并修改其中的参数和规则。
PS: 如果修改了webui的`host`和`port`,也需要在配置文件中更改启动参数, 更多启动参数参考:https://github.com/mix1009/sdwebuiapi/blob/a1cb4c6d2f39389d6e962f0e6436f4aa74cd752c/webuiapi/webuiapi.py#L114
### 画图请求格式
用户的画图请求格式为:
```
<画图触发词><关键词1> <关键词2> ... <关键词n>:<prompt>
```
- 本插件会对画图触发词后的关键词进行逐个匹配,如果触发了规则中的关键词,则会在画图请求中重载对应的参数。
- 规则的匹配顺序参考`config.json`中的顺序,每个关键词最多被匹配到1次,如果多个关键词触发了重复的参数,重复参数以最后一个关键词为准。
- 关键词中包含`help`或`帮助`,会打印出帮助文档。
第一个"**:**"号之后的内容会作为附加的**prompt**,接在最终的prompt后
例如: 画横版 高清 二次元:cat
会触发三个关键词 "横版", "高清", "二次元"prompt为"cat"
若默认参数是:
```json
"width": 512,
"height": 512,
"enable_hr": false,
"prompt": "8k"
"negative_prompt": "nsfw",
"sd_model_checkpoint": "perfectWorld_v2Baked"
```
"横版"触发的规则参数为:
```json
"width": 640,
"height": 384,
```
"高清"触发的规则参数为:
```json
"enable_hr": true,
"hr_scale": 1.6,
```
"二次元"触发的规则参数为:
```json
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
"steps": 20,
"prompt": "masterpiece, best quality",
"sd_model_checkpoint": "meinamix_meinaV8"
```
以上这些规则的参数会和默认参数合并。第一个":"后的内容cat会连接在prompt后。
得到最终参数为:
```json
"width": 640,
"height": 384,
"enable_hr": true,
"hr_scale": 1.6,
"negative_prompt": "(low quality, worst quality:1.4),(bad_prompt:0.8), (monochrome:1.1), (greyscale)",
"steps": 20,
"prompt": "masterpiece, best quality, cat",
"sd_model_checkpoint": "meinamix_meinaV8"
```
PS: 实际参数分为两部分:
- 一部分是`params`,为画画的参数;参数名**必须**与webuiapi包中[txt2img api](https://github.com/mix1009/sdwebuiapi/blob/fb2054e149c0a4e25125c0cd7e7dca06bda839d4/webuiapi/webuiapi.py#L163)的参数名一致
- 另一部分是`options`,指sdwebui的设置,使用的模型和vae需写在里面。它和(http://127.0.0.1:7860/sdapi/v1/options )所返回的键一致。
-123
View File
@@ -1,123 +0,0 @@
# encoding:utf-8
import json
import os
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from config import conf
import plugins
from plugins import *
from common.log import logger
import webuiapi
import io
@plugins.register(name="sdwebui", desc="利用stable-diffusion webui来画图", version="2.0", author="lanvent")
class SDWebUI(Plugin):
def __init__(self):
super().__init__()
curdir = os.path.dirname(__file__)
config_path = os.path.join(curdir, "config.json")
try:
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
self.rules = config["rules"]
defaults = config["defaults"]
self.default_params = defaults["params"]
self.default_options = defaults["options"]
self.start_args = config["start"]
self.api = webuiapi.WebUIApi(**self.start_args)
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
logger.info("[SD] inited")
except Exception as e:
if isinstance(e, FileNotFoundError):
logger.warn(f"[SD] init failed, {config_path} not found, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/sdwebui .")
else:
logger.warn("[SD] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/sdwebui .")
raise e
def on_handle_context(self, e_context: EventContext):
if e_context['context'].type != ContextType.IMAGE_CREATE:
return
channel = e_context['context'].channel
if ReplyType.IMAGE in channel.NOT_SUPPORT_REPLYTYPE:
return
logger.debug("[SD] on_handle_context. content: %s" %e_context['context'].content)
logger.info("[SD] image_query={}".format(e_context['context'].content))
reply = Reply()
try:
content = e_context['context'].content[:]
# 解析用户输入 如"横版 高清 二次元:cat"
if ":" in content:
keywords, prompt = content.split(":", 1)
else:
keywords = content
prompt = ""
keywords = keywords.split()
if "help" in keywords or "帮助" in keywords:
reply.type = ReplyType.INFO
reply.content = self.get_help_text(verbose = True)
else:
rule_params = {}
rule_options = {}
for keyword in keywords:
matched = False
for rule in self.rules:
if keyword in rule["keywords"]:
for key in rule["params"]:
rule_params[key] = rule["params"][key]
if "options" in rule:
for key in rule["options"]:
rule_options[key] = rule["options"][key]
matched = True
break # 一个关键词只匹配一个规则
if not matched:
logger.warning("[SD] keyword not matched: %s" % keyword)
params = {**self.default_params, **rule_params}
options = {**self.default_options, **rule_options}
params["prompt"] = params.get("prompt", "")+f", {prompt}"
if len(options) > 0:
logger.info("[SD] cover options={}".format(options))
self.api.set_options(options)
logger.info("[SD] params={}".format(params))
result = self.api.txt2img(
**params
)
reply.type = ReplyType.IMAGE
b_img = io.BytesIO()
result.image.save(b_img, format="PNG")
reply.content = b_img
e_context.action = EventAction.BREAK_PASS # 事件结束后,跳过处理context的默认逻辑
except Exception as e:
reply.type = ReplyType.ERROR
reply.content = "[SD] "+str(e)
logger.error("[SD] exception: %s" % e)
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
finally:
e_context['reply'] = reply
def get_help_text(self, verbose = False, **kwargs):
if not conf().get('image_create_prefix'):
return "画图功能未启用"
else:
trigger = conf()['image_create_prefix'][0]
help_text = "利用stable-diffusion来画图。\n"
if not verbose:
return help_text
help_text += f"使用方法:\n使用\"{trigger}[关键词1] [关键词2]...:提示语\"的格式作画,如\"{trigger}横版 高清:cat\"\n"
help_text += "目前可用关键词:\n"
for rule in self.rules:
keywords = [f"[{keyword}]" for keyword in rule['keywords']]
help_text += f"{','.join(keywords)}"
if "desc" in rule:
help_text += f"-{rule['desc']}\n"
else:
help_text += "\n"
return help_text
+8
View File
@@ -0,0 +1,8 @@
{
"repo": {
"sdwebui": {
"url": "https://github.com/lanvent/plugin_sdwebui.git",
"desc": "利用stable-diffusion画图的插件"
}
}
}
+1
View File
@@ -0,0 +1 @@
from .tool import *
+3 -3
View File
@@ -8,6 +8,9 @@ pyttsx3>=2.90 # pytsx text to speech
baidu_aip>=4.16.10 # baidu voice baidu_aip>=4.16.10 # baidu voice
# azure-cognitiveservices-speech # azure voice # azure-cognitiveservices-speech # azure voice
#install plugin
dulwich
# wechaty # wechaty
wechaty>=0.10.7 wechaty>=0.10.7
wechaty_puppet>=0.4.23 wechaty_puppet>=0.4.23
@@ -16,9 +19,6 @@ pysilk_mod>=1.6.0 # needed by send voice
# wechatmp # wechatmp
web.py web.py
# sdwebui plugin
webuiapi>=0.6.2
# chatgpt-tool-hub plugin # chatgpt-tool-hub plugin
--extra-index-url https://pypi.python.org/simple --extra-index-url https://pypi.python.org/simple
chatgpt_tool_hub>=0.3.5 chatgpt_tool_hub>=0.3.5