server.py•7.38 kB
import asyncio
import logging
import signal
import traceback
from collections.abc import Sequence
from typing import Any
from mcp.server import Server
from mcp.types import TextContent, Tool
from .shell_executor import ShellExecutor
from .version import __version__
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("mcp-shell-server")
app: Server = Server("mcp-shell-server")
class ExecuteToolHandler:
    """Handler for shell command execution"""
    name = "shell_execute"
    description = "Execute a shell command"
    def __init__(self):
        self.executor = ShellExecutor()
    def get_allowed_commands(self) -> list[str]:
        """Get the allowed commands"""
        return self.executor.validator.get_allowed_commands()
    def get_allowed_patterns(self) -> list[str]:
        """Get the allowed regex patterns"""
        return [
            pattern.pattern
            for pattern in self.executor.validator._get_allowed_patterns()
        ]
    def get_tool_description(self) -> Tool:
        """Get the tool description for the execute command"""
        allowed_commands = ", ".join(self.get_allowed_commands())
        allowed_patterns = ", ".join(self.get_allowed_patterns())
        """Get the tool description for the execute command"""
        return Tool(
            name=self.name,
            description=f"{self.description}\nAllowed commands: {allowed_commands}\nAllowed patterns: {allowed_patterns}",
            inputSchema={
                "type": "object",
                "properties": {
                    "command": {
                        "type": "array",
                        "items": {"type": "string"},
                        "description": "Command and its arguments as array",
                    },
                    "stdin": {
                        "type": "string",
                        "description": "Input to be passed to the command via stdin",
                    },
                    "directory": {
                        "type": "string",
                        "description": "Absolute path to a working directory where the command will be executed",
                    },
                    "timeout": {
                        "type": "integer",
                        "description": "Maximum execution time in seconds",
                        "minimum": 0,
                    },
                },
                "required": ["command", "directory"],
            },
        )
    async def run_tool(self, arguments: dict) -> Sequence[TextContent]:
        """Execute the shell command with the given arguments"""
        command = arguments.get("command", [])
        stdin = arguments.get("stdin")
        directory = arguments.get("directory", "/tmp")  # default to /tmp for safety
        timeout = arguments.get("timeout")
        if not command:
            raise ValueError("No command provided")
        if not isinstance(command, list):
            raise ValueError("'command' must be an array")
        # Make sure directory exists
        if not directory:
            raise ValueError("Directory is required")
        content: list[TextContent] = []
        try:
            # Handle execution with timeout
            try:
                # Add small buffer to timeout for CI scheduling delays if timeout is specified
                actual_timeout = timeout + 0.5 if timeout is not None else None
                result = await asyncio.wait_for(
                    self.executor.execute(
                        command, directory, stdin, None
                    ),  # Pass None for timeout
                    timeout=actual_timeout,
                )
            except asyncio.TimeoutError as e:
                raise ValueError("Command execution timed out") from e
            if result.get("error"):
                raise ValueError(result["error"])
            # Add stdout if present
            if result.get("stdout"):
                content.append(TextContent(type="text", text=result["stdout"]))
            # Add stderr if present (filter out specific messages)
            stderr = result.get("stderr")
            if stderr and "cannot set terminal process group" not in stderr:
                content.append(TextContent(type="text", text=stderr))
        except asyncio.TimeoutError as e:
            raise ValueError(f"Command timed out after {timeout} seconds") from e
        return content
# Initialize tool handlers
tool_handler = ExecuteToolHandler()
@app.list_tools()
async def list_tools() -> list[Tool]:
    """List available tools."""
    return [tool_handler.get_tool_description()]
@app.call_tool()
async def call_tool(name: str, arguments: Any) -> Sequence[TextContent]:
    """Handle tool calls"""
    try:
        if name != tool_handler.name:
            raise ValueError(f"Unknown tool: {name}")
        if not isinstance(arguments, dict):
            raise ValueError("Arguments must be a dictionary")
        return await tool_handler.run_tool(arguments)
    except Exception as e:
        logger.error(traceback.format_exc())
        raise RuntimeError(f"Error executing command: {str(e)}") from e
async def main() -> None:
    """Main entry point for the MCP shell server"""
    logger.info(f"Starting MCP shell server v{__version__}")
    # Setup signal handling
    loop = asyncio.get_running_loop()
    stop_event = asyncio.Event()
    def handle_signal():
        if not stop_event.is_set():  # Prevent duplicate handling
            logger.info("Received shutdown signal, starting cleanup...")
            stop_event.set()
    # Register signal handlers
    for sig in (signal.SIGTERM, signal.SIGINT):
        loop.add_signal_handler(sig, handle_signal)
    try:
        from mcp.server.stdio import stdio_server
        async with stdio_server() as (read_stream, write_stream):
            # Run the server until stop_event is set
            server_task = asyncio.create_task(
                app.run(read_stream, write_stream, app.create_initialization_options())
            )
            # Create task for stop event
            stop_task = asyncio.create_task(stop_event.wait())
            # Wait for either server completion or stop signal
            done, pending = await asyncio.wait(
                [server_task, stop_task], return_when=asyncio.FIRST_COMPLETED
            )
            # Check for exceptions in completed tasks
            for task in done:
                try:
                    await task
                except Exception:
                    raise  # Re-raise the exception
            # Cancel any pending tasks
            for task in pending:
                task.cancel()
                try:
                    await task
                except asyncio.CancelledError:
                    pass
    except Exception as e:
        logger.error(f"Server error: {str(e)}")
        raise
    finally:
        # Cleanup signal handlers
        for sig in (signal.SIGTERM, signal.SIGINT):
            loop.remove_signal_handler(sig)
        # Ensure all processes are terminated
        if hasattr(tool_handler, "executor") and hasattr(
            tool_handler.executor, "process_manager"
        ):
            await tool_handler.executor.process_manager.cleanup_processes()
        logger.info("Server shutdown complete")