Databricks MCP Server
by JustTryAI
Verified
"""
Tests for the clusters API.
"""
import json
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import status
from fastapi.testclient import TestClient
from src.api import clusters
from src.server.app import create_app
@pytest.fixture
def client():
"""Create a test client for the API."""
app = create_app()
return TestClient(app)
@pytest.fixture
def mock_cluster_response():
"""Mock response for cluster operations."""
return {
"cluster_id": "1234-567890-abcdef",
"cluster_name": "Test Cluster",
"spark_version": "10.4.x-scala2.12",
"node_type_id": "Standard_D3_v2",
"num_workers": 2,
"state": "RUNNING",
"creator_user_name": "test@example.com",
}
@pytest.mark.asyncio
async def test_create_cluster():
"""Test creating a cluster."""
# Mock the API call
clusters.create_cluster = AsyncMock(return_value={"cluster_id": "1234-567890-abcdef"})
# Create cluster config
cluster_config = {
"cluster_name": "Test Cluster",
"spark_version": "10.4.x-scala2.12",
"node_type_id": "Standard_D3_v2",
"num_workers": 2,
}
# Call the function
response = await clusters.create_cluster(cluster_config)
# Check the response
assert response["cluster_id"] == "1234-567890-abcdef"
# Verify the mock was called with the correct arguments
clusters.create_cluster.assert_called_once_with(cluster_config)
@pytest.mark.asyncio
async def test_list_clusters():
"""Test listing clusters."""
# Mock the API call
mock_response = {
"clusters": [
{
"cluster_id": "1234-567890-abcdef",
"cluster_name": "Test Cluster 1",
"state": "RUNNING",
},
{
"cluster_id": "9876-543210-fedcba",
"cluster_name": "Test Cluster 2",
"state": "TERMINATED",
},
]
}
clusters.list_clusters = AsyncMock(return_value=mock_response)
# Call the function
response = await clusters.list_clusters()
# Check the response
assert len(response["clusters"]) == 2
assert response["clusters"][0]["cluster_id"] == "1234-567890-abcdef"
assert response["clusters"][1]["cluster_id"] == "9876-543210-fedcba"
# Verify the mock was called
clusters.list_clusters.assert_called_once()
@pytest.mark.asyncio
async def test_get_cluster():
"""Test getting cluster information."""
# Mock the API call
mock_response = {
"cluster_id": "1234-567890-abcdef",
"cluster_name": "Test Cluster",
"state": "RUNNING",
}
clusters.get_cluster = AsyncMock(return_value=mock_response)
# Call the function
response = await clusters.get_cluster("1234-567890-abcdef")
# Check the response
assert response["cluster_id"] == "1234-567890-abcdef"
assert response["state"] == "RUNNING"
# Verify the mock was called with the correct arguments
clusters.get_cluster.assert_called_once_with("1234-567890-abcdef")
@pytest.mark.asyncio
async def test_terminate_cluster():
"""Test terminating a cluster."""
# Mock the API call
clusters.terminate_cluster = AsyncMock(return_value={})
# Call the function
response = await clusters.terminate_cluster("1234-567890-abcdef")
# Check the response
assert response == {}
# Verify the mock was called with the correct arguments
clusters.terminate_cluster.assert_called_once_with("1234-567890-abcdef")
@pytest.mark.asyncio
async def test_start_cluster():
"""Test starting a cluster."""
# Mock the API call
clusters.start_cluster = AsyncMock(return_value={})
# Call the function
response = await clusters.start_cluster("1234-567890-abcdef")
# Check the response
assert response == {}
# Verify the mock was called with the correct arguments
clusters.start_cluster.assert_called_once_with("1234-567890-abcdef")
@pytest.mark.asyncio
async def test_resize_cluster():
"""Test resizing a cluster."""
# Mock the API call
clusters.resize_cluster = AsyncMock(return_value={})
# Call the function
response = await clusters.resize_cluster("1234-567890-abcdef", 4)
# Check the response
assert response == {}
# Verify the mock was called with the correct arguments
clusters.resize_cluster.assert_called_once_with("1234-567890-abcdef", 4)
@pytest.mark.asyncio
async def test_restart_cluster():
"""Test restarting a cluster."""
# Mock the API call
clusters.restart_cluster = AsyncMock(return_value={})
# Call the function
response = await clusters.restart_cluster("1234-567890-abcdef")
# Check the response
assert response == {}
# Verify the mock was called with the correct arguments
clusters.restart_cluster.assert_called_once_with("1234-567890-abcdef")