"""Trend analysis tools for MCP."""
import logging
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from typing import Any
import httpx
from mcp.types import TextContent, Tool
from jana_mcp.client import APIError, AuthenticationError, JanaClient
from jana_mcp.constants import (
BBOX_COORDINATES_REQUIRED,
MIN_TREND_DATA_POINTS,
POINT_COORDINATES_REQUIRED,
TREND_DATA_LIMIT,
)
from jana_mcp.tools.response import create_error_response, serialize_response
from jana_mcp.tools.validation import validate_bbox, validate_coordinates, validate_date_format
logger = logging.getLogger(__name__)
TRENDS_TOOL = Tool(
name="get_trends",
description="""Analyze temporal trends in environmental data.
Returns time-series analysis showing how measurements change over time,
with aggregations at specified intervals (daily, weekly, monthly).""",
inputSchema={
"type": "object",
"properties": {
"source": {
"type": "string",
"enum": ["openaq", "climatetrace", "edgar"],
"description": "Data source to analyze",
},
"parameter": {
"type": "string",
"description": "Parameter to analyze (e.g., 'pm25', 'co2')",
},
"location_bbox": {
"type": "array",
"items": {"type": "number"},
"minItems": 4,
"maxItems": 4,
"description": "Bounding box [min_lon, min_lat, max_lon, max_lat]",
},
"country_codes": {
"type": "array",
"items": {"type": "string"},
"description": "ISO-3 country codes",
},
"date_from": {
"type": "string",
"format": "date",
"description": "Start date (ISO 8601)",
},
"date_to": {
"type": "string",
"format": "date",
"description": "End date (ISO 8601)",
},
"temporal_resolution": {
"type": "string",
"enum": ["daily", "weekly", "monthly"],
"default": "monthly",
"description": "Time aggregation interval",
},
},
"required": ["source", "parameter"],
},
)
def aggregate_by_period(
data: list[dict[str, Any]], resolution: str
) -> dict[str, list[float]]:
"""
Aggregate values by time period.
Args:
data: List of data dictionaries with timestamp and value fields
resolution: Time resolution ("daily", "weekly", or "monthly")
Returns:
Dictionary mapping period keys to lists of values
"""
aggregated: dict[str, list[float]] = defaultdict(list)
for item in data:
# Try to extract timestamp from various possible fields
timestamp_str = (
item.get("datetime")
or item.get("date")
or item.get("timestamp")
or item.get("measurement_date")
)
if not timestamp_str:
continue
try:
# Parse timestamp with timezone awareness
if isinstance(timestamp_str, str):
# Try ISO format parsing (timezone-aware)
try:
# fromisoformat handles ISO 8601 with timezone
dt = datetime.fromisoformat(timestamp_str.replace("Z", "+00:00"))
except ValueError:
# Fallback to naive parsing and assume UTC
try:
# Try common formats
for fmt in ["%Y-%m-%dT%H:%M:%S", "%Y-%m-%d", "%Y-%m-%dT%H:%M:%SZ"]:
try:
dt = datetime.strptime(timestamp_str[:19], fmt)
dt = dt.replace(tzinfo=timezone.utc)
break
except ValueError:
continue
else:
continue
except (ValueError, TypeError):
continue
# Ensure timezone-aware
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
elif isinstance(timestamp_str, datetime):
dt = timestamp_str
# Ensure timezone-aware
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
else:
continue
# Get period key based on resolution
if resolution == "daily":
period_key = dt.strftime("%Y-%m-%d")
elif resolution == "weekly":
period_key = dt.strftime("%Y-W%W")
else: # monthly
period_key = dt.strftime("%Y-%m")
# Extract value
value = item.get("value") or item.get("emissions") or item.get("measurement")
if value is not None:
try:
aggregated[period_key].append(float(value))
except (ValueError, TypeError):
# Skip invalid numeric values
continue
except (ValueError, TypeError, KeyError, AttributeError) as e:
# Skip items with invalid timestamp or value formats
logger.debug("Skipping invalid data item: %s", e)
continue
return dict(aggregated)
def calculate_trend_stats(aggregated: dict[str, list[float]]) -> dict[str, Any]:
"""
Calculate statistics from aggregated data.
Args:
aggregated: Dictionary mapping period keys to lists of values
Returns:
Dictionary with trend statistics and summary
"""
if not aggregated:
return {"error": "No data available for trend analysis"}
# Sort by period
sorted_periods = sorted(aggregated.keys())
# Calculate averages per period
trend_data: list[dict[str, Any]] = []
values_for_trend: list[float] = []
for period in sorted_periods:
values = aggregated[period]
avg = sum(values) / len(values)
trend_data.append({
"period": period,
"average": round(avg, 4),
"min": round(min(values), 4),
"max": round(max(values), 4),
"count": len(values),
})
values_for_trend.append(avg)
# Calculate overall trend (simple linear)
if len(values_for_trend) >= MIN_TREND_DATA_POINTS:
first_half_avg = sum(values_for_trend[: len(values_for_trend) // 2]) / (
len(values_for_trend) // 2
)
second_half_avg = sum(values_for_trend[len(values_for_trend) // 2 :]) / (
len(values_for_trend) - len(values_for_trend) // 2
)
trend_direction = "increasing" if second_half_avg > first_half_avg else "decreasing"
if first_half_avg == 0:
# Avoid division by zero - use absolute change instead
change_pct = float("inf") if second_half_avg > 0 else 0.0
else:
change_pct = ((second_half_avg - first_half_avg) / first_half_avg) * 100
else:
trend_direction = "insufficient data"
change_pct = 0
return {
"periods": trend_data,
"summary": {
"total_periods": len(sorted_periods),
"date_range": {"from": sorted_periods[0], "to": sorted_periods[-1]},
"trend_direction": trend_direction,
"change_percent": round(change_pct, 2),
"overall_average": round(sum(values_for_trend) / len(values_for_trend), 4),
},
}
async def execute_trends(
client: JanaClient, arguments: dict[str, Any]
) -> list[TextContent]:
"""
Execute the trends analysis tool.
Args:
client: Jana API client
arguments: Tool arguments from MCP call
Returns:
List of TextContent with results or error message
"""
logger.info("Executing get_trends")
source = arguments.get("source")
parameter = arguments.get("parameter")
if not source or not parameter:
error_response = create_error_response(
"source and parameter are required", "VALIDATION_ERROR"
)
return [TextContent(type="text", text=serialize_response(error_response))]
# Validate location filter (trends supports bbox, country_codes, or point+radius)
# Note: trends uses require_location_filter but also accepts location_point
location_bbox = arguments.get("location_bbox")
location_point = arguments.get("location_point")
country_codes = arguments.get("country_codes")
has_location = (
(location_bbox and len(location_bbox) >= BBOX_COORDINATES_REQUIRED)
or (location_point and len(location_point) >= POINT_COORDINATES_REQUIRED)
or (country_codes and len(country_codes) > 0)
)
if not has_location:
validation_error = create_error_response(
"At least one location filter is required (location_bbox, location_point+radius_km, or country_codes)",
"VALIDATION_ERROR",
)
return [TextContent(type="text", text=serialize_response(validation_error))]
# Validate coordinates and dates
if location_bbox:
bbox_error = validate_bbox(location_bbox)
if bbox_error:
validation_error = create_error_response(bbox_error, "VALIDATION_ERROR")
return [TextContent(type="text", text=serialize_response(validation_error))]
if location_point and len(location_point) >= POINT_COORDINATES_REQUIRED:
coord_error = validate_coordinates(location_point[0], location_point[1])
if coord_error:
validation_error = create_error_response(coord_error, "VALIDATION_ERROR")
return [TextContent(type="text", text=serialize_response(validation_error))]
# Validate date formats if provided
date_from = arguments.get("date_from")
if date_from:
date_error = validate_date_format(date_from)
if date_error:
validation_error = create_error_response(date_error, "VALIDATION_ERROR")
return [TextContent(type="text", text=serialize_response(validation_error))]
date_to = arguments.get("date_to")
if date_to:
date_error = validate_date_format(date_to)
if date_error:
validation_error = create_error_response(date_error, "VALIDATION_ERROR")
return [TextContent(type="text", text=serialize_response(validation_error))]
# Get resolution early so it's available for all responses
resolution = arguments.get("temporal_resolution", "monthly")
# Set default date range (last 12 months) if not provided for meaningful trend analysis
if not date_from or not date_to:
now = datetime.now(timezone.utc)
if not date_to:
date_to = now.strftime("%Y-%m-%d")
if not date_from:
# Default to 12 months ago
twelve_months_ago = now - timedelta(days=365)
date_from = twelve_months_ago.strftime("%Y-%m-%d")
try:
# Fetch data based on source
if source == "openaq":
result = await client.get_air_quality(
bbox=arguments.get("location_bbox"),
point=arguments.get("location_point"),
radius_km=arguments.get("radius_km"),
country_codes=arguments.get("country_codes"),
parameters=[parameter],
date_from=date_from,
date_to=date_to,
limit=TREND_DATA_LIMIT,
)
else:
result = await client.get_emissions(
sources=[source],
bbox=arguments.get("location_bbox"),
point=arguments.get("location_point"),
radius_km=arguments.get("radius_km"),
country_codes=arguments.get("country_codes"),
date_from=date_from,
date_to=date_to,
limit=TREND_DATA_LIMIT,
)
# Extract data list from result
data = result.get("results", result.get("data", []))
if not data:
no_data_response: dict[str, Any] = {
"source": source,
"parameter": parameter,
"temporal_resolution": resolution,
"success": False,
"error": "No data found for the specified criteria",
"error_code": "NO_DATA",
}
return [
TextContent(
type="text",
text=serialize_response(no_data_response),
)
]
# Aggregate and calculate trends
aggregated = aggregate_by_period(data, resolution)
trend_stats = calculate_trend_stats(aggregated)
response: dict[str, Any] = {
"source": source,
"parameter": parameter,
"temporal_resolution": resolution,
"analysis": trend_stats,
}
return [
TextContent(
type="text",
text=serialize_response(response),
)
]
except (APIError, AuthenticationError) as e:
logger.exception("API error in get_trends")
error_response = create_error_response(
f"API error analyzing trends: {e}", "API_ERROR"
)
return [TextContent(type="text", text=serialize_response(error_response))]
except httpx.RequestError as e:
logger.exception("Network error in get_trends")
error_response = create_error_response(
f"Network error analyzing trends: {e}", "NETWORK_ERROR"
)
return [TextContent(type="text", text=serialize_response(error_response))]
except (KeyError, ValueError, TypeError) as e:
logger.exception("Data parsing error in get_trends")
error_response = create_error_response(
f"Invalid data format: {e}", "DATA_ERROR"
)
return [TextContent(type="text", text=serialize_response(error_response))]