test_mcp_server.py•11.5 kB
import json
import logging
import os
import random
import subprocess
import time
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Iterable, Literal
import pytest
from fastmcp import Client
from fastmcp.client import SSETransport, StdioTransport, StreamableHttpTransport
from mcp.types import TextContent
from integtests.conftest import (
    DEV_STORAGE_API_URL_ENV_VAR,
    DEV_STORAGE_TOKEN_ENV_VAR,
    DEV_WORKSPACE_SCHEMA_ENV_VAR,
    ConfigDef,
)
from keboola_mcp_server.tools.components.model import Configuration
from keboola_mcp_server.tools.project import ProjectInfo
LOG = logging.getLogger(__name__)
HttpTransportStr = Literal['sse', 'streamable-http']
@pytest.mark.asyncio
async def test_stdio_setup(
    configs: list[ConfigDef],
    storage_api_token: str,
    workspace_schema: str,
    storage_api_url: str,
):
    assert storage_api_token is not None
    assert workspace_schema is not None
    assert storage_api_url is not None
    transport = StdioTransport(
        command='python',
        args=[
            '-m',
            'keboola_mcp_server',
            '--api-url',
            storage_api_url,
            '--storage-token',
            storage_api_token,
            '--workspace-schema',
            workspace_schema,
        ],
        env={},  # make sure no env vars are passed from the test environment
    )
    component_config = configs[0]
    async with Client(transport) as client:
        await _assert_basic_setup(client)
        await _assert_get_component_details_tool_call(client, component_config)
@pytest.mark.asyncio
@pytest.mark.parametrize('transport', ['sse', 'streamable-http', 'http-compat'])
async def test_remote_setup(
    transport: HttpTransportStr,
    configs: list[ConfigDef],
    storage_api_token: str,
    workspace_schema: str,
    storage_api_url: str,
):
    assert storage_api_token is not None
    assert workspace_schema is not None
    assert storage_api_url is not None
    component_config = configs[0]
    for url in _run_server_remote(storage_api_url, transport):
        # test both cases: with headers and without headers using query params
        headers = {'storage_token': storage_api_token, 'workspace_schema': workspace_schema}
        async with _run_client(url, headers) as client:
            await _assert_basic_setup(client)
            await _assert_get_component_details_tool_call(client, component_config)
@pytest.mark.asyncio
async def test_http_multiple_clients(
    configs: list[ConfigDef],
    storage_api_token: str,
    workspace_schema: str,
    storage_api_url: str,
):
    transport: HttpTransportStr = 'streamable-http'
    component_config = configs[0]
    for url in _run_server_remote(storage_api_url, transport):
        headers = {
            'storage_token': storage_api_token,
            'workspace_schema': workspace_schema,
            'storage_api_url': storage_api_url,
        }
        async with (
            _run_client(url, headers) as client_1,
            _run_client(url, headers) as client_2,
            _run_client(url, headers) as client_3,
        ):
            await _assert_basic_setup(client_1)
            await _assert_basic_setup(client_2)
            await _assert_basic_setup(client_3)
            await _assert_get_component_details_tool_call(client_1, component_config)
            await _assert_get_component_details_tool_call(client_2, component_config)
            await _assert_get_component_details_tool_call(client_3, component_config)
@pytest.mark.asyncio
async def test_http_multiple_clients_with_different_headers(
    storage_api_url: str,
    storage_api_token: str,
    workspace_schema: str,
    storage_api_token_2: str | None,
    workspace_schema_2: str | None,
):
    """
    Test that the server can handle multiple clients with different headers and checks the values of the headers.
    """
    if not storage_api_token_2 or not workspace_schema_2:
        pytest.skip('No SAPI token or workspace schema for the second client. Skipping test.')
    headers = {
        'client_1': {'storage_token': storage_api_token, 'workspace_schema': workspace_schema},
        'client_2': {'storage_token': storage_api_token_2, 'workspace_schema': workspace_schema_2},
    }
    transport: HttpTransportStr = 'streamable-http'
    for url in _run_server_remote(storage_api_url, transport):
        async with (
            _run_client(url, headers['client_1']) as client_1,
            _run_client(url, headers['client_2']) as client_2,
        ):
            await _assert_basic_setup(client_1)
            await _assert_basic_setup(client_2)
            response_1 = await client_1.call_tool('get_project_info')
            project_info_1 = ProjectInfo.model_validate(response_1.structured_content)
            project_info_1.project_id = storage_api_token.split(sep='-')[0]
            LOG.info(f'project_info_1={project_info_1}')
            response_2 = await client_2.call_tool('get_project_info')
            project_info_2 = ProjectInfo.model_validate(response_2.structured_content)
            project_info_2.project_id = storage_api_token_2.split(sep='-')[0]
            LOG.info(f'project_info_2={project_info_2}')
