test_input_validation.py•8.3 kB
"""Tests for input validation in MCP server tools."""
from mcp_ssh import mcp_server
def test_validate_alias_valid():
    """Test that valid aliases pass validation."""
    # Valid aliases
    valid, msg = mcp_server._validate_alias("web1")
    assert valid is True
    assert msg == ""
    valid, msg = mcp_server._validate_alias("prod-web-1")
    assert valid is True
    valid, msg = mcp_server._validate_alias("db_server")
    assert valid is True
    valid, msg = mcp_server._validate_alias("server.example.com")
    assert valid is True
def test_validate_alias_empty():
    """Test that empty aliases are rejected."""
    valid, msg = mcp_server._validate_alias("")
    assert valid is False
    assert "required" in msg.lower()
    valid, msg = mcp_server._validate_alias("   ")
    assert valid is False
    assert "required" in msg.lower()
def test_validate_alias_too_long():
    """Test that aliases exceeding length limit are rejected."""
    long_alias = "a" * (mcp_server.MAX_ALIAS_LENGTH + 1)
    valid, msg = mcp_server._validate_alias(long_alias)
    assert valid is False
    assert "too long" in msg.lower()
    assert str(mcp_server.MAX_ALIAS_LENGTH) in msg
def test_validate_alias_invalid_characters():
    """Test that aliases with invalid characters are rejected."""
    # Special characters not allowed
    invalid_aliases = [
        "web@1",
        "web#1",
        "web$1",
        "web/1",
        "web\\1",
        "web 1",
        "web(1)",
        "web[1]",
        "web{1}",
        "web|1",
        "web&1",
    ]
    for alias in invalid_aliases:
        valid, msg = mcp_server._validate_alias(alias)
        assert valid is False, f"Alias '{alias}' should be rejected"
        assert "invalid characters" in msg.lower()
def test_validate_alias_at_limit():
    """Test that aliases at exactly the length limit are accepted."""
    alias_at_limit = "a" * mcp_server.MAX_ALIAS_LENGTH
    valid, msg = mcp_server._validate_alias(alias_at_limit)
    assert valid is True
def test_validate_command_valid():
    """Test that valid commands pass validation."""
    valid, msg = mcp_server._validate_command("uptime")
    assert valid is True
    assert msg == ""
    valid, msg = mcp_server._validate_command("df -h")
    assert valid is True
    valid, msg = mcp_server._validate_command("echo 'hello world'")
    assert valid is True
    # Commands with newlines (legitimate use case)
    valid, msg = mcp_server._validate_command("echo 'line1'\necho 'line2'")
    assert valid is True
    # Commands with tabs (legitimate use case)
    valid, msg = mcp_server._validate_command("echo 'line1'\techo 'line2'")
    assert valid is True
def test_validate_command_empty():
    """Test that empty commands are rejected."""
    valid, msg = mcp_server._validate_command("")
    assert valid is False
    assert "required" in msg.lower()
    valid, msg = mcp_server._validate_command("   ")
    assert valid is False
    assert "required" in msg.lower()
def test_validate_command_too_long():
    """Test that commands exceeding length limit are rejected."""
    long_command = "a" * (mcp_server.MAX_COMMAND_LENGTH + 1)
    valid, msg = mcp_server._validate_command(long_command)
    assert valid is False
    assert "too long" in msg.lower()
    assert str(mcp_server.MAX_COMMAND_LENGTH) in msg
def test_validate_command_at_limit():
    """Test that commands at exactly the length limit are accepted."""
    command_at_limit = "a" * mcp_server.MAX_COMMAND_LENGTH
    valid, msg = mcp_server._validate_command(command_at_limit)
    assert valid is True
def test_validate_command_null_bytes():
    """Test that commands with null bytes are rejected."""
    command_with_null = "echo\x00hello"
    valid, msg = mcp_server._validate_command(command_with_null)
    assert valid is False
    assert "null bytes" in msg.lower()
