"""
REST API Server for PM Data
FastAPI server that wraps MCP server functionality
"""
import os
import logging
from typing import Optional
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from dotenv import load_dotenv
from mcp_client import MCPClient
# Load environment variables
load_dotenv()
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(title="PM Data API", version="1.0.0")
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify allowed origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize MCP client
mcp_client = MCPClient()
# Request/Response models
class QueryRequest(BaseModel):
query: str
class FetchIntervalUpdate(BaseModel):
minutes: int
# API Endpoints
@app.get("/")
async def root():
"""Root endpoint"""
return {"message": "PM Data API", "version": "1.0.0"}
@app.get("/api/query")
async def query(
q: str = Query(..., description="Natural language query"),
counter_name: Optional[str] = Query(None, description="Counter name"),
query_time: Optional[str] = Query(None, description="Query time (ISO format)"),
interface_name: Optional[str] = Query(None, description="Interface name")
):
"""
Query counter values
Supports natural language queries or specific parameters
"""
try:
# If natural language query provided
if q:
result = mcp_client.query_natural_language(q)
# Check for actual errors (not helpful "no data" messages)
if not result.get("success", True):
logger.error(f"Query error: {result.get('error')}")
raise HTTPException(status_code=400, detail=result.get("error", "Query failed"))
return result
# If specific parameters provided
if counter_name and query_time:
result = mcp_client.query_counter(
counter_name=counter_name,
query_time=query_time,
interface_name=interface_name
)
if not result.get("success"):
logger.error(f"Query error: {result.get('error')}")
raise HTTPException(status_code=400, detail=result.get("error", "Query failed"))
return result
raise HTTPException(status_code=400, detail="Either 'q' parameter or 'counter_name' and 'query_time' must be provided")
except HTTPException:
# Re-raise HTTP exceptions
raise
except Exception as e:
logger.error(f"Error in query endpoint: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.get("/api/interfaces")
async def list_interfaces():
"""List all available interfaces"""
try:
result = mcp_client.list_interfaces()
if not result.get("success"):
raise HTTPException(status_code=500, detail=result.get("error", "Failed to list interfaces"))
return result
except Exception as e:
logger.error(f"Error in list_interfaces endpoint: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/counters")
async def list_counters():
"""List all available counter types"""
try:
result = mcp_client.list_counters()
if not result.get("success"):
raise HTTPException(status_code=500, detail=result.get("error", "Failed to list counters"))
return result
except Exception as e:
logger.error(f"Error in list_counters endpoint: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/alerts")
async def get_alerts(
start_time: Optional[str] = Query(None, description="Start time (ISO format)"),
end_time: Optional[str] = Query(None, description="End time (ISO format)")
):
"""Get threshold alerts for a time range"""
try:
result = mcp_client.query_alerts(start_time=start_time, end_time=end_time)
if not result.get("success"):
raise HTTPException(status_code=500, detail=result.get("error", "Failed to query alerts"))
return result
except Exception as e:
logger.error(f"Error in get_alerts endpoint: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/config/fetch-interval")
async def get_fetch_interval():
"""Get current fetch interval in minutes"""
try:
# Use query engine to get fetch interval from database
import psycopg2
db_host = os.getenv("POSTGRES_HOST", "postgres")
db_database = os.getenv("POSTGRES_DB", "pm_data")
db_user = os.getenv("POSTGRES_USER", "postgres")
db_password = os.getenv("POSTGRES_PASSWORD", "postgres")
db_port = int(os.getenv("POSTGRES_PORT", "5432"))
conn = psycopg2.connect(
host=db_host,
database=db_database,
user=db_user,
password=db_password,
port=db_port
)
try:
with conn.cursor() as cur:
cur.execute("SELECT value FROM config WHERE key = 'fetch_interval_minutes'")
result = cur.fetchone()
if result:
interval = int(result[0])
return {"success": True, "interval_minutes": interval}
else:
# Default value
return {"success": True, "interval_minutes": 5}
finally:
conn.close()
except Exception as e:
logger.error(f"Error in get_fetch_interval endpoint: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to get fetch interval: {str(e)}")
@app.post("/api/config/fetch-interval")
async def update_fetch_interval(update: FetchIntervalUpdate):
"""Update fetch interval in minutes"""
try:
# Use direct database connection to update fetch interval
import psycopg2
db_host = os.getenv("POSTGRES_HOST", "postgres")
db_database = os.getenv("POSTGRES_DB", "pm_data")
db_user = os.getenv("POSTGRES_USER", "postgres")
db_password = os.getenv("POSTGRES_PASSWORD", "postgres")
db_port = int(os.getenv("POSTGRES_PORT", "5432"))
conn = psycopg2.connect(
host=db_host,
database=db_database,
user=db_user,
password=db_password,
port=db_port
)
try:
with conn.cursor() as cur:
cur.execute("""
UPDATE config SET value = %s, updated_at = CURRENT_TIMESTAMP
WHERE key = 'fetch_interval_minutes'
""", (str(update.minutes),))
conn.commit()
return {"success": True, "interval_minutes": update.minutes, "message": "Fetch interval updated"}
finally:
conn.close()
except Exception as e:
logger.error(f"Error in update_fetch_interval endpoint: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to update fetch interval: {str(e)}")
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("API_SERVER_PORT", "8000"))
uvicorn.run(app, host="0.0.0.0", port=port)