mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-03-19 21:38:18 +08:00
decouple message processing process
This commit is contained in:
@@ -19,19 +19,24 @@ class ChatGPTBot(Bot):
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if not context or not context.get('type') or context.get('type') == 'TEXT':
|
||||
if context['type']=='TEXT':
|
||||
logger.info("[OPEN_AI] query={}".format(query))
|
||||
session_id = context.get('session_id') or context.get('from_user_id')
|
||||
session_id = context['session_id']
|
||||
reply=None
|
||||
if query == '#清除记忆':
|
||||
self.sessions.clear_session(session_id)
|
||||
return '记忆已清除'
|
||||
reply={'type':'INFO', 'content':'记忆已清除'}
|
||||
elif query == '#清除所有':
|
||||
self.sessions.clear_all_session()
|
||||
return '所有人记忆已清除'
|
||||
reply={'type':'INFO', 'content':'所有人记忆已清除'}
|
||||
elif query == '#更新配置':
|
||||
load_config()
|
||||
return '配置已更新'
|
||||
|
||||
reply={'type':'INFO', 'content':'配置已更新'}
|
||||
elif query == '#DEBUG':
|
||||
logger.setLevel('DEBUG')
|
||||
reply={'type':'INFO', 'content':'DEBUG模式已开启'}
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.build_session_query(query, session_id)
|
||||
logger.debug("[OPEN_AI] session query={}".format(session))
|
||||
|
||||
@@ -41,12 +46,26 @@ class ChatGPTBot(Bot):
|
||||
|
||||
reply_content = self.reply_text(session, session_id, 0)
|
||||
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"]))
|
||||
if reply_content["completion_tokens"] > 0:
|
||||
if reply_content['completion_tokens']==0 and len(reply_content['content'])>0:
|
||||
reply={'type':'ERROR', 'content':reply_content['content']}
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
return reply_content["content"]
|
||||
reply={'type':'TEXT', 'content':reply_content["content"]}
|
||||
else:
|
||||
reply={'type':'ERROR', 'content':reply_content['content']}
|
||||
logger.debug("[OPEN_AI] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
|
||||
elif context.get('type', None) == 'IMAGE_CREATE':
|
||||
return self.create_img(query, 0)
|
||||
elif context['type'] == 'IMAGE_CREATE':
|
||||
ok, retstring=self.create_img(query, 0)
|
||||
reply=None
|
||||
if ok:
|
||||
reply = {'type':'IMAGE', 'content':retstring}
|
||||
else:
|
||||
reply = {'type':'ERROR', 'content':retstring}
|
||||
return reply
|
||||
else:
|
||||
reply= {'type':'ERROR', 'content':'Bot不支持处理{}类型的消息'.format(context['type'])}
|
||||
|
||||
def reply_text(self, session, session_id, retry_count=0) ->dict:
|
||||
'''
|
||||
@@ -104,7 +123,7 @@ class ChatGPTBot(Bot):
|
||||
)
|
||||
image_url = response['data'][0]['url']
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
return image_url
|
||||
return True,image_url
|
||||
except openai.error.RateLimitError as e:
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
@@ -112,10 +131,10 @@ class ChatGPTBot(Bot):
|
||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||
return self.create_img(query, retry_count+1)
|
||||
else:
|
||||
return "提问太快啦,请休息一下再问我吧"
|
||||
return False,"提问太快啦,请休息一下再问我吧"
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return None
|
||||
return False,str(e)
|
||||
|
||||
class SessionManager(object):
|
||||
def __init__(self):
|
||||
|
||||
Reference in New Issue
Block a user