搭建RAG知识库的完整源码实现

时间:2025-02-22 21:30:31
# -*- coding: utf-8 -*- # 文件名:rag_knowledge_base.py # RAG知识库搭建完整源码(含中文注释) import os import re import shutil import chromadb from datetime import datetime from typing import List, Dict from PyPDF2 import PdfReader import pdfplumber from langchain.text_splitter import RecursiveCharacterTextSplitter from text2vec import SentenceModel from paddleocr import PaddleOCR class KnowledgeBaseBuilder: def __init__(self): # 初始化模型和工具 self.ocr = PaddleOCR(use_angle_cls=True, lang="ch") self.vector_model = SentenceModel("shibing624/text2vec-base-chinese") self.chroma_client = chromadb.PersistentClient(path="./rag_db") def collect_documents(self, source_dir: str, target_dir: str) -> None: """步骤1:自动采集有效文档""" os.makedirs(target_dir, exist_ok=True) # 定义有效版本正则规则 version_pattern = re.compile(r"V(2\.[3-9]|3\.\d+)_.*评审通过") for filename in os.listdir(source_dir): file_path = os.path.join(source_dir, filename) if filename.endswith(".pdf") and version_pattern.search(filename): # 移动有效文档到目标目录 shutil.copy(file_path, os.path.join(target_dir, filename)) print(f"采集有效文档: {filename}") def clean_document(self, file_path: str) -> str: """步骤2:文档清洗处理""" text = "" if file_path.endswith(".pdf"): # 处理PDF文字内容 with pdfplumber.open(file_path) as pdf: for page in pdf.pages: text += page.extract_text() # 处理PDF中的表格 with pdfplumber.open(file_path) as pdf: for page in pdf.pages: for table in page.extract_tables(): text += "\n表格内容:\n" for row in table: text += "|".join(str(cell) for cell in row) + "\n" # 处理PDF中的图片(OCR识别) with pdfplumber.open(file_path) as pdf: for page_num, page in enumerate(pdf.pages): for img in page.images: img_text = self.ocr.ocr(img["stream"].get_data())[0] text += f"\n图片{page_num+1}-{img['name']}识别结果:\n" text += "\n".join([line[1][0] for line in img_text]) # 清洗敏感信息 text = re.sub(r"机密|内部资料", "", text) return text def chunk_text(self, text: str, doc_type: str) -> List[Dict]: """步骤3:智能分块处理""" # 定义分块策略 chunk_config = { "需求文档": {"size": 256, "separators": ["\n\n", "。", "!", "?"]}, "API文档": {"size": 512, "separators": ["\n\n", "/api/"]}, "测试用例": {"size": 200, "separators": ["测试场景:", "预期结果:"]} } splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_config[doc_type]["size"], separators=chunk_config[doc_type]["separators"] ) chunks = splitter.split_text(text) return [{ "content": chunk, "metadata": { "doc_type": doc_type, "chunk_size": len(chunk), "process_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S") } } for chunk in chunks] def vectorize_and_store(self, chunks: List[Dict], collection_name: str) -> None: """步骤4:向量化存储""" collection = self.chroma_client.create_collection(name=collection_name) documents = [] metadatas = [] embeddings = [] for idx, chunk in enumerate(chunks): # 添加业务元数据 metadata = chunk["metadata"] metadata.update({ "module": self.detect_module(chunk["content"]), "priority": self.detect_priority(chunk["content"]) }) # 生成向量 embedding = self.vector_model.encode(chunk["content"]) documents.append(chunk["content"]) metadatas.append(metadata) embeddings.append(embedding.tolist()) # 转换为list格式 if (idx+1) % 10 == 0: print(f"已处理 {idx+1}/{len(chunks)} 个分块") # 批量存储到ChromaDB collection.add( documents=documents, metadatas=metadatas, embeddings=embeddings, ids=[str(i) for i in range(len(documents))] ) def verify_knowledge_base(self, collection_name: str, query: str) -> Dict: """步骤5:知识库验证""" collection = self.chroma_client.get_collection(collection_name) results = collection.query( query_texts=[query], n_results=3, include=["documents", "metadatas", "distances"] ) return { "query": query, "results": [ { "content": results["documents"][0][i], "metadata": results["metadatas"][0][i], "score": 1 - results["distances"][0][i] # 转换为相似度分数 } for i in range(len(results["documents"][0])) ] } # ---------- 辅助函数 ---------- def detect_module(self, text: str) -> str: """自动检测功能模块""" modules = ["登录", "支付", "订单", "用户"] for module in modules: if module in text: return module return "其他" def detect_priority(self, text: str) -> str: """自动检测优先级""" if "P0" in text: return "P0" elif "关键路径" in text: return "P1" return "P2" # ----------------- 使用示例 ----------------- if __name__ == "__main__": builder = KnowledgeBaseBuilder() # 第一步:采集文档 builder.collect_documents( source_dir="./原始文档", target_dir="./有效知识库" ) # 第二步:清洗并处理文档 sample_doc = "./有效知识库/支付_V2.3_评审通过.pdf" cleaned_text = builder.clean_document(sample_doc) # 第三步:分块处理 chunks = builder.chunk_text(cleaned_text, doc_type="需求文档") # 第四步:向量化存储 builder.vectorize_and_store( chunks=chunks, collection_name="payment_module" ) # 第五步:验证效果 test_query = "如何测试支付超时场景?" results = builder.verify_knowledge_base("payment_module", test_query) print("\n验证结果:") for idx, result in enumerate(results["results"]): print(f"\n结果{idx+1}(相似度:{result['score']:.2f}):") print(f"模块:{result['metadata']['module']}") print(f"内容片段:{result['content'][:100]}...")