Files
2026-02-05 16:25:52 +08:00

264 lines
7.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
AI 服务抽象层 - 支持通义千问和本地模型
"""
import os
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any
import yaml
from pathlib import Path
# 配置文件路径
CONFIG_PATH = Path(__file__).parent.parent.parent / "config.yaml"
def load_config() -> Dict[str, Any]:
"""加载配置文件"""
if CONFIG_PATH.exists():
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
return yaml.safe_load(f) or {}
return {}
class AIProvider(ABC):
"""AI 服务提供者抽象基类"""
@abstractmethod
async def generate(self, prompt: str, context: Optional[str] = None) -> str:
"""生成内容
Args:
prompt: 用户提示词
context: 可选的上下文信息
Returns:
生成的文本内容
"""
pass
@abstractmethod
async def check(self, content: str, requirements: Optional[list] = None) -> Dict[str, Any]:
"""检查内容是否包含必要信息
Args:
content: 要检查的内容
requirements: 可选的检查要求列表
Returns:
检查结果字典,包含 passed, issues, suggestions 等字段
"""
pass
class OpenAICompatibleProvider(AIProvider):
"""OpenAI 兼容接口实现 - 支持通义千问、DeepSeek 等"""
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "qwen-plus"
):
self.api_key = api_key or os.getenv("AI_API_KEY", "")
self.base_url = base_url or os.getenv("AI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
self.model = model
# 初始化 OpenAI 客户端
from openai import AsyncOpenAI
self.client = AsyncOpenAI(
api_key=self.api_key,
base_url=self.base_url
)
async def generate(self, prompt: str, context: Optional[str] = None) -> str:
"""使用 OpenAI 兼容接口生成内容"""
messages = []
if context:
messages.append({"role": "system", "content": context})
messages.append({"role": "user", "content": prompt})
try:
response = await self.client.chat.completions.create(
model=self.model,
messages=messages
)
return response.choices[0].message.content
except Exception as e:
return f"调用API出错: {str(e)}"
async def check(self, content: str, requirements: Optional[list] = None) -> Dict[str, Any]:
"""使用 AI 检查内容"""
check_prompt = f"""请检查以下意图编制内容是否完整,是否包含必要的信息。
要检查的内容:
{content}
请检查以下方面:
1. 测试目标是否明确
2. 测试范围是否清晰
3. 测试条件是否完整
4. 预期结果是否明确
5. 是否有遗漏的关键信息
请以JSON格式返回检查结果
{{
"passed": true/false,
"score": 0-100,
"issues": ["问题1", "问题2"],
"suggestions": ["建议1", "建议2"]
}}
"""
try:
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": check_prompt}]
)
result_text = response.choices[0].message.content
# 尝试解析JSON
import json
try:
# 提取JSON部分
start = result_text.find('{')
end = result_text.rfind('}') + 1
if start != -1 and end > start:
return json.loads(result_text[start:end])
except:
pass
return {
"passed": False,
"score": 0,
"issues": ["无法解析AI返回结果"],
"suggestions": [],
"raw_response": result_text
}
except Exception as e:
return {
"passed": False,
"score": 0,
"issues": [f"调用出错: {str(e)}"],
"suggestions": []
}
class LocalModelProvider(AIProvider):
"""本地模型实现 - 兼容 OpenAI API 格式"""
def __init__(self, endpoint: str = "http://localhost:8000", model: str = "llama3", api_key: str = ""):
self.endpoint = endpoint.rstrip('/')
self.model = model
self.api_key = api_key or os.getenv("LOCAL_MODEL_API_KEY", "not-needed")
async def generate(self, prompt: str, context: Optional[str] = None) -> str:
"""使用本地模型生成内容"""
import aiohttp
messages = []
if context:
messages.append({"role": "system", "content": context})
messages.append({"role": "user", "content": prompt})
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
payload = {
"model": self.model,
"messages": messages,
"stream": False
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.endpoint}/v1/chat/completions",
headers=headers,
json=payload
) as response:
if response.status == 200:
data = await response.json()
return data["choices"][0]["message"]["content"]
else:
return f"生成失败: HTTP {response.status}"
except Exception as e:
return f"调用本地模型出错: {str(e)}"
async def check(self, content: str, requirements: Optional[list] = None) -> Dict[str, Any]:
"""使用本地模型检查内容"""
check_prompt = f"""请检查以下意图编制内容是否完整。
内容:
{content}
请以JSON格式返回{{"passed": bool, "score": int, "issues": [], "suggestions": []}}"""
result = await self.generate(check_prompt)
try:
import json
start = result.find('{')
end = result.rfind('}') + 1
if start != -1 and end > start:
return json.loads(result[start:end])
except:
pass
return {
"passed": False,
"score": 0,
"issues": ["无法解析结果"],
"suggestions": [],
"raw_response": result
}
class AIServiceFactory:
"""AI 服务工厂 - 根据配置创建对应的 Provider"""
_instance: Optional[AIProvider] = None
@classmethod
def get_provider(cls) -> AIProvider:
"""获取 AI Provider 单例"""
if cls._instance is None:
cls._instance = cls._create_provider()
return cls._instance
@classmethod
def _create_provider(cls) -> AIProvider:
"""根据配置创建 Provider"""
config = load_config()
ai_config = config.get("ai", {})
# 优先从环境变量读取
api_key = os.getenv("AI_API_KEY", "")
base_url = os.getenv("AI_BASE_URL", "")
model = os.getenv("AI_MODEL", "")
# 如果环境变量未设置,从配置文件读取
if not api_key:
api_key = ai_config.get("api_key", "")
if not base_url:
base_url = ai_config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
if not model:
model = ai_config.get("model", "qwen-plus")
return OpenAICompatibleProvider(
api_key=api_key,
base_url=base_url,
model=model
)
@classmethod
def reset(cls):
"""重置单例,用于切换 Provider"""
cls._instance = None
# 便捷函数
def get_ai_service() -> AIProvider:
"""获取 AI 服务实例"""
return AIServiceFactory.get_provider()