Skip to main content
Glama
test_middleware.py30.6 kB
"""Unit tests for access control middleware. Tests the AgentAccessControl middleware that enforces per-agent access rules by extracting agent identity, validating permissions, and managing context state. """ import pytest from unittest.mock import AsyncMock, MagicMock, Mock from dataclasses import dataclass from typing import Any from src.middleware import AgentAccessControl from src.policy import PolicyEngine from fastmcp.exceptions import ToolError # Mock classes to simulate FastMCP middleware context @dataclass class MockToolCall: """Mock tool call message.""" name: str arguments: dict[str, Any] @dataclass class MockMiddlewareContext: """Mock middleware context.""" message: MockToolCall fastmcp_context: Any = None method: str = "tools/call" class MockFastMCPContext: """Mock FastMCP context with state management.""" def __init__(self): self._state = {} def set_state(self, key: str, value: Any): """Store state value.""" self._state[key] = value def get_state(self, key: str, default: Any = None) -> Any: """Retrieve state value.""" return self._state.get(key, default) class TestMiddlewareAgentExtraction: """Test agent_id extraction and context storage.""" @pytest.mark.asyncio async def test_middleware_extracts_agent_id(self): """Test that middleware successfully extracts agent_id from arguments.""" rules = { "agents": { "test_agent": { "allow": {"servers": ["api"]} } }, "defaults": {"deny_on_missing_agent": True} } policy_engine = PolicyEngine(rules) middleware = AgentAccessControl(policy_engine) # Create mock context with agent_id tool_call = MockToolCall( name="list_servers", arguments={"agent_id": "test_agent", "include_metadata": False} ) fastmcp_ctx = MockFastMCPContext() context = MockMiddlewareContext(message=tool_call, fastmcp_context=fastmcp_ctx) # Mock call_next call_next = AsyncMock(return_value={"result": "success"}) # Execute middleware result = await middleware.on_call_tool(context, call_next) # Verify agent was stored in context assert fastmcp_ctx.get_state("current_agent") == "test_agent" assert result == {"result": "success"} assert call_next.called @pytest.mark.asyncio async def test_middleware_keeps_agent_id(self): """Test that middleware keeps agent_id in arguments for gateway tools.""" rules = { "agents": { "test_agent": { "allow": {"servers": ["api"]} } }, "defaults": {"deny_on_missing_agent": True} } policy_engine = PolicyEngine(rules) middleware = AgentAccessControl(policy_engine) # Create mock context with multiple arguments tool_call = MockToolCall( name="list_servers", arguments={ "agent_id": "test_agent", "include_metadata": True, "format": "json" } ) fastmcp_ctx = MockFastMCPContext() context = MockMiddlewareContext(message=tool_call, fastmcp_context=fastmcp_ctx) # Mock call_next call_next = AsyncMock(return_value={"result": "success"}) # Execute middleware await middleware.on_call_tool(context, call_next) # Verify agent_id is kept (gateway tools need it) along with other arguments assert "agent_id" in tool_call.arguments assert tool_call.arguments["agent_id"] == "test_agent" assert tool_call.arguments["include_metadata"] is True assert tool_call.arguments["format"] == "json" @pytest.mark.asyncio async def test_middleware_stores_in_context(self): """Test that middleware stores agent in context state.""" rules = { "agents": { "researcher": { "allow": {"servers": ["brave-search"]} } }, "defaults": {"deny_on_missing_agent": True} } policy_engine = PolicyEngine(rules) middleware = AgentAccessControl(policy_engine) # Create mock context tool_call = MockToolCall( name="get_server_tools", arguments={"agent_id": "researcher", "server": "brave-search"} ) fastmcp_ctx = MockFastMCPContext() context = MockMiddlewareContext(message=tool_call, fastmcp_context=fastmcp_ctx) # Mock call_next call_next = AsyncMock(return_value={"tools": []}) # Execute middleware await middleware.on_call_tool(context, call_next) # Verify context state was updated stored_agent = fastmcp_ctx.get_state("current_agent") assert stored_agent == "researcher" class TestMiddlewareMissingAgentID: """Test handling of missing agent_id based on default policy.""" @pytest.mark.asyncio async def test_middleware_missing_agent_id_deny(self): """Test that missing agent_id raises error when default policy denies.""" rules = { "agents": { "known_agent": { "allow": {"servers": ["api"]} } }, "defaults": {"deny_on_missing_agent": True} } policy_engine = PolicyEngine(rules) middleware = AgentAccessControl(policy_engine) # Create mock context WITHOUT agent_id tool_call = MockToolCall( name="list_servers", arguments={"include_metadata": False} ) fastmcp_ctx = MockFastMCPContext() context = MockMiddlewareContext(message=tool_call, fastmcp_context=fastmcp_ctx) # Mock call_next call_next = AsyncMock() # Execute middleware - should raise ToolError with pytest.raises(ToolError) as exc_info: await middleware.on_call_tool(context, call_next) # Verify error message assert "agent_id" in str(exc_info.value).lower() assert "missing" in str(exc_info.value).lower() # Verify call_next was NOT called assert not call_next.called @pytest.mark.asyncio async def test_middleware_missing_agent_id_allow(self): """Test that missing agent_id uses fallback when default policy permits. NOTE: This test was updated from the original implementation. Previously, missing agent_id with deny_on_missing_agent=false would proceed with None. Now it attempts to use fallback chain (GATEWAY_DEFAULT_AGENT or 'default' agent). Since no fallback is configured in this test, it should error with helpful message. """ rules = { "agents": { "known_agent": { "allow": {"servers": ["api"]} } }, "defaults": {"deny_on_missing_agent": False} } policy_engine = PolicyEngine(rules) middleware = AgentAccessControl(policy_engine) # Create mock context WITHOUT agent_id (and no fallback configured) tool_call = MockToolCall( name="list_servers", arguments={"include_metadata": False} ) fastmcp_ctx = MockFastMCPContext() context = MockMiddlewareContext(message=tool_call, fastmcp_context=fastmcp_ctx) # Mock call_next call_next = AsyncMock() # Execute middleware - should raise error explaining fallback options # This is the new behavior: we need a fallback agent configured with pytest.raises(ToolError) as exc_info: await middleware.on_call_tool(context, call_next) # Verify error message is helpful error_msg = str(exc_info.value).lower() assert "agent_id" in error_msg or "default" in error_msg or "gateway_default_agent" in error_msg assert not call_next.called class TestMiddlewareGatewayTools: """Test that gateway tools are allowed through middleware.""" @pytest.mark.asyncio async def test_middleware_gateway_tools_allowed(self): """Test that gateway tools pass through middleware without blocking.""" rules = { "agents": { "test_agent": { "allow": {"servers": ["api"]} } }, "defaults": {"deny_on_missing_agent": True} } policy_engine = PolicyEngine(rules) middleware = AgentAccessControl(policy_engine) # Test each gateway tool gateway_tools = ["list_servers", "get_server_tools", "execute_tool"] for tool_name in gateway_tools: tool_call = MockToolCall( name=tool_name, arguments={"agent_id": "test_agent", "server": "api"} ) fastmcp_ctx = MockFastMCPContext() context = MockMiddlewareContext(message=tool_call, fastmcp_context=fastmcp_ctx) # Mock call_next call_next = AsyncMock(return_value={"result": "ok"}) # Execute middleware result = await middleware.on_call_tool(context, call_next) # Verify tool was allowed through assert result == {"result": "ok"} assert call_next.called assert fastmcp_ctx.get_state("current_agent") == "test_agent" @pytest.mark.asyncio async def test_middleware_list_tools_no_filtering(self): """Test that on_list_tools passes through without filtering.""" rules = { "agents": { "test_agent": { "allow": {"servers": ["api"]} } }, "defaults": {"deny_on_missing_agent": True} } policy_engine = PolicyEngine(rules) middleware = AgentAccessControl(policy_engine) # Create mock context (list_tools has no message arguments) context = MockMiddlewareContext( message=MockToolCall(name="", arguments={}), fastmcp_context=MockFastMCPContext() ) context.method = "tools/list" # Mock call_next with a list of tools mock_tools = [ {"name": "list_servers", "description": "List servers"}, {"name": "get_server_tools", "description": "Get tools"}, {"name": "execute_tool", "description": "Execute tool"} ] call_next = AsyncMock(return_value=mock_tools) # Execute middleware result = await middleware.on_list_tools(context, call_next) # Verify no filtering occurred assert result == mock_tools assert len(result) == 3 assert call_next.called class TestMiddlewareWithoutFastMCPContext: """Test middleware behavior when fastmcp_context is None.""" @pytest.mark.asyncio async def test_middleware_without_context_object(self): """Test that middleware handles missing fastmcp_context gracefully.""" rules = { "agents": { "test_agent": { "allow": {"servers": ["api"]} } }, "defaults": {"deny_on_missing_agent": True} } policy_engine = PolicyEngine(rules) middleware = AgentAccessControl(policy_engine) # Create mock context WITHOUT fastmcp_context tool_call = MockToolCall( name="list_servers", arguments={"agent_id": "test_agent"} ) context = MockMiddlewareContext(message=tool_call, fastmcp_context=None) # Mock call_next call_next = AsyncMock(return_value={"result": "success"}) # Execute middleware - should not crash result = await middleware.on_call_tool(context, call_next) # Verify execution proceeded despite no context assert result == {"result": "success"} assert call_next.called class TestMiddlewareEdgeCases: """Test edge cases and boundary conditions.""" @pytest.mark.asyncio async def test_middleware_empty_arguments(self): """Test middleware with empty arguments dict.""" rules = { "agents": {}, "defaults": {"deny_on_missing_agent": True} } policy_engine = PolicyEngine(rules) middleware = AgentAccessControl(policy_engine) # Create mock context with empty arguments tool_call = MockToolCall(name="list_servers", arguments={}) fastmcp_ctx = MockFastMCPContext() context = MockMiddlewareContext(message=tool_call, fastmcp_context=fastmcp_ctx) # Mock call_next call_next = AsyncMock() # Execute middleware - should raise ToolError for missing agent_id with pytest.raises(ToolError): await middleware.on_call_tool(context, call_next) @pytest.mark.asyncio async def test_middleware_none_arguments(self): """Test middleware when arguments is None.""" rules = { "agents": {}, "defaults": {"deny_on_missing_agent": True} } policy_engine = PolicyEngine(rules) middleware = AgentAccessControl(policy_engine) # Create mock context with None arguments tool_call = MockToolCall(name="list_servers", arguments=None) fastmcp_ctx = MockFastMCPContext() context = MockMiddlewareContext(message=tool_call, fastmcp_context=fastmcp_ctx) # Mock call_next call_next = AsyncMock() # Execute middleware - should handle None arguments gracefully with pytest.raises(ToolError): await middleware.on_call_tool(context, call_next) @pytest.mark.asyncio async def test_middleware_agent_id_with_special_characters(self): """Test that agent_id with special characters is handled correctly.""" rules = { "agents": { "agent-with-dashes_123": { "allow": {"servers": ["api"]} } }, "defaults": {"deny_on_missing_agent": True} } policy_engine = PolicyEngine(rules) middleware = AgentAccessControl(policy_engine) # Create mock context with special character agent_id tool_call = MockToolCall( name="list_servers", arguments={"agent_id": "agent-with-dashes_123", "foo": "bar"} ) fastmcp_ctx = MockFastMCPContext() context = MockMiddlewareContext(message=tool_call, fastmcp_context=fastmcp_ctx) # Mock call_next call_next = AsyncMock(return_value={"result": "ok"}) # Execute middleware result = await middleware.on_call_tool(context, call_next) # Verify agent was stored correctly assert fastmcp_ctx.get_state("current_agent") == "agent-with-dashes_123" assert result == {"result": "ok"} # agent_id is kept in arguments for gateway tools assert "agent_id" in tool_call.arguments assert tool_call.arguments["agent_id"] == "agent-with-dashes_123" assert "foo" in tool_call.arguments class TestMiddlewareMultipleArguments: """Test middleware with various argument combinations.""" @pytest.mark.asyncio async def test_middleware_preserves_all_arguments(self): """Test that all arguments including agent_id are preserved.""" rules = { "agents": { "test": {"allow": {"servers": ["*"]}} }, "defaults": {"deny_on_missing_agent": True} } policy_engine = PolicyEngine(rules) middleware = AgentAccessControl(policy_engine) # Create mock context with many arguments original_args = { "agent_id": "test", "server": "postgres", "tool": "query", "args": {"sql": "SELECT * FROM users"}, "timeout_ms": 5000, "format": "json" } tool_call = MockToolCall(name="execute_tool", arguments=original_args.copy()) fastmcp_ctx = MockFastMCPContext() context = MockMiddlewareContext(message=tool_call, fastmcp_context=fastmcp_ctx) # Mock call_next call_next = AsyncMock(return_value={"rows": []}) # Execute middleware await middleware.on_call_tool(context, call_next) # Verify all arguments are preserved (including agent_id) assert "agent_id" in tool_call.arguments assert tool_call.arguments["agent_id"] == "test" assert tool_call.arguments["server"] == "postgres" assert tool_call.arguments["tool"] == "query" assert tool_call.arguments["args"] == {"sql": "SELECT * FROM users"} assert tool_call.arguments["timeout_ms"] == 5000 assert tool_call.arguments["format"] == "json" @pytest.mark.asyncio async def test_middleware_agent_id_only_argument(self): """Test when agent_id is the only argument.""" rules = { "agents": { "solo": {"allow": {"servers": ["api"]}} }, "defaults": {"deny_on_missing_agent": True} } policy_engine = PolicyEngine(rules) middleware = AgentAccessControl(policy_engine) # Create mock context with only agent_id tool_call = MockToolCall( name="list_servers", arguments={"agent_id": "solo"} ) fastmcp_ctx = MockFastMCPContext() context = MockMiddlewareContext(message=tool_call, fastmcp_context=fastmcp_ctx) # Mock call_next call_next = AsyncMock(return_value={"servers": []}) # Execute middleware await middleware.on_call_tool(context, call_next) # Verify agent_id is still present assert tool_call.arguments == {"agent_id": "solo"} assert fastmcp_ctx.get_state("current_agent") == "solo" class TestMiddlewareAgentIDFallback: """Test agent_id fallback chain when agent_id is missing.""" @pytest.mark.asyncio async def test_fallback_to_env_var(self, monkeypatch): """When agent_id missing, should use GATEWAY_DEFAULT_AGENT env var.""" # Mock get_default_agent_id to return "researcher" from src import gateway monkeypatch.setattr(gateway, "_default_agent_id", "researcher") rules = { "agents": { "researcher": { "allow": {"servers": ["brave-search"]} } }, "defaults": {"deny_on_missing_agent": False} } policy_engine = PolicyEngine(rules) middleware = AgentAccessControl(policy_engine) # Create mock context WITHOUT agent_id tool_call = MockToolCall( name="list_servers", arguments={"include_metadata": False} ) fastmcp_ctx = MockFastMCPContext() context = MockMiddlewareContext(message=tool_call, fastmcp_context=fastmcp_ctx) # Mock call_next call_next = AsyncMock(return_value={"servers": ["brave-search"]}) # Execute middleware - should use researcher's permissions result = await middleware.on_call_tool(context, call_next) # Verify call succeeded using fallback agent assert result == {"servers": ["brave-search"]} assert call_next.called # Verify context state was set to fallback agent assert fastmcp_ctx.get_state("current_agent") == "researcher" @pytest.mark.asyncio async def test_fallback_to_default_agent(self): """When agent_id missing and no env var, should use 'default' agent.""" rules = { "agents": { "default": { "allow": {"servers": ["api"]} } }, "defaults": {"deny_on_missing_agent": False} } policy_engine = PolicyEngine(rules) middleware = AgentAccessControl(policy_engine) # Create mock context WITHOUT agent_id (and no env var) tool_call = MockToolCall( name="list_servers", arguments={"include_metadata": False} ) fastmcp_ctx = MockFastMCPContext() context = MockMiddlewareContext(message=tool_call, fastmcp_context=fastmcp_ctx) # Mock call_next call_next = AsyncMock(return_value={"servers": ["api"]}) # Execute middleware - should use default agent's permissions result = await middleware.on_call_tool(context, call_next) # Verify call succeeded using default agent assert result == {"servers": ["api"]} assert call_next.called # Verify context state was set to default agent assert fastmcp_ctx.get_state("current_agent") == "default" @pytest.mark.asyncio async def test_env_var_precedence_over_default(self, monkeypatch): """GATEWAY_DEFAULT_AGENT should override 'default' agent in rules.""" # Mock get_default_agent_id to return "researcher" from src import gateway monkeypatch.setattr(gateway, "_default_agent_id", "researcher") rules = { "agents": { "researcher": { "allow": {"servers": ["brave-search"]} }, "default": { "allow": {"servers": ["api"]} } }, "defaults": {"deny_on_missing_agent": False} } policy_engine = PolicyEngine(rules) middleware = AgentAccessControl(policy_engine) # Create mock context WITHOUT agent_id tool_call = MockToolCall( name="list_servers", arguments={"include_metadata": False} ) fastmcp_ctx = MockFastMCPContext() context = MockMiddlewareContext(message=tool_call, fastmcp_context=fastmcp_ctx) # Mock call_next call_next = AsyncMock(return_value={"servers": ["brave-search"]}) # Execute middleware await middleware.on_call_tool(context, call_next) # Verify used researcher (not default) assert fastmcp_ctx.get_state("current_agent") == "researcher" assert call_next.called @pytest.mark.asyncio async def test_fallback_agent_not_in_rules(self, monkeypatch): """Should error if fallback agent doesn't exist in rules config.""" # Mock get_default_agent_id to return nonexistent agent from src import gateway monkeypatch.setattr(gateway, "_default_agent_id", "nonexistent") rules = { "agents": { "researcher": { "allow": {"servers": ["api"]} } }, "defaults": {"deny_on_missing_agent": False} } policy_engine = PolicyEngine(rules) middleware = AgentAccessControl(policy_engine) # Create mock context WITHOUT agent_id tool_call = MockToolCall( name="list_servers", arguments={"include_metadata": False} ) fastmcp_ctx = MockFastMCPContext() context = MockMiddlewareContext(message=tool_call, fastmcp_context=fastmcp_ctx) # Mock call_next call_next = AsyncMock() # Execute middleware - should raise helpful error with pytest.raises(ToolError) as exc_info: await middleware.on_call_tool(context, call_next) # Verify error message is helpful error_msg = str(exc_info.value) assert "nonexistent" in error_msg.lower() or "fallback" in error_msg.lower() or "default" in error_msg.lower() assert not call_next.called @pytest.mark.asyncio async def test_no_fallback_configured(self): """Should error if no env var and no 'default' agent.""" rules = { "agents": { "researcher": { "allow": {"servers": ["api"]} } }, "defaults": {"deny_on_missing_agent": False} } policy_engine = PolicyEngine(rules) middleware = AgentAccessControl(policy_engine) # Create mock context WITHOUT agent_id (and no env var, no default agent) tool_call = MockToolCall( name="list_servers", arguments={"include_metadata": False} ) fastmcp_ctx = MockFastMCPContext() context = MockMiddlewareContext(message=tool_call, fastmcp_context=fastmcp_ctx) # Mock call_next call_next = AsyncMock() # Execute middleware - should raise error explaining config options with pytest.raises(ToolError) as exc_info: await middleware.on_call_tool(context, call_next) # Verify error message explains configuration options error_msg = str(exc_info.value).lower() assert "agent_id" in error_msg or "default" in error_msg or "gateway_default_agent" in error_msg assert not call_next.called @pytest.mark.asyncio async def test_deny_on_missing_bypasses_fallback(self, monkeypatch): """When deny_on_missing_agent=true, should reject without checking fallback.""" # Mock get_default_agent_id to return valid agent from src import gateway monkeypatch.setattr(gateway, "_default_agent_id", "researcher") rules = { "agents": { "researcher": { "allow": {"servers": ["api"]} } }, "defaults": {"deny_on_missing_agent": True} } policy_engine = PolicyEngine(rules) middleware = AgentAccessControl(policy_engine) # Create mock context WITHOUT agent_id tool_call = MockToolCall( name="list_servers", arguments={"include_metadata": False} ) fastmcp_ctx = MockFastMCPContext() context = MockMiddlewareContext(message=tool_call, fastmcp_context=fastmcp_ctx) # Mock call_next call_next = AsyncMock() # Execute middleware - should raise error immediately (doesn't use env var) with pytest.raises(ToolError) as exc_info: await middleware.on_call_tool(context, call_next) # Verify error about missing agent_id error_msg = str(exc_info.value) assert "agent_id" in error_msg.lower() assert "missing" in error_msg.lower() assert not call_next.called @pytest.mark.asyncio async def test_explicit_agent_id_overrides_fallback(self, monkeypatch): """When agent_id is provided, fallback should not be used.""" # Mock get_default_agent_id - but it should be ignored from src import gateway monkeypatch.setattr(gateway, "_default_agent_id", "researcher") rules = { "agents": { "researcher": { "allow": {"servers": ["brave-search"]} }, "backend": { "allow": {"servers": ["postgres"]} } }, "defaults": {"deny_on_missing_agent": False} } policy_engine = PolicyEngine(rules) middleware = AgentAccessControl(policy_engine) # Create mock context WITH explicit agent_id tool_call = MockToolCall( name="list_servers", arguments={"agent_id": "backend", "include_metadata": False} ) fastmcp_ctx = MockFastMCPContext() context = MockMiddlewareContext(message=tool_call, fastmcp_context=fastmcp_ctx) # Mock call_next call_next = AsyncMock(return_value={"servers": ["postgres"]}) # Execute middleware await middleware.on_call_tool(context, call_next) # Verify used explicit agent_id (backend), not fallback (researcher) assert fastmcp_ctx.get_state("current_agent") == "backend" assert call_next.called @pytest.mark.asyncio async def test_fallback_with_special_characters(self, monkeypatch): """Test fallback with agent name containing special characters.""" # Mock get_default_agent_id with dashes and underscores from src import gateway monkeypatch.setattr(gateway, "_default_agent_id", "team-backend_v2") rules = { "agents": { "team-backend_v2": { "allow": {"servers": ["postgres"]} } }, "defaults": {"deny_on_missing_agent": False} } policy_engine = PolicyEngine(rules) middleware = AgentAccessControl(policy_engine) # Create mock context WITHOUT agent_id tool_call = MockToolCall( name="list_servers", arguments={"include_metadata": False} ) fastmcp_ctx = MockFastMCPContext() context = MockMiddlewareContext(message=tool_call, fastmcp_context=fastmcp_ctx) # Mock call_next call_next = AsyncMock(return_value={"servers": ["postgres"]}) # Execute middleware await middleware.on_call_tool(context, call_next) # Verify fallback agent with special chars was used assert fastmcp_ctx.get_state("current_agent") == "team-backend_v2" assert call_next.called @pytest.mark.asyncio async def test_fallback_empty_env_var_treated_as_unset(self, monkeypatch): """Empty GATEWAY_DEFAULT_AGENT env var should fall back to 'default' agent.""" # Mock get_default_agent_id to return None (empty string evaluates to False) from src import gateway monkeypatch.setattr(gateway, "_default_agent_id", None) rules = { "agents": { "default": { "allow": {"servers": ["api"]} } }, "defaults": {"deny_on_missing_agent": False} } policy_engine = PolicyEngine(rules) middleware = AgentAccessControl(policy_engine) # Create mock context WITHOUT agent_id tool_call = MockToolCall( name="list_servers", arguments={"include_metadata": False} ) fastmcp_ctx = MockFastMCPContext() context = MockMiddlewareContext(message=tool_call, fastmcp_context=fastmcp_ctx) # Mock call_next call_next = AsyncMock(return_value={"servers": ["api"]}) # Execute middleware - should fall back to 'default' agent await middleware.on_call_tool(context, call_next) # Verify used 'default' agent (not empty string) assert fastmcp_ctx.get_state("current_agent") == "default" assert call_next.called

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/roddutra/agent-mcp-gateway'

If you have feedback or need assistance with the MCP directory API, please join our Discord server