test_main.py•2.82 kB
"""Tests for the main entry point of the AWS MCP Server."""
from unittest.mock import MagicMock, patch
import pytest
# Import handle_interrupt function for direct testing
from aws_mcp_server.__main__ import handle_interrupt
def test_handle_interrupt():
"""Test the handle_interrupt function."""
with patch("sys.exit") as mock_exit:
# Call the function with mock signal and frame
handle_interrupt(MagicMock(), MagicMock())
# Verify sys.exit was called with 0
mock_exit.assert_called_once_with(0)
@pytest.mark.skip(reason="Cannot reload main module during testing")
def test_main_with_valid_transport():
"""Test main module with valid transport setting."""
with patch("aws_mcp_server.__main__.TRANSPORT", "stdio"):
with patch("aws_mcp_server.__main__.mcp.run") as mock_run:
# We can't easily test the full __main__ module execution
from aws_mcp_server.__main__ import mcp
# Instead, we'll test the specific function we modified
with patch("aws_mcp_server.__main__.logger") as mock_logger:
# Import the function to ensure proper validation
from aws_mcp_server.__main__ import TRANSPORT
# Call the relevant function directly
mcp.run(transport=TRANSPORT)
# Check that mcp.run was called with the correct transport
mock_run.assert_called_once_with(transport="stdio")
# Verify logger was called
mock_logger.info.assert_any_call("Starting server with transport protocol: stdio")
def test_main_transport_validation():
"""Test transport protocol validation."""
with patch("aws_mcp_server.config.TRANSPORT", "invalid"):
from aws_mcp_server.config import TRANSPORT
# Test the main function's validation logic
with patch("aws_mcp_server.server.mcp.run") as mock_run:
with patch("sys.exit") as mock_exit:
with patch("aws_mcp_server.__main__.logger") as mock_logger:
# Execute the validation logic directly
if TRANSPORT not in ("stdio", "sse"):
mock_logger.error(f"Invalid transport protocol: {TRANSPORT}. Must be 'stdio' or 'sse'")
mock_exit(1)
else:
mock_run(transport=TRANSPORT)
# Check that error was logged with invalid transport
mock_logger.error.assert_called_once_with("Invalid transport protocol: invalid. Must be 'stdio' or 'sse'")
# Check that exit was called
mock_exit.assert_called_once_with(1)
# Check that mcp.run was not called
mock_run.assert_not_called()