【Dify精讲】第7章:知识库与向量检索实现

在这里插入图片描述

引言

还记得第一次体验 Dify 的知识库功能时,我被它的检索速度和准确性深深震撼。上传一份PDF文档,几秒钟就能完成索引,然后无论是模糊查询还是语义检索,都能快速返回精准的结果。

作为一个在搜索领域摸爬滚打多年的工程师,我知道这背后绝不简单。今天,让我们一起深入 Dify 的知识库系统,看看一个生产级的向量检索系统是如何设计的,从文档解析到向量存储,从检索算法到索引优化,每一个环节都蕴含着深刻的工程智慧。

一、知识库系统整体架构

1.1 架构设计理念

打开 api/core/rag/ 目录,你会发现 Dify 的 RAG(Retrieval-Augmented Generation)系统结构清晰:

rag/
├── datasource/           # 数据源管理
├── extractor/           # 文档提取器
├── models/             # 数据模型
├── splitter/           # 文本分割器
├── embedding/          # 向量化服务
└── retrieval/          # 检索服务

这种设计遵循了单一职责原则,每个模块专注于一个特定功能,同时通过清晰的接口实现模块间的协作。

1.2 核心流程概览

让我们先理解知识库的核心处理流程:

# api/core/rag/datasource/vdb/vector_factory.py
class VectorFactory:
    """向量数据库工厂类"""
    
    @staticmethod
    def get_vector_database(vector_type: str, dataset: Dataset) -> BaseVector:
        """获取向量数据库实例"""
        vector_config = dataset.vector_config
        
        if vector_type == VectorType.WEAVIATE:
            return WeaviateVector(
                collection_name=dataset.collection_name,
                config=WeaviateConfig(**vector_config)
            )
        elif vector_type == VectorType.QDRANT:
            return QdrantVector(
                collection_name=dataset.collection_name,
                config=QdrantConfig(**vector_config)
            )
        elif vector_type == VectorType.CHROMA:
            return ChromaVector(
                collection_name=dataset.collection_name,
                config=ChromaConfig(**vector_config)
            )
        
        raise ValueError(f"Unsupported vector database type: {vector_type}")

这个工厂类体现了策略模式的应用,支持多种向量数据库,用户可以根据需求灵活选择。

二、文档处理Pipeline

2.1 文档解析器:万能的内容提取器

Dify 支持多种文档格式,每种格式都有专门的解析器:

# api/core/rag/extractor/extract_processor.py
class ExtractProcessor:
    """文档提取处理器"""
    
    def __init__(self):
        self.extractors = {
            'pdf': PDFExtractor(),
            'docx': DocxExtractor(), 
            'txt': TxtExtractor(),
            'markdown': MarkdownExtractor(),
            'html': HtmlExtractor(),
            'csv': CsvExtractor()
        }
    
    def extract(self, file_path: str, file_type: str) -> ExtractedContent:
        """提取文档内容"""
        extractor = self.extractors.get(file_type.lower())
        if not extractor:
            raise UnsupportedFileTypeError(f"Unsupported file type: {file_type}")
        
        try:
            # 1. 文档解析
            raw_content = extractor.extract(file_path)
            
            # 2. 内容清洗
            cleaned_content = self._clean_content(raw_content)
            
            # 3. 元数据提取
            metadata = self._extract_metadata(file_path, file_type)
            
            return ExtractedContent(
                content=cleaned_content,
                metadata=metadata,
                file_type=file_type
            )
            
        except Exception as e:
            logger.error(f"Failed to extract content from {file_path}: {str(e)}")
            raise DocumentProcessingError(str(e))
    
    def _clean_content(self, content: str) -> str:
        """内容清洗"""
        # 1. 去除多余空白
        content = re.sub(r'\s+', ' ', content)
        
        # 2. 去除特殊字符
        content = re.sub(r'[^\w\s\u4e00-\u9fff.,!?;:]', '', content)
        
        # 3. 标准化编码
        content = content.encode('utf-8', errors='ignore').decode('utf-8')
        
        return content.strip()

设计亮点

  • 使用工厂模式管理不同类型的提取器
  • 统一的错误处理和日志记录
  • 内容清洗确保数据质量

2.2 PDF 解析器:特殊挑战的优雅解决

PDF 解析是最复杂的,让我们看看 Dify 是如何处理的:

