test_server_tools.py•6.82 kB
"""Tests for MCP server tools."""
import json
import os
import tempfile
import pytest
import yaml
from mcp_ssh import mcp_server
from mcp_ssh.config import Config
@pytest.fixture
def mock_config():
    """Create a mock config for testing."""
    with tempfile.TemporaryDirectory() as tmpdir:
        # Create test servers.yml
        servers = {
            "hosts": [
                {
                    "alias": "test1",
                    "host": "10.0.0.1",
                    "port": 22,
                    "credentials": "cred1",
                    "tags": ["web"],
                },
                {
                    "alias": "test2",
                    "host": "10.0.0.2",
                    "port": 22,
                    "credentials": "cred1",
                    "tags": ["db", "prod"],
                },
            ]
        }
        with open(os.path.join(tmpdir, "servers.yml"), "w") as f:
            yaml.dump(servers, f)
        credentials = {
            "entries": [
                {"name": "cred1", "username": "user1", "key_path": "id_ed25519"},
            ]
        }
        with open(os.path.join(tmpdir, "credentials.yml"), "w") as f:
            yaml.dump(credentials, f)
        policy = {
            "limits": {"max_seconds": 60},
            "rules": [
                {
                    "action": "allow",
                    "aliases": ["*"],
                    "tags": [],
                    "commands": ["uptime*"],
                },
            ],
        }
        with open(os.path.join(tmpdir, "policy.yml"), "w") as f:
            yaml.dump(policy, f)
        # Replace global config
        config = Config(config_dir=tmpdir)
        mcp_server.config = config
        yield config
def test_ssh_ping():
    """Test ping tool."""
    result = mcp_server.ssh_ping()
    assert result == "pong"
def test_ssh_list_hosts(mock_config):
    """Test list_hosts tool."""
    result = mcp_server.ssh_list_hosts()
    hosts = json.loads(result)
    assert isinstance(hosts, list)
    assert len(hosts) == 2
    assert "test1" in hosts
    assert "test2" in hosts
def test_ssh_describe_host(mock_config):
    """Test describe_host tool."""
    result = mcp_server.ssh_describe_host(alias="test1")
    host = json.loads(result)
    assert host["alias"] == "test1"
    assert host["host"] == "10.0.0.1"
    assert host["port"] == 22
def test_ssh_describe_host_not_found(mock_config):
    """Test describe_host with non-existent host."""
    result = mcp_server.ssh_describe_host(alias="nonexistent")
    assert "Error" in result
def test_ssh_plan(mock_config):
    """Test plan tool."""
    result = mcp_server.ssh_plan(alias="test1", command="uptime")
    plan = json.loads(result)
    assert plan["alias"] == "test1"
    assert plan["command"] == "uptime"
    assert "allowed" in plan
    assert "limits" in plan
def test_ssh_reload_config(mock_config):
    """Test reload_config tool."""
    result = mcp_server.ssh_reload_config()
    assert "reloaded" in result.lower()
def test_ssh_cancel_not_found():
    """Test cancel tool with non-existent task."""
    result = mcp_server.ssh_cancel(task_id="nonexistent")
    assert "not found" in result.lower()
def test_ssh_cancel_no_task_id():
    """Test cancel tool without task_id."""
    result = mcp_server.ssh_cancel(task_id="")
    assert "required" in result.lower()
# Note: Testing ssh_run and ssh_run_on_tag requires mocking SSH connections,
# which is complex and better suited for integration tests with a real SSH server.
# The core logic is tested through the individual component tests (policy, ssh_client, config).
def test_default_parameters():
    """Test that all MCP tools have default empty string parameters."""
    # Per instructions: all params should default to empty strings
    assert mcp_server.ssh_describe_host() == "Error: alias is required"
    assert mcp_server.ssh_plan() != ""  # Should return an error but not crash
def test_ssh_run_async_invalid_alias():
    """Test ssh_run_async with invalid alias."""
    result = mcp_server.ssh_run_async(alias="nonexistent", command="uptime")
    assert "Error" in result or "error" in result.lower()
    assert "not found" in result.lower()
def test_ssh_run_async_no_alias():
    """Test ssh_run_async without alias."""
    result = mcp_server.ssh_run_async(alias="", command="uptime")
    assert "Error" in result
    assert "required" in result.lower()
def test_ssh_run_async_no_command():
    """Test ssh_run_async without command."""
    result = mcp_server.ssh_run_async(alias="test1", command="")
    assert "Error" in result
    assert "required" in result.lower()
def test_ssh_get_task_status_invalid_task():
    """Test ssh_get_task_status with invalid task ID."""
    result = mcp_server.ssh_get_task_status(task_id="invalid:task:id")
    assert "Error" in result
    assert "not found" in result.lower()
def test_ssh_get_task_status_no_task_id():
    """Test ssh_get_task_status without task_id."""
    result = mcp_server.ssh_get_task_status(task_id="")
    assert "Error" in result
    assert "required" in result.lower()
def test_ssh_get_task_result_invalid_task():
    """Test ssh_get_task_result with invalid task ID."""
    result = mcp_server.ssh_get_task_result(task_id="invalid:task:id")
    assert "Error" in result
    assert "not found" in result.lower()
def test_ssh_get_task_result_no_task_id():
    """Test ssh_get_task_result without task_id."""
    result = mcp_server.ssh_get_task_result(task_id="")
    assert "Error" in result
    assert "required" in result.lower()
def test_ssh_get_task_output_invalid_task():
    """Test ssh_get_task_output with invalid task ID."""
    result = mcp_server.ssh_get_task_output(task_id="invalid:task:id")
    assert "Error" in result
    assert "not found" in result.lower()
def test_ssh_get_task_output_no_task_id():
    """Test ssh_get_task_output without task_id."""
    result = mcp_server.ssh_get_task_output(task_id="")
    assert "Error" in result
    assert "required" in result.lower()
def test_ssh_cancel_async_task_invalid_task():
    """Test ssh_cancel_async_task with invalid task ID."""
    result = mcp_server.ssh_cancel_async_task(task_id="invalid:task:id")
    assert "Error" in result or "not found" in result.lower()
def test_ssh_cancel_async_task_no_task_id():
    """Test ssh_cancel_async_task without task_id."""
    result = mcp_server.ssh_cancel_async_task(task_id="")
    assert "Error" in result
    assert "required" in result.lower()
# Note: Testing actual async task execution requires mocking SSH connections,
# which is complex and better suited for integration tests with a real SSH server.
# The core async task management logic is tested through the AsyncTaskManager tests.