refactor llm config
This commit is contained in:
@@ -14,12 +14,6 @@ from .parser import parse_json
|
||||
from .prompt import build_prompt, load_template
|
||||
from .exceptions import LLMError, ParseError
|
||||
|
||||
try:
|
||||
import litellm
|
||||
HAS_LITELLM = True
|
||||
except ImportError:
|
||||
HAS_LITELLM = False
|
||||
|
||||
# 模块级信号量,懒加载
|
||||
_SEMAPHORE: Optional[asyncio.Semaphore] = None
|
||||
|
||||
@@ -38,14 +32,7 @@ def _call_with_requests(config: LLMConfig, prompt: str) -> str:
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {config.api_key}"
|
||||
}
|
||||
|
||||
# 兼容 litellm 的 openai/ 前缀处理,以及其他常见前缀清理
|
||||
model_name = config.model_name
|
||||
for prefix in ["openai/", "azure/", "bedrock/"]:
|
||||
if model_name.startswith(prefix):
|
||||
model_name = model_name[len(prefix):]
|
||||
break
|
||||
|
||||
data = {
|
||||
"model": model_name,
|
||||
"messages": [{"role": "user", "content": prompt}]
|
||||
@@ -82,28 +69,13 @@ def _call_with_requests(config: LLMConfig, prompt: str) -> str:
|
||||
async def call_llm(prompt: str, mode: LLMMode = LLMMode.NORMAL) -> str:
|
||||
"""
|
||||
基础 LLM 调用,自动控制并发
|
||||
使用 urllib 直接调用 OpenAI 兼容接口
|
||||
"""
|
||||
config = LLMConfig.from_mode(mode)
|
||||
semaphore = _get_semaphore()
|
||||
|
||||
async with semaphore:
|
||||
if HAS_LITELLM:
|
||||
try:
|
||||
# 使用 litellm 原生异步接口
|
||||
response = await litellm.acompletion(
|
||||
model=config.model_name,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
)
|
||||
result = response.choices[0].message.content
|
||||
except Exception as e:
|
||||
# 再次抛出以便上层处理,或者记录日志
|
||||
raise Exception(f"LiteLLM call failed: {str(e)}") from e
|
||||
else:
|
||||
# 降级到 requests (在线程池中运行),实现 OpenAI 兼容接口
|
||||
# 这样即使没有 litellm,只要模型服务提供商支持 OpenAI 格式(如 Qwen, DeepSeek, LocalAI 等)均可工作
|
||||
result = await asyncio.to_thread(_call_with_requests, config, prompt)
|
||||
result = await asyncio.to_thread(_call_with_requests, config, prompt)
|
||||
|
||||
log_llm_call(config.model_name, prompt, result)
|
||||
return result
|
||||
@@ -188,17 +160,7 @@ def test_connectivity(mode: LLMMode = LLMMode.NORMAL, config: Optional[LLMConfig
|
||||
if config is None:
|
||||
config = LLMConfig.from_mode(mode)
|
||||
|
||||
if HAS_LITELLM:
|
||||
# 使用 litellm 同步接口
|
||||
litellm.completion(
|
||||
model=config.model_name,
|
||||
messages=[{"role": "user", "content": "你好"}],
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
)
|
||||
else:
|
||||
# 直接调用 requests 实现
|
||||
_call_with_requests(config, "test")
|
||||
_call_with_requests(config, "test")
|
||||
return True, ""
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
||||
Reference in New Issue
Block a user