test_storage.py•11.7 kB
import csv
import logging
import pytest
from fastmcp import Context
from integtests.conftest import BucketDef, TableDef
from keboola_mcp_server.clients.client import KeboolaClient, get_metadata_property
from keboola_mcp_server.config import MetadataField
from keboola_mcp_server.tools.storage import (
    BucketDetail,
    DescriptionUpdate,
    ListBucketsOutput,
    ListTablesOutput,
    TableDetail,
    UpdateDescriptionsOutput,
    get_bucket,
    get_table,
    list_buckets,
    list_tables,
    update_descriptions,
)
LOG = logging.getLogger(__name__)
@pytest.mark.asyncio
async def test_list_buckets(mcp_context: Context, buckets: list[BucketDef]):
    """Tests that `list_buckets` returns a list of `BucketDetail` instances."""
    result = await list_buckets(mcp_context)
    assert isinstance(result, ListBucketsOutput)
    for item in result.buckets:
        assert isinstance(item, BucketDetail)
    assert len(result.buckets) == len(buckets)
    assert result.bucket_counts.total_buckets == len(buckets)
    # Count buckets by stage from the actual result (since BucketDef doesn't have stage info)
    actual_input_count = sum(1 for bucket in result.buckets if bucket.stage == 'in')
    actual_output_count = sum(1 for bucket in result.buckets if bucket.stage == 'out')
    # Verify our counts match what we calculated
    assert result.bucket_counts.input_buckets == actual_input_count
    assert result.bucket_counts.output_buckets == actual_output_count
    # Verify the counts add up to the total
    assert (
        result.bucket_counts.input_buckets + result.bucket_counts.output_buckets == result.bucket_counts.total_buckets
    )
@pytest.mark.asyncio
async def test_get_bucket(mcp_context: Context, buckets: list[BucketDef]):
    """Tests that for each test bucket, `get_bucket` returns a `BucketDetail` instance."""
    for bucket in buckets:
        result = await get_bucket(bucket.bucket_id, mcp_context)
        assert isinstance(result, BucketDetail)
        assert result.id == bucket.bucket_id
@pytest.mark.asyncio
async def test_get_table(mcp_context: Context, tables: list[TableDef]):
    """Tests that for each test table, `get_table` returns a `TableDetail` instance with correct fields."""
    for table in tables:
        with table.file_path.open('r', encoding='utf-8') as f:
            reader = csv.reader(f)
            columns = frozenset(next(reader))
        result = await get_table(table.table_id, mcp_context)
        assert isinstance(result, TableDetail)
        assert result.id == table.table_id
        assert result.name == table.table_name
        assert result.columns is not None
        assert {col.name for col in result.columns} == columns
@pytest.mark.asyncio
async def test_list_tables(mcp_context: Context, tables: list[TableDef], buckets: list[BucketDef]):
    """Tests that `list_tables` returns the correct tables for each bucket."""
    # Group tables by bucket to verify counts
    tables_by_bucket = {}
    for table in tables:
        if table.bucket_id not in tables_by_bucket:
            tables_by_bucket[table.bucket_id] = []
        tables_by_bucket[table.bucket_id].append(table)
    for bucket in buckets:
        result = await list_tables(bucket.bucket_id, mcp_context)
        assert isinstance(result, ListTablesOutput)
        for item in result.tables:
            assert isinstance(item, TableDetail)
        # Verify the count matches expected tables for this bucket
        expected_tables = tables_by_bucket.get(bucket.bucket_id, [])
        assert len(result.tables) == len(expected_tables)
        # Verify table IDs match
        result_table_ids = {table.id for table in result.tables}
        expected_table_ids = {table.table_id for table in expected_tables}
        assert result_table_ids == expected_table_ids
@pytest.mark.asyncio
async def test_update_descriptions_bucket(mcp_context: Context, buckets: list[BucketDef]):
    """Tests that `update_descriptions` updates bucket descriptions correctly."""
    bucket = buckets[0]
    client = KeboolaClient.from_state(mcp_context.session.state)
    result = await update_descriptions(
        ctx=mcp_context,
        updates=[DescriptionUpdate(item_id=bucket.bucket_id, description='New Description')],
    )
    assert isinstance(result, UpdateDescriptionsOutput)
    assert result.total_processed == 1
    assert result.successful == 1
    assert result.failed == 0
    assert len(result.results) == 1
    bucket_result = result.results[0]
    assert bucket_result.item_id == bucket.bucket_id
    assert bucket_result.success is True
    assert bucket_result.error is None
    assert bucket_result.timestamp is not None
    # Verify the description was actually updated
    metadata = await client.storage_client.bucket_metadata_get(bucket.bucket_id)
    assert get_metadata_property(metadata, MetadataField.DESCRIPTION) == 'New Description'
@pytest.mark.asyncio
async def test_update_descriptions_table(mcp_context: Context, tables: list[TableDef]):
    """Tests that `update_descriptions` updates table descriptions correctly."""
    table = tables[0]
    client = KeboolaClient.from_state(mcp_context.session.state)
    result = await update_descriptions(
        ctx=mcp_context,
        updates=[DescriptionUpdate(item_id=table.table_id, description='New Table Description')],
    )
    assert isinstance(result, UpdateDescriptionsOutput)
    assert result.total_processed == 1
    assert result.successful == 1
    assert result.failed == 0
    assert len(result.results) == 1
    table_result = result.results[0]
    assert table_result.item_id == table.table_id
    assert table_result.success is True
    assert table_result.error is None
    assert table_result.timestamp is not None
    # Verify the description was actually updated
    metadata = await client.storage_client.table_metadata_get(table.table_id)
    assert get_metadata_property(metadata, MetadataField.DESCRIPTION) == 'New Table Description'
