test_api.py•28.1 kB
"""
Unit tests for the Treasure Data API client.
"""
import pytest
import requests
import responses
from td_mcp_server.api import Database, Project, Table, TreasureDataClient
class TestTreasureDataClient:
    """Tests for the TreasureDataClient class."""
    def setup_method(self):
        """Set up test environment before each test."""
        self.api_key = "test_api_key"
        self.endpoint = "api.treasuredata.com"
        self.client = TreasureDataClient(api_key=self.api_key, endpoint=self.endpoint)
        self.mock_databases = [
            {
                "name": "db1",
                "created_at": "2023-01-01 00:00:00 UTC",
                "updated_at": "2023-01-01 00:00:00 UTC",
                "count": 3,
                "organization": None,
                "permission": "administrator",
                "delete_protected": False,
            },
            {
                "name": "db2",
                "created_at": "2023-01-02 00:00:00 UTC",
                "updated_at": "2023-01-02 00:00:00 UTC",
                "count": 5,
                "organization": None,
                "permission": "administrator",
                "delete_protected": True,
            },
            {
                "name": "db3",
                "created_at": "2023-01-03 00:00:00 UTC",
                "updated_at": "2023-01-03 00:00:00 UTC",
                "count": 0,
                "organization": None,
                "permission": "administrator",
                "delete_protected": False,
            },
        ]
        self.mock_tables = [
            {
                "id": 1234,
                "name": "table1",
                "estimated_storage_size": 10000,
                "counter_updated_at": "2023-01-01T00:00:00Z",
                "last_log_timestamp": "2023-01-01T00:00:00Z",
                "delete_protected": False,
                "created_at": "2023-01-01 00:00:00 UTC",
                "updated_at": "2023-01-01 00:00:00 UTC",
                "type": "log",
                "include_v": True,
                "count": 100,
                "schema": '[["id","string"],["name","string"]]',
                "expire_days": None,
            },
            {
                "id": 5678,
                "name": "table2",
                "estimated_storage_size": 20000,
                "counter_updated_at": "2023-01-02T00:00:00Z",
                "last_log_timestamp": "2023-01-02T00:00:00Z",
                "delete_protected": True,
                "created_at": "2023-01-02 00:00:00 UTC",
                "updated_at": "2023-01-02 00:00:00 UTC",
                "type": "log",
                "include_v": True,
                "count": 200,
                "schema": '[["id","string"],["value","integer"]]',
                "expire_days": 30,
            },
        ]
        self.mock_projects = [
            {
                "id": "123456",
                "name": "demo_content_affinity",
                "revision": "abcdef1234567890abcdef1234567890",
                "createdAt": "2022-01-01T00:00:00Z",
                "updatedAt": "2022-01-02T00:00:00Z",
                "deletedAt": None,
                "archiveType": "s3",
                "archiveMd5": "abcdefghijklmnopqrstuvwx==",
                "metadata": [],
            },
            {
                "id": "789012",
                "name": "cdp_audience_123456",
                "revision": "abcdef1234567890abcdef1234567890",
                "createdAt": "2022-01-01T00:00:00Z",
                "updatedAt": "2023-01-01T00:00:00Z",
                "deletedAt": None,
                "archiveType": "s3",
                "archiveMd5": "zyxwvutsrqponmlkjihgfed==",
                "metadata": [
                    {"key": "pbp", "value": "cdp_audience"},
                    {"key": "pbp", "value": "cdp_audience_123456"},
                    {"key": "sys", "value": "cdp_audience"},
                ],
            },
        ]
    def test_init(self):
        """Test client initialization."""
        assert self.client.api_key == self.api_key
        assert self.client.endpoint == self.endpoint
        assert self.client.base_url == f"https://{self.endpoint}/v3"
        assert self.client.headers["Authorization"] == f"TD1 {self.api_key}"
        assert self.client.headers["Content-Type"] == "application/json"
    def test_init_from_env(self, monkeypatch):
        """Test client initialization from environment variable."""
        monkeypatch.setenv("TD_API_KEY", "env_api_key")
        client = TreasureDataClient()
        assert client.api_key == "env_api_key"
        assert client.endpoint == "api.treasuredata.com"  # default endpoint
    def test_init_missing_api_key(self, monkeypatch):
        """Test client initialization with missing API key."""
        monkeypatch.delenv("TD_API_KEY", raising=False)
        with pytest.raises(ValueError, match="API key must be provided"):
            TreasureDataClient()
    @responses.activate
    def test_get_databases(self):
        """Test get_databases method."""
        # Mock the API response
        responses.add(
            responses.GET,
            f"https://{self.endpoint}/v3/database/list",
            json={"databases": self.mock_databases},
            status=200,
        )
        # Call the method
        databases = self.client.get_databases()
        # Verify the results
        assert len(databases) == 3
        assert isinstance(databases[0], Database)
        assert databases[0].name == "db1"
        assert databases[1].name == "db2"
        assert databases[2].name == "db3"
        assert databases[0].count == 3
        assert databases[1].delete_protected is True
    @responses.activate
    def test_get_databases_with_pagination(self):
        """Test get_databases method with pagination."""
        # Mock the API response
        responses.add(
            responses.GET,
            f"https://{self.endpoint}/v3/database/list",
            json={"databases": self.mock_databases},
            status=200,
        )
        # Test with limit and offset
        databases = self.client.get_databases(limit=2, offset=1)
        assert len(databases) == 2
        assert databases[0].name == "db2"
        assert databases[1].name == "db3"
        # Test with all_results=True
        databases = self.client.get_databases(all_results=True)
        assert len(databases) == 3
        # Test with small limit
        databases = self.client.get_databases(limit=1)
        assert len(databases) == 1
        assert databases[0].name == "db1"
    @responses.activate
    def test_get_database(self):
        """Test get_database method."""
        # Mock the API response
        responses.add(
            responses.GET,
            f"https://{self.endpoint}/v3/database/list",
            json={"databases": self.mock_databases},
            status=200,
        )
        # Test existing database
        database = self.client.get_database("db2")
        assert database is not None
        assert database.name == "db2"
        assert database.count == 5
        assert database.delete_protected is True
        # Test non-existing database
        database = self.client.get_database("nonexistent")
        assert database is None
    @responses.activate
    def test_get_tables(self):
        """Test get_tables method."""
        database_name = "test_db"
        # Mock the API response
        responses.add(
            responses.GET,
            f"https://{self.endpoint}/v3/table/list/{database_name}",
            json={"tables": self.mock_tables},
            status=200,
        )
        # Call the method
        tables = self.client.get_tables(database_name)
        # Verify the results
        assert len(tables) == 2
        assert isinstance(tables[0], Table)
        assert tables[0].name == "table1"
        assert tables[1].name == "table2"
        assert tables[0].count == 100
        assert tables[1].expire_days == 30
    @responses.activate
    def test_get_tables_with_pagination(self):
        """Test get_tables method with pagination."""
        database_name = "test_db"
        # Mock the API response
        responses.add(
            responses.GET,
            f"https://{self.endpoint}/v3/table/list/{database_name}",
            json={"tables": self.mock_tables},
            status=200,
        )
        # Test with limit and offset
        tables = self.client.get_tables(database_name, limit=1, offset=1)
        assert len(tables) == 1
        assert tables[0].name == "table2"
        # Test with all_results=True
        tables = self.client.get_tables(database_name, all_results=True)
        assert len(tables) == 2
        # Test with large limit
        tables = self.client.get_tables(database_name, limit=10)
        assert len(tables) == 2
    @responses.activate
    def test_make_request_error(self):
        """Test error handling in _make_request method."""
        # Mock an error response
        responses.add(
            responses.GET,
            f"https://{self.endpoint}/v3/error",
            json={"error": "Something went wrong"},
            status=500,
        )
        # Verify that exception is raised
        with pytest.raises(requests.exceptions.HTTPError):
            self.client._make_request("GET", "error")
    @responses.activate
    def test_get_projects(self):
        """Test get_projects method."""
        # Mock the API response
        workflow_endpoint = "api-workflow.treasuredata.com"
        responses.add(
            responses.GET,
            f"https://{workflow_endpoint}/api/projects",
            json={"projects": self.mock_projects},
            status=200,
        )
        # Call the method
        projects = self.client.get_projects()
        # Verify the results
        assert len(projects) == 2
        assert isinstance(projects[0], Project)
        assert projects[0].id == "123456"
        assert projects[0].name == "demo_content_affinity"
        assert projects[1].id == "789012"
        assert projects[1].name == "cdp_audience_123456"
        assert len(projects[1].metadata) == 3
        assert projects[1].metadata[0].key == "pbp"
        assert projects[1].metadata[0].value == "cdp_audience"
    @responses.activate
    def test_get_projects_with_pagination(self):
        """Test get_projects method with pagination."""
        # Mock the API response
        workflow_endpoint = "api-workflow.treasuredata.com"
        responses.add(
            responses.GET,
            f"https://{workflow_endpoint}/api/projects",
            json={"projects": self.mock_projects},
            status=200,
        )
        # Test with limit and offset
        projects = self.client.get_projects(limit=1, offset=1)
        assert len(projects) == 1
        assert projects[0].id == "789012"
        assert projects[0].name == "cdp_audience_123456"
        # Test with all_results=True
        projects = self.client.get_projects(all_results=True)
        assert len(projects) == 2
        # Test with large limit
        projects = self.client.get_projects(limit=10)
        assert len(projects) == 2
    def test_workflow_endpoint_derivation(self):
        """Test workflow endpoint derivation based on API endpoint."""
        # Test US region standard pattern
        client = TreasureDataClient(
            api_key=self.api_key, endpoint="api.treasuredata.com"
        )
        assert client.workflow_endpoint == "api-workflow.treasuredata.com"
        # Test Japan region pattern
        client = TreasureDataClient(
            api_key=self.api_key, endpoint="api.treasuredata.co.jp"
        )
        assert client.workflow_endpoint == "api-workflow.treasuredata.co.jp"
        # Test with non-standard region
        client = TreasureDataClient(
            api_key=self.api_key, endpoint="api.treasuredata.eu"
        )
        assert client.workflow_endpoint == "api-workflow.treasuredata.eu"
        # Test with different domain structure (non-standard input)
        client = TreasureDataClient(
            api_key=self.api_key, endpoint="treasuredata-api.com"
        )
        # Should still perform the replacement
        assert client.workflow_endpoint == "treasuredata-api-workflow.com"
        # Test custom endpoint
        custom_endpoint = "custom-workflow.example.com"
        client = TreasureDataClient(
            api_key=self.api_key,
            endpoint="api.example.com",
            workflow_endpoint=custom_endpoint,
        )
        assert client.workflow_endpoint == custom_endpoint
    @responses.activate
    def test_get_project(self):
        """Test get_project method."""
        project_id = "123456"
        workflow_endpoint = "api-workflow.treasuredata.com"
        # Mock the API response
        responses.add(
            responses.GET,
            f"https://{workflow_endpoint}/api/projects/{project_id}",
            json={
                "id": project_id,
                "name": "demo_content_affinity",
                "revision": "abcdef1234567890abcdef1234567890",
                "createdAt": "2022-01-01T00:00:00Z",
                "updatedAt": "2022-01-02T00:00:00Z",
                "deletedAt": None,
                "archiveType": "s3",
                "archiveMd5": "abcdefghijklmnopqrstuvwx==",
                "metadata": [
                    {"key": "category", "value": "machine-learning"},
                    {"key": "version", "value": "1.0.0"},
                ],
            },
            status=200,
        )
        # Call the method
        project = self.client.get_project(project_id)
        # Verify the results
        assert project is not None
        assert project.id == project_id
        assert project.name == "demo_content_affinity"
        assert project.revision == "abcdef1234567890abcdef1234567890"
        assert project.created_at == "2022-01-01T00:00:00Z"
        assert project.updated_at == "2022-01-02T00:00:00Z"
        assert len(project.metadata) == 2
        assert project.metadata[0].key == "category"
        assert project.metadata[0].value == "machine-learning"
    @responses.activate
    def test_get_project_not_found(self):
        """Test get_project method when project is not found."""
        project_id = "nonexistent"
        workflow_endpoint = "api-workflow.treasuredata.com"
        # Mock the API response with 404 status code
        responses.add(
            responses.GET,
            f"https://{workflow_endpoint}/api/projects/{project_id}",
            json={"error": "Project not found"},
            status=404,
        )
        # Call the method - should return None for 404
        project = self.client.get_project(project_id)
        # Verify the result
        assert project is None
    @responses.activate
    def test_download_project_archive(self, tmp_path):
        """Test download_project_archive method."""
        project_id = "123456"
        workflow_endpoint = "api-workflow.treasuredata.com"
        # Create a temporary file to save the archive
        output_path = tmp_path / f"project_{project_id}.tar.gz"
        # Mock archive data - creating a simple tar.gz file
        mock_archive_data = b"mock tar.gz content"
        # Mock the API response
        responses.add(
            responses.GET,
            f"https://{workflow_endpoint}/api/projects/{project_id}/archive",
            body=mock_archive_data,
            status=200,
        )
        # Call the method
        success = self.client.download_project_archive(project_id, str(output_path))
        # Verify the result
        assert success is True
        assert output_path.exists()
        # Check content of the downloaded file
        with open(output_path, "rb") as f:
            content = f.read()
            assert content == mock_archive_data
    @responses.activate
    def test_download_project_archive_not_found(self, tmp_path):
        """Test download_project_archive method when project is not found."""
        project_id = "nonexistent"
        workflow_endpoint = "api-workflow.treasuredata.com"
        # Create a temporary file to save the archive
        output_path = tmp_path / f"project_{project_id}.tar.gz"
        # Mock the API response with 404 status code
        responses.add(
            responses.GET,
            f"https://{workflow_endpoint}/api/projects/{project_id}/archive",
            json={"error": "Project not found"},
            status=404,
        )
        # Call the method - should return False for 404
        success = self.client.download_project_archive(project_id, str(output_path))
        # Verify the result
        assert success is False
        assert not output_path.exists()
    @responses.activate
    def test_get_workflow_by_id(self):
        """Test get_workflow_by_id method."""
        workflow_id = "12345678"
        workflow_endpoint = "api-workflow.treasuredata.com"
        # Mock the API response
        mock_response = {
            "id": workflow_id,
            "name": "test_workflow",
            "project": {
                "id": "123456",
                "name": "test_project",
            },
            "revision": "abcdef1234567890abcdef1234567890",
            "timezone": "UTC",
            "config": {
                "+task1": {
                    "td>": {
                        "database": "test_db",
                        "engine": "presto",
                    }
                }
            },
        }
        responses.add(
            responses.GET,
            f"https://{workflow_endpoint}/api/workflows/{workflow_id}",
            json=mock_response,
            status=200,
        )
        # Call the method
        workflow = self.client.get_workflow_by_id(workflow_id)
        # Verify the result
        assert workflow is not None
        assert workflow.id == workflow_id
        assert workflow.name == "test_workflow"
        assert workflow.project.id == "123456"
        assert workflow.project.name == "test_project"
        assert workflow.timezone == "UTC"
        assert workflow.revision == "abcdef1234567890abcdef1234567890"
    @responses.activate
    def test_get_workflow_by_id_not_found(self):
        """Test get_workflow_by_id method when workflow is not found."""
        workflow_id = "nonexistent"
        workflow_endpoint = "api-workflow.treasuredata.com"
        # Mock the API response with 404 status code
        responses.add(
            responses.GET,
            f"https://{workflow_endpoint}/api/workflows/{workflow_id}",
            json={"message": "Resource does not exist: workflow id=nonexistent"},
            status=404,
        )
        # Call the method - should return None for 404
        workflow = self.client.get_workflow_by_id(workflow_id)
        # Verify the result
        assert workflow is None
    @responses.activate
    def test_get_workflows_with_pagination(self):
        """Test get_workflows method with pagination."""
        workflow_endpoint = "api-workflow.treasuredata.com"
        # Mock workflow data for page 1
        mock_workflows_page1 = [
            {
                "id": "123",
                "name": "workflow1",
                "project": {
                    "id": "1",
                    "name": "project1",
                    "updatedAt": "2023-01-01T00:00:00Z",
                },
                "revision": "abc123",
                "timezone": "UTC",
                "config": {},
                "schedule": None,
                "latestSessions": [],
            },
            {
                "id": "456",
                "name": "workflow2",
                "project": {
                    "id": "2",
                    "name": "project2",
                    "updatedAt": "2023-01-02T00:00:00Z",
                },
                "revision": "def456",
                "timezone": "Asia/Tokyo",
                "config": {},
                "schedule": {"cron": "0 0 * * *"},
                "latestSessions": [],
            },
        ]
        # Mock the API response for page 1
        responses.add(
            responses.GET,
            f"https://{workflow_endpoint}/api/console/workflows",
            json={"workflows": mock_workflows_page1},
            status=200,
            match=[
                responses.matchers.query_param_matcher(
                    {
                        "count": "2",
                        "page": "1",
                        "order": "asc",
                        "sessions": "5",
                        "output": "simple",
                        "project_type": "user",
                    }
                )
            ],
        )
        # Call the method with specific page
        workflows = self.client.get_workflows(count=2, page=1)
        # Verify the result
        assert len(workflows) == 2
        assert workflows[0].id == "123"
        assert workflows[0].name == "workflow1"
        assert workflows[1].id == "456"
        assert workflows[1].name == "workflow2"
    @responses.activate
    def test_get_session(self):
        """Test get_session method."""
        session_id = "123456789"
        workflow_endpoint = "api-workflow.treasuredata.com"
        # Mock the API response
        mock_response = {
            "id": session_id,
            "project": {"id": "123456", "name": "test_project"},
            "workflow": {"name": "test_workflow", "id": "12345678"},
            "sessionUuid": "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
            "sessionTime": "2025-08-03T03:00:00+00:00",
            "lastAttempt": {
                "id": "987654321",
                "retryAttemptName": None,
                "done": True,
                "success": True,
                "cancelRequested": False,
                "params": {},
                "createdAt": "2025-08-03T03:00:00Z",
                "finishedAt": "2025-08-03T03:05:30Z",
                "status": "success",
            },
        }
        responses.add(
            responses.GET,
            f"https://{workflow_endpoint}/api/sessions/{session_id}",
            json=mock_response,
            status=200,
        )
        # Call the method
        session = self.client.get_session(session_id)
        # Verify the result
        assert session is not None
        assert session.id == session_id
        assert session.project["name"] == "test_project"
        assert session.workflow["name"] == "test_workflow"
        assert session.session_time == "2025-08-03T03:00:00+00:00"
        assert session.last_attempt.success is True
    @responses.activate
    def test_get_sessions(self):
        """Test get_sessions method."""
        workflow_endpoint = "api-workflow.treasuredata.com"
        # Mock the API response
        mock_response = {
            "sessions": [
                {
                    "id": "123456789",
                    "project": {"id": "123456", "name": "test_project"},
                    "workflow": {"name": "test_workflow", "id": "12345678"},
                    "sessionUuid": "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
                    "sessionTime": "2025-08-03T03:00:00+00:00",
                    "lastAttempt": {
                        "id": "987654321",
                        "retryAttemptName": None,
                        "done": True,
                        "success": True,
                        "cancelRequested": False,
                        "params": {},
                        "createdAt": "2025-08-03T03:00:00Z",
                        "finishedAt": "2025-08-03T03:05:30Z",
                        "status": "success",
                    },
                }
            ]
        }
        responses.add(
            responses.GET,
            f"https://{workflow_endpoint}/api/sessions",
            json=mock_response,
            status=200,
            match=[responses.matchers.query_param_matcher({"last": "20"})],
        )
        # Call the method
        sessions = self.client.get_sessions()
        # Verify the result
        assert len(sessions) == 1
        assert sessions[0].id == "123456789"
        assert sessions[0].project["name"] == "test_project"
    @responses.activate
    def test_get_attempt(self):
        """Test get_attempt method."""
        attempt_id = "987654321"
        workflow_endpoint = "api-workflow.treasuredata.com"
        # Mock the API response
        mock_response = {
            "id": attempt_id,
            "index": 1,
            "project": {"id": "123456", "name": "test_project"},
            "workflow": {"name": "test_workflow", "id": "12345678"},
            "sessionId": "123456789",
            "sessionUuid": "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
            "sessionTime": "2025-08-03T03:00:00+00:00",
            "retryAttemptName": None,
            "done": True,
            "success": True,
            "cancelRequested": False,
            "params": {},
            "createdAt": "2025-08-03T03:00:00Z",
            "finishedAt": "2025-08-03T03:05:30Z",
            "status": "success",
        }
        responses.add(
            responses.GET,
            f"https://{workflow_endpoint}/api/attempts/{attempt_id}",
            json=mock_response,
            status=200,
        )
        # Call the method
        attempt = self.client.get_attempt(attempt_id)
        # Verify the result
        assert attempt is not None
        assert attempt.id == attempt_id
        assert attempt.index == 1
        assert attempt.session_id == "123456789"
        assert attempt.success is True
    @responses.activate
    def test_get_attempt_tasks(self):
        """Test get_attempt_tasks method."""
        attempt_id = "987654321"
        workflow_endpoint = "api-workflow.treasuredata.com"
        # Mock the API response
        mock_response = {
            "tasks": [
                {
                    "id": "1234567890",
                    "fullName": "+main_workflow",
                    "parentId": None,
                    "config": {},
                    "upstreams": [],
                    "state": "success",
                    "cancelRequested": False,
                    "exportParams": {},
                    "storeParams": {},
                    "stateParams": {},
                    "updatedAt": "2025-08-03T03:05:30Z",
                    "retryAt": None,
                    "startedAt": "2025-08-03T03:00:00Z",
                    "error": {},
                    "isGroup": True,
                },
                {
                    "id": "1234567891",
                    "fullName": "+main_workflow+extract_data",
                    "parentId": "1234567890",
                    "config": {
                        "td>": {
                            "query": "SELECT * FROM test_table",
                            "database": "test_db",
                        }
                    },
                    "upstreams": [],
                    "state": "success",
                    "cancelRequested": False,
                    "exportParams": {},
                    "storeParams": {},
                    "stateParams": {},
                    "updatedAt": "2025-08-03T03:02:15Z",
                    "retryAt": None,
                    "startedAt": "2025-08-03T03:00:01Z",
                    "error": {},
                    "isGroup": False,
                },
            ]
        }
        responses.add(
            responses.GET,
            f"https://{workflow_endpoint}/api/attempts/{attempt_id}/tasks",
            json=mock_response,
            status=200,
        )
        # Call the method
        tasks = self.client.get_attempt_tasks(attempt_id)
        # Verify the result
        assert len(tasks) == 2
        assert tasks[0].id == "1234567890"
        assert tasks[0].full_name == "+main_workflow"
        assert tasks[0].is_group is True
        assert tasks[1].id == "1234567891"
        assert tasks[1].full_name == "+main_workflow+extract_data"
        assert tasks[1].parent_id == "1234567890"