搭建RAG知识库的完整源码实现
# -*- coding: utf-8 -*-
# 文件名:rag_knowledge_base.py
# RAG知识库搭建完整源码(含中文注释)
import os
import re
import shutil
import chromadb
from datetime import datetime
from typing import List, Dict
from PyPDF2 import PdfReader
import pdfplumber
from langchain.text_splitter import RecursiveCharacterTextSplitter
from text2vec import SentenceModel
from paddleocr import PaddleOCR
class KnowledgeBaseBuilder:
def __init__(self):
# 初始化模型和工具
self.ocr = PaddleOCR(use_angle_cls=True, lang="ch")
self.vector_model = SentenceModel("shibing624/text2vec-base-chinese")
self.chroma_client = chromadb.PersistentClient(path="./rag_db")
def collect_documents(self, source_dir: str, target_dir: str) -> None:
"""步骤1:自动采集有效文档"""
os.makedirs(target_dir, exist_ok=True)
# 定义有效版本正则规则
version_pattern = re.compile(r"V(2\.[3-9]|3\.\d+)_.*评审通过")
for filename in os.listdir(source_dir):
file_path = os.path.join(source_dir, filename)
if filename.endswith(".pdf") and version_pattern.search(filename):
# 移动有效文档到目标目录
shutil.copy(file_path, os.path.join(target_dir, filename))
print(f"采集有效文档: {filename}")
def clean_document(self, file_path: str) -> str:
"""步骤2:文档清洗处理"""
text = ""
if file_path.endswith(".pdf"):
# 处理PDF文字内容
with pdfplumber.open(file_path) as pdf:
for page in pdf.pages:
text += page.extract_text()
# 处理PDF中的表格
with pdfplumber.open(file_path) as pdf:
for page in pdf.pages:
for table in page.extract_tables():
text += "\n表格内容:\n"
for row in table:
text += "|".join(str(cell) for cell in row) + "\n"
# 处理PDF中的图片(OCR识别)
with pdfplumber.open(file_path) as pdf:
for page_num, page in enumerate(pdf.pages):
for img in page.images:
img_text = self.ocr.ocr(img["stream"].get_data())[0]
text += f"\n图片{page_num+1}-{img['name']}识别结果:\n"
text += "\n".join([line[1][0] for line in img_text])
# 清洗敏感信息
text = re.sub(r"机密|内部资料", "", text)
return text
def chunk_text(self, text: str, doc_type: str) -> List[Dict]:
"""步骤3:智能分块处理"""
# 定义分块策略
chunk_config = {
"需求文档": {"size": 256, "separators": ["\n\n", "。", "!", "?"]},
"API文档": {"size": 512, "separators": ["\n\n", "/api/"]},
"测试用例": {"size": 200, "separators": ["测试场景:", "预期结果:"]}
}
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_config[doc_type]["size"],
separators=chunk_config[doc_type]["separators"]
)
chunks = splitter.split_text(text)
return [{
"content": chunk,
"metadata": {
"doc_type": doc_type,
"chunk_size": len(chunk),
"process_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
} for chunk in chunks]
def vectorize_and_store(self, chunks: List[Dict], collection_name: str) -> None:
"""步骤4:向量化存储"""
collection = self.chroma_client.create_collection(name=collection_name)
documents = []
metadatas = []
embeddings = []
for idx, chunk in enumerate(chunks):
# 添加业务元数据
metadata = chunk["metadata"]
metadata.update({
"module": self.detect_module(chunk["content"]),
"priority": self.detect_priority(chunk["content"])
})
# 生成向量
embedding = self.vector_model.encode(chunk["content"])
documents.append(chunk["content"])
metadatas.append(metadata)
embeddings.append(embedding.tolist()) # 转换为list格式
if (idx+1) % 10 == 0:
print(f"已处理 {idx+1}/{len(chunks)} 个分块")
# 批量存储到ChromaDB
collection.add(
documents=documents,
metadatas=metadatas,
embeddings=embeddings,
ids=[str(i) for i in range(len(documents))]
)
def verify_knowledge_base(self, collection_name: str, query: str) -> Dict:
"""步骤5:知识库验证"""
collection = self.chroma_client.get_collection(collection_name)
results = collection.query(
query_texts=[query],
n_results=3,
include=["documents", "metadatas", "distances"]
)
return {
"query": query,
"results": [
{
"content": results["documents"][0][i],
"metadata": results["metadatas"][0][i],
"score": 1 - results["distances"][0][i] # 转换为相似度分数
}
for i in range(len(results["documents"][0]))
]
}
# ---------- 辅助函数 ----------
def detect_module(self, text: str) -> str:
"""自动检测功能模块"""
modules = ["登录", "支付", "订单", "用户"]
for module in modules:
if module in text:
return module
return "其他"
def detect_priority(self, text: str) -> str:
"""自动检测优先级"""
if "P0" in text:
return "P0"
elif "关键路径" in text:
return "P1"
return "P2"
# ----------------- 使用示例 -----------------
if __name__ == "__main__":
builder = KnowledgeBaseBuilder()
# 第一步:采集文档
builder.collect_documents(
source_dir="./原始文档",
target_dir="./有效知识库"
)
# 第二步:清洗并处理文档
sample_doc = "./有效知识库/支付_V2.3_评审通过.pdf"
cleaned_text = builder.clean_document(sample_doc)
# 第三步:分块处理
chunks = builder.chunk_text(cleaned_text, doc_type="需求文档")
# 第四步:向量化存储
builder.vectorize_and_store(
chunks=chunks,
collection_name="payment_module"
)
# 第五步:验证效果
test_query = "如何测试支付超时场景?"
results = builder.verify_knowledge_base("payment_module", test_query)
print("\n验证结果:")
for idx, result in enumerate(results["results"]):
print(f"\n结果{idx+1}(相似度:{result['score']:.2f}):")
print(f"模块:{result['metadata']['module']}")
print(f"内容片段:{result['content'][:100]}...")