workspace.py•21.8 kB
import abc
import asyncio
import json
import logging
import time
from typing import Any, Literal, Mapping, Sequence, cast
from urllib.parse import urlunparse
from httpx import HTTPStatusError
from pydantic import Field, TypeAdapter
from pydantic.dataclasses import dataclass
from keboola_mcp_server.clients.base import JsonDict
from keboola_mcp_server.clients.client import KeboolaClient
from keboola_mcp_server.clients.query import QueryServiceClient
LOG = logging.getLogger(__name__)
@dataclass(frozen=True)
class TableFqn:
    """The properly quoted parts of a fully qualified table name."""
    # TODO: refactor this and probably use just a simple string
    db_name: str  # project_id in a BigQuery
    schema_name: str  # dataset in a BigQuery
    table_name: str
    quote_char: str = ''
    @property
    def identifier(self) -> str:
        """Returns the properly quoted database identifier."""
        return '.'.join(
            f'{self.quote_char}{n}{self.quote_char}' for n in [self.db_name, self.schema_name, self.table_name]
        )
    def __repr__(self) -> str:
        return self.identifier
    def __str__(self) -> str:
        return self.__repr__()
@dataclass(frozen=True)
class DbColumnInfo:
    name: str
    quoted_name: str
    native_type: str
    nullable: bool
@dataclass(frozen=True)
class DbTableInfo:
    id: str
    fqn: TableFqn
    columns: Mapping[str, DbColumnInfo]
QueryStatus = Literal['ok', 'error']
SqlSelectDataRow = Mapping[str, Any]
@dataclass(frozen=True)
class SqlSelectData:
    columns: Sequence[str] = Field(description='Names of the columns returned from SQL select.')
    rows: Sequence[SqlSelectDataRow] = Field(
        description='Selected rows, each row is a dictionary of column: value pairs.'
    )
@dataclass(frozen=True)
class QueryResult:
    status: QueryStatus = Field(description='Status of running the SQL query.')
    data: SqlSelectData | None = Field(None, description='Data selected by the SQL SELECT query.')
    message: str | None = Field(None, description='Either an error message or the information from non-SELECT queries.')
    @property
    def is_ok(self) -> bool:
        return self.status == 'ok'
    @property
    def is_error(self) -> bool:
        return not self.is_ok
class _Workspace(abc.ABC):
    _QUERY_TIMEOUT = 300.0  # 5 minutes
    def __init__(self, workspace_id: int) -> None:
        self._workspace_id = workspace_id
    @property
    def id(self) -> int:
        return self._workspace_id
    @abc.abstractmethod
    def get_sql_dialect(self) -> str:
        pass
    @abc.abstractmethod
    def get_quoted_name(self, name: str) -> str:
        pass
    @abc.abstractmethod
    async def get_table_info(self, table: Mapping[str, Any]) -> DbTableInfo | None:
        # TODO: use a pydantic class for the 'table' param
        pass
    @abc.abstractmethod
    async def execute_query(self, sql_query: str) -> QueryResult:
        """Runs a SQL SELECT query."""
        pass
