"""Static GTFS feed loader for CATA bus data."""
import csv
import io
import logging
import zipfile
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
import aiofiles
import aiohttp
logger = logging.getLogger(__name__)
GTFS_STATIC_URL = "https://catabus.com/wp-content/uploads/google_transit.zip"
# Detect cloud environment for cache directory
import os
def get_cache_dir():
# FastMCP Cloud, Lambda, or other cloud environments
if (os.environ.get('LAMBDA_RUNTIME_DIR') or
os.environ.get('AWS_LAMBDA_FUNCTION_NAME') or
os.environ.get('FASTMCP_CLOUD') or
os.path.exists('/tmp') and not os.path.exists(os.path.expanduser('~'))):
return Path("/tmp/catabus_cache")
else:
return Path("cache")
CACHE_DIR = get_cache_dir()
@dataclass
class Stop:
stop_id: str
stop_name: str
stop_lat: float
stop_lon: float
stop_code: Optional[str] = None
stop_desc: Optional[str] = None
@dataclass
class Route:
route_id: str
route_short_name: str
route_long_name: str
route_type: int
route_color: Optional[str] = None
route_text_color: Optional[str] = None
@dataclass
class Trip:
trip_id: str
route_id: str
service_id: str
trip_headsign: Optional[str] = None
direction_id: Optional[int] = None
shape_id: Optional[str] = None
@dataclass
class StopTime:
trip_id: str
arrival_time: str
departure_time: str
stop_id: str
stop_sequence: int
pickup_type: Optional[int] = None
drop_off_type: Optional[int] = None
@dataclass
class GTFSData:
routes: Dict[str, Route] = field(default_factory=dict)
stops: Dict[str, Stop] = field(default_factory=dict)
trips: Dict[str, Trip] = field(default_factory=dict)
stop_times: List[StopTime] = field(default_factory=list)
shapes: Dict[str, List[Dict[str, Any]]] = field(default_factory=dict)
last_updated: Optional[datetime] = None
class StaticGTFSLoader:
def __init__(self):
self.data = GTFSData()
CACHE_DIR.mkdir(exist_ok=True)
async def download_feed(self, timeout_seconds: int = 15) -> bytes:
"""Download the static GTFS feed with strict timeout for cloud deployment."""
timeout = aiohttp.ClientTimeout(total=timeout_seconds, connect=5)
async with aiohttp.ClientSession(timeout=timeout) as session:
logger.info(f"Downloading GTFS feed from {GTFS_STATIC_URL} (timeout: {timeout_seconds}s)")
try:
async with session.get(GTFS_STATIC_URL) as response:
response.raise_for_status()
content = await response.read()
logger.info(f"Downloaded GTFS feed: {len(content)} bytes")
return content
except asyncio.TimeoutError:
logger.error(f"Download timed out after {timeout_seconds}s")
raise
except Exception as e:
logger.error(f"Download failed: {e}")
raise
def parse_csv(self, content: str) -> List[Dict[str, str]]:
"""Parse CSV content into list of dictionaries."""
reader = csv.DictReader(io.StringIO(content))
return list(reader)
async def load_feed(self, force_refresh: bool = False, timeout_seconds: int = 30) -> GTFSData:
"""Load and parse the GTFS static feed with improved error handling."""
cache_file = CACHE_DIR / "google_transit.zip"
feed_data: Optional[bytes] = None
# Attempt to download fresh data first
if force_refresh or not cache_file.exists():
logger.info("Downloading fresh GTFS feed.")
try:
feed_data = await self.download_feed(timeout_seconds)
async with aiofiles.open(cache_file, "wb") as f:
await f.write(feed_data)
logger.info("Successfully downloaded and cached GTFS feed.")
except Exception as e:
logger.error(f"Failed to download GTFS feed: {e}")
if cache_file.exists():
logger.warning("Using stale cached feed due to download failure.")
async with aiofiles.open(cache_file, "rb") as f:
feed_data = await f.read()
else:
logger.critical("No cached GTFS data available. The server will run without static data.")
return self.data # Return empty data
else:
# Use cached data if it's not too old
age = datetime.now().timestamp() - cache_file.stat().st_mtime
if age < 86400: # 24 hours
logger.info("Using fresh cached GTFS feed.")
async with aiofiles.open(cache_file, "rb") as f:
feed_data = await f.read()
else:
logger.info("Cached GTFS feed is stale. Attempting to refresh.")
try:
feed_data = await self.download_feed(timeout_seconds)
async with aiofiles.open(cache_file, "wb") as f:
await f.write(feed_data)
logger.info("Successfully refreshed and cached GTFS feed.")
except Exception as e:
logger.warning(f"Failed to refresh stale cache: {e}. Using stale cache.")
async with aiofiles.open(cache_file, "rb") as f:
feed_data = await f.read()
if not feed_data:
logger.error("Failed to load GTFS data from any source.")
return self.data # Return empty data
# Parse the feed
with zipfile.ZipFile(io.BytesIO(feed_data)) as zf:
# Load routes
if "routes.txt" in zf.namelist():
content = zf.read("routes.txt").decode("utf-8-sig")
for row in self.parse_csv(content):
route = Route(
route_id=row["route_id"],
route_short_name=row.get("route_short_name", ""),
route_long_name=row.get("route_long_name", ""),
route_type=int(row.get("route_type", 3)),
route_color=row.get("route_color"),
route_text_color=row.get("route_text_color"),
)
self.data.routes[route.route_id] = route
# Load stops
if "stops.txt" in zf.namelist():
content = zf.read("stops.txt").decode("utf-8-sig")
for row in self.parse_csv(content):
stop = Stop(
stop_id=row["stop_id"],
stop_name=row["stop_name"],
stop_lat=float(row["stop_lat"]),
stop_lon=float(row["stop_lon"]),
stop_code=row.get("stop_code"),
stop_desc=row.get("stop_desc"),
)
self.data.stops[stop.stop_id] = stop
# Load trips
if "trips.txt" in zf.namelist():
content = zf.read("trips.txt").decode("utf-8-sig")
for row in self.parse_csv(content):
trip = Trip(
trip_id=row["trip_id"],
route_id=row["route_id"],
service_id=row["service_id"],
trip_headsign=row.get("trip_headsign"),
direction_id=int(row["direction_id"]) if row.get("direction_id") else None,
shape_id=row.get("shape_id"),
)
self.data.trips[trip.trip_id] = trip
# Load stop times
if "stop_times.txt" in zf.namelist():
content = zf.read("stop_times.txt").decode("utf-8-sig")
for row in self.parse_csv(content):
stop_time = StopTime(
trip_id=row["trip_id"],
arrival_time=row["arrival_time"],
departure_time=row["departure_time"],
stop_id=row["stop_id"],
stop_sequence=int(row["stop_sequence"]),
pickup_type=int(row["pickup_type"]) if row.get("pickup_type") else None,
drop_off_type=int(row["drop_off_type"]) if row.get("drop_off_type") else None,
)
self.data.stop_times.append(stop_time)
# Load shapes (optional)
if "shapes.txt" in zf.namelist():
content = zf.read("shapes.txt").decode("utf-8-sig")
for row in self.parse_csv(content):
shape_id = row["shape_id"]
if shape_id not in self.data.shapes:
self.data.shapes[shape_id] = []
self.data.shapes[shape_id].append({
"lat": float(row["shape_pt_lat"]),
"lon": float(row["shape_pt_lon"]),
"sequence": int(row["shape_pt_sequence"]),
})
self.data.last_updated = datetime.now()
logger.info(f"Loaded {len(self.data.routes)} routes, {len(self.data.stops)} stops, "
f"{len(self.data.trips)} trips, {len(self.data.stop_times)} stop times")
return self.data