Source code for core

import os
import time
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
from sklearn.cluster import KMeans

from indexers import BaseIndexer, get_indexer_map

try:
    from safetensors.numpy import load_file, save_file
except ImportError:
    save_file = load_file = None


[docs] class Collection: """ Primary interface for managing embeddings and metadata. Provides a high-level API for indexing, search, and I/O. """ def __init__( self, name: str = "default", dimension: Optional[int] = None, indexer_type: str = "faiss", sparse_indexer_type: Optional[str] = None, truncate_dim: Optional[int] = None, **indexer_kwargs, ): self.name = name self.dimension = truncate_dim if truncate_dim else dimension self.full_dimension = dimension self.truncate_dim = truncate_dim self.indexer_type = indexer_type.lower() self.sparse_indexer_type = sparse_indexer_type.lower() if sparse_indexer_type else None self.indexer_kwargs = indexer_kwargs self.indexer: Optional[BaseIndexer] = None self.sparse_indexer: Optional[BaseIndexer] = None self._vectors = None self._metadata = [] if self.dimension: self._init_indexer(self.dimension) if self.sparse_indexer_type: self._init_sparse_indexer() def _init_indexer(self, dimension: int): indexer_map = get_indexer_map() if self.indexer_type not in indexer_map: raise ValueError(f"Indexer type '{self.indexer_type}' not found.") indexer_cls = indexer_map[self.indexer_type] self.indexer = indexer_cls(dimension=dimension, **self.indexer_kwargs) self.dimension = dimension def _init_sparse_indexer(self): indexer_map = get_indexer_map() if self.sparse_indexer_type not in indexer_map: raise ValueError(f"Sparse indexer type '{self.sparse_indexer_type}' not found.") indexer_cls = indexer_map[self.sparse_indexer_type] # Dimension 0 for sparse indexers like BM25 self.sparse_indexer = indexer_cls(dimension=0)
[docs] def add( self, vectors: Union[np.ndarray, List[List[float]]], metadata: Optional[List[Dict[str, Any]]] = None, ): """Add vectors and metadata to the collection.""" vectors = np.array(vectors).astype(np.float32) # Handle Matryoshka truncation during add if specified if self.truncate_dim: vectors = vectors[:, : self.truncate_dim] if self.dimension is None: self._init_indexer(vectors.shape[1]) elif vectors.shape[1] != self.dimension: raise ValueError( f"Vector dimension mismatch. Expected {self.dimension}, got {vectors.shape[1]}" ) meta = metadata or [{} for _ in range(len(vectors))] # Build/Update dense index self.indexer.build_index(vectors.tolist(), meta) # Build/Update sparse index if present if self.sparse_indexer: self.sparse_indexer.build_index(vectors.tolist(), meta) # Keep local copy for I/O and non-native operations if self._vectors is None: self._vectors = vectors else: self._vectors = np.vstack([self._vectors, vectors]) self._metadata.extend(meta)
[docs] def add_images(self, image_paths: List[str], model: str = "openai/clip-vit-base-patch32", metadata: Optional[List[Dict[str, Any]]] = None): """ Embed and add images to the collection. """ from llm import Embedder emb = Embedder(model) vectors = emb.embed_texts(image_paths) meta = metadata or [{} for _ in range(len(image_paths))] for i, path in enumerate(image_paths): meta[i]["image_path"] = path self.add(vectors, meta)
[docs] def search( self, query: Union[np.ndarray, List[float]], top_k: int = 5, where: Optional[Dict[str, Any]] = None, reranker: Optional[Union[callable, "RerankHandler"]] = None, query_text: Optional[str] = None, ) -> List[Tuple[Dict[str, Any], float]]: """ Search the collection for the nearest neighbors. Args: query: Vector to search for. top_k: Number of results to return. where: Metadata filter dictionary. reranker: A callable or RerankHandler for re-scoring. query_text: Original text for reranking context. """ if self.indexer is None: raise RuntimeError("Collection is empty. Add data before searching.") # Increase search limit if reranking or filtering is requested search_k = top_k if where or reranker: search_k = max(top_k * 10, 100) query_vec = np.array(query).astype(np.float32) if self.truncate_dim: if len(query_vec.shape) == 1: query_vec = query_vec[: self.truncate_dim] else: query_vec = query_vec[:, : self.truncate_dim] def _process_single(q): results = self.indexer.search(q.tolist(), top_k=search_k) if where: results = self._apply_filter(results, where) if reranker: # If it's a RerankHandler, use its rerank method if hasattr(reranker, "rerank") and query_text: results = reranker.rerank(query_text, results, top_k=top_k) else: results = reranker(q, results) return results[:top_k] if len(query_vec.shape) == 1: return _process_single(query_vec) else: return [_process_single(q) for q in query_vec]
[docs] def search_image(self, image_path: str, model: str = "openai/clip-vit-base-patch32", top_k: int = 5) -> List[Tuple[Dict[str, Any], float]]: """ Search for similar items using an image query. """ from llm import Embedder emb = Embedder(model) q_vec = emb.embed_query(image_path) return self.search(q_vec, top_k=top_k)
[docs] def search_trajectory( self, trajectory: Union[np.ndarray, List[List[float]]], top_k: int = 5, pooling: str = "mean", where: Optional[Dict[str, Any]] = None, ) -> List[Tuple[Dict[str, Any], float]]: """ Search for similar trajectories (sequences of vectors). Args: trajectory: Sequence of vectors representing a state/action trajectory. top_k: Number of results to return. pooling: Method to pool the trajectory into a single search vector ('mean' or 'max'). where: Metadata filter dictionary. """ traj_vecs = np.array(trajectory).astype(np.float32) if len(traj_vecs.shape) != 2: raise ValueError("Trajectory must be a 2D array (sequence of vectors).") if pooling == "mean": query_vec = np.mean(traj_vecs, axis=0) elif pooling == "max": query_vec = np.max(traj_vecs, axis=0) else: raise ValueError(f"Unknown pooling method: {pooling}") return self.search(query_vec, top_k=top_k, where=where)
[docs] def benchmark(self, indexers: Optional[List[str]] = None, top_k: int = 5): """ Benchmark multiple indexers on the current collection data. Args: indexers: List of indexer names to compare (e.g. ["faiss", "hnswlib"]). If None, benchmarks all available indexers. top_k: Number of neighbors to search for during benchmark. """ if self._vectors is None: raise RuntimeError("Collection is empty. Add data before benchmarking.") from rich.console import Console from benchmark import benchmark_single_indexer, display_results from indexers import get_indexer_map console = Console() indexer_map = get_indexer_map() if indexers is None: selected = list(indexer_map.keys()) else: selected = [i.lower() for i in indexers if i.lower() in indexer_map] results = [] for name in selected: res = benchmark_single_indexer( name, indexer_map[name], self.dimension, self._vectors.tolist(), self._metadata, console, ) if res: results.append(res) if results: display_results(results, console) return results
[docs] def evaluate( self, indexer_type: str = "faiss-hnsw", top_k: int = 10, **kwargs ) -> Dict[str, Any]: """ Evaluate an indexer's recall and latency against an exact search baseline. Returns: Dictionary with 'recall' and 'latency_ms' metrics. """ if self._vectors is None: raise RuntimeError("Collection is empty. Add data before evaluating.") import time from indexers.simple_indexer import SimpleIndexer # 1. Exact Search Baseline exact = SimpleIndexer(self.dimension) exact.build_index(self._vectors.tolist(), self._metadata) # 2. Candidate Indexer indexer_map = get_indexer_map() if indexer_type not in indexer_map: raise ValueError(f"Indexer '{indexer_type}' not found.") candidate_cls = indexer_map[indexer_type] candidate = candidate_cls(self.dimension, **kwargs) candidate.build_index(self._vectors.tolist(), self._metadata) # 3. Sample queries (up to 100) n_samples = min(100, len(self._vectors)) sample_indices = np.random.choice(len(self._vectors), n_samples, replace=False) queries = self._vectors[sample_indices] recalls = [] latencies = [] for q in queries: q_list = q.tolist() # Get exact ground truth IDs exact_res = exact.search(q_list, top_k=top_k) exact_ids = {meta.get("id") or meta.get("text") or str(meta) for meta, _ in exact_res} # Get candidate results and measure latency t0 = time.perf_counter() cand_res = candidate.search(q_list, top_k=top_k) latencies.append((time.perf_counter() - t0) * 1000) cand_ids = {meta.get("id") or meta.get("text") or str(meta) for meta, _ in cand_res} # Calculate intersection (Recall@K) if exact_ids: intersection = exact_ids.intersection(cand_ids) recalls.append(len(intersection) / len(exact_ids)) else: recalls.append(1.0) return { "indexer": indexer_type, "recall": float(np.mean(recalls)), "latency_ms": float(np.mean(latencies)), "samples": n_samples, }
[docs] def export_to_production(self, backend: str, connection_url: str, collection_name: Optional[str] = None): """ One-click export from local Embenx collection to production clusters. Supported backends: 'qdrant', 'milvus'. """ if self._vectors is None: raise RuntimeError("Collection is empty. Add data before exporting.") name = collection_name or self.name if backend.lower() == "qdrant": from qdrant_client import QdrantClient from qdrant_client.http import models client = QdrantClient(url=connection_url) client.recreate_collection( collection_name=name, vectors_config=models.VectorParams(size=self.dimension, distance=models.Distance.COSINE), ) client.upload_collection( collection_name=name, vectors=self._vectors, payload=self._metadata, ids=None ) print(f"Successfully exported {len(self._vectors)} vectors to Qdrant at {connection_url}") elif backend.lower() == "milvus": from pymilvus import Collection as MilvusCollection from pymilvus import CollectionSchema, DataType, FieldSchema, connections connections.connect(alias="default", uri=connection_url) fields = [ FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=self.dimension) ] schema = CollectionSchema(fields, f"Embenx export of {name}") mc = MilvusCollection(name, schema) # Milvus usually needs flat list for insert entities = [ self._vectors.tolist() ] mc.insert(entities) mc.flush() print(f"Successfully exported {len(self._vectors)} vectors to Milvus at {connection_url}") else: raise ValueError(f"Export to backend '{backend}' not supported yet.")
def _apply_filter(self, results: List[Tuple[Dict[str, Any], float]], where: Dict[str, Any]): filtered = [] for meta, dist in results: match = True for key, value in where.items(): if meta.get(key) != value: match = False break if match: filtered.append((meta, dist)) return filtered
[docs] @classmethod def from_numpy(cls, path: str, **kwargs): """Load a collection from a .npy or .npz file.""" data = np.load(path, allow_pickle=True) col = cls(**kwargs) if path.endswith(".npz"): vectors = data.get("vectors") metadata = data.get("metadata") col.add(vectors, metadata.tolist() if metadata is not None else None) else: col.add(data) return col
[docs] @classmethod def from_parquet(cls, path: str, vector_col: str = "vector", **kwargs): """Load a collection from a Parquet file.""" df = pd.read_parquet(path) vectors = np.stack(df[vector_col].values) metadata = df.drop(columns=[vector_col]).to_dict(orient="records") col = cls(**kwargs) col.add(vectors, metadata) return col
[docs] def generate_synthetic_queries( self, text_key: str = "text", n_queries_per_doc: int = 1, num_docs: int = 100, model: str = "gpt-4o-mini", custom_prompt: Optional[str] = None, output_path: Optional[str] = None, api_base: Optional[str] = None, **llm_kwargs, ) -> List[Dict[str, Any]]: """ Generate synthetic search queries for documents in the collection using an LLM. Supports local Ollama/vLLM via api_base and llm_kwargs. """ from llm import Generator import random generator = Generator(model_name=model, api_base=api_base, **llm_kwargs) valid_docs = [m for m in self._metadata if m.get(text_key)] if not valid_docs: return [] sample_size = min(num_docs, len(valid_docs)) sampled_docs = random.sample(valid_docs, sample_size) results = [] for i, doc in enumerate(sampled_docs): text = doc[text_key] if custom_prompt: prompt = custom_prompt.format(text=text, n=n_queries_per_doc) else: prompt = ( f"Given the following document text, generate {n_queries_per_doc} " f"diverse and realistic search queries that a user might type to find this document. " f"Return ONLY the queries, one per line, without numbering or bullets.\n\n" f"Document: {text}" ) response = generator.generate(prompt) if not response: continue lines = [q.lstrip("- *1234567890.\t").strip() for q in response.split("\n") if q.strip()] queries = [q for q in lines if q] for q in queries[:n_queries_per_doc]: results.append({ "query": q, "doc_id": doc.get("id"), "doc_text": text }) if output_path and results: df = pd.DataFrame(results) if output_path.endswith(".parquet"): df.to_parquet(output_path) elif output_path.endswith(".jsonl"): df.to_json(output_path, orient="records", lines=True) elif output_path.endswith(".csv"): df.to_csv(output_path, index=False) return results
[docs] def to_parquet(self, path: str, vector_col: str = "vector"): """Save the collection to a Parquet file.""" if self._vectors is None: raise RuntimeError("Collection is empty.") df = pd.DataFrame(self._metadata) df[vector_col] = list(self._vectors) df.to_parquet(path)
def __repr__(self) -> str: count = len(self._metadata) if self._metadata else 0 return f"Collection(name='{self.name}', size={count}, indexer='{self.indexer_type}', sparse='{self.sparse_indexer_type}')"
[docs] class CacheCollection(Collection): """ Specialized collection for Retrieval-Augmented KV Caching (RA-KVC). Supports storing high-dimensional activation tensors. """
[docs] def add_cache( self, vectors: Union[np.ndarray, List[List[float]]], activations: Dict[str, np.ndarray], metadata: Optional[List[Dict[str, Any]]] = None, quantize: bool = False ): """ Add embeddings and their associated KV cache activations. """ if save_file is None: raise ImportError("safetensors is required for CacheCollection.") os.makedirs(f"cache_{self.name}", exist_ok=True) meta = metadata or [{} for _ in range(len(vectors))] for i, m in enumerate(meta): cache_id = m.get("id") or f"idx_{len(self._metadata) + i}" cache_path = os.path.join(f"cache_{self.name}", f"{cache_id}.safetensors") doc_activations = {k: v[i] for k, v in activations.items()} if quantize: # Simple 1-bit quantization (sign-based) as per TurboQuant concepts doc_activations = {k: np.sign(v).astype(np.int8) for k, v in doc_activations.items()} m["quantized"] = True save_file(doc_activations, cache_path) m["cache_path"] = cache_path self.add(vectors, meta)
[docs] def get_cache(self, metadata: Dict[str, Any]) -> Dict[str, np.ndarray]: """ Retrieve activations for a given metadata result. """ if load_file is None: raise ImportError("safetensors is required for CacheCollection.") path = metadata.get("cache_path") if path and os.path.exists(path): return load_file(path) return {}
[docs] class StateCollection(Collection): """ Specialized collection for State Space Model (SSM) hydration. Supports storing hidden states (h0). """
[docs] def add_states( self, vectors: Union[np.ndarray, List[List[float]]], states: np.ndarray, metadata: Optional[List[Dict[str, Any]]] = None, ): """ Add embeddings and their associated SSM hidden states. """ if save_file is None: raise ImportError("safetensors is required for StateCollection.") os.makedirs(f"states_{self.name}", exist_ok=True) meta = metadata or [{} for _ in range(len(vectors))] for i, m in enumerate(meta): state_id = m.get("id") or f"state_{len(self._metadata) + i}" state_path = os.path.join(f"states_{self.name}", f"{state_id}.safetensors") # Store hidden state 'h' save_file({"h": states[i]}, state_path) m["state_path"] = state_path self.add(vectors, meta)
[docs] def get_state(self, metadata: Dict[str, Any]) -> np.ndarray: """ Retrieve hidden state for a given metadata result. """ if load_file is None: raise ImportError("safetensors is required for StateCollection.") path = metadata.get("state_path") if path and os.path.exists(path): return load_file(path)["h"] return None
[docs] class ClusterCollection(Collection): """ Specialized collection for ClusterKV-style optimizations. Implements semantic clustering of vectors for improved retrieval throughput. """ def __init__(self, n_clusters: int = 10, **kwargs): super().__init__(**kwargs) self.n_clusters = n_clusters self.kmeans = KMeans(n_clusters=n_clusters, random_state=42) self.cluster_map = {} # cluster_id -> list of indices
[docs] def cluster_data(self): """ Perform semantic clustering on the current collection data. """ if self._vectors is None or len(self._vectors) < self.n_clusters: return cluster_labels = self.kmeans.fit_predict(self._vectors) self.cluster_map = {} for i, label in enumerate(cluster_labels): label = int(label) if label not in self.cluster_map: self.cluster_map[label] = [] self.cluster_map[label].append(i) # Update metadata with cluster info for i, label in enumerate(cluster_labels): self._metadata[i]["cluster_id"] = int(label)
[docs] def search_clustered(self, query: np.ndarray, top_k: int = 5) -> List[Tuple[Dict[str, Any], float]]: """ Search for vectors by first identifying the most relevant cluster. """ if self._vectors is None: return [] # 1. Identify nearest cluster query_vec = np.array(query).reshape(1, -1).astype(np.float32) cluster_id = int(self.kmeans.predict(query_vec)[0]) # 2. Retrieve indices for this cluster indices = self.cluster_map.get(cluster_id, []) if not indices: return self.search(query, top_k=top_k) # 3. Brute force within cluster (simulating ClusterKV pattern) cluster_vectors = self._vectors[indices] cluster_metadata = [self._metadata[i] for i in indices] # Simple cosine similarity within cluster norms = np.linalg.norm(cluster_vectors, axis=1) * np.linalg.norm(query) norms[norms == 0] = 1.0 similarities = np.dot(cluster_vectors, query.flatten()) / norms results_idx = np.argsort(similarities)[::-1][:top_k] return [(cluster_metadata[i], 1.0 - float(similarities[i])) for i in results_idx]
[docs] class SpatialCollection(Collection): """ Specialized collection for ESWM (Episodic Spatial World Memory). Supports navigation trajectories and spatial-aware retrieval. """
[docs] def add_spatial( self, vectors: Union[np.ndarray, List[List[float]]], coords: np.ndarray, metadata: Optional[List[Dict[str, Any]]] = None, ): """ Add semantic embeddings and their associated spatial coordinates (x, y, z). """ meta = metadata or [{} for _ in range(len(vectors))] for i, m in enumerate(meta): m["coords"] = coords[i].tolist() self.add(vectors, meta)
[docs] def search_spatial( self, query_vector: np.ndarray, current_coords: np.ndarray, top_k: int = 5, spatial_radius: float = 10.0 ) -> List[Tuple[Dict[str, Any], float]]: """ Spatial-aware search that favors episodic memories near the current location. """ # 1. Perform semantic search first results = self.search(query_vector, top_k=top_k * 2) spatial_results = [] for meta, sem_dist in results: item_coords = np.array(meta.get("coords", [0, 0, 0])) # Euclidean distance in 3D space euc_dist = np.linalg.norm(item_coords - current_coords) # Spatial gating (ESWM pattern) if euc_dist <= spatial_radius: # Combine semantic distance and spatial proximity # Higher weight to spatial proximity if needed combined_score = sem_dist * (1.0 + (euc_dist / spatial_radius)) spatial_results.append((meta, float(combined_score))) # Sort by combined score spatial_results.sort(key=lambda x: x[1]) return spatial_results[:top_k]
[docs] class TemporalCollection(Collection): """ Specialized collection for Echo-style temporal episodic memory. Supports time-stamped embeddings and recency-biased retrieval. """
[docs] def add_temporal( self, vectors: Union[np.ndarray, List[List[float]]], timestamps: Optional[List[float]] = None, metadata: Optional[List[Dict[str, Any]]] = None, ): """ Add embeddings with Unix timestamps. """ if timestamps is None: timestamps = [time.time()] * len(vectors) meta = metadata or [{} for _ in range(len(vectors))] for i, m in enumerate(meta): m["timestamp"] = float(timestamps[i]) self.add(vectors, meta)
[docs] def search_temporal( self, query_vector: np.ndarray, top_k: int = 5, recency_weight: float = 0.5, time_window: Optional[Tuple[float, float]] = None, where: Optional[Dict[str, Any]] = None, ) -> List[Tuple[Dict[str, Any], float]]: """ Temporal-aware search that ranks results by similarity and recency. """ # 1. Semantic search # Pass the where filter down to the base search if supported or apply here results = self.search(query_vector, top_k=top_k * 5, where=where) current_time = time.time() temporal_results = [] for meta, sem_dist in results: ts = meta.get("timestamp", 0.0) # 2. Time window filtering if time_window: if not (time_window[0] <= ts <= time_window[1]): continue # 3. Recency scoring (normalized time difference) # Smaller diff = more recent = higher boost time_diff = max(0, current_time - ts) # Simple decay: 1 / (1 + log(1 + time_diff)) recency_score = 1.0 / (1.0 + np.log1p(time_diff)) # Combined score (Distance is lower-better, Recency is higher-better) # We convert recency to a 'temporal distance': 1 - recency temporal_dist = 1.0 - recency_score combined_score = (sem_dist * (1.0 - recency_weight)) + (temporal_dist * recency_weight) temporal_results.append((meta, float(combined_score))) temporal_results.sort(key=lambda x: x[1]) return temporal_results[:top_k]
[docs] class AgenticCollection(Collection): """ Specialized collection for autonomous agent memory. Supports search loops, feedback, and self-healing ranking. """
[docs] def feedback(self, doc_id: str, label: str = "good"): """ Provide feedback on a retrieval result. Adjusts metadata to influence future rankings. """ for i, m in enumerate(self._metadata): if m.get("id") == doc_id: if "feedback_score" not in m: m["feedback_score"] = 0.0 if label == "good": m["feedback_score"] += 0.5 # Stronger boost elif label == "bad": m["feedback_score"] -= 0.5 # Stronger demote break
[docs] class Session: """ Managed agentic session with automatic temporal decay and persistence. """ def __init__(self, session_id: str, dimension: int, storage_dir: str = ".embenx_sessions"): self.session_id = session_id self.storage_dir = storage_dir self.path = os.path.join(storage_dir, f"{session_id}.parquet") os.makedirs(storage_dir, exist_ok=True) # Initialize with TemporalCollection for time-based features if os.path.exists(self.path): self.collection = TemporalCollection.from_parquet(self.path) else: self.collection = TemporalCollection(name=session_id, dimension=dimension)
[docs] def add_interaction(self, vector: Union[np.ndarray, List[float]], text: str, **metadata): """ Add a new interaction to the session memory. """ meta = metadata or {} meta["text"] = text self.collection.add_temporal([vector], metadata=[meta]) self.collection.to_parquet(self.path)
[docs] def retrieve_context(self, query_vector: np.ndarray, top_k: int = 5, recency_weight: float = 0.4): """ Retrieve relevant context from the session with recency bias. """ return self.collection.search_temporal(query_vector, top_k=top_k, recency_weight=recency_weight)
[docs] def cleanup(self): """Delete session data.""" if os.path.exists(self.path): os.remove(self.path)