#!/usr/bin/env python3
"""
Integration tests for dbt MCP Server
Comprehensive test suite for validating MCP protocol compliance,
dbt tool functionality, and Claude Code integration capabilities.
"""
import asyncio
import json
import os
import sys
import tempfile
import unittest
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from mcp_servers.dbt_server import DbtMCPServer
from mcp.types import CallToolResult, ListResourcesResult, ListToolsResult
class TestMCPServerIntegration(unittest.IsolatedAsyncioTestCase):
"""Integration tests for dbt MCP server functionality"""
async def asyncSetUp(self):
"""Set up test environment"""
# Set test environment variables
os.environ.update({
"DBT_PROJECT_DIR": "/Users/ajdoyle/claude-data-stack-mcp/transform",
"DBT_PROFILES_DIR": "/Users/ajdoyle/claude-data-stack-mcp/transform/profiles/duckdb",
"DBT_PROFILE": "data_stack",
"DBT_TARGET": "dev",
"DUCKDB_PATH": "/Users/ajdoyle/claude-data-stack-mcp/data/warehouse/data_stack.duckdb",
"DBT_MCP_ENABLE_CLI_TOOLS": "true",
"DBT_MCP_ENABLE_DISCOVERY_TOOLS": "true",
"DBT_MCP_ENABLE_REMOTE_TOOLS": "true",
"MCP_DEBUG": "false"
})
# Initialize server
self.server = DbtMCPServer()
def test_server_initialization(self):
"""Test MCP server initialization and configuration"""
self.assertIsNotNone(self.server)
self.assertEqual(self.server.dbt_profile, "data_stack")
self.assertEqual(self.server.dbt_target, "dev")
self.assertTrue(self.server.enable_cli_tools)
self.assertTrue(self.server.enable_discovery_tools)
self.assertTrue(self.server.enable_remote_tools)
async def test_list_tools_compliance(self):
"""Test MCP protocol compliance for tool listing"""
# Get tools list handler
handler = None
for handler_info in self.server.server._list_tools_handlers:
handler = handler_info.handler
break
self.assertIsNotNone(handler, "List tools handler not found")
# Execute handler
result = await handler()
# Validate result type
self.assertIsInstance(result, ListToolsResult)
self.assertIsInstance(result.tools, list)
# Validate tool count and structure
self.assertGreaterEqual(len(result.tools), 6, "Expected at least 6 tools")
# Check for required tools
tool_names = [tool.name for tool in result.tools]
# CLI tools
if self.server.enable_cli_tools:
self.assertIn("dbt_run", tool_names)
self.assertIn("dbt_test", tool_names)
self.assertIn("dbt_compile", tool_names)
self.assertIn("dbt_build", tool_names)
# Discovery tools
if self.server.enable_discovery_tools:
self.assertIn("discovery_list_models", tool_names)
self.assertIn("discovery_model_details", tool_names)
# Remote tools
if self.server.enable_remote_tools:
self.assertIn("remote_query_database", tool_names)
# Validate tool schemas
for tool in result.tools:
self.assertIsNotNone(tool.name)
self.assertIsNotNone(tool.description)
self.assertIsNotNone(tool.inputSchema)
self.assertIsInstance(tool.inputSchema, dict)
async def test_list_resources_compliance(self):
"""Test MCP protocol compliance for resource listing"""
# Get resources list handler
handler = None
for handler_info in self.server.server._list_resources_handlers:
handler = handler_info.handler
break
self.assertIsNotNone(handler, "List resources handler not found")
# Execute handler
result = await handler()
# Validate result type
self.assertIsInstance(result, ListResourcesResult)
self.assertIsInstance(result.resources, list)
# Validate resources
self.assertGreaterEqual(len(result.resources), 1, "Expected at least 1 resource")
# Check for project config resource
resource_uris = [resource.uri for resource in result.resources]
self.assertIn("dbt://project/config", resource_uris)
# Validate resource structure
for resource in result.resources:
self.assertIsNotNone(resource.uri)
self.assertIsNotNone(resource.name)
self.assertIsNotNone(resource.description)
async def test_model_discovery(self):
"""Test model discovery functionality"""
models = await self.server._discover_models()
# Validate models discovered
self.assertIsInstance(models, list)
self.assertGreater(len(models), 0, "No models discovered")
# Check for expected models from data stack
model_names = [model['name'] for model in models]
expected_models = ['stg_employees', 'dim_employees', 'agg_department_stats']
for expected_model in expected_models:
self.assertIn(expected_model, model_names, f"Model {expected_model} not found")
# Validate model structure
for model in models:
self.assertIsInstance(model, dict)
self.assertIn('name', model)
self.assertIn('path', model)
self.assertIn('layer', model)
self.assertIn('size_bytes', model)
# Check model layers
layers = [model['layer'] for model in models]
self.assertIn('staging', layers, "No staging models found")
self.assertIn('marts', layers, "No marts models found")
async def test_model_details(self):
"""Test model details retrieval"""
# Test with known model
model_name = "stg_employees"
try:
details = await self.server._get_model_details(model_name)
# Validate details structure
self.assertIsInstance(details, dict)
self.assertEqual(details['name'], model_name)
self.assertIn('content', details)
self.assertIn('line_count', details)
self.assertIn('references', details)
self.assertIn('columns', details)
# Validate content
self.assertIsInstance(details['content'], str)
self.assertGreater(details['line_count'], 0)
self.assertIsInstance(details['references'], list)
self.assertIsInstance(details['columns'], list)
except FileNotFoundError:
self.fail(f"Model {model_name} not found")
async def test_project_config(self):
"""Test project configuration retrieval"""
config = await self.server._get_project_config()
# Validate config structure
self.assertIsInstance(config, dict)
self.assertIn('project_name', config)
self.assertIn('dbt_version', config)
self.assertIn('profile', config)
self.assertIn('target', config)
self.assertIn('tool_groups', config)
# Validate tool groups
tool_groups = config['tool_groups']
self.assertIsInstance(tool_groups, dict)
self.assertIn('cli_tools', tool_groups)
self.assertIn('discovery_tools', tool_groups)
self.assertIn('remote_tools', tool_groups)
async def test_database_connectivity(self):
"""Test DuckDB database connectivity"""
# Test simple query
try:
result = await self.server._execute_sql_query("SELECT 1 as test", limit=1)
# Validate result format
self.assertIsInstance(result, str)
self.assertIn("Query:", result)
self.assertIn("test", result)
self.assertIn("Rows returned:", result)
except Exception as e:
self.fail(f"Database connectivity test failed: {e}")
async def test_table_description(self):
"""Test table description functionality"""
# Test with known table (if exists)
try:
result = await self.server._describe_table("stg_employees", "main")
# Validate result format
self.assertIsInstance(result, str)
self.assertIn("Query:", result)
except Exception as e:
# Table might not exist, which is acceptable for this test
self.assertIn("Error", str(e))
async def test_tool_execution_error_handling(self):
"""Test error handling in tool execution"""
# Test with invalid arguments
result = await self.server._handle_dbt_tool("dbt_invalid", {})
# Validate error response
self.assertIsInstance(result, CallToolResult)
self.assertTrue(result.isError)
self.assertGreater(len(result.content), 0)
async def test_discovery_tool_execution(self):
"""Test discovery tool execution"""
# Test model listing
result = await self.server._handle_discovery_tool("discovery_list_models", {})
# Validate result
self.assertIsInstance(result, CallToolResult)
self.assertFalse(result.isError)
self.assertGreater(len(result.content), 0)
# Parse JSON content
content = result.content[0].text
models = json.loads(content)
self.assertIsInstance(models, list)
async def test_remote_tool_execution(self):
"""Test remote tool execution"""
# Test simple database query
arguments = {"sql": "SELECT 1 as test", "limit": 1}
result = await self.server._handle_remote_tool("remote_query_database", arguments)
# Validate result
self.assertIsInstance(result, CallToolResult)
self.assertFalse(result.isError)
self.assertGreater(len(result.content), 0)
# Check result content
content = result.content[0].text
self.assertIn("Query:", content)
self.assertIn("test", content)
def test_environment_validation(self):
"""Test environment configuration validation"""
# Check required paths exist
self.assertTrue(self.server.dbt_project_dir.exists(), "dbt project directory missing")
self.assertTrue(self.server.dbt_profiles_dir.exists(), "dbt profiles directory missing")
# Check DuckDB database
if not self.server.duckdb_path.exists():
self.skipTest("DuckDB database not found - run Meltano ELT first")
def test_configuration_validation(self):
"""Test server configuration validation"""
# Validate paths are absolute
self.assertTrue(self.server.dbt_project_dir.is_absolute())
self.assertTrue(self.server.dbt_profiles_dir.is_absolute())
self.assertTrue(self.server.duckdb_path.is_absolute())
# Validate boolean configurations
self.assertIsInstance(self.server.enable_cli_tools, bool)
self.assertIsInstance(self.server.enable_discovery_tools, bool)
self.assertIsInstance(self.server.enable_remote_tools, bool)
class TestMCPProtocolCompliance(unittest.TestCase):
"""Test MCP protocol specification compliance"""
def setUp(self):
"""Set up test environment"""
os.environ.update({
"DBT_PROJECT_DIR": "/Users/ajdoyle/claude-data-stack-mcp/transform",
"DBT_PROFILES_DIR": "/Users/ajdoyle/claude-data-stack-mcp/transform/profiles/duckdb",
"MCP_DEBUG": "false"
})
self.server = DbtMCPServer()
def test_server_instance(self):
"""Test server instance and basic properties"""
self.assertIsNotNone(self.server.server)
self.assertEqual(self.server.server.name, "dbt-data-stack")
def test_handler_registration(self):
"""Test MCP handler registration"""
# Check handlers are registered
self.assertGreater(len(self.server.server._list_tools_handlers), 0)
self.assertGreater(len(self.server.server._call_tool_handlers), 0)
self.assertGreater(len(self.server.server._list_resources_handlers), 0)
self.assertGreater(len(self.server.server._read_resource_handlers), 0)
def test_tool_schema_validation(self):
"""Test tool input schema validation"""
cli_tools = self.server._get_cli_tools()
discovery_tools = self.server._get_discovery_tools()
remote_tools = self.server._get_remote_tools()
all_tools = cli_tools + discovery_tools + remote_tools
for tool in all_tools:
# Validate required fields
self.assertIsNotNone(tool.name)
self.assertIsNotNone(tool.description)
self.assertIsNotNone(tool.inputSchema)
# Validate schema structure
schema = tool.inputSchema
self.assertIsInstance(schema, dict)
self.assertEqual(schema.get("type"), "object")
self.assertIn("properties", schema)
# Validate properties
properties = schema["properties"]
self.assertIsInstance(properties, dict)
# Check for required fields if present
if "required" in schema:
required = schema["required"]
self.assertIsInstance(required, list)
for req_field in required:
self.assertIn(req_field, properties)
def run_integration_tests():
"""Run all integration tests and generate report"""
print("🧪 Starting MCP Server Integration Tests")
print("=" * 50)
# Create test suite
loader = unittest.TestLoader()
suite = unittest.TestSuite()
# Add test classes
suite.addTests(loader.loadTestsFromTestCase(TestMCPServerIntegration))
suite.addTests(loader.loadTestsFromTestCase(TestMCPProtocolCompliance))
# Run tests with verbose output
runner = unittest.TextTestRunner(verbosity=2, stream=sys.stdout)
result = runner.run(suite)
# Generate summary report
print("\n" + "=" * 50)
print("🎯 Test Results Summary")
print("=" * 50)
print(f"Tests run: {result.testsRun}")
print(f"Failures: {len(result.failures)}")
print(f"Errors: {len(result.errors)}")
print(f"Skipped: {len(result.skipped) if hasattr(result, 'skipped') else 0}")
if result.failures:
print("\n❌ Failures:")
for test, traceback in result.failures:
print(f" - {test}: {traceback.split('AssertionError:')[-1].strip()}")
if result.errors:
print("\n🚨 Errors:")
for test, traceback in result.errors:
print(f" - {test}: {traceback.split('Exception:')[-1].strip()}")
success_rate = ((result.testsRun - len(result.failures) - len(result.errors)) / result.testsRun * 100) if result.testsRun > 0 else 0
print(f"\n✅ Success Rate: {success_rate:.1f}%")
return result
if __name__ == "__main__":
run_integration_tests()