#!/usr/bin/env python3
"""
Tests for Databricks MCP Server Prompts
Tests the prompt functionality including listing prompts,
generating specific prompts, and validating arguments.
"""
import pytest
from databricks_mcp_server.prompts import PromptsManager
class TestPromptsManager:
"""Test the PromptsManager functionality."""
def setup_method(self):
"""Set up test fixtures."""
self.prompts_manager = PromptsManager()
def test_list_prompts(self):
"""Test listing all available prompts."""
prompts = self.prompts_manager.list_prompts()
# Check that we have the expected prompts
prompt_names = [p.name for p in prompts]
expected_prompts = [
"explore-dataset",
"optimize-query",
"investigate-lineage",
"design-data-quality-checks",
"generate-documentation",
"troubleshoot-data-issue",
"compare-schemas"
]
for expected in expected_prompts:
assert expected in prompt_names, f"Missing prompt: {expected}"
# Check that all prompts have required attributes
for prompt in prompts:
assert prompt.name, "Prompt must have a name"
assert prompt.description, "Prompt must have a description"
assert hasattr(prompt, "arguments"), "Prompt must have arguments attribute"
def test_explore_dataset_prompt(self):
"""Test the explore-dataset prompt."""
# Test with all required arguments
result = self.prompts_manager.get_prompt(
"explore-dataset",
{
"catalog_name": "test_catalog",
"schema_name": "test_schema",
"table_name": "test_table"
}
)
assert result.description
assert len(result.messages) == 1
assert "test_catalog.test_schema.test_table" in result.messages[0].content.text
# Test with optional argument
result_with_focus = self.prompts_manager.get_prompt(
"explore-dataset",
{
"catalog_name": "test_catalog",
"schema_name": "test_schema",
"table_name": "test_table",
"analysis_focus": "data quality"
}
)
assert "data quality" in result_with_focus.messages[0].content.text
def test_optimize_query_prompt(self):
"""Test the optimize-query prompt."""
test_query = "SELECT * FROM table WHERE id = 1"
result = self.prompts_manager.get_prompt(
"optimize-query",
{
"query": test_query
}
)
assert result.description
assert len(result.messages) == 1
assert test_query in result.messages[0].content.text
def test_data_quality_prompt(self):
"""Test the design-data-quality-checks prompt."""
result = self.prompts_manager.get_prompt(
"design-data-quality-checks",
{
"catalog_name": "test_catalog",
"schema_name": "test_schema",
"table_name": "test_table"
}
)
assert result.description
assert len(result.messages) == 1
assert "test_catalog.test_schema.test_table" in result.messages[0].content.text
assert "data quality" in result.messages[0].content.text.lower()
def test_documentation_prompt(self):
"""Test the generate-documentation prompt."""
# Test table documentation
result = self.prompts_manager.get_prompt(
"generate-documentation",
{
"catalog_name": "test_catalog",
"schema_name": "test_schema",
"table_name": "test_table"
}
)
assert result.description
assert len(result.messages) == 1
assert "table `test_catalog.test_schema.test_table`" in result.messages[0].content.text
# Test schema documentation (no table_name)
result_schema = self.prompts_manager.get_prompt(
"generate-documentation",
{
"catalog_name": "test_catalog",
"schema_name": "test_schema"
}
)
assert "schema `test_catalog.test_schema`" in result_schema.messages[0].content.text
def test_troubleshoot_prompt(self):
"""Test the troubleshoot-data-issue prompt."""
result = self.prompts_manager.get_prompt(
"troubleshoot-data-issue",
{
"issue_description": "Data is missing",
"affected_table": "catalog.schema.table"
}
)
assert result.description
assert len(result.messages) == 1
assert "Data is missing" in result.messages[0].content.text
assert "catalog.schema.table" in result.messages[0].content.text
def test_lineage_prompt(self):
"""Test the investigate-lineage prompt."""
result = self.prompts_manager.get_prompt(
"investigate-lineage",
{
"catalog_name": "test_catalog",
"schema_name": "test_schema",
"table_name": "test_table"
}
)
assert result.description
assert len(result.messages) == 1
assert "test_catalog.test_schema.test_table" in result.messages[0].content.text
assert "lineage" in result.messages[0].content.text.lower()
def test_compare_schemas_prompt(self):
"""Test the compare-schemas prompt."""
result = self.prompts_manager.get_prompt(
"compare-schemas",
{
"source_table": "catalog1.schema1.table1",
"target_table": "catalog2.schema2.table2"
}
)
assert result.description
assert len(result.messages) == 1
assert "catalog1.schema1.table1" in result.messages[0].content.text
assert "catalog2.schema2.table2" in result.messages[0].content.text
def test_invalid_prompt_name(self):
"""Test handling of invalid prompt names."""
with pytest.raises(ValueError, match="Prompt not found"):
self.prompts_manager.get_prompt("invalid-prompt")
def test_missing_required_arguments(self):
"""Test handling of missing required arguments."""
with pytest.raises(ValueError, match="Required argument"):
self.prompts_manager.get_prompt("explore-dataset", {})
def test_prompt_arguments_validation(self):
"""Test that prompt arguments are properly defined."""
prompts = self.prompts_manager.list_prompts()
for prompt in prompts:
if prompt.arguments:
for arg in prompt.arguments:
assert arg.name, f"Argument in prompt {prompt.name} must have a name"
assert arg.description, f"Argument {arg.name} in prompt {prompt.name} must have a description"
assert hasattr(arg, "required"), f"Argument {arg.name} must have required attribute"
def test_prompt_message_structure(self):
"""Test that generated prompt messages have the correct structure."""
result = self.prompts_manager.get_prompt(
"explore-dataset",
{
"catalog_name": "test_catalog",
"schema_name": "test_schema",
"table_name": "test_table"
}
)
assert hasattr(result, "description")
assert hasattr(result, "messages")
assert len(result.messages) > 0
for message in result.messages:
assert hasattr(message, "role")
assert hasattr(message, "content")
assert hasattr(message.content, "type")
assert hasattr(message.content, "text")
assert message.content.type == "text"
if __name__ == "__main__":
pytest.main([__file__])