"""
FastAPI application for ClientSphere RAG Backend.
Provides endpoints for knowledge base management and chat.
"""
from fastapi import FastAPI, File, UploadFile, HTTPException, Form, BackgroundTasks, Request, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from pathlib import Path
import shutil
import uuid
from datetime import datetime
from typing import Optional
import logging
from app.config import settings
from app.middleware.auth import get_auth_context, require_auth
from app.middleware.rate_limit import (
limiter,
get_tenant_rate_limit_key,
RateLimitExceeded,
_rate_limit_exceeded_handler
)
from app.models.schemas import (
UploadResponse,
ChatRequest,
ChatResponse,
KnowledgeBaseStats,
HealthResponse,
DocumentStatus,
Citation,
)
from app.models.billing_schemas import (
UsageResponse,
PlanLimitsResponse,
CostReportResponse,
SetPlanRequest
)
from app.rag.ingest import parser
from app.rag.chunking import chunker
from app.rag.embeddings import get_embedding_service
from app.rag.vectorstore import get_vector_store
from app.rag.retrieval import get_retrieval_service
from app.rag.answer import get_answer_service
from app.db.database import get_db, init_db
from app.billing.quota import check_quota, ensure_tenant_exists
from app.billing.usage_tracker import track_usage
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title=settings.APP_NAME,
description="RAG-based customer support chatbot API",
version="1.0.0",
)
# Initialize database on startup
@app.on_event("startup")
async def startup_event():
"""Initialize database on application startup."""
init_db()
logger.info("Database initialized")
# Configure CORS - SECURITY: Restrict in production
if settings.ALLOWED_ORIGINS == "*":
allowed_origins = ["*"]
else:
# Split by comma and strip whitespace
allowed_origins = [origin.strip() for origin in settings.ALLOWED_ORIGINS.split(",") if origin.strip()]
# Default to allowing localhost if no origins specified
if not allowed_origins or allowed_origins == ["*"]:
allowed_origins = ["*"] # Allow all in dev mode
logger.info(f"CORS configured with origins: {allowed_origins}")
app.add_middleware(
CORSMiddleware,
allow_origins=allowed_origins,
allow_credentials=True,
allow_methods=["GET", "POST", "DELETE", "OPTIONS"], # Include OPTIONS for preflight
allow_headers=["Content-Type", "Authorization", "X-Tenant-Id", "X-User-Id"], # Include auth headers
)
# Configure rate limiting
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Add exception handler for validation errors
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""Handle request validation errors with detailed logging."""
body = await request.body()
logger.error(f"Request validation error: {exc.errors()}")
logger.error(f"Request body (raw): {body}")
logger.error(f"Request headers: {dict(request.headers)}")
return JSONResponse(
status_code=422,
content={"detail": exc.errors(), "body": body.decode('utf-8', errors='ignore')}
)
# Add exception handler for validation errors
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""Handle request validation errors with detailed logging."""
logger.error(f"Request validation error: {exc.errors()}")
logger.error(f"Request body: {await request.body()}")
return JSONResponse(
status_code=422,
content={"detail": exc.errors(), "body": str(await request.body())}
)
# ============== Health & Status Endpoints ==============
@app.get("/", response_model=HealthResponse)
async def root():
"""Root endpoint with basic info."""
return HealthResponse(
status="ok",
version="1.0.0",
vector_db_connected=True,
llm_configured=bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY)
)
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint."""
try:
vector_store = get_vector_store()
stats = vector_store.get_stats()
return HealthResponse(
status="healthy",
version="1.0.0",
vector_db_connected=True,
llm_configured=bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY)
)
except Exception as e:
logger.error(f"Health check failed: {e}")
return HealthResponse(
status="unhealthy",
version="1.0.0",
vector_db_connected=False,
llm_configured=False
)
@app.get("/health/live")
async def liveness():
"""Kubernetes liveness probe - always returns alive."""
return {"status": "alive"}
@app.get("/health/ready")
async def readiness():
"""Kubernetes readiness probe - checks dependencies."""
checks = {
"vector_db": False,
"llm_configured": bool(settings.GEMINI_API_KEY or settings.OPENAI_API_KEY)
}
# Check vector DB connection
try:
vector_store = get_vector_store()
vector_store.get_stats()
checks["vector_db"] = True
except Exception as e:
logger.warning(f"Vector DB check failed: {e}")
checks["vector_db"] = False
# All checks must pass
if all(checks.values()):
return {"status": "ready", "checks": checks}
else:
from fastapi import HTTPException
raise HTTPException(status_code=503, detail={"status": "not_ready", "checks": checks})
# ============== Knowledge Base Endpoints ==============
@app.post("/kb/upload", response_model=UploadResponse)
@limiter.limit("20/hour", key_func=get_tenant_rate_limit_key)
async def upload_document(
background_tasks: BackgroundTasks,
request: Request,
file: UploadFile = File(...),
tenant_id: Optional[str] = Form(None), # Optional in dev, ignored in prod
user_id: Optional[str] = Form(None), # Optional in dev, ignored in prod
kb_id: str = Form(...)
):
"""
Upload a document to the knowledge base.
- Saves file to disk
- Parses and chunks the document
- Generates embeddings
- Stores in vector database
"""
# SECURITY: Extract tenant_id from auth token in production
if settings.ENV == "prod":
auth_context = await require_auth(request)
tenant_id = auth_context.get("tenant_id")
if not tenant_id:
raise HTTPException(
status_code=403,
detail="tenant_id must come from authentication token in production mode"
)
elif not tenant_id:
raise HTTPException(
status_code=400,
detail="tenant_id is required"
)
# Validate file type
file_ext = Path(file.filename).suffix.lower()
if file_ext not in parser.SUPPORTED_EXTENSIONS:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type: {file_ext}. Supported: {parser.SUPPORTED_EXTENSIONS}"
)
# Validate file size (SECURITY)
file.file.seek(0, 2) # Seek to end
file_size = file.file.tell()
file.file.seek(0) # Reset to start
max_size_bytes = settings.MAX_FILE_SIZE_MB * 1024 * 1024
if file_size > max_size_bytes:
raise HTTPException(
status_code=400,
detail=f"File too large. Maximum size: {settings.MAX_FILE_SIZE_MB}MB"
)
# Generate document ID
doc_id = f"{tenant_id}_{kb_id}_{uuid.uuid4().hex[:8]}"
# Save file to uploads directory
upload_path = settings.UPLOADS_DIR / f"{doc_id}_{file.filename}"
try:
with open(upload_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
logger.info(f"Saved file: {upload_path}")
except Exception as e:
logger.error(f"Error saving file: {e}")
raise HTTPException(status_code=500, detail="Failed to save file")
# Process document in background
background_tasks.add_task(
process_document,
upload_path,
tenant_id, # CRITICAL: Multi-tenant isolation
user_id,
kb_id,
file.filename,
doc_id
)
return UploadResponse(
success=True,
message="Document upload started. Processing in background.",
document_id=doc_id,
file_name=file.filename,
chunks_created=0,
status=DocumentStatus.PROCESSING
)
async def process_document(
file_path: Path,
tenant_id: str, # CRITICAL: Multi-tenant isolation
user_id: str,
kb_id: str,
original_filename: str,
document_id: str
):
"""
Background task to process an uploaded document.
"""
try:
logger.info(f"Processing document: {original_filename}")
# Parse document
parsed_doc = parser.parse(file_path)
logger.info(f"Parsed document: {len(parsed_doc.text)} characters")
# Chunk document
chunks = chunker.chunk_text(
parsed_doc.text,
page_numbers=parsed_doc.page_map
)
logger.info(f"Created {len(chunks)} chunks")
if not chunks:
logger.warning(f"No chunks created from {original_filename}")
return
# Create metadata for each chunk
metadatas = []
chunk_ids = []
chunk_texts = []
for chunk in chunks:
metadata = chunker.create_chunk_metadata(
chunk=chunk,
tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation
kb_id=kb_id,
user_id=user_id,
file_name=original_filename,
file_type=parsed_doc.file_type,
total_chunks=len(chunks),
document_id=document_id
)
metadatas.append(metadata)
chunk_ids.append(metadata["chunk_id"])
chunk_texts.append(chunk.content)
# Generate embeddings
embedding_service = get_embedding_service()
embeddings = embedding_service.embed_texts(chunk_texts)
logger.info(f"Generated {len(embeddings)} embeddings")
# Store in vector database
vector_store = get_vector_store()
vector_store.add_documents(
documents=chunk_texts,
embeddings=embeddings,
metadatas=metadatas,
ids=chunk_ids
)
logger.info(f"Successfully processed {original_filename}: {len(chunks)} chunks stored")
except Exception as e:
logger.error(f"Error processing document {original_filename}: {e}")
raise
@app.get("/kb/stats", response_model=KnowledgeBaseStats)
async def get_kb_stats(
request: Request,
tenant_id: Optional[str] = None, # Optional in dev, ignored in prod
kb_id: Optional[str] = None,
user_id: Optional[str] = None # Optional in dev, ignored in prod
):
"""Get statistics for a knowledge base."""
# SECURITY: Get tenant_id and user_id from auth context
auth_context = await get_auth_context(request)
tenant_id_from_auth = auth_context.get("tenant_id")
user_id_from_auth = auth_context.get("user_id")
if settings.ENV == "prod":
if not tenant_id_from_auth or not user_id_from_auth:
raise HTTPException(
status_code=403,
detail="tenant_id and user_id must come from authentication token in production mode"
)
tenant_id = tenant_id_from_auth
user_id = user_id_from_auth
else:
tenant_id = tenant_id or tenant_id_from_auth
user_id = user_id or user_id_from_auth
if not tenant_id or not kb_id or not user_id:
raise HTTPException(
status_code=400,
detail="tenant_id, kb_id, and user_id are required"
)
try:
vector_store = get_vector_store()
stats = vector_store.get_stats(tenant_id=tenant_id, kb_id=kb_id, user_id=user_id)
return KnowledgeBaseStats(
tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation
kb_id=kb_id,
user_id=user_id,
total_documents=len(stats.get("file_names", [])),
total_chunks=stats.get("total_chunks", 0),
file_names=stats.get("file_names", []),
last_updated=datetime.utcnow()
)
except Exception as e:
logger.error(f"Error getting KB stats: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/kb/document")
async def delete_document(
request: Request,
tenant_id: Optional[str] = None, # Optional in dev, ignored in prod
kb_id: Optional[str] = None,
user_id: Optional[str] = None, # Optional in dev, ignored in prod
file_name: Optional[str] = None
):
"""Delete a document from the knowledge base."""
# SECURITY: Get tenant_id and user_id from auth context
auth_context = await get_auth_context(request)
tenant_id_from_auth = auth_context.get("tenant_id")
user_id_from_auth = auth_context.get("user_id")
if settings.ENV == "prod":
if not tenant_id_from_auth or not user_id_from_auth:
raise HTTPException(
status_code=403,
detail="tenant_id and user_id must come from authentication token in production mode"
)
tenant_id = tenant_id_from_auth
user_id = user_id_from_auth
else:
tenant_id = tenant_id or tenant_id_from_auth
user_id = user_id or user_id_from_auth
if not tenant_id or not kb_id or not user_id or not file_name:
raise HTTPException(
status_code=400,
detail="tenant_id, kb_id, user_id, and file_name are required (provide via headers or query params)"
)
try:
vector_store = get_vector_store()
deleted = vector_store.delete_by_filter({
"tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation
"kb_id": kb_id,
"user_id": user_id,
"file_name": file_name
})
return {
"success": True,
"message": f"Deleted {deleted} chunks",
"file_name": file_name
}
except Exception as e:
logger.error(f"Error deleting document: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/kb/clear")
async def clear_kb(
request: Request,
tenant_id: Optional[str] = None, # Optional in dev, ignored in prod
kb_id: Optional[str] = None,
user_id: Optional[str] = None # Optional in dev, ignored in prod
):
"""Clear all documents from a knowledge base."""
# SECURITY: Get tenant_id and user_id from auth context
auth_context = await get_auth_context(request)
tenant_id_from_auth = auth_context.get("tenant_id")
user_id_from_auth = auth_context.get("user_id")
if settings.ENV == "prod":
if not tenant_id_from_auth or not user_id_from_auth:
raise HTTPException(
status_code=403,
detail="tenant_id and user_id must come from authentication token in production mode"
)
tenant_id = tenant_id_from_auth
user_id = user_id_from_auth
else:
tenant_id = tenant_id or tenant_id_from_auth
user_id = user_id or user_id_from_auth
if not tenant_id or not kb_id or not user_id:
raise HTTPException(
status_code=400,
detail="tenant_id, kb_id, and user_id are required"
)
try:
vector_store = get_vector_store()
deleted = vector_store.delete_by_filter({
"tenant_id": tenant_id, # CRITICAL: Multi-tenant isolation
"kb_id": kb_id,
"user_id": user_id
})
return {
"success": True,
"message": f"Cleared knowledge base. Deleted {deleted} chunks.",
"kb_id": kb_id
}
except Exception as e:
logger.error(f"Error clearing KB: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ============== Chat Endpoints ==============
@app.post("/chat", response_model=ChatResponse)
@limiter.limit("10/minute", key_func=get_tenant_rate_limit_key)
async def chat(chat_request: ChatRequest, request: Request):
"""
Process a chat message using RAG.
- Retrieves relevant context from knowledge base
- Generates answer using LLM
- Returns answer with citations
"""
conversation_id = "unknown"
try:
logger.info(f"=== CHAT REQUEST RECEIVED ===")
logger.info(f"Request body: tenant_id={chat_request.tenant_id}, user_id={chat_request.user_id}, kb_id={chat_request.kb_id}, question_length={len(chat_request.question)}")
logger.info(f"Request headers: {dict(request.headers)}")
# SECURITY: Get tenant_id and user_id from auth context
# In PROD: MUST come from JWT token (never from request body)
try:
auth_context = await get_auth_context(request)
except Exception as e:
logger.error(f"Error getting auth context: {e}", exc_info=True)
raise HTTPException(status_code=401, detail=f"Authentication error: {str(e)}")
tenant_id_from_auth = auth_context.get("tenant_id")
user_id_from_auth = auth_context.get("user_id")
if settings.ENV == "prod":
if not tenant_id_from_auth or not user_id_from_auth:
raise HTTPException(
status_code=403,
detail="tenant_id and user_id must come from authentication token in production mode"
)
# Override request values with auth context (security enforcement)
chat_request.tenant_id = tenant_id_from_auth
chat_request.user_id = user_id_from_auth
else:
# DEV mode: use from request if provided, otherwise from auth context
if not chat_request.tenant_id:
chat_request.tenant_id = tenant_id_from_auth
if not chat_request.user_id:
chat_request.user_id = user_id_from_auth
if not chat_request.tenant_id or not chat_request.user_id:
raise HTTPException(
status_code=400,
detail="tenant_id and user_id are required (provide via X-Tenant-Id/X-User-Id headers or request body)"
)
# Log without PII in production
if settings.ENV == "prod":
logger.info(f"Chat request: tenant={chat_request.tenant_id}, user={chat_request.user_id}, kb={chat_request.kb_id}, q_length={len(chat_request.question)}")
else:
logger.info(f"Chat request: tenant={chat_request.tenant_id}, user={chat_request.user_id}, kb={chat_request.kb_id}, q={chat_request.question[:50]}...")
# Generate conversation ID if not provided
conversation_id = chat_request.conversation_id or f"conv_{uuid.uuid4().hex[:12]}"
# Get database session
try:
db = next(get_db())
except Exception as e:
logger.error(f"Database connection error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
try:
# Ensure tenant exists in billing DB
ensure_tenant_exists(db, chat_request.tenant_id)
# Check quota BEFORE making LLM call
has_quota, quota_error = check_quota(db, chat_request.tenant_id)
if not has_quota:
logger.warning(f"Quota exceeded for tenant {chat_request.tenant_id}")
raise HTTPException(
status_code=402,
detail=quota_error or "AI quota exceeded. Upgrade your plan."
)
# Retrieve relevant context
retrieval_service = get_retrieval_service()
results, confidence, has_relevant = retrieval_service.retrieve(
query=chat_request.question,
tenant_id=chat_request.tenant_id, # CRITICAL: Multi-tenant isolation
kb_id=chat_request.kb_id,
user_id=chat_request.user_id
)
logger.info(f"Retrieval results: {len(results)} results, confidence={confidence:.3f}, has_relevant={has_relevant}")
# Format context for LLM
context, citations_info = retrieval_service.get_context_for_llm(results)
logger.info(f"Formatted context length: {len(context)} chars, citations: {len(citations_info)}")
# Generate answer
answer_service = get_answer_service()
answer_result = answer_service.generate_answer(
question=chat_request.question,
context=context,
citations_info=citations_info,
confidence=confidence,
has_relevant_results=has_relevant
)
# Track usage if LLM was called (usage info present)
usage_info = answer_result.get("usage")
if usage_info:
try:
track_usage(
db=db,
tenant_id=chat_request.tenant_id,
user_id=chat_request.user_id,
kb_id=chat_request.kb_id,
provider=settings.LLM_PROVIDER,
model=usage_info.get("model_used", settings.GEMINI_MODEL if settings.LLM_PROVIDER == "gemini" else settings.OPENAI_MODEL),
prompt_tokens=usage_info.get("prompt_tokens", 0),
completion_tokens=usage_info.get("completion_tokens", 0)
)
except Exception as e:
logger.error(f"Failed to track usage: {e}", exc_info=True)
# Don't fail the request if usage tracking fails
# Build metadata with refusal info
metadata = {
"chunks_retrieved": len(results),
"kb_id": chat_request.kb_id
}
if "refused" in answer_result:
metadata["refused"] = answer_result["refused"]
if "refusal_reason" in answer_result:
metadata["refusal_reason"] = answer_result["refusal_reason"]
if "verifier_passed" in answer_result:
metadata["verifier_passed"] = answer_result["verifier_passed"]
return ChatResponse(
success=True,
answer=answer_result["answer"],
citations=answer_result["citations"],
confidence=answer_result["confidence"],
from_knowledge_base=answer_result["from_knowledge_base"],
escalation_suggested=answer_result["escalation_suggested"],
conversation_id=conversation_id,
refused=answer_result.get("refused", False),
metadata=metadata
)
except ValueError as e:
# API key or configuration error
error_msg = str(e)
logger.error(f"Configuration error: {error_msg}")
if "API key" in error_msg.lower():
return ChatResponse(
success=False,
answer="⚠️ LLM API key not configured. Please set GEMINI_API_KEY in your .env file. Retrieval is working, but answer generation requires an API key.",
citations=[],
confidence=0.0,
from_knowledge_base=False,
escalation_suggested=True,
conversation_id=conversation_id,
metadata={"error": error_msg, "error_type": "configuration"}
)
else:
return ChatResponse(
success=False,
answer=f"Configuration error: {error_msg}",
citations=[],
confidence=0.0,
from_knowledge_base=False,
escalation_suggested=True,
conversation_id=conversation_id,
metadata={"error": error_msg}
)
except HTTPException:
# Re-raise HTTP exceptions (they have proper status codes)
raise
except Exception as e:
logger.error(f"Chat error: {e}", exc_info=True)
logger.error(f"Error type: {type(e).__name__}", exc_info=True)
return ChatResponse(
success=False,
answer=f"I encountered an error processing your request: {str(e)}. Please check the server logs for details.",
citations=[],
confidence=0.0,
from_knowledge_base=False,
escalation_suggested=True,
conversation_id=conversation_id,
metadata={"error": str(e), "error_type": type(e).__name__}
)
except HTTPException:
# Re-raise HTTP exceptions from outer try block
raise
except Exception as e:
logger.error(f"Outer chat error: {e}", exc_info=True)
return ChatResponse(
success=False,
answer=f"I encountered an error processing your request: {str(e)}. Please check the server logs for details.",
citations=[],
confidence=0.0,
from_knowledge_base=False,
escalation_suggested=True,
conversation_id=conversation_id,
metadata={"error": str(e), "error_type": type(e).__name__}
)
# ============== Utility Endpoints ==============
@app.get("/kb/search")
@limiter.limit("30/minute", key_func=get_tenant_rate_limit_key)
async def search_kb(
request: Request,
query: str,
tenant_id: Optional[str] = None, # Optional in dev, ignored in prod
kb_id: Optional[str] = None,
user_id: Optional[str] = None, # Optional in dev, ignored in prod
top_k: int = 5
):
"""
Search the knowledge base without generating an answer.
Useful for debugging and testing retrieval.
"""
# SECURITY: Extract tenant_id from auth token in production
if settings.ENV == "prod":
auth_context = await require_auth(request)
tenant_id = auth_context.get("tenant_id")
user_id = auth_context.get("user_id")
if not tenant_id or not user_id:
raise HTTPException(
status_code=403,
detail="tenant_id and user_id must come from authentication token in production mode"
)
elif not tenant_id or not kb_id or not user_id:
raise HTTPException(
status_code=400,
detail="tenant_id, kb_id, and user_id are required"
)
try:
retrieval_service = get_retrieval_service()
results, confidence, has_relevant = retrieval_service.retrieve(
query=query,
tenant_id=tenant_id, # CRITICAL: Multi-tenant isolation
kb_id=kb_id,
user_id=user_id,
top_k=top_k
)
return {
"success": True,
"results": [
{
"chunk_id": r.chunk_id,
"content": r.content[:500] + "..." if len(r.content) > 500 else r.content,
"metadata": r.metadata,
"similarity_score": r.similarity_score
}
for r in results
],
"confidence": confidence,
"has_relevant_results": has_relevant
}
except Exception as e:
logger.error(f"Search error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ============== Billing & Usage Endpoints ==============
@app.get("/billing/usage", response_model=UsageResponse)
async def get_usage(
request: Request,
range: str = "month", # "day" or "month"
year: Optional[int] = None,
month: Optional[int] = None,
day: Optional[int] = None
):
"""
Get usage statistics for the current tenant.
Args:
range: "day" or "month"
year: Year (optional, defaults to current)
month: Month 1-12 (optional, defaults to current)
day: Day 1-31 (optional, defaults to current, only for range="day")
"""
# Get tenant from auth
auth_context = await get_auth_context(request)
tenant_id = auth_context.get("tenant_id")
if not tenant_id:
raise HTTPException(status_code=403, detail="tenant_id required")
db = next(get_db())
try:
from app.db.models import UsageDaily, UsageMonthly
from datetime import datetime
from calendar import monthrange
now = datetime.utcnow()
target_year = year or now.year
target_month = month or now.month
if range == "day":
target_day = day or now.day
date_start = datetime(target_year, target_month, target_day)
daily = db.query(UsageDaily).filter(
UsageDaily.tenant_id == tenant_id,
UsageDaily.date == date_start
).first()
if not daily:
return UsageResponse(
tenant_id=tenant_id,
period="day",
total_requests=0,
total_tokens=0,
total_cost_usd=0.0,
start_date=date_start,
end_date=date_start
)
return UsageResponse(
tenant_id=tenant_id,
period="day",
total_requests=daily.total_requests,
total_tokens=daily.total_tokens,
total_cost_usd=daily.total_cost_usd,
gemini_requests=daily.gemini_requests,
openai_requests=daily.openai_requests,
start_date=daily.date,
end_date=daily.date
)
else: # month
monthly = db.query(UsageMonthly).filter(
UsageMonthly.tenant_id == tenant_id,
UsageMonthly.year == target_year,
UsageMonthly.month == target_month
).first()
if not monthly:
# Calculate date range for the month
_, last_day = monthrange(target_year, target_month)
start_date = datetime(target_year, target_month, 1)
end_date = datetime(target_year, target_month, last_day)
return UsageResponse(
tenant_id=tenant_id,
period="month",
total_requests=0,
total_tokens=0,
total_cost_usd=0.0,
start_date=start_date,
end_date=end_date
)
_, last_day = monthrange(monthly.year, monthly.month)
start_date = datetime(monthly.year, monthly.month, 1)
end_date = datetime(monthly.year, monthly.month, last_day)
return UsageResponse(
tenant_id=tenant_id,
period="month",
total_requests=monthly.total_requests,
total_tokens=monthly.total_tokens,
total_cost_usd=monthly.total_cost_usd,
gemini_requests=monthly.gemini_requests,
openai_requests=monthly.openai_requests,
start_date=start_date,
end_date=end_date
)
except Exception as e:
logger.error(f"Error getting usage: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/billing/limits", response_model=PlanLimitsResponse)
async def get_limits(request: Request):
"""Get current plan limits and usage for the tenant."""
# Get tenant from auth
auth_context = await get_auth_context(request)
tenant_id = auth_context.get("tenant_id")
if not tenant_id:
raise HTTPException(status_code=403, detail="tenant_id required")
db = next(get_db())
try:
from app.billing.quota import get_tenant_plan, get_monthly_usage
from datetime import datetime
plan = get_tenant_plan(db, tenant_id)
if not plan:
# Default to starter
plan_name = "starter"
monthly_limit = 500
else:
plan_name = plan.plan_name
monthly_limit = plan.monthly_chat_limit
# Get current month usage
now = datetime.utcnow()
monthly_usage = get_monthly_usage(db, tenant_id, now.year, now.month)
current_usage = monthly_usage.total_requests if monthly_usage else 0
remaining = None if monthly_limit == -1 else max(0, monthly_limit - current_usage)
return PlanLimitsResponse(
tenant_id=tenant_id,
plan_name=plan_name,
monthly_chat_limit=monthly_limit,
current_month_usage=current_usage,
remaining_chats=remaining
)
except Exception as e:
logger.error(f"Error getting limits: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/billing/plan")
async def set_plan(request_body: SetPlanRequest, http_request: Request):
"""
Set tenant's subscription plan (admin only in production).
In dev mode, allows any tenant to set their plan.
In prod mode, should be restricted to admin users.
"""
# Get tenant from auth
auth_context = await get_auth_context(http_request)
auth_tenant_id = auth_context.get("tenant_id")
# In prod, verify admin role (placeholder - implement actual admin check)
if settings.ENV == "prod":
# TODO: Add admin role check
if auth_tenant_id != request_body.tenant_id:
raise HTTPException(status_code=403, detail="Cannot set plan for other tenants")
db = next(get_db())
try:
from app.billing.quota import set_tenant_plan
plan = set_tenant_plan(db, request_body.tenant_id, request_body.plan_name)
return {
"success": True,
"tenant_id": request_body.tenant_id,
"plan_name": plan.plan_name,
"monthly_chat_limit": plan.monthly_chat_limit
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Error setting plan: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/billing/cost-report", response_model=CostReportResponse)
async def get_cost_report(
request: Request,
range: str = "month",
year: Optional[int] = None,
month: Optional[int] = None
):
"""Get cost report with breakdown by provider and model."""
# Get tenant from auth
auth_context = await get_auth_context(request)
tenant_id = auth_context.get("tenant_id")
if not tenant_id:
raise HTTPException(status_code=403, detail="tenant_id required")
db = next(get_db())
try:
from app.db.models import UsageEvent
from datetime import datetime
from sqlalchemy import func, and_
now = datetime.utcnow()
target_year = year or now.year
target_month = month or now.month
# Query usage events for the period
if range == "month":
query = db.query(UsageEvent).filter(
and_(
UsageEvent.tenant_id == tenant_id,
func.extract('year', UsageEvent.request_timestamp) == target_year,
func.extract('month', UsageEvent.request_timestamp) == target_month
)
)
else: # all time
query = db.query(UsageEvent).filter(UsageEvent.tenant_id == tenant_id)
events = query.all()
# Calculate totals
total_cost = sum(e.estimated_cost_usd for e in events)
total_requests = len(events)
total_tokens = sum(e.total_tokens for e in events)
# Breakdown by provider
breakdown_by_provider = {}
for event in events:
provider = event.provider
if provider not in breakdown_by_provider:
breakdown_by_provider[provider] = {
"requests": 0,
"tokens": 0,
"cost_usd": 0.0
}
breakdown_by_provider[provider]["requests"] += 1
breakdown_by_provider[provider]["tokens"] += event.total_tokens
breakdown_by_provider[provider]["cost_usd"] += event.estimated_cost_usd
# Breakdown by model
breakdown_by_model = {}
for event in events:
model = event.model
if model not in breakdown_by_model:
breakdown_by_model[model] = {
"requests": 0,
"tokens": 0,
"cost_usd": 0.0
}
breakdown_by_model[model]["requests"] += 1
breakdown_by_model[model]["tokens"] += event.total_tokens
breakdown_by_model[model]["cost_usd"] += event.estimated_cost_usd
return CostReportResponse(
tenant_id=tenant_id,
period=range,
total_cost_usd=total_cost,
total_requests=total_requests,
total_tokens=total_tokens,
breakdown_by_provider=breakdown_by_provider,
breakdown_by_model=breakdown_by_model
)
except Exception as e:
logger.error(f"Error getting cost report: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)