Initial commit
This commit is contained in:
130
backend/planner/planning_agent/rag_pipeline.py
Normal file
130
backend/planner/planning_agent/rag_pipeline.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from langchain_core.documents import Document
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
from planner.mineru_client import extract_texts_with_mineru
|
||||
|
||||
class LocalKnowledgeBase:
|
||||
def __init__(
|
||||
self,
|
||||
knowledge_dir: str = "planner\knowledge",
|
||||
db_path: str = r"planner\vector_db",
|
||||
mineru_save_dir: str = "planner\mineru_result",
|
||||
):
|
||||
"""
|
||||
初始化本地知识库。
|
||||
:param knowledge_dir: 知识文件目录
|
||||
:param db_path: 向量数据库存储路径
|
||||
:param mineru_save_dir: MinerU 解析输出目录
|
||||
"""
|
||||
self.knowledge_dir = Path(knowledge_dir)
|
||||
self.db_path = Path(db_path)
|
||||
self.mineru_save_dir = Path(mineru_save_dir)
|
||||
|
||||
self.db_path.mkdir(parents=True, exist_ok=True)
|
||||
self.mineru_save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 向量嵌入模型(轻量中文支持)
|
||||
self.embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
||||
|
||||
self.vectorstore: Optional[FAISS] = None
|
||||
|
||||
def _fingerprint_knowledge_dir(self) -> str:
|
||||
"""
|
||||
对知识文件目录计算指纹(文件名 + 大小 + 修改时间)
|
||||
用于判断是否需要重新解析 MinerU 或重建索引。
|
||||
"""
|
||||
files = sorted(self.knowledge_dir.rglob("*"))
|
||||
fingerprint_parts = []
|
||||
for f in files:
|
||||
if f.is_file():
|
||||
stat = f.stat()
|
||||
fingerprint_parts.append(f"{f.name}:{stat.st_size}:{stat.st_mtime_ns}")
|
||||
return str(hash("|".join(fingerprint_parts)))
|
||||
|
||||
def _meta_path(self):
|
||||
return self.db_path / "meta.json"
|
||||
|
||||
def build_or_load_db(self, force_rebuild: bool = False) -> FAISS:
|
||||
"""
|
||||
解析知识库并构建向量数据库。
|
||||
若已存在向量数据库且知识文件未变,则直接加载。
|
||||
"""
|
||||
index_file = self.db_path / "index.faiss"
|
||||
meta_path = self._meta_path()
|
||||
|
||||
# 判断是否需要重建
|
||||
fingerprint = self._fingerprint_knowledge_dir()
|
||||
if (
|
||||
not force_rebuild
|
||||
and index_file.exists()
|
||||
and meta_path.exists()
|
||||
):
|
||||
try:
|
||||
meta = json.loads(meta_path.read_text(encoding="utf-8"))
|
||||
if meta.get("fingerprint") == fingerprint:
|
||||
print("[INFO] 检测到知识库未变,直接加载向量数据库...")
|
||||
self.vectorstore = FAISS.load_local(
|
||||
str(self.db_path),
|
||||
self.embedding_model,
|
||||
allow_dangerous_deserialization=True,
|
||||
)
|
||||
print("[INFO] 向量数据库加载完成。")
|
||||
return self.vectorstore
|
||||
except Exception as e:
|
||||
print(f"[INFO] 读取 meta.json 出错,将重新构建:{e}")
|
||||
|
||||
print("[INFO] 正在调用 MinerU API 解析知识库文件...")
|
||||
parsed_text_path = extract_texts_with_mineru(str(self.knowledge_dir), save_dir=str(self.mineru_save_dir))
|
||||
|
||||
if not parsed_text_path or not Path(parsed_text_path).exists():
|
||||
raise FileNotFoundError("[INFO] MinerU 解析失败,未生成知识文本。")
|
||||
|
||||
print(f"[INFO] MinerU 输出: {parsed_text_path}")
|
||||
all_text = Path(parsed_text_path).read_text(encoding="utf-8", errors="ignore")
|
||||
|
||||
print("[INFO] 正在切分知识文本...")
|
||||
splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=150)
|
||||
chunks = splitter.split_text(all_text)
|
||||
print(f"[INFO] 已切分为 {len(chunks)} 段。")
|
||||
|
||||
print("[INFO] 正在生成嵌入向量并构建数据库...")
|
||||
|
||||
|
||||
docs = [Document(
|
||||
page_content=c,
|
||||
metadata={"source": "mineru_all_knowledge"}
|
||||
) for c in chunks]
|
||||
|
||||
self.vectorstore = FAISS.from_documents(docs, self.embedding_model)
|
||||
self.vectorstore.save_local(str(self.db_path))
|
||||
|
||||
|
||||
# 保存 meta 信息
|
||||
meta = {
|
||||
"fingerprint": fingerprint,
|
||||
"built_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"num_chunks": len(chunks),
|
||||
"mineru_text": str(parsed_text_path),
|
||||
}
|
||||
meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
|
||||
print(f"[INFO] 知识库向量数据库构建完成,共 {len(chunks)} 段。")
|
||||
return self.vectorstore
|
||||
|
||||
def retrieve_context(self, query: str, top_k: int = 4) -> str:
|
||||
"""
|
||||
从知识库中检索与 query 最相关的上下文。
|
||||
"""
|
||||
if self.vectorstore is None:
|
||||
self.build_or_load_db()
|
||||
|
||||
docs = self.vectorstore.similarity_search(query, k=top_k)
|
||||
context = "\n\n".join([d.page_content for d in docs])
|
||||
print(f"[INFO] 已召回 {len(docs)} 条相关知识。")
|
||||
return context
|
||||
Reference in New Issue
Block a user