"""
RAG (Retrieval Augmented Generation) system for PM Counter queries
Uses Groq API for LLM and database queries for retrieval
"""
import os
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_, func
from database import (
get_db, MeasurementInterval, SystemCounter, InterfaceCounter,
IPCounter, TCPCounter, BGPCounter, NetworkElement, FileRecord
)
import json
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
try:
from groq import Groq
GROQ_AVAILABLE = True
except ImportError:
GROQ_AVAILABLE = False
logger.warning("Groq library not installed. Install with: pip install groq")
class RAGSystem:
"""RAG system for querying PM counter data"""
def __init__(self):
self.groq_client = None
if GROQ_AVAILABLE:
api_key = os.getenv("GROQ_API_KEY")
if api_key:
self.groq_client = Groq(api_key=api_key)
else:
logger.warning("GROQ_API_KEY not set in environment")
def query_database(self, query_type: str, params: Dict[str, Any], db: Session) -> Dict[str, Any]:
"""Query database based on query type and parameters"""
if query_type == "system_counter":
return self._query_system_counter(params, db)
elif query_type == "interface_counter":
return self._query_interface_counter(params, db)
elif query_type == "network_element":
return self._query_network_element(params, db)
elif query_type == "time_range":
return self._query_time_range(params, db)
elif query_type == "summary":
return self._query_summary(params, db)
else:
return {"error": f"Unknown query type: {query_type}"}
def _query_system_counter(self, params: Dict[str, Any], db: Session) -> Dict[str, Any]:
"""Query system counters"""
counter_name = params.get("counter_name", "")
start_time = params.get("start_time")
end_time = params.get("end_time")
limit = params.get("limit", 100)
query = db.query(
SystemCounter.counter_name,
SystemCounter.value,
SystemCounter.unit,
MeasurementInterval.start_time.label('timestamp')
).join(MeasurementInterval)
if counter_name:
query = query.filter(SystemCounter.counter_name.ilike(f"%{counter_name}%"))
if start_time:
query = query.filter(MeasurementInterval.start_time >= start_time)
if end_time:
query = query.filter(MeasurementInterval.end_time <= end_time)
results = query.order_by(MeasurementInterval.start_time.desc()).limit(limit).all()
return {
"type": "system_counter",
"data": [
{
"counter_name": r.counter_name,
"value": r.value,
"unit": r.unit,
"timestamp": r.timestamp.isoformat()
}
for r in results
],
"count": len(results)
}
def _query_interface_counter(self, params: Dict[str, Any], db: Session) -> Dict[str, Any]:
"""Query interface counters"""
interface_name = params.get("interface_name", "")
counter_name = params.get("counter_name", "")
start_time = params.get("start_time")
end_time = params.get("end_time")
limit = params.get("limit", 100)
query = db.query(
InterfaceCounter.interface_name,
InterfaceCounter.counter_name,
InterfaceCounter.value,
InterfaceCounter.unit,
MeasurementInterval.start_time.label('timestamp')
).join(MeasurementInterval)
if interface_name:
query = query.filter(InterfaceCounter.interface_name.ilike(f"%{interface_name}%"))
if counter_name:
query = query.filter(InterfaceCounter.counter_name.ilike(f"%{counter_name}%"))
if start_time:
query = query.filter(MeasurementInterval.start_time >= start_time)
if end_time:
query = query.filter(MeasurementInterval.end_time <= end_time)
results = query.order_by(MeasurementInterval.start_time.desc()).limit(limit).all()
return {
"type": "interface_counter",
"data": [
{
"interface_name": r.interface_name,
"counter_name": r.counter_name,
"value": r.value,
"unit": r.unit,
"timestamp": r.timestamp.isoformat()
}
for r in results
],
"count": len(results)
}
def _query_network_element(self, params: Dict[str, Any], db: Session) -> Dict[str, Any]:
"""Query network elements"""
ne_name = params.get("ne_name", "")
query = db.query(NetworkElement)
if ne_name:
query = query.filter(NetworkElement.ne_name.ilike(f"%{ne_name}%"))
results = query.all()
return {
"type": "network_element",
"data": [
{
"ne_name": r.ne_name,
"ne_type": r.ne_type,
"site": r.site,
"region": r.region,
"country": r.country,
"management_ip": r.management_ip
}
for r in results
],
"count": len(results)
}
def _query_time_range(self, params: Dict[str, Any], db: Session) -> Dict[str, Any]:
"""Query data within a time range"""
start_time = params.get("start_time")
end_time = params.get("end_time")
if not start_time or not end_time:
return {"error": "start_time and end_time required"}
# Get all measurement intervals in range
intervals = db.query(MeasurementInterval).filter(
and_(
MeasurementInterval.start_time >= start_time,
MeasurementInterval.end_time <= end_time
)
).all()
result_data = []
for interval in intervals:
# Get system counters
system_counters = db.query(SystemCounter).filter(
SystemCounter.interval_id == interval.id
).all()
# Get interface counters
interface_counters = db.query(InterfaceCounter).filter(
InterfaceCounter.interval_id == interval.id
).all()
result_data.append({
"interval_start": interval.start_time.isoformat(),
"interval_end": interval.end_time.isoformat(),
"system_counters": [
{"name": sc.counter_name, "value": sc.value, "unit": sc.unit}
for sc in system_counters
],
"interface_counters": [
{
"interface": ic.interface_name,
"name": ic.counter_name,
"value": ic.value,
"unit": ic.unit
}
for ic in interface_counters[:10] # Limit to avoid too much data
]
})
return {
"type": "time_range",
"data": result_data,
"count": len(result_data)
}
def _query_summary(self, params: Dict[str, Any], db: Session) -> Dict[str, Any]:
"""Get summary statistics"""
total_files = db.query(FileRecord).count()
processed_files = db.query(FileRecord).filter(
FileRecord.processed_at.isnot(None)
).count()
total_intervals = db.query(MeasurementInterval).count()
total_network_elements = db.query(NetworkElement).count()
return {
"type": "summary",
"data": {
"total_files": total_files,
"processed_files": processed_files,
"total_intervals": total_intervals,
"total_network_elements": total_network_elements
}
}
def extract_query_params(self, question: str) -> Dict[str, Any]:
"""Extract query parameters from natural language question"""
import re
from dateutil import parser
params = {}
question_lower = question.lower()
# Extract date/time
date_patterns = [
r"(\d{4}-\d{2}-\d{2})", # YYYY-MM-DD
r"(\d{1,2}/\d{1,2}/\d{4})", # MM/DD/YYYY
r"(january|february|march|april|may|june|july|august|september|october|november|december)\s+\d{1,2},?\s+\d{4}",
]
time_patterns = [
r"(\d{1,2}):(\d{2})\s*(am|pm)", # 2:10 pm
r"(\d{1,2}):(\d{2})", # 14:10
]
# Try to parse dates
for pattern in date_patterns:
match = re.search(pattern, question, re.IGNORECASE)
if match:
try:
date_str = match.group(1) if match.lastindex >= 1 else match.group(0)
parsed_date = parser.parse(date_str, fuzzy=True)
params["date"] = parsed_date.date()
except Exception as e:
logger.debug(f"Failed to parse date: {e}")
pass
# Try to parse time
for pattern in time_patterns:
match = re.search(pattern, question, re.IGNORECASE)
if match:
try:
hour = int(match.group(1))
minute = int(match.group(2))
# Handle AM/PM
if match.lastindex >= 3:
am_pm = match.group(3).lower() if match.lastindex >= 3 else None
if am_pm == 'pm' and hour != 12:
hour += 12
elif am_pm == 'am' and hour == 12:
hour = 0
params["hour"] = hour
params["minute"] = minute
except Exception as e:
logger.debug(f"Failed to parse time: {e}")
pass
# Extract counter names
if "cpu" in question_lower and "utilization" in question_lower:
params["counter_name"] = "cpuUtilization"
elif "memory" in question_lower and "utilization" in question_lower:
params["counter_name"] = "memoryUtilization"
elif "temperature" in question_lower:
params["counter_name"] = "temperature"
elif "interface" in question_lower:
params["query_type"] = "interface"
# Extract interface names
interface_match = re.search(r"(gigabit|ten.?gigabit)?ethernet[\w/]+", question_lower)
if interface_match:
params["interface_name"] = interface_match.group(0)
return params
def generate_response(self, question: str, db_data: Dict[str, Any]) -> str:
"""Generate natural language response using Groq API"""
if not self.groq_client:
# Fallback to simple formatting
return self._format_response_simple(db_data)
try:
# Prepare context for LLM
context = json.dumps(db_data, indent=2)
prompt = f"""You are a helpful assistant that answers questions about network performance monitoring data.
User Question: {question}
Retrieved Data from Database:
{context}
Based on the retrieved data, provide a clear and concise answer to the user's question.
If the data shows multiple records, summarize the key findings.
If no data is found, explain that clearly.
Be specific with numbers, timestamps, and units when available.
"""
response = self.groq_client.chat.completions.create(
model="llama-3.1-70b-versatile", # or "mixtral-8x7b-32768" for faster responses
messages=[
{"role": "system", "content": "You are a helpful assistant that answers questions about network performance monitoring data. Always be precise with numbers and timestamps."},
{"role": "user", "content": prompt}
],
temperature=0.3,
max_tokens=500
)
return response.choices[0].message.content
except Exception as e:
logger.error(f"Error generating response with Groq: {e}")
return self._format_response_simple(db_data)
def _format_response_simple(self, db_data: Dict[str, Any]) -> str:
"""Simple formatting when Groq is not available"""
if "error" in db_data:
return f"Error: {db_data['error']}"
if db_data.get("count", 0) == 0:
return "No data found matching your query."
data = db_data.get("data", [])
query_type = db_data.get("type", "")
if query_type == "system_counter":
if len(data) == 1:
d = data[0]
return f"{d['counter_name']}: {d['value']} {d['unit']} (at {d['timestamp']})"
else:
result = f"Found {len(data)} records:\n\n"
for d in data[:10]:
result += f"- {d['counter_name']}: {d['value']} {d['unit']} at {d['timestamp']}\n"
return result
elif query_type == "interface_counter":
result = f"Found {len(data)} interface counter records:\n\n"
for d in data[:10]:
result += f"- {d['interface_name']} - {d['counter_name']}: {d['value']} {d['unit']} at {d['timestamp']}\n"
return result
else:
return json.dumps(data, indent=2)
def process_question(self, question: str, db: Session) -> str:
"""Main method to process a question using RAG"""
# Extract parameters from question
params = self.extract_query_params(question)
# Determine query type
question_lower = question.lower()
if "system" in question_lower or "cpu" in question_lower or "memory" in question_lower:
query_type = "system_counter"
elif "interface" in question_lower:
query_type = "interface_counter"
elif "network element" in question_lower or "network" in question_lower:
query_type = "network_element"
elif "summary" in question_lower or "statistics" in question_lower or "stats" in question_lower:
query_type = "summary"
else:
# Try to infer from parameters
if "counter_name" in params:
query_type = "system_counter"
elif "interface_name" in params:
query_type = "interface_counter"
else:
query_type = "summary"
# Build time range if date/time specified
if "date" in params:
date = params["date"]
hour = params.get("hour", 0)
minute = params.get("minute", 0)
try:
# Create datetime with specified date and time
if isinstance(date, str):
from dateutil import parser
date = parser.parse(date).date()
start_time = datetime.combine(date, datetime.min.time().replace(hour=hour, minute=minute))
# For specific time queries, use a 5-minute window
end_time = start_time + timedelta(minutes=5)
params["start_time"] = start_time
params["end_time"] = end_time
except Exception as e:
logger.error(f"Error parsing date/time: {e}")
# If time parsing fails, use the date for the whole day
if isinstance(date, str):
from dateutil import parser
date = parser.parse(date).date()
start_time = datetime.combine(date, datetime.min.time())
end_time = datetime.combine(date, datetime.max.time())
params["start_time"] = start_time
params["end_time"] = end_time
# Query database
db_result = self.query_database(query_type, params, db)
# Generate response using LLM
response = self.generate_response(question, db_result)
return response