import os
import sys
from rich.console import Console
from rich_utils import (
display_search_results,
display_separator,
display_answer,
display_error,
display_instructions,
display_reranking_info,
display_collection_info,
display_welcome_banner,
get_user_question,
show_processing_step,
)
import requests
import numpy as np
from typing import List
from qdrant_client import QdrantClient
import json
from dotenv import load_dotenv
load_dotenv()
def create_connection_qdrant():
url = os.getenv("QDRANT_URL")
api_key = os.getenv("QDRANT_API_KEY")
client = QdrantClient(url=url, api_key=api_key)
return client
# Your existing methods here (keep them as provided)
def vector_search(
name_collection: str, query_vector: List[float], limit: int = 5
) -> list:
"""
Search for similar vectors
:param name_collection: name collection
:param query_vector: query vector
:param limit: limit closest points
:return:
"""
client = create_connection_qdrant()
query_vector = np.array(query_vector)
hits = client.search(
collection_name=name_collection, query_vector=query_vector, limit=limit
)
return [el.model_dump() for el in hits]
def rerank_documents(query: str, documents: List[str]):
"""
Rerank texts
:param query: user query
:param documents: documents for rerank
:return: reranked list of documents
"""
url = os.getenv("RERANK_URL")
rerank_model = os.getenv("RERANK_MODEL")
data = {"model": rerank_model, "query": query, "documents": documents}
result = requests.post(url, json=data).json()
return result
def llm_chat_completion(text: str) -> str:
"""
Generate answer for text
:param text:
:return: answer
"""
url = os.getenv("LLM_SERVICE_CHAT_COMPLETIONS_URL")
api_key = os.getenv("LLM_SERVICE_API_KEY")
model_name = os.getenv("LLM_SERVICE_MODEL")
data = {"model": model_name, "messages": [{"content": text, "role": "user"}]}
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
result = requests.post(
url,
data=json.dumps(data),
headers=headers,
).json()
return result["choices"][0]["message"]["content"]
def get_embedding(text: str) -> List[float]:
"""Generate embedding for input text using embedding service"""
embedding_url = os.getenv("EMBEDDING_URL")
embedding_model = os.getenv("EMBEDDING_MODEL")
headers = {
"Authorization": f"Bearer {os.getenv('EMBEDDING_API_KEY')}",
"Content-Type": "application/json",
}
data = {"model": embedding_model, "input": text}
response = requests.post(embedding_url, headers=headers, json=data)
response.raise_for_status()
return response.json()["data"][0]["embedding"]
def main():
console = Console()
if len(sys.argv) != 2:
console.print("[red]β Usage: python rag.py <collection_name>[/red]")
sys.exit(1)
collection_name = sys.argv[1]
# Display welcome interface
display_welcome_banner(console)
display_collection_info(collection_name, console)
display_instructions(console)
query_count = 0
while True:
try:
# Get user question
query = get_user_question(console)
# Handle special commands
if query.lower() in ["exit", "quit"]:
console.print(
"\n[cyan]π Thank you for using RAG Assistant! Goodbye![/cyan]"
)
break
elif query.lower() == "clear":
console.clear()
display_welcome_banner(console)
display_collection_info(collection_name, console)
display_instructions(console)
continue
elif not query:
continue
query_count += 1
console.print(f"\n[dim]Query #{query_count}[/dim]")
# Step 1: Generate embedding
show_processing_step("π§ Generating query embedding", console, "blue")
query_vector = get_embedding(query)
# Step 2: Search documents
show_processing_step("π Searching relevant documents", console, "cyan")
results = vector_search(collection_name, query_vector, limit=5)
# Display search results
documents = display_search_results(results, console)
if not documents:
display_separator(console)
continue
# Step 3: Rerank documents
show_processing_step("π― Reranking documents by relevance", console, "magenta")
reranked = rerank_documents(query, documents)
# Display reranking info
display_reranking_info(reranked, console)
# Get top 3 documents
top_docs = [
documents[result["index"]]
for result in sorted(
reranked["results"],
key=lambda x: x["relevance_score"],
reverse=True,
)[:3]
]
# Prepare context
context = "\n".join(f"{i + 1}. {doc}" for i, doc in enumerate(top_docs))
prompt = f"""Based on the following information:
{context}
Give me a short answer to the question: {query}
Answer:"""
# Step 4: Generate answer
show_processing_step("β¨ Generating AI response", console, "green")
answer = llm_chat_completion(prompt)
# Display final answer
display_answer(answer, console)
# Add separator between queries
display_separator(console)
except KeyboardInterrupt:
console.print("\n[yellow]β οΈ Interrupted by user[/yellow]")
console.print("[cyan]π Goodbye![/cyan]")
break
except Exception as e:
display_error(str(e), console)
display_separator(console)
console.print_exception(show_locals=True)
continue
if __name__ == "__main__":
main()