Skip to main content
Glama
test_token_aware_middleware.py15.6 kB
"""Unit tests for TokenAwareMiddleware. Tests response optimization, token budget enforcement, and summarization triggering. """ from unittest.mock import MagicMock, patch import pytest from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import JSONResponse, Response from starlette.routing import Route from starlette.testclient import TestClient from src.api.middleware.token_aware_middleware import TokenAwareMiddleware class TestTokenAwareMiddleware: """Test suite for TokenAwareMiddleware.""" @pytest.fixture def mock_config_service(self): """Create mock config service.""" mock_service = MagicMock() mock_service.get_endpoint_config.return_value = ( 4000, # threshold 12000, # hard_cap 50, # page_size True, # summarization_enabled True, # pagination_enabled ) return mock_service @pytest.fixture def mock_summarization_service(self): """Create mock summarization service.""" mock_service = MagicMock() mock_response = MagicMock() mock_response.model_dump.return_value = { "summary": "Summarized content", "meta": {"summarized": True}, } mock_service.summarize_object.return_value = mock_response return mock_service @pytest.fixture def app(self): """Create test application with middleware.""" async def endpoint_small(request: Request): return JSONResponse({"message": "small response"}) async def endpoint_large(request: Request): # Large response that exceeds token budget return JSONResponse( { "data": "x" * 50000, # Large enough to exceed 4000 token threshold "items": list(range(1000)), } ) async def endpoint_non_json(request: Request): return Response(content="plain text", media_type="text/plain") async def endpoint_error(request: Request): return JSONResponse({"error": "something went wrong"}, status_code=500) app = Starlette( routes=[ Route("/api/small", endpoint_small), Route("/api/large", endpoint_large), Route("/api/text", endpoint_non_json), Route("/api/error", endpoint_error), ] ) # Add middleware without patches - patches will be applied in each test app.add_middleware(TokenAwareMiddleware) return app def test_small_response_passes_through( self, app, mock_config_service, mock_summarization_service ): """Test small response passes through without summarization.""" with patch( "src.api.middleware.token_aware_middleware.get_config_service", return_value=mock_config_service, ): with patch( "src.api.middleware.token_aware_middleware.get_summarization_service", return_value=mock_summarization_service, ): client = TestClient(app) response = client.get("/api/small") assert response.status_code == 200 assert response.json() == {"message": "small response"} # Summarization should not be called mock_summarization_service.summarize_object.assert_not_called() @patch("src.api.middleware.token_aware_middleware.estimate_tokens") def test_large_response_triggers_summarization( self, mock_estimate, app, mock_config_service, mock_summarization_service ): """Test large response triggers summarization.""" # Mock token estimation to exceed threshold mock_estimate.return_value = 5000 # Exceeds 4000 threshold with patch( "src.api.middleware.token_aware_middleware.get_config_service", return_value=mock_config_service, ): with patch( "src.api.middleware.token_aware_middleware.get_summarization_service", return_value=mock_summarization_service, ): client = TestClient(app) response = client.get("/api/large") assert response.status_code == 200 data = response.json() # Check if response was summarized (has summary field) assert "summary" in data or "data" in data # Could be original or summarized # Summarization should be called mock_summarization_service.summarize_object.assert_called_once() def test_non_json_response_passes_through(self, app): """Test non-JSON response is not processed.""" client = TestClient(app) response = client.get("/api/text") assert response.status_code == 200 assert response.text == "plain text" def test_error_response_passes_through( self, app, mock_config_service, mock_summarization_service ): """Test error responses (4xx, 5xx) are not processed.""" with patch( "src.api.middleware.token_aware_middleware.get_config_service", return_value=mock_config_service, ): with patch( "src.api.middleware.token_aware_middleware.get_summarization_service", return_value=mock_summarization_service, ): client = TestClient(app) response = client.get("/api/error") assert response.status_code == 500 assert response.json() == {"error": "something went wrong"} # Summarization should not be called for errors mock_summarization_service.summarize_object.assert_not_called() @patch("src.api.middleware.token_aware_middleware.estimate_tokens") def test_paginated_response_skipped( self, mock_estimate, app, mock_config_service, mock_summarization_service ): """Test paginated response is skipped (not summarized).""" mock_estimate.return_value = 5000 # Would normally trigger summarization # Create endpoint that returns paginated structure async def endpoint_paginated(request: Request): return JSONResponse( {"items": [{"id": 1}, {"id": 2}], "meta": {"totalCount": 100, "hasMore": True}} ) app.routes.append(Route("/api/paginated", endpoint_paginated)) with patch( "src.api.middleware.token_aware_middleware.get_config_service", return_value=mock_config_service, ): with patch( "src.api.middleware.token_aware_middleware.get_summarization_service", return_value=mock_summarization_service, ): client = TestClient(app) response = client.get("/api/paginated") assert response.status_code == 200 data = response.json() # Should not be summarized because it's already paginated assert "items" in data assert "meta" in data mock_summarization_service.summarize_object.assert_not_called() @patch("src.api.middleware.token_aware_middleware.estimate_tokens") def test_already_summarized_response_skipped( self, mock_estimate, app, mock_config_service, mock_summarization_service ): """Test already summarized response is skipped.""" mock_estimate.return_value = 5000 # Create endpoint that returns summary structure async def endpoint_summary(request: Request): return JSONResponse({"summary": "Already summarized", "meta": {"summarized": True}}) app.routes.append(Route("/api/summary", endpoint_summary)) with patch( "src.api.middleware.token_aware_middleware.get_config_service", return_value=mock_config_service, ): with patch( "src.api.middleware.token_aware_middleware.get_summarization_service", return_value=mock_summarization_service, ): client = TestClient(app) response = client.get("/api/summary") assert response.status_code == 200 data = response.json() # Should not be re-summarized assert "summary" in data mock_summarization_service.summarize_object.assert_not_called() def test_summarization_disabled_for_endpoint( self, mock_config_service, mock_summarization_service ): """Test summarization skipped when disabled for endpoint.""" # Configure summarization as disabled mock_config_service.get_endpoint_config.return_value = ( 4000, 12000, 50, False, True, # summarization_enabled=False ) async def endpoint_large(request: Request): return JSONResponse({"data": "x" * 50000}) app = Starlette(routes=[Route("/api/test", endpoint_large)]) with patch( "src.api.middleware.token_aware_middleware.get_config_service", return_value=mock_config_service, ): with patch( "src.api.middleware.token_aware_middleware.get_summarization_service", return_value=mock_summarization_service, ): with patch( "src.api.middleware.token_aware_middleware.estimate_tokens", return_value=5000 ): app.add_middleware(TokenAwareMiddleware) client = TestClient(app) response = client.get("/api/test") assert response.status_code == 200 # Summarization should not be called when disabled mock_summarization_service.summarize_object.assert_not_called() def test_invalid_json_response_passes_through(self, mock_config_service): """Test invalid JSON response passes through unchanged.""" async def endpoint_broken_json(request: Request): return Response(content=b"not valid json {{{", media_type="application/json") app = Starlette(routes=[Route("/api/broken", endpoint_broken_json)]) with patch( "src.api.middleware.token_aware_middleware.get_config_service", return_value=mock_config_service, ): app.add_middleware(TokenAwareMiddleware) client = TestClient(app) response = client.get("/api/broken") # Should return original broken response assert response.status_code == 200 assert response.text == "not valid json {{{" @patch("src.api.middleware.token_aware_middleware.estimate_tokens") def test_detect_object_type_booking( self, mock_estimate, mock_config_service, mock_summarization_service ): """Test object type detection for booking endpoint.""" mock_estimate.return_value = 5000 async def endpoint_booking(request: Request): return JSONResponse({"booking_id": "123", "guest_name": "John"}) app = Starlette(routes=[Route("/api/bookings/123", endpoint_booking)]) with patch( "src.api.middleware.token_aware_middleware.get_config_service", return_value=mock_config_service, ): with patch( "src.api.middleware.token_aware_middleware.get_summarization_service", return_value=mock_summarization_service, ): app.add_middleware(TokenAwareMiddleware) client = TestClient(app) response = client.get("/api/bookings/123") # Verify summarize_object was called with "booking" type call_args = mock_summarization_service.summarize_object.call_args assert call_args[1]["obj_type"] == "booking" @patch("src.api.middleware.token_aware_middleware.estimate_tokens") def test_detect_object_type_listing( self, mock_estimate, mock_config_service, mock_summarization_service ): """Test object type detection for listing endpoint.""" mock_estimate.return_value = 5000 async def endpoint_listing(request: Request): return JSONResponse({"listing_id": "456", "name": "Beach House"}) app = Starlette(routes=[Route("/api/listings/456", endpoint_listing)]) with patch( "src.api.middleware.token_aware_middleware.get_config_service", return_value=mock_config_service, ): with patch( "src.api.middleware.token_aware_middleware.get_summarization_service", return_value=mock_summarization_service, ): app.add_middleware(TokenAwareMiddleware) client = TestClient(app) response = client.get("/api/listings/456") call_args = mock_summarization_service.summarize_object.call_args assert call_args[1]["obj_type"] == "listing" @patch("src.api.middleware.token_aware_middleware.estimate_tokens") def test_detect_object_type_financial( self, mock_estimate, mock_config_service, mock_summarization_service ): """Test object type detection for financial endpoint.""" mock_estimate.return_value = 5000 async def endpoint_financial(request: Request): return JSONResponse({"transaction_id": "789", "amount": 500}) app = Starlette(routes=[Route("/api/financial/report", endpoint_financial)]) with patch( "src.api.middleware.token_aware_middleware.get_config_service", return_value=mock_config_service, ): with patch( "src.api.middleware.token_aware_middleware.get_summarization_service", return_value=mock_summarization_service, ): app.add_middleware(TokenAwareMiddleware) client = TestClient(app) response = client.get("/api/financial/report") call_args = mock_summarization_service.summarize_object.call_args assert call_args[1]["obj_type"] == "financial_transaction" @patch("src.api.middleware.token_aware_middleware.estimate_tokens") def test_list_response_warning( self, mock_estimate, mock_config_service, mock_summarization_service, caplog ): """Test warning logged when list response received.""" mock_estimate.return_value = 5000 async def endpoint_list(request: Request): return JSONResponse([{"id": 1}, {"id": 2}, {"id": 3}]) app = Starlette(routes=[Route("/api/items", endpoint_list)]) with patch( "src.api.middleware.token_aware_middleware.get_config_service", return_value=mock_config_service, ): with patch( "src.api.middleware.token_aware_middleware.get_summarization_service", return_value=mock_summarization_service, ): app.add_middleware(TokenAwareMiddleware) client = TestClient(app) response = client.get("/api/items") # Should log warning about list response assert "should be handled by pagination middleware" in caplog.text

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/darrentmorgan/hostaway-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server