"""
Embedding Pipeline
Handles vector embedding generation for document chunks and FAISS index management.
Supports local embedding models and efficient similarity search.
"""
import json
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from src.ingest import ChunkMetadata, DocumentChunk
# Import utility functions
from .utils import (
get_logger,
log_memory_usage,
log_performance,
)
logger = get_logger(__name__)
[docs]
@dataclass
class EmbeddingConfig:
"""Configuration for embedding generation."""
model_name: str
normalize_embeddings: bool
device: str
similarity_threshold: float
top_k: int
[docs]
class EmbeddingModel:
"""Handles embedding model loading and text embedding generation."""
[docs]
def __init__(self, config: EmbeddingConfig):
"""
Initialize embedding model.
Args:
config: Embedding configuration
"""
self.config = config
self.model = None
self._load_model()
logger.info(f"Initialized embedding model: {config.model_name}")
def _load_model(self) -> None:
"""Load the sentence transformer model."""
try:
self.model = SentenceTransformer(
self.config.model_name, device=self.config.device
)
logger.info(
f"Loaded model {self.config.model_name} on {self.config.device}"
)
except Exception as e:
logger.error(f"Failed to load model {self.config.model_name}: {e}")
raise
[docs]
def generate_embeddings(self, texts: list[str]) -> np.ndarray:
"""
Generate embeddings for a list of texts.
Args:
texts: List of text strings to embed
Returns:
numpy array of embeddings
"""
if not texts:
return np.array([])
try:
embeddings = self.model.encode(
texts,
normalize_embeddings=self.config.normalize_embeddings,
show_progress_bar=False,
)
return embeddings
except Exception as e:
logger.error(f"Error generating embeddings: {e}")
raise
[docs]
def generate_single_embedding(self, text: str) -> np.ndarray:
"""
Generate embedding for a single text.
Args:
text: Text string to embed
Returns:
numpy array of embedding
"""
return self.generate_embeddings([text])[0]
[docs]
class FAISSIndex:
"""Handles FAISS index creation and management."""
[docs]
def __init__(self, dimension: int, index_type: str = "IndexFlatIP"):
"""
Initialize FAISS index.
Args:
dimension: Dimension of embeddings
index_type: Type of FAISS index to use
"""
self.dimension = dimension
self.index_type = index_type
self.index = None
self.chunk_metadata = []
self._create_index()
logger.info(f"Initialized FAISS index: {index_type} with dimension {dimension}")
def _create_index(self) -> None:
"""Create FAISS index based on type."""
if self.index_type == "IndexFlatIP":
self.index = faiss.IndexFlatIP(self.dimension)
elif self.index_type == "IndexFlatL2":
self.index = faiss.IndexFlatL2(self.dimension)
elif self.index_type == "IndexIVFFlat":
# For IVF, we need to train on some data first
quantizer = faiss.IndexFlatL2(self.dimension)
self.index = faiss.IndexIVFFlat(quantizer, self.dimension, 100)
else:
raise ValueError(f"Unsupported index type: {self.index_type}")
[docs]
def add_embeddings(
self, embeddings: np.ndarray, chunk_metadata: list[ChunkMetadata]
) -> None:
"""
Add embeddings to the index.
Args:
embeddings: numpy array of embeddings
chunk_metadata: List of chunk metadata corresponding to embeddings
"""
if len(embeddings) != len(chunk_metadata):
raise ValueError(
"Number of embeddings must match number of metadata entries"
)
# Add embeddings to index
self.index.add(embeddings.astype("float32"))
# Store metadata
self.chunk_metadata.extend(chunk_metadata)
logger.info(f"Added {len(embeddings)} embeddings to index")
[docs]
def search(
self, query_embedding: np.ndarray, k: int
) -> tuple[np.ndarray, np.ndarray]:
"""
Search for similar embeddings.
Args:
query_embedding: Query embedding
k: Number of results to return
Returns:
Tuple of (distances, indices)
"""
if self.index.ntotal == 0:
return np.array([]), np.array([])
# Reshape query embedding for FAISS
query_embedding = query_embedding.reshape(1, -1).astype("float32")
# Search
distances, indices = self.index.search(
query_embedding, min(k, self.index.ntotal)
)
return distances[0], indices[0]
[docs]
def get_chunk_by_index(self, index: int) -> ChunkMetadata | None:
"""
Get chunk metadata by index.
Args:
index: Index in the metadata list
Returns:
Chunk metadata or None if index is invalid
"""
if 0 <= index < len(self.chunk_metadata):
return self.chunk_metadata[index]
return None
[docs]
def get_total_embeddings(self) -> int:
"""Get total number of embeddings in index."""
return self.index.ntotal
[docs]
def save_index(self, index_path: Path) -> None:
"""
Save FAISS index and metadata to disk.
Args:
index_path: Path to save index
"""
index_path.mkdir(parents=True, exist_ok=True)
# Save FAISS index
faiss.write_index(self.index, str(index_path / "faiss.index"))
# Save metadata
metadata_file = index_path / "chunk_metadata.json"
with open(metadata_file, "w", encoding="utf-8") as f:
metadata_list = [asdict(meta) for meta in self.chunk_metadata]
json.dump(metadata_list, f, indent=2, ensure_ascii=False)
# Save index info
info = {
"dimension": self.dimension,
"index_type": self.index_type,
"total_embeddings": self.get_total_embeddings(),
"model_name": "sentence-transformers", # This will be updated by EmbeddingPipeline
}
info_file = index_path / "index_info.json"
with open(info_file, "w", encoding="utf-8") as f:
json.dump(info, f, indent=2)
logger.info(f"Saved index to {index_path}")
[docs]
def load_index(self, index_path: Path) -> None:
"""
Load FAISS index and metadata from disk.
Args:
index_path: Path to load index from
"""
# Load FAISS index
faiss_index_file = index_path / "faiss.index"
if not faiss_index_file.exists():
raise FileNotFoundError(f"FAISS index file not found: {faiss_index_file}")
self.index = faiss.read_index(str(faiss_index_file))
# Load metadata
metadata_file = index_path / "chunk_metadata.json"
if not metadata_file.exists():
raise FileNotFoundError(f"Metadata file not found: {metadata_file}")
with open(metadata_file, encoding="utf-8") as f:
metadata_list = json.load(f)
self.chunk_metadata = [ChunkMetadata(**meta) for meta in metadata_list]
# Load index info
info_file = index_path / "index_info.json"
if info_file.exists():
with open(info_file, encoding="utf-8") as f:
info = json.load(f)
self.dimension = info.get("dimension", self.dimension)
self.index_type = info.get("index_type", self.index_type)
logger.info(
f"Loaded index from {index_path} with {len(self.chunk_metadata)} chunks"
)
[docs]
class EmbeddingPipeline:
"""Main class for embedding generation and index management."""
[docs]
def __init__(self, config: dict[str, Any]):
"""
Initialize embedding pipeline.
Args:
config: Configuration dictionary
"""
self.config = config
# Initialize embedding model
embedding_config = EmbeddingConfig(
model_name=config.get("embedding", {}).get(
"model_name", "all-MiniLM-L6-v2"
),
normalize_embeddings=config.get("embedding", {}).get(
"normalize_embeddings", True
),
device=config.get("embedding", {}).get("device", "cpu"),
similarity_threshold=config.get("embedding", {}).get(
"similarity_threshold", 0.7
),
top_k=config.get("embedding", {}).get("top_k", 5),
)
self.embedding_model = EmbeddingModel(embedding_config)
self.faiss_index = None
self.index_loaded = False
logger.info("Initialized EmbeddingPipeline")
[docs]
@log_performance
def create_embeddings_from_chunks(self, chunks: list[DocumentChunk]) -> None:
"""
Create embeddings from document chunks and build FAISS index.
Args:
chunks: List of document chunks
"""
if not chunks:
logger.warning("No chunks provided for embedding")
return
logger.info(f"Creating embeddings for {len(chunks)} chunks")
log_memory_usage(logger, "Before embedding creation")
# Extract texts and metadata
texts = [chunk.text for chunk in chunks]
metadata = [chunk.metadata for chunk in chunks]
# Get batch size from config
batch_size = self.config.get("system", {}).get("batch_size", 32)
# Process embeddings in batches
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i : i + batch_size]
logger.debug(
f"Processing embedding batch {i // batch_size + 1}/{(len(texts) + batch_size - 1) // batch_size}"
)
try:
batch_embeddings = self.embedding_model.generate_embeddings(batch_texts)
all_embeddings.extend(batch_embeddings)
except Exception as e:
logger.error(f"Error encoding batch: {e}")
# Add zero embeddings for failed batch
zero_embeddings = [
np.zeros(self.embedding_model.config.dimension)
] * len(batch_texts)
all_embeddings.extend(zero_embeddings)
# Convert to numpy array
embeddings = np.array(all_embeddings)
# Create FAISS index
dimension = embeddings.shape[1]
self.faiss_index = FAISSIndex(dimension)
# Add embeddings to index
self.faiss_index.add_embeddings(embeddings, metadata)
logger.info(f"Created embeddings with dimension {dimension}")
log_memory_usage(logger, "After embedding creation")
[docs]
def save_index(self, index_path: Path) -> None:
"""
Save the FAISS index and metadata.
Args:
index_path: Path to save index
"""
if self.faiss_index is None:
raise ValueError(
"No index to save. Run create_embeddings_from_chunks first."
)
self.faiss_index.save_index(index_path)
# Save embedding model info
model_info = {
"model_name": self.embedding_model.config.model_name,
"normalize_embeddings": self.embedding_model.config.normalize_embeddings,
"device": self.embedding_model.config.device,
"similarity_threshold": self.embedding_model.config.similarity_threshold,
"top_k": self.embedding_model.config.top_k,
}
model_info_file = index_path / "model_info.json"
with open(model_info_file, "w", encoding="utf-8") as f:
json.dump(model_info, f, indent=2)
[docs]
def load_index(self, index_path: Path) -> None:
"""
Load the FAISS index and metadata.
Args:
index_path: Path to load index from
"""
if not index_path.exists():
raise FileNotFoundError(f"Index directory not found: {index_path}")
# Load FAISS index
self.faiss_index = FAISSIndex(384) # Default dimension, will be updated
self.faiss_index.load_index(index_path)
# Load model info
model_info_file = index_path / "model_info.json"
if model_info_file.exists():
with open(model_info_file, encoding="utf-8") as f:
model_info = json.load(f)
# Update embedding model config
self.embedding_model.config.model_name = model_info.get(
"model_name", "all-MiniLM-L6-v2"
)
self.embedding_model.config.normalize_embeddings = model_info.get(
"normalize_embeddings", True
)
self.embedding_model.config.device = model_info.get("device", "cpu")
self.embedding_model.config.similarity_threshold = model_info.get(
"similarity_threshold", 0.7
)
self.embedding_model.config.top_k = model_info.get("top_k", 5)
self.index_loaded = True
logger.info(
f"Loaded index with {self.faiss_index.get_total_embeddings()} embeddings"
)
[docs]
def search_similar_chunks(
self, query: str, top_k: int | None = None
) -> list[tuple[DocumentChunk, float]]:
"""
Search for chunks similar to the query.
Args:
query: Query text
top_k: Number of results to return (uses config default if None)
Returns:
List of (chunk, similarity_score) tuples
"""
if self.faiss_index is None:
raise ValueError("No index loaded. Load index first.")
if top_k is None:
top_k = self.embedding_model.config.top_k
# Generate query embedding
query_embedding = self.embedding_model.generate_single_embedding(query)
# Search index
distances, indices = self.faiss_index.search(query_embedding, top_k)
# Debug logging
logger.debug(f"FAISS search returned {len(distances)} results")
logger.debug(f"Distances: {distances}")
logger.debug(f"Indices: {indices}")
# Convert to similarity scores (for IP index, higher is better)
if self.faiss_index.index_type == "IndexFlatIP":
similarities = distances
else:
# For L2 distance, convert to similarity (lower distance = higher similarity)
similarities = 1.0 / (1.0 + distances)
logger.debug(f"Similarities: {similarities}")
# Return all results (threshold filtering will be done by query engine)
results = []
for idx, similarity in zip(indices, similarities, strict=False):
metadata = self.faiss_index.get_chunk_by_index(idx)
if metadata:
# Reconstruct chunk (we don't store text in index to save space)
# In a real implementation, you might want to store text or load from chunks.json
chunk = DocumentChunk(
text="[Text not stored in index]", metadata=metadata
)
results.append((chunk, float(similarity)))
logger.debug(f"Returning {len(results)} results")
return results
[docs]
def get_index_stats(self) -> dict[str, Any]:
"""
Get statistics about the index.
Returns:
Dictionary with index statistics
"""
if self.faiss_index is None:
return {"error": "No index loaded"}
return {
"total_embeddings": self.faiss_index.get_total_embeddings(),
"dimension": self.faiss_index.dimension,
"index_type": self.faiss_index.index_type,
"model_name": self.embedding_model.config.model_name,
"similarity_threshold": self.embedding_model.config.similarity_threshold,
"top_k": self.embedding_model.config.top_k,
}
[docs]
def create_embeddings_from_chunks_file(
chunks_file: Path, config: dict[str, Any], output_path: Path
) -> None:
"""
Create embeddings from a chunks.json file.
Args:
chunks_file: Path to chunks.json file
config: Configuration dictionary
output_path: Path to save index
"""
# Load chunks
with open(chunks_file, encoding="utf-8") as f:
chunks_data = json.load(f)
# Convert to DocumentChunk objects
chunks = []
for chunk_data in chunks_data:
metadata = ChunkMetadata(**chunk_data["metadata"])
chunk = DocumentChunk(text=chunk_data["text"], metadata=metadata)
chunks.append(chunk)
logger.info(f"Loaded {len(chunks)} chunks from {chunks_file}")
# Create embeddings
pipeline = EmbeddingPipeline(config)
pipeline.create_embeddings_from_chunks(chunks)
# Save index
pipeline.save_index(output_path)
# Print stats
stats = pipeline.get_index_stats()
logger.info(f"Created index with {stats['total_embeddings']} embeddings")
logger.info(f"Index dimension: {stats['dimension']}")
logger.info(f"Model used: {stats['model_name']}")
[docs]
def load_embedding_pipeline(
config: dict[str, Any], index_path: Path
) -> EmbeddingPipeline:
"""
Load an embedding pipeline with existing index.
Args:
config: Configuration dictionary
index_path: Path to index directory
Returns:
Loaded EmbeddingPipeline
"""
pipeline = EmbeddingPipeline(config)
pipeline.load_index(index_path)
return pipeline