Databricks MCP Server
by JustTryAI
Verified
"""
Tests for the Databricks MCP server.
This test file connects to the MCP server using the MCP client library
and tests the cluster and notebook operations.
"""
import asyncio
import json
import logging
import os
import subprocess
import sys
import time
from typing import Any, Dict, List, Optional, Tuple
import anyio
import pytest
from mcp.client.session import ClientSession
from mcp.client.stdio import StdioServerParameters, stdio_client
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)
class DatabricksMCPClient:
"""Client for testing the Databricks MCP server."""
def __init__(self):
self.session: Optional[ClientSession] = None
self.stdio_transport: Optional[Tuple[Any, Any]] = None
self.server_process: Optional[subprocess.Popen] = None
async def connect(self):
"""Connect to the MCP server."""
logger.info("Starting Databricks MCP server...")
# Set up environment variables if needed
# os.environ["DATABRICKS_HOST"] = "..."
# os.environ["DATABRICKS_TOKEN"] = "..."
# Start the server with SkipPrompt flag to avoid interactive prompts
cmd = ["pwsh", "-File", "start_mcp_server.ps1", "-SkipPrompt"]
self.server_process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1
)
# Wait for server to start
time.sleep(2)
# Connect to the server with SkipPrompt flag
logger.info("Connecting to MCP server...")
params = StdioServerParameters(
command="pwsh",
args=["-File", "start_mcp_server.ps1", "-SkipPrompt"],
env=None
)
async with anyio.create_task_group() as tg:
async with stdio_client(params) as stdio_transport:
self.stdio_transport = stdio_transport
stdio, write = stdio_transport
self.session = ClientSession(stdio, write)
await self.session.initialize()
# Log available tools
tools_response = await self.session.list_tools()
logger.info(f"Available tools: {[t.name for t in tools_response.tools]}")
# Run tests and then exit
await tg.start(self.run_tests)
async def run_tests(self):
"""Run the tests for the Databricks MCP server."""
try:
await self.test_list_clusters()
await self.test_get_cluster()
await self.test_list_notebooks()
await self.test_export_notebook()
logger.info("All tests completed successfully!")
except Exception as e:
logger.error(f"Test failed: {e}")
raise
finally:
if self.server_process:
self.server_process.terminate()
async def test_list_clusters(self):
"""Test listing clusters."""
logger.info("Testing list_clusters...")
response = await self.session.call_tool("list_clusters", {})
logger.info(f"list_clusters response: {json.dumps(response, indent=2)}")
assert "clusters" in response, "Response should contain 'clusters' key"
return response
async def test_get_cluster(self):
"""Test getting cluster details."""
logger.info("Testing get_cluster...")
# First list clusters to get a cluster_id
clusters_response = await self.test_list_clusters()
if not clusters_response.get("clusters"):
logger.warning("No clusters found to test get_cluster")
return
# Get the first cluster ID
cluster_id = clusters_response["clusters"][0]["cluster_id"]
# Get cluster details
response = await self.session.call_tool("get_cluster", {"cluster_id": cluster_id})
logger.info(f"get_cluster response: {json.dumps(response, indent=2)}")
assert "cluster_id" in response, "Response should contain 'cluster_id' key"
assert response["cluster_id"] == cluster_id, "Returned cluster ID should match requested ID"
async def test_list_notebooks(self):
"""Test listing notebooks."""
logger.info("Testing list_notebooks...")
response = await self.session.call_tool("list_notebooks", {"path": "/"})
logger.info(f"list_notebooks response: {json.dumps(response, indent=2)}")
assert "objects" in response, "Response should contain 'objects' key"
return response
async def test_export_notebook(self):
"""Test exporting a notebook."""
logger.info("Testing export_notebook...")
# First list notebooks to get a notebook path
notebooks_response = await self.test_list_notebooks()
if not notebooks_response.get("objects"):
logger.warning("No notebooks found to test export_notebook")
return
# Find the first notebook (not a directory)
notebook = None
for obj in notebooks_response["objects"]:
if obj.get("object_type") == "NOTEBOOK":
notebook = obj
break
if not notebook:
logger.warning("No notebooks found to test export_notebook")
return
# Get notebook path
notebook_path = notebook["path"]
# Export notebook
response = await self.session.call_tool(
"export_notebook",
{"path": notebook_path, "format": "SOURCE"}
)
logger.info(f"export_notebook response (truncated): {str(response)[:200]}...")
assert "content" in response, "Response should contain 'content' key"
# Skip this test for now as it causes hanging issues
@pytest.mark.skip(reason="Test causes hanging issues - needs further investigation")
@pytest.mark.asyncio
async def test_databricks_mcp_server():
"""Test the Databricks MCP server."""
client = DatabricksMCPClient()
await client.connect()
if __name__ == "__main__":
"""Run the tests directly."""
asyncio.run(DatabricksMCPClient().connect())