class _SnowflakeWorkspace(_Workspace):
    def __init__(self, workspace_id: int, schema: str, client: KeboolaClient):
        super().__init__(workspace_id)
        self._schema = schema  # default schema created for the workspace
        self._client = client
        self._qsclient: QueryServiceClient | None = None
    def get_sql_dialect(self) -> str:
        return 'Snowflake'
    def get_quoted_name(self, name: str) -> str:
        return f'"{name}"'  # wrap name in double quotes
    async def get_table_info(self, table: Mapping[str, Any]) -> DbTableInfo | None:
        table_id = table['id']
        db_name: str | None = None
        schema_name: str | None = None
        table_name: str | None = None
        if source_table := table.get('sourceTable'):
            # a table linked from some other project
            schema_name, table_name = source_table['id'].rsplit(sep='.', maxsplit=1)
            source_project_id = source_table['project']['id']
            # sql = f"show databases like '%_{source_project_id}';"
            sql = (
                f'select "DATABASE_NAME" from "INFORMATION_SCHEMA"."DATABASES" '
                f'where "DATABASE_NAME" like \'%_{source_project_id}\';'
            )
            result = await self.execute_query(sql)
            if result.is_ok and result.data and result.data.rows:
                db_name = result.data.rows[0]['DATABASE_NAME']
            else:
                LOG.error(f'Failed to run SQL: {sql}, SAPI response: {result}')
        else:
            sql = 'select CURRENT_DATABASE() as "current_database";'
            result = await self.execute_query(sql)
            if result.is_ok and result.data and result.data.rows:
                row = result.data.rows[0]
                db_name = row['current_database']
                if '.' in table_id:
                    # a table local in a project for which the snowflake connection/workspace is open
                    schema_name, table_name = table_id.rsplit(sep='.', maxsplit=1)
                else:
                    # a table not in the project, but in the writable schema created for the workspace
                    # TODO: we should never come here, because the tools for listing tables can only see
                    #  tables that are in the project
                    schema_name = self._schema
                    table_name = table['name']
            else:
                LOG.error(f'Failed to run SQL: {sql}, SAPI response: {result}')
        if db_name and schema_name and table_name:
            sql = (
                f'SELECT "COLUMN_NAME", "DATA_TYPE", "IS_NULLABLE" '
                f'FROM "INFORMATION_SCHEMA"."COLUMNS" '
                f'WHERE "TABLE_CATALOG" = \'{db_name}\' AND "TABLE_SCHEMA" = \'{schema_name}\' '
                f'AND "TABLE_NAME" = \'{table_name}\' '
                f'ORDER BY "ORDINAL_POSITION";'
            )
            result = await self.execute_query(sql)
            if result.is_ok and result.data:
                return DbTableInfo(
                    id=table_id,
                    fqn=TableFqn(db_name, schema_name, table_name, quote_char='"'),
                    columns={
                        row['COLUMN_NAME']: DbColumnInfo(
                            name=row['COLUMN_NAME'],
                            quoted_name=self.get_quoted_name(row['COLUMN_NAME']),
                            native_type=row['DATA_TYPE'],
                            nullable=row['IS_NULLABLE'] == 'YES',
                        )
                        for row in result.data.rows
                    },
                )
            else:
                LOG.error(f'Failed to run SQL: {sql}, SAPI response: {result}')
        return None
    async def execute_query(self, sql_query: str) -> QueryResult:
        if not self._qsclient:
            self._qsclient = await self._create_qs_client()
        ts_start = time.perf_counter()
        job_id = await self._qsclient.submit_job(statements=[sql_query], workspace_id=str(self.id))
        while (job_status := await self._qsclient.get_job_status(job_id)) and job_status['status'] not in [
            'completed',
            'failed',
            'canceled',
        ]:
            await asyncio.sleep(1)
            elapsed_time = time.perf_counter() - ts_start
            if elapsed_time > self._QUERY_TIMEOUT:
                raise RuntimeError(
                    f'Query execution timed out after {elapsed_time:.2f} seconds: '
                    f'job_id={job_id}, status={job_status["status"]}'
                )
        statement_id = cast(list[JsonDict], job_status['statements'])[0]['id']
        results = await self._qsclient.get_job_results(job_id, statement_id)
        if results['status'] == 'completed':
            columns = [col['name'] for col in cast(list[JsonDict], results['columns'])]
            rows = [
                {col_name: value for col_name, value in zip(columns, row)}
                for row in cast(list[list[Any]], results['data'])
            ]
            query_result = QueryResult(
                status='ok',
                data=SqlSelectData(columns=columns, rows=rows) if columns else None,
                message=results['message'],
            )
        elif results['status'] in ['failed', 'canceled']:
            query_result = QueryResult(status='error', data=None, message=results['message'])
        else:
            raise ValueError(f'Unexpected query status: {results["status"]}')
        return query_result
    async def _create_qs_client(self) -> QueryServiceClient:
        real_branch_id = self._client.branch_id
        if not real_branch_id:
            for branch in await self._client.storage_client.branches_list():
                if (is_default := branch.get('isDefault')) and isinstance(is_default, bool) and is_default:
                    real_branch_id = branch['id']
                    break
        if not real_branch_id:
            raise RuntimeError('Cannot determine the default branch ID')
        return QueryServiceClient.create(
            root_url=urlunparse(('https', f'query.{self._client.hostname_suffix}', '', '', '', '')),
            branch_id=real_branch_id,
            token=self._client.token,
            headers=self._client.headers,
        )
