"""Unit tests for transport selection and execution."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastmcp import settings as fastmcp_settings
from mcp_atlassian import _run_stdio_with_stdin_guard, main
class TestMainTransportSelection:
"""Test the main function's transport-specific execution logic."""
@pytest.fixture
def mock_server(self):
"""Create a mock server instance."""
server = MagicMock()
server.run_async = AsyncMock(return_value=None)
return server
@pytest.fixture
def mock_asyncio_run(self):
"""Mock asyncio.run to capture what coroutine is executed."""
with patch("asyncio.run") as mock_run:
# Store the coroutine for inspection
mock_run.side_effect = lambda coro: setattr(mock_run, "_called_with", coro)
yield mock_run
@pytest.mark.parametrize("transport", ["sse", "streamable-http"])
def test_http_transports_use_direct_execution(
self, mock_server, mock_asyncio_run, transport
):
"""Verify HTTP transports use direct execution without stdin monitoring.
This is a regression test for issues #519 and #524.
"""
with patch("mcp_atlassian.servers.main.AtlassianMCP", return_value=mock_server):
with patch.dict("os.environ", {"TRANSPORT": transport}):
with patch("sys.argv", ["mcp-atlassian"]):
try:
main()
except SystemExit:
pass
# Verify asyncio.run was called
assert mock_asyncio_run.called
# Get the coroutine info
called_coro = mock_asyncio_run._called_with
coro_repr = repr(called_coro)
assert "_run_stdio_with_stdin_guard" not in coro_repr
assert "run_async" in coro_repr or hasattr(called_coro, "cr_code")
def test_stdio_transport_uses_stdin_guard(self, mock_server, mock_asyncio_run):
with patch("mcp_atlassian.servers.main.AtlassianMCP", return_value=mock_server):
with patch.dict("os.environ", {"TRANSPORT": "stdio"}):
with patch("sys.argv", ["mcp-atlassian"]):
try:
main()
except SystemExit:
pass
assert mock_asyncio_run.called
called_coro = mock_asyncio_run._called_with
coro_repr = repr(called_coro)
assert "_run_stdio_with_stdin_guard" in coro_repr
@pytest.mark.parametrize("stateless", ["False", "True"])
def test_stateless_set(self, mock_asyncio_run, stateless):
"""Verify that the server is started in stateless mode when the environment variable is set."""
with patch.dict("os.environ", {"STATELESS": stateless}):
with patch.dict("os.environ", {"TRANSPORT": "streamable-http"}):
with patch("sys.argv", ["mcp-atlassian"]):
try:
main()
except SystemExit:
pass
# Verify asyncio.run was called
assert mock_asyncio_run.called
desired = stateless.lower() == "true"
assert fastmcp_settings.stateless_http == desired
@pytest.mark.parametrize("transport", ["stdio", "sse"])
def test_stateless_rejects_non_streamable_http(self, mock_asyncio_run, transport):
"""Verify that --stateless flag errors when used with non-streamable-http transport."""
with patch.dict("os.environ", {"STATELESS": "true", "TRANSPORT": transport}):
with patch("sys.argv", ["mcp-atlassian"]):
with pytest.raises(SystemExit) as exc_info:
main()
# Should exit with code 1 (error)
assert exc_info.value.code == 1
def test_cli_overrides_env_transport(self, mock_server, mock_asyncio_run):
"""Test that CLI transport argument overrides environment variable."""
with patch("mcp_atlassian.servers.main.AtlassianMCP", return_value=mock_server):
with patch.dict("os.environ", {"TRANSPORT": "sse"}):
# Simulate CLI args with --transport stdio
with patch("sys.argv", ["mcp-atlassian", "--transport", "stdio"]):
try:
main()
except SystemExit:
pass
called_coro = mock_asyncio_run._called_with
coro_repr = repr(called_coro)
assert "_run_stdio_with_stdin_guard" in coro_repr
@pytest.mark.asyncio
async def test_stdio_guard_cancels_server_when_parent_exits(self):
server_started = asyncio.Event()
server_cancelled = asyncio.Event()
async def fake_run_async(**kwargs):
del kwargs
server_started.set()
try:
await asyncio.Event().wait()
except asyncio.CancelledError:
server_cancelled.set()
raise
async def fake_watch_parent() -> None:
await server_started.wait()
with patch(
"mcp_atlassian.servers.main_mcp.run_async", side_effect=fake_run_async
):
with patch(
"mcp_atlassian._watch_parent_exit", side_effect=fake_watch_parent
):
await _run_stdio_with_stdin_guard({"transport": "stdio"})
assert server_cancelled.is_set()
def test_signal_handlers_always_setup(self, mock_server):
"""Test that signal handlers are set up regardless of transport."""
with patch("mcp_atlassian.servers.main.AtlassianMCP", return_value=mock_server):
with patch("asyncio.run"):
# Patch where it's imported in the main module
with patch("mcp_atlassian.setup_signal_handlers") as mock_setup:
with patch.dict("os.environ", {"TRANSPORT": "stdio"}):
with patch("sys.argv", ["mcp-atlassian"]):
try:
main()
except SystemExit:
pass
# Signal handlers should always be set up
mock_setup.assert_called_once()
def test_error_handling_preserved(self, mock_server):
"""Test that error handling works correctly for all transports."""
# Make the server's run_async raise an exception when awaited
error = RuntimeError("Server error")
async def failing_run_async(**kwargs):
raise error
mock_server.run_async = failing_run_async
with patch("mcp_atlassian.servers.main.AtlassianMCP", return_value=mock_server):
with patch("asyncio.run") as mock_run:
# Simulate the exception propagating through asyncio.run
mock_run.side_effect = error
with patch.dict("os.environ", {"TRANSPORT": "stdio"}):
with patch("sys.argv", ["mcp-atlassian"]):
# The main function logs the error and exits with code 1
with patch("sys.exit") as mock_exit:
main()
# Verify error was handled - sys.exit called with 1 for error
# and then with 0 in the finally block
assert mock_exit.call_count == 2
assert mock_exit.call_args_list[0][0][0] == 1 # Error exit
assert (
mock_exit.call_args_list[1][0][0] == 0
) # Finally exit