"""Strava API client for interacting with the Strava API."""
import time
from typing import Any, Optional, List
import httpx
class StravaClient:
"""Client for interacting with the Strava API."""
BASE_URL = "https://www.strava.com/api/v3"
def __init__(self, refresh_token: str, client_id: str, client_secret: str):
"""
Initialize the Strava API client.
Args:
refresh_token: Refresh token for Strava API
client_id: Client ID for Strava API
client_secret: Client secret for Strava API
"""
self.refresh_token = refresh_token
self.client_id = client_id
self.client_secret = client_secret
self.access_token: Optional[str] = None
self.expires_at = 0
self.client = httpx.Client(timeout=30.0)
def _ensure_valid_token(self) -> None:
"""Ensure we have a valid access token, refreshing if necessary."""
current_time = int(time.time())
if not self.access_token or current_time >= self.expires_at:
self._refresh_token()
def _refresh_token(self) -> None:
"""Refresh the access token using the refresh token."""
refresh_url = "https://www.strava.com/oauth/token"
payload = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"refresh_token": self.refresh_token,
"grant_type": "refresh_token",
}
response = self.client.post(refresh_url, data=payload)
if response.status_code != 200:
raise Exception(f"Error {response.status_code}: {response.text}")
token_data = response.json()
self.access_token = token_data["access_token"]
self.expires_at = token_data["expires_at"]
def _make_request(self, endpoint: str, params: Optional[dict] = None) -> Any:
"""Make an authenticated request to the Strava API."""
self._ensure_valid_token()
url = f"{self.BASE_URL}/{endpoint}"
headers = {"Authorization": f"Bearer {self.access_token}"}
response = self.client.get(url, headers=headers, params=params)
if response.status_code != 200:
raise Exception(f"Error {response.status_code}: {response.text}")
return response.json()
def get_activities(
self,
limit: int = 10,
before: Optional[int] = None,
after: Optional[int] = None,
sport_type: Optional[str] = None,
) -> List[dict]:
"""
Get the authenticated athlete's activities.
Args:
limit: Maximum number of activities to return
before: Unix timestamp to filter activities before this time
after: Unix timestamp to filter activities after this time
sport_type: Filter by sport type (e.g., 'Run', 'Ride', 'Swim')
Returns:
List of activities with standardized field names
"""
params = {"per_page": limit}
if before:
params["before"] = before
if after:
params["after"] = after
activities = self._make_request("athlete/activities", params)
filtered = [self._normalize_activity(a) for a in activities]
if sport_type:
filtered = [a for a in filtered if a.get("sport_type") == sport_type]
return filtered
def get_activity(self, activity_id: int) -> dict:
"""
Get detailed information about a specific activity.
Args:
activity_id: ID of the activity to retrieve
Returns:
Activity details with standardized field names
"""
activity = self._make_request(f"activities/{activity_id}")
return self._normalize_activity(activity)
def get_activity_streams(
self, activity_id: int, keys: Optional[List[str]] = None
) -> dict:
"""
Get detailed time-series data (streams) for a specific activity.
Args:
activity_id: ID of the activity to retrieve streams for
keys: List of stream types to retrieve
Returns:
Dictionary containing streams data
"""
if keys is None:
keys = [
"time",
"distance",
"latlng",
"altitude",
"velocity_smooth",
"heartrate",
"cadence",
"watts",
"temp",
"moving",
"grade_smooth",
]
keys_str = ",".join(keys)
params = {"keys": keys_str, "key_by_type": "true"}
return self._make_request(f"activities/{activity_id}/streams", params)
def get_activity_laps(self, activity_id: int) -> List[dict]:
"""
Get laps for a specific activity.
Args:
activity_id: ID of the activity to retrieve laps for
Returns:
List of lap data
"""
return self._make_request(f"activities/{activity_id}/laps")
def get_activity_zones(self, activity_id: int) -> dict:
"""
Get heart rate and power zone data for a specific activity.
Args:
activity_id: ID of the activity to retrieve zones for
Returns:
Dictionary containing zone data
"""
return self._make_request(f"activities/{activity_id}/zones")
def get_athlete_stats(self, athlete_id: Optional[int] = None) -> dict:
"""
Get statistics for an athlete.
Args:
athlete_id: ID of the athlete (defaults to authenticated athlete)
Returns:
Dictionary containing athlete statistics
"""
if athlete_id:
return self._make_request(f"athletes/{athlete_id}/stats")
athlete = self._make_request("athlete")
return self._make_request(f"athletes/{athlete['id']}/stats")
def get_athlete(self) -> dict:
"""
Get the authenticated athlete's profile.
Returns:
Dictionary containing athlete profile data
"""
return self._make_request("athlete")
def _normalize_activity(self, activity: dict) -> dict:
"""Normalize activity data with consistent field names and units."""
field_mapping = {
"id": "id",
"name": "name",
"description": "description",
"sport_type": "sport_type",
"start_date": "start_date",
"distance": "distance_metres",
"moving_time": "moving_time_seconds",
"elapsed_time": "elapsed_time_seconds",
"total_elevation_gain": "total_elevation_gain_metres",
"elev_high": "elev_high_metres",
"elev_low": "elev_low_metres",
"average_speed": "average_speed_mps",
"max_speed": "max_speed_mps",
"average_cadence": "average_cadence_rpm",
"average_heartrate": "average_heartrate_bpm",
"max_heartrate": "max_heartrate_bpm",
"average_watts": "average_watts",
"weighted_average_watts": "weighted_average_watts",
"kilojoules": "kilojoules",
"calories": "calories",
"workout_type": "workout_type",
"perceived_exertion": "perceived_exertion",
"suffer_score": "suffer_score",
"start_latlng": "start_latlng",
"end_latlng": "end_latlng",
}
normalized = {}
for old_key, new_key in field_mapping.items():
if old_key in activity:
normalized[new_key] = activity[old_key]
return normalized
def close(self) -> None:
"""Close the HTTP client."""
self.client.close()
def parse_date(date_str: str) -> "date":
"""
Parse a date string in ISO format (YYYY-MM-DD).
Args:
date_str: Date string in ISO format
Returns:
Date object
Raises:
ValueError: If date format is invalid
"""
from datetime import date
try:
return date.fromisoformat(date_str)
except ValueError as err:
raise ValueError(
f"Invalid date format: {date_str}. Expected format: YYYY-MM-DD"
) from err