Initial commit

This commit is contained in:
Your Name
2026-02-05 16:25:52 +08:00
commit d5ea866eb4
178 changed files with 32681 additions and 0 deletions

7
backend/app/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
"""
App package initialization
"""
from app.database import Base, get_db, engine, AsyncSessionLocal
__all__ = ["Base", "get_db", "engine", "AsyncSessionLocal"]

View File

@@ -0,0 +1,7 @@
"""
API package initialization
"""
from app.api.routes import intent_router, ai_router
__all__ = ["intent_router", "ai_router"]

View File

@@ -0,0 +1,8 @@
"""
API routes package initialization
"""
from app.api.routes.intent import router as intent_router
from app.api.routes.ai_assistant import router as ai_router
__all__ = ["intent_router", "ai_router"]

View File

@@ -0,0 +1,401 @@
"""
AI 助手 API 路由
"""
from typing import Optional, List
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from pydantic import BaseModel
from datetime import datetime
from app.database import get_db
from app.models.intent import Intent, AICheckRecord
from app.services.ai_service import get_ai_service
router = APIRouter(prefix="/ai", tags=["AI助手"])
# ============ Pydantic Schemas ============
class GenerateRequest(BaseModel):
"""AI生成请求"""
prompt: str
context: Optional[str] = None
class GenerateResponse(BaseModel):
"""AI生成响应"""
content: str
success: bool
error: Optional[str] = None
class CheckRequest(BaseModel):
"""AI检查请求"""
content: str
intent_id: Optional[int] = None # 如果提供,会保存检查记录
requirements: Optional[List[str]] = None
class CheckResult(BaseModel):
"""AI检查结果"""
passed: bool
score: int
issues: List[str]
suggestions: List[str]
raw_response: Optional[str] = None
class CheckResponse(BaseModel):
"""AI检查响应"""
result: CheckResult
record_id: Optional[int] = None # 如果保存了记录
success: bool
error: Optional[str] = None
class IntentGenerateRequest(BaseModel):
"""意图编制生成请求"""
test_type: str # 测试类型:功能测试、性能测试、安全测试等
test_target: str # 测试目标描述
additional_requirements: Optional[str] = None
class PlanGenerateRequest(BaseModel):
"""测试规划生成请求"""
requirement_text: str # 测试需求文本(前端文本框内容)
source_name: Optional[str] = "用户输入" # 来源名称
class ExtractStepsRequest(BaseModel):
"""提取测试步骤请求"""
content: str # 意图编制内容
title: Optional[str] = "意图编制"
class ExtractedStepData(BaseModel):
"""提取的步骤数据"""
id: int
name: str
purpose: str
deviceParams: dict # {设备类别: {参数名: 参数值}}
class ExtractStepsResponse(BaseModel):
"""提取测试步骤响应"""
success: bool
steps: List[ExtractedStepData]
error: Optional[str] = None
# ============ API Endpoints ============
@router.post("/generate", response_model=GenerateResponse)
async def generate_content(request: GenerateRequest):
"""AI 生成内容"""
try:
ai_service = get_ai_service()
content = await ai_service.generate(
prompt=request.prompt,
context=request.context
)
return GenerateResponse(content=content, success=True)
except Exception as e:
return GenerateResponse(
content="",
success=False,
error=str(e)
)
@router.post("/check", response_model=CheckResponse)
async def check_content(
request: CheckRequest,
db: AsyncSession = Depends(get_db)
):
"""AI 检查内容"""
try:
ai_service = get_ai_service()
result = await ai_service.check(
content=request.content,
requirements=request.requirements
)
check_result = CheckResult(
passed=result.get("passed", False),
score=result.get("score", 0),
issues=result.get("issues", []),
suggestions=result.get("suggestions", []),
raw_response=result.get("raw_response")
)
record_id = None
# 如果提供了 intent_id保存检查记录
if request.intent_id:
# 验证意图存在
query = select(Intent).where(Intent.id == request.intent_id)
intent_result = await db.execute(query)
intent = intent_result.scalar_one_or_none()
if intent:
record = AICheckRecord(
intent_id=request.intent_id,
check_result=result,
suggestions="\n".join(result.get("suggestions", []))
)
db.add(record)
await db.commit()
await db.refresh(record)
record_id = record.id
return CheckResponse(
result=check_result,
record_id=record_id,
success=True
)
except Exception as e:
return CheckResponse(
result=CheckResult(passed=False, score=0, issues=[str(e)], suggestions=[]),
success=False,
error=str(e)
)
@router.post("/generate-intent", response_model=GenerateResponse)
async def generate_intent_content(request: IntentGenerateRequest):
"""AI 生成意图编制初始内容"""
prompt = f"""请帮我生成一份测试意图编制文档的初始内容。
测试类型:{request.test_type}
测试目标:{request.test_target}
{"额外要求:" + request.additional_requirements if request.additional_requirements else ""}
请按照以下格式生成内容:
## 1. 测试目标
[详细描述测试的目标和预期达成的效果]
## 2. 测试范围
[明确测试的边界和覆盖范围]
## 3. 测试条件
### 3.1 前置条件
[列出测试开始前需要满足的条件]
### 3.2 测试环境
[描述测试所需的硬件、软件环境]
### 3.3 测试数据
[描述测试所需的数据准备]
## 4. 测试用例概述
[列出主要的测试场景和用例]
## 5. 预期结果
[描述测试成功的判定标准]
## 6. 风险与注意事项
[列出可能的风险和需要注意的事项]
"""
context = """你是一位专业的软件测试工程师,擅长编写测试意图编制文档。
请生成规范、完整、专业的测试意图编制内容。"""
try:
ai_service = get_ai_service()
content = await ai_service.generate(prompt=prompt, context=context)
return GenerateResponse(content=content, success=True)
except Exception as e:
return GenerateResponse(
content="",
success=False,
error=str(e)
)
@router.post("/generate-plan", response_model=GenerateResponse)
async def generate_test_plan(request: PlanGenerateRequest):
"""
根据前端输入的测试需求文本,调用 planner 生成测试规划。
直接返回生成的 Markdown 内容(不保存文件)。
"""
from planner.planning_agent.planner import build_plan_from_text
try:
# 直接调用规划生成函数,返回 Markdown 内容
md_content = build_plan_from_text(
requirement_text=request.requirement_text,
source_doc=request.source_name or "用户输入"
)
if not md_content:
return GenerateResponse(
content="",
success=False,
error="生成内容为空,请检查输入内容或后端服务"
)
return GenerateResponse(content=md_content, success=True)
except Exception as e:
import traceback
traceback.print_exc()
return GenerateResponse(
content="",
success=False,
error=f"规划生成失败: {str(e)}"
)
@router.post("/extract-steps", response_model=ExtractStepsResponse)
async def extract_steps_from_intent(request: ExtractStepsRequest):
"""
从意图编制内容中提取测试步骤和参数。
使用 LLM 解析文本并返回结构化的步骤数据。
"""
try:
ai_service = get_ai_service()
# 构建提取 prompt
extract_prompt = f"""请从以下测试规划内容中提取测试步骤和仪器参数。
请严格按照以下 JSON 格式返回,不要添加任何其他内容:
```json
{{
"steps": [
{{
"id": 1,
"name": "步骤名称",
"purpose": "步骤目的",
"deviceParams": {{
"程控电源参数": {{
"输出电压": "28V",
"输出电流": "6A"
}},
"频谱分析仪参数": {{
"中心频率": "2.7GHz",
"扫宽Span": "500MHz"
}}
}}
}}
]
}}
```
设备类别必须使用以下名称之一:
- 程控电源参数
- 功率探头参数
- 频谱分析仪参数
- 矢量网络分析仪参数
- 矢量信号分析仪参数
- 矢量信号源参数
- 示波器参数
- 信号源基础参数
测试规划内容:
{request.content}
"""
context = """你是一个专业的测试工程师,擅长从测试规划文档中提取结构化的测试步骤和仪器参数。
请仔细分析文档,提取每个测试步骤及其涉及的仪器配置参数。
只返回 JSON 格式的结果,不要添加任何解释或其他文字。"""
response_text = await ai_service.generate(prompt=extract_prompt, context=context)
# 解析 JSON 响应
import json
import re
# 尝试从响应中提取 JSON
json_match = re.search(r'```json\s*([\s\S]*?)\s*```', response_text)
if json_match:
json_str = json_match.group(1)
else:
# 尝试直接解析整个响应
json_str = response_text.strip()
# 清理可能的 markdown 代码块标记
json_str = re.sub(r'^```\w*\n?', '', json_str)
json_str = re.sub(r'\n?```$', '', json_str)
try:
parsed = json.loads(json_str)
steps_data = parsed.get("steps", [])
except json.JSONDecodeError:
# 如果解析失败,返回空列表
return ExtractStepsResponse(
success=False,
steps=[],
error="无法解析 LLM 返回的 JSON 格式"
)
# 构建响应,过滤掉空值
def filter_device_params(params: dict) -> dict:
"""过滤掉空的设备参数"""
if not params or not isinstance(params, dict):
return {}
filtered = {}
for device_category, device_params in params.items():
if not device_params or not isinstance(device_params, dict):
continue
# 过滤掉空值的参数
non_empty_params = {
k: v for k, v in device_params.items()
if v and isinstance(v, str) and v.strip()
}
if non_empty_params:
filtered[device_category] = non_empty_params
return filtered
steps = []
for step in steps_data:
filtered_params = filter_device_params(step.get("deviceParams", {}))
steps.append(ExtractedStepData(
id=step.get("id", len(steps) + 1),
name=step.get("name", f"步骤 {len(steps) + 1}"),
purpose=step.get("purpose", ""),
deviceParams=filtered_params
))
return ExtractStepsResponse(
success=True,
steps=steps
)
except Exception as e:
import traceback
traceback.print_exc()
return ExtractStepsResponse(
success=False,
steps=[],
error=f"步骤提取失败: {str(e)}"
)
@router.get("/check-records/{intent_id}")
async def get_check_records(
intent_id: int,
db: AsyncSession = Depends(get_db)
):
"""获取意图的所有检查记录"""
query = select(AICheckRecord).where(
AICheckRecord.intent_id == intent_id
).order_by(AICheckRecord.checked_at.desc())
result = await db.execute(query)
records = result.scalars().all()
return {
"intent_id": intent_id,
"total": len(records),
"records": [
{
"id": r.id,
"check_result": r.check_result,
"suggestions": r.suggestions,
"checked_at": r.checked_at.isoformat()
}
for r in records
]
}