# api/core/rag/extractor/pdf_extractor.py
class PDFExtractor(BaseExtractor):
    """PDF文档提取器"""
    
    def __init__(self):
        self.use_ocr = current_app.config.get('PDF_OCR_ENABLED', False)
        self.max_pages = current_app.config.get('PDF_MAX_PAGES', 1000)
    
    def extract(self, file_path: str) -> str:
        """提取PDF内容"""
        try:
            # 方法1:尝试直接提取文本
            text = self._extract_text_directly(file_path)
            
            # 方法2:如果直接提取失败或内容太少,使用OCR
            if not text or len(text.strip()) < 100:
                if self.use_ocr:
                    text = self._extract_with_ocr(file_path)
                else:
                    logger.warning(f"PDF {file_path} has little text and OCR is disabled")
            
            return text
            
        except Exception as e:
            logger.error(f"Failed to extract PDF {file_path}: {str(e)}")
            raise
    
    def _extract_text_directly(self, file_path: str) -> str:
        """直接提取文本"""
        import PyPDF2
        
        with open(file_path, 'rb') as file:
            pdf_reader = PyPDF2.PdfReader(file)
            
            if len(pdf_reader.pages) > self.max_pages:
                raise ValueError(f"PDF has too many pages: {len(pdf_reader.pages)}")
            
            text_parts = []
            for page_num, page in enumerate(pdf_reader.pages):
                try:
                    page_text = page.extract_text()
                    if page_text:
                        text_parts.append(f"[Page {page_num + 1}]\n{page_text}")
                except Exception as e:
                    logger.warning(f"Failed to extract page {page_num}: {str(e)}")
                    continue
            
            return '\n\n'.join(text_parts)
    
    def _extract_with_ocr(self, file_path: str) -> str:
        """使用OCR提取文本"""
        import pdf2image
        import pytesseract
        
        # 1. PDF转图片
        images = pdf2image.convert_from_path(file_path)
        
        # 2. OCR识别
        text_parts = []
        for i, image in enumerate(images):
            if i >= self.max_pages:
                break
                
            try:
                # 图像预处理提高OCR准确率
                processed_image = self._preprocess_image(image)
                
                # OCR识别
                page_text = pytesseract.image_to_string(
                    processed_image, 
                    lang='chi_sim+eng'  # 支持中英文
                )
                
                if page_text.strip():
                    text_parts.append(f"[Page {i + 1}]\n{page_text}")
                    
            except Exception as e:
                logger.warning(f"OCR failed for page {i}: {str(e)}")
                continue
        
        return '\n\n'.join(text_parts)
    
    def _preprocess_image(self, image):
        """图像预处理"""
        import cv2
        import numpy as np
        
        # PIL转OpenCV格式
        img_array = np.array(image)
        
        # 灰度化
        gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
        
        # 去噪
        denoised = cv2.fastNlMeansDenoising(gray)
        
        # 二值化
        _, binary = cv2.threshold(denoised, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        
        return binary

工程智慧

  • 渐进式策略:先尝试简单方法,失败后再用复杂方法
  • 性能保护:限制最大页数,避免资源耗尽
  • 图像预处理:提高 OCR 识别准确率

三、文本分割策略

3.1 智能分块算法

文本分割是 RAG 系统的关键环节,直接影响检索效果:

# api/core/rag/splitter/fixed_text_splitter.py
class FixedTextSplitter(BaseTextSplitter):
    """固定长度文本分割器"""
    
    def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        
        # 分句正则,支持中英文
        self.sentence_endings = re.compile(r'[.!?。!?][\s]*')
    
    def split_text(self, text: str) -> List[TextChunk]:
        """分割文本"""
        if not text or not text.strip():
            return []
        
        # 1. 预处理
        text = self._preprocess_text(text)
        
        # 2. 按句子分割
        sentences = self._split_sentences(text)
        
        # 3. 组合成块
        chunks = self._combine_sentences_to_chunks(sentences)
        
        return chunks
    
    def _split_sentences(self, text: str) -> List[str]:
        """按句子分割"""
        # 使用正则表达式分割
        sentences = self.sentence_endings.split(text)
        
        # 过滤空句子
        sentences = [s.strip() for s in sentences if s.strip()]
        
        return sentences
    
    def _combine_sentences_to_chunks(self, sentences: List[str]) -> List[TextChunk]:
        """将句子组合成块"""
        chunks = []
        current_chunk = ""
        current_length = 0
        
        for sentence in sentences:
            sentence_length = len(sentence)
            
            # 如果添加这句话会超过限制
            if current_length + sentence_length > self.chunk_size and current_chunk:
                # 保存当前块
                chunks.append(TextChunk(
                    content=current_chunk.strip(),
                    metadata={'chunk_index': len(chunks)}
                ))
                
                # 重叠处理
                if self.chunk_overlap > 0:
                    overlap_text = self._get_overlap_text(current_chunk, self.chunk_overlap)
                    current_chunk = overlap_text + " " + sentence
                    current_length = len(current_chunk)
                else:
                    current_chunk = sentence
                    current_length = sentence_length
            else:
                # 添加到当前块
                if current_chunk:
                    current_chunk += " " + sentence
                else:
                    current_chunk = sentence
                current_length += sentence_length + 1  # +1 for space
        
        # 添加最后一块
        if current_chunk.strip():
            chunks.append(TextChunk(
                content=current_chunk.strip(),
                metadata={'chunk_index': len(chunks)}
            ))
        
        return chunks
    
    def _get_overlap_text(self, text: str, overlap_size: int) -> str:
        """获取重叠文本"""
        if len(text) <= overlap_size:
            return text
        
        # 尝试在句子边界处截断
        sentences = self._split_sentences(text)
        overlap_text = ""
        
        for sentence in reversed(sentences):
            if len(overlap_text) + len(sentence) <= overlap_size:
                overlap_text = sentence + " " + overlap_text
            else:
                break
        
        return overlap_text.strip()

3.2 语义感知分割器

除了固定长度分割,Dify 还实现了语义感知的分割策略:

# api/core/rag/splitter/semantic_text_splitter.py
class SemanticTextSplitter(BaseTextSplitter):
    """语义感知文本分割器"""
    
    def __init__(self, embedding_model, similarity_threshold: float = 0.8):
        self.embedding_model = embedding_model
        self.similarity_threshold = similarity_threshold
    
    def split_text(self, text: str) -> List[TextChunk]:
        """基于语义相似度分割文本"""
        # 1. 按段落分割
        paragraphs = text.split('\n\n')
        paragraphs = [p.strip() for p in paragraphs if p.strip()]
        
        if not paragraphs:
            return []
        
        # 2. 计算段落向量
        paragraph_embeddings = self.embedding_model.embed(paragraphs)
        
        # 3. 基于相似度聚类
        clusters = self._cluster_by_similarity(paragraphs, paragraph_embeddings)
        
        # 4. 生成文本块
        chunks = []
        for i, cluster in enumerate(clusters):
            content = '\n\n'.join([paragraphs[idx] for idx in cluster])
            chunks.append(TextChunk(
                content=content,
                metadata={
                    'chunk_index': i,
                    'paragraph_count': len(cluster),
                    'semantic_cluster': True
                }
            ))
        
        return chunks
    
    def _cluster_by_similarity(self, paragraphs: List[str], embeddings: List[List[float]]) -> List[List[int]]:
        """基于相似度聚类"""
        from sklearn.cluster import AgglomerativeClustering
        from sklearn.metrics.pairwise import cosine_similarity
        
        # 计算相似度矩阵
        similarity_matrix = cosine_similarity(embeddings)
        
        # 转换为距离矩阵
        distance_matrix = 1 - similarity_matrix
        
        # 层次聚类
        clustering = AgglomerativeClustering(
            n_clusters=None,
            distance_threshold=1 - self.similarity_threshold,
            linkage='average',
            metric='precomputed'
        )
        
        cluster_labels = clustering.fit_predict(distance_matrix)
        
        # 组织聚类结果
        clusters = {}
        for idx, label in enumerate(cluster_labels):
            if label not in clusters:
                clusters[label] = []
            clusters[label].append(idx)
        
        return list(clusters.values())

设计思路

  • 语义连贯性:相似的段落聚合在一起
  • 动态大小:根据内容语义自然分割,不受固定长度限制
  • 计算成本:需要计算向量,适合对质量要求高的场景

四、向量化存储方案

4.1 向量模型抽象层

Dify 支持多种向量模型,通过抽象层统一接口:

# api/core/rag/embedding/embedding_base.py
class BaseEmbedding:
    """向量模型基类"""
    
    def __init__(self, model_name: str, **kwargs):
        self.model_name = model_name
        self.max_tokens = kwargs.get('max_tokens', 8192)
        self.dimensions = kwargs.get('dimensions', 1536)
    
    def embed(self, texts: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
        """向量化文本"""
        raise NotImplementedError
    
    def embed_query(self, text: str) -> List[float]:
        """向量化查询"""
        return self.embed(text)
    
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """批量向量化文档"""
        return self.embed(texts)

# api/core/rag/embedding/openai_embedding.py
class OpenAIEmbedding(BaseEmbedding):
    """OpenAI向量模型"""
    
    def __init__(self, api_key: str, model_name: str = "text-embedding-ada-002"):
        super().__init__(model_name)
        self.client = OpenAI(api_key=api_key)
        self.rate_limiter = RateLimiter(requests_per_minute=3000)
    
    def embed(self, texts: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
        """向量化文本"""
        is_single = isinstance(texts, str)
        if is_single:
            texts = [texts]
        
        # 1. 文本预处理
        processed_texts = [self._preprocess_text(text) for text in texts]
        
        # 2. 批量处理,避免超过API限制
        batch_size = 100  # OpenAI推荐的批次大小
        all_embeddings = []
        
        for i in range(0, len(processed_texts), batch_size):
            batch = processed_texts[i:i + batch_size]
            
            # 3. 速率限制
            with self.rate_limiter:
                try:
                    response = self.client.embeddings.create(
                        model=self.model_name,
                        input=batch
                    )
                    
                    batch_embeddings = [item.embedding for item in response.data]
                    all_embeddings.extend(batch_embeddings)
                    
                except Exception as e:
                    logger.error(f"OpenAI embedding error: {str(e)}")
                    # 重试机制
                    time.sleep(1)
                    raise
        
        return all_embeddings[0] if is_single else all_embeddings
    
    def _preprocess_text(self, text: str) -> str:
        """预处理文本"""
        # 1. 清理文本
        text = re.sub(r'\s+', ' ', text).strip()
        
        # 2. 截断过长文本
        if len(text) > self.max_tokens * 4:  # 粗略估算token数
            text = text[:self.max_tokens * 4]
        
        return text

4.2 向量数据库适配器

Dify 支持多种向量数据库,每种都有专门的适配器:

# api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
class QdrantVector(BaseVector):
    """Qdrant向量数据库适配器"""
    
    def __init__(self, collection_name: str, config: QdrantConfig):
        self.collection_name = collection_name
        self.config = config
        self.client = QdrantClient(
            host=config.host,
            port=config.port,
            api_key=config.api_key,
            timeout=config.timeout
        )
        
        # 确保集合存在
        self._ensure_collection_exists()
    
    def _ensure_collection_exists(self):
        """确保集合存在"""
        try:
            self.client.get_collection(self.collection_name)
        except Exception:
            # 集合不存在,创建它
            self.client.create_collection(
                collection_name=self.collection_name,
                vectors_config=models.VectorParams(
                    size=self.config.vector_size,
                    distance=models.Distance.COSINE
                ),
                optimizers_config=models.OptimizersConfig(
                    default_segment_number=2,
                    max_segment_size=20000,
                    memmap_threshold=20000,
                ),
                hnsw_config=models.HnswConfig(
                    m=16,
                    ef_construct=100,
                    full_scan_threshold=10000,
                )
            )
    
    def add_texts(self, texts: List[str], embeddings: List[List[float]], 
                  metadatas: Optional[List[dict]] = None, **kwargs) -> List[str]:
        """添加文本向量"""
        ids = [str(uuid.uuid4()) for _ in texts]
        
        points = []
        for i, (text, embedding) in enumerate(zip(texts, embeddings)):
            metadata = metadatas[i] if metadatas else {}
            metadata.update({
                'text': text,
                'created_at': datetime.utcnow().isoformat()
            })
            
            points.append(models.PointStruct(
                id=ids[i],
                vector=embedding,
                payload=metadata
            ))
        
        # 批量插入
        batch_size = 100
        for i in range(0, len(points), batch_size):
            batch = points[i:i + batch_size]
            self.client.upsert(
                collection_name=self.collection_name,
                points=batch
            )
        
        return ids
    
    def similarity_search(self, query_embedding: List[float], k: int = 5, 
                         score_threshold: Optional[float] = None,
                         filter_dict: Optional[dict] = None) -> List[Document]:
        """相似度搜索"""
        # 构建过滤条件
        query_filter = None
        if filter_dict:
            query_filter = self._build_filter(filter_dict)
        
        # 执行搜索
        search_result = self.client.search(
            collection_name=self.collection_name,
            query_vector=query_embedding,
            query_filter=query_filter,
            limit=k,
            score_threshold=score_threshold,
            with_payload=True,
            with_vectors=False
        )
        
        # 转换结果
        documents = []
        for scored_point in search_result:
            doc = Document(
                page_content=scored_point.payload.get('text', ''),
                metadata={
                    **scored_point.payload,
                    'score': scored_point.score,
                    'id': scored_point.id
                }
            )
            documents.append(doc)
        
        return documents
    
    def _build_filter(self, filter_dict: dict) -> models.Filter:
        """构建查询过滤器"""
        conditions = []
        
        for key, value in filter_dict.items():
            if isinstance(value, list):
                # IN 查询
                conditions.append(
                    models.FieldCondition(
                        key=key,
                        match=models.MatchAny(any=value)
                    )
                )
            else:
                # 等值查询
                conditions.append(
                    models.FieldCondition(
                        key=key,
                        match=models.MatchValue(value=value)
                    )
                )
        
        return models.Filter(must=conditions) if conditions else None

架构优势

  • 统一接口:不同向量数据库使用相同的API
  • 性能优化:批量操作、连接池、索引配置
  • 灵活过滤:支持复杂的查询条件

五、检索算法实现

5.1 混合检索策略

Dify 实现了向量检索和关键词检索的混合策略:

# api/core/rag/retrieval/retrival_methods.py
class HybridRetrieval:
    """混合检索策略"""
    
    def __init__(self, vector_store: BaseVector, keyword_store: BaseKeyword, 
                 vector_weight: float = 0.7):
        self.vector_store = vector_store
        self.keyword_store = keyword_store
        self.vector_weight = vector_weight
        self.keyword_weight = 1.0 - vector_weight
    
    def retrieve(self, query: str, k: int = 10, **kwargs) -> List[Document]:
        """混合检索"""
        # 1. 向量检索
        vector_results = self.vector_store.similarity_search(
            query_embedding=self._embed_query(query),
            k=k * 2,  # 获取更多候选
            **kwargs
        )
        
        # 2. 关键词检索
        keyword_results = self.keyword_store.search(
            query=query,
            k=k * 2,
            **kwargs
        )
        
        # 3. 结果融合
        merged_results = self._merge_results(vector_results, keyword_results, k)
        
        return merged_results
    
    def _merge_results(self, vector_results: List[Document], 
                      keyword_results: List[Document], 
                      k: int) -> List[Document]:
        """融合检索结果"""
        # 1. 结果去重和评分
        document_scores = {}
        
        # 处理向量检索结果
        for doc in vector_results:
            doc_id = doc.metadata.get('id')
            vector_score = doc.metadata.get('score', 0.0)
            
            if doc_id not in document_scores:
                document_scores[doc_id] = {
                    'document': doc,
                    'vector_score': vector_score,
                    'keyword_score': 0.0
                }
            else:
                document_scores[doc_id]['vector_score'] = max(
                    document_scores[doc_id]['vector_score'], 
                    vector_score
                )
        
        # 处理关键词检索结果
        for doc in keyword_results:
            doc_id = doc.metadata.get('id')
            keyword_score = doc.metadata.get('score', 0.0)
            
            if doc_id not in document_scores:
                document_scores[doc_id] = {
                    'document': doc,
                    'vector_score': 0.0,
                    'keyword_score': keyword_score
                }
            else:
                document_scores[doc_id]['keyword_score'] = max(
                    document_scores[doc_id]['keyword_score'], 
                    keyword_score
                )
        
        # 2. 计算综合得分
        for doc_id, scores in document_scores.items():
            # RRF (Reciprocal Rank Fusion) 算法
            vector_score = scores['vector_score']
            keyword_score = scores['keyword_score']
            
            # 归一化分数
            normalized_vector = self._normalize_score(vector_score, 'vector')
            normalized_keyword = self._normalize_score(keyword_score, 'keyword')
            
            # 加权融合
            final_score = (
                self.vector_weight * normalized_vector + 
                self.keyword_weight * normalized_keyword
            )
            
            scores['final_score'] = final_score
            scores['document'].metadata['final_score'] = final_score
        
        # 3. 排序并返回前k个
        sorted_results = sorted(
            document_scores.values(),
            key=lambda x: x['final_score'],
            reverse=True
        )
        
        return [item['document'] for item in sorted_results[:k]]
    
    def _normalize_score(self, score: float, score_type: str) -> float:
        """归一化分数"""
        if score_type == 'vector':
            # 向量相似度通常在0-1之间
            return max(0.0, min(1.0, score))
        elif score_type == 'keyword':
            # 关键词分数需要根据实际情况调整
            return max(0.0, min(1.0, score / 10.0))
        
        return score

5.2 查询扩展和重写

为提高检索召回率,Dify 实现了查询扩展机制:

# api/core/rag/retrieval/query_expansion.py
class QueryExpansion:
    """查询扩展器"""
    
    def __init__(self, llm_client, embedding_model):
        self.llm_client = llm_client
        self.embedding_model = embedding_model
        self.synonym_cache = {}
    
    def expand_query(self, query: str, method: str = 'llm') -> List[str]:
        """扩展查询"""
        if method == 'llm':
            return self._expand_with_llm(query)
        elif method == 'embedding':
            return self._expand_with_embedding(query)
        elif method == 'hybrid':
            llm_expansions = self._expand_with_llm(query)
            embedding_expansions = self._expand_with_embedding(query)
            return list(set(llm_expansions + embedding_expansions))
        
        return [query]
    
    def _expand_with_llm(self, query: str) -> List[str]:
        """使用LLM扩展查询"""
        prompt = f"""
        请为以下查询生成3-5个语义相近的表达方式,用于提高搜索召回率:
        
        原查询:{query}
        
        扩展后的查询(每行一个):
        """
        
        try:
            response = self.llm_client.chat([
                {"role": "user", "content": prompt}
            ])
            
            # 解析响应
            expanded_queries = [query]  # 包含原查询
            lines = response.strip().split('\n')
            
            for line in lines:
                line = line.strip()
                if line and line != query:
                    # 去除可能的序号
                    cleaned_line = re.sub(r'^\d+\.?\s*', '', line)
                    if cleaned_line:
                        expanded_queries.append(cleaned_line)
            
            return expanded_queries[:6]  # 最多返回6个查询
            
        except Exception as e:
            logger.warning(f"LLM query expansion failed: {str(e)}")
            return [query]
    
    def _expand_with_embedding(self, query: str) -> List[str]:
        """使用向量相似度扩展查询"""
        # 这里可以基于已有的文档库找相似表达
        # 简化实现,实际中可以构建同义词向量库
        synonyms = self._get_cached_synonyms(query)
        if synonyms:
            return [query] + synonyms[:3]
        
        return [query]
    
    def _get_cached_synonyms(self, query: str) -> List[str]:
        """获取缓存的同义词"""
        return self.synonym_cache.get(query, [])

5.3 重排序算法

检索后的重排序对最终效果至关重要:

# api/core/rag/retrieval/reranker.py
class Reranker:
    """重排序器"""
    
    def __init__(self, model_name: str = "cross-encoder"):
        self.model_name = model_name
        self.model = self._load_model()
    
    def rerank(self, query: str, documents: List[Document], 
               top_k: Optional[int] = None) -> List[Document]:
        """重新排序文档"""
        if not documents:
            return documents
        
        # 1. 计算查询-文档相关性分数
        scores = []
        for doc in documents:
            score = self._compute_relevance_score(query, doc.page_content)
            scores.append(score)
        
        # 2. 结合原始分数
        final_scores = []
        for i, doc in enumerate(documents):
            original_score = doc.metadata.get('score', 0.0)
            rerank_score = scores[i]
            
            # 加权融合
            final_score = 0.6 * rerank_score + 0.4 * original_score
            final_scores.append((final_score, doc))
        
        # 3. 排序
        final_scores.sort(key=lambda x: x[0], reverse=True)
        
        # 4. 更新分数并返回
        reranked_docs = []
        for score, doc in final_scores:
            doc.metadata['rerank_score'] = score
            reranked_docs.append(doc)
        
        return reranked_docs[:top_k] if top_k else reranked_docs
    
    def _compute_relevance_score(self, query: str, document: str) -> float:
        """计算相关性分数"""
        try:
            # 使用交叉编码器计算相关性
            inputs = self.model.tokenizer(
                query, document,
                return_tensors='pt',
                truncation=True,
                max_length=512
            )
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                score = torch.sigmoid(outputs.logits).item()
            
            return score
            
        except Exception as e:
            logger.warning(f"Reranking failed: {str(e)}")
            # 降级到基于关键词的简单匹配
            return self._simple_relevance_score(query, document)
    
    def _simple_relevance_score(self, query: str, document: str) -> float:
        """简单相关性计算(降级方案)"""
        query_words = set(query.lower().split())
        doc_words = set(document.lower().split())
        
        if not query_words:
            return 0.0
        
        # Jaccard 相似度
        intersection = query_words.intersection(doc_words)
        union = query_words.union(doc_words)
        
        return len(intersection) / len(union) if union else 0.0

六、数据索引优化

6.1 索引构建策略

Dify 采用了多种索引优化技术:

# api/core/rag/datasource/vdb/vector_index.py
class VectorIndexManager:
    """向量索引管理器"""
    
    def __init__(self, vector_store: BaseVector):
        self.vector_store = vector_store
        self.index_config = self._get_index_config()
    
    def create_index(self, documents: List[Document], 
                    batch_size: int = 1000) -> str:
        """创建向量索引"""
        logger.info(f"Creating index for {len(documents)} documents")
        
        # 1. 文档预处理
        processed_docs = self._preprocess_documents(documents)
        
        # 2. 批量向量化
        embeddings = []
        texts = []
        metadatas = []
        
        for i in range(0, len(processed_docs), batch_size):
            batch_docs = processed_docs[i:i + batch_size]
            batch_texts = [doc.page_content for doc in batch_docs]
            batch_metadatas = [doc.metadata for doc in batch_docs]
            
            # 向量化
            batch_embeddings = self.vector_store.embedding_model.embed_documents(batch_texts)
            
            embeddings.extend(batch_embeddings)
            texts.extend(batch_texts)
            metadatas.extend(batch_metadatas)
            
            logger.info(f"Processed {i + len(batch_docs)}/{len(processed_docs)} documents")
        
        # 3. 批量插入向量数据库
        doc_ids = self.vector_store.add_texts(
            texts=texts,
            embeddings=embeddings,
            metadatas=metadatas
        )
        
        # 4. 优化索引
        self._optimize_index()
        
        logger.info(f"Index created successfully with {len(doc_ids)} vectors")
        return f"index_{int(time.time())}"
    
    def _preprocess_documents(self, documents: List[Document]) -> List[Document]:
        """预处理文档"""
        processed = []
        
        for doc in documents:
            # 1. 内容清洗
            cleaned_content = self._clean_content(doc.page_content)
            
            # 2. 长度检查
            if len(cleaned_content) < 10:  # 过短的文档跳过
                continue
            
            # 3. 元数据丰富
            metadata = doc.metadata.copy()
            metadata.update({
                'content_length': len(cleaned_content),
                'word_count': len(cleaned_content.split()),
                'indexed_at': datetime.utcnow().isoformat()
            })
            
            processed.append(Document(
                page_content=cleaned_content,
                metadata=metadata
            ))
        
        return processed
    
    def _optimize_index(self):
        """优化索引性能"""
        if hasattr(self.vector_store, 'optimize_index'):
            self.vector_store.optimize_index()
        
        # 预热索引
        if hasattr(self.vector_store, 'warmup'):
            self.vector_store.warmup()
    
    def rebuild_index(self, dataset_id: str) -> str:
        """重建索引"""
        logger.info(f"Rebuilding index for dataset {dataset_id}")
        
        # 1. 获取所有文档
        documents = self._get_dataset_documents(dataset_id)
        
        # 2. 删除旧索引
        self._delete_old_index(dataset_id)
        
        # 3. 创建新索引
        new_index_id = self.create_index(documents)
        
        # 4. 更新索引映射
        self._update_index_mapping(dataset_id, new_index_id)
        
        return new_index_id

6.2 增量索引更新

# api/core/rag/datasource/vdb/incremental_index.py
class IncrementalIndexer:
    """增量索引器"""
    
    def __init__(self, vector_store: BaseVector):
        self.vector_store = vector_store
        self.change_log = []
    
    def add_document(self, document: Document) -> str:
        """添加单个文档"""
        # 1. 向量化
        embedding = self.vector_store.embedding_model.embed_documents([document.page_content])[0]
        
        # 2. 添加到向量库
        doc_ids = self.vector_store.add_texts(
            texts=[document.page_content],
            embeddings=[embedding],
            metadatas=[document.metadata]
        )
        
        # 3. 记录变更
        self.change_log.append({
            'action': 'add',
            'doc_id': doc_ids[0],
            'timestamp': datetime.utcnow()
        })
        
        return doc_ids[0]
    
    def update_document(self, doc_id: str, new_document: Document) -> bool:
        """更新文档"""
        try:
            # 1. 删除旧文档
            self.vector_store.delete([doc_id])
            
            # 2. 添加新文档
            new_doc_id = self.add_document(new_document)
            
            # 3. 记录变更
            self.change_log.append({
                'action': 'update',
                'old_doc_id': doc_id,
                'new_doc_id': new_doc_id,
                'timestamp': datetime.utcnow()
            })
            
            return True
            
        except Exception as e:
            logger.error(f"Failed to update document {doc_id}: {str(e)}")
            return False
    
    def delete_document(self, doc_id: str) -> bool:
        """删除文档"""
        try:
            self.vector_store.delete([doc_id])
            
            self.change_log.append({
                'action': 'delete',
                'doc_id': doc_id,
                'timestamp': datetime.utcnow()
            })
            
            return True
            
        except Exception as e:
            logger.error(f"Failed to delete document {doc_id}: {str(e)}")
            return False
    
    def batch_update(self, operations: List[dict]) -> dict:
        """批量更新操作"""
        results = {
            'success': 0,
            'failed': 0,
            'errors': []
        }
        
        for op in operations:
            try:
                if op['action'] == 'add':
                    self.add_document(op['document'])
                elif op['action'] == 'update':
                    self.update_document(op['doc_id'], op['document'])
                elif op['action'] == 'delete':
                    self.delete_document(op['doc_id'])
                
                results['success'] += 1
                
            except Exception as e:
                results['failed'] += 1
                results['errors'].append({
                    'operation': op,
                    'error': str(e)
                })
        
        return results

七、性能监控与优化

7.1 检索性能监控

# api/core/rag/monitoring/performance_monitor.py
class RetrievalPerformanceMonitor:
    """检索性能监控器"""
    
    def __init__(self):
        self.metrics = {
            'query_count': 0,
            'total_latency': 0.0,
            'error_count': 0,
            'cache_hit_rate': 0.0
        }
        self.cache = {}
        self.cache_ttl = 3600  # 1小时
    
    def monitor_query(self, func):
        """查询监控装饰器"""
        @wraps(func)
        def wrapper(*args, **kwargs):
            start_time = time.time()
            query_hash = self._hash_query(args, kwargs)
            
            try:
                # 检查缓存
                if query_hash in self.cache:
                    cache_entry = self.cache[query_hash]
                    if time.time() - cache_entry['timestamp'] < self.cache_ttl:
                        self._update_cache_metrics(hit=True)
                        return cache_entry['result']
                
                # 执行查询
                result = func(*args, **kwargs)
                
                # 缓存结果
                self.cache[query_hash] = {
                    'result': result,
                    'timestamp': time.time()
                }
                
                # 更新指标
                latency = time.time() - start_time
                self._update_metrics(latency, success=True)
                self._update_cache_metrics(hit=False)
                
                return result
                
            except Exception as e:
                latency = time.time() - start_time
                self._update_metrics(latency, success=False)
                logger.error(f"Query failed: {str(e)}")
                raise
        
        return wrapper
    
    def _update_metrics(self, latency: float, success: bool):
        """更新性能指标"""
        self.metrics['query_count'] += 1
        self.metrics['total_latency'] += latency
        
        if not success:
            self.metrics['error_count'] += 1
    
    def _update_cache_metrics(self, hit: bool):
        """更新缓存指标"""
        if not hasattr(self, '_cache_requests'):
            self._cache_requests = 0
            self._cache_hits = 0
        
        self._cache_requests += 1
        if hit:
            self._cache_hits += 1
        
        self.metrics['cache_hit_rate'] = self._cache_hits / self._cache_requests
    
    def get_metrics(self) -> dict:
        """获取性能指标"""
        avg_latency = (
            self.metrics['total_latency'] / self.metrics['query_count']
            if self.metrics['query_count'] > 0 else 0
        )
        
        error_rate = (
            self.metrics['error_count'] / self.metrics['query_count']
            if self.metrics['query_count'] > 0 else 0
        )
        
        return {
            **self.metrics,
            'avg_latency': avg_latency,
            'error_rate': error_rate,
            'qps': self._calculate_qps()
        }
    
    def _calculate_qps(self) -> float:
        """计算QPS"""
        # 简化实现,实际中需要时间窗口计算
        return self.metrics['query_count'] / 3600  # 假设1小时窗口

7.2 自动化优化建议

# api/core/rag/optimization/auto_optimizer.py
class AutoOptimizer:
    """自动优化器"""
    
    def __init__(self, monitor: RetrievalPerformanceMonitor):
        self.monitor = monitor
        self.optimization_rules = [
            self._check_latency_issues,
            self._check_cache_efficiency,
            self._check_error_rate,
            self._check_index_health
        ]
    
    def analyze_and_suggest(self) -> List[dict]:
        """分析并提供优化建议"""
        metrics = self.monitor.get_metrics()
        suggestions = []
        
        for rule in self.optimization_rules:
            suggestion = rule(metrics)
            if suggestion:
                suggestions.append(suggestion)
        
        return suggestions
    
    def _check_latency_issues(self, metrics: dict) -> Optional[dict]:
        """检查延迟问题"""
        if metrics['avg_latency'] > 2.0:  # 超过2秒
            return {
                'type': 'latency',
                'severity': 'high' if metrics['avg_latency'] > 5.0 else 'medium',
                'message': f"平均查询延迟 {metrics['avg_latency']:.2f}s 过高",
                'suggestions': [
                    "考虑增加缓存层",
                    "优化向量索引配置",
                    "减少检索的文档数量",
                    "使用更快的向量数据库"
                ]
            }
        return None
    
    def _check_cache_efficiency(self, metrics: dict) -> Optional[dict]:
        """检查缓存效率"""
        if metrics['cache_hit_rate'] < 0.3:  # 缓存命中率低于30%
            return {
                'type': 'cache',
                'severity': 'medium',
                'message': f"缓存命中率 {metrics['cache_hit_rate']:.2%} 较低",
                'suggestions': [
                    "增加缓存TTL时间",
                    "优化缓存键生成策略",
                    "考虑使用分层缓存",
                    "分析查询模式以优化缓存策略"
                ]
            }
        return None
    
    def _check_error_rate(self, metrics: dict) -> Optional[dict]:
        """检查错误率"""
        if metrics['error_rate'] > 0.05:  # 错误率超过5%
            return {
                'type': 'error',
                'severity': 'high',
                'message': f"错误率 {metrics['error_rate']:.2%} 过高",
                'suggestions': [
                    "检查向量数据库连接",
                    "增加重试机制",
                    "优化错误处理逻辑",
                    "监控资源使用情况"
                ]
            }
        return None

八、总结与实战建议

通过深入剖析 Dify 的知识库与向量检索系统,我们看到了一个生产级 RAG 系统的完整架构:

核心设计亮点

  1. 模块化架构:清晰的职责分离,便于维护和扩展
  2. 多策略支持:文档解析、文本分割、向量存储都支持多种策略
  3. 性能优化:批量处理、缓存机制、增量更新等优化手段
  4. 监控完善:全方位的性能监控和自动优化建议

实战经验总结

文档处理最佳实践

  • PDF 处理采用渐进式策略,优先尝试简单方法
  • 文本分割要平衡语义完整性和检索粒度
  • 内容清洗对最终效果影响巨大,不可忽视

向量检索优化技巧

  • 混合检索策略能显著提升召回率
  • 重排序是提升精度的关键环节
  • 查询扩展要适度,避免语义漂移

性能优化要点

  • 批量操作远比单条操作高效
  • 合理的缓存策略能大幅提升响应速度
  • 监控和自动优化是系统稳定运行的保障

下一章,我们将深入探讨 Dify 的 Agent 能力实现,看看它是如何让 AI 具备工具调用和推理能力的。相信那里会有更多关于 AI Agent 架构设计的精彩内容等着我们!

如果你在实现自己的 RAG 系统时遇到问题,或者对某些设计有不同看法,欢迎在评论区交流。记住,最好的架构不是最复杂的,而是最适合你业务场景的那一个!

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值