"""Jana Backend API Client.
This module provides a wrapper around the Jana backend API for use by MCP tools.
It handles authentication, request formatting, and response parsing.
"""
import asyncio
import json
import logging
from typing import Any
import httpx
from jana_mcp.config import Settings, get_settings
logger = logging.getLogger(__name__)
class JanaClientError(Exception):
"""Base exception for Jana client errors."""
pass
class AuthenticationError(JanaClientError):
"""Raised when authentication fails."""
pass
class APIError(JanaClientError):
"""Raised when an API request fails."""
def __init__(self, message: str, status_code: int | None = None, response: Any = None):
super().__init__(message)
self.status_code = status_code
self.response = response
class JanaClient:
"""
Client for interacting with the Jana backend API.
Handles authentication and provides methods for accessing
environmental data endpoints.
"""
def __init__(self, settings: Settings | None = None):
"""
Initialize the Jana client.
Args:
settings: Application settings. If None, uses default settings.
"""
self.settings = settings or get_settings()
self._token: str | None = self.settings.get_auth_token()
self._client: httpx.AsyncClient | None = None
self._client_lock = asyncio.Lock()
@property
def base_url(self) -> str:
"""Get the base URL for API requests."""
return self.settings.jana_backend_url.rstrip("/")
async def _get_client(self) -> httpx.AsyncClient:
"""Get or create the HTTP client (thread-safe)."""
async with self._client_lock:
if self._client is None or self._client.is_closed:
# Build headers, including optional Host override for Docker networking
headers: dict[str, str] = {}
if self.settings.jana_host_header:
headers["Host"] = self.settings.jana_host_header
self._client = httpx.AsyncClient(
base_url=self.base_url,
timeout=httpx.Timeout(self.settings.jana_timeout),
limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
headers=headers if headers else None,
)
return self._client
async def close(self) -> None:
"""Close the HTTP client."""
if self._client and not self._client.is_closed:
await self._client.aclose()
self._client = None
async def __aenter__(self) -> "JanaClient":
"""Async context manager entry."""
return self
async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None:
"""Async context manager exit - cleanup resources."""
await self.close()
async def authenticate(self) -> str:
"""
Authenticate with the Jana backend and obtain a token.
Returns:
Authentication token
Raises:
AuthenticationError: If authentication fails
"""
# If we already have a token, return it
if self._token:
return self._token
# Check if we have credentials
if not self.settings.jana_username or not self.settings.jana_password:
raise AuthenticationError(
"No authentication credentials configured. "
"Provide JANA_TOKEN or JANA_USERNAME/JANA_PASSWORD."
)
logger.info("Authenticating with Jana backend...")
client = await self._get_client()
try:
response = await client.post(
"/api/auth/login/",
json={
"username": self.settings.jana_username,
"password": self.settings.get_password(),
},
)
from jana_mcp.constants import HTTP_OK, HTTP_UNAUTHORIZED
if response.status_code == HTTP_OK:
try:
data = response.json()
except json.JSONDecodeError as e:
raise AuthenticationError(f"Invalid JSON in authentication response: {e}") from e
self._token = data.get("token")
if not self._token:
raise AuthenticationError("No token in authentication response")
logger.info("Successfully authenticated with Jana backend")
return self._token
elif response.status_code == HTTP_UNAUTHORIZED:
raise AuthenticationError("Invalid username or password")
else:
# Don't expose response.text which may contain sensitive data
raise AuthenticationError(
f"Authentication failed with status {response.status_code}"
)
except httpx.RequestError as e:
raise AuthenticationError(f"Failed to connect to Jana backend: {e}") from e
def _get_headers(self) -> dict[str, str]:
"""Get headers for authenticated requests."""
headers = {"Content-Type": "application/json"}
if self._token:
headers["Authorization"] = f"Token {self._token}"
return headers
async def request(
self,
method: str,
endpoint: str,
params: dict[str, Any] | None = None,
json_data: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""
Make an authenticated request to the Jana API.
Args:
method: HTTP method (GET, POST, etc.)
endpoint: API endpoint (e.g., "/api/v1/esg/air-quality/")
params: Query parameters
json_data: JSON body for POST requests
Returns:
JSON response data
Raises:
APIError: If the request fails
"""
# Ensure we're authenticated
if not self._token:
await self.authenticate()
client = await self._get_client()
headers = self._get_headers()
# Clean params - remove None values
if params:
params = {k: v for k, v in params.items() if v is not None}
logger.debug("Request: %s %s params=%s", method, endpoint, params)
try:
response = await client.request(
method=method,
url=endpoint,
params=params,
json=json_data,
headers=headers,
)
from jana_mcp.constants import HTTP_UNAUTHORIZED
if response.status_code == HTTP_UNAUTHORIZED:
# Token might have expired, try to re-authenticate
logger.warning("Token expired, re-authenticating...")
self._token = None
await self.authenticate()
headers = self._get_headers()
response = await client.request(
method=method,
url=endpoint,
params=params,
json=json_data,
headers=headers,
)
from jana_mcp.constants import HTTP_BAD_REQUEST
if response.status_code >= HTTP_BAD_REQUEST:
raise APIError(
f"API request failed: {response.status_code}",
status_code=response.status_code,
response=response.text,
)
try:
response_data: dict[str, Any] = response.json()
return response_data
except json.JSONDecodeError as e:
raise APIError(
f"Invalid JSON in API response: {e}",
status_code=response.status_code,
response=response.text,
) from e
except httpx.RequestError as e:
raise APIError(f"Request failed: {e}") from e
# ==========================================
# Convenience methods for specific endpoints
# ==========================================
async def get_air_quality(
self,
*,
bbox: list[float] | None = None,
point: list[float] | None = None,
radius_km: float | None = None,
country_codes: list[str] | None = None,
parameters: list[str] | None = None,
date_from: str | None = None,
date_to: str | None = None,
limit: int = 100,
) -> dict[str, Any]:
"""
Query air quality measurements.
Args:
bbox: Bounding box [min_lon, min_lat, max_lon, max_lat]
point: Point coordinates [lon, lat]
radius_km: Search radius in kilometers
country_codes: List of ISO-3 country codes
parameters: Pollutant parameters (pm25, pm10, o3, no2, so2, co)
date_from: Start date (ISO 8601)
date_to: End date (ISO 8601)
limit: Maximum results to return
Returns:
Air quality measurement data
"""
params: dict[str, Any] = {"limit": limit}
if bbox:
params["bbox"] = ",".join(map(str, bbox))
if point:
params["coordinates"] = ",".join(map(str, point))
if radius_km:
params["radius"] = radius_km
if country_codes:
# OpenAQ uses ISO-3 country codes
params["countries_id"] = ",".join(country_codes)
if parameters:
params["parameters"] = ",".join(parameters)
if date_from:
params["date_from"] = date_from
if date_to:
params["date_to"] = date_to
return await self.request("GET", "/api/v1/data-sources/openaq/measurements/", params=params)
async def get_emissions(
self,
*,
sources: list[str] | None = None,
bbox: list[float] | None = None,
point: list[float] | None = None,
radius_km: float | None = None,
country_codes: list[str] | None = None,
sectors: list[str] | None = None,
gases: list[str] | None = None,
date_from: str | None = None,
date_to: str | None = None,
limit: int = 100,
) -> dict[str, Any]:
"""
Query greenhouse gas emissions.
Args:
sources: Data sources (climatetrace, edgar)
bbox: Bounding box [min_lon, min_lat, max_lon, max_lat]
point: Point coordinates [lon, lat]
radius_km: Search radius in kilometers
country_codes: List of ISO-3 country codes
sectors: Sector filters
gases: Gas type filters (co2, ch4, n2o)
date_from: Start date
date_to: End date
limit: Maximum results
Returns:
Emissions data
"""
params: dict[str, Any] = {"limit": limit}
if bbox:
params["bbox"] = ",".join(map(str, bbox))
if point:
params["coordinates"] = ",".join(map(str, point))
if radius_km:
params["radius"] = radius_km
if country_codes:
# Climate Trace uses ISO-3 country codes (country_iso3 field)
params["country_iso3"] = ",".join(country_codes)
if sectors:
params["sector_name"] = ",".join(sectors)
if gases:
params["gas"] = ",".join(gases)
if date_from:
params["start_time__gte"] = date_from
if date_to:
params["end_time__lte"] = date_to
# Default to climatetrace emissions endpoint
# Future: could support edgar based on 'sources' parameter
return await self.request("GET", "/api/v1/data-sources/climatetrace/emissions/", params=params)
async def find_nearby(
self,
*,
point: list[float],
radius_km: float,
sources: list[str] | None = None,
limit: int = 100,
) -> dict[str, Any]:
"""
Find environmental monitoring stations near a location.
Args:
point: Point coordinates [lon, lat]
radius_km: Search radius in kilometers
sources: Data sources to search (openaq, climatetrace)
limit: Maximum results
Returns:
Nearby monitoring stations/assets
"""
params: dict[str, Any] = {
"coordinates": ",".join(map(str, point)),
"radius": radius_km,
"limit": limit,
}
# Default to OpenAQ locations for nearby search
# Future: could support climatetrace assets
return await self.request("GET", "/api/v1/data-sources/openaq/locations/", params=params)
async def get_summary(self) -> dict[str, Any]:
"""
Get platform data summary.
Returns:
Summary of data coverage, record counts by source,
geographic coverage, and data freshness.
"""
return await self.request("GET", "/api/v1/esg/summary/")
async def check_health(self) -> dict[str, Any]:
"""
Check backend health status.
Returns:
Health status information
"""
# Health endpoint doesn't require auth
from jana_mcp.constants import HTTP_OK
client = await self._get_client()
try:
response = await client.get("/health/")
if response.status_code == HTTP_OK:
try:
backend_data = response.json()
except json.JSONDecodeError:
backend_data = {"raw_status": "invalid_json"}
return {"status": "healthy", "backend": backend_data}
else:
return {"status": "unhealthy", "status_code": response.status_code}
except httpx.RequestError:
return {"status": "unreachable", "error": "Network error"}