"""Realtime GTFS feed poller for CATA bus data."""
import asyncio
import logging
from dataclasses import dataclass, field
from datetime import UTC, datetime
import aiohttp
from google.transit import gtfs_realtime_pb2
from pydantic import BaseModel, Field
from typing import Any
logger = logging.getLogger(__name__)
# GTFS-RT endpoints (without debug parameter for protobuf format)
VEHICLE_POSITIONS_URL = (
"https://realtime.catabus.com/InfoPoint/GTFS-Realtime.ashx?Type=VehiclePosition"
)
TRIP_UPDATES_URL = "https://realtime.catabus.com/InfoPoint/GTFS-Realtime.ashx?Type=TripUpdate"
ALERTS_URL = "https://realtime.catabus.com/InfoPoint/GTFS-Realtime.ashx?Type=Alert"
# Rate limit: minimum 10 seconds between requests to same endpoint
MIN_POLL_INTERVAL = 15 # seconds
class VehiclePosition(BaseModel):
vehicle_id: str
trip_id: str | None = None
route_id: str | None = None
latitude: float
longitude: float
bearing: float | None = None
speed: float | None = None # meters per second
timestamp: datetime
occupancy_status: str | None = None
class TripUpdate(BaseModel):
trip_id: str
route_id: str | None = None
vehicle_id: str | None = None
timestamp: datetime
stop_time_updates: list[dict[str, Any]] = Field(default_factory=list)
class ServiceAlert(BaseModel):
alert_id: str
header: str
description: str | None = None
severity: str = "UNKNOWN"
active_periods: list[dict[str, Any]] = Field(default_factory=list)
informed_entities: list[dict[str, Any]] = Field(default_factory=list)
@dataclass
class RealtimeData:
vehicle_positions: dict[str, VehiclePosition] = field(default_factory=dict)
trip_updates: dict[str, TripUpdate] = field(default_factory=dict)
alerts: list[ServiceAlert] = field(default_factory=list)
last_vehicle_update: datetime | None = None
last_trip_update: datetime | None = None
last_alert_update: datetime | None = None
class RealtimeGTFSPoller:
def __init__(self) -> None:
self.data = RealtimeData()
self._running = False
self._session: aiohttp.ClientSession | None = None
async def start(self) -> None:
"""Start the polling tasks without blocking startup."""
if self._running:
return
self._running = True
self._session = aiohttp.ClientSession()
# Start tasks immediately for fast cloud startup
# Use background task with internal staggering to avoid blocking
asyncio.create_task(self._start_staggered_polling())
async def _start_staggered_polling(self) -> None:
"""Start polling tasks with internal staggering - non-blocking."""
try:
# Start immediately without blocking main startup
asyncio.create_task(self._poll_vehicle_positions())
# Stagger other tasks in background
await asyncio.sleep(5)
asyncio.create_task(self._poll_trip_updates())
await asyncio.sleep(5)
asyncio.create_task(self._poll_alerts())
except Exception as e:
logger.error(f"Error starting staggered polling: {e}")
async def stop(self) -> None:
"""Stop the polling tasks."""
self._running = False
if self._session:
await self._session.close()
async def _fetch_protobuf(self, url: str) -> bytes | None:
"""Fetch protobuf data from URL."""
if not self._session:
return None
try:
async with self._session.get(url, timeout=30) as response:
response.raise_for_status()
return await response.read()
except Exception as e:
logger.error(f"Error fetching {url}: {e}")
return None
async def _poll_vehicle_positions(self) -> None:
"""Poll vehicle positions endpoint."""
while self._running:
try:
data = await self._fetch_protobuf(VEHICLE_POSITIONS_URL)
if data:
feed = gtfs_realtime_pb2.FeedMessage()
feed.ParseFromString(data)
positions = {}
for entity in feed.entity:
if entity.HasField("vehicle"):
vehicle = entity.vehicle
if vehicle.HasField("position"):
pos = VehiclePosition(
vehicle_id=(
vehicle.vehicle.id if vehicle.vehicle.id else entity.id
),
trip_id=(
vehicle.trip.trip_id if vehicle.HasField("trip") else None
),
route_id=(
vehicle.trip.route_id if vehicle.HasField("trip") else None
),
latitude=vehicle.position.latitude,
longitude=vehicle.position.longitude,
bearing=(
vehicle.position.bearing
if vehicle.position.HasField("bearing")
else None
),
speed=(
vehicle.position.speed
if vehicle.position.HasField("speed")
else None
),
timestamp=datetime.fromtimestamp(vehicle.timestamp, tz=UTC),
occupancy_status=(
self._get_occupancy_status(vehicle.occupancy_status)
if vehicle.HasField("occupancy_status")
else None
),
)
positions[pos.vehicle_id] = pos
self.data.vehicle_positions = positions
self.data.last_vehicle_update = datetime.now(UTC)
logger.debug(f"Updated {len(positions)} vehicle positions")
except Exception as e:
logger.error(f"Error polling vehicle positions: {e}")
await asyncio.sleep(MIN_POLL_INTERVAL)
async def _poll_trip_updates(self) -> None:
"""Poll trip updates endpoint."""
while self._running:
try:
data = await self._fetch_protobuf(TRIP_UPDATES_URL)
if data:
feed = gtfs_realtime_pb2.FeedMessage()
feed.ParseFromString(data)
updates = {}
for entity in feed.entity:
if entity.HasField("trip_update"):
trip_update = entity.trip_update
stop_time_updates = []
for stu in trip_update.stop_time_update:
update = {
"stop_id": stu.stop_id,
"stop_sequence": (
stu.stop_sequence if stu.HasField("stop_sequence") else None
),
}
if stu.HasField("arrival"):
update["arrival_delay"] = (
stu.arrival.delay if stu.arrival.HasField("delay") else 0
)
update["arrival_time"] = (
stu.arrival.time if stu.arrival.HasField("time") else None
)
if stu.HasField("departure"):
update["departure_delay"] = (
stu.departure.delay
if stu.departure.HasField("delay")
else 0
)
update["departure_time"] = (
stu.departure.time
if stu.departure.HasField("time")
else None
)
stop_time_updates.append(update)
update = TripUpdate(
trip_id=trip_update.trip.trip_id,
route_id=(
trip_update.trip.route_id
if trip_update.trip.HasField("route_id")
else None
),
vehicle_id=(
trip_update.vehicle.id
if trip_update.HasField("vehicle")
else None
),
timestamp=(
datetime.fromtimestamp(trip_update.timestamp, tz=UTC)
if trip_update.HasField("timestamp")
else datetime.now(UTC)
),
stop_time_updates=stop_time_updates,
)
updates[update.trip_id] = update
self.data.trip_updates = updates
self.data.last_trip_update = datetime.now(UTC)
logger.info(f"Successfully updated {len(updates)} trip updates.")
else:
logger.warning("No data received from trip updates endpoint.")
except Exception as e:
logger.error(
f"An exception occurred while polling trip updates: {e}", exc_info=True
)
await asyncio.sleep(MIN_POLL_INTERVAL)
async def _poll_alerts(self) -> None:
"""Poll service alerts endpoint."""
while self._running:
try:
data = await self._fetch_protobuf(ALERTS_URL)
if data:
feed = gtfs_realtime_pb2.FeedMessage()
feed.ParseFromString(data)
alerts = []
for entity in feed.entity:
if entity.HasField("alert"):
alert = entity.alert
# Get header text (handling translations)
header_text = ""
if alert.HasField("header_text"):
for translation in alert.header_text.translation:
if translation.language == "en" or not header_text:
header_text = translation.text
# Get description text
description_text = ""
if alert.HasField("description_text"):
for translation in alert.description_text.translation:
if translation.language == "en" or not description_text:
description_text = translation.text
# Get active periods
active_periods = []
for period in alert.active_period:
active_periods.append(
{
"start": (
datetime.fromtimestamp(period.start, tz=UTC)
if period.HasField("start")
else None
),
"end": (
datetime.fromtimestamp(period.end, tz=UTC)
if period.HasField("end")
else None
),
}
)
# Get informed entities
informed_entities = []
for entity in alert.informed_entity:
ie = {}
if entity.HasField("route_id"):
ie["route_id"] = entity.route_id
if entity.HasField("trip"):
ie["trip_id"] = entity.trip.trip_id
if entity.HasField("stop_id"):
ie["stop_id"] = entity.stop_id
informed_entities.append(ie)
service_alert = ServiceAlert(
alert_id=entity.id,
header=header_text,
description=description_text,
severity=(
self._get_severity(alert.severity_level)
if alert.HasField("severity_level")
else "UNKNOWN"
),
active_periods=active_periods,
informed_entities=informed_entities,
)
alerts.append(service_alert)
self.data.alerts = alerts
self.data.last_alert_update = datetime.now(UTC)
logger.debug(f"Updated {len(alerts)} alerts")
except Exception as e:
logger.error(f"Error polling alerts: {e}")
await asyncio.sleep(MIN_POLL_INTERVAL)
def _get_occupancy_status(self, status: int) -> str:
"""Convert occupancy status enum to string."""
mapping = {
0: "EMPTY",
1: "MANY_SEATS_AVAILABLE",
2: "FEW_SEATS_AVAILABLE",
3: "STANDING_ROOM_ONLY",
4: "CRUSHED_STANDING_ROOM_ONLY",
5: "FULL",
6: "NOT_ACCEPTING_PASSENGERS",
}
return mapping.get(status, "UNKNOWN")
def _get_severity(self, level: int) -> str:
"""Convert severity level enum to string."""
mapping = {
1: "UNKNOWN",
2: "INFO",
3: "WARNING",
4: "SEVERE",
}
return mapping.get(level, "UNKNOWN")