"""
FastAPI server for querying PM counter data
"""
from fastapi import FastAPI, Depends, HTTPException, Query, UploadFile, File
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from sqlalchemy import func, and_, or_
from database import (
get_db, FileRecord, NetworkElement, MeasurementInterval,
InterfaceCounter, IPCounter, TCPCounter, SystemCounter, BGPCounter
)
from datetime import datetime, timedelta
from typing import Optional, List
from pydantic import BaseModel
from config import Config
import os
import tempfile
import logging
from xml_parser import XMLParser
from data_storage import DataStorage
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="PM Counter API", version="1.0.0")
# Pydantic models for responses
class InterfaceCounterResponse(BaseModel):
interface_name: str
counter_name: str
value: float
unit: str
timestamp: datetime
class Config:
from_attributes = True
class SystemCounterResponse(BaseModel):
counter_name: str
value: float
unit: str
timestamp: datetime
class Config:
from_attributes = True
class NetworkElementInfo(BaseModel):
ne_name: str
ne_type: str
site: str
region: str
class Config:
from_attributes = True
@app.get("/")
def root():
return {"message": "PM Counter API", "version": "1.0.0"}
@app.get("/network-elements")
def get_network_elements(db: Session = Depends(get_db)):
"""Get all network elements"""
elements = db.query(NetworkElement).all()
return [NetworkElementInfo.from_orm(e) for e in elements]
@app.get("/interfaces/{interface_name}/counters")
def get_interface_counters(
interface_name: str,
counter_name: Optional[str] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
limit: int = Query(100, le=1000),
db: Session = Depends(get_db)
):
"""Get interface counters"""
query = db.query(
InterfaceCounter.interface_name,
InterfaceCounter.counter_name,
InterfaceCounter.value,
InterfaceCounter.unit,
MeasurementInterval.start_time.label('timestamp')
).join(MeasurementInterval).filter(
InterfaceCounter.interface_name == interface_name
)
if counter_name:
query = query.filter(InterfaceCounter.counter_name == counter_name)
if start_time:
query = query.filter(MeasurementInterval.start_time >= start_time)
if end_time:
query = query.filter(MeasurementInterval.end_time <= end_time)
results = query.order_by(MeasurementInterval.start_time.desc()).limit(limit).all()
return [
{
"interface_name": r.interface_name,
"counter_name": r.counter_name,
"value": r.value,
"unit": r.unit,
"timestamp": r.timestamp
}
for r in results
]
@app.get("/system/counters")
def get_system_counters(
counter_name: Optional[str] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
limit: int = Query(100, le=1000),
db: Session = Depends(get_db)
):
"""Get system counters"""
query = db.query(
SystemCounter.counter_name,
SystemCounter.value,
SystemCounter.unit,
MeasurementInterval.start_time.label('timestamp')
).join(MeasurementInterval)
if counter_name:
query = query.filter(SystemCounter.counter_name == counter_name)
if start_time:
query = query.filter(MeasurementInterval.start_time >= start_time)
if end_time:
query = query.filter(MeasurementInterval.end_time <= end_time)
results = query.order_by(MeasurementInterval.start_time.desc()).limit(limit).all()
return [
{
"counter_name": r.counter_name,
"value": r.value,
"unit": r.unit,
"timestamp": r.timestamp
}
for r in results
]
@app.get("/cpu/utilization")
def get_cpu_utilization(
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
limit: int = Query(100, le=1000),
db: Session = Depends(get_db)
):
"""Get CPU utilization metrics"""
query = db.query(
SystemCounter.value,
MeasurementInterval.start_time.label('timestamp')
).join(MeasurementInterval).filter(
SystemCounter.counter_name == 'cpuUtilization'
)
if start_time:
query = query.filter(MeasurementInterval.start_time >= start_time)
if end_time:
query = query.filter(MeasurementInterval.end_time <= end_time)
results = query.order_by(MeasurementInterval.start_time.desc()).limit(limit).all()
return [
{
"cpu_utilization": r.value,
"timestamp": r.timestamp
}
for r in results
]
@app.get("/memory/utilization")
def get_memory_utilization(
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
limit: int = Query(100, le=1000),
db: Session = Depends(get_db)
):
"""Get memory utilization metrics"""
query = db.query(
SystemCounter.value,
MeasurementInterval.start_time.label('timestamp')
).join(MeasurementInterval).filter(
SystemCounter.counter_name == 'memoryUtilization'
)
if start_time:
query = query.filter(MeasurementInterval.start_time >= start_time)
if end_time:
query = query.filter(MeasurementInterval.end_time <= end_time)
results = query.order_by(MeasurementInterval.start_time.desc()).limit(limit).all()
return [
{
"memory_utilization": r.value,
"timestamp": r.timestamp
}
for r in results
]
@app.get("/bgp/peers")
def get_bgp_peers(db: Session = Depends(get_db)):
"""Get all BGP peers"""
peers = db.query(
BGPCounter.peer_address,
BGPCounter.as_number
).distinct().all()
return [{"peer_address": p.peer_address, "as_number": p.as_number} for p in peers]
@app.get("/bgp/peers/{peer_address}/counters")
def get_bgp_counters(
peer_address: str,
counter_name: Optional[str] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
limit: int = Query(100, le=1000),
db: Session = Depends(get_db)
):
"""Get BGP peer counters"""
query = db.query(
BGPCounter.peer_address,
BGPCounter.counter_name,
BGPCounter.value,
BGPCounter.unit,
MeasurementInterval.start_time.label('timestamp')
).join(MeasurementInterval).filter(
BGPCounter.peer_address == peer_address
)
if counter_name:
query = query.filter(BGPCounter.counter_name == counter_name)
if start_time:
query = query.filter(MeasurementInterval.start_time >= start_time)
if end_time:
query = query.filter(MeasurementInterval.end_time <= end_time)
results = query.order_by(MeasurementInterval.start_time.desc()).limit(limit).all()
return [
{
"peer_address": r.peer_address,
"counter_name": r.counter_name,
"value": r.value,
"unit": r.unit,
"timestamp": r.timestamp
}
for r in results
]
@app.get("/files/processed")
def get_processed_files(
limit: int = Query(50, le=500),
db: Session = Depends(get_db)
):
"""Get list of processed files"""
files = db.query(FileRecord).filter(
FileRecord.processed_at.isnot(None)
).order_by(FileRecord.processed_at.desc()).limit(limit).all()
return [
{
"filename": f.filename,
"downloaded_at": f.downloaded_at,
"processed_at": f.processed_at,
"file_size": f.file_size
}
for f in files
]
@app.get("/stats/summary")
def get_stats_summary(db: Session = Depends(get_db)):
"""Get summary statistics"""
total_files = db.query(FileRecord).count()
processed_files = db.query(FileRecord).filter(
FileRecord.processed_at.isnot(None)
).count()
total_intervals = db.query(MeasurementInterval).count()
total_network_elements = db.query(NetworkElement).count()
return {
"total_files": total_files,
"processed_files": processed_files,
"total_intervals": total_intervals,
"total_network_elements": total_network_elements
}
@app.get("/config/fetch-interval")
def get_fetch_interval():
"""Get current fetch interval"""
return {
"interval_hours": Config.FETCH_INTERVAL_HOURS,
"interval_minutes": Config.FETCH_INTERVAL_HOURS * 60
}
@app.get("/rag/query")
def rag_query(question: str = Query(..., description="Natural language question"), db: Session = Depends(get_db)):
"""Process natural language question using RAG"""
try:
from rag_system import RAGSystem
rag = RAGSystem()
response = rag.process_question(question, db)
return {"response": response, "question": question}
except Exception as e:
logger.error(f"Error in RAG query: {e}")
import traceback
return {"error": str(e), "traceback": traceback.format_exc(), "question": question}
@app.post("/upload")
async def upload_file(
file: UploadFile = File(...),
db: Session = Depends(get_db)
):
"""Upload and process an XML file"""
# Validate file type
if not file.filename.endswith('.xml'):
raise HTTPException(status_code=400, detail="Only XML files are allowed")
try:
# Create temporary file to save uploaded content
upload_dir = Config.SFTP_LOCAL_PATH
os.makedirs(upload_dir, exist_ok=True)
# Generate unique filename to avoid conflicts
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
safe_filename = f"{timestamp}_{file.filename}"
temp_file_path = os.path.join(upload_dir, safe_filename)
# Save uploaded file
with open(temp_file_path, "wb") as f:
content = await file.read()
f.write(content)
logger.info(f"Uploaded file saved to {temp_file_path}")
# Check if file already processed (by original filename)
storage = DataStorage(db)
if storage.file_already_processed(file.filename):
# Clean up temp file
os.remove(temp_file_path)
return JSONResponse(
status_code=200,
content={
"message": "File already processed",
"filename": file.filename,
"status": "skipped"
}
)
# Parse XML
try:
parser = XMLParser(temp_file_path)
parsed_data = parser.parse()
# Update filename in parsed data to use original filename
parsed_data['file_info']['filename'] = file.filename
# Save to database
file_id = storage.save_file_data(parsed_data)
# Clean up temp file
os.remove(temp_file_path)
return JSONResponse(
status_code=200,
content={
"message": "File processed successfully",
"filename": file.filename,
"file_id": file_id,
"status": "processed",
"intervals": len(parsed_data.get('measurement_intervals', []))
}
)
except Exception as e:
# Clean up temp file on error
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
logger.error(f"Error processing uploaded file: {e}")
raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}")
except Exception as e:
logger.error(f"Error uploading file: {e}")
raise HTTPException(status_code=500, detail=f"Error uploading file: {str(e)}")
@app.post("/upload/multiple")
async def upload_multiple_files(
files: List[UploadFile] = File(...),
db: Session = Depends(get_db)
):
"""Upload and process multiple XML files"""
results = []
for file in files:
try:
# Validate file type
if not file.filename.endswith('.xml'):
results.append({
"filename": file.filename,
"status": "error",
"message": "Only XML files are allowed"
})
continue
# Create temporary file
upload_dir = Config.SFTP_LOCAL_PATH
os.makedirs(upload_dir, exist_ok=True)
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S_%f")
safe_filename = f"{timestamp}_{file.filename}"
temp_file_path = os.path.join(upload_dir, safe_filename)
# Save uploaded file
with open(temp_file_path, "wb") as f:
content = await file.read()
f.write(content)
# Check if already processed
storage = DataStorage(db)
if storage.file_already_processed(file.filename):
os.remove(temp_file_path)
results.append({
"filename": file.filename,
"status": "skipped",
"message": "File already processed"
})
continue
# Parse and save
try:
parser = XMLParser(temp_file_path)
parsed_data = parser.parse()
parsed_data['file_info']['filename'] = file.filename
file_id = storage.save_file_data(parsed_data)
os.remove(temp_file_path)
results.append({
"filename": file.filename,
"status": "processed",
"file_id": file_id,
"intervals": len(parsed_data.get('measurement_intervals', []))
})
except Exception as e:
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
results.append({
"filename": file.filename,
"status": "error",
"message": str(e)
})
except Exception as e:
results.append({
"filename": file.filename,
"status": "error",
"message": str(e)
})
return {
"message": f"Processed {len(results)} file(s)",
"results": results,
"successful": sum(1 for r in results if r["status"] == "processed"),
"skipped": sum(1 for r in results if r["status"] == "skipped"),
"errors": sum(1 for r in results if r["status"] == "error")
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=Config.API_HOST, port=Config.API_PORT)