import json
import os
from pathlib import Path
from typing import Optional
import chromadb
from llama_index.core import Settings, StorageContext, VectorStoreIndex
#from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.chroma import ChromaVectorStore
from indexer import DEFAULT_LOCAL_EMBEDDING_MODEL_NAME, build_embedding_function
def create_retriever_from_env(
vector_store_path: str,
similarity_top_k = 3
):
return Retriever(vector_store_path,
embed_endpoint = os.getenv('EMBED_ENDPOINT'),
embedding_model_name = os.getenv('EMBED_MODEL', DEFAULT_LOCAL_EMBEDDING_MODEL_NAME),
similarity_top_k = similarity_top_k)
class Retriever:
def __init__(self, vector_store_path: str,
embed_endpoint: Optional[str] = None,
embedding_model_name: Optional[str] = DEFAULT_LOCAL_EMBEDDING_MODEL_NAME,
similarity_top_k=3
):
"""
Initialize the retriever.
Args:
vector_store_path: Path to the ChromaDB vector store directory
"""
self.vector_store_path = Path(vector_store_path)
self.embed_model = build_embedding_function(embedding_model_name, embed_endpoint=embed_endpoint)
self.similarity_top_k = similarity_top_k
# Load the index
self._load_index()
def _load_index(self) -> None:
"""Load the vector store index."""
# Initialize ChromaDB
chroma_client = chromadb.PersistentClient(path=str(self.vector_store_path))
vector_store = ChromaVectorStore(chroma_collection=chroma_client.get_or_create_collection("documents"))
# Create storage context
storage_context = StorageContext.from_defaults(
vector_store=vector_store,
)
# Load the index
self.index = VectorStoreIndex.from_vector_store(
vector_store=vector_store,
storage_context=storage_context,
embed_model=self.embed_model
)
# Create retriever without LLM
self.retriever = self.index.as_retriever(
similarity_top_k = self.similarity_top_k
)
def query(self, question: str):
"""
Query the vector store with a question and return raw document chunks.
Args:
question: The question to ask
Returns:
The retrieved document chunks
"""
# Get relevant nodes
nodes = self.retriever.retrieve(question)
# Format the response
chunks = []
for node in nodes:
chunks.append({
'text': node.text,
'metadata': node.metadata
})
return chunks
def query_as_text(self, question: str):
return json.dumps(self.query(question), indent=4)