Strava MCP Server

by yorrickjansen
Verified
"""Tests for the Strava OAuth server module.""" import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import FastAPI from strava_mcp.oauth_server import StravaOAuthServer, get_refresh_token_from_oauth @pytest.fixture def client_credentials(): """Fixture for client credentials.""" return { "client_id": "test_client_id", "client_secret": "test_client_secret", } @pytest.fixture def oauth_server(client_credentials): """Fixture for StravaOAuthServer.""" return StravaOAuthServer( client_id=client_credentials["client_id"], client_secret=client_credentials["client_secret"], ) @pytest.mark.asyncio async def test_initialize_server(oauth_server): """Test initializing the server.""" # Mock the OAuth server's dependencies directly with patch("strava_mcp.oauth_server.StravaAuthenticator") as mock_authenticator_class: with patch("asyncio.create_task") as mock_create_task: # Setup mocks mock_authenticator = MagicMock() mock_authenticator_class.return_value = mock_authenticator mock_task = MagicMock() mock_create_task.return_value = mock_task # Test method await oauth_server._initialize_server() # Verify FastAPI app was created assert oauth_server.app is not None assert oauth_server.app.title == "Strava OAuth" # Verify authenticator was created and configured mock_authenticator_class.assert_called_once_with( client_id=oauth_server.client_id, client_secret=oauth_server.client_secret, app=oauth_server.app, host=oauth_server.host, port=oauth_server.port, ) assert oauth_server.authenticator == mock_authenticator # Verify token future was stored in authenticator assert mock_authenticator.token_future is oauth_server.token_future # Verify routes were set up mock_authenticator.setup_routes.assert_called_once_with(oauth_server.app) # Verify server task was created mock_create_task.assert_called_once() assert oauth_server.server_task == mock_task @pytest.mark.asyncio async def test_run_server(oauth_server): """Test running the server.""" with patch("uvicorn.Server") as mock_server_class: with patch("uvicorn.Config") as mock_config_class: # Setup mocks mock_server = AsyncMock() mock_server_class.return_value = mock_server mock_config = MagicMock() mock_config_class.return_value = mock_config # Create app oauth_server.app = FastAPI() # Test method await oauth_server._run_server() # Verify config was created correctly mock_config_class.assert_called_once_with( app=oauth_server.app, host=oauth_server.host, port=oauth_server.port, log_level="info", ) # Verify server was created and run mock_server_class.assert_called_once_with(mock_config) mock_server.serve.assert_called_once() assert oauth_server.server == mock_server @pytest.mark.asyncio async def test_run_server_exception(oauth_server): """Test running the server with an exception.""" with patch("uvicorn.Server") as mock_server_class: with patch("uvicorn.Config") as mock_config_class: # Setup mocks mock_server = AsyncMock() mock_server.serve = AsyncMock(side_effect=Exception("Test error")) mock_server_class.return_value = mock_server mock_config = MagicMock() mock_config_class.return_value = mock_config # Create app and token future oauth_server.app = FastAPI() oauth_server.token_future = asyncio.Future() # Test method await oauth_server._run_server() # Verify token future has exception assert oauth_server.token_future.done() with pytest.raises(Exception, match="Test error"): await oauth_server.token_future @pytest.mark.asyncio async def test_stop_server(oauth_server): """Test stopping the server.""" # Setup server and task oauth_server.server = MagicMock() oauth_server.server_task = MagicMock() oauth_server.server_task.done = MagicMock(return_value=False) # Make asyncio.wait_for return immediately with patch("asyncio.wait_for", new=AsyncMock()) as mock_wait_for: # Test method await oauth_server._stop_server() # Verify server was stopped assert oauth_server.server.should_exit is True mock_wait_for.assert_called_once_with(oauth_server.server_task, timeout=5.0) @pytest.mark.asyncio async def test_stop_server_timeout(oauth_server): """Test stopping the server with timeout.""" # Setup server and task oauth_server.server = MagicMock() oauth_server.server_task = MagicMock() oauth_server.server_task.done = MagicMock(return_value=False) # Make asyncio.wait_for raise TimeoutError with patch("asyncio.wait_for", new=AsyncMock(side_effect=TimeoutError())) as mock_wait_for: # Test method await oauth_server._stop_server() # Verify server was stopped assert oauth_server.server.should_exit is True mock_wait_for.assert_called_once_with(oauth_server.server_task, timeout=5.0) @pytest.mark.asyncio async def test_get_token(oauth_server): """Test getting a token.""" # Setup mocks oauth_server._initialize_server = AsyncMock() oauth_server._stop_server = AsyncMock() oauth_server.authenticator = MagicMock() oauth_server.authenticator.get_authorization_url = MagicMock(return_value="https://example.com/auth") with patch("webbrowser.open") as mock_open: # Prepare token future oauth_server.token_future = asyncio.Future() oauth_server.token_future.set_result("test_refresh_token") # Test method token = await oauth_server.get_token() # Verify assert token == "test_refresh_token" oauth_server._initialize_server.assert_called_once() oauth_server.authenticator.get_authorization_url.assert_called_once() mock_open.assert_called_once_with("https://example.com/auth") oauth_server._stop_server.assert_called_once() @pytest.mark.asyncio async def test_get_token_no_authenticator(oauth_server): """Test getting a token with no authenticator.""" # Setup mocks oauth_server._initialize_server = AsyncMock() oauth_server._stop_server = AsyncMock() oauth_server.authenticator = None # Test method with pytest.raises(Exception, match="Authenticator not initialized"): await oauth_server.get_token() # Verify oauth_server._initialize_server.assert_called_once() # The stop server is not called because we exit with exception before getting there # oauth_server._stop_server.assert_called_once() @pytest.mark.asyncio async def test_get_token_cancelled(oauth_server): """Test getting a token that is cancelled.""" # Setup mocks oauth_server._initialize_server = AsyncMock() oauth_server._stop_server = AsyncMock() oauth_server.authenticator = MagicMock() oauth_server.authenticator.get_authorization_url = MagicMock(return_value="https://example.com/auth") with patch("webbrowser.open") as mock_open: # Prepare token future with cancellation oauth_server.token_future = asyncio.Future() oauth_server.token_future.cancel() # Test method with pytest.raises(Exception, match="OAuth flow was cancelled"): await oauth_server.get_token() # Verify oauth_server._initialize_server.assert_called_once() oauth_server.authenticator.get_authorization_url.assert_called_once() mock_open.assert_called_once_with("https://example.com/auth") oauth_server._stop_server.assert_called_once() @pytest.mark.asyncio async def test_get_token_exception(oauth_server): """Test getting a token with exception.""" # Setup mocks oauth_server._initialize_server = AsyncMock() oauth_server._stop_server = AsyncMock() oauth_server.authenticator = MagicMock() oauth_server.authenticator.get_authorization_url = MagicMock(return_value="https://example.com/auth") with patch("webbrowser.open") as mock_open: # Prepare token future with exception oauth_server.token_future = asyncio.Future() oauth_server.token_future.set_exception(Exception("Test error")) # Test method with pytest.raises(Exception, match="OAuth flow failed: Test error"): await oauth_server.get_token() # Verify oauth_server._initialize_server.assert_called_once() oauth_server.authenticator.get_authorization_url.assert_called_once() mock_open.assert_called_once_with("https://example.com/auth") oauth_server._stop_server.assert_called_once() @pytest.mark.asyncio async def test_get_refresh_token_from_oauth(client_credentials): """Test get_refresh_token_from_oauth function.""" with patch("strava_mcp.oauth_server.StravaOAuthServer") as mock_oauth_server_class: # Setup mock mock_server = MagicMock() mock_server.get_token = AsyncMock(return_value="test_refresh_token") mock_oauth_server_class.return_value = mock_server # Test function token = await get_refresh_token_from_oauth(client_credentials["client_id"], client_credentials["client_secret"]) # Verify assert token == "test_refresh_token" mock_oauth_server_class.assert_called_once_with( client_credentials["client_id"], client_credentials["client_secret"] ) mock_server.get_token.assert_called_once()