mcp-shell-server

by tumf
""" Test configuration and fixtures. """ import asyncio from typing import IO from unittest.mock import AsyncMock, MagicMock import pytest import pytest_asyncio from mcp_shell_server.shell_executor import ShellExecutor @pytest.fixture def mock_file(mocker): """Provide a mock file object.""" mock = mocker.MagicMock(spec=IO) mock.close = mocker.MagicMock() return mock @pytest_asyncio.fixture async def mock_process_manager(): """Provide a mock process manager.""" manager = MagicMock() # Mock process object process = AsyncMock() process.returncode = 0 # Mock manager methods manager.create_process = AsyncMock() async def create_process_side_effect(*args, **kwargs): process = AsyncMock() process.returncode = 0 process.communicate = AsyncMock(return_value=(b"", b"")) process.kill = AsyncMock() process.wait = AsyncMock() return process manager.create_process.side_effect = create_process_side_effect manager.execute_with_timeout = AsyncMock() manager.execute_pipeline = AsyncMock() manager.cleanup_processes = AsyncMock() # Set empty default return values - tests should override these as needed manager.execute_with_timeout.return_value = (b"", b"") manager.execute_pipeline.return_value = (b"", b"", 0) return manager @pytest_asyncio.fixture async def shell_executor_with_mock(mock_process_manager): """Provide a shell executor with mock process manager.""" executor = ShellExecutor(process_manager=mock_process_manager) return executor @pytest.fixture def temp_test_dir(tmpdir): """Provide a temporary test directory.""" return str(tmpdir) @pytest_asyncio.fixture(scope="function") async def event_loop(): """Create and provide a new event loop for each module.""" loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) yield loop # Clean up the event loop try: # Close all tasks tasks = asyncio.all_tasks(loop) if tasks: # Cancel all tasks and wait for their completion for task in tasks: task.cancel() loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) # Clean up all transports if hasattr(loop, "_transports"): for transport in list(loop._transports.values()): if hasattr(transport, "close"): transport.close() # Cleanup loop.stop() asyncio.set_event_loop(None) await loop.shutdown_asyncgens() loop.close() except Exception as e: print(f"Error during loop cleanup: {e}")