mcp_server.py•26.1 kB
import json
import re
import time
from mcp.server.fastmcp import FastMCP
from mcp_ssh.config import Config
from mcp_ssh.policy import Policy
from mcp_ssh.ssh_client import SSHClient
from mcp_ssh.tools.utilities import (
ASYNC_TASKS,
TASKS,
hash_command,
log_json,
sanitize_error,
)
mcp = FastMCP()
config = Config()
# Set up notification callback for async tasks
def _send_task_notification(event_type: str, task_id: str, data: dict):
"""Send MCP notification for task events."""
try:
mcp.send_notification(f"tasks/{event_type}", {"task_id": task_id, **data})
except Exception as e:
log_json({"level": "warn", "msg": "notification_failed", "error": str(e)})
# Configure async task manager with notification callback
ASYNC_TASKS.set_notification_callback(_send_task_notification)
def _client_for(alias: str, limits: dict, require_known_host: bool) -> SSHClient:
"""Build SSH client from config and limits."""
host = config.get_host(alias)
creds_ref = host.get("credentials", "")
creds = config.get_credentials(creds_ref) if creds_ref else {}
known_hosts_path = (config.get_policy() or {}).get("known_hosts_path", "")
auto_add = bool(limits.get("host_key_auto_add", False))
# Input validation
hostname = host.get("host", "").strip()
if not hostname:
raise ValueError(f"Host '{alias}' has no hostname configured")
username = creds.get("username", "").strip()
if not username:
if creds_ref:
raise ValueError(
f"Host '{alias}' references credentials '{creds_ref}' but no username found"
)
else:
raise ValueError(
f"Host '{alias}' has no credentials reference and no username configured"
)
port = host.get("port", 22)
try:
port = int(port)
if not (1 <= port <= 65535):
raise ValueError(f"Invalid port {port} for host '{alias}'")
except (ValueError, TypeError) as e:
raise ValueError(f"Invalid port '{port}' for host '{alias}'") from e
# Validate authentication method
key_path = creds.get("key_path", "").strip()
password = creds.get("password", "").strip()
if not key_path and not password:
raise ValueError(
f"Host '{alias}' has no authentication method configured (neither key_path nor password)"
)
return SSHClient(
host=hostname,
username=username,
port=port,
key_path=creds.get("key_path", ""),
password=creds.get("password", ""),
passphrase=creds.get("passphrase", ""),
known_hosts_path=known_hosts_path,
auto_add_host_keys=auto_add,
require_known_host=require_known_host,
)
# Input validation constants
MAX_ALIAS_LENGTH = 100
MAX_COMMAND_LENGTH = 10000
MAX_TAG_LENGTH = 50
MAX_TASK_ID_LENGTH = 200
def _validate_alias(alias: str) -> tuple[bool, str]:
"""Validate alias parameter.
Security: Validates alias format to prevent injection attacks.
- Length limit: 100 characters
- Allowed characters: alphanumeric, dash, underscore, dot
- Cannot be empty
Args:
alias: Alias string to validate
Returns:
Tuple of (is_valid, error_message)
If valid: (True, "")
If invalid: (False, error_message)
"""
if not alias or not alias.strip():
return False, "alias is required"
alias = alias.strip()
# Length validation
if len(alias) > MAX_ALIAS_LENGTH:
return False, f"alias too long (max {MAX_ALIAS_LENGTH} characters)"
# Character validation: alphanumeric, dash, underscore, dot only
if not re.match(r"^[a-zA-Z0-9._-]+$", alias):
return (
False,
"alias contains invalid characters (only alphanumeric, dot, dash, underscore allowed)",
)
return True, ""
def _validate_command(command: str) -> tuple[bool, str]:
"""Validate command parameter.
Security: Validates command format to prevent injection and resource exhaustion.
- Length limit: 10000 characters
- Rejects null bytes
- Rejects control characters (except newline, tab, carriage return)
Args:
command: Command string to validate
Returns:
Tuple of (is_valid, error_message)
If valid: (True, "")
If invalid: (False, error_message)
"""
if not command or not command.strip():
return False, "command is required"
command = command.strip()
# Length validation
if len(command) > MAX_COMMAND_LENGTH:
return False, f"command too long (max {MAX_COMMAND_LENGTH} characters)"
# Null byte validation (common injection vector)
if "\x00" in command:
log_json(
{
"level": "error",
"msg": "security_event",
"type": "null_byte_injection_attempt",
"field": "command",
}
)
return False, "command contains invalid characters (null bytes not allowed)"
# Control character validation (allow newline, tab, carriage return for legitimate use)
# Reject other control characters
for char in command:
if ord(char) < 32 and char not in ["\n", "\t", "\r"]:
log_json(
{
"level": "error",
"msg": "security_event",
"type": "control_character_injection_attempt",
"field": "command",
"char_code": ord(char),
}
)
return False, "command contains invalid control characters"
return True, ""
def _validate_tag(tag: str) -> tuple[bool, str]:
"""Validate tag parameter.
Security: Validates tag format to prevent injection attacks.
- Length limit: 50 characters
- Allowed characters: alphanumeric, dash, underscore, dot
- Cannot be empty
Args:
tag: Tag string to validate
Returns:
Tuple of (is_valid, error_message)
If valid: (True, "")
If invalid: (False, error_message)
"""
if not tag or not tag.strip():
return False, "tag is required"
tag = tag.strip()
# Length validation
if len(tag) > MAX_TAG_LENGTH:
return False, f"tag too long (max {MAX_TAG_LENGTH} characters)"
# Character validation: alphanumeric, dash, underscore, dot only
if not re.match(r"^[a-zA-Z0-9._-]+$", tag):
return (
False,
"tag contains invalid characters (only alphanumeric, dot, dash, underscore allowed)",
)
return True, ""
def _validate_task_id(task_id: str) -> tuple[bool, str]:
"""Validate task_id parameter.
Security: Validates task_id format.
- Length limit: 200 characters
- Format validation (expected: alias:hash:timestamp)
- Cannot be empty
Args:
task_id: Task ID string to validate
Returns:
Tuple of (is_valid, error_message)
If valid: (True, "")
If invalid: (False, error_message)
"""
if not task_id or not task_id.strip():
return False, "task_id is required"
task_id = task_id.strip()
# Length validation
if len(task_id) > MAX_TASK_ID_LENGTH:
return False, f"task_id too long (max {MAX_TASK_ID_LENGTH} characters)"
# Format validation: should match pattern alias:hash:timestamp
# Allow alphanumeric, colon, dash, underscore
if not re.match(r"^[a-zA-Z0-9:_-]+$", task_id):
return (
False,
"task_id contains invalid characters (only alphanumeric, colon, dash, underscore allowed)",
)
return True, ""
def _precheck_network(pol: Policy, hostname: str) -> tuple[bool, str]:
"""Resolve hostname and verify at least one resolved IP is allowed."""
ips = SSHClient.resolve_ips(hostname)
if not ips:
# No resolution: fail closed to be safe.
return False, "DNS resolution failed"
for ip in ips:
if pol.is_ip_allowed(ip):
return True, ""
return False, "No resolved IPs allowed by policy.network"
@mcp.tool()
def ssh_ping() -> str:
"""Health check."""
return "pong"
@mcp.tool()
def ssh_list_hosts() -> str:
"""List configured hosts."""
try:
hosts = config.list_hosts()
return json.dumps(hosts)
except Exception as e:
error_str = str(e)
log_json({"level": "error", "msg": "list_hosts_exception", "error": error_str})
return f"Error: {sanitize_error(error_str)}"
@mcp.tool()
def ssh_describe_host(alias: str = "") -> str:
"""Return host definition in JSON."""
try:
# Input validation
valid, error_msg = _validate_alias(alias)
if not valid:
return f"Error: {error_msg}"
host = config.get_host(alias)
return json.dumps(host, indent=2)
except Exception as e:
error_str = str(e)
log_json(
{"level": "error", "msg": "describe_host_exception", "error": error_str}
)
return f"Error: {sanitize_error(error_str)}"
@mcp.tool()
def ssh_plan(alias: str = "", command: str = "") -> str:
"""Show what would be executed and if policy allows."""
try:
# Input validation
valid, error_msg = _validate_alias(alias)
if not valid:
return f"Error: {error_msg}"
valid, error_msg = _validate_command(command)
if not valid:
return f"Error: {error_msg}"
cmd_hash = hash_command(command)
tags = config.get_host_tags(alias)
pol = Policy(config.get_policy())
allowed = pol.is_allowed(alias, tags, command)
limits = pol.limits_for(alias, tags)
preview = {
"alias": alias,
"command": command,
"hash": cmd_hash,
"allowed": allowed,
"limits": {
"max_seconds": limits.get("max_seconds", 60),
"max_output_bytes": limits.get("max_output_bytes", 1024 * 1024),
"host_key_auto_add": bool(limits.get("host_key_auto_add", False)),
"require_known_host": bool(limits.get("require_known_host", True)),
},
}
return json.dumps(preview, indent=2)
except Exception as e:
error_str = str(e)
log_json({"level": "error", "msg": "plan_exception", "error": error_str})
return f"Error: {sanitize_error(error_str)}"
@mcp.tool()
def ssh_run(alias: str = "", command: str = "") -> str:
"""Execute SSH command with policy, network checks, progress, timeout, and cancellation."""
start = time.time()
try:
# Input validation
valid, error_msg = _validate_alias(alias)
if not valid:
return f"Error: {error_msg}"
valid, error_msg = _validate_command(command)
if not valid:
return f"Error: {error_msg}"
# Normalize after validation
alias = alias.strip()
command = command.strip()
host = config.get_host(alias)
hostname = host.get("host", "")
cmd_hash = hash_command(command)
tags = config.get_host_tags(alias)
pol = Policy(config.get_policy())
# Command policy
allowed = pol.is_allowed(alias, tags, command)
pol.log_decision(alias, cmd_hash, allowed)
if not allowed:
return f"Denied by policy: {command}"
# Network precheck (DNS -> allowlist)
ok, reason = _precheck_network(pol, hostname)
if not ok:
return f"Denied by network policy: {reason}"
limits = pol.limits_for(alias, tags)
max_seconds = int(limits.get("max_seconds", 60))
max_output_bytes = int(limits.get("max_output_bytes", 1024 * 1024))
require_known_host = bool(
limits.get("require_known_host", pol.require_known_host())
)
task_id = TASKS.create(alias, cmd_hash)
def progress_cb(phase, bytes_read, elapsed_ms):
pol.log_progress(task_id, phase, int(bytes_read), int(elapsed_ms))
client = _client_for(alias, limits, require_known_host)
cancel_event = TASKS.get_event(task_id)
(
exit_code,
duration_ms,
cancelled,
timeout,
bytes_out,
bytes_err,
combined,
peer_ip,
) = client.run_streaming(
command=command,
cancel_event=cancel_event,
max_seconds=max_seconds,
max_output_bytes=max_output_bytes,
progress_cb=progress_cb,
)
TASKS.cleanup(task_id)
# Post-connect enforcement: ensure actual peer IP is allowed
if peer_ip and not pol.is_ip_allowed(peer_ip):
pol.log_audit(
alias,
cmd_hash,
int(exit_code),
int(duration_ms),
int(bytes_out),
int(bytes_err),
bool(cancelled),
bool(timeout),
peer_ip,
)
return f"Denied by network policy: peer IP {peer_ip} not allowed"
pol.log_audit(
alias,
cmd_hash,
int(exit_code),
int(duration_ms),
int(bytes_out),
int(bytes_err),
bool(cancelled),
bool(timeout),
peer_ip,
)
result = {
"task_id": task_id,
"alias": alias,
"hash": cmd_hash,
"exit_code": int(exit_code),
"duration_ms": int(duration_ms),
"cancelled": bool(cancelled),
"timeout": bool(timeout),
"target_ip": peer_ip,
"output": combined,
}
return json.dumps(result, indent=2)
except Exception as e:
error_str = str(e)
log_json({"level": "error", "msg": "run_exception", "error": error_str})
return f"Run error: {sanitize_error(error_str)}"
finally:
elapsed = int((time.time() - start) * 1000)
log_json({"type": "trace", "op": "run_done", "elapsed_ms": elapsed})
@mcp.tool()
def ssh_run_on_tag(tag: str = "", command: str = "") -> str:
"""Execute SSH command on all hosts with a tag (with network checks)."""
try:
# Input validation
valid, error_msg = _validate_tag(tag)
if not valid:
return f"Error: {error_msg}"
valid, error_msg = _validate_command(command)
if not valid:
return f"Error: {error_msg}"
# Normalize after validation
tag = tag.strip()
command = command.strip()
aliases = config.find_hosts_by_tag(tag)
if not aliases:
return json.dumps(
{"tag": tag, "results": [], "note": "No hosts matched."}, indent=2
)
results = []
for alias in aliases:
host = config.get_host(alias)
hostname = host.get("host", "")
cmd_hash = hash_command(command)
tags = config.get_host_tags(alias)
pol = Policy(config.get_policy())
# Command policy
allowed = pol.is_allowed(alias, tags, command)
pol.log_decision(alias, cmd_hash, allowed)
if not allowed:
results.append(
{
"alias": alias,
"hash": cmd_hash,
"denied": True,
"reason": "policy",
}
)
continue
# Network precheck
ok, reason = _precheck_network(pol, hostname)
if not ok:
results.append(
{
"alias": alias,
"hash": cmd_hash,
"denied": True,
"reason": f"network: {reason}",
}
)
continue
limits = pol.limits_for(alias, tags)
max_seconds = int(limits.get("max_seconds", 60))
max_output_bytes = int(limits.get("max_output_bytes", 1024 * 1024))
require_known_host = bool(
limits.get("require_known_host", pol.require_known_host())
)
task_id = TASKS.create(alias, cmd_hash)
def progress_cb(
phase, bytes_read, elapsed_ms, pol_ref=pol, task_ref=task_id
):
pol_ref.log_progress(task_ref, phase, int(bytes_read), int(elapsed_ms))
client = _client_for(alias, limits, require_known_host)
cancel_event = TASKS.get_event(task_id)
(
exit_code,
duration_ms,
cancelled,
timeout,
bytes_out,
bytes_err,
combined,
peer_ip,
) = client.run_streaming(
command=command,
cancel_event=cancel_event,
max_seconds=max_seconds,
max_output_bytes=max_output_bytes,
progress_cb=progress_cb,
)
TASKS.cleanup(task_id)
# Post-connect enforcement
if peer_ip and not pol.is_ip_allowed(peer_ip):
pol.log_audit(
alias,
cmd_hash,
int(exit_code),
int(duration_ms),
int(bytes_out),
int(bytes_err),
bool(cancelled),
bool(timeout),
peer_ip,
)
results.append(
{
"alias": alias,
"task_id": task_id,
"hash": cmd_hash,
"denied": True,
"reason": f"network: peer {peer_ip} not allowed",
}
)
continue
pol.log_audit(
alias,
cmd_hash,
int(exit_code),
int(duration_ms),
int(bytes_out),
int(bytes_err),
bool(cancelled),
bool(timeout),
peer_ip,
)
results.append(
{
"alias": alias,
"task_id": task_id,
"hash": cmd_hash,
"exit_code": int(exit_code),
"duration_ms": int(duration_ms),
"cancelled": bool(cancelled),
"timeout": bool(timeout),
"target_ip": peer_ip,
"output": combined,
}
)
return json.dumps({"tag": tag, "results": results}, indent=2)
except Exception as e:
error_str = str(e)
log_json({"level": "error", "msg": "run_on_tag_exception", "error": error_str})
return f"Run on tag error: {sanitize_error(error_str)}"
@mcp.tool()
def ssh_cancel(task_id: str = "") -> str:
"""Request cancellation for a running task."""
try:
# Input validation
valid, error_msg = _validate_task_id(task_id)
if not valid:
return f"Error: {error_msg}"
task_id = task_id.strip()
ok = TASKS.cancel(task_id)
if ok:
return f"Cancellation signaled for task_id: {task_id}"
return f"Task not found: {task_id}"
except Exception as e:
error_str = str(e)
log_json({"level": "error", "msg": "cancel_exception", "error": error_str})
return f"Cancel error: {sanitize_error(error_str)}"
@mcp.tool()
def ssh_reload_config() -> str:
"""Reload configuration files."""
try:
config.reload()
return "Configuration reloaded."
except Exception as e:
error_str = str(e)
log_json({"level": "error", "msg": "reload_exception", "error": error_str})
return f"Reload error: {sanitize_error(error_str)}"
@mcp.tool()
def ssh_run_async(alias: str = "", command: str = "") -> str:
"""Start SSH command asynchronously (SEP-1686 compliant).
Returns immediately with task_id for polling. Use ssh_get_task_status
and ssh_get_task_result to monitor and retrieve results.
"""
try:
# Input validation
valid, error_msg = _validate_alias(alias)
if not valid:
return f"Error: {error_msg}"
valid, error_msg = _validate_command(command)
if not valid:
return f"Error: {error_msg}"
# Normalize after validation
alias = alias.strip()
command = command.strip()
host = config.get_host(alias)
hostname = host.get("host", "")
cmd_hash = hash_command(command)
tags = config.get_host_tags(alias)
pol = Policy(config.get_policy())
# Command policy
allowed = pol.is_allowed(alias, tags, command)
pol.log_decision(alias, cmd_hash, allowed)
if not allowed:
return f"Denied by policy: {command}"
# Network precheck (DNS -> allowlist)
ok, reason = _precheck_network(pol, hostname)
if not ok:
return f"Denied by network policy: {reason}"
limits = pol.limits_for(alias, tags)
require_known_host = bool(
limits.get("require_known_host", pol.require_known_host())
)
# Create SSH client
client = _client_for(alias, limits, require_known_host)
# Enhanced progress callback for async tasks
def progress_cb(phase, bytes_read, elapsed_ms):
pol.log_progress(
f"async:{alias}:{cmd_hash}", phase, int(bytes_read), int(elapsed_ms)
)
# Start async task
task_id = ASYNC_TASKS.start_async_task(
alias=alias,
command=command,
ssh_client=client,
limits=limits,
progress_cb=progress_cb,
)
# Return SEP-1686 compliant response
result = {
"task_id": task_id,
"status": "pending",
"keepAlive": int(limits.get("task_result_ttl", 300)),
"pollFrequency": int(limits.get("task_progress_interval", 5)),
"alias": alias,
"command": command,
"hash": cmd_hash,
}
return json.dumps(result, indent=2)
except Exception as e:
error_str = str(e)
log_json({"level": "error", "msg": "async_run_exception", "error": error_str})
return f"Async run error: {sanitize_error(error_str)}"
@mcp.tool()
def ssh_get_task_status(task_id: str = "") -> str:
"""Get current status of an async task (SEP-1686 compliant).
Returns task state, progress, elapsed time, and output summary.
"""
try:
# Input validation
valid, error_msg = _validate_task_id(task_id)
if not valid:
return f"Error: {error_msg}"
task_id = task_id.strip()
status = ASYNC_TASKS.get_task_status(task_id)
if not status:
return f"Error: Task not found: {task_id}"
return json.dumps(status, indent=2)
except Exception as e:
error_str = str(e)
log_json({"level": "error", "msg": "status_exception", "error": error_str})
return f"Status error: {sanitize_error(error_str)}"
@mcp.tool()
def ssh_get_task_result(task_id: str = "") -> str:
"""Get final result of completed task (SEP-1686 compliant).
Returns complete output, exit code, and execution metadata.
"""
try:
# Input validation
valid, error_msg = _validate_task_id(task_id)
if not valid:
return f"Error: {error_msg}"
task_id = task_id.strip()
result = ASYNC_TASKS.get_task_result(task_id)
if not result:
return f"Error: Task not found or expired: {task_id}"
return json.dumps(result, indent=2)
except Exception as e:
error_str = str(e)
log_json({"level": "error", "msg": "result_exception", "error": error_str})
return f"Result error: {sanitize_error(error_str)}"
@mcp.tool()
def ssh_get_task_output(task_id: str = "", max_lines: int = 50) -> str:
"""Get recent output lines from running or completed task.
Enhanced beyond SEP-1686: enables streaming output visibility.
"""
try:
# Input validation
valid, error_msg = _validate_task_id(task_id)
if not valid:
return f"Error: {error_msg}"
if max_lines < 1 or max_lines > 1000:
return "Error: max_lines must be between 1 and 1000"
task_id = task_id.strip()
output = ASYNC_TASKS.get_task_output(task_id, max_lines)
if not output:
return f"Error: Task not found or no output available: {task_id}"
return json.dumps(output, indent=2)
except Exception as e:
error_str = str(e)
log_json({"level": "error", "msg": "output_exception", "error": error_str})
return f"Output error: {sanitize_error(error_str)}"
@mcp.tool()
def ssh_cancel_async_task(task_id: str = "") -> str:
"""Cancel a running async task."""
try:
# Input validation
valid, error_msg = _validate_task_id(task_id)
if not valid:
return f"Error: {error_msg}"
task_id = task_id.strip()
success = ASYNC_TASKS.cancel_task(task_id)
if success:
return f"Cancellation signaled for async task: {task_id}"
else:
return f"Task not found or not cancellable: {task_id}"
except Exception as e:
error_str = str(e)
log_json(
{"level": "error", "msg": "cancel_async_exception", "error": error_str}
)
return f"Cancel error: {sanitize_error(error_str)}"
def main():
"""Main entry point for MCP server."""
mcp.run(transport="stdio")
if __name__ == "__main__":
main()