mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-05-12 23:23:25 +08:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 94004b095b | |||
| f652d592bd | |||
| 186e18fe94 | |||
| 28eb67bc24 | |||
| 6c7e4aaf37 | |||
| 709a1317ef |
@@ -10,3 +10,4 @@ nohup.out
|
||||
tmp
|
||||
plugins.json
|
||||
itchat.pkl
|
||||
*.log
|
||||
+22
-7
@@ -1,14 +1,12 @@
|
||||
|
||||
|
||||
from asyncio import CancelledError
|
||||
import queue
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.dequeue import Dequeue
|
||||
from channel.channel import Channel
|
||||
from bridge.reply import *
|
||||
from bridge.context import *
|
||||
@@ -245,8 +243,11 @@ class ChatChannel(Channel):
|
||||
session_id = context['session_id']
|
||||
with self.lock:
|
||||
if session_id not in self.sessions:
|
||||
self.sessions[session_id] = (queue.Queue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 1)))
|
||||
self.sessions[session_id][0].put(context)
|
||||
self.sessions[session_id] = [Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 1))]
|
||||
if context.type == ContextType.TEXT and context.content.startswith("#"):
|
||||
self.sessions[session_id][0].putleft(context) # 优先处理管理命令
|
||||
else:
|
||||
self.sessions[session_id][0].put(context)
|
||||
|
||||
# 消费者函数,单独线程,用于从消息队列中取出消息并处理
|
||||
def consume(self):
|
||||
@@ -272,12 +273,26 @@ class ChatChannel(Channel):
|
||||
semaphore.release()
|
||||
time.sleep(0.1)
|
||||
|
||||
def cancel(self, session_id):
|
||||
# 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务
|
||||
def cancel_session(self, session_id):
|
||||
with self.lock:
|
||||
if session_id in self.sessions:
|
||||
for future in self.futures[session_id]:
|
||||
future.cancel()
|
||||
self.sessions[session_id][0]=queue.Queue()
|
||||
cnt = self.sessions[session_id][0].qsize()
|
||||
if cnt>0:
|
||||
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
|
||||
self.sessions[session_id][0] = Dequeue()
|
||||
|
||||
def cancel_all_session(self):
|
||||
with self.lock:
|
||||
for session_id in self.sessions:
|
||||
for future in self.futures[session_id]:
|
||||
future.cancel()
|
||||
cnt = self.sessions[session_id][0].qsize()
|
||||
if cnt>0:
|
||||
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
|
||||
self.sessions[session_id][0] = Dequeue()
|
||||
|
||||
|
||||
def check_prefix(content, prefix_list):
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
|
||||
from queue import Full, Queue
|
||||
from time import monotonic as time
|
||||
|
||||
# add implementation of putleft to Queue
|
||||
class Dequeue(Queue):
|
||||
def putleft(self, item, block=True, timeout=None):
|
||||
with self.not_full:
|
||||
if self.maxsize > 0:
|
||||
if not block:
|
||||
if self._qsize() >= self.maxsize:
|
||||
raise Full
|
||||
elif timeout is None:
|
||||
while self._qsize() >= self.maxsize:
|
||||
self.not_full.wait()
|
||||
elif timeout < 0:
|
||||
raise ValueError("'timeout' must be a non-negative number")
|
||||
else:
|
||||
endtime = time() + timeout
|
||||
while self._qsize() >= self.maxsize:
|
||||
remaining = endtime - time()
|
||||
if remaining <= 0.0:
|
||||
raise Full
|
||||
self.not_full.wait(remaining)
|
||||
self._putleft(item)
|
||||
self.unfinished_tasks += 1
|
||||
self.not_empty.notify()
|
||||
|
||||
def putleft_nowait(self, item):
|
||||
return self.putleft(item, block=False)
|
||||
|
||||
def _putleft(self, item):
|
||||
self.queue.appendleft(item)
|
||||
@@ -8,6 +8,10 @@ def _get_logger():
|
||||
console_handle = logging.StreamHandler(sys.stdout)
|
||||
console_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'))
|
||||
file_handle = logging.FileHandler('run.log', encoding='utf-8')
|
||||
file_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'))
|
||||
log.addHandler(file_handle)
|
||||
log.addHandler(console_handle)
|
||||
return log
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from common.log import logger
|
||||
|
||||
@@ -38,7 +39,6 @@ available_setting = {
|
||||
"rate_limit_chatgpt": 20, # chatgpt的调用频率限制
|
||||
"rate_limit_dalle": 50, # openai dalle的调用频率限制
|
||||
|
||||
|
||||
# chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create
|
||||
"temperature": 0.9,
|
||||
"top_p": 1,
|
||||
@@ -77,11 +77,12 @@ available_setting = {
|
||||
"wechaty_puppet_service_token": "", # wechaty的token
|
||||
|
||||
# chatgpt指令自定义触发词
|
||||
"clear_memory_commands": ['#清除记忆'], # 重置会话指令
|
||||
"clear_memory_commands": ['#清除记忆'], # 重置会话指令,必须以#开头
|
||||
|
||||
# channel配置
|
||||
"channel_type": "wx", # 通道类型,支持wx,wxy和terminal
|
||||
|
||||
"debug": False, # 是否开启debug模式,开启后会打印更多日志
|
||||
|
||||
}
|
||||
|
||||
@@ -139,6 +140,10 @@ def load_config():
|
||||
else:
|
||||
config[name] = value
|
||||
|
||||
if config.get("debug", False):
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.debug("[INIT] set log level to DEBUG")
|
||||
|
||||
logger.info("[INIT] load config: {}".format(config))
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Tuple
|
||||
from bridge.bridge import Bridge
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from config import load_config
|
||||
from config import conf, load_config
|
||||
import plugins
|
||||
from plugins import *
|
||||
from common import const
|
||||
@@ -126,7 +126,14 @@ class Godcmd(Plugin):
|
||||
else:
|
||||
with open(config_path,"r") as f:
|
||||
gconf=json.load(f)
|
||||
|
||||
|
||||
custom_commands = conf().get("clear_memory_commands", [])
|
||||
for custom_command in custom_commands:
|
||||
if custom_command and custom_command.startswith("#"):
|
||||
custom_command = custom_command[1:]
|
||||
if custom_command and custom_command not in COMMANDS["reset"]["alias"]:
|
||||
COMMANDS["reset"]["alias"].append(custom_command)
|
||||
|
||||
self.password = gconf["password"]
|
||||
self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证 TODO: 用户名每次都会变,目前不可用
|
||||
self.isrunning = True # 机器人是否运行中
|
||||
@@ -146,6 +153,7 @@ class Godcmd(Plugin):
|
||||
logger.debug("[Godcmd] on_handle_context. content: %s" % content)
|
||||
if content.startswith("#"):
|
||||
# msg = e_context['context']['msg']
|
||||
channel = e_context['channel']
|
||||
user = e_context['context']['receiver']
|
||||
session_id = e_context['context']['session_id']
|
||||
isgroup = e_context['context']['isgroup']
|
||||
@@ -181,6 +189,7 @@ class Godcmd(Plugin):
|
||||
elif cmd == "reset":
|
||||
if bottype in (const.CHATGPT, const.OPEN_AI):
|
||||
bot.sessions.clear_session(session_id)
|
||||
channel.cancel_session(session_id)
|
||||
ok, result = True, "会话已重置"
|
||||
else:
|
||||
ok, result = False, "当前对话机器人不支持重置会话"
|
||||
@@ -202,6 +211,7 @@ class Godcmd(Plugin):
|
||||
ok, result = True, "配置已重载"
|
||||
elif cmd == "resetall":
|
||||
if bottype in (const.CHATGPT, const.OPEN_AI):
|
||||
channel.cancel_all_session()
|
||||
bot.sessions.clear_all_session()
|
||||
ok, result = True, "重置所有会话成功"
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user