vector_store.py•4.44 kB
import os
import numpy as np
import faiss
import pickle
from typing import List, Dict, Any
class FAISSVectorStore:
def __init__(self, dimension: int = 768):
"""
Initialize a FAISS vector store.
Args:
dimension: Dimension of the embeddings
"""
self.dimension = dimension
self.index = faiss.IndexFlatL2(dimension) # L2 distance
self.documents = []
def index_documents(self, documents: List[Dict[str, Any]]):
"""
Index a list of documents.
Args:
documents: List of document dictionaries with 'embedding' and other metadata
"""
if not documents:
return
# Extract embeddings and store documents
embeddings = []
updated_docs = []
for doc in documents:
if 'embedding' in doc and isinstance(doc['embedding'], np.ndarray):
embeddings.append(doc['embedding'])
# Store document without embedding to save memory
doc_copy = doc.copy()
doc_copy.pop('embedding', None)
updated_docs.append(doc_copy)
else:
print(f"Warning: Document missing embedding, skipping: {doc.get('id', 'unknown')}")
if not embeddings:
print("No valid embeddings found in documents")
return
# Convert to numpy array and add to index
embeddings_array = np.array(embeddings).astype('float32')
self.index.add(embeddings_array)
# Store documents
start_idx = len(self.documents)
self.documents.extend(updated_docs)
print(f"Added {len(embeddings)} documents to index. Total: {len(self.documents)}")
def search(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Dict[str, Any]]:
"""
Search for similar documents using a query embedding.
Args:
query_embedding: The embedding vector of the query
top_k: Number of results to return
Returns:
List of document dictionaries with similarity scores
"""
if len(self.documents) == 0:
return []
# Ensure embedding is 2D array
if len(query_embedding.shape) == 1:
query_embedding = np.expand_dims(query_embedding, axis=0)
# Ensure correct data type
query_embedding = query_embedding.astype('float32')
# Limit top_k to number of documents
top_k = min(top_k, len(self.documents))
# Search the index
distances, indices = self.index.search(query_embedding, top_k)
# Return results with scores
results = []
for i, (dist, idx) in enumerate(zip(distances[0], indices[0])):
if idx < len(self.documents) and idx >= 0:
doc = self.documents[idx].copy()
doc['score'] = float(1.0 / (1.0 + dist)) # Convert distance to score
results.append(doc)
return results
def save(self, filepath: str):
"""
Save the index and documents to disk.
Args:
filepath: Path to save the index
"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
# Save the index
faiss.write_index(self.index, filepath)
# Save the documents
with open(f"{filepath}.documents", 'wb') as f:
pickle.dump(self.documents, f)
print(f"Saved index with {len(self.documents)} documents to {filepath}")
def load(self, filepath: str):
"""
Load the index and documents from disk.
Args:
filepath: Path to load the index from
"""
# Load the index
if os.path.exists(filepath):
self.index = faiss.read_index(filepath)
else:
raise FileNotFoundError(f"Index file not found: {filepath}")
# Load the documents
doc_path = f"{filepath}.documents"
if os.path.exists(doc_path):
with open(doc_path, 'rb') as f:
self.documents = pickle.load(f)
print(f"Loaded index with {len(self.documents)} documents from {filepath}")