mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-03-19 21:38:18 +08:00
openai 接口返回token数量来修剪会话长度
This commit is contained in:
@@ -6,7 +6,6 @@ from common.log import logger
|
|||||||
from common.expired_dict import ExpiredDict
|
from common.expired_dict import ExpiredDict
|
||||||
import openai
|
import openai
|
||||||
import time
|
import time
|
||||||
import json
|
|
||||||
|
|
||||||
if conf().get('expires_in_seconds'):
|
if conf().get('expires_in_seconds'):
|
||||||
user_session = ExpiredDict(conf().get('expires_in_seconds'))
|
user_session = ExpiredDict(conf().get('expires_in_seconds'))
|
||||||
@@ -44,12 +43,19 @@ class ChatGPTBot(Bot):
|
|||||||
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
|
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
|
||||||
if reply_content:
|
if reply_content:
|
||||||
Session.save_session(query, reply_content, from_user_id)
|
Session.save_session(query, reply_content, from_user_id)
|
||||||
return reply_content
|
return reply_content[1]
|
||||||
|
|
||||||
elif context.get('type', None) == 'IMAGE_CREATE':
|
elif context.get('type', None) == 'IMAGE_CREATE':
|
||||||
return self.create_img(query, 0)
|
return self.create_img(query, 0)
|
||||||
|
|
||||||
def reply_text(self, query, user_id, retry_count=0):
|
def reply_text(self, query, user_id, retry_count=0):
|
||||||
|
'''
|
||||||
|
call openai's ChatCompletion to get the answer
|
||||||
|
:param query: query content
|
||||||
|
:param user_id: from user id
|
||||||
|
:param retry_count: retry count
|
||||||
|
:return: [0]-tokens used and [1]-answer
|
||||||
|
'''
|
||||||
try:
|
try:
|
||||||
response = openai.ChatCompletion.create(
|
response = openai.ChatCompletion.create(
|
||||||
model="gpt-3.5-turbo", # 对话模型的名称
|
model="gpt-3.5-turbo", # 对话模型的名称
|
||||||
@@ -62,8 +68,8 @@ class ChatGPTBot(Bot):
|
|||||||
)
|
)
|
||||||
# res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
|
# res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
|
||||||
logger.info(response.choices[0]['message']['content'])
|
logger.info(response.choices[0]['message']['content'])
|
||||||
# log.info("[OPEN_AI] reply={}".format(res_content))
|
|
||||||
return response.choices[0]['message']['content']
|
return response["usage"]["prompt_tokens"],response.choices[0]['message']['content']
|
||||||
except openai.error.RateLimitError as e:
|
except openai.error.RateLimitError as e:
|
||||||
# rate limit exception
|
# rate limit exception
|
||||||
logger.warn(e)
|
logger.warn(e)
|
||||||
@@ -72,21 +78,21 @@ class ChatGPTBot(Bot):
|
|||||||
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
|
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||||
return self.reply_text(query, user_id, retry_count+1)
|
return self.reply_text(query, user_id, retry_count+1)
|
||||||
else:
|
else:
|
||||||
return "提问太快啦,请休息一下再问我吧"
|
return 0,"提问太快啦,请休息一下再问我吧"
|
||||||
except openai.error.APIConnectionError as e:
|
except openai.error.APIConnectionError as e:
|
||||||
# api connection exception
|
# api connection exception
|
||||||
logger.warn(e)
|
logger.warn(e)
|
||||||
logger.warn("[OPEN_AI] APIConnection failed")
|
logger.warn("[OPEN_AI] APIConnection failed")
|
||||||
return "我连接不到你的网络"
|
return 0,"我连接不到你的网络"
|
||||||
except openai.error.Timeout as e:
|
except openai.error.Timeout as e:
|
||||||
logger.warn(e)
|
logger.warn(e)
|
||||||
logger.warn("[OPEN_AI] Timeout")
|
logger.warn("[OPEN_AI] Timeout")
|
||||||
return "我没有收到你的消息"
|
return 0,"我没有收到你的消息"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# unknown exception
|
# unknown exception
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
Session.clear_session(user_id)
|
Session.clear_session(user_id)
|
||||||
return "请再问我一次吧"
|
return 0,"请再问我一次吧"
|
||||||
|
|
||||||
def create_img(self, query, retry_count=0):
|
def create_img(self, query, retry_count=0):
|
||||||
try:
|
try:
|
||||||
@@ -142,31 +148,27 @@ class Session(object):
|
|||||||
if not max_tokens:
|
if not max_tokens:
|
||||||
# default 3000
|
# default 3000
|
||||||
max_tokens = 1000
|
max_tokens = 1000
|
||||||
|
max_tokens=int(max_tokens)
|
||||||
|
|
||||||
session = user_session.get(user_id)
|
session = user_session.get(user_id)
|
||||||
if session:
|
if session:
|
||||||
# append conversation
|
# append conversation
|
||||||
gpt_item = {'role': 'assistant', 'content': answer}
|
gpt_item = {'role': 'assistant', 'content': answer[1]}
|
||||||
session.append(gpt_item)
|
session.append(gpt_item)
|
||||||
|
|
||||||
# discard exceed limit conversation
|
# discard exceed limit conversation
|
||||||
Session.discard_exceed_conversation(user_session[user_id], max_tokens)
|
used_tokens=int(answer[0])
|
||||||
|
# logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens))
|
||||||
|
|
||||||
@staticmethod
|
while used_tokens > max_tokens:
|
||||||
def discard_exceed_conversation(session, max_tokens):
|
# pop first conversation
|
||||||
count = 0
|
if len(session) > 0:
|
||||||
count_list = list()
|
session.pop(0)
|
||||||
for i in range(len(session)-1, -1, -1):
|
else:
|
||||||
# count tokens of conversation list
|
break
|
||||||
history_conv = session[i]
|
|
||||||
tokens=json.dumps(history_conv).split()
|
used_tokens=used_tokens-max_tokens
|
||||||
count += len(tokens)
|
|
||||||
count_list.append(count)
|
|
||||||
|
|
||||||
for c in count_list:
|
|
||||||
if c > max_tokens:
|
|
||||||
# pop first conversation
|
|
||||||
session.pop(0)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def clear_session(user_id):
|
def clear_session(user_id):
|
||||||
|
|||||||
Reference in New Issue
Block a user