def test_validate_command_control_characters():
    """Test that commands with control characters are rejected."""
    # Control characters except newline, tab, carriage return
    invalid_commands = [
        "echo\x01hello",  # SOH
        "echo\x02hello",  # STX
        "echo\x03hello",  # ETX
        "echo\x04hello",  # EOT
        "echo\x05hello",  # ENQ
        "echo\x1bhello",  # ESC
    ]
    for cmd in invalid_commands:
        valid, msg = mcp_server._validate_command(cmd)
        assert (
            valid is False
        ), f"Command with control char should be rejected: {repr(cmd)}"
        assert "control characters" in msg.lower()
def test_validate_command_allowed_control_characters():
    """Test that newline, tab, and carriage return are allowed."""
    # These should be allowed for legitimate use
    valid, msg = mcp_server._validate_command("echo 'line1'\necho 'line2'")
    assert valid is True
    valid, msg = mcp_server._validate_command("echo 'line1'\techo 'line2'")
    assert valid is True
    valid, msg = mcp_server._validate_command("echo 'line1'\recho 'line2'")
    assert valid is True
def test_validate_tag_valid():
    """Test that valid tags pass validation."""
    valid, msg = mcp_server._validate_tag("web")
    assert valid is True
    assert msg == ""
    valid, msg = mcp_server._validate_tag("prod-web")
    assert valid is True
    valid, msg = mcp_server._validate_tag("db_server")
    assert valid is True
    valid, msg = mcp_server._validate_tag("server.example")
    assert valid is True
def test_validate_tag_empty():
    """Test that empty tags are rejected."""
    valid, msg = mcp_server._validate_tag("")
    assert valid is False
    assert "required" in msg.lower()
    valid, msg = mcp_server._validate_tag("   ")
    assert valid is False
    assert "required" in msg.lower()
def test_validate_tag_too_long():
    """Test that tags exceeding length limit are rejected."""
    long_tag = "a" * (mcp_server.MAX_TAG_LENGTH + 1)
    valid, msg = mcp_server._validate_tag(long_tag)
    assert valid is False
    assert "too long" in msg.lower()
    assert str(mcp_server.MAX_TAG_LENGTH) in msg
def test_validate_tag_invalid_characters():
    """Test that tags with invalid characters are rejected."""
    invalid_tags = ["web@1", "web#1", "web/1", "web 1"]
    for tag in invalid_tags:
        valid, msg = mcp_server._validate_tag(tag)
        assert valid is False, f"Tag '{tag}' should be rejected"
        assert "invalid characters" in msg.lower()
def test_validate_task_id_valid():
    """Test that valid task IDs pass validation."""
    valid, msg = mcp_server._validate_task_id("web1:abc123:1234567890")
    assert valid is True
    assert msg == ""
    valid, msg = mcp_server._validate_task_id("prod-web-1:def456:9876543210")
    assert valid is True
def test_validate_task_id_empty():
    """Test that empty task IDs are rejected."""
    valid, msg = mcp_server._validate_task_id("")
    assert valid is False
    assert "required" in msg.lower()
    valid, msg = mcp_server._validate_task_id("   ")
    assert valid is False
    assert "required" in msg.lower()
def test_validate_task_id_too_long():
    """Test that task IDs exceeding length limit are rejected."""
    long_task_id = "a" * (mcp_server.MAX_TASK_ID_LENGTH + 1)
    valid, msg = mcp_server._validate_task_id(long_task_id)
    assert valid is False
    assert "too long" in msg.lower()
    assert str(mcp_server.MAX_TASK_ID_LENGTH) in msg
def test_validate_task_id_invalid_characters():
    """Test that task IDs with invalid characters are rejected."""
    invalid_task_ids = [
        "web1@abc123:1234567890",
        "web1#abc123:1234567890",
        "web1/abc123:1234567890",
        "web1 abc123:1234567890",
    ]
    for task_id in invalid_task_ids:
        valid, msg = mcp_server._validate_task_id(task_id)
        assert valid is False, f"Task ID '{task_id}' should be rejected"
        assert "invalid characters" in msg.lower()
def test_validate_task_id_allowed_characters():
    """Test that task IDs with allowed characters pass validation."""
    # Colon is allowed (format: alias:hash:timestamp)
    valid, msg = mcp_server._validate_task_id("web1:abc123:1234567890")
    assert valid is True
    # Dash and underscore are allowed
    valid, msg = mcp_server._validate_task_id("prod-web_1:abc123:1234567890")
    assert valid is True