""" AI 服务抽象层 - 支持通义千问和本地模型 """ import os from abc import ABC, abstractmethod from typing import Optional, Dict, Any import yaml from pathlib import Path # 配置文件路径 CONFIG_PATH = Path(__file__).parent.parent.parent / "config.yaml" def load_config() -> Dict[str, Any]: """加载配置文件""" if CONFIG_PATH.exists(): with open(CONFIG_PATH, 'r', encoding='utf-8') as f: return yaml.safe_load(f) or {} return {} class AIProvider(ABC): """AI 服务提供者抽象基类""" @abstractmethod async def generate(self, prompt: str, context: Optional[str] = None) -> str: """生成内容 Args: prompt: 用户提示词 context: 可选的上下文信息 Returns: 生成的文本内容 """ pass @abstractmethod async def check(self, content: str, requirements: Optional[list] = None) -> Dict[str, Any]: """检查内容是否包含必要信息 Args: content: 要检查的内容 requirements: 可选的检查要求列表 Returns: 检查结果字典,包含 passed, issues, suggestions 等字段 """ pass class OpenAICompatibleProvider(AIProvider): """OpenAI 兼容接口实现 - 支持通义千问、DeepSeek 等""" def __init__( self, api_key: Optional[str] = None, base_url: Optional[str] = None, model: str = "qwen-plus" ): self.api_key = api_key or os.getenv("AI_API_KEY", "") self.base_url = base_url or os.getenv("AI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1") self.model = model # 初始化 OpenAI 客户端 from openai import AsyncOpenAI self.client = AsyncOpenAI( api_key=self.api_key, base_url=self.base_url ) async def generate(self, prompt: str, context: Optional[str] = None) -> str: """使用 OpenAI 兼容接口生成内容""" messages = [] if context: messages.append({"role": "system", "content": context}) messages.append({"role": "user", "content": prompt}) try: response = await self.client.chat.completions.create( model=self.model, messages=messages ) return response.choices[0].message.content except Exception as e: return f"调用API出错: {str(e)}" async def check(self, content: str, requirements: Optional[list] = None) -> Dict[str, Any]: """使用 AI 检查内容""" check_prompt = f"""请检查以下意图编制内容是否完整,是否包含必要的信息。 要检查的内容: {content} 请检查以下方面: 1. 测试目标是否明确 2. 测试范围是否清晰 3. 测试条件是否完整 4. 预期结果是否明确 5. 是否有遗漏的关键信息 请以JSON格式返回检查结果: {{ "passed": true/false, "score": 0-100, "issues": ["问题1", "问题2"], "suggestions": ["建议1", "建议2"] }} """ try: response = await self.client.chat.completions.create( model=self.model, messages=[{"role": "user", "content": check_prompt}] ) result_text = response.choices[0].message.content # 尝试解析JSON import json try: # 提取JSON部分 start = result_text.find('{') end = result_text.rfind('}') + 1 if start != -1 and end > start: return json.loads(result_text[start:end]) except: pass return { "passed": False, "score": 0, "issues": ["无法解析AI返回结果"], "suggestions": [], "raw_response": result_text } except Exception as e: return { "passed": False, "score": 0, "issues": [f"调用出错: {str(e)}"], "suggestions": [] } class LocalModelProvider(AIProvider): """本地模型实现 - 兼容 OpenAI API 格式""" def __init__(self, endpoint: str = "http://localhost:8000", model: str = "llama3", api_key: str = ""): self.endpoint = endpoint.rstrip('/') self.model = model self.api_key = api_key or os.getenv("LOCAL_MODEL_API_KEY", "not-needed") async def generate(self, prompt: str, context: Optional[str] = None) -> str: """使用本地模型生成内容""" import aiohttp messages = [] if context: messages.append({"role": "system", "content": context}) messages.append({"role": "user", "content": prompt}) headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}" } payload = { "model": self.model, "messages": messages, "stream": False } try: async with aiohttp.ClientSession() as session: async with session.post( f"{self.endpoint}/v1/chat/completions", headers=headers, json=payload ) as response: if response.status == 200: data = await response.json() return data["choices"][0]["message"]["content"] else: return f"生成失败: HTTP {response.status}" except Exception as e: return f"调用本地模型出错: {str(e)}" async def check(self, content: str, requirements: Optional[list] = None) -> Dict[str, Any]: """使用本地模型检查内容""" check_prompt = f"""请检查以下意图编制内容是否完整。 内容: {content} 请以JSON格式返回:{{"passed": bool, "score": int, "issues": [], "suggestions": []}}""" result = await self.generate(check_prompt) try: import json start = result.find('{') end = result.rfind('}') + 1 if start != -1 and end > start: return json.loads(result[start:end]) except: pass return { "passed": False, "score": 0, "issues": ["无法解析结果"], "suggestions": [], "raw_response": result } class AIServiceFactory: """AI 服务工厂 - 根据配置创建对应的 Provider""" _instance: Optional[AIProvider] = None @classmethod def get_provider(cls) -> AIProvider: """获取 AI Provider 单例""" if cls._instance is None: cls._instance = cls._create_provider() return cls._instance @classmethod def _create_provider(cls) -> AIProvider: """根据配置创建 Provider""" config = load_config() ai_config = config.get("ai", {}) # 优先从环境变量读取 api_key = os.getenv("AI_API_KEY", "") base_url = os.getenv("AI_BASE_URL", "") model = os.getenv("AI_MODEL", "") # 如果环境变量未设置,从配置文件读取 if not api_key: api_key = ai_config.get("api_key", "") if not base_url: base_url = ai_config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1") if not model: model = ai_config.get("model", "qwen-plus") return OpenAICompatibleProvider( api_key=api_key, base_url=base_url, model=model ) @classmethod def reset(cls): """重置单例,用于切换 Provider""" cls._instance = None # 便捷函数 def get_ai_service() -> AIProvider: """获取 AI 服务实例""" return AIServiceFactory.get_provider()