Skip to main content
Glama
ssh_executor.py11.3 kB
"""SSH executor for remote command execution. This module provides functionality to execute commands on remote systems via SSH, with connection pooling and SSH key discovery. It seamlessly routes commands to either local or remote execution based on the provided parameters. """ import asyncio import logging import os import shlex import subprocess import time from pathlib import Path from typing import Optional, Tuple import asyncssh from ..audit import log_ssh_connect, log_ssh_command logger = logging.getLogger(__name__) def discover_ssh_key() -> Optional[str]: """ Discover SSH private key for authentication. Checks in order: 1. LINUX_MCP_SSH_KEY_PATH environment variable 2. Default locations: ~/.ssh/id_ed25519, ~/.ssh/id_rsa, ~/.ssh/id_ecdsa Returns: Path to SSH private key if found, None otherwise. """ logger.debug("Discovering SSH key for authentication") # Check environment variable first env_key = os.getenv("LINUX_MCP_SSH_KEY_PATH") if env_key: logger.debug(f"Checking SSH key from environment: {env_key}") key_path = Path(env_key) if key_path.exists() and key_path.is_file(): logger.info(f"Using SSH key from environment: {env_key}") return str(key_path) else: logger.warning(f"SSH key specified in LINUX_MCP_SSH_KEY_PATH not found: {env_key}") return None # Check default locations (prefer modern algorithms) home = Path.home() default_keys = [ home / ".ssh" / "id_ed25519", home / ".ssh" / "id_ecdsa", home / ".ssh" / "id_rsa", ] logger.debug(f"Checking default SSH key locations: {[str(k) for k in default_keys]}") for key_path in default_keys: if key_path.exists() and key_path.is_file(): logger.info(f"Using SSH key: {key_path}") return str(key_path) logger.warning("No SSH private key found in default locations") return None class SSHConnectionManager: """ Manages SSH connections with connection pooling. This class implements a singleton pattern to maintain a pool of SSH connections across the lifetime of the application, improving performance by reusing connections to the same hosts. """ _instance = None def __new__(cls): """Implement singleton pattern.""" if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._connections = {} cls._instance._ssh_key = discover_ssh_key() return cls._instance async def get_connection(self, host: str, username: str) -> asyncssh.SSHClientConnection: """ Get or create an SSH connection to a host. Args: host: Remote host address username: SSH username Returns: SSH connection object Raises: ConnectionError: If connection fails """ key = f"{username}@{host}" # Return existing connection if available if key in self._connections: conn = self._connections[key] if not conn.is_closed(): # DEBUG level: Log connection reuse and pool state logger.debug(f"SSH_REUSE: {key} | pool_size={len(self._connections)}") # Use audit log with connection reuse info log_ssh_connect(host, username, status="success", reused=True, key_path=self._ssh_key) return conn else: # Connection was closed, remove it logger.debug(f"SSH_POOL: remove_closed_connection | connection={key}") del self._connections[key] # Create new connection # DEBUG level: Log connection attempt before it completes logger.debug(f"SSH_CONNECTING: {key} | key={self._ssh_key or 'none'}") try: connect_kwargs = { "host": host, "username": username, "known_hosts": None, # Don't verify host keys for now } if self._ssh_key: connect_kwargs["client_keys"] = [self._ssh_key] conn = await asyncssh.connect(**connect_kwargs) self._connections[key] = conn # Log successful connection using audit function log_ssh_connect(host, username, status="success", reused=False, key_path=self._ssh_key) # DEBUG level: Log pool state logger.debug(f"SSH_POOL: add_connection | connections={len(self._connections)}") return conn except asyncssh.PermissionDenied as e: # Use audit log for authentication failure error_msg = str(e) log_ssh_connect(host, username, status="failed", error=f"Permission denied: {error_msg}") raise ConnectionError(f"Authentication failed for {username}@{host}") from e except asyncssh.Error as e: # Use audit log for connection failure error_msg = str(e) log_ssh_connect(host, username, status="failed", error=error_msg) raise ConnectionError(f"Failed to connect to {username}@{host}: {e}") from e async def execute_remote( self, command: list[str], host: str, username: str ) -> Tuple[int, str, str]: """ Execute a command on a remote host via SSH. Args: command: Command and arguments to execute host: Remote host address username: SSH username Returns: Tuple of (return_code, stdout, stderr) Raises: ConnectionError: If SSH connection fails """ conn = await self.get_connection(host, username) # Build command string with proper shell escaping # Use shlex.quote() to ensure special characters (like \n in printf format) are preserved cmd_str = " ".join(shlex.quote(arg) for arg in command) # Start timing for command execution start_time = time.time() try: result = await conn.run(cmd_str, check=False) return_code = result.exit_status if result.exit_status is not None else 0 stdout = result.stdout if result.stdout else "" stderr = result.stderr if result.stderr else "" # Calculate duration duration = time.time() - start_time # Use audit log for command execution log_ssh_command(cmd_str, host, exit_code=return_code, duration=duration) return return_code, stdout, stderr except asyncssh.Error as e: duration = time.time() - start_time logger.error(f"Error executing command on {username}@{host}: {e}", extra={ 'event': 'REMOTE_EXEC_ERROR', 'command': cmd_str, 'host': host, 'duration': f"{duration:.3f}s", 'error': str(e) }) raise ConnectionError(f"Failed to execute command on {username}@{host}: {e}") from e async def close_all(self): """Close all SSH connections.""" connection_count = len(self._connections) logger.info(f"Closing {connection_count} SSH connections") for key, conn in list(self._connections.items()): try: logger.debug(f"SSH_CLOSE: {key}") conn.close() await conn.wait_closed() except Exception as e: logger.warning(f"Error closing connection to {key}: {e}") self._connections.clear() logger.debug(f"SSH_POOL: cleared | closed_connections={connection_count}") # Global connection manager instance _connection_manager = SSHConnectionManager() async def execute_command( command: list[str], host: Optional[str] = None, username: Optional[str] = None, **kwargs ) -> Tuple[int, str, str]: """ Execute a command locally or remotely. This is the main entry point for command execution. It routes the command to either local subprocess execution or remote SSH execution based on whether host/username parameters are provided. Args: command: Command and arguments to execute host: Optional remote host address username: Optional SSH username (required if host is provided) **kwargs: Additional arguments (reserved for future use) Returns: Tuple of (return_code, stdout, stderr) Raises: ValueError: If host is provided without username ConnectionError: If remote connection fails Examples: # Local execution >>> returncode, stdout, stderr = await execute_command(["ls", "-la"]) # Remote execution >>> returncode, stdout, stderr = await execute_command( ... ["ls", "-la"], ... host="server.example.com", ... username="admin" ... ) """ cmd_str = " ".join(command) # Route to remote execution if host is provided if host: if not username: logger.error(f"Host provided without username for command: {cmd_str}") raise ValueError("username is required when host is provided") logger.debug(f"Routing to remote execution: {username}@{host} | command={cmd_str}") return await _connection_manager.execute_remote(command, host, username) # Local execution logger.debug(f"LOCAL_EXEC: {cmd_str}") return await _execute_local(command) async def _execute_local(command: list[str]) -> Tuple[int, str, str]: """ Execute a command locally using subprocess. Args: command: Command and arguments to execute Returns: Tuple of (return_code, stdout, stderr) """ cmd_str = " ".join(command) start_time = time.time() try: proc = await asyncio.create_subprocess_exec( *command, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) stdout_bytes, stderr_bytes = await proc.communicate() return_code = proc.returncode if proc.returncode is not None else 0 stdout = stdout_bytes.decode('utf-8', errors='replace') stderr = stderr_bytes.decode('utf-8', errors='replace') duration = time.time() - start_time # DEBUG level: Log local command execution with timing logger.debug(f"LOCAL_EXEC completed: {cmd_str} | exit_code={return_code} | duration={duration:.3f}s") return return_code, stdout, stderr except Exception as e: duration = time.time() - start_time logger.error(f"Error executing local command: {cmd_str}", extra={ 'event': 'LOCAL_EXEC_ERROR', 'command': cmd_str, 'duration': f"{duration:.3f}s", 'error': str(e) }) return 1, "", str(e)

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/narmaku/linux-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server