Source code for indexers.scann_indexer
import os
import tempfile
from typing import Any, Dict, List, Tuple
import numpy as np
try:
import scann
except ImportError:
scann = None
from .base import BaseIndexer
[docs]
class ScaNNIndexer(BaseIndexer):
def __init__(self, dimension: int):
super().__init__("ScaNN", dimension)
if scann is None:
raise ImportError("scann is not installed. Please install it with 'pip install scann'.")
self.searcher = None
self.metadata = []
self.temp_dir = tempfile.TemporaryDirectory()
[docs]
def build_index(self, embeddings: List[List[float]], metadata: List[Dict[str, Any]]) -> None:
data = np.array(embeddings).astype(np.float32)
# Normalize for cosine similarity
norms = np.linalg.norm(data, axis=1, keepdims=True)
norms[norms == 0] = 1.0
data = data / norms
# ScaNN configuration
# tree: AH (Anisotropic Hashing) is generally recommended for high accuracy
# num_leaves: typically sqrt(N)
num_leaves = int(np.sqrt(len(data)))
self.searcher = (
scann.scann_ops_pybind.builder(data, 10, "dot_product")
.tree(
num_leaves=num_leaves,
num_leaves_to_search=min(num_leaves, 100),
training_sample_size=len(data),
)
.score_ah(2, anisotropic_quantization_threshold=0.2)
.reorder(100)
.build()
)
self.metadata = metadata
# Save to temp dir to track size
self.searcher.serialize(self.temp_dir.name)
[docs]
def search(
self, query_embedding: List[float], top_k: int = 5
) -> List[Tuple[Dict[str, Any], float]]:
if self.searcher is None:
return []
query = np.array(query_embedding).astype(np.float32)
query_norm = np.linalg.norm(query)
if query_norm > 0:
query = query / query_norm
indices, distances = self.searcher.search(query, final_num_neighbors=top_k)
results = []
for idx, dist in zip(indices, distances):
results.append((self.metadata[idx], float(dist)))
return results
[docs]
def get_size(self) -> int:
total_size = 0
try:
if os.path.exists(self.temp_dir.name):
for dirpath, _, filenames in os.walk(self.temp_dir.name):
for f in filenames:
fp = os.path.join(dirpath, f)
total_size += os.path.getsize(fp)
except Exception:
pass
return total_size
[docs]
def cleanup(self) -> None:
self.temp_dir.cleanup()