rag_integration.py•17.7 kB
#!/usr/bin/env python
"""
RAG Integration Module for MCP Server
This module provides integration between the FAISS vector store
and LLM APIs for Retrieval-Augmented Generation (RAG).
"""
import os
import json
import logging
import argparse
import re
from typing import List, Dict, Any, Optional, Tuple
from dotenv import load_dotenv
import requests
from sentence_transformers import SentenceTransformer
from mcp_server.models.vector_store import FAISSVectorStore
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Check .env.example file for OPENAI_API_KEY
dotenv_path = os.path.join(os.path.dirname(__file__), '.env.example')
api_key_set = False
# Check if the key is set in the .env.example file
if os.path.exists(dotenv_path):
try:
with open(dotenv_path, 'r') as f:
content = f.read()
if re.search(r'OPENAI_API_KEY=\S+', content):
api_key_set = True
except Exception as e:
logger.error(f"Error reading .env.example file: {str(e)}")
# Load environment variables from .env.example file
load_dotenv(dotenv_path=dotenv_path)
logger.info(f"Loaded environment variables from .env.example")
# Check if OPENAI_API_KEY is properly set
openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key or openai_api_key.strip() == "" or not api_key_set:
logger.warning("OPENAI_API_KEY is not set in .env.example file")
print("\n" + "=" * 80)
print("WARNING: OPENAI_API_KEY is not set or is empty in the .env.example file.")
print("To set it up:")
print(f"1. Open the file: {dotenv_path}")
print("2. Add your API key to the OPENAI_API_KEY line (e.g., OPENAI_API_KEY=sk-your-key-here)")
print("3. Save the file and run the script again")
print("")
print("Alternatively, you can provide the API key via command line:")
print("python -m mcp_server.rag_integration \"Your query\" --api-key=your_openai_api_key")
print("=" * 80 + "\n")
class RagLlmIntegration:
"""Integrates FAISS vector store with LLM APIs for RAG"""
def __init__(self, index_file: str = None, api_key: str = None,
api_url: str = None, model: str = None):
"""
Initialize the RAG-LLM integration.
Args:
index_file: Path to FAISS index file
api_key: LLM API key (defaults to OPENAI_API_KEY env var)
api_url: LLM API URL (defaults to env var or OpenAI)
model: LLM model name (defaults to env var or GPT-3.5)
"""
# Set up embedding model
self.embedding_model = SentenceTransformer(
os.getenv("EMBEDDING_MODEL", "all-MiniLM-L6-v2")
)
# Set up vector store
self.vector_store = FAISSVectorStore()
# Default index file from env var or default path
self.index_file = index_file or os.getenv("INDEX_FILE", "data/faiss_index.bin")
# Load index if it exists
if os.path.exists(self.index_file):
logger.info(f"Loading index from {self.index_file}")
self.vector_store.load(self.index_file)
logger.info(f"Loaded index with {len(self.vector_store.documents)} documents")
else:
logger.error(f"Index file not found: {self.index_file}")
raise FileNotFoundError(f"Index file not found: {self.index_file}")
# LLM API configuration
self.api_key = api_key or os.getenv("OPENAI_API_KEY", "")
self.api_url = api_url or os.getenv(
"LLM_API_URL", "https://api.openai.com/v1/chat/completions"
)
self.model = model or os.getenv("LLM_MODEL", "gpt-4o")
# Validate API key if available
if not self.api_key:
logger.warning("No OpenAI API key provided. LLM responses will be simulated.")
print("\nWARNING: No OpenAI API key found. Set the OPENAI_API_KEY environment variable.")
print("You can set it temporarily for this run with: export OPENAI_API_KEY=your_api_key_here\n")
def retrieve_documents(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""
Retrieve relevant documents for a query.
Args:
query: The search query
top_k: Number of documents to retrieve
Returns:
List of retrieved documents with scores
"""
# Get query embedding
query_embedding = self.embedding_model.encode(query)
# Search vector store
results = self.vector_store.search(query_embedding, top_k)
logger.info(f"Retrieved {len(results)} documents for query: {query}")
# Print documents to console
print("\n" + "=" * 80)
print("RETRIEVED DOCUMENTS:")
for i, doc in enumerate(results):
print(f"\nDocument {i+1} (Score: {doc['score']:.4f}):")
print(f"Path: {doc['path']}")
print("-" * 50)
print(f"Content:\n{doc['content'][:500]}..." if len(doc['content']) > 500 else f"Content:\n{doc['content']}")
print("-" * 50)
print("=" * 80 + "\n")
return results
def format_retrieved_context(self, documents: List[Dict[str, Any]]) -> str:
"""
Format retrieved documents into a context string for the LLM.
Args:
documents: List of retrieved documents
Returns:
Formatted context string
"""
context = "\n\n".join([
f"Document {i+1} (Score: {doc['score']:.4f}):\n{doc['content']}"
for i, doc in enumerate(documents)
])
return context
def call_llm_api(self, query: str, context: str) -> str:
"""
Call LLM API with the query and retrieved context.
Args:
query: User query
context: Retrieved document context
Returns:
LLM response text
"""
if not self.api_key:
logger.warning("No LLM API key provided, simulating LLM response")
return self._simulate_llm_response(query, context)
# Create prompt with retrieved context
system_prompt = (
"You are a helpful assistant that answers questions about the Move programming "
"language and the Sui blockchain. Use the provided context to answer the question. "
"If the context doesn't contain the information needed, say so instead of making up "
"an answer."
)
user_prompt = f"Context:\n{context}\n\nQuestion: {query}"
# Prepare request based on API
if "openai" in self.api_url.lower():
# OpenAI-style API
payload = {
"model": self.model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
elif "anthropic" in self.api_url.lower():
# Anthropic-style API
payload = {
"model": self.model,
"max_tokens": 1024,
"messages": [
{"role": "user", "content": f"{system_prompt}\n\n{user_prompt}"}
]
}
headers = {
"Content-Type": "application/json",
"x-api-key": self.api_key
}
else:
# Generic API, best effort
payload = {
"model": self.model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
"temperature": 0.7
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
try:
logger.info(f"Calling LLM API with model: {self.model}")
print(f"\nSending request to LLM API using model: {self.model}...")
response = requests.post(
self.api_url,
headers=headers,
json=payload,
timeout=30
)
if response.status_code != 200:
error_msg = f"API error: {response.status_code} - {response.text}"
logger.error(error_msg)
# Provide a more helpful error message
if response.status_code == 401:
return "Error: Invalid API key. Please provide a valid OpenAI API key."
elif response.status_code == 404:
return f"Error: Model '{self.model}' not found. Try using a different model name like 'gpt-4o' or 'gpt-3.5-turbo'."
elif response.status_code == 429:
return "Error: Rate limit exceeded. Please try again later or check your OpenAI account usage limits."
else:
return f"Error calling LLM API: {response.status_code} - {response.text}"
result = response.json()
# Extract answer based on API response format
if "openai" in self.api_url.lower():
# OpenAI format
answer = result.get("choices", [{}])[0].get("message", {}).get("content", "")
elif "anthropic" in self.api_url.lower():
# Anthropic format
answer = result.get("content", [{}])[0].get("text", "")
else:
# Generic fallback
answer = result.get("choices", [{}])[0].get("message", {}).get("content", "")
if not answer and "output" in result:
answer = result.get("output", "")
return answer
except requests.exceptions.Timeout:
logger.error("Timeout error calling LLM API")
return "Error: Request to LLM API timed out. Please try again later."
except requests.exceptions.ConnectionError:
logger.error("Connection error calling LLM API")
return "Error: Could not connect to LLM API. Please check your internet connection."
except Exception as e:
logger.error(f"Error calling LLM API: {str(e)}")
return f"Error calling LLM API: {str(e)}"
def _simulate_llm_response(self, query: str, context: str) -> str:
"""
Simulate an LLM response for testing without an API key.
Args:
query: User query
context: Retrieved document context
Returns:
Simulated LLM response
"""
# This is a simple simulation that just returns a template response
# based on the first document's content
if "module" in query.lower():
return (
"Based on the provided context, a module in Sui Move is a fundamental code organization unit. "
"\n\nFrom the examples I can see:\n\n```move\nmodule sui::sui {\n // module contents\n}\n```\n\n"
"A module is defined using the `module` keyword followed by the module path (like `sui::sui`). "
"The module path typically follows the format `package_name::module_name`. Module contents are "
"enclosed in curly braces `{}`.\n\n"
"Modules contain various elements like:\n- Structs and resource definitions\n- Functions (public and private)\n"
"- Constants\n- Use statements for dependencies"
)
elif "coin" in query.lower() or "sui coin" in query.lower():
return (
"Based on the provided context, I can see that Coin<SUI> is the native token used in the Sui blockchain.\n\n"
"From sui.move:\n```move\n/// Coin<SUI> is the token used to pay for gas in Sui.\n"
"/// It has 9 decimals, and the smallest unit (10^-9) is called \"mist\".\n"
"module sui::sui {\n // ...\n const MIST_PER_SUI: u64 = 1_000_000_000;\n // ...\n}\n```\n\n"
"Key information about SUI coin:\n1. It is used to pay for gas (transaction fees) in the Sui blockchain\n"
"2. It has 9 decimal places\n3. The smallest unit (10^-9 SUI) is called \"mist\"\n"
"4. The conversion rate is 1 SUI = 1,000,000,000 mist"
)
elif "struct" in query.lower():
return (
"Based on the provided context, in Sui Move, a struct is a custom data type that can hold multiple fields. "
"Here's how to define a struct in Sui Move:\n\n"
"```move\nstruct Example {\n field1: u64,\n field2: String,\n field3: address\n}\n```\n\n"
"Structs in Sui Move can be used to represent both ordinary data and resources. Resources are special "
"structs that cannot be copied or implicitly discarded, only moved or explicitly destroyed."
)
else:
# Generic response
return (
"Based on the provided context, I can answer your question about Sui Move.\n\n"
"The Sui Move programming language is a safe and expressive language for writing smart contracts "
"on the Sui blockchain. It includes features like resource types, abilities, and modules, which "
"help developers create secure and efficient smart contracts.\n\n"
"For more specific information, please ask a more targeted question about Sui Move."
)
def process_query(self, query: str, top_k: int = 5) -> Dict[str, Any]:
"""
Process a query using the complete RAG pipeline.
Args:
query: User query
top_k: Number of documents to retrieve
Returns:
Dictionary with query, retrieved documents, and LLM response
"""
# Retrieve relevant documents
retrieved_docs = self.retrieve_documents(query, top_k)
# Format context from retrieved documents
context = self.format_retrieved_context(retrieved_docs)
# Call LLM with context
llm_response = self.call_llm_api(query, context)
# Return complete results
return {
"query": query,
"retrieved_documents": retrieved_docs,
"llm_response": llm_response
}
def main():
"""Command-line entry point for mcp-rag command"""
parser = argparse.ArgumentParser(description="RAG Query with LLM Integration")
parser.add_argument(
"query",
nargs="?",
help="The search query"
)
parser.add_argument(
"--index-file",
default=os.getenv("INDEX_FILE", "data/faiss_index.bin"),
help="Path to FAISS index file"
)
parser.add_argument(
"--api-key",
default=os.getenv("OPENAI_API_KEY", ""),
help="LLM API key"
)
parser.add_argument(
"--api-url",
default=os.getenv("LLM_API_URL", "https://api.openai.com/v1/chat/completions"),
help="LLM API URL"
)
parser.add_argument(
"--model",
default=os.getenv("LLM_MODEL", "gpt-4o"),
help="LLM model name"
)
parser.add_argument(
"--top-k",
type=int,
default=5,
help="Number of documents to retrieve"
)
parser.add_argument(
"--output-json",
action="store_true",
help="Output results as JSON"
)
parser.add_argument(
"--verbose",
action="store_true",
help="Enable verbose logging"
)
args = parser.parse_args()
# Set logging level
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
# Check for required query
if not args.query:
parser.print_help()
print("\nError: Query is required")
return 1
try:
# Initialize RAG integration
rag = RagLlmIntegration(
index_file=args.index_file,
api_key=args.api_key,
api_url=args.api_url,
model=args.model
)
# Process query
result = rag.process_query(args.query, args.top_k)
if args.output_json:
# Output as JSON for programmatic use
print(json.dumps(result, indent=2, default=str))
else:
# Pretty print for human readability
print("\n" + "=" * 80)
print(f"QUERY: {result['query']}")
print("=" * 80 + "\n")
print("RETRIEVED DOCUMENTS:")
for i, doc in enumerate(result['retrieved_documents']):
print(f" {i+1}. {os.path.basename(doc['path'])} (Score: {doc['score']:.4f})")
print("\n" + "=" * 80)
print("LLM RESPONSE:\n")
print(result['llm_response'])
print("\n" + "=" * 80)
return 0
except Exception as e:
logger.error(f"Error: {str(e)}")
return 1
if __name__ == "__main__":
main()