import os
import time
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
from sklearn.cluster import KMeans
from indexers import BaseIndexer, get_indexer_map
try:
from safetensors.numpy import load_file, save_file
except ImportError:
save_file = load_file = None
[docs]
class Collection:
"""
Primary interface for managing embeddings and metadata.
Provides a high-level API for indexing, search, and I/O.
"""
def __init__(
self,
name: str = "default",
dimension: Optional[int] = None,
indexer_type: str = "faiss",
sparse_indexer_type: Optional[str] = None,
truncate_dim: Optional[int] = None,
**indexer_kwargs,
):
self.name = name
self.dimension = truncate_dim if truncate_dim else dimension
self.full_dimension = dimension
self.truncate_dim = truncate_dim
self.indexer_type = indexer_type.lower()
self.sparse_indexer_type = sparse_indexer_type.lower() if sparse_indexer_type else None
self.indexer_kwargs = indexer_kwargs
self.indexer: Optional[BaseIndexer] = None
self.sparse_indexer: Optional[BaseIndexer] = None
self._vectors = None
self._metadata = []
if self.dimension:
self._init_indexer(self.dimension)
if self.sparse_indexer_type:
self._init_sparse_indexer()
def _init_indexer(self, dimension: int):
indexer_map = get_indexer_map()
if self.indexer_type not in indexer_map:
raise ValueError(f"Indexer type '{self.indexer_type}' not found.")
indexer_cls = indexer_map[self.indexer_type]
self.indexer = indexer_cls(dimension=dimension, **self.indexer_kwargs)
self.dimension = dimension
def _init_sparse_indexer(self):
indexer_map = get_indexer_map()
if self.sparse_indexer_type not in indexer_map:
raise ValueError(f"Sparse indexer type '{self.sparse_indexer_type}' not found.")
indexer_cls = indexer_map[self.sparse_indexer_type]
# Dimension 0 for sparse indexers like BM25
self.sparse_indexer = indexer_cls(dimension=0)
[docs]
def add(
self,
vectors: Union[np.ndarray, List[List[float]]],
metadata: Optional[List[Dict[str, Any]]] = None,
):
"""Add vectors and metadata to the collection."""
vectors = np.array(vectors).astype(np.float32)
# Handle Matryoshka truncation during add if specified
if self.truncate_dim:
vectors = vectors[:, : self.truncate_dim]
if self.dimension is None:
self._init_indexer(vectors.shape[1])
elif vectors.shape[1] != self.dimension:
raise ValueError(
f"Vector dimension mismatch. Expected {self.dimension}, got {vectors.shape[1]}"
)
meta = metadata or [{} for _ in range(len(vectors))]
# Build/Update dense index
self.indexer.build_index(vectors.tolist(), meta)
# Build/Update sparse index if present
if self.sparse_indexer:
self.sparse_indexer.build_index(vectors.tolist(), meta)
# Keep local copy for I/O and non-native operations
if self._vectors is None:
self._vectors = vectors
else:
self._vectors = np.vstack([self._vectors, vectors])
self._metadata.extend(meta)
[docs]
def add_images(self, image_paths: List[str], model: str = "openai/clip-vit-base-patch32", metadata: Optional[List[Dict[str, Any]]] = None):
"""
Embed and add images to the collection.
"""
from llm import Embedder
emb = Embedder(model)
vectors = emb.embed_texts(image_paths)
meta = metadata or [{} for _ in range(len(image_paths))]
for i, path in enumerate(image_paths):
meta[i]["image_path"] = path
self.add(vectors, meta)
[docs]
def search(
self,
query: Union[np.ndarray, List[float]],
top_k: int = 5,
where: Optional[Dict[str, Any]] = None,
reranker: Optional[Union[callable, "RerankHandler"]] = None,
query_text: Optional[str] = None,
) -> List[Tuple[Dict[str, Any], float]]:
"""
Search the collection for the nearest neighbors.
Args:
query: Vector to search for.
top_k: Number of results to return.
where: Metadata filter dictionary.
reranker: A callable or RerankHandler for re-scoring.
query_text: Original text for reranking context.
"""
if self.indexer is None:
raise RuntimeError("Collection is empty. Add data before searching.")
# Increase search limit if reranking or filtering is requested
search_k = top_k
if where or reranker:
search_k = max(top_k * 10, 100)
query_vec = np.array(query).astype(np.float32)
if self.truncate_dim:
if len(query_vec.shape) == 1:
query_vec = query_vec[: self.truncate_dim]
else:
query_vec = query_vec[:, : self.truncate_dim]
def _process_single(q):
results = self.indexer.search(q.tolist(), top_k=search_k)
if where:
results = self._apply_filter(results, where)
if reranker:
# If it's a RerankHandler, use its rerank method
if hasattr(reranker, "rerank") and query_text:
results = reranker.rerank(query_text, results, top_k=top_k)
else:
results = reranker(q, results)
return results[:top_k]
if len(query_vec.shape) == 1:
return _process_single(query_vec)
else:
return [_process_single(q) for q in query_vec]
[docs]
def search_image(self, image_path: str, model: str = "openai/clip-vit-base-patch32", top_k: int = 5) -> List[Tuple[Dict[str, Any], float]]:
"""
Search for similar items using an image query.
"""
from llm import Embedder
emb = Embedder(model)
q_vec = emb.embed_query(image_path)
return self.search(q_vec, top_k=top_k)
[docs]
def search_trajectory(
self,
trajectory: Union[np.ndarray, List[List[float]]],
top_k: int = 5,
pooling: str = "mean",
where: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Dict[str, Any], float]]:
"""
Search for similar trajectories (sequences of vectors).
Args:
trajectory: Sequence of vectors representing a state/action trajectory.
top_k: Number of results to return.
pooling: Method to pool the trajectory into a single search vector ('mean' or 'max').
where: Metadata filter dictionary.
"""
traj_vecs = np.array(trajectory).astype(np.float32)
if len(traj_vecs.shape) != 2:
raise ValueError("Trajectory must be a 2D array (sequence of vectors).")
if pooling == "mean":
query_vec = np.mean(traj_vecs, axis=0)
elif pooling == "max":
query_vec = np.max(traj_vecs, axis=0)
else:
raise ValueError(f"Unknown pooling method: {pooling}")
return self.search(query_vec, top_k=top_k, where=where)
[docs]
def hybrid_search(
self,
query_vector: Union[np.ndarray, List[float]],
query_text: str,
top_k: int = 5,
dense_weight: float = 0.5,
sparse_weight: float = 0.5,
where: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Dict[str, Any], float]]:
"""
Perform hybrid search combining dense and sparse results using Reciprocal Rank Fusion (RRF).
"""
if not self.sparse_indexer:
raise RuntimeError(
"Sparse indexer not initialized. Initialize Collection with sparse_indexer_type."
)
# 1. Get Dense results
dense_results = self.search(query_vector, top_k=max(top_k * 2, 50), where=where)
# 2. Get Sparse results
sparse_results = self.sparse_indexer.search(query_text, top_k=max(top_k * 2, 50))
if where:
sparse_results = self._apply_filter(sparse_results, where)
# 3. Reciprocal Rank Fusion (RRF)
scores = {}
def _update_scores(results, weight):
for rank, (meta, _) in enumerate(results):
# Use 'text' or 'id' as a unique key for fusion
doc_key = meta.get("id") or meta.get("text") or str(meta)
if doc_key not in scores:
scores[doc_key] = {"meta": meta, "score": 0.0}
# RRF formula component: weight * (1 / (rank + k))
scores[doc_key]["score"] += weight * (1.0 / (rank + 60))
_update_scores(dense_results, dense_weight)
_update_scores(sparse_results, sparse_weight)
# Sort by fused score
sorted_results = sorted(scores.values(), key=lambda x: x["score"], reverse=True)
return [(item["meta"], item["score"]) for item in sorted_results[:top_k]]
[docs]
def benchmark(self, indexers: Optional[List[str]] = None, top_k: int = 5):
"""
Benchmark multiple indexers on the current collection data.
Args:
indexers: List of indexer names to compare (e.g. ["faiss", "hnswlib"]).
If None, benchmarks all available indexers.
top_k: Number of neighbors to search for during benchmark.
"""
if self._vectors is None:
raise RuntimeError("Collection is empty. Add data before benchmarking.")
from rich.console import Console
from benchmark import benchmark_single_indexer, display_results
from indexers import get_indexer_map
console = Console()
indexer_map = get_indexer_map()
if indexers is None:
selected = list(indexer_map.keys())
else:
selected = [i.lower() for i in indexers if i.lower() in indexer_map]
results = []
for name in selected:
res = benchmark_single_indexer(
name,
indexer_map[name],
self.dimension,
self._vectors.tolist(),
self._metadata,
console,
)
if res:
results.append(res)
if results:
display_results(results, console)
return results
[docs]
def evaluate(
self, indexer_type: str = "faiss-hnsw", top_k: int = 10, **kwargs
) -> Dict[str, Any]:
"""
Evaluate an indexer's recall and latency against an exact search baseline.
Returns:
Dictionary with 'recall' and 'latency_ms' metrics.
"""
if self._vectors is None:
raise RuntimeError("Collection is empty. Add data before evaluating.")
import time
from indexers.simple_indexer import SimpleIndexer
# 1. Exact Search Baseline
exact = SimpleIndexer(self.dimension)
exact.build_index(self._vectors.tolist(), self._metadata)
# 2. Candidate Indexer
indexer_map = get_indexer_map()
if indexer_type not in indexer_map:
raise ValueError(f"Indexer '{indexer_type}' not found.")
candidate_cls = indexer_map[indexer_type]
candidate = candidate_cls(self.dimension, **kwargs)
candidate.build_index(self._vectors.tolist(), self._metadata)
# 3. Sample queries (up to 100)
n_samples = min(100, len(self._vectors))
sample_indices = np.random.choice(len(self._vectors), n_samples, replace=False)
queries = self._vectors[sample_indices]
recalls = []
latencies = []
for q in queries:
q_list = q.tolist()
# Get exact ground truth IDs
exact_res = exact.search(q_list, top_k=top_k)
exact_ids = {meta.get("id") or meta.get("text") or str(meta) for meta, _ in exact_res}
# Get candidate results and measure latency
t0 = time.perf_counter()
cand_res = candidate.search(q_list, top_k=top_k)
latencies.append((time.perf_counter() - t0) * 1000)
cand_ids = {meta.get("id") or meta.get("text") or str(meta) for meta, _ in cand_res}
# Calculate intersection (Recall@K)
if exact_ids:
intersection = exact_ids.intersection(cand_ids)
recalls.append(len(intersection) / len(exact_ids))
else:
recalls.append(1.0)
return {
"indexer": indexer_type,
"recall": float(np.mean(recalls)),
"latency_ms": float(np.mean(latencies)),
"samples": n_samples,
}
[docs]
def export_to_production(self, backend: str, connection_url: str, collection_name: Optional[str] = None):
"""
One-click export from local Embenx collection to production clusters.
Supported backends: 'qdrant', 'milvus'.
"""
if self._vectors is None:
raise RuntimeError("Collection is empty. Add data before exporting.")
name = collection_name or self.name
if backend.lower() == "qdrant":
from qdrant_client import QdrantClient
from qdrant_client.http import models
client = QdrantClient(url=connection_url)
client.recreate_collection(
collection_name=name,
vectors_config=models.VectorParams(size=self.dimension, distance=models.Distance.COSINE),
)
client.upload_collection(
collection_name=name,
vectors=self._vectors,
payload=self._metadata,
ids=None
)
print(f"Successfully exported {len(self._vectors)} vectors to Qdrant at {connection_url}")
elif backend.lower() == "milvus":
from pymilvus import Collection as MilvusCollection
from pymilvus import CollectionSchema, DataType, FieldSchema, connections
connections.connect(alias="default", uri=connection_url)
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=self.dimension)
]
schema = CollectionSchema(fields, f"Embenx export of {name}")
mc = MilvusCollection(name, schema)
# Milvus usually needs flat list for insert
entities = [
self._vectors.tolist()
]
mc.insert(entities)
mc.flush()
print(f"Successfully exported {len(self._vectors)} vectors to Milvus at {connection_url}")
else:
raise ValueError(f"Export to backend '{backend}' not supported yet.")
def _apply_filter(self, results: List[Tuple[Dict[str, Any], float]], where: Dict[str, Any]):
filtered = []
for meta, dist in results:
match = True
for key, value in where.items():
if meta.get(key) != value:
match = False
break
if match:
filtered.append((meta, dist))
return filtered
[docs]
@classmethod
def from_numpy(cls, path: str, **kwargs):
"""Load a collection from a .npy or .npz file."""
data = np.load(path, allow_pickle=True)
col = cls(**kwargs)
if path.endswith(".npz"):
vectors = data.get("vectors")
metadata = data.get("metadata")
col.add(vectors, metadata.tolist() if metadata is not None else None)
else:
col.add(data)
return col
[docs]
@classmethod
def from_parquet(cls, path: str, vector_col: str = "vector", **kwargs):
"""Load a collection from a Parquet file."""
df = pd.read_parquet(path)
vectors = np.stack(df[vector_col].values)
metadata = df.drop(columns=[vector_col]).to_dict(orient="records")
col = cls(**kwargs)
col.add(vectors, metadata)
return col
[docs]
def generate_synthetic_queries(
self,
text_key: str = "text",
n_queries_per_doc: int = 1,
num_docs: int = 100,
model: str = "gpt-4o-mini",
custom_prompt: Optional[str] = None,
output_path: Optional[str] = None,
api_base: Optional[str] = None,
**llm_kwargs,
) -> List[Dict[str, Any]]:
"""
Generate synthetic search queries for documents in the collection using an LLM.
Supports local Ollama/vLLM via api_base and llm_kwargs.
"""
from llm import Generator
import random
generator = Generator(model_name=model, api_base=api_base, **llm_kwargs)
valid_docs = [m for m in self._metadata if m.get(text_key)]
if not valid_docs:
return []
sample_size = min(num_docs, len(valid_docs))
sampled_docs = random.sample(valid_docs, sample_size)
results = []
for i, doc in enumerate(sampled_docs):
text = doc[text_key]
if custom_prompt:
prompt = custom_prompt.format(text=text, n=n_queries_per_doc)
else:
prompt = (
f"Given the following document text, generate {n_queries_per_doc} "
f"diverse and realistic search queries that a user might type to find this document. "
f"Return ONLY the queries, one per line, without numbering or bullets.\n\n"
f"Document: {text}"
)
response = generator.generate(prompt)
if not response:
continue
lines = [q.lstrip("- *1234567890.\t").strip() for q in response.split("\n") if q.strip()]
queries = [q for q in lines if q]
for q in queries[:n_queries_per_doc]:
results.append({
"query": q,
"doc_id": doc.get("id"),
"doc_text": text
})
if output_path and results:
df = pd.DataFrame(results)
if output_path.endswith(".parquet"):
df.to_parquet(output_path)
elif output_path.endswith(".jsonl"):
df.to_json(output_path, orient="records", lines=True)
elif output_path.endswith(".csv"):
df.to_csv(output_path, index=False)
return results
[docs]
def to_parquet(self, path: str, vector_col: str = "vector"):
"""Save the collection to a Parquet file."""
if self._vectors is None:
raise RuntimeError("Collection is empty.")
df = pd.DataFrame(self._metadata)
df[vector_col] = list(self._vectors)
df.to_parquet(path)
def __repr__(self) -> str:
count = len(self._metadata) if self._metadata else 0
return f"Collection(name='{self.name}', size={count}, indexer='{self.indexer_type}', sparse='{self.sparse_indexer_type}')"
[docs]
class CacheCollection(Collection):
"""
Specialized collection for Retrieval-Augmented KV Caching (RA-KVC).
Supports storing high-dimensional activation tensors.
"""
[docs]
def add_cache(
self,
vectors: Union[np.ndarray, List[List[float]]],
activations: Dict[str, np.ndarray],
metadata: Optional[List[Dict[str, Any]]] = None,
quantize: bool = False
):
"""
Add embeddings and their associated KV cache activations.
"""
if save_file is None:
raise ImportError("safetensors is required for CacheCollection.")
os.makedirs(f"cache_{self.name}", exist_ok=True)
meta = metadata or [{} for _ in range(len(vectors))]
for i, m in enumerate(meta):
cache_id = m.get("id") or f"idx_{len(self._metadata) + i}"
cache_path = os.path.join(f"cache_{self.name}", f"{cache_id}.safetensors")
doc_activations = {k: v[i] for k, v in activations.items()}
if quantize:
# Simple 1-bit quantization (sign-based) as per TurboQuant concepts
doc_activations = {k: np.sign(v).astype(np.int8) for k, v in doc_activations.items()}
m["quantized"] = True
save_file(doc_activations, cache_path)
m["cache_path"] = cache_path
self.add(vectors, meta)
[docs]
def get_cache(self, metadata: Dict[str, Any]) -> Dict[str, np.ndarray]:
"""
Retrieve activations for a given metadata result.
"""
if load_file is None:
raise ImportError("safetensors is required for CacheCollection.")
path = metadata.get("cache_path")
if path and os.path.exists(path):
return load_file(path)
return {}
[docs]
class StateCollection(Collection):
"""
Specialized collection for State Space Model (SSM) hydration.
Supports storing hidden states (h0).
"""
[docs]
def add_states(
self,
vectors: Union[np.ndarray, List[List[float]]],
states: np.ndarray,
metadata: Optional[List[Dict[str, Any]]] = None,
):
"""
Add embeddings and their associated SSM hidden states.
"""
if save_file is None:
raise ImportError("safetensors is required for StateCollection.")
os.makedirs(f"states_{self.name}", exist_ok=True)
meta = metadata or [{} for _ in range(len(vectors))]
for i, m in enumerate(meta):
state_id = m.get("id") or f"state_{len(self._metadata) + i}"
state_path = os.path.join(f"states_{self.name}", f"{state_id}.safetensors")
# Store hidden state 'h'
save_file({"h": states[i]}, state_path)
m["state_path"] = state_path
self.add(vectors, meta)
[docs]
def get_state(self, metadata: Dict[str, Any]) -> np.ndarray:
"""
Retrieve hidden state for a given metadata result.
"""
if load_file is None:
raise ImportError("safetensors is required for StateCollection.")
path = metadata.get("state_path")
if path and os.path.exists(path):
return load_file(path)["h"]
return None
[docs]
class ClusterCollection(Collection):
"""
Specialized collection for ClusterKV-style optimizations.
Implements semantic clustering of vectors for improved retrieval throughput.
"""
def __init__(self, n_clusters: int = 10, **kwargs):
super().__init__(**kwargs)
self.n_clusters = n_clusters
self.kmeans = KMeans(n_clusters=n_clusters, random_state=42)
self.cluster_map = {} # cluster_id -> list of indices
[docs]
def cluster_data(self):
"""
Perform semantic clustering on the current collection data.
"""
if self._vectors is None or len(self._vectors) < self.n_clusters:
return
cluster_labels = self.kmeans.fit_predict(self._vectors)
self.cluster_map = {}
for i, label in enumerate(cluster_labels):
label = int(label)
if label not in self.cluster_map:
self.cluster_map[label] = []
self.cluster_map[label].append(i)
# Update metadata with cluster info
for i, label in enumerate(cluster_labels):
self._metadata[i]["cluster_id"] = int(label)
[docs]
def search_clustered(self, query: np.ndarray, top_k: int = 5) -> List[Tuple[Dict[str, Any], float]]:
"""
Search for vectors by first identifying the most relevant cluster.
"""
if self._vectors is None:
return []
# 1. Identify nearest cluster
query_vec = np.array(query).reshape(1, -1).astype(np.float32)
cluster_id = int(self.kmeans.predict(query_vec)[0])
# 2. Retrieve indices for this cluster
indices = self.cluster_map.get(cluster_id, [])
if not indices:
return self.search(query, top_k=top_k)
# 3. Brute force within cluster (simulating ClusterKV pattern)
cluster_vectors = self._vectors[indices]
cluster_metadata = [self._metadata[i] for i in indices]
# Simple cosine similarity within cluster
norms = np.linalg.norm(cluster_vectors, axis=1) * np.linalg.norm(query)
norms[norms == 0] = 1.0
similarities = np.dot(cluster_vectors, query.flatten()) / norms
results_idx = np.argsort(similarities)[::-1][:top_k]
return [(cluster_metadata[i], 1.0 - float(similarities[i])) for i in results_idx]
[docs]
class SpatialCollection(Collection):
"""
Specialized collection for ESWM (Episodic Spatial World Memory).
Supports navigation trajectories and spatial-aware retrieval.
"""
[docs]
def add_spatial(
self,
vectors: Union[np.ndarray, List[List[float]]],
coords: np.ndarray,
metadata: Optional[List[Dict[str, Any]]] = None,
):
"""
Add semantic embeddings and their associated spatial coordinates (x, y, z).
"""
meta = metadata or [{} for _ in range(len(vectors))]
for i, m in enumerate(meta):
m["coords"] = coords[i].tolist()
self.add(vectors, meta)
[docs]
def search_spatial(
self,
query_vector: np.ndarray,
current_coords: np.ndarray,
top_k: int = 5,
spatial_radius: float = 10.0
) -> List[Tuple[Dict[str, Any], float]]:
"""
Spatial-aware search that favors episodic memories near the current location.
"""
# 1. Perform semantic search first
results = self.search(query_vector, top_k=top_k * 2)
spatial_results = []
for meta, sem_dist in results:
item_coords = np.array(meta.get("coords", [0, 0, 0]))
# Euclidean distance in 3D space
euc_dist = np.linalg.norm(item_coords - current_coords)
# Spatial gating (ESWM pattern)
if euc_dist <= spatial_radius:
# Combine semantic distance and spatial proximity
# Higher weight to spatial proximity if needed
combined_score = sem_dist * (1.0 + (euc_dist / spatial_radius))
spatial_results.append((meta, float(combined_score)))
# Sort by combined score
spatial_results.sort(key=lambda x: x[1])
return spatial_results[:top_k]
[docs]
class TemporalCollection(Collection):
"""
Specialized collection for Echo-style temporal episodic memory.
Supports time-stamped embeddings and recency-biased retrieval.
"""
[docs]
def add_temporal(
self,
vectors: Union[np.ndarray, List[List[float]]],
timestamps: Optional[List[float]] = None,
metadata: Optional[List[Dict[str, Any]]] = None,
):
"""
Add embeddings with Unix timestamps.
"""
if timestamps is None:
timestamps = [time.time()] * len(vectors)
meta = metadata or [{} for _ in range(len(vectors))]
for i, m in enumerate(meta):
m["timestamp"] = float(timestamps[i])
self.add(vectors, meta)
[docs]
def search_temporal(
self,
query_vector: np.ndarray,
top_k: int = 5,
recency_weight: float = 0.5,
time_window: Optional[Tuple[float, float]] = None,
where: Optional[Dict[str, Any]] = None,
) -> List[Tuple[Dict[str, Any], float]]:
"""
Temporal-aware search that ranks results by similarity and recency.
"""
# 1. Semantic search
# Pass the where filter down to the base search if supported or apply here
results = self.search(query_vector, top_k=top_k * 5, where=where)
current_time = time.time()
temporal_results = []
for meta, sem_dist in results:
ts = meta.get("timestamp", 0.0)
# 2. Time window filtering
if time_window:
if not (time_window[0] <= ts <= time_window[1]):
continue
# 3. Recency scoring (normalized time difference)
# Smaller diff = more recent = higher boost
time_diff = max(0, current_time - ts)
# Simple decay: 1 / (1 + log(1 + time_diff))
recency_score = 1.0 / (1.0 + np.log1p(time_diff))
# Combined score (Distance is lower-better, Recency is higher-better)
# We convert recency to a 'temporal distance': 1 - recency
temporal_dist = 1.0 - recency_score
combined_score = (sem_dist * (1.0 - recency_weight)) + (temporal_dist * recency_weight)
temporal_results.append((meta, float(combined_score)))
temporal_results.sort(key=lambda x: x[1])
return temporal_results[:top_k]
[docs]
class AgenticCollection(Collection):
"""
Specialized collection for autonomous agent memory.
Supports search loops, feedback, and self-healing ranking.
"""
[docs]
def feedback(self, doc_id: str, label: str = "good"):
"""
Provide feedback on a retrieval result.
Adjusts metadata to influence future rankings.
"""
for i, m in enumerate(self._metadata):
if m.get("id") == doc_id:
if "feedback_score" not in m:
m["feedback_score"] = 0.0
if label == "good":
m["feedback_score"] += 0.5 # Stronger boost
elif label == "bad":
m["feedback_score"] -= 0.5 # Stronger demote
break
[docs]
def agentic_search(
self,
query_vector: np.ndarray,
top_k: int = 5,
feedback_weight: float = 1.0 # High default for testing
) -> List[Tuple[Dict[str, Any], float]]:
"""
Search with additive self-healing logic using stored feedback scores.
"""
# Initial search
results = self.search(query_vector, top_k=top_k * 3)
agentic_results = []
for meta, dist in results:
# We must fetch the latest feedback from self._metadata
# because search() returns a snapshot/copy.
doc_id = meta.get("id")
actual_meta = meta
for m in self._metadata:
if m.get("id") == doc_id:
actual_meta = m
break
fb_score = actual_meta.get("feedback_score", 0.0)
# Adjusted distance: positive feedback REDUCES distance (boosts result)
# Additive shift ensures even exact matches (dist=0) can be demoted.
adjusted_dist = dist - (fb_score * feedback_weight)
agentic_results.append((actual_meta, float(adjusted_dist)))
agentic_results.sort(key=lambda x: x[1])
return agentic_results[:top_k]
[docs]
class Session:
"""
Managed agentic session with automatic temporal decay and persistence.
"""
def __init__(self, session_id: str, dimension: int, storage_dir: str = ".embenx_sessions"):
self.session_id = session_id
self.storage_dir = storage_dir
self.path = os.path.join(storage_dir, f"{session_id}.parquet")
os.makedirs(storage_dir, exist_ok=True)
# Initialize with TemporalCollection for time-based features
if os.path.exists(self.path):
self.collection = TemporalCollection.from_parquet(self.path)
else:
self.collection = TemporalCollection(name=session_id, dimension=dimension)
[docs]
def add_interaction(self, vector: Union[np.ndarray, List[float]], text: str, **metadata):
"""
Add a new interaction to the session memory.
"""
meta = metadata or {}
meta["text"] = text
self.collection.add_temporal([vector], metadata=[meta])
self.collection.to_parquet(self.path)
[docs]
def retrieve_context(self, query_vector: np.ndarray, top_k: int = 5, recency_weight: float = 0.4):
"""
Retrieve relevant context from the session with recency bias.
"""
return self.collection.search_temporal(query_vector, top_k=top_k, recency_weight=recency_weight)
[docs]
def cleanup(self):
"""Delete session data."""
if os.path.exists(self.path):
os.remove(self.path)