基于Milvus和bge m3实现混合检索的python实现
【代码】基于Milvus和bge m3实现混合检索的python实现。
·
"""
Milvus检索器模块
基于Milvus的混合检索器,支持dense和sparse向量检索。
使用BGEM3模型进行embedding。
"""
from typing import List, Optional
from langchain_core.documents import Document
from pymilvus import (
connections,
utility,
FieldSchema,
CollectionSchema,
DataType,
Collection,
AnnSearchRequest,
RRFRanker,
)
from pymilvus.model.hybrid import BGEM3EmbeddingFunction
from pathlib import Path
from src.evrag.retriever.base import BaseRetriever
from src.evrag.config import get_settings
# 常量定义
EMB_BATCH = 50
MAX_TEXT_LENGTH = 512 # 与旧项目保持一致
ID_MAX_LENGTH = 100
COL_NAME = "hybrid_bge_m3"
MILVUS_ALIAS = "default" # 连接别名
class MilvusRetriever(BaseRetriever):
"""
Milvus混合检索器
支持dense和sparse向量的混合检索,使用BGEM3模型进行embedding。
"""
def __init__(
self,
docs: Optional[List[Document]] = None,
retrieve: bool = False,
collection_name: Optional[str] = None,
) -> None:
"""
Initialize the Milvus retriever.
Args:
docs: The documents to retrieve from.
retrieve: Whether to retrieve from existing index or build new index.
collection_name: The name of the collection to retrieve from.
"""
super().__init__(docs, retrieve)
# 获取配置
settings = get_settings()
self.collection_name = collection_name or COL_NAME
self.milvus_db_path = settings.milvus_db_path
self.bge_m3_model_path = settings.bge_m3_model_path
# connect to milvus
self._connect_milvus()
# initialize embedding function
# 注意:BGEM3EmbeddingFunction内部使用sentence-transformers
# 我们通过设置torch的默认设备来控制GPU使用
if settings.device == "cuda":
import torch
# 保存原始默认设备
original_device = (
torch.cuda.current_device() if torch.cuda.is_available() else None
)
# 设置默认设备为配置的GPU
if (
torch.cuda.is_available()
and settings.milvus_retriever_gpu_id < torch.cuda.device_count()
):
torch.cuda.set_device(settings.milvus_retriever_gpu_id)
print(f"✓ Milvus检索器将使用GPU {settings.milvus_retriever_gpu_id}")
else:
print(f"⚠ GPU {settings.milvus_retriever_gpu_id} 不可用,使用默认GPU")
self.embedding_handler = BGEM3EmbeddingFunction(
model_name=str(self.bge_m3_model_path),
device="cuda" if settings.device == "cuda" else "cpu",
)
self.col = self._init_collection()
if not self.retrieve and self.documents:
self._save_vectorstore(self.documents)
def _connect_milvus(self) -> None:
"""connect to milvus"""
try:
# 确保数据库目录存在
db_path = Path(self.milvus_db_path)
db_dir = db_path.parent
db_dir.mkdir(parents=True, exist_ok=True)
# 使用绝对路径
absolute_path = db_path.absolute()
# 检查是否已经连接
if MILVUS_ALIAS not in connections.list_connections():
connections.connect(
alias=MILVUS_ALIAS,
uri=str(absolute_path),
)
else:
# 如果已经连接,验证连接是否有效
try:
# 尝试获取连接地址来验证连接
connections.get_connection_addr(MILVUS_ALIAS)
except Exception:
# 连接无效,重新连接
connections.connect(
alias=MILVUS_ALIAS,
uri=str(absolute_path),
)
except Exception as e:
raise ConnectionError(f"Failed to connect to Milvus: {e}")
def _init_collection(self) -> Collection:
"""
initialize the collection
Returns:
The collection object.
"""
# declare fields
fields = [
FieldSchema(
name="unique_id",
dtype=DataType.VARCHAR,
is_primary=True,
max_length=ID_MAX_LENGTH,
),
FieldSchema(
name="text", dtype=DataType.VARCHAR, max_length=MAX_TEXT_LENGTH
),
FieldSchema(name="sparse_vector", dtype=DataType.SPARSE_FLOAT_VECTOR),
FieldSchema(
name="dense_vector",
dtype=DataType.FLOAT_VECTOR,
dim=self.embedding_handler.dim["dense"],
),
]
schema = CollectionSchema(
fields=fields, description="Hybrid BGE-M3 Collection Schema"
)
# if not retrieve mode and collection already exists, drop it
if not self.retrieve and utility.has_collection(self.collection_name):
Collection(self.collection_name).drop()
# create collection
col = Collection(self.collection_name, schema, consistency_level="Strong")
# create index
sparse_index = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"}
dense_index = {"index_type": "AUTOINDEX", "metric_type": "IP"}
col.create_index("sparse_vector", sparse_index)
col.create_index("dense_vector", dense_index)
col.load()
return col
def _save_vectorstore(self, docs: List[Document]) -> None:
"""
save the vectorstore to milvus
Args:
documents: The documents to save.
"""
# 提取文本并截断过长的文本(作为安全措施)
raw_texts = []
truncated_count = 0
for doc in docs:
text = doc.page_content
if len(text) > MAX_TEXT_LENGTH:
# 截断到最大长度(保留前MAX_TEXT_LENGTH个字符)
text = text[:MAX_TEXT_LENGTH]
truncated_count += 1
raw_texts.append(text)
if truncated_count > 0:
print(
f"⚠️ 警告: {truncated_count} 个文档被截断(超过 {MAX_TEXT_LENGTH} 字符)"
)
unique_ids = [
doc.metadata.get("unique_id", str(i)) for i, doc in enumerate(docs)
]
# 计算embedding
texts_embeddings = self.embedding_handler(raw_texts)
# 批量插入
for i in range(0, len(docs), EMB_BATCH):
batched_entities = [
unique_ids[i : i + EMB_BATCH],
raw_texts[i : i + EMB_BATCH],
texts_embeddings["sparse"][i : i + EMB_BATCH],
texts_embeddings["dense"][i : i + EMB_BATCH],
]
self.col.insert(batched_entities)
print(f"索引构建完成,插入了{self.col.num_entities}条数据")
def _hybrid_search(
self,
query_dense_embedding,
query_sparse_embedding,
limit: int = 10,
) -> List[dict]:
"""
混合检索
Args:
query_dense_embedding: 查询的dense向量
query_sparse_embedding: 查询的sparse向量(可能是scipy sparse matrix或dict)
limit: 返回结果数量
Returns:
检索结果列表
"""
# 转换sparse vector格式:如果是scipy sparse matrix,转换为dict格式
if hasattr(query_sparse_embedding, "toarray"):
# scipy sparse matrix -> dict {index: value}
sparse_array = query_sparse_embedding.toarray().flatten()
query_sparse_embedding = {
int(i): float(v) for i, v in enumerate(sparse_array) if v != 0
}
elif hasattr(query_sparse_embedding, "getnnz"):
# scipy sparse matrix (CSR/CSC) -> dict {index: value}
if hasattr(query_sparse_embedding, "indices") and hasattr(
query_sparse_embedding, "data"
):
# CSR or CSC format
query_sparse_embedding = {
int(idx): float(val)
for idx, val in zip(
query_sparse_embedding.indices, query_sparse_embedding.data
)
}
else:
# Fallback: convert to dense then to dict
sparse_array = query_sparse_embedding.toarray().flatten()
query_sparse_embedding = {
int(i): float(v) for i, v in enumerate(sparse_array) if v != 0
}
elif isinstance(query_sparse_embedding, dict):
# 已经是dict格式,确保值是float类型
query_sparse_embedding = {
int(k): float(v) for k, v in query_sparse_embedding.items()
}
elif isinstance(query_sparse_embedding, (list, tuple)):
# 如果是列表,转换为dict格式(只保留非零值)
query_sparse_embedding = {
int(i): float(v) for i, v in enumerate(query_sparse_embedding) if v != 0
}
dense_search_params = {"metric_type": "IP", "params": {}}
dense_req = AnnSearchRequest(
[query_dense_embedding], "dense_vector", dense_search_params, limit=limit
)
sparse_search_params = {"metric_type": "IP", "params": {}}
sparse_req = AnnSearchRequest(
[query_sparse_embedding], "sparse_vector", sparse_search_params, limit=limit
)
# 使用RRF进行重排序
rerank = RRFRanker()
res = self.col.hybrid_search(
[sparse_req, dense_req],
rerank=rerank,
limit=limit,
output_fields=["unique_id", "text"],
)
return res[0]
def retrieve_topk(self, query: str, topk: int = 10) -> List[Document]:
"""
检索Top-K相关文档
Args:
query: 查询字符串
topk: 返回的文档数量
Returns:
相关文档列表
"""
if self.col is None:
raise ValueError("Milvus collection not initialized")
# 计算查询的embedding
query_embeddings = self.embedding_handler.encode_queries([query])
# 混合检索
hybrid_results = self._hybrid_search(
query_embeddings["dense"][0], query_embeddings["sparse"][0], limit=topk
)
# 转换为Document对象
related_docs = []
for result in hybrid_results:
# 从结果中获取文本和元数据
text = result.get("text", "")
unique_id = result.get("id", "")
doc = Document(page_content=text, metadata={"unique_id": unique_id})
related_docs.append(doc)
return related_docs
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐


所有评论(0)