Skip to main content
Glama

MCP Server for Apache Airflow

by yangkyeongmo
test_dag.py21.2 kB
"""Table-driven tests for the dag module using pytest framework.""" from unittest.mock import ANY, MagicMock, patch import mcp.types as types import pytest from src.airflow.dag import ( clear_task_instances, delete_dag, get_dag, get_dag_details, get_dag_source, get_dag_tasks, get_dag_url, get_dags, get_task, get_tasks, patch_dag, pause_dag, reparse_dag_file, set_task_instances_state, unpause_dag, ) class TestDagModule: """Table-driven test cases for the dag module.""" @pytest.fixture def mock_dag_api(self): """Create a mock DAG API instance.""" with patch("src.airflow.dag.dag_api") as mock_api: yield mock_api def test_get_dag_url(self): """Test DAG URL generation.""" test_cases = [ # (dag_id, expected_url) ("test_dag", "http://localhost:8080/dags/test_dag/grid"), ("my-complex_dag.v2", "http://localhost:8080/dags/my-complex_dag.v2/grid"), ("", "http://localhost:8080/dags//grid"), ] for dag_id, expected_url in test_cases: with patch("src.airflow.dag.AIRFLOW_HOST", "http://localhost:8080"): result = get_dag_url(dag_id) assert result == expected_url @pytest.mark.parametrize( "test_case", [ # Test case structure: (input_params, mock_response_dict, expected_result_partial) { "name": "get_dags_no_params", "input": {}, "mock_response": {"dags": [{"dag_id": "test_dag", "description": "Test"}], "total_entries": 1}, "expected_call_kwargs": {}, "expected_ui_urls": True, }, { "name": "get_dags_with_limit_offset", "input": {"limit": 10, "offset": 5}, "mock_response": {"dags": [{"dag_id": "dag1"}, {"dag_id": "dag2"}], "total_entries": 2}, "expected_call_kwargs": {"limit": 10, "offset": 5}, "expected_ui_urls": True, }, { "name": "get_dags_with_filters", "input": {"tags": ["prod", "daily"], "only_active": True, "paused": False, "dag_id_pattern": "prod_*"}, "mock_response": {"dags": [{"dag_id": "prod_dag1"}], "total_entries": 1}, "expected_call_kwargs": { "tags": ["prod", "daily"], "only_active": True, "paused": False, "dag_id_pattern": "prod_*", }, "expected_ui_urls": True, }, ], ) async def test_get_dags_table_driven(self, test_case, mock_dag_api): """Table-driven test for get_dags function.""" # Setup mock response mock_response = MagicMock() mock_response.to_dict.return_value = test_case["mock_response"] mock_dag_api.get_dags.return_value = mock_response # Execute function with patch("src.airflow.dag.AIRFLOW_HOST", "http://localhost:8080"): result = await get_dags(**test_case["input"]) # Verify API call mock_dag_api.get_dags.assert_called_once_with(**test_case["expected_call_kwargs"]) # Verify result structure assert len(result) == 1 assert isinstance(result[0], types.TextContent) # Parse result and verify UI URLs were added if expected if test_case["expected_ui_urls"]: result_text = result[0].text assert "ui_url" in result_text @pytest.mark.parametrize( "test_case", [ { "name": "get_dag_basic", "input": {"dag_id": "test_dag"}, "mock_response": {"dag_id": "test_dag", "description": "Test DAG", "is_paused": False}, "expected_call_kwargs": {"dag_id": "test_dag"}, }, { "name": "get_dag_complex_id", "input": {"dag_id": "complex-dag_name.v2"}, "mock_response": {"dag_id": "complex-dag_name.v2", "description": "Complex DAG", "is_paused": True}, "expected_call_kwargs": {"dag_id": "complex-dag_name.v2"}, }, ], ) async def test_get_dag_table_driven(self, test_case, mock_dag_api): """Table-driven test for get_dag function.""" # Setup mock response mock_response = MagicMock() mock_response.to_dict.return_value = test_case["mock_response"] mock_dag_api.get_dag.return_value = mock_response # Execute function with patch("src.airflow.dag.AIRFLOW_HOST", "http://localhost:8080"): result = await get_dag(**test_case["input"]) # Verify API call mock_dag_api.get_dag.assert_called_once_with(**test_case["expected_call_kwargs"]) # Verify result structure and UI URL addition assert len(result) == 1 assert isinstance(result[0], types.TextContent) assert "ui_url" in result[0].text @pytest.mark.parametrize( "test_case", [ { "name": "get_dag_details_no_fields", "input": {"dag_id": "test_dag"}, "mock_response": {"dag_id": "test_dag", "file_path": "/path/to/dag.py"}, "expected_call_kwargs": {"dag_id": "test_dag"}, }, { "name": "get_dag_details_with_fields", "input": {"dag_id": "test_dag", "fields": ["dag_id", "description"]}, "mock_response": {"dag_id": "test_dag", "description": "Test"}, "expected_call_kwargs": {"dag_id": "test_dag", "fields": ["dag_id", "description"]}, }, ], ) async def test_get_dag_details_table_driven(self, test_case, mock_dag_api): """Table-driven test for get_dag_details function.""" # Setup mock response mock_response = MagicMock() mock_response.to_dict.return_value = test_case["mock_response"] mock_dag_api.get_dag_details.return_value = mock_response # Execute function result = await get_dag_details(**test_case["input"]) # Verify API call and result mock_dag_api.get_dag_details.assert_called_once_with(**test_case["expected_call_kwargs"]) assert len(result) == 1 assert isinstance(result[0], types.TextContent) @pytest.mark.parametrize( "test_case", [ { "name": "pause_dag", "function": pause_dag, "input": {"dag_id": "test_dag"}, "mock_response": {"dag_id": "test_dag", "is_paused": True}, "expected_call_kwargs": {"dag_id": "test_dag", "dag": ANY, "update_mask": ["is_paused"]}, "expected_dag_is_paused": True, }, { "name": "unpause_dag", "function": unpause_dag, "input": {"dag_id": "test_dag"}, "mock_response": {"dag_id": "test_dag", "is_paused": False}, "expected_call_kwargs": {"dag_id": "test_dag", "dag": ANY, "update_mask": ["is_paused"]}, "expected_dag_is_paused": False, }, ], ) async def test_pause_unpause_dag_table_driven(self, test_case, mock_dag_api): """Table-driven test for pause_dag and unpause_dag functions.""" # Setup mock response mock_response = MagicMock() mock_response.to_dict.return_value = test_case["mock_response"] mock_dag_api.patch_dag.return_value = mock_response # Execute function result = await test_case["function"](**test_case["input"]) # Verify API call and result mock_dag_api.patch_dag.assert_called_once_with(**test_case["expected_call_kwargs"]) # Verify the DAG object has correct is_paused value actual_call_args = mock_dag_api.patch_dag.call_args actual_dag = actual_call_args.kwargs["dag"] assert actual_dag["is_paused"] == test_case["expected_dag_is_paused"] assert len(result) == 1 assert isinstance(result[0], types.TextContent) @pytest.mark.parametrize( "test_case", [ { "name": "get_tasks_no_order", "input": {"dag_id": "test_dag"}, "mock_response": { "tasks": [ {"task_id": "task1", "operator": "BashOperator"}, {"task_id": "task2", "operator": "PythonOperator"}, ] }, "expected_call_kwargs": {"dag_id": "test_dag"}, }, { "name": "get_tasks_with_order", "input": {"dag_id": "test_dag", "order_by": "task_id"}, "mock_response": { "tasks": [ {"task_id": "task1", "operator": "BashOperator"}, {"task_id": "task2", "operator": "PythonOperator"}, ] }, "expected_call_kwargs": {"dag_id": "test_dag", "order_by": "task_id"}, }, ], ) async def test_get_tasks_table_driven(self, test_case, mock_dag_api): """Table-driven test for get_tasks function.""" # Setup mock response mock_response = MagicMock() mock_response.to_dict.return_value = test_case["mock_response"] mock_dag_api.get_tasks.return_value = mock_response # Execute function result = await get_tasks(**test_case["input"]) # Verify API call and result mock_dag_api.get_tasks.assert_called_once_with(**test_case["expected_call_kwargs"]) assert len(result) == 1 assert isinstance(result[0], types.TextContent) @pytest.mark.parametrize( "test_case", [ { "name": "patch_dag_pause_only", "input": {"dag_id": "test_dag", "is_paused": True}, "mock_response": {"dag_id": "test_dag", "is_paused": True}, "expected_update_mask": ["is_paused"], }, { "name": "patch_dag_tags_only", "input": {"dag_id": "test_dag", "tags": ["prod", "daily"]}, "mock_response": {"dag_id": "test_dag", "tags": ["prod", "daily"]}, "expected_update_mask": ["tags"], }, { "name": "patch_dag_both_fields", "input": {"dag_id": "test_dag", "is_paused": False, "tags": ["dev"]}, "mock_response": {"dag_id": "test_dag", "is_paused": False, "tags": ["dev"]}, "expected_update_mask": ["is_paused", "tags"], }, ], ) async def test_patch_dag_table_driven(self, test_case, mock_dag_api): """Table-driven test for patch_dag function.""" # Setup mock response mock_response = MagicMock() mock_response.to_dict.return_value = test_case["mock_response"] mock_dag_api.patch_dag.return_value = mock_response # Execute function with patch("src.airflow.dag.DAG") as mock_dag_class: mock_dag_instance = MagicMock() mock_dag_class.return_value = mock_dag_instance result = await patch_dag(**test_case["input"]) # Verify DAG instance creation and API call expected_update_request = {k: v for k, v in test_case["input"].items() if k != "dag_id"} mock_dag_class.assert_called_once_with(**expected_update_request) mock_dag_api.patch_dag.assert_called_once_with( dag_id=test_case["input"]["dag_id"], dag=mock_dag_instance, update_mask=test_case["expected_update_mask"], ) assert len(result) == 1 assert isinstance(result[0], types.TextContent) @pytest.mark.parametrize( "test_case", [ { "name": "clear_task_instances_minimal", "input": {"dag_id": "test_dag"}, "mock_response": {"message": "Task instances cleared"}, "expected_clear_request": {}, }, { "name": "clear_task_instances_full", "input": { "dag_id": "test_dag", "task_ids": ["task1", "task2"], "start_date": "2023-01-01", "end_date": "2023-01-31", "include_subdags": True, "include_upstream": True, "dry_run": True, }, "mock_response": {"message": "Dry run completed"}, "expected_clear_request": { "task_ids": ["task1", "task2"], "start_date": "2023-01-01", "end_date": "2023-01-31", "include_subdags": True, "include_upstream": True, "dry_run": True, }, }, ], ) async def test_clear_task_instances_table_driven(self, test_case, mock_dag_api): """Table-driven test for clear_task_instances function.""" # Setup mock response mock_response = MagicMock() mock_response.to_dict.return_value = test_case["mock_response"] mock_dag_api.post_clear_task_instances.return_value = mock_response # Execute function with patch("src.airflow.dag.ClearTaskInstances") as mock_clear_class: mock_clear_instance = MagicMock() mock_clear_class.return_value = mock_clear_instance result = await clear_task_instances(**test_case["input"]) # Verify ClearTaskInstances creation and API call mock_clear_class.assert_called_once_with(**test_case["expected_clear_request"]) mock_dag_api.post_clear_task_instances.assert_called_once_with( dag_id=test_case["input"]["dag_id"], clear_task_instances=mock_clear_instance ) assert len(result) == 1 assert isinstance(result[0], types.TextContent) @pytest.mark.parametrize( "test_case", [ { "name": "set_task_state_minimal", "input": {"dag_id": "test_dag", "state": "success"}, "mock_response": {"message": "Task state updated"}, "expected_state_request": {"state": "success"}, }, { "name": "set_task_state_full", "input": { "dag_id": "test_dag", "state": "failed", "task_ids": ["task1"], "execution_date": "2023-01-01T00:00:00Z", "include_upstream": True, "include_downstream": False, "dry_run": True, }, "mock_response": {"message": "Dry run state update"}, "expected_state_request": { "state": "failed", "task_ids": ["task1"], "execution_date": "2023-01-01T00:00:00Z", "include_upstream": True, "include_downstream": False, "dry_run": True, }, }, ], ) async def test_set_task_instances_state_table_driven(self, test_case, mock_dag_api): """Table-driven test for set_task_instances_state function.""" # Setup mock response mock_response = MagicMock() mock_response.to_dict.return_value = test_case["mock_response"] mock_dag_api.post_set_task_instances_state.return_value = mock_response # Execute function with patch("src.airflow.dag.UpdateTaskInstancesState") as mock_state_class: mock_state_instance = MagicMock() mock_state_class.return_value = mock_state_instance result = await set_task_instances_state(**test_case["input"]) # Verify UpdateTaskInstancesState creation and API call mock_state_class.assert_called_once_with(**test_case["expected_state_request"]) mock_dag_api.post_set_task_instances_state.assert_called_once_with( dag_id=test_case["input"]["dag_id"], update_task_instances_state=mock_state_instance ) assert len(result) == 1 assert isinstance(result[0], types.TextContent) @pytest.mark.parametrize( "test_case", [ { "name": "simple_functions_get_dag_source", "function": get_dag_source, "api_method": "get_dag_source", "input": {"file_token": "test_token"}, "mock_response": {"content": "DAG source code"}, "expected_call_kwargs": {"file_token": "test_token"}, }, { "name": "simple_functions_get_dag_tasks", "function": get_dag_tasks, "api_method": "get_tasks", "input": {"dag_id": "test_dag"}, "mock_response": {"tasks": []}, "expected_call_kwargs": {"dag_id": "test_dag"}, }, { "name": "simple_functions_get_task", "function": get_task, "api_method": "get_task", "input": {"dag_id": "test_dag", "task_id": "test_task"}, "mock_response": {"task_id": "test_task", "operator": "BashOperator"}, "expected_call_kwargs": {"dag_id": "test_dag", "task_id": "test_task"}, }, { "name": "simple_functions_delete_dag", "function": delete_dag, "api_method": "delete_dag", "input": {"dag_id": "test_dag"}, "mock_response": {"message": "DAG deleted"}, "expected_call_kwargs": {"dag_id": "test_dag"}, }, { "name": "simple_functions_reparse_dag_file", "function": reparse_dag_file, "api_method": "reparse_dag_file", "input": {"file_token": "test_token"}, "mock_response": {"message": "DAG file reparsed"}, "expected_call_kwargs": {"file_token": "test_token"}, }, ], ) async def test_simple_functions_table_driven(self, test_case, mock_dag_api): """Table-driven test for simple functions that directly call API methods.""" # Setup mock response mock_response = MagicMock() mock_response.to_dict.return_value = test_case["mock_response"] getattr(mock_dag_api, test_case["api_method"]).return_value = mock_response # Execute function result = await test_case["function"](**test_case["input"]) # Verify API call and result getattr(mock_dag_api, test_case["api_method"]).assert_called_once_with(**test_case["expected_call_kwargs"]) assert len(result) == 1 assert isinstance(result[0], types.TextContent) assert str(test_case["mock_response"]) in result[0].text @pytest.mark.integration async def test_dag_functions_integration_flow(self, mock_dag_api): """Integration test showing typical DAG management workflow.""" # Test data for a complete workflow dag_id = "integration_test_dag" # Mock responses for each step mock_responses = { "get_dag": {"dag_id": dag_id, "is_paused": True}, "patch_dag": {"dag_id": dag_id, "is_paused": False}, "get_tasks": {"tasks": [{"task_id": "task1"}, {"task_id": "task2"}]}, "delete_dag": {"message": "DAG deleted successfully"}, } # Setup mock responses for method, response in mock_responses.items(): mock_response = MagicMock() mock_response.to_dict.return_value = response getattr(mock_dag_api, method).return_value = mock_response # Execute workflow steps with patch("src.airflow.dag.AIRFLOW_HOST", "http://localhost:8080"): # 1. Get DAG info dag_info = await get_dag(dag_id) assert len(dag_info) == 1 # 2. Unpause DAG with patch("src.airflow.dag.DAG") as mock_dag_class: mock_dag_class.return_value = MagicMock() unpause_result = await patch_dag(dag_id, is_paused=False) assert len(unpause_result) == 1 # 3. Get tasks tasks_result = await get_tasks(dag_id) assert len(tasks_result) == 1 # 4. Delete DAG delete_result = await delete_dag(dag_id) assert len(delete_result) == 1 # Verify all API calls were made mock_dag_api.get_dag.assert_called_once_with(dag_id=dag_id) mock_dag_api.patch_dag.assert_called_once() mock_dag_api.get_tasks.assert_called_once_with(dag_id=dag_id) mock_dag_api.delete_dag.assert_called_once_with(dag_id=dag_id)

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/yangkyeongmo/mcp-server-apache-airflow'

If you have feedback or need assistance with the MCP directory API, please join our Discord server