"""
Tests for the DatabricksClient class.
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch
from databricks_mcp_server.databricks_client import DatabricksClient, TableMetadata, QueryResult
@pytest.fixture
def mock_workspace_client():
"""Create a mock workspace client."""
client = Mock()
client.catalogs.list.return_value = [
Mock(name="main", comment="Main catalog"),
Mock(name="test", comment="Test catalog")
]
return client
@pytest.fixture
async def databricks_client():
"""Create a DatabricksClient instance for testing."""
with patch('databricks_mcp_server.databricks_client.WorkspaceClient') as mock_ws:
mock_ws.return_value = Mock()
client = await DatabricksClient.create()
return client
class TestDatabricksClient:
"""Test cases for DatabricksClient."""
@pytest.mark.asyncio
async def test_create(self):
"""Test client creation."""
with patch('databricks_mcp_server.databricks_client.WorkspaceClient') as mock_ws:
mock_ws.return_value = Mock()
client = await DatabricksClient.create()
assert client.workspace_client is not None
assert client.executor is not None
@pytest.mark.asyncio
async def test_list_catalogs(self, databricks_client, mock_workspace_client):
"""Test listing catalogs."""
databricks_client.workspace_client = mock_workspace_client
catalogs = await databricks_client.list_catalogs()
assert len(catalogs) == 2
assert catalogs[0].name == "main"
assert catalogs[1].name == "test"
@pytest.mark.asyncio
async def test_describe_table(self, databricks_client):
"""Test describing a table."""
# Mock table info
mock_table = Mock()
mock_table.name = "test_table"
mock_table.catalog_name = "main"
mock_table.schema_name = "default"
mock_table.table_type.value = "MANAGED"
mock_table.comment = "Test table"
mock_table.owner = "test_user"
mock_table.columns = [
Mock(name="id", type_name="int", nullable=False, comment="ID column"),
Mock(name="name", type_name="string", nullable=True, comment="Name column")
]
mock_table.properties = {}
mock_table.storage_location = None
mock_table.created_at = None
mock_table.created_by = None
mock_table.updated_at = None
mock_table.updated_by = None
mock_table.data_source_format = None
databricks_client.workspace_client.tables.get.return_value = mock_table
result = await databricks_client.describe_table("main", "default", "test_table")
assert isinstance(result, TableMetadata)
assert result.name == "test_table"
assert result.catalog_name == "main"
assert result.schema_name == "default"
assert result.table_type == "MANAGED"
assert len(result.columns) == 2
@pytest.mark.asyncio
async def test_sample_table(self, databricks_client):
"""Test sampling table data."""
# Mock query execution
mock_execution = Mock()
mock_execution.status.state = "SUCCEEDED"
mock_execution.status.duration_ms = 100
mock_execution.result.data_array = [["1", "Alice"], ["2", "Bob"]]
mock_execution.manifest.schema.columns = [
Mock(name="id"),
Mock(name="name")
]
databricks_client.workspace_client.warehouses.list.return_value = [
Mock(id="warehouse-123")
]
databricks_client.workspace_client.statement_execution.execute_statement.return_value = mock_execution
databricks_client.workspace_client.statement_execution.get_statement.return_value = mock_execution
result = await databricks_client.sample_table("main", "default", "test_table", limit=2)
assert isinstance(result, QueryResult)
assert result.status == "SUCCESS"
assert result.row_count == 2
assert len(result.data) == 2
assert result.data[0]["id"] == "1"
assert result.data[0]["name"] == "Alice"
@pytest.mark.asyncio
async def test_close(self, databricks_client):
"""Test closing the client."""
await databricks_client.close()
# Should not raise any exceptions