@pytest.mark.asyncio
async def test_update_descriptions_table_column(mcp_context: Context, tables: list[TableDef]):
    """Tests that `update_descriptions` updates table descriptions correctly."""
    table = tables[0]
    with table.file_path.open('r', encoding='utf-8') as f:
        reader = csv.reader(f)
        columns = next(reader)
    column_name = columns[0]
    column_id = f'{table.table_id}.{column_name}'
    result = await update_descriptions(
        ctx=mcp_context,
        updates=[DescriptionUpdate(item_id=column_id, description='New Table Column Description')],
    )
    assert isinstance(result, UpdateDescriptionsOutput)
    assert result.total_processed == 1
    assert result.successful == 1
    assert result.failed == 0
    assert len(result.results) == 1
    column_result = result.results[0]
    assert column_result.item_id == column_id
    assert column_result.success is True
    assert column_result.error is None
    assert column_result.timestamp is not None
    # Verify the description is available in the table detail
    table_detail = await get_table(table.table_id, mcp_context)
    assert table_detail.columns is not None
    column_detail = next((col for col in table_detail.columns if col.name == column_name), None)
    assert column_detail is not None
    assert column_detail.description == 'New Table Column Description'
@pytest.mark.asyncio
async def test_update_descriptions_mixed_types(mcp_context: Context, buckets: list[BucketDef], tables: list[TableDef]):
    """Tests that `update_descriptions` can handle mixed types in a single call."""
    bucket = buckets[0]
    table = tables[0]
    # Get the first column name from the table CSV file
    with table.file_path.open('r', encoding='utf-8') as f:
        reader = csv.reader(f)
        columns = next(reader)
    column_name = columns[0]
    md_ids: list[str] = []
    client = KeboolaClient.from_state(mcp_context.session.state)
    try:
        result = await update_descriptions(
            ctx=mcp_context,
            updates=[
                DescriptionUpdate(item_id=bucket.bucket_id, description='Mixed Bucket Description'),
                DescriptionUpdate(item_id=table.table_id, description='Mixed Table Description'),
                DescriptionUpdate(item_id=f'{table.table_id}.{column_name}', description='Mixed Column Description'),
            ],
        )
        assert isinstance(result, UpdateDescriptionsOutput)
        assert result.total_processed == 3
        assert result.successful == 3
        assert result.failed == 0
        assert len(result.results) == 3
        # Verify all results are successful
        for item_result in result.results:
            assert item_result.success is True
            assert item_result.error is None
            assert item_result.timestamp is not None
        # Verify bucket description was updated
        bucket_metadata = await client.storage_client.bucket_metadata_get(bucket.bucket_id)
        bucket_entry = next((entry for entry in bucket_metadata if entry.get('key') == MetadataField.DESCRIPTION), None)
        if bucket_entry:
            assert bucket_entry['value'] == 'Mixed Bucket Description'
            md_ids.append(('bucket', bucket.bucket_id, str(bucket_entry['id'])))
        # Verify table description was updated
        table_metadata = await client.storage_client.table_metadata_get(table.table_id)
        table_entry = next((entry for entry in table_metadata if entry.get('key') == MetadataField.DESCRIPTION), None)
        if table_entry:
            assert table_entry['value'] == 'Mixed Table Description'
            md_ids.append(('table', table.table_id, str(table_entry['id'])))
        # Verify column description was updated
        table_detail = await client.storage_client.table_detail(table.table_id)
        assert 'columnMetadata' in table_detail
        column_metadata = table_detail['columnMetadata']
        assert column_name in column_metadata
        column_entry = next(
            (entry for entry in column_metadata[column_name] if entry.get('key') == MetadataField.DESCRIPTION), None
        )
        if column_entry:
            assert column_entry['value'] == 'Mixed Column Description'
            md_ids.append(('column', f'{table.table_id}.{column_name}', str(column_entry['id'])))
    finally:
        # Clean up metadata
        for md_type, item_id, md_id in md_ids:
            if md_type == 'bucket':
                await client.storage_client.bucket_metadata_delete(bucket_id=item_id, metadata_id=md_id)
            elif md_type == 'table':
                await client.storage_client.table_metadata_delete(table_id=item_id, metadata_id=md_id)
            elif md_type == 'column':
                await client.storage_client.column_metadata_delete(column_id=item_id, metadata_id=md_id)
@pytest.mark.asyncio
async def test_update_descriptions_invalid_path(mcp_context: Context):
    """Tests that `update_descriptions` handles invalid paths gracefully."""
    result = await update_descriptions(
        ctx=mcp_context,
        updates=[DescriptionUpdate(item_id='invalid-path', description='This should fail')],
    )
    assert isinstance(result, UpdateDescriptionsOutput)
    assert result.total_processed == 1
    assert result.successful == 0
    assert result.failed == 1
    assert len(result.results) == 1
    error_result = result.results[0]
    assert error_result.item_id == 'invalid-path'
    assert error_result.success is False
    assert error_result.error is not None
    assert 'Invalid item_id format' in error_result.error
    assert error_result.timestamp is None