Source code for indexers.milvus_indexer
import os
import tempfile
from typing import Any, Dict, List, Tuple
from pymilvus import DataType, MilvusClient
from .base import BaseIndexer
[docs]
class MilvusIndexer(BaseIndexer):
def __init__(self, dimension: int):
super().__init__("Milvus", dimension)
self.temp_file = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
self.client = MilvusClient(uri=self.temp_file.name)
self.collection_name = "benchmark"
schema = MilvusClient.create_schema(
auto_id=True,
enable_dynamic_field=True,
)
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dimension)
index_params = self.client.prepare_index_params()
index_params.add_index(field_name="vector", index_type="FLAT", metric_type="COSINE")
self.client.create_collection(
collection_name=self.collection_name, schema=schema, index_params=index_params
)
[docs]
def build_index(self, embeddings: List[List[float]], metadata: List[Dict[str, Any]]) -> None:
data = []
for emb, meta in zip(embeddings, metadata):
clean_meta = {}
for k, v in meta.items():
if isinstance(v, (str, int, float, bool)):
clean_meta[k] = v
else:
clean_meta[k] = str(v)
row = {"vector": emb, **clean_meta}
data.append(row)
if data:
self.client.insert(collection_name=self.collection_name, data=data)
[docs]
def search(
self, query_embedding: List[float], top_k: int = 5
) -> List[Tuple[Dict[str, Any], float]]:
res = self.client.search(
collection_name=self.collection_name,
data=[query_embedding],
limit=top_k,
output_fields=["*"],
)
out = []
if res and len(res[0]) > 0:
for hit in res[0]:
entity = hit["entity"]
dist = hit["distance"]
entity.pop("id", None)
entity.pop("vector", None)
out.append((entity, float(dist)))
return out
[docs]
def get_size(self) -> int:
if os.path.exists(self.temp_file.name):
return os.path.getsize(self.temp_file.name)
return 0
[docs]
def cleanup(self) -> None:
if os.path.exists(self.temp_file.name):
os.remove(self.temp_file.name)