#!/usr/bin/env python3
"""
Tests for the Aerith Admin MCP server operation modes (HTTP and STDIO).
"""
import os
import sys
import time
import threading
import subprocess
import signal
import json
import pytest
import uvicorn # Import uvicorn explicitly
import requests # Add back missing import
from unittest import mock
from io import StringIO
# Import the test fixture
from conftest import add_mcp_to_path, configure_logging
class TestServerModes:
"""Test class for HTTP and STDIO server modes."""
@pytest.mark.parametrize("port", [8091]) # Use a non-default port for testing
def test_http_server_startup(self, tmpdir, port):
"""Test HTTP server startup and basic functionality."""
# Set up test environment
os.chdir(tmpdir)
# Let's test the HTTP server mode without threading
# by directly calling the code that would run in the if __name__ == "__main__" block
# Mock necessary components
with mock.patch('uvicorn.run') as mock_run:
# Set up the port argument
with mock.patch.object(sys, 'argv', ['server.py', '--port', str(port)]):
# Import server module
import server
# Create a mock FastMCP instance returned by server.FastMCP
mock_mcp = mock.MagicMock()
mock_mcp.sse_app.return_value = "mock_app"
# Replace server's mcp with our mock
original_mcp = server.mcp
server.mcp = mock_mcp
try:
# Call the function with our mocked values
# This simulates the execution of the __main__ block in non-stdio mode
if hasattr(server, 'main'):
# If there's a main function, call it directly
server.main()
else:
# Otherwise, execute the main part directly
# Get the use_stdio flag from sys.argv
use_stdio = "--stdio" in sys.argv
if not use_stdio:
host = "0.0.0.0"
port_to_use = port
app = mock_mcp.sse_app()
# Use the imported uvicorn directly, not through the server module
uvicorn.run(
app,
host=host,
port=port_to_use,
log_level="info"
)
# Verify that uvicorn.run was called with the correct arguments
mock_run.assert_called_once()
call_args = mock_run.call_args[1]
assert call_args['port'] == port
assert call_args['host'] == '0.0.0.0'
finally:
# Restore original values
server.mcp = original_mcp
def test_stdio_mode_startup(self, tmpdir):
"""Test STDIO mode startup."""
# Set up test environment
os.chdir(tmpdir)
# Let's test the STDIO server mode without threading
# by directly calling the code that would run in the if __name__ == "__main__" block
# Mock necessary components
with mock.patch('server.FastMCP.run_stdio_async') as mock_run_stdio:
# Create a mock asyncio future for run_stdio_async to return
mock_future = mock.MagicMock()
mock_run_stdio.return_value = mock_future
# Mock asyncio event loop functions
with mock.patch('asyncio.new_event_loop') as mock_new_loop:
with mock.patch('asyncio.set_event_loop') as mock_set_loop:
with mock.patch('time.sleep') as mock_sleep: # Prevent actual sleep
# Set up sys.argv for stdio mode
with mock.patch.object(sys, 'argv', ['server.py', '--stdio']):
# Mock stdin and stdout
mock_stdin = StringIO('{"jsonrpc":"2.0","method":"ping","params":{},"id":"test-1"}')
mock_stdout = StringIO()
# Replace sys.stdin and sys.stdout temporarily
with mock.patch.object(sys, 'stdin', mock_stdin), mock.patch.object(sys, 'stdout', mock_stdout):
# Import server module
import server
# Skip signal setup to avoid errors in threads
with mock.patch('signal.signal') as mock_signal:
# Call the function with our mocked values
# This simulates the execution of the __main__ block in stdio mode
if hasattr(server, 'main'):
# If there's a main function, call it directly
server.main()
else:
# Otherwise, execute the main stdio part directly
if "--stdio" in sys.argv:
server.mcp.run_stdio_async()
# Verify run_stdio_async was called
mock_run_stdio.assert_called_once()
def test_signal_handler(self):
"""Test the signal handler for graceful shutdown."""
# Import the server module
import server
# Mock sys.exit to prevent actual exit
with mock.patch('sys.exit') as mock_exit:
# Call the signal handler directly
server.signal_handler(signal.SIGINT, None)
# Verify sys.exit was called
mock_exit.assert_called_once_with(0)
@pytest.mark.slow
class TestServerProcessIntegration:
"""Integration tests that launch actual server processes.
These tests are marked as slow because they launch actual processes.
Skip with pytest -m "not slow" to avoid running these tests.
"""
def test_actual_http_server(self, tmpdir):
"""Test launching an actual HTTP server process and making a request."""
# Set up test environment
os.chdir(tmpdir)
os.makedirs(os.path.join(tmpdir, ".aerith", "instructions"), exist_ok=True)
port = 8092 # Use a specific port for this test
# Get the server path
server_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'server.py')
# Start the server as a subprocess
server_process = subprocess.Popen(
[sys.executable, server_path, '--port', str(port)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
preexec_fn=os.setsid # Create a new process group
)
try:
# Wait for the server to start up
time.sleep(2)
# Try to make a request to the server
response = requests.get(f'http://localhost:{port}/')
# Check the response
assert response.status_code == 200
finally:
# Kill the server process group
os.killpg(os.getpgid(server_process.pid), signal.SIGTERM)
server_process.wait()
def test_actual_stdio_server(self, tmpdir):
"""Test launching an actual STDIO server process and sending a command."""
# Set up test environment
os.chdir(tmpdir)
os.makedirs(os.path.join(tmpdir, ".aerith", "instructions"), exist_ok=True)
# Get the server path
server_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'server.py')
# Create a simple JSON-RPC request
request = {
"jsonrpc": "2.0",
"method": "create_instruction",
"params": {
"title": "Test Instruction",
"description": "Test Description",
"goal": "Test Goal",
"priority": "low"
},
"id": "test-1"
}
request_json = json.dumps(request)
request_bytes = request_json.encode('utf-8')
# Format the request with Content-Length header
stdin_data = f"Content-Length: {len(request_bytes)}\r\n\r\n{request_json}"
# Start the server as a subprocess
server_process = subprocess.Popen(
[sys.executable, server_path, '--stdio'],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
try:
# Send the request
server_process.stdin.write(stdin_data.encode('utf-8'))
server_process.stdin.flush()
# Read the response header
response_header = server_process.stdout.readline().decode('utf-8')
assert response_header.startswith('Content-Length: ')
# Skip empty line
server_process.stdout.readline()
# Read the expected content length
content_length = int(response_header.split(': ')[1])
# Read the response
response_data = server_process.stdout.read(content_length).decode('utf-8')
response = json.loads(response_data)
# Verify the response contains expected fields
assert "jsonrpc" in response
assert "id" in response
assert response["id"] == "test-1"
assert "result" in response
finally:
# Terminate the server process
server_process.terminate()
server_process.wait()