Strava MCP Server
by yorrickjansen
Verified
- strava-mcp
- tests
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from strava_mcp.models import Activity, DetailedActivity, Segment, SegmentEffort
# Patch the StravaOAuthServer._run_server method to prevent coroutine warnings
# This must be at the module level before any imports that might create the coroutine
with patch("strava_mcp.oauth_server.StravaOAuthServer._run_server", new_callable=AsyncMock):
# Now imports will use the patched version
pass
class MockContext:
"""Mock MCP context for testing."""
def __init__(self, service):
self.request_context = MagicMock()
self.request_context.lifespan_context = {"service": service}
@pytest.fixture
def mock_service():
mock = AsyncMock()
mock.get_activities = AsyncMock()
mock.get_activity = AsyncMock()
mock.get_activity_segments = AsyncMock()
return mock
@pytest.fixture
def mock_ctx(mock_service):
return MockContext(mock_service)
@pytest.mark.asyncio
async def test_get_user_activities(mock_ctx, mock_service):
# Setup mock data
mock_activity = Activity(
id=1234567890,
name="Morning Run",
distance=5000,
moving_time=1200,
elapsed_time=1300,
total_elevation_gain=50,
type="Run",
sport_type="Run",
start_date=datetime.fromisoformat("2023-01-01T10:00:00+00:00"),
start_date_local=datetime.fromisoformat("2023-01-01T10:00:00+00:00"),
timezone="Europe/London",
achievement_count=2,
kudos_count=5,
comment_count=0,
athlete_count=1,
photo_count=0,
trainer=False,
commute=False,
manual=False,
private=False,
flagged=False,
average_speed=4.167,
max_speed=5.3,
has_heartrate=True,
average_heartrate=140,
max_heartrate=160,
# Add required fields with default values
map=None,
workout_type=None,
elev_high=None,
elev_low=None,
)
mock_service.get_activities.return_value = [mock_activity]
# Test tool
from strava_mcp.server import get_user_activities
result = await get_user_activities(mock_ctx)
# Verify service call
mock_service.get_activities.assert_called_once_with(None, None, 1, 30)
# Verify result
assert len(result) == 1
assert result[0]["id"] == mock_activity.id
assert result[0]["name"] == mock_activity.name
@pytest.mark.asyncio
async def test_get_activity(mock_ctx, mock_service):
# Setup mock data
mock_activity = DetailedActivity(
id=1234567890,
name="Morning Run",
distance=5000,
moving_time=1200,
elapsed_time=1300,
total_elevation_gain=50,
type="Run",
sport_type="Run",
start_date=datetime.fromisoformat("2023-01-01T10:00:00+00:00"),
start_date_local=datetime.fromisoformat("2023-01-01T10:00:00+00:00"),
timezone="Europe/London",
achievement_count=2,
kudos_count=5,
comment_count=0,
athlete_count=1,
photo_count=0,
trainer=False,
commute=False,
manual=False,
private=False,
flagged=False,
average_speed=4.167,
max_speed=5.3,
has_heartrate=True,
average_heartrate=140,
max_heartrate=160,
athlete={"id": 123},
description="Test description",
# Add required fields with default values
map=None,
workout_type=None,
elev_high=None,
elev_low=None,
calories=None,
segment_efforts=None,
splits_metric=None,
splits_standard=None,
best_efforts=None,
photos=None,
gear=None,
device_name=None,
)
mock_service.get_activity.return_value = mock_activity
# Test tool
from strava_mcp.server import get_activity
result = await get_activity(mock_ctx, 1234567890)
# Verify service call
mock_service.get_activity.assert_called_once_with(1234567890, False)
# Verify result
assert result["id"] == mock_activity.id
assert result["name"] == mock_activity.name
assert result["description"] == mock_activity.description
@pytest.mark.asyncio
async def test_get_activity_segments(mock_ctx, mock_service):
# Setup mock data
mock_segment = SegmentEffort(
id=67890,
activity_id=1234567890,
segment_id=12345,
name="Test Segment",
elapsed_time=180,
moving_time=180,
start_date=datetime.fromisoformat("2023-01-01T10:05:00+00:00"),
start_date_local=datetime.fromisoformat("2023-01-01T10:05:00+00:00"),
distance=1000,
athlete={"id": 123},
segment=Segment(
id=12345,
name="Test Segment",
activity_type="Run",
distance=1000,
average_grade=5.0,
maximum_grade=10.0,
elevation_high=200,
elevation_low=150,
total_elevation_gain=50,
start_latlng=[51.5, -0.1],
end_latlng=[51.5, -0.2],
climb_category=0,
private=False,
starred=False,
city=None,
state=None,
country=None,
),
# Add required fields with default values
average_watts=None,
device_watts=None,
average_heartrate=None,
max_heartrate=None,
pr_rank=None,
achievements=None,
)
mock_service.get_activity_segments.return_value = [mock_segment]
# Test tool
from strava_mcp.server import get_activity_segments
result = await get_activity_segments(mock_ctx, 1234567890)
# Verify service call
mock_service.get_activity_segments.assert_called_once_with(1234567890)
# Verify result
assert len(result) == 1
assert result[0]["id"] == mock_segment.id
assert result[0]["name"] == mock_segment.name