async def _assert_basic_setup(client: Client):
    tools = await client.list_tools()
    # the create_conditional_flow, create_flow and search tools may not be present based on the testing project
    exclude = {
        'create_conditional_flow',
        'create_flow',
        'search',
    }
    expected_tools = {
        'add_config_row',
        'create_conditional_flow',
        'create_config',
        'create_flow',
        'create_oauth_url',
        'create_sql_transformation',
        'deploy_data_app',
        'docs_query',
        'find_component_id',
        'get_bucket',
        'get_component',
        'get_config',
        'get_config_examples',
        'get_data_apps',
        'get_flow',
        'get_flow_examples',
        'get_flow_schema',
        'get_job',
        'get_project_info',
        'get_table',
        'list_buckets',
        'list_configs',
        'list_flows',
        'list_jobs',
        'list_tables',
        'modify_data_app',
        'query_data',
        'run_job',
        'search',
        'update_config',
        'update_config_row',
        'update_descriptions',
        'update_flow',
        'update_sql_transformation',
    }
    expected_tools = expected_tools - exclude
    actual_tools = {tool.name for tool in tools}
    actual_tools = actual_tools - exclude
    missing_tools = expected_tools - actual_tools
    assert not missing_tools, f'Missing tools: {missing_tools}'
    unexpected_tools = actual_tools - expected_tools
    assert not unexpected_tools, f'Unexpected new tools: {unexpected_tools}'
    prompts = await client.list_prompts()
    assert len(prompts) == 6
    # there are no resources exposed in the MCP server; just check that the call succeeds
    resources = await client.list_resources()
    assert len(resources) == 0
async def _assert_get_component_details_tool_call(client: Client, config: ConfigDef):
    assert config.configuration_id is not None
    tool_result = await client.call_tool(
        'get_config',
        {'configuration_id': config.configuration_id, 'component_id': config.component_id},
    )
    assert tool_result is not None
    assert len(tool_result.content) == 1
    tool_result_content = tool_result.content[0]
    assert isinstance(tool_result_content, TextContent)  # only one tool call is executed
    component_str = tool_result_content.text
    component_json = json.loads(component_str)
    component_config = Configuration.model_validate(component_json)
    assert isinstance(component_config, Configuration)
    assert component_config.component is not None
    assert component_config.component.component_id == config.component_id
    assert component_config.component.component_type is not None
    assert component_config.component.component_name is not None
    assert component_config.configuration_root is not None
    assert component_config.configuration_root.configuration_id == config.configuration_id
    assert component_config.configuration_rows is None
def _run_server_remote(storage_api_url: str, transport: HttpTransportStr) -> Iterable[str]:
    """
    Run the server in a subprocess.
    :param storage_api_url: The Storage API URL to use.
    :param transport: The transport to use.
    :return: The url of the remote server.
    """
    port = random.randint(8000, 9000)
    p = subprocess.Popen(
        [
            'python',
            '-m',
            'keboola_mcp_server',
            '--transport',
            transport,
            '--api-url',
            storage_api_url,
            '--port',
            str(port),
        ],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True,
        env={
            name: val
            for name, val in os.environ.items()
            if name not in [DEV_STORAGE_API_URL_ENV_VAR, DEV_STORAGE_TOKEN_ENV_VAR, DEV_WORKSPACE_SCHEMA_ENV_VAR]
        },
    )
    try:
        urls: list[str] = []
        if transport in ['sse', 'http-compat']:
            urls.append(f'http://127.0.0.1:{port}/sse')
        if transport in ['streamable-http', 'http-compat']:
            urls.append(f'http://127.0.0.1:{port}/mcp')
        if not urls:
            raise ValueError(f'Unknown transport: {transport}')
        LOG.info(f'Running MCP server in subprocess with {transport} transport, listening on: {urls}')
        time.sleep(5)  # wait for the server to start
        yield from urls
    finally:
        LOG.info('Terminating MCP server subprocess.')
        p.terminate()
        stdout, stderr = p.communicate()
        LOG.info(f'-- MCP server stdout --\n{stdout}\n-- end stdout --')
        LOG.info(f'-- MCP server stderr --\n{stderr}\n-- end stderr --')
@asynccontextmanager
async def _run_client(url: str, headers: dict[str, str] | None = None) -> AsyncGenerator[Client, None]:
    """
    Run the client in an async context manager which will ensure that the client is properly closed after the test.
    The client is created with the given transport and connected to the url of the remote server with which it
    communicates.
    :param url: The url of the remote server to which the client will be connected.
    :param headers: The headers to use for the client.
    :return: The Client connected to the remote server.
    """
    if url.endswith('/sse'):
        transport = SSETransport(url=url, headers=headers)
    elif url.endswith('/mcp'):
        transport = StreamableHttpTransport(url=url, headers=headers)
    else:
        raise ValueError(f'Unknown transport: {url}')
    client_explicit = Client(transport)
    exception_from_client = None
    LOG.info(f'Running MCP client connecting to {url} and expecting `{transport}` server transport.')
    try:
        async with client_explicit:
            try:
                yield client_explicit
            except Exception as e:
                LOG.error(f'Error in client TaskGroup: {e}')
                exception_from_client = e
                # we need to keep an exception from the client TaskGroup and raise it
                # outside the context manager, otherwise it will inform only about task group error
    finally:
        del client_explicit
        if isinstance(exception_from_client, Exception):
            raise exception_from_client