"""
Milvus检索器模块

基于Milvus的混合检索器,支持dense和sparse向量检索。
使用BGEM3模型进行embedding。
"""

from typing import List, Optional
from langchain_core.documents import Document
from pymilvus import (
    connections,
    utility,
    FieldSchema,
    CollectionSchema,
    DataType,
    Collection,
    AnnSearchRequest,
    RRFRanker,
)
from pymilvus.model.hybrid import BGEM3EmbeddingFunction
from pathlib import Path
from src.evrag.retriever.base import BaseRetriever
from src.evrag.config import get_settings

# 常量定义
EMB_BATCH = 50
MAX_TEXT_LENGTH = 512  # 与旧项目保持一致
ID_MAX_LENGTH = 100
COL_NAME = "hybrid_bge_m3"
MILVUS_ALIAS = "default"  # 连接别名


class MilvusRetriever(BaseRetriever):
    """
    Milvus混合检索器

    支持dense和sparse向量的混合检索,使用BGEM3模型进行embedding。
    """

    def __init__(
        self,
        docs: Optional[List[Document]] = None,
        retrieve: bool = False,
        collection_name: Optional[str] = None,
    ) -> None:
        """
        Initialize the Milvus retriever.

        Args:
            docs: The documents to retrieve from.
            retrieve: Whether to retrieve from existing index or build new index.
            collection_name: The name of the collection to retrieve from.
        """
        super().__init__(docs, retrieve)

        # 获取配置
        settings = get_settings()

        self.collection_name = collection_name or COL_NAME
        self.milvus_db_path = settings.milvus_db_path
        self.bge_m3_model_path = settings.bge_m3_model_path

        # connect to milvus
        self._connect_milvus()

        # initialize embedding function
        # 注意:BGEM3EmbeddingFunction内部使用sentence-transformers
        # 我们通过设置torch的默认设备来控制GPU使用
        if settings.device == "cuda":
            import torch

            # 保存原始默认设备
            original_device = (
                torch.cuda.current_device() if torch.cuda.is_available() else None
            )
            # 设置默认设备为配置的GPU
            if (
                torch.cuda.is_available()
                and settings.milvus_retriever_gpu_id < torch.cuda.device_count()
            ):
                torch.cuda.set_device(settings.milvus_retriever_gpu_id)
                print(f"✓ Milvus检索器将使用GPU {settings.milvus_retriever_gpu_id}")
            else:
                print(f"⚠ GPU {settings.milvus_retriever_gpu_id} 不可用,使用默认GPU")

        self.embedding_handler = BGEM3EmbeddingFunction(
            model_name=str(self.bge_m3_model_path),
            device="cuda" if settings.device == "cuda" else "cpu",
        )
        self.col = self._init_collection()

        if not self.retrieve and self.documents:
            self._save_vectorstore(self.documents)

    def _connect_milvus(self) -> None:
        """connect to milvus"""
        try:
            # 确保数据库目录存在
            db_path = Path(self.milvus_db_path)
            db_dir = db_path.parent
            db_dir.mkdir(parents=True, exist_ok=True)

            # 使用绝对路径
            absolute_path = db_path.absolute()

            # 检查是否已经连接
            if MILVUS_ALIAS not in connections.list_connections():
                connections.connect(
                    alias=MILVUS_ALIAS,
                    uri=str(absolute_path),
                )
            else:
                # 如果已经连接,验证连接是否有效
                try:
                    # 尝试获取连接地址来验证连接
                    connections.get_connection_addr(MILVUS_ALIAS)
                except Exception:
                    # 连接无效,重新连接
                    connections.connect(
                        alias=MILVUS_ALIAS,
                        uri=str(absolute_path),
                    )
        except Exception as e:
            raise ConnectionError(f"Failed to connect to Milvus: {e}")

    def _init_collection(self) -> Collection:
        """
        initialize the collection

        Returns:
            The collection object.
        """
        # declare fields
        fields = [
            FieldSchema(
                name="unique_id",
                dtype=DataType.VARCHAR,
                is_primary=True,
                max_length=ID_MAX_LENGTH,
            ),
            FieldSchema(
                name="text", dtype=DataType.VARCHAR, max_length=MAX_TEXT_LENGTH
            ),
            FieldSchema(name="sparse_vector", dtype=DataType.SPARSE_FLOAT_VECTOR),
            FieldSchema(
                name="dense_vector",
                dtype=DataType.FLOAT_VECTOR,
                dim=self.embedding_handler.dim["dense"],
            ),
        ]

        schema = CollectionSchema(
            fields=fields, description="Hybrid BGE-M3 Collection Schema"
        )

        # if not retrieve mode and collection already exists, drop it
        if not self.retrieve and utility.has_collection(self.collection_name):
            Collection(self.collection_name).drop()

        # create collection
        col = Collection(self.collection_name, schema, consistency_level="Strong")

        # create index
        sparse_index = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"}
        dense_index = {"index_type": "AUTOINDEX", "metric_type": "IP"}
        col.create_index("sparse_vector", sparse_index)
        col.create_index("dense_vector", dense_index)
        col.load()

        return col

    def _save_vectorstore(self, docs: List[Document]) -> None:
        """
        save the vectorstore to milvus

        Args:
            documents: The documents to save.
        """
        # 提取文本并截断过长的文本(作为安全措施)
        raw_texts = []
        truncated_count = 0
        for doc in docs:
            text = doc.page_content
            if len(text) > MAX_TEXT_LENGTH:
                # 截断到最大长度(保留前MAX_TEXT_LENGTH个字符)
                text = text[:MAX_TEXT_LENGTH]
                truncated_count += 1
            raw_texts.append(text)

        if truncated_count > 0:
            print(
                f"⚠️  警告: {truncated_count} 个文档被截断(超过 {MAX_TEXT_LENGTH} 字符)"
            )

        unique_ids = [
            doc.metadata.get("unique_id", str(i)) for i, doc in enumerate(docs)
        ]

        # 计算embedding
        texts_embeddings = self.embedding_handler(raw_texts)

        # 批量插入
        for i in range(0, len(docs), EMB_BATCH):
            batched_entities = [
                unique_ids[i : i + EMB_BATCH],
                raw_texts[i : i + EMB_BATCH],
                texts_embeddings["sparse"][i : i + EMB_BATCH],
                texts_embeddings["dense"][i : i + EMB_BATCH],
            ]
            self.col.insert(batched_entities)

        print(f"索引构建完成,插入了{self.col.num_entities}条数据")

    def _hybrid_search(
        self,
        query_dense_embedding,
        query_sparse_embedding,
        limit: int = 10,
    ) -> List[dict]:
        """
        混合检索

        Args:
            query_dense_embedding: 查询的dense向量
            query_sparse_embedding: 查询的sparse向量(可能是scipy sparse matrix或dict)
            limit: 返回结果数量

        Returns:
            检索结果列表
        """
        # 转换sparse vector格式:如果是scipy sparse matrix,转换为dict格式
        if hasattr(query_sparse_embedding, "toarray"):
            # scipy sparse matrix -> dict {index: value}

            sparse_array = query_sparse_embedding.toarray().flatten()
            query_sparse_embedding = {
                int(i): float(v) for i, v in enumerate(sparse_array) if v != 0
            }
        elif hasattr(query_sparse_embedding, "getnnz"):
            # scipy sparse matrix (CSR/CSC) -> dict {index: value}

            if hasattr(query_sparse_embedding, "indices") and hasattr(
                query_sparse_embedding, "data"
            ):
                # CSR or CSC format
                query_sparse_embedding = {
                    int(idx): float(val)
                    for idx, val in zip(
                        query_sparse_embedding.indices, query_sparse_embedding.data
                    )
                }
            else:
                # Fallback: convert to dense then to dict
                sparse_array = query_sparse_embedding.toarray().flatten()
                query_sparse_embedding = {
                    int(i): float(v) for i, v in enumerate(sparse_array) if v != 0
                }
        elif isinstance(query_sparse_embedding, dict):
            # 已经是dict格式,确保值是float类型
            query_sparse_embedding = {
                int(k): float(v) for k, v in query_sparse_embedding.items()
            }
        elif isinstance(query_sparse_embedding, (list, tuple)):
            # 如果是列表,转换为dict格式(只保留非零值)
            query_sparse_embedding = {
                int(i): float(v) for i, v in enumerate(query_sparse_embedding) if v != 0
            }

        dense_search_params = {"metric_type": "IP", "params": {}}
        dense_req = AnnSearchRequest(
            [query_dense_embedding], "dense_vector", dense_search_params, limit=limit
        )

        sparse_search_params = {"metric_type": "IP", "params": {}}
        sparse_req = AnnSearchRequest(
            [query_sparse_embedding], "sparse_vector", sparse_search_params, limit=limit
        )

        # 使用RRF进行重排序
        rerank = RRFRanker()
        res = self.col.hybrid_search(
            [sparse_req, dense_req],
            rerank=rerank,
            limit=limit,
            output_fields=["unique_id", "text"],
        )

        return res[0]

    def retrieve_topk(self, query: str, topk: int = 10) -> List[Document]:
        """
        检索Top-K相关文档

        Args:
            query: 查询字符串
            topk: 返回的文档数量

        Returns:
            相关文档列表
        """
        if self.col is None:
            raise ValueError("Milvus collection not initialized")

        # 计算查询的embedding
        query_embeddings = self.embedding_handler.encode_queries([query])

        # 混合检索
        hybrid_results = self._hybrid_search(
            query_embeddings["dense"][0], query_embeddings["sparse"][0], limit=topk
        )

        # 转换为Document对象
        related_docs = []
        for result in hybrid_results:
            # 从结果中获取文本和元数据
            text = result.get("text", "")
            unique_id = result.get("id", "")

            doc = Document(page_content=text, metadata={"unique_id": unique_id})
            related_docs.append(doc)

        return related_docs
Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