workspace.py•22.3 kB
import abc
import asyncio
import json
import logging
import time
from typing import Any, Literal, Mapping, Sequence, cast
from urllib.parse import urlunparse
from httpx import HTTPStatusError
from pydantic import Field, TypeAdapter
from pydantic.dataclasses import dataclass
from keboola_mcp_server.clients.base import JsonDict
from keboola_mcp_server.clients.client import KeboolaClient
from keboola_mcp_server.clients.query import QueryServiceClient
LOG = logging.getLogger(__name__)
@dataclass(frozen=True)
class TableFqn:
"""The properly quoted parts of a fully qualified table name."""
# TODO: refactor this and probably use just a simple string
db_name: str # project_id in a BigQuery
schema_name: str # dataset in a BigQuery
table_name: str
quote_char: str = ''
@property
def identifier(self) -> str:
"""Returns the properly quoted database identifier."""
return '.'.join(
f'{self.quote_char}{n}{self.quote_char}' for n in [self.db_name, self.schema_name, self.table_name]
)
def __repr__(self) -> str:
return self.identifier
def __str__(self) -> str:
return self.__repr__()
@dataclass(frozen=True)
class DbColumnInfo:
name: str
quoted_name: str
native_type: str
nullable: bool
@dataclass(frozen=True)
class DbTableInfo:
id: str
fqn: TableFqn
columns: Mapping[str, DbColumnInfo]
QueryStatus = Literal['ok', 'error']
SqlSelectDataRow = Mapping[str, Any]
@dataclass(frozen=True)
class SqlSelectData:
columns: Sequence[str] = Field(description='Names of the columns returned from SQL select.')
rows: Sequence[SqlSelectDataRow] = Field(
description='Selected rows, each row is a dictionary of column: value pairs.'
)
@dataclass(frozen=True)
class QueryResult:
status: QueryStatus = Field(description='Status of running the SQL query.')
data: SqlSelectData | None = Field(None, description='Data selected by the SQL SELECT query.')
message: str | None = Field(None, description='Either an error message or the information from non-SELECT queries.')
@property
def is_ok(self) -> bool:
return self.status == 'ok'
@property
def is_error(self) -> bool:
return not self.is_ok
class _Workspace(abc.ABC):
_QUERY_TIMEOUT = 300.0 # 5 minutes
def __init__(self, workspace_id: int) -> None:
self._workspace_id = workspace_id
@property
def id(self) -> int:
return self._workspace_id
@abc.abstractmethod
def get_sql_dialect(self) -> str:
pass
@abc.abstractmethod
def get_quoted_name(self, name: str) -> str:
pass
@abc.abstractmethod
async def get_table_info(self, table: Mapping[str, Any]) -> DbTableInfo | None:
# TODO: use a pydantic class for the 'table' param
pass
@abc.abstractmethod
async def execute_query(self, sql_query: str) -> QueryResult:
"""Runs a SQL SELECT query."""
pass
@abc.abstractmethod
async def get_branch_id(self) -> str:
"""Returns the branch ID."""
pass
class _SnowflakeWorkspace(_Workspace):
def __init__(self, workspace_id: int, schema: str, client: KeboolaClient):
super().__init__(workspace_id)
self._schema = schema # default schema created for the workspace
self._client = client
self._qsclient: QueryServiceClient | None = None
def get_sql_dialect(self) -> str:
return 'Snowflake'
def get_quoted_name(self, name: str) -> str:
return f'"{name}"' # wrap name in double quotes
async def get_table_info(self, table: Mapping[str, Any]) -> DbTableInfo | None:
table_id = table['id']
db_name: str | None = None
schema_name: str | None = None
table_name: str | None = None
if source_table := table.get('sourceTable'):
# a table linked from some other project
schema_name, table_name = source_table['id'].rsplit(sep='.', maxsplit=1)
source_project_id = source_table['project']['id']
# sql = f"show databases like '%_{source_project_id}';"
sql = (
f'select "DATABASE_NAME" from "INFORMATION_SCHEMA"."DATABASES" '
f'where "DATABASE_NAME" like \'%_{source_project_id}\';'
)
result = await self.execute_query(sql)
if result.is_ok and result.data and result.data.rows:
db_name = result.data.rows[0]['DATABASE_NAME']
else:
LOG.error(f'Failed to run SQL: {sql}, SAPI response: {result}')
else:
sql = 'select CURRENT_DATABASE() as "current_database";'
result = await self.execute_query(sql)
if result.is_ok and result.data and result.data.rows:
row = result.data.rows[0]
db_name = row['current_database']
if '.' in table_id:
# a table local in a project for which the snowflake connection/workspace is open
schema_name, table_name = table_id.rsplit(sep='.', maxsplit=1)
else:
# a table not in the project, but in the writable schema created for the workspace
# TODO: we should never come here, because the tools for listing tables can only see
# tables that are in the project
schema_name = self._schema
table_name = table['name']
else:
LOG.error(f'Failed to run SQL: {sql}, SAPI response: {result}')
if db_name and schema_name and table_name:
sql = (
f'SELECT "COLUMN_NAME", "DATA_TYPE", "IS_NULLABLE" '
f'FROM "INFORMATION_SCHEMA"."COLUMNS" '
f'WHERE "TABLE_CATALOG" = \'{db_name}\' AND "TABLE_SCHEMA" = \'{schema_name}\' '
f'AND "TABLE_NAME" = \'{table_name}\' '
f'ORDER BY "ORDINAL_POSITION";'
)
result = await self.execute_query(sql)
if result.is_ok and result.data:
return DbTableInfo(
id=table_id,
fqn=TableFqn(db_name, schema_name, table_name, quote_char='"'),
columns={
row['COLUMN_NAME']: DbColumnInfo(
name=row['COLUMN_NAME'],
quoted_name=self.get_quoted_name(row['COLUMN_NAME']),
native_type=row['DATA_TYPE'],
nullable=row['IS_NULLABLE'] == 'YES',
)
for row in result.data.rows
},
)
else:
LOG.error(f'Failed to run SQL: {sql}, SAPI response: {result}')
return None
async def execute_query(self, sql_query: str) -> QueryResult:
if not self._qsclient:
self._qsclient = await self._create_qs_client()
ts_start = time.perf_counter()
job_id = await self._qsclient.submit_job(statements=[sql_query], workspace_id=str(self.id))
while (job_status := await self._qsclient.get_job_status(job_id)) and job_status['status'] not in [
'completed',
'failed',
'canceled',
]:
await asyncio.sleep(1)
elapsed_time = time.perf_counter() - ts_start
if elapsed_time > self._QUERY_TIMEOUT:
raise RuntimeError(
f'Query execution timed out after {elapsed_time:.2f} seconds: '
f'job_id={job_id}, status={job_status["status"]}'
)
statement_id = cast(list[JsonDict], job_status['statements'])[0]['id']
results = await self._qsclient.get_job_results(job_id, statement_id)
if results['status'] == 'completed':
columns = [col['name'] for col in cast(list[JsonDict], results['columns'])]
rows = [
{col_name: value for col_name, value in zip(columns, row)}
for row in cast(list[list[Any]], results['data'])
]
query_result = QueryResult(
status='ok',
data=SqlSelectData(columns=columns, rows=rows) if columns else None,
message=results['message'],
)
elif results['status'] in ['failed', 'canceled']:
query_result = QueryResult(status='error', data=None, message=results['message'])
else:
raise ValueError(f'Unexpected query status: {results["status"]}')
return query_result
async def get_branch_id(self) -> str:
if not self._qsclient:
self._qsclient = await self._create_qs_client()
return self._qsclient.branch_id
async def _create_qs_client(self) -> QueryServiceClient:
real_branch_id = self._client.branch_id
if not real_branch_id:
for branch in await self._client.storage_client.branches_list():
if (is_default := branch.get('isDefault')) and isinstance(is_default, bool) and is_default:
real_branch_id = branch['id']
break
if not real_branch_id:
raise RuntimeError('Cannot determine the default branch ID')
return QueryServiceClient.create(
root_url=urlunparse(('https', f'query.{self._client.hostname_suffix}', '', '', '', '')),
branch_id=real_branch_id,
token=self._client.token,
headers=self._client.headers,
)
class _BigQueryWorkspace(_Workspace):
_BQ_FIELDS = {'_timestamp'}
def __init__(self, workspace_id: int, dataset_id: str, project_id: str, client: KeboolaClient):
super().__init__(workspace_id)
self._dataset_id = dataset_id # default dataset created for the workspace
self._project_id = project_id
self._client = client
def get_sql_dialect(self) -> str:
return 'BigQuery'
def get_quoted_name(self, name: str) -> str:
return f'`{name}`' # wrap name in back tick
async def get_table_info(self, table: Mapping[str, Any]) -> DbTableInfo | None:
table_id = table['id']
if '.' in table_id:
# a table local in a project for which the workspace is open
schema_name, table_name = table_id.rsplit(sep='.', maxsplit=1)
schema_name = schema_name.replace('.', '_').replace('-', '_')
else:
# a table not in the project, but in the writable schema created for the workspace
# TODO: we should never come here, because the tools for listing tables can only see
# tables that are in the project
schema_name = self._dataset_id
table_name = table['name']
if schema_name and table_name:
sql = (
f'SELECT column_name, data_type, is_nullable '
f'FROM `{self._project_id}`.`{schema_name}`.`INFORMATION_SCHEMA`.`COLUMNS` '
f"WHERE table_name = '{table_name}' "
f'ORDER BY ordinal_position;'
)
result = await self.execute_query(sql)
if result.is_ok and result.data:
return DbTableInfo(
id=table_id,
fqn=TableFqn(self._project_id, schema_name, table_name, quote_char='`'),
columns={
row['column_name']: DbColumnInfo(
name=row['column_name'],
quoted_name=self.get_quoted_name(row['column_name']),
native_type=row['data_type'],
nullable=row['is_nullable'] == 'YES',
)
for row in result.data.rows
},
)
else:
LOG.error(f'Failed to run SQL: {sql}, SAPI response: {result}')
return None
async def execute_query(self, sql_query: str) -> QueryResult:
resp = await self._client.storage_client.workspace_query(workspace_id=self.id, query=sql_query)
return TypeAdapter(QueryResult).validate_python(resp)
async def get_branch_id(self) -> str:
return self._client.branch_id or 'default'
@dataclass(frozen=True)
class _WspInfo:
id: int
schema: str
backend: str
credentials: str | None # the backend credentials; it can contain serialized JSON data
readonly: bool | None
@staticmethod
def from_sapi_info(sapi_wsp_info: Mapping[str, Any]) -> '_WspInfo':
_id = sapi_wsp_info.get('id')
backend = sapi_wsp_info.get('connection', {}).get('backend')
_schema = sapi_wsp_info.get('connection', {}).get('schema')
credentials = sapi_wsp_info.get('connection', {}).get('user')
readonly = sapi_wsp_info.get('readOnlyStorageAccess')
return _WspInfo(id=_id, schema=_schema, backend=backend, credentials=credentials, readonly=readonly)
class WorkspaceManager:
STATE_KEY = 'workspace_manager'
MCP_META_KEY = 'KBC.McpServer.workspaceId'
@classmethod
def from_state(cls, state: Mapping[str, Any]) -> 'WorkspaceManager':
instance = state[cls.STATE_KEY]
assert isinstance(instance, WorkspaceManager), f'Expected WorkspaceManager, got: {instance}'
return instance
def __init__(self, client: KeboolaClient, workspace_schema: str | None = None):
# We use the read-only workspace with access to all project data which lives in the production branch.
# Hence, we need KeboolaClient bound to the production/default branch.
self._client = client.with_branch_id(None)
self._workspace_schema = workspace_schema
self._workspace: _Workspace | None = None
self._table_info_cache: dict[str, DbTableInfo] = {}
async def _find_ws_by_schema(self, schema: str) -> _WspInfo | None:
"""Finds the workspace info by its schema."""
for sapi_wsp_info in await self._client.storage_client.workspace_list():
assert isinstance(sapi_wsp_info, dict)
wi = _WspInfo.from_sapi_info(sapi_wsp_info) # type: ignore[attr-defined]
if wi.id and wi.backend and wi.schema and wi.schema == schema:
return wi
return None
async def _find_ws_by_id(self, workspace_id: int) -> _WspInfo | None:
"""Finds the workspace info by its ID."""
try:
sapi_wsp_info = await self._client.storage_client.workspace_detail(workspace_id)
assert isinstance(sapi_wsp_info, dict)
wi = _WspInfo.from_sapi_info(sapi_wsp_info) # type: ignore[attr-defined]
if wi.id and wi.backend and wi.schema:
return wi
else:
raise ValueError(f'Invalid workspace info: {sapi_wsp_info}')
except HTTPStatusError as e:
if e.response.status_code == 404:
return None
else:
raise e
async def _find_ws_in_branch(self) -> _WspInfo | None:
"""Finds the workspace info in the current branch."""
metadata = await self._client.storage_client.branch_metadata_get()
for m in metadata:
if m.get('key') == self.MCP_META_KEY:
workspace_id = m.get('value')
if workspace_id and (info := await self._find_ws_by_id(workspace_id)) and info.readonly:
return info
return None
async def _create_ws(self, *, timeout_sec: float = 300.0) -> _WspInfo | None:
"""
Creates a new workspace in the current branch and returns its info.
:param timeout_sec: The number of seconds to wait for the workspace creation job to finish.
:return: The workspace info if the workspace was created successfully, None otherwise.
"""
# Verify token before creating workspace to ensure it has proper permissions
token_info = await self._client.storage_client.verify_token()
# Check for defaultBackend parameter in token info under owner object
owner_info = token_info.get('owner', {})
default_backend = owner_info.get('defaultBackend')
if default_backend == 'snowflake':
resp = await self._client.storage_client.workspace_create(
login_type='snowflake-person-sso',
backend=default_backend,
async_run=True,
read_only_storage_access=True,
)
elif default_backend == 'bigquery':
resp = await self._client.storage_client.workspace_create(
login_type='default', backend=default_backend, async_run=True, read_only_storage_access=True
)
else:
raise ValueError(f'Unexpected default backend: {default_backend}')
assert 'id' in resp, f'Expected job ID in response: {resp}'
assert isinstance(resp['id'], int)
job_id = resp['id']
start_ts = time.perf_counter()
LOG.info(f'Requested new workspace: job_id={job_id}, timeout={timeout_sec:.2f} seconds')
while True:
job_info = await self._client.storage_client.job_detail(job_id)
job_status = job_info['status']
duration = time.perf_counter() - start_ts
LOG.info(
f'Job info: job_id={job_id}, status={job_status}, '
f'duration={duration:.2f} seconds, timeout={timeout_sec:.2f} seconds'
)
if job_info['status'] == 'success':
assert 'results' in job_info, f'Expected `results` in job info: {job_info}'
job_results = job_info['results']
assert isinstance(job_results, dict)
assert 'id' in job_results, f'Expected `id` in `results` in job info: {job_info}'
assert isinstance(job_results['id'], int)
workspace_id = job_results['id']
LOG.info(f'Created workspace: {workspace_id}')
return await self._find_ws_by_id(workspace_id)
elif duration > timeout_sec:
LOG.info(f'Workspace creation timed out after {duration:.2f} seconds.')
return None
else:
remaining_time = max(0.0, timeout_sec - duration)
await asyncio.sleep(min(5.0, remaining_time))
def _init_workspace(self, info: _WspInfo) -> _Workspace:
"""Creates a new `Workspace` instance based on the workspace info."""
if info.backend == 'snowflake':
return _SnowflakeWorkspace(workspace_id=info.id, schema=info.schema, client=self._client)
elif info.backend == 'bigquery':
credentials = json.loads(info.credentials or '{}')
if project_id := credentials.get('project_id'):
return _BigQueryWorkspace(
workspace_id=info.id,
dataset_id=info.schema,
project_id=project_id,
client=self._client,
)
else:
raise ValueError(f'No credentials or no project ID in workspace: {info.schema}')
else:
raise ValueError(f'Unexpected backend type "{info.backend}" in workspace: {info.schema}')
async def _get_workspace(self) -> _Workspace:
if self._workspace:
return self._workspace
if self._workspace_schema:
# use the workspace that was explicitly requested
# this workspace must never be written to the default branch metadata
LOG.info(f'Looking up workspace by schema: {self._workspace_schema}')
if info := await self._find_ws_by_schema(self._workspace_schema):
LOG.info(f'Found workspace: {info}')
self._workspace = self._init_workspace(info)
return self._workspace
else:
raise ValueError(
f'No Keboola workspace found or the workspace has no read-only storage access: '
f'workspace_schema={self._workspace_schema}'
)
LOG.info('Looking up workspace in the default branch.')
if info := await self._find_ws_in_branch():
# use the workspace that has already been created by the MCP server and noted to the branch
LOG.info(f'Found workspace: {info}')
self._workspace = self._init_workspace(info)
return self._workspace
# create a new workspace and note its ID to the branch
LOG.info('Creating workspace in the default branch.')
if info := await self._create_ws():
# update the branch metadata with the workspace ID
meta = await self._client.storage_client.branch_metadata_update({self.MCP_META_KEY: info.id})
LOG.info(f'Set metadata in the default branch: {meta}')
# use the newly created workspace
self._workspace = self._init_workspace(info)
return self._workspace
else:
raise ValueError('Failed to initialize Keboola Workspace.')
async def execute_query(self, sql_query: str) -> QueryResult:
workspace = await self._get_workspace()
return await workspace.execute_query(sql_query)
async def get_table_info(self, table: Mapping[str, Any]) -> DbTableInfo | None:
table_id = table['id']
if table_id in self._table_info_cache:
return self._table_info_cache[table_id]
workspace = await self._get_workspace()
if info := await workspace.get_table_info(table):
self._table_info_cache[table_id] = info
return info
async def get_quoted_name(self, name: str) -> str:
workspace = await self._get_workspace()
return workspace.get_quoted_name(name)
async def get_sql_dialect(self) -> str:
workspace = await self._get_workspace()
return workspace.get_sql_dialect()
async def get_workspace_id(self) -> int:
workspace = await self._get_workspace()
return workspace.id
async def get_branch_id(self) -> str:
workspace = await self._get_workspace()
return await workspace.get_branch_id()