mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-04-18 01:53:47 +08:00
feat: prioritize handling commands
This commit is contained in:
@@ -1,14 +1,12 @@
|
|||||||
|
|
||||||
|
|
||||||
from asyncio import CancelledError
|
from asyncio import CancelledError
|
||||||
import queue
|
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from channel.chat_message import ChatMessage
|
from common.dequeue import Dequeue
|
||||||
from common.expired_dict import ExpiredDict
|
|
||||||
from channel.channel import Channel
|
from channel.channel import Channel
|
||||||
from bridge.reply import *
|
from bridge.reply import *
|
||||||
from bridge.context import *
|
from bridge.context import *
|
||||||
@@ -245,8 +243,11 @@ class ChatChannel(Channel):
|
|||||||
session_id = context['session_id']
|
session_id = context['session_id']
|
||||||
with self.lock:
|
with self.lock:
|
||||||
if session_id not in self.sessions:
|
if session_id not in self.sessions:
|
||||||
self.sessions[session_id] = (queue.Queue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 1)))
|
self.sessions[session_id] = (Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 1)))
|
||||||
self.sessions[session_id][0].put(context)
|
if context.type == ContextType.TEXT and context.content.startswith("#"):
|
||||||
|
self.sessions[session_id][0].putleft(context) # 优先处理命令
|
||||||
|
else:
|
||||||
|
self.sessions[session_id][0].put(context)
|
||||||
|
|
||||||
# 消费者函数,单独线程,用于从消息队列中取出消息并处理
|
# 消费者函数,单独线程,用于从消息队列中取出消息并处理
|
||||||
def consume(self):
|
def consume(self):
|
||||||
@@ -277,7 +278,7 @@ class ChatChannel(Channel):
|
|||||||
if session_id in self.sessions:
|
if session_id in self.sessions:
|
||||||
for future in self.futures[session_id]:
|
for future in self.futures[session_id]:
|
||||||
future.cancel()
|
future.cancel()
|
||||||
self.sessions[session_id][0]=queue.Queue()
|
self.sessions[session_id][0]=Dequeue()
|
||||||
|
|
||||||
|
|
||||||
def check_prefix(content, prefix_list):
|
def check_prefix(content, prefix_list):
|
||||||
|
|||||||
33
common/dequeue.py
Normal file
33
common/dequeue.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
|
||||||
|
from queue import Full, Queue
|
||||||
|
from time import monotonic as time
|
||||||
|
|
||||||
|
# add implementation of putleft to Queue
|
||||||
|
class Dequeue(Queue):
|
||||||
|
def putleft(self, item, block=True, timeout=None):
|
||||||
|
with self.not_full:
|
||||||
|
if self.maxsize > 0:
|
||||||
|
if not block:
|
||||||
|
if self._qsize() >= self.maxsize:
|
||||||
|
raise Full
|
||||||
|
elif timeout is None:
|
||||||
|
while self._qsize() >= self.maxsize:
|
||||||
|
self.not_full.wait()
|
||||||
|
elif timeout < 0:
|
||||||
|
raise ValueError("'timeout' must be a non-negative number")
|
||||||
|
else:
|
||||||
|
endtime = time() + timeout
|
||||||
|
while self._qsize() >= self.maxsize:
|
||||||
|
remaining = endtime - time()
|
||||||
|
if remaining <= 0.0:
|
||||||
|
raise Full
|
||||||
|
self.not_full.wait(remaining)
|
||||||
|
self._putleft(item)
|
||||||
|
self.unfinished_tasks += 1
|
||||||
|
self.not_empty.notify()
|
||||||
|
|
||||||
|
def put_nowait(self, item):
|
||||||
|
return self.put(item, block=False)
|
||||||
|
|
||||||
|
def _putleft(self, item):
|
||||||
|
self.queue.appendleft(item)
|
||||||
Reference in New Issue
Block a user