"""Helper functions to perform a query on the chroma database"""
import os
from datetime import datetime
from typing import Optional
from haystack import Pipeline
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack_integrations.document_stores.chroma import ChromaDocumentStore
from haystack_integrations.components.retrievers.chroma import ChromaEmbeddingRetriever
from parameters import EMBEDDING_MODEL, COLLECTION_NAME, PERSIST_PATH, DEFAULT_TOP_K
# Global variables for document stores and pipelines
_document_store: Optional[ChromaDocumentStore] = None
_retrieval_pipeline: Optional[Pipeline] = None
def get_document_store() -> ChromaDocumentStore:
"""Get or create the ChromaDB document store."""
global _document_store
if _document_store is None:
persist_path = os.path.join(os.path.dirname(__file__), PERSIST_PATH)
_document_store = ChromaDocumentStore(
collection_name=COLLECTION_NAME,
persist_path=persist_path,
embedding_function="default",
)
return _document_store
def get_retrieval_pipeline(top_k: int = 5) -> Pipeline:
"""Get or create the retrieval pipeline."""
global _retrieval_pipeline
if _retrieval_pipeline is None:
document_store = get_document_store()
pipeline = Pipeline()
pipeline.add_component(
"embedder",
SentenceTransformersTextEmbedder(model=EMBEDDING_MODEL),
)
pipeline.add_component(
"retriever",
ChromaEmbeddingRetriever(document_store=document_store, top_k=top_k),
)
pipeline.connect("embedder.embedding", "retriever.query_embedding")
_retrieval_pipeline = pipeline
return _retrieval_pipeline
def perform_rag_lookup(
query: str, source: str = "both", top_k: int = DEFAULT_TOP_K
) -> dict:
"""
Perform RAG lookup in the ChromaDB database.
Args:
query: The search query
source: Where to search ('documents', 'repos', or 'both')
top_k: Number of results to return
Returns:
Dictionary with search results and metadata
"""
try:
# Get the retrieval pipeline
pipeline = get_retrieval_pipeline(top_k=top_k)
# Run the query
result = pipeline.run({"embedder": {"text": query}})
# Extract documents from result
documents = result.get("retriever", {}).get("documents", [])
# Filter by source if needed
if source != "both":
filtered_docs = []
for doc in documents:
source_file = doc.meta.get("source_file", "")
if source == "documents" and "Documents" in source_file:
filtered_docs.append(doc)
elif source == "repos" and "Repos" in source_file:
filtered_docs.append(doc)
documents = filtered_docs
# Format results
results = []
for doc in documents:
result_item = {
"content": doc.content,
"score": float(doc.score)
if hasattr(doc, "score") and doc.score
else 0.0,
"metadata": {
"file_name": doc.meta.get("file_name", "Unknown"),
"source_file": doc.meta.get("source_file", "Unknown"),
},
}
# Add any additional metadata
for key, value in doc.meta.items():
if key not in ["file_name", "source_file"]:
result_item["metadata"][key] = value
results.append(result_item)
# Count documents by source
docs_count = 0
repos_count = 0
for doc in documents:
source_file = doc.meta.get("source_file", "")
if "Documents" in source_file:
docs_count += 1
elif "Repos" in source_file:
repos_count += 1
return {
"status": "success",
"query": query,
"source": source,
"timestamp": datetime.utcnow().isoformat() + "Z",
"results": results,
"metadata": {
"documents_searched": docs_count,
"repos_searched": repos_count,
"total_matches": len(results),
},
}
except Exception as e:
return {
"status": "error",
"query": query,
"source": source,
"timestamp": datetime.utcnow().isoformat() + "Z",
"error": str(e),
"results": [],
"metadata": {
"documents_searched": 0,
"repos_searched": 0,
"total_matches": 0,
},
}
def test_query(
query: str = "test query", source: str = "both", top_k: int = DEFAULT_TOP_K
):
"""
Test function to perform a RAG lookup query.
Args:
query: The search query to test
source: Where to search ('documents', 'repos', or 'both')
top_k: Number of results to return
"""
print(f"\n{'=' * 60}")
print("Testing RAG Lookup")
print(f"{'=' * 60}")
print(f"Query: {query}")
print(f"Source: {source}")
print(f"Top K: {top_k}")
print(f"{'=' * 60}\n")
# Perform the lookup
result = perform_rag_lookup(query, source, top_k)
for i, r in enumerate(result["results"]):
print(
f"\033[1m{i + 1}. score={r['score']} document={r['metadata'].get('file_name')}\033[0m"
)
print(f"content={r['content']}")
if __name__ == "__main__":
# Test with a sample query
# You can modify these parameters to test different queries
test_query(query="xml format", source="both", top_k=5)
print()