131 lines
5.0 KiB
Python
131 lines
5.0 KiB
Python
import os
|
|
import time
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
from langchain_core.documents import Document
|
|
from langchain_community.vectorstores import FAISS
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
from langchain_huggingface import HuggingFaceEmbeddings
|
|
from planner.mineru_client import extract_texts_with_mineru
|
|
|
|
class LocalKnowledgeBase:
|
|
def __init__(
|
|
self,
|
|
knowledge_dir: str = "planner\knowledge",
|
|
db_path: str = r"planner\vector_db",
|
|
mineru_save_dir: str = "planner\mineru_result",
|
|
):
|
|
"""
|
|
初始化本地知识库。
|
|
:param knowledge_dir: 知识文件目录
|
|
:param db_path: 向量数据库存储路径
|
|
:param mineru_save_dir: MinerU 解析输出目录
|
|
"""
|
|
self.knowledge_dir = Path(knowledge_dir)
|
|
self.db_path = Path(db_path)
|
|
self.mineru_save_dir = Path(mineru_save_dir)
|
|
|
|
self.db_path.mkdir(parents=True, exist_ok=True)
|
|
self.mineru_save_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# 向量嵌入模型(轻量中文支持)
|
|
self.embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
|
|
|
self.vectorstore: Optional[FAISS] = None
|
|
|
|
def _fingerprint_knowledge_dir(self) -> str:
|
|
"""
|
|
对知识文件目录计算指纹(文件名 + 大小 + 修改时间)
|
|
用于判断是否需要重新解析 MinerU 或重建索引。
|
|
"""
|
|
files = sorted(self.knowledge_dir.rglob("*"))
|
|
fingerprint_parts = []
|
|
for f in files:
|
|
if f.is_file():
|
|
stat = f.stat()
|
|
fingerprint_parts.append(f"{f.name}:{stat.st_size}:{stat.st_mtime_ns}")
|
|
return str(hash("|".join(fingerprint_parts)))
|
|
|
|
def _meta_path(self):
|
|
return self.db_path / "meta.json"
|
|
|
|
def build_or_load_db(self, force_rebuild: bool = False) -> FAISS:
|
|
"""
|
|
解析知识库并构建向量数据库。
|
|
若已存在向量数据库且知识文件未变,则直接加载。
|
|
"""
|
|
index_file = self.db_path / "index.faiss"
|
|
meta_path = self._meta_path()
|
|
|
|
# 判断是否需要重建
|
|
fingerprint = self._fingerprint_knowledge_dir()
|
|
if (
|
|
not force_rebuild
|
|
and index_file.exists()
|
|
and meta_path.exists()
|
|
):
|
|
try:
|
|
meta = json.loads(meta_path.read_text(encoding="utf-8"))
|
|
if meta.get("fingerprint") == fingerprint:
|
|
print("[INFO] 检测到知识库未变,直接加载向量数据库...")
|
|
self.vectorstore = FAISS.load_local(
|
|
str(self.db_path),
|
|
self.embedding_model,
|
|
allow_dangerous_deserialization=True,
|
|
)
|
|
print("[INFO] 向量数据库加载完成。")
|
|
return self.vectorstore
|
|
except Exception as e:
|
|
print(f"[INFO] 读取 meta.json 出错,将重新构建:{e}")
|
|
|
|
print("[INFO] 正在调用 MinerU API 解析知识库文件...")
|
|
parsed_text_path = extract_texts_with_mineru(str(self.knowledge_dir), save_dir=str(self.mineru_save_dir))
|
|
|
|
if not parsed_text_path or not Path(parsed_text_path).exists():
|
|
raise FileNotFoundError("[INFO] MinerU 解析失败,未生成知识文本。")
|
|
|
|
print(f"[INFO] MinerU 输出: {parsed_text_path}")
|
|
all_text = Path(parsed_text_path).read_text(encoding="utf-8", errors="ignore")
|
|
|
|
print("[INFO] 正在切分知识文本...")
|
|
splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=150)
|
|
chunks = splitter.split_text(all_text)
|
|
print(f"[INFO] 已切分为 {len(chunks)} 段。")
|
|
|
|
print("[INFO] 正在生成嵌入向量并构建数据库...")
|
|
|
|
|
|
docs = [Document(
|
|
page_content=c,
|
|
metadata={"source": "mineru_all_knowledge"}
|
|
) for c in chunks]
|
|
|
|
self.vectorstore = FAISS.from_documents(docs, self.embedding_model)
|
|
self.vectorstore.save_local(str(self.db_path))
|
|
|
|
|
|
# 保存 meta 信息
|
|
meta = {
|
|
"fingerprint": fingerprint,
|
|
"built_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
"num_chunks": len(chunks),
|
|
"mineru_text": str(parsed_text_path),
|
|
}
|
|
meta_path.write_text(json.dumps(meta, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
|
|
print(f"[INFO] 知识库向量数据库构建完成,共 {len(chunks)} 段。")
|
|
return self.vectorstore
|
|
|
|
def retrieve_context(self, query: str, top_k: int = 4) -> str:
|
|
"""
|
|
从知识库中检索与 query 最相关的上下文。
|
|
"""
|
|
if self.vectorstore is None:
|
|
self.build_or_load_db()
|
|
|
|
docs = self.vectorstore.similarity_search(query, k=top_k)
|
|
context = "\n\n".join([d.page_content for d in docs])
|
|
print(f"[INFO] 已召回 {len(docs)} 条相关知识。")
|
|
return context
|