Skip to main content
Glama
test_usage_tracking.py13.9 kB
"""Unit tests for src/api/middleware/usage_tracking.py""" from datetime import datetime from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import Request, Response from src.api.middleware.usage_tracking import UsageTrackingMiddleware # === Test _extract_tool_name === def test_extract_tool_name_valid_path(): """Test tool name extraction from valid API path""" # Arrange path = "/api/v1/properties/123" # Act tool_name = UsageTrackingMiddleware._extract_tool_name(path) # Assert assert tool_name == "properties" def test_extract_tool_name_various_resources(): """Test tool name extraction for different resource types""" # Arrange & Act & Assert assert UsageTrackingMiddleware._extract_tool_name("/api/v1/reservations") == "reservations" assert UsageTrackingMiddleware._extract_tool_name("/api/v1/listings/456") == "listings" assert UsageTrackingMiddleware._extract_tool_name("/api/v1/webhooks/test") == "webhooks" def test_extract_tool_name_with_trailing_slash(): """Test tool name extraction with trailing slash""" # Arrange path = "/api/v1/properties/" # Act tool_name = UsageTrackingMiddleware._extract_tool_name(path) # Assert assert tool_name == "properties" def test_extract_tool_name_malformed_path(): """Test tool name extraction from malformed path returns 'unknown'""" # Arrange & Act & Assert assert UsageTrackingMiddleware._extract_tool_name("/api") == "unknown" assert UsageTrackingMiddleware._extract_tool_name("/api/v1") == "unknown" assert UsageTrackingMiddleware._extract_tool_name("/") == "unknown" assert UsageTrackingMiddleware._extract_tool_name("") == "unknown" # === Test _track_usage === @pytest.mark.asyncio @patch("src.api.middleware.usage_tracking.get_supabase_client") async def test_track_usage_success(mock_get_client, capsys): """Test successful usage tracking""" # Arrange mock_supabase = MagicMock() mock_supabase.rpc.return_value = mock_supabase mock_supabase.execute.return_value = MagicMock(data={"success": True}) mock_get_client.return_value = mock_supabase organization_id = "org-123" tool_name = "properties" expected_month = datetime.now().strftime("%Y-%m") # Act await UsageTrackingMiddleware._track_usage( organization_id=organization_id, tool_name=tool_name, ) # Assert mock_supabase.rpc.assert_called_once_with( "increment_usage_metrics", { "org_id": organization_id, "month": expected_month, "tool": tool_name, }, ) mock_supabase.execute.assert_called_once() # Check debug logging captured = capsys.readouterr() assert "Tracking usage - org:" in captured.out assert organization_id in captured.out assert expected_month in captured.out assert tool_name in captured.out assert "Usage tracked successfully" in captured.out @pytest.mark.asyncio @patch("src.api.middleware.usage_tracking.get_supabase_client") async def test_track_usage_no_data_returned(mock_get_client, capsys): """Test usage tracking when RPC returns no data (logs warning)""" # Arrange mock_supabase = MagicMock() mock_supabase.rpc.return_value = mock_supabase mock_supabase.execute.return_value = MagicMock(data=None) mock_get_client.return_value = mock_supabase organization_id = "org-456" tool_name = "listings" # Act await UsageTrackingMiddleware._track_usage( organization_id=organization_id, tool_name=tool_name, ) # Assert captured = capsys.readouterr() # Check for debug logging assert "Tracking usage - org:" in captured.out assert organization_id in captured.out assert tool_name in captured.out # Check for warning assert "Warning: Usage tracking returned no data" in captured.out @pytest.mark.asyncio @patch("src.api.middleware.usage_tracking.get_supabase_client") async def test_track_usage_rpc_failure(mock_get_client, capsys): """Test usage tracking handles RPC failure gracefully""" # Arrange mock_supabase = MagicMock() mock_supabase.rpc.return_value = mock_supabase mock_supabase.execute.side_effect = Exception("Database connection failed") mock_get_client.return_value = mock_supabase organization_id = "org-789" tool_name = "webhooks" # Act - should not raise exception await UsageTrackingMiddleware._track_usage( organization_id=organization_id, tool_name=tool_name, ) # Assert - error logged but not propagated captured = capsys.readouterr() assert "Failed to track usage" in captured.out assert organization_id in captured.out @pytest.mark.asyncio @patch("src.api.middleware.usage_tracking.get_supabase_client") async def test_track_usage_multiple_tools(mock_get_client): """Test tracking multiple different tools for same organization""" # Arrange mock_supabase = MagicMock() mock_supabase.rpc.return_value = mock_supabase mock_supabase.execute.return_value = MagicMock(data={"success": True}) mock_get_client.return_value = mock_supabase organization_id = "org-multi" expected_month = datetime.now().strftime("%Y-%m") # Act await UsageTrackingMiddleware._track_usage(organization_id, "properties") await UsageTrackingMiddleware._track_usage(organization_id, "reservations") await UsageTrackingMiddleware._track_usage(organization_id, "listings") # Assert assert mock_supabase.rpc.call_count == 3 calls = list(mock_supabase.rpc.call_args_list) assert calls[0][0][1]["tool"] == "properties" assert calls[0][0][1]["month"] == expected_month assert calls[0][0][1]["org_id"] == organization_id assert calls[1][0][1]["tool"] == "reservations" assert calls[1][0][1]["month"] == expected_month assert calls[1][0][1]["org_id"] == organization_id assert calls[2][0][1]["tool"] == "listings" assert calls[2][0][1]["month"] == expected_month assert calls[2][0][1]["org_id"] == organization_id # === Test dispatch method === @pytest.mark.asyncio async def test_dispatch_tracks_api_routes(): """Test middleware tracks /api/* routes""" # Arrange middleware = UsageTrackingMiddleware(app=MagicMock()) mock_request = MagicMock(spec=Request) mock_request.url.path = "/api/v1/properties/123" mock_request.state.organization_id = "org-123" mock_response = Response(content=b"test", status_code=200) mock_call_next = AsyncMock(return_value=mock_response) # Act with patch.object( UsageTrackingMiddleware, "_track_usage", new_callable=AsyncMock ) as mock_track: response = await middleware.dispatch(mock_request, mock_call_next) # Assert assert response == mock_response mock_call_next.assert_called_once_with(mock_request) mock_track.assert_called_once_with( organization_id="org-123", tool_name="properties", ) @pytest.mark.asyncio async def test_dispatch_skips_non_api_routes(): """Test middleware skips tracking for non-API routes""" # Arrange middleware = UsageTrackingMiddleware(app=MagicMock()) mock_request = MagicMock(spec=Request) mock_request.url.path = "/health" mock_response = Response(content=b"ok", status_code=200) mock_call_next = AsyncMock(return_value=mock_response) # Act with patch.object( UsageTrackingMiddleware, "_track_usage", new_callable=AsyncMock ) as mock_track: response = await middleware.dispatch(mock_request, mock_call_next) # Assert assert response == mock_response mock_call_next.assert_called_once_with(mock_request) mock_track.assert_not_called() # Should not track non-API routes @pytest.mark.asyncio async def test_dispatch_skips_tracking_without_organization_id(): """Test middleware skips tracking when organization_id is missing""" # Arrange middleware = UsageTrackingMiddleware(app=MagicMock()) mock_request = MagicMock(spec=Request) mock_request.url.path = "/api/v1/properties" mock_request.state.organization_id = None # No org_id mock_response = Response(content=b"test", status_code=200) mock_call_next = AsyncMock(return_value=mock_response) # Act with patch.object( UsageTrackingMiddleware, "_track_usage", new_callable=AsyncMock ) as mock_track: response = await middleware.dispatch(mock_request, mock_call_next) # Assert assert response == mock_response mock_track.assert_not_called() # Should not track without org_id @pytest.mark.asyncio async def test_dispatch_continues_on_tracking_failure(capsys): """Test middleware doesn't fail request if tracking fails""" # Arrange middleware = UsageTrackingMiddleware(app=MagicMock()) mock_request = MagicMock(spec=Request) mock_request.url.path = "/api/v1/listings" mock_request.state.organization_id = "org-fail" mock_response = Response(content=b"success", status_code=200) mock_call_next = AsyncMock(return_value=mock_response) # Act with patch.object( UsageTrackingMiddleware, "_track_usage", new_callable=AsyncMock ) as mock_track: mock_track.side_effect = Exception("Tracking error") response = await middleware.dispatch(mock_request, mock_call_next) # Assert assert response == mock_response # Request succeeds despite tracking failure captured = capsys.readouterr() assert "Usage tracking error" in captured.out @pytest.mark.asyncio async def test_dispatch_multiple_requests_same_org(): """Test multiple requests from same organization""" # Arrange middleware = UsageTrackingMiddleware(app=MagicMock()) organization_id = "org-repeat" mock_response = Response(content=b"test", status_code=200) mock_call_next = AsyncMock(return_value=mock_response) # Act with patch.object( UsageTrackingMiddleware, "_track_usage", new_callable=AsyncMock ) as mock_track: # Request 1 mock_request1 = MagicMock(spec=Request) mock_request1.url.path = "/api/v1/properties" mock_request1.state.organization_id = organization_id await middleware.dispatch(mock_request1, mock_call_next) # Request 2 mock_request2 = MagicMock(spec=Request) mock_request2.url.path = "/api/v1/reservations" mock_request2.state.organization_id = organization_id await middleware.dispatch(mock_request2, mock_call_next) # Assert assert mock_track.call_count == 2 calls = list(mock_track.call_args_list) assert calls[0][1]["organization_id"] == organization_id assert calls[0][1]["tool_name"] == "properties" assert calls[1][1]["organization_id"] == organization_id assert calls[1][1]["tool_name"] == "reservations" @pytest.mark.asyncio async def test_dispatch_different_organizations(): """Test requests from different organizations are tracked separately""" # Arrange middleware = UsageTrackingMiddleware(app=MagicMock()) mock_response = Response(content=b"test", status_code=200) mock_call_next = AsyncMock(return_value=mock_response) # Act with patch.object( UsageTrackingMiddleware, "_track_usage", new_callable=AsyncMock ) as mock_track: # Org 1 mock_request1 = MagicMock(spec=Request) mock_request1.url.path = "/api/v1/properties" mock_request1.state.organization_id = "org-1" await middleware.dispatch(mock_request1, mock_call_next) # Org 2 mock_request2 = MagicMock(spec=Request) mock_request2.url.path = "/api/v1/properties" mock_request2.state.organization_id = "org-2" await middleware.dispatch(mock_request2, mock_call_next) # Assert assert mock_track.call_count == 2 calls = list(mock_track.call_args_list) assert calls[0][1]["organization_id"] == "org-1" assert calls[1][1]["organization_id"] == "org-2" @pytest.mark.asyncio async def test_dispatch_with_error_response(): """Test middleware tracks usage even when API returns error""" # Arrange middleware = UsageTrackingMiddleware(app=MagicMock()) mock_request = MagicMock(spec=Request) mock_request.url.path = "/api/v1/properties/999" mock_request.state.organization_id = "org-error" # Simulate error response (404) mock_error_response = Response(content=b"Not Found", status_code=404) mock_call_next = AsyncMock(return_value=mock_error_response) # Act with patch.object( UsageTrackingMiddleware, "_track_usage", new_callable=AsyncMock ) as mock_track: response = await middleware.dispatch(mock_request, mock_call_next) # Assert assert response.status_code == 404 mock_track.assert_called_once() # Should still track failed requests @pytest.mark.asyncio async def test_dispatch_request_state_without_org_attribute(): """Test middleware handles request.state without organization_id attribute""" # Arrange middleware = UsageTrackingMiddleware(app=MagicMock()) mock_request = MagicMock(spec=Request) mock_request.url.path = "/api/v1/properties" # Mock state without organization_id attribute mock_request.state = MagicMock() del mock_request.state.organization_id # Ensure attribute doesn't exist mock_response = Response(content=b"test", status_code=200) mock_call_next = AsyncMock(return_value=mock_response) # Act with patch.object( UsageTrackingMiddleware, "_track_usage", new_callable=AsyncMock ) as mock_track: response = await middleware.dispatch(mock_request, mock_call_next) # Assert assert response == mock_response mock_track.assert_not_called() # Should not track without org_id

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/darrentmorgan/hostaway-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server