import asyncio
from datetime import datetime, timedelta, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastmcp import Context
from fastmcp.exceptions import ToolError
from pydantic import BaseModel, Field
from keboola_mcp_server.clients.client import KeboolaClient
from keboola_mcp_server.config import Config, ServerRuntimeInfo
from keboola_mcp_server.mcp import (
AggregateError,
ServerState,
SessionStateMiddleware,
ToolsFilteringMiddleware,
_exclude_none_serializer,
_filter_toon_nulls,
process_concurrently,
toon_serializer,
unwrap_results,
)
class SimpleModel(BaseModel):
field1: str | None = None
field2: int | None = Field(default=None, serialization_alias='field2_alias')
field3: datetime | None = None
class NestedModel(BaseModel):
field1: str | None = None
field2: list[str] | None = None
def _tool(name: str) -> MagicMock:
tool = MagicMock()
tool.name = name
return tool
async def _async_square(n: int) -> int:
"""Simple async function that squares a number after a short delay."""
await asyncio.sleep(0.01) # Simulate some async work
return n * n
async def _async_fail(n: int) -> None:
"""Simple async function that always raises an exception."""
await asyncio.sleep(0.01)
raise ValueError(f'Failed for {n}')
async def _async_square_or_fail(n: int) -> int:
"""Async function that squares even numbers and fails for odd numbers."""
await asyncio.sleep(0.01)
if n % 2 == 0:
return n * n
else:
raise ValueError(f'Failed for odd number {n}')
@pytest.mark.parametrize(
('data', 'expected'),
[
(None, ''),
# Exclude none values from a single model
(SimpleModel(field1='value1'), '{"field1":"value1"}'),
# Exclude none values from a list of models
(
[SimpleModel(field1='value1', field2=None), SimpleModel(field2=123)],
'[{"field1":"value1"},{"field2":123}]',
),
# Exclude none values from a dictionary with models
(
{'key1': SimpleModel(field1='value1'), 'key2': None, 'key3': SimpleModel(field2=456)},
'{"key1":{"field1":"value1"},"key3":{"field2":456}}',
),
# Exclude none values from primitives
({'key1': 123, 'key2': None, 'key3': 'value'}, '{"key1":123,"key3":"value"}'),
# Exclude none values with nested structures
(
{'key1': [SimpleModel(field1='value1'), None], 'key2': {'nested_key': SimpleModel(field2=789)}},
'{"key1":[{"field1":"value1"}],"key2":{"nested_key":{"field2":789}}}',
),
(
{
'key1': [
SimpleModel(field3=datetime(2025, 2, 3, 10, 11, 12, tzinfo=timezone(timedelta(hours=2)))),
None,
],
'key2': {'nested_key': SimpleModel(field2=789)},
'key3': datetime(2025, 1, 1, 1, 2, 3),
},
'{"key1":[{"field3":"2025-02-03T10:11:12+02:00"}],'
'"key2":{"nested_key":{"field2":789}},'
'"key3":"2025-01-01T01:02:03"}',
),
],
)
def test_exclude_none_serializer(data, expected):
result = _exclude_none_serializer(data)
assert result == expected
@pytest.mark.parametrize(
('data', 'expected'),
[
# Top-level None
(None, 'null'),
# Empty dict
({}, ''),
# Empty list
([], '[0]:'),
# Empty tuple
((), '[0]:'),
# Empty set
(set(), '[0]:'),
# Datetime
(
datetime(2025, 1, 1),
'"2025-01-01T00:00:00"',
),
# Simple dictionary
(
{'key': 'value', 'none_key': None},
'key: value\nnone_key: null',
),
# List
(
['item1', 'item2'],
'[2]: item1,item2',
),
# Mixed types in a list
(
['a', 1, True, None],
'[4]: a,1,true,null',
),
# Tuple
(
(1, 2, 3),
'[3]: 1,2,3',
),
# Nested dictionary
(
{'a': {'b': 1}},
'a:\n b: 1',
),
# Deeply nested None
(
{'a': {'b': None}},
'a:\n b: null',
),
# Model with some None values - toon_serializer includes None and does NOT use aliases
(
SimpleModel(field1='value1', field2=123),
'field1: value1\nfield2: 123\nfield3: null',
),
# Simple model (only has primitive fields) in a list
(
[SimpleModel(field1='value1', field2=123), SimpleModel(field1='value2', field2=456)],
'[2]{field1,field2,field3}:\n value1,123,null\n value2,456,null',
),
# Nested model (has a list field) in a list - this disables the tabular view
(
[
NestedModel(field1='value1', field2=['item1', 'item2']),
NestedModel(field1='value2', field2=['item3', 'item4']),
],
'[2]:\n'
' - field1: value1\n'
' field2[2]: item1,item2\n'
' - field1: value2\n'
' field2[2]: item3,item4',
),
# Complex structure with models, lists, dicts, and None
(
{
'users': [
{'name': 'Alice', 'active': True},
{'name': 'Bob', 'active': None},
],
'meta': SimpleModel(field1='test'),
},
'users[2]{name,active}:\n'
' Alice,true\n'
' Bob,null\n'
'meta:\n'
' field1: test\n'
' field2: null\n'
' field3: null',
),
],
)
def test_toon_serializer(data, expected):
result = toon_serializer(data)
assert result == expected
def test_filter_toon_nulls_single_item_list() -> None:
data = [{'a': 1, 'b': None, 'c': {'d': None, 'e': 2}}]
assert _filter_toon_nulls(data) == [{'a': 1, 'c': {'e': 2}}]
def test_filter_toon_nulls_multi_item_list_preserves_alignment() -> None:
data = [{'a': 1, 'b': None}, {'a': None, 'b': 2}]
assert _filter_toon_nulls(data) == [{'a': 1, 'b': None}, {'a': None, 'b': 2}]
def test_filter_toon_nulls_multi_item_list_preserves_key_order() -> None:
data = [
{
'b': 1,
'd': None,
'a': None,
},
{'a': 2, 'b': None, 'c': 3, 'd': None, 'e': None},
]
result = _filter_toon_nulls(data)
assert result == [{'b': 1, 'a': None, 'c': None}, {'b': None, 'a': 2, 'c': 3}]
assert list(result[0].keys()) == ['b', 'a', 'c']
@pytest.mark.parametrize(
('data', 'expected'),
[
({}, {}),
([], []),
(['a', None, 1], ['a', None, 1]),
({'a': None, 'b': 2}, {'b': 2}),
({'a': {'b': None, 'c': 3}}, {'a': {'c': 3}}),
([{'a': None, 'b': None}, {'a': 1, 'b': None}], [{'a': None}, {'a': 1}]),
([{'a': None}, {'b': None}], [{}, {}]),
([{'a': {'b': None}, 'c': 1}], [{'a': {}, 'c': 1}]),
# Test that _filter_toon_nulls applies recursively to lists nested inside dicts
(
[
{'a': 1, 'b': [None, 2, 3]},
{'a': None, 'b': [4, None]},
],
[
{'a': 1, 'b': [None, 2, 3]},
{'a': None, 'b': [4, None]},
],
),
# Test with deeper nesting for key 'b'
(
[
{'a': 1, 'b': [{'c': None, 'd': 2}, {'c': None, 'd': None}]},
{'a': 2, 'b': [{'c': None, 'd': None}, {'c': None, 'd': None}]},
],
[
{'a': 1, 'b': [{'d': 2}, {'d': None}]},
{'a': 2, 'b': [{}, {}]},
],
),
],
)
def test_filter_toon_nulls_edge_cases(data, expected) -> None:
assert _filter_toon_nulls(data) == expected
@pytest.mark.asyncio
@pytest.mark.parametrize(
('items', 'afunc', 'max_concurrency', 'expected_successes', 'expected_exceptions'),
[
# All succeed
(list(range(5)), _async_square, 2, [0, 1, 4, 9, 16], []),
# Mixed success and failure (odd numbers fail)
(list(range(5)), _async_square_or_fail, 3, [0, 4, 16], ['Failed for odd number 1', 'Failed for odd number 3']),
# All fail
(list(range(3)), _async_fail, 2, [], ['Failed for 0', 'Failed for 1', 'Failed for 2']),
# Empty input
([], _async_square, 5, [], []),
],
ids=['all_succeed', 'mixed_success_failure', 'all_fail', 'empty_input'],
)
async def test_process_concurrently(items, afunc, max_concurrency, expected_successes, expected_exceptions):
"""Test process_concurrently with various scenarios."""
results = await process_concurrently(items, afunc, max_concurrency=max_concurrency)
assert len(results) == len(items)
successes = sorted([r for r in results if not isinstance(r, BaseException)])
exceptions = [str(e) for e in results if isinstance(e, BaseException)]
assert successes == expected_successes
assert exceptions == expected_exceptions
@pytest.mark.asyncio
async def test_process_concurrently_respects_max_concurrency():
"""Test that max_concurrency limits simultaneous executions."""
max_concurrency = 3
current_running = 0
peak_running = 0
lock = asyncio.Lock()
async def track_concurrency(n: int) -> int:
nonlocal current_running, peak_running
async with lock:
current_running += 1
peak_running = max(peak_running, current_running)
try:
await asyncio.sleep(0.01)
return n * n
finally:
async with lock:
current_running -= 1
results = await process_concurrently(list(range(10)), track_concurrency, max_concurrency=max_concurrency)
assert sorted(results) == [i * i for i in range(10)]
assert peak_running <= max_concurrency
@pytest.mark.asyncio
@pytest.mark.parametrize('max_concurrency', [0, -1, -10])
async def test_process_concurrently_invalid_max_concurrency(max_concurrency):
"""Test that process_concurrently raises ValueError for invalid max_concurrency."""
with pytest.raises(ValueError, match='max_concurrency must be a positive integer'):
await process_concurrently([1, 2, 3], _async_square, max_concurrency=max_concurrency)
@pytest.mark.parametrize(
('results', 'expected'),
[
# All successes
([1, 2, 3], [1, 2, 3]),
# Empty list
([], []),
# Single success
(['value'], ['value']),
],
ids=['all_successes', 'empty', 'single_success'],
)
def test_unwrap_results_success(results, expected):
"""Test unwrap_results returns successes when no exceptions present."""
assert unwrap_results(results) == expected
def test_unwrap_results_raises_aggregate_error():
"""Test unwrap_results raises AggregateError when exceptions are present."""
exc1 = ValueError('error 1')
exc2 = RuntimeError('error 2')
results: list[int | BaseException] = [1, exc1, 2, exc2, 3]
with pytest.raises(AggregateError) as exc_info:
unwrap_results(results, 'Test errors')
err = exc_info.value
assert err.message == 'Test errors'
assert err.exceptions == [exc1, exc2]
assert str(err) == 'Test errors (2 errors): ValueError: error 1; RuntimeError: error 2'
def test_unwrap_results_all_exceptions():
"""Test unwrap_results when all results are exceptions."""
exc1 = ValueError('error 1')
exc2 = ValueError('error 2')
results: list[int | BaseException] = [exc1, exc2]
with pytest.raises(AggregateError) as exc_info:
unwrap_results(results)
err = exc_info.value
assert err.exceptions == [exc1, exc2]
assert str(err) == 'Multiple errors occurred (2 errors): ValueError: error 1; ValueError: error 2'
class TestToolsFilteringMiddleware:
@pytest.mark.asyncio
@pytest.mark.parametrize(
('branch_id', 'expect_filtered'),
[
('1234', True),
(None, False),
],
)
async def test_list_tools_filters_data_apps_by_branch(
self,
mcp_context_client,
branch_id: str | None,
expect_filtered: bool,
) -> None:
keboola_client = KeboolaClient.from_state(mcp_context_client.session.state)
keboola_client.branch_id = branch_id
keboola_client.storage_client.verify_token = AsyncMock(return_value={'owner': {'features': []}, 'admin': {}})
tools = [_tool('modify_data_app'), _tool('get_data_apps'), _tool('deploy_data_app'), _tool('other_tool')]
async def call_next(_):
return tools
middleware = ToolsFilteringMiddleware()
context = SimpleNamespace(fastmcp_context=mcp_context_client)
result = await middleware.on_list_tools(context, call_next)
result_names = {t.name for t in result}
if expect_filtered:
assert 'modify_data_app' not in result_names
assert 'get_data_apps' not in result_names
assert 'deploy_data_app' not in result_names
else:
assert 'modify_data_app' in result_names
assert 'get_data_apps' in result_names
assert 'deploy_data_app' in result_names
assert 'other_tool' in result_names
@pytest.mark.asyncio
@pytest.mark.parametrize(
('branch_id', 'expect_error'),
[
('5678', True),
(None, False),
],
)
async def test_call_tool_blocks_data_apps_by_branch(
self,
mcp_context_client,
branch_id: str | None,
expect_error: bool,
) -> None:
keboola_client = KeboolaClient.from_state(mcp_context_client.session.state)
keboola_client.branch_id = branch_id
keboola_client.storage_client.verify_token = AsyncMock(return_value={'owner': {'features': []}, 'admin': {}})
tool = _tool('modify_data_app')
mcp_context_client.fastmcp = SimpleNamespace(get_tool=AsyncMock(return_value=tool))
context = SimpleNamespace(fastmcp_context=mcp_context_client, message=SimpleNamespace(name='modify_data_app'))
expected = MagicMock()
async def call_next(_):
return expected
middleware = ToolsFilteringMiddleware()
if expect_error:
with pytest.raises(ToolError, match='Data apps are supported only in the main production branch'):
await middleware.on_call_tool(context, call_next)
else:
result = await middleware.on_call_tool(context, call_next)
assert result is expected
class TestSessionStateMiddleware:
@pytest.mark.asyncio
@pytest.mark.parametrize(
('method', 'expected_branch_id'),
[
('tools/list', None),
('resources/list', None),
('prompts/list', None),
('tools/call', '999'),
('resources/read', '999'),
],
ids=['tools_list', 'resources_list', 'prompts_list', 'tools_call', 'resources_read'],
)
async def test_on_request_branch_handling(self, method: str, expected_branch_id: str | None):
config = Config(
storage_api_url='https://connection.test.keboola.com',
storage_token='test-token',
branch_id='999',
)
runtime_info = ServerRuntimeInfo(transport='stdio')
server_state = ServerState(config=config, runtime_info=runtime_info)
# Use a non-MagicMock session so the middleware enters the branch-handling code path
session = SimpleNamespace(state={})
# ctx must pass isinstance(ctx, Context) check, so we use MagicMock(spec=Context).
# However ctx.session must NOT be a MagicMock (line 146 guard), so we override it.
ctx = MagicMock(spec=Context)
ctx.session = session
ctx.request_context.lifespan_context = server_state
context = SimpleNamespace(method=method, fastmcp_context=ctx)
expected_result = object()
async def call_next(_):
return expected_result
captured_configs: list[Config] = []
async def fake_create_session_state(cfg, _runtime_info, readonly=None):
captured_configs.append(cfg)
return {'fake': 'state'}
middleware = SessionStateMiddleware()
with (
patch.object(middleware, 'create_session_state', side_effect=fake_create_session_state),
patch('keboola_mcp_server.mcp.get_http_request_or_none', return_value=None),
):
result = await middleware.on_request(context, call_next)
assert result is expected_result
assert len(captured_configs) == 1
assert captured_configs[0].branch_id == expected_branch_id