View File

@@ -0,0 +1,188 @@
"""
意图编制 API 路由
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from pydantic import BaseModel
from datetime import datetime
from pathlib import Path
from app.database import get_db
from app.models.intent import Intent, AICheckRecord
from planner.planning_agent.input_pipeline import parse_intent_file
import tempfile
router = APIRouter(prefix="/intent", tags=["意图编制"])
# ============ Pydantic Schemas ============
class IntentCreate(BaseModel):
"""创建意图请求"""
title: str
content: Optional[str] = None
status: str = "draft"
class IntentUpdate(BaseModel):
"""更新意图请求"""
title: Optional[str] = None
content: Optional[str] = None
status: Optional[str] = None
class IntentResponse(BaseModel):
"""意图响应"""
id: int
title: str
content: Optional[str]
status: str
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class IntentListResponse(BaseModel):
"""意图列表响应"""
total: int
items: List[IntentResponse]
# ============ API Endpoints ============
@router.post("", response_model=IntentResponse, status_code=status.HTTP_201_CREATED)
async def create_intent(
intent_data: IntentCreate,
db: AsyncSession = Depends(get_db)
):
"""创建新意图"""
intent = Intent(
title=intent_data.title,
content=intent_data.content,
status=intent_data.status
)
db.add(intent)
await db.commit()
await db.refresh(intent)
return intent
@router.get("", response_model=IntentListResponse)
async def list_intents(
skip: int = 0,
limit: int = 20,
status_filter: Optional[str] = None,
db: AsyncSession = Depends(get_db)
):
"""获取意图列表"""
query = select(Intent)
if status_filter:
query = query.where(Intent.status == status_filter)
query = query.order_by(Intent.updated_at.desc())
# 获取总数
count_query = select(Intent)
if status_filter:
count_query = count_query.where(Intent.status == status_filter)
result = await db.execute(count_query)
total = len(result.scalars().all())
# 获取分页数据
query = query.offset(skip).limit(limit)
result = await db.execute(query)
items = result.scalars().all()
return IntentListResponse(total=total, items=items)
@router.get("/{intent_id}", response_model=IntentResponse)
async def get_intent(
intent_id: int,
db: AsyncSession = Depends(get_db)
):
"""获取单个意图"""
result = await db.execute(select(Intent).where(Intent.id == intent_id))
intent = result.scalar_one_or_none()
if not intent:
raise HTTPException(status_code=404, detail="意图不存在")
return intent
@router.put("/{intent_id}", response_model=IntentResponse)
async def update_intent(
intent_id: int,
intent_data: IntentUpdate,
db: AsyncSession = Depends(get_db)
):
"""更新意图"""
result = await db.execute(select(Intent).where(Intent.id == intent_id))
intent = result.scalar_one_or_none()
if not intent:
raise HTTPException(status_code=404, detail="意图不存在")
if intent_data.title is not None:
intent.title = intent_data.title
if intent_data.content is not None:
intent.content = intent_data.content
if intent_data.status is not None:
intent.status = intent_data.status
intent.updated_at = datetime.utcnow()
await db.commit()
await db.refresh(intent)
return intent
@router.delete("/{intent_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_intent(
intent_id: int,
db: AsyncSession = Depends(get_db)
):
"""删除意图"""
result = await db.execute(select(Intent).where(Intent.id == intent_id))
intent = result.scalar_one_or_none()
if not intent:
raise HTTPException(status_code=404, detail="意图不存在")
await db.delete(intent)
await db.commit()
return None
@router.post("/parse-file")
async def parse_intent_file_endpoint(file: UploadFile = File(...)):
"""解析意图编制导入文件pdf/图片走 MinerU文本直接读取"""
if not file.filename:
raise HTTPException(status_code=400, detail="未提供文件")
suffix = Path(file.filename).suffix.lower()
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir) / file.filename
content = await file.read()
tmp_path.write_bytes(content)
try:
result = parse_intent_file(str(tmp_path))
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
return {
"filename": file.filename,
"suffix": suffix,
"title": result.get("title", ""),
"content": result.get("content", ""),
"raw_result": result.get("raw_result"),
}

38
backend/app/database.py Normal file
View File

@@ -0,0 +1,38 @@
"""
数据库配置和连接管理
"""
import os
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import declarative_base
from dotenv import load_dotenv
load_dotenv()
DATABASE_URL = os.getenv(
"DATABASE_URL",
"postgresql+asyncpg://postgres:postgres123@localhost:5432/flexible_test_platform"
)
engine = create_async_engine(
DATABASE_URL,
echo=True, # 开发环境打印 SQL
pool_pre_ping=True,
)
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
)
Base = declarative_base()
async def get_db():
"""获取数据库会话"""
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()

68
backend/app/main.py Normal file
View File

@@ -0,0 +1,68 @@
"""
柔性敏捷智能测试体系平台 - FastAPI 后端主入口
"""
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.database import engine, Base
from app.api import intent_router, ai_router
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
# 启动时创建数据库表
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
# 关闭时清理资源
await engine.dispose()
app = FastAPI(
title="柔性敏捷智能测试体系平台",
description="Flexible Agile Intelligent Test System Platform API",
version="1.0.0",
lifespan=lifespan
)
# CORS 配置 - 允许 Electron 前端访问
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Electron 应用使用 file:// 协议
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册路由
app.include_router(intent_router, prefix="/api")
app.include_router(ai_router, prefix="/api")
@app.get("/")
async def root():
"""根路径"""
return {
"name": "柔性敏捷智能测试体系平台",
"version": "1.0.0",
"status": "running"
}
@app.get("/health")
async def health_check():
"""健康检查"""
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host="0.0.0.0",
port=8080,
reload=True
)

View File

@@ -0,0 +1,7 @@
"""
Models package initialization
"""
from app.models.intent import Intent, AICheckRecord
__all__ = ["Intent", "AICheckRecord"]

View File

@@ -0,0 +1,38 @@
"""
意图编制数据模型
"""
from datetime import datetime
from typing import Optional
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, JSON
from sqlalchemy.orm import relationship
from app.database import Base
class Intent(Base):
"""意图编制表"""
__tablename__ = "intents"
id = Column(Integer, primary_key=True, index=True)
title = Column(String(255), nullable=False, comment="意图标题")
content = Column(Text, nullable=True, comment="意图内容")
status = Column(String(50), default="draft", comment="状态: draft, submitted, approved")
created_at = Column(DateTime, default=datetime.utcnow, comment="创建时间")
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, comment="更新时间")
# 关联 AI 检查记录
check_records = relationship("AICheckRecord", back_populates="intent", cascade="all, delete-orphan")
class AICheckRecord(Base):
"""AI 检查记录表"""
__tablename__ = "ai_check_records"
id = Column(Integer, primary_key=True, index=True)
intent_id = Column(Integer, ForeignKey("intents.id"), nullable=False)
check_result = Column(JSON, nullable=True, comment="检查结果JSON")
suggestions = Column(Text, nullable=True, comment="AI建议")
checked_at = Column(DateTime, default=datetime.utcnow, comment="检查时间")
# 关联意图
intent = relationship("Intent", back_populates="check_records")

View File

@@ -0,0 +1,19 @@
"""
Services package initialization
"""
from app.services.ai_service import (
AIProvider,
OpenAICompatibleProvider,
LocalModelProvider,
AIServiceFactory,
get_ai_service
)
__all__ = [
"AIProvider",
"OpenAICompatibleProvider",
"LocalModelProvider",
"AIServiceFactory",
"get_ai_service"
]

View File

@@ -0,0 +1,263 @@
"""
AI 服务抽象层 - 支持通义千问和本地模型
"""
import os
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any
import yaml
from pathlib import Path
# 配置文件路径
CONFIG_PATH = Path(__file__).parent.parent.parent / "config.yaml"
def load_config() -> Dict[str, Any]:
"""加载配置文件"""
if CONFIG_PATH.exists():
with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
return yaml.safe_load(f) or {}
return {}
class AIProvider(ABC):
"""AI 服务提供者抽象基类"""
@abstractmethod
async def generate(self, prompt: str, context: Optional[str] = None) -> str:
"""生成内容
Args:
prompt: 用户提示词
context: 可选的上下文信息
Returns:
生成的文本内容
"""
pass
@abstractmethod
async def check(self, content: str, requirements: Optional[list] = None) -> Dict[str, Any]:
"""检查内容是否包含必要信息
Args:
content: 要检查的内容
requirements: 可选的检查要求列表
Returns:
检查结果字典,包含 passed, issues, suggestions 等字段
"""
pass
class OpenAICompatibleProvider(AIProvider):
"""OpenAI 兼容接口实现 - 支持通义千问、DeepSeek 等"""
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = "qwen-plus"
):
self.api_key = api_key or os.getenv("AI_API_KEY", "")
self.base_url = base_url or os.getenv("AI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
self.model = model
# 初始化 OpenAI 客户端
from openai import AsyncOpenAI
self.client = AsyncOpenAI(
api_key=self.api_key,
base_url=self.base_url
)
async def generate(self, prompt: str, context: Optional[str] = None) -> str:
"""使用 OpenAI 兼容接口生成内容"""
messages = []
if context:
messages.append({"role": "system", "content": context})
messages.append({"role": "user", "content": prompt})
try:
response = await self.client.chat.completions.create(
model=self.model,
messages=messages
)
return response.choices[0].message.content
except Exception as e:
return f"调用API出错: {str(e)}"
async def check(self, content: str, requirements: Optional[list] = None) -> Dict[str, Any]:
"""使用 AI 检查内容"""
check_prompt = f"""请检查以下意图编制内容是否完整,是否包含必要的信息。
要检查的内容:
{content}
请检查以下方面:
1. 测试目标是否明确
2. 测试范围是否清晰
3. 测试条件是否完整
4. 预期结果是否明确
5. 是否有遗漏的关键信息
请以JSON格式返回检查结果
{{
"passed": true/false,
"score": 0-100,
"issues": ["问题1", "问题2"],
"suggestions": ["建议1", "建议2"]
}}
"""
try:
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": check_prompt}]
)
result_text = response.choices[0].message.content
# 尝试解析JSON
import json
try:
# 提取JSON部分
start = result_text.find('{')
end = result_text.rfind('}') + 1
if start != -1 and end > start:
return json.loads(result_text[start:end])
except:
pass
return {
"passed": False,
"score": 0,
"issues": ["无法解析AI返回结果"],
"suggestions": [],
"raw_response": result_text
}
except Exception as e:
return {
"passed": False,
"score": 0,
"issues": [f"调用出错: {str(e)}"],
"suggestions": []
}
class LocalModelProvider(AIProvider):
"""本地模型实现 - 兼容 OpenAI API 格式"""
def __init__(self, endpoint: str = "http://localhost:8000", model: str = "llama3", api_key: str = ""):
self.endpoint = endpoint.rstrip('/')
self.model = model
self.api_key = api_key or os.getenv("LOCAL_MODEL_API_KEY", "not-needed")
async def generate(self, prompt: str, context: Optional[str] = None) -> str:
"""使用本地模型生成内容"""
import aiohttp
messages = []
if context:
messages.append({"role": "system", "content": context})
messages.append({"role": "user", "content": prompt})
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
payload = {
"model": self.model,
"messages": messages,
"stream": False
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.endpoint}/v1/chat/completions",
headers=headers,
json=payload
) as response:
if response.status == 200:
data = await response.json()
return data["choices"][0]["message"]["content"]
else:
return f"生成失败: HTTP {response.status}"
except Exception as e:
return f"调用本地模型出错: {str(e)}"
async def check(self, content: str, requirements: Optional[list] = None) -> Dict[str, Any]:
"""使用本地模型检查内容"""
check_prompt = f"""请检查以下意图编制内容是否完整。
内容:
{content}
请以JSON格式返回{{"passed": bool, "score": int, "issues": [], "suggestions": []}}"""
result = await self.generate(check_prompt)
try:
import json
start = result.find('{')
end = result.rfind('}') + 1
if start != -1 and end > start:
return json.loads(result[start:end])
except:
pass
return {
"passed": False,
"score": 0,
"issues": ["无法解析结果"],
"suggestions": [],
"raw_response": result
}
class AIServiceFactory:
"""AI 服务工厂 - 根据配置创建对应的 Provider"""
_instance: Optional[AIProvider] = None
@classmethod
def get_provider(cls) -> AIProvider:
"""获取 AI Provider 单例"""
if cls._instance is None:
cls._instance = cls._create_provider()
return cls._instance
@classmethod
def _create_provider(cls) -> AIProvider:
"""根据配置创建 Provider"""
config = load_config()
ai_config = config.get("ai", {})
# 优先从环境变量读取
api_key = os.getenv("AI_API_KEY", "")
base_url = os.getenv("AI_BASE_URL", "")
model = os.getenv("AI_MODEL", "")
# 如果环境变量未设置,从配置文件读取
if not api_key:
api_key = ai_config.get("api_key", "")
if not base_url:
base_url = ai_config.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")
if not model:
model = ai_config.get("model", "qwen-plus")
return OpenAICompatibleProvider(
api_key=api_key,
base_url=base_url,
model=model
)
@classmethod
def reset(cls):
"""重置单例,用于切换 Provider"""
cls._instance = None
# 便捷函数
def get_ai_service() -> AIProvider:
"""获取 AI 服务实例"""
return AIServiceFactory.get_provider()