class _BigQueryWorkspace(_Workspace):
    _BQ_FIELDS = {'_timestamp'}
    def __init__(self, workspace_id: int, dataset_id: str, project_id: str, client: KeboolaClient):
        super().__init__(workspace_id)
        self._dataset_id = dataset_id  # default dataset created for the workspace
        self._project_id = project_id
        self._client = client
    def get_sql_dialect(self) -> str:
        return 'BigQuery'
    def get_quoted_name(self, name: str) -> str:
        return f'`{name}`'  # wrap name in back tick
    async def get_table_info(self, table: Mapping[str, Any]) -> DbTableInfo | None:
        table_id = table['id']
        if '.' in table_id:
            # a table local in a project for which the workspace is open
            schema_name, table_name = table_id.rsplit(sep='.', maxsplit=1)
            schema_name = schema_name.replace('.', '_').replace('-', '_')
        else:
            # a table not in the project, but in the writable schema created for the workspace
            # TODO: we should never come here, because the tools for listing tables can only see
            #  tables that are in the project
            schema_name = self._dataset_id
            table_name = table['name']
        if schema_name and table_name:
            sql = (
                f'SELECT column_name, data_type, is_nullable '
                f'FROM `{self._project_id}`.`{schema_name}`.`INFORMATION_SCHEMA`.`COLUMNS` '
                f"WHERE table_name = '{table_name}' "
                f'ORDER BY ordinal_position;'
            )
            result = await self.execute_query(sql)
            if result.is_ok and result.data:
                return DbTableInfo(
                    id=table_id,
                    fqn=TableFqn(self._project_id, schema_name, table_name, quote_char='`'),
                    columns={
                        row['column_name']: DbColumnInfo(
                            name=row['column_name'],
                            quoted_name=self.get_quoted_name(row['column_name']),
                            native_type=row['data_type'],
                            nullable=row['is_nullable'] == 'YES',
                        )
                        for row in result.data.rows
                    },
                )
            else:
                LOG.error(f'Failed to run SQL: {sql}, SAPI response: {result}')
        return None
    async def execute_query(self, sql_query: str) -> QueryResult:
        resp = await self._client.storage_client.workspace_query(workspace_id=self.id, query=sql_query)
        return TypeAdapter(QueryResult).validate_python(resp)
@dataclass(frozen=True)
class _WspInfo:
    id: int
    schema: str
    backend: str
    credentials: str | None  # the backend credentials; it can contain serialized JSON data
    readonly: bool | None
    @staticmethod
    def from_sapi_info(sapi_wsp_info: Mapping[str, Any]) -> '_WspInfo':
        _id = sapi_wsp_info.get('id')
        backend = sapi_wsp_info.get('connection', {}).get('backend')
        _schema = sapi_wsp_info.get('connection', {}).get('schema')
        credentials = sapi_wsp_info.get('connection', {}).get('user')
        readonly = sapi_wsp_info.get('readOnlyStorageAccess')
        return _WspInfo(id=_id, schema=_schema, backend=backend, credentials=credentials, readonly=readonly)
class WorkspaceManager:
    STATE_KEY = 'workspace_manager'
    MCP_META_KEY = 'KBC.McpServer.workspaceId'
    @classmethod
    def from_state(cls, state: Mapping[str, Any]) -> 'WorkspaceManager':
        instance = state[cls.STATE_KEY]
        assert isinstance(instance, WorkspaceManager), f'Expected WorkspaceManager, got: {instance}'
        return instance
    def __init__(self, client: KeboolaClient, workspace_schema: str | None = None):
        # We use the read-only workspace with access to all project data which lives in the production branch.
        # Hence, we need KeboolaClient bound to the production/default branch.
        self._client = client.with_branch_id(None)
        self._workspace_schema = workspace_schema
        self._workspace: _Workspace | None = None
        self._table_info_cache: dict[str, DbTableInfo] = {}
    async def _find_ws_by_schema(self, schema: str) -> _WspInfo | None:
        """Finds the workspace info by its schema."""
        for sapi_wsp_info in await self._client.storage_client.workspace_list():
            assert isinstance(sapi_wsp_info, dict)
            wi = _WspInfo.from_sapi_info(sapi_wsp_info)  # type: ignore[attr-defined]
            if wi.id and wi.backend and wi.schema and wi.schema == schema:
                return wi
        return None
    async def _find_ws_by_id(self, workspace_id: int) -> _WspInfo | None:
        """Finds the workspace info by its ID."""
        try:
            sapi_wsp_info = await self._client.storage_client.workspace_detail(workspace_id)
            assert isinstance(sapi_wsp_info, dict)
            wi = _WspInfo.from_sapi_info(sapi_wsp_info)  # type: ignore[attr-defined]
            if wi.id and wi.backend and wi.schema:
                return wi
            else:
                raise ValueError(f'Invalid workspace info: {sapi_wsp_info}')
        except HTTPStatusError as e:
            if e.response.status_code == 404:
                return None
            else:
                raise e
    async def _find_ws_in_branch(self) -> _WspInfo | None:
        """Finds the workspace info in the current branch."""
        metadata = await self._client.storage_client.branch_metadata_get()
        for m in metadata:
            if m.get('key') == self.MCP_META_KEY:
                workspace_id = m.get('value')
                if workspace_id and (info := await self._find_ws_by_id(workspace_id)) and info.readonly:
                    return info
        return None
    async def _create_ws(self, *, timeout_sec: float = 300.0) -> _WspInfo | None:
        """
        Creates a new workspace in the current branch and returns its info.
        :param timeout_sec: The number of seconds to wait for the workspace creation job to finish.
        :return: The workspace info if the workspace was created successfully, None otherwise.
        """
        # Verify token before creating workspace to ensure it has proper permissions
        token_info = await self._client.storage_client.verify_token()
        # Check for defaultBackend parameter in token info under owner object
        owner_info = token_info.get('owner', {})
        default_backend = owner_info.get('defaultBackend')
        if default_backend == 'snowflake':
            resp = await self._client.storage_client.workspace_create(
                login_type='snowflake-person-sso',
                backend=default_backend,
                async_run=True,
                read_only_storage_access=True,
            )
        elif default_backend == 'bigquery':
            resp = await self._client.storage_client.workspace_create(
                login_type='default', backend=default_backend, async_run=True, read_only_storage_access=True
            )
        else:
            raise ValueError(f'Unexpected default backend: {default_backend}')
        assert 'id' in resp, f'Expected job ID in response: {resp}'
        assert isinstance(resp['id'], int)
        job_id = resp['id']
        start_ts = time.perf_counter()
        LOG.info(f'Requested new workspace: job_id={job_id}, timeout={timeout_sec:.2f} seconds')
        while True:
            job_info = await self._client.storage_client.job_detail(job_id)
            job_status = job_info['status']
            duration = time.perf_counter() - start_ts
            LOG.info(
                f'Job info: job_id={job_id}, status={job_status}, '
                f'duration={duration:.2f} seconds, timeout={timeout_sec:.2f} seconds'
            )
            if job_info['status'] == 'success':
                assert 'results' in job_info, f'Expected `results` in job info: {job_info}'
                job_results = job_info['results']
                assert isinstance(job_results, dict)
                assert 'id' in job_results, f'Expected `id` in `results` in job info: {job_info}'
                assert isinstance(job_results['id'], int)
                workspace_id = job_results['id']
                LOG.info(f'Created workspace: {workspace_id}')
                return await self._find_ws_by_id(workspace_id)
            elif duration > timeout_sec:
                LOG.info(f'Workspace creation timed out after {duration:.2f} seconds.')
                return None
            else:
                remaining_time = max(0.0, timeout_sec - duration)
                await asyncio.sleep(min(5.0, remaining_time))
    def _init_workspace(self, info: _WspInfo) -> _Workspace:
        """Creates a new `Workspace` instance based on the workspace info."""
        if info.backend == 'snowflake':
            return _SnowflakeWorkspace(workspace_id=info.id, schema=info.schema, client=self._client)
        elif info.backend == 'bigquery':
            credentials = json.loads(info.credentials or '{}')
            if project_id := credentials.get('project_id'):
                return _BigQueryWorkspace(
                    workspace_id=info.id,
                    dataset_id=info.schema,
                    project_id=project_id,
                    client=self._client,
                )
            else:
                raise ValueError(f'No credentials or no project ID in workspace: {info.schema}')
        else:
            raise ValueError(f'Unexpected backend type "{info.backend}" in workspace: {info.schema}')
    async def _get_workspace(self) -> _Workspace:
        if self._workspace:
            return self._workspace
        if self._workspace_schema:
            # use the workspace that was explicitly requested
            # this workspace must never be written to the default branch metadata
            LOG.info(f'Looking up workspace by schema: {self._workspace_schema}')
            if info := await self._find_ws_by_schema(self._workspace_schema):
                LOG.info(f'Found workspace: {info}')
                self._workspace = self._init_workspace(info)
                return self._workspace
            else:
                raise ValueError(
                    f'No Keboola workspace found or the workspace has no read-only storage access: '
                    f'workspace_schema={self._workspace_schema}'
                )
        LOG.info('Looking up workspace in the default branch.')
        if info := await self._find_ws_in_branch():
            # use the workspace that has already been created by the MCP server and noted to the branch
            LOG.info(f'Found workspace: {info}')
            self._workspace = self._init_workspace(info)
            return self._workspace
        # create a new workspace and note its ID to the branch
        LOG.info('Creating workspace in the default branch.')
        if info := await self._create_ws():
            # update the branch metadata with the workspace ID
            meta = await self._client.storage_client.branch_metadata_update({self.MCP_META_KEY: info.id})
            LOG.info(f'Set metadata in the default branch: {meta}')
            # use the newly created workspace
            self._workspace = self._init_workspace(info)
            return self._workspace
        else:
            raise ValueError('Failed to initialize Keboola Workspace.')
    async def execute_query(self, sql_query: str) -> QueryResult:
        workspace = await self._get_workspace()
        return await workspace.execute_query(sql_query)
    async def get_table_info(self, table: Mapping[str, Any]) -> DbTableInfo | None:
        table_id = table['id']
        if table_id in self._table_info_cache:
            return self._table_info_cache[table_id]
        workspace = await self._get_workspace()
        if info := await workspace.get_table_info(table):
            self._table_info_cache[table_id] = info
        return info
    async def get_quoted_name(self, name: str) -> str:
        workspace = await self._get_workspace()
        return workspace.get_quoted_name(name)
    async def get_sql_dialect(self) -> str:
        workspace = await self._get_workspace()
        return workspace.get_sql_dialect()
    async def get_workspace_id(self) -> int:
        workspace = await self._get_workspace()
        return workspace.id