mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-03-19 21:38:18 +08:00
formatting: run precommit on all files
This commit is contained in:
@@ -10,10 +10,7 @@ from bridge.reply import Reply, ReplyType
|
||||
class BaiduUnitBot(Bot):
|
||||
def reply(self, query, context=None):
|
||||
token = self.get_token()
|
||||
url = (
|
||||
"https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token="
|
||||
+ token
|
||||
)
|
||||
url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + token
|
||||
post_data = (
|
||||
'{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"'
|
||||
+ query
|
||||
@@ -32,12 +29,7 @@ class BaiduUnitBot(Bot):
|
||||
def get_token(self):
|
||||
access_key = "YOUR_ACCESS_KEY"
|
||||
secret_key = "YOUR_SECRET_KEY"
|
||||
host = (
|
||||
"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id="
|
||||
+ access_key
|
||||
+ "&client_secret="
|
||||
+ secret_key
|
||||
)
|
||||
host = "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + access_key + "&client_secret=" + secret_key
|
||||
response = requests.get(host)
|
||||
if response:
|
||||
print(response.json())
|
||||
|
||||
@@ -30,23 +30,15 @@ class ChatGPTBot(Bot, OpenAIImage):
|
||||
if conf().get("rate_limit_chatgpt"):
|
||||
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))
|
||||
|
||||
self.sessions = SessionManager(
|
||||
ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo"
|
||||
)
|
||||
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
self.args = {
|
||||
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
# "max_tokens":4096, # 回复最大的字符数
|
||||
"top_p": 1,
|
||||
"frequency_penalty": conf().get(
|
||||
"frequency_penalty", 0.0
|
||||
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get(
|
||||
"presence_penalty", 0.0
|
||||
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"request_timeout": conf().get(
|
||||
"request_timeout", None
|
||||
), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
|
||||
}
|
||||
|
||||
@@ -87,15 +79,10 @@ class ChatGPTBot(Bot, OpenAIImage):
|
||||
reply_content["completion_tokens"],
|
||||
)
|
||||
)
|
||||
if (
|
||||
reply_content["completion_tokens"] == 0
|
||||
and len(reply_content["content"]) > 0
|
||||
):
|
||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(
|
||||
reply_content["content"], session_id, reply_content["total_tokens"]
|
||||
)
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
@@ -126,9 +113,7 @@ class ChatGPTBot(Bot, OpenAIImage):
|
||||
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
|
||||
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
|
||||
# if api_key == None, the default openai.api_key will be used
|
||||
response = openai.ChatCompletion.create(
|
||||
api_key=api_key, messages=session.messages, **self.args
|
||||
)
|
||||
response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **self.args)
|
||||
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||
return {
|
||||
"total_tokens": response["usage"]["total_tokens"],
|
||||
|
||||
@@ -25,9 +25,7 @@ class ChatGPTSession(Session):
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug(
|
||||
"Exception when counting tokens precisely for query: {}".format(e)
|
||||
)
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 2:
|
||||
self.messages.pop(1)
|
||||
@@ -39,16 +37,10 @@ class ChatGPTSession(Session):
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
break
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
|
||||
logger.warn(
|
||||
"user message exceed max_tokens. total_tokens={}".format(cur_tokens)
|
||||
)
|
||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug(
|
||||
"max_tokens={}, total_tokens={}, len(messages)={}".format(
|
||||
max_tokens, cur_tokens, len(self.messages)
|
||||
)
|
||||
)
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
@@ -75,17 +67,13 @@ def num_tokens_from_messages(messages, model):
|
||||
elif model == "gpt-4":
|
||||
return num_tokens_from_messages(messages, model="gpt-4-0314")
|
||||
elif model == "gpt-3.5-turbo-0301":
|
||||
tokens_per_message = (
|
||||
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
)
|
||||
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||
elif model == "gpt-4-0314":
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
logger.warn(
|
||||
f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301."
|
||||
)
|
||||
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.")
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
|
||||
@@ -28,23 +28,15 @@ class OpenAIBot(Bot, OpenAIImage):
|
||||
if proxy:
|
||||
openai.proxy = proxy
|
||||
|
||||
self.sessions = SessionManager(
|
||||
OpenAISession, model=conf().get("model") or "text-davinci-003"
|
||||
)
|
||||
self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003")
|
||||
self.args = {
|
||||
"model": conf().get("model") or "text-davinci-003", # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
"max_tokens": 1200, # 回复最大的字符数
|
||||
"top_p": 1,
|
||||
"frequency_penalty": conf().get(
|
||||
"frequency_penalty", 0.0
|
||||
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get(
|
||||
"presence_penalty", 0.0
|
||||
), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"request_timeout": conf().get(
|
||||
"request_timeout", None
|
||||
), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
|
||||
"stop": ["\n\n\n"],
|
||||
}
|
||||
@@ -71,17 +63,13 @@ class OpenAIBot(Bot, OpenAIImage):
|
||||
result["content"],
|
||||
)
|
||||
logger.debug(
|
||||
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||
str(session), session_id, reply_content, completion_tokens
|
||||
)
|
||||
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)
|
||||
)
|
||||
|
||||
if total_tokens == 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content)
|
||||
else:
|
||||
self.sessions.session_reply(
|
||||
reply_content, session_id, total_tokens
|
||||
)
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
||||
reply = Reply(ReplyType.TEXT, reply_content)
|
||||
return reply
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
@@ -96,9 +84,7 @@ class OpenAIBot(Bot, OpenAIImage):
|
||||
def reply_text(self, session: OpenAISession, retry_count=0):
|
||||
try:
|
||||
response = openai.Completion.create(prompt=str(session), **self.args)
|
||||
res_content = (
|
||||
response.choices[0]["text"].strip().replace("<|endoftext|>", "")
|
||||
)
|
||||
res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "")
|
||||
total_tokens = response["usage"]["total_tokens"]
|
||||
completion_tokens = response["usage"]["completion_tokens"]
|
||||
logger.info("[OPEN_AI] reply={}".format(res_content))
|
||||
|
||||
@@ -23,9 +23,7 @@ class OpenAIImage(object):
|
||||
response = openai.Image.create(
|
||||
prompt=query, # 图片描述
|
||||
n=1, # 每次生成图片的数量
|
||||
size=conf().get(
|
||||
"image_create_size", "256x256"
|
||||
), # 图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
)
|
||||
image_url = response["data"][0]["url"]
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
@@ -34,11 +32,7 @@ class OpenAIImage(object):
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
logger.warn(
|
||||
"[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(
|
||||
retry_count + 1
|
||||
)
|
||||
)
|
||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1))
|
||||
return self.create_img(query, retry_count + 1)
|
||||
else:
|
||||
return False, "提问太快啦,请休息一下再问我吧"
|
||||
|
||||
@@ -36,9 +36,7 @@ class OpenAISession(Session):
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug(
|
||||
"Exception when counting tokens precisely for query: {}".format(e)
|
||||
)
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 1:
|
||||
self.messages.pop(0)
|
||||
@@ -50,18 +48,10 @@ class OpenAISession(Session):
|
||||
cur_tokens = len(str(self))
|
||||
break
|
||||
elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
|
||||
logger.warn(
|
||||
"user question exceed max_tokens. total_tokens={}".format(
|
||||
cur_tokens
|
||||
)
|
||||
)
|
||||
logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug(
|
||||
"max_tokens={}, total_tokens={}, len(conversation)={}".format(
|
||||
max_tokens, cur_tokens, len(self.messages)
|
||||
)
|
||||
)
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
|
||||
@@ -55,9 +55,7 @@ class SessionManager(object):
|
||||
return self.sessioncls(session_id, system_prompt, **self.session_args)
|
||||
|
||||
if session_id not in self.sessions:
|
||||
self.sessions[session_id] = self.sessioncls(
|
||||
session_id, system_prompt, **self.session_args
|
||||
)
|
||||
self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args)
|
||||
elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session
|
||||
self.sessions[session_id].set_system_prompt(system_prompt)
|
||||
session = self.sessions[session_id]
|
||||
@@ -71,9 +69,7 @@ class SessionManager(object):
|
||||
total_tokens = session.discard_exceeding(max_tokens, None)
|
||||
logger.debug("prompt tokens used={}".format(total_tokens))
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Exception when counting tokens precisely for prompt: {}".format(str(e))
|
||||
)
|
||||
logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
|
||||
return session
|
||||
|
||||
def session_reply(self, reply, session_id, total_tokens=None):
|
||||
@@ -82,17 +78,9 @@ class SessionManager(object):
|
||||
try:
|
||||
max_tokens = conf().get("conversation_max_tokens", 1000)
|
||||
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
|
||||
logger.debug(
|
||||
"raw total_tokens={}, savesession tokens={}".format(
|
||||
total_tokens, tokens_cnt
|
||||
)
|
||||
)
|
||||
logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Exception when counting tokens precisely for session: {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
logger.debug("Exception when counting tokens precisely for session: {}".format(str(e)))
|
||||
return session
|
||||
|
||||
def clear_session(self, session_id):
|
||||
|
||||
Reference in New Issue
Block a user