mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-03-19 21:38:18 +08:00
Merge branch 'plugins' into dev
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import requests
|
||||
from bot.bot import Bot
|
||||
from bridge.reply import Reply, ReplyType
|
||||
|
||||
|
||||
# Baidu Unit对话接口 (可用, 但能力较弱)
|
||||
@@ -14,7 +15,8 @@ class BaiduUnitBot(Bot):
|
||||
headers = {'content-type': 'application/x-www-form-urlencoded'}
|
||||
response = requests.post(url, data=post_data.encode(), headers=headers)
|
||||
if response:
|
||||
return response.json()['result']['context']['SYS_PRESUMED_HIST'][1]
|
||||
reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1])
|
||||
return reply
|
||||
|
||||
def get_token(self):
|
||||
access_key = 'YOUR_ACCESS_KEY'
|
||||
|
||||
@@ -3,8 +3,12 @@ Auto-replay chat robot abstract class
|
||||
"""
|
||||
|
||||
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply
|
||||
|
||||
|
||||
class Bot(object):
|
||||
def reply(self, query, context=None):
|
||||
def reply(self, query, context : Context =None) -> Reply:
|
||||
"""
|
||||
bot auto-reply content
|
||||
:param req: received message
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# encoding:utf-8
|
||||
|
||||
from bot.bot import Bot
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from config import conf, load_config
|
||||
from common.log import logger
|
||||
from common.token_bucket import TokenBucket
|
||||
@@ -8,10 +10,6 @@ from common.expired_dict import ExpiredDict
|
||||
import openai
|
||||
import time
|
||||
|
||||
if conf().get('expires_in_seconds'):
|
||||
all_sessions = ExpiredDict(conf().get('expires_in_seconds'))
|
||||
else:
|
||||
all_sessions = dict()
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class ChatGPTBot(Bot):
|
||||
@@ -20,6 +18,7 @@ class ChatGPTBot(Bot):
|
||||
if conf().get('open_ai_api_base'):
|
||||
openai.api_base = conf().get('open_ai_api_base')
|
||||
proxy = conf().get('proxy')
|
||||
self.sessions = SessionManager()
|
||||
if proxy:
|
||||
openai.proxy = proxy
|
||||
if conf().get('rate_limit_chatgpt'):
|
||||
@@ -29,21 +28,24 @@ class ChatGPTBot(Bot):
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if not context or not context.get('type') or context.get('type') == 'TEXT':
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[OPEN_AI] query={}".format(query))
|
||||
session_id = context.get('session_id') or context.get('from_user_id')
|
||||
|
||||
session_id = context['session_id']
|
||||
reply = None
|
||||
clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆'])
|
||||
if query in clear_memory_commands:
|
||||
Session.clear_session(session_id)
|
||||
return '记忆已清除'
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, '记忆已清除')
|
||||
elif query == '#清除所有':
|
||||
Session.clear_all_session()
|
||||
return '所有人记忆已清除'
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
|
||||
elif query == '#更新配置':
|
||||
load_config()
|
||||
return '配置已更新'
|
||||
|
||||
session = Session.build_session_query(query, session_id)
|
||||
reply = Reply(ReplyType.INFO, '配置已更新')
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.build_session_query(query, session_id)
|
||||
logger.debug("[OPEN_AI] session query={}".format(session))
|
||||
|
||||
# if context.get('stream'):
|
||||
@@ -52,14 +54,29 @@ class ChatGPTBot(Bot):
|
||||
|
||||
reply_content = self.reply_text(session, session_id, 0)
|
||||
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"]))
|
||||
if reply_content["completion_tokens"] > 0:
|
||||
Session.save_session(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
return reply_content["content"]
|
||||
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content['content'])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content['content'])
|
||||
logger.debug("[OPEN_AI] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
|
||||
elif context.get('type', None) == 'IMAGE_CREATE':
|
||||
return self.create_img(query, 0)
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type))
|
||||
return reply
|
||||
|
||||
def reply_text(self, session, session_id, retry_count=0) ->dict:
|
||||
def reply_text(self, session, session_id, retry_count=0) -> dict:
|
||||
'''
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
@@ -80,8 +97,8 @@ class ChatGPTBot(Bot):
|
||||
presence_penalty=conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
)
|
||||
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||
return {"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
return {"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
"content": response.choices[0]['message']['content']}
|
||||
except openai.error.RateLimitError as e:
|
||||
# rate limit exception
|
||||
@@ -96,15 +113,15 @@ class ChatGPTBot(Bot):
|
||||
# api connection exception
|
||||
logger.warn(e)
|
||||
logger.warn("[OPEN_AI] APIConnection failed")
|
||||
return {"completion_tokens": 0, "content":"我连接不到你的网络"}
|
||||
return {"completion_tokens": 0, "content": "我连接不到你的网络"}
|
||||
except openai.error.Timeout as e:
|
||||
logger.warn(e)
|
||||
logger.warn("[OPEN_AI] Timeout")
|
||||
return {"completion_tokens": 0, "content":"我没有收到你的消息"}
|
||||
return {"completion_tokens": 0, "content": "我没有收到你的消息"}
|
||||
except Exception as e:
|
||||
# unknown exception
|
||||
logger.exception(e)
|
||||
Session.clear_session(session_id)
|
||||
self.sessions.clear_session(session_id)
|
||||
return {"completion_tokens": 0, "content": "请再问我一次吧"}
|
||||
|
||||
def create_img(self, query, retry_count=0):
|
||||
@@ -119,7 +136,7 @@ class ChatGPTBot(Bot):
|
||||
)
|
||||
image_url = response['data'][0]['url']
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
return image_url
|
||||
return True, image_url
|
||||
except openai.error.RateLimitError as e:
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
@@ -127,14 +144,31 @@ class ChatGPTBot(Bot):
|
||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||
return self.create_img(query, retry_count+1)
|
||||
else:
|
||||
return "请求太快啦,请休息一下再问我吧"
|
||||
return False, "提问太快啦,请休息一下再问我吧"
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return None
|
||||
return False, str(e)
|
||||
|
||||
class Session(object):
|
||||
@staticmethod
|
||||
def build_session_query(query, session_id):
|
||||
|
||||
class SessionManager(object):
|
||||
def __init__(self):
|
||||
if conf().get('expires_in_seconds'):
|
||||
sessions = ExpiredDict(conf().get('expires_in_seconds'))
|
||||
else:
|
||||
sessions = dict()
|
||||
self.sessions = sessions
|
||||
|
||||
def build_session(self, session_id, system_prompt=None):
|
||||
session = self.sessions.get(session_id, [])
|
||||
if len(session) == 0:
|
||||
if system_prompt is None:
|
||||
system_prompt = conf().get("character_desc", "")
|
||||
system_item = {'role': 'system', 'content': system_prompt}
|
||||
session.append(system_item)
|
||||
self.sessions[session_id] = session
|
||||
return session
|
||||
|
||||
def build_session_query(self, query, session_id):
|
||||
'''
|
||||
build query with conversation history
|
||||
e.g. [
|
||||
@@ -147,36 +181,28 @@ class Session(object):
|
||||
:param session_id: session id
|
||||
:return: query content with conversaction
|
||||
'''
|
||||
session = all_sessions.get(session_id, [])
|
||||
if len(session) == 0:
|
||||
system_prompt = conf().get("character_desc", "")
|
||||
system_item = {'role': 'system', 'content': system_prompt}
|
||||
session.append(system_item)
|
||||
all_sessions[session_id] = session
|
||||
session = self.build_session(session_id)
|
||||
user_item = {'role': 'user', 'content': query}
|
||||
session.append(user_item)
|
||||
return session
|
||||
|
||||
@staticmethod
|
||||
def save_session(answer, session_id, total_tokens):
|
||||
def save_session(self, answer, session_id, total_tokens):
|
||||
max_tokens = conf().get("conversation_max_tokens")
|
||||
if not max_tokens:
|
||||
# default 3000
|
||||
max_tokens = 1000
|
||||
max_tokens=int(max_tokens)
|
||||
max_tokens = int(max_tokens)
|
||||
|
||||
session = all_sessions.get(session_id)
|
||||
session = self.sessions.get(session_id)
|
||||
if session:
|
||||
# append conversation
|
||||
gpt_item = {'role': 'assistant', 'content': answer}
|
||||
session.append(gpt_item)
|
||||
|
||||
# discard exceed limit conversation
|
||||
Session.discard_exceed_conversation(session, max_tokens, total_tokens)
|
||||
|
||||
self.discard_exceed_conversation(session, max_tokens, total_tokens)
|
||||
|
||||
@staticmethod
|
||||
def discard_exceed_conversation(session, max_tokens, total_tokens):
|
||||
def discard_exceed_conversation(self, session, max_tokens, total_tokens):
|
||||
dec_tokens = int(total_tokens)
|
||||
# logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens))
|
||||
while dec_tokens > max_tokens:
|
||||
@@ -185,13 +211,11 @@ class Session(object):
|
||||
session.pop(1)
|
||||
session.pop(1)
|
||||
else:
|
||||
break
|
||||
break
|
||||
dec_tokens = dec_tokens - max_tokens
|
||||
|
||||
@staticmethod
|
||||
def clear_session(session_id):
|
||||
all_sessions[session_id] = []
|
||||
def clear_session(self, session_id):
|
||||
self.sessions[session_id] = []
|
||||
|
||||
@staticmethod
|
||||
def clear_all_session():
|
||||
all_sessions.clear()
|
||||
def clear_all_session(self):
|
||||
self.sessions.clear()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# encoding:utf-8
|
||||
|
||||
from bot.bot import Bot
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
import openai
|
||||
@@ -18,29 +20,32 @@ class OpenAIBot(Bot):
|
||||
if proxy:
|
||||
openai.proxy = proxy
|
||||
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if not context or not context.get('type') or context.get('type') == 'TEXT':
|
||||
logger.info("[OPEN_AI] query={}".format(query))
|
||||
from_user_id = context.get('from_user_id') or context.get('session_id')
|
||||
if query == '#清除记忆':
|
||||
Session.clear_session(from_user_id)
|
||||
return '记忆已清除'
|
||||
elif query == '#清除所有':
|
||||
Session.clear_all_session()
|
||||
return '所有人记忆已清除'
|
||||
if context and context.type:
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[OPEN_AI] query={}".format(query))
|
||||
from_user_id = context['session_id']
|
||||
reply = None
|
||||
if query == '#清除记忆':
|
||||
Session.clear_session(from_user_id)
|
||||
reply = Reply(ReplyType.INFO, '记忆已清除')
|
||||
elif query == '#清除所有':
|
||||
Session.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
|
||||
else:
|
||||
new_query = Session.build_session_query(query, from_user_id)
|
||||
logger.debug("[OPEN_AI] session query={}".format(new_query))
|
||||
|
||||
new_query = Session.build_session_query(query, from_user_id)
|
||||
logger.debug("[OPEN_AI] session query={}".format(new_query))
|
||||
|
||||
reply_content = self.reply_text(new_query, from_user_id, 0)
|
||||
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
|
||||
if reply_content and query:
|
||||
Session.save_session(query, reply_content, from_user_id)
|
||||
return reply_content
|
||||
|
||||
elif context.get('type', None) == 'IMAGE_CREATE':
|
||||
return self.create_img(query, 0)
|
||||
reply_content = self.reply_text(new_query, from_user_id, 0)
|
||||
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
|
||||
if reply_content and query:
|
||||
Session.save_session(query, reply_content, from_user_id)
|
||||
reply = Reply(ReplyType.TEXT, reply_content)
|
||||
return reply
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
return self.create_img(query, 0)
|
||||
|
||||
def reply_text(self, query, user_id, retry_count=0):
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user