Skip to main content
Glama
fake_snowflake_connector.py10.9 kB
"""Deterministic Snowflake service/cursor fakes for execute_query tests.""" from __future__ import annotations import threading import time from collections.abc import Iterable from dataclasses import dataclass from typing import Any @dataclass class FakeSessionDefaults: """Session defaults returned by snapshot_session in tests.""" role: str = "TEST_ROLE" warehouse: str = "TEST_WH" database: str = "TEST_DB" schema: str = "PUBLIC" @dataclass class FakeQueryPlan: """Plan describing how the fake cursor should behave for a single query.""" statement: str rows: list[dict[str, Any]] | None = None rowcount: int | None = None duration: float = 0.05 sfqid: str = "FAKE_QID_123" error: Exception | None = None def clone(self) -> FakeQueryPlan: rows_copy = None if self.rows is not None: rows_copy = [] for row in self.rows: if isinstance(row, dict): rows_copy.append(dict(row)) elif hasattr(row, "_asdict"): rows_copy.append(row.__class__(**row._asdict())) else: rows_copy.append(row) return FakeQueryPlan( statement=self.statement, rows=rows_copy, rowcount=self.rowcount, duration=self.duration, sfqid=self.sfqid, error=self.error, ) class FakeSnowflakeService: """Service that returns fixture-backed cursors following provided plans.""" def __init__( self, plans: Iterable[FakeQueryPlan], *, session_defaults: FakeSessionDefaults | None = None, query_tag_param: dict[str, Any] | None = None, ) -> None: self._plans = [plan.clone() for plan in plans] if not self._plans: raise ValueError("FakeSnowflakeService requires at least one query plan.") self._plan_index = 0 self.session_defaults = session_defaults or FakeSessionDefaults() self.cursors: list[FakeSnowflakeCursor] = [] self._snowcli_session_lock = threading.Lock() self._query_tag_param = dict(query_tag_param or {}) def get_query_tag_param(self) -> dict[str, Any]: return dict(self._query_tag_param) def add_query_plan(self, plan: FakeQueryPlan) -> None: """Add a new query plan to the service. Plans are consumed in order as queries are executed. Useful for system tests that dynamically add queries. """ self._plans.append(plan.clone()) def get_connection(self, **_: Any) -> FakeSnowflakeConnection: plan = self._consume_plan() cursor = FakeSnowflakeCursor(plan, self.session_defaults) self.cursors.append(cursor) return FakeSnowflakeConnection(cursor) def _consume_plan(self) -> FakeQueryPlan: if self._plan_index < len(self._plans): plan = self._plans[self._plan_index].clone() self._plan_index += 1 return plan # Reuse the last plan when more connections are requested return self._plans[-1].clone() class FakeSnowflakeConnection: """Context manager returning the prepared fake cursor.""" def __init__(self, cursor: FakeSnowflakeCursor) -> None: self.cursor = cursor def __enter__(self) -> tuple[None, FakeSnowflakeCursor]: return None, self.cursor def __exit__(self, exc_type, exc, tb) -> bool: return False class FakeSnowflakeCursor: """Cursor that emulates the subset of Snowflake behaviour execute_query needs.""" def __init__( self, plan: FakeQueryPlan, session_defaults: FakeSessionDefaults, ) -> None: self.plan = plan self.session_defaults = session_defaults self.sfqid: str | None = None self.description: list[tuple[str]] | None = None self.rowcount: int = 0 self._rows: list[dict[str, Any]] = [] self._cancelled: bool = False self._main_executed: bool = False self._fetchone_map: dict[str, Any] = {} self._session_parameters: dict[str, str | None] = { "QUERY_TAG": None, "STATEMENT_TIMEOUT_IN_SECONDS": "0", } self.query_tags_seen: list[str | None] = [] self.statement_timeouts_seen: list[str | None] = [] # -- DB-API subset --------------------------------------------------- def execute(self, query: str) -> None: normalized = " ".join(query.strip().split()) upper = normalized.upper() if upper.startswith("SHOW PARAMETERS LIKE 'QUERY_TAG'"): self.description = [("KEY",), ("VALUE",)] tag = self._session_parameters.get("QUERY_TAG") or "" self._rows = [{"KEY": "QUERY_TAG", "VALUE": tag}] return # Handle both plain and escaped versions of QUERY_TAG if upper.startswith("SHOW PARAMETERS LIKE 'QUERY_TAG'") or upper.startswith( r"SHOW PARAMETERS LIKE 'QUERY\_TAG'" ): self.description = [("KEY",), ("VALUE",)] tag = self._session_parameters.get("QUERY_TAG") or "" self._rows = [{"KEY": "QUERY_TAG", "VALUE": tag}] return # Handle both plain and escaped versions of STATEMENT_TIMEOUT_IN_SECONDS if upper.startswith("SHOW PARAMETERS LIKE 'STATEMENT_TIMEOUT_IN_SECONDS'") or upper.startswith( r"SHOW PARAMETERS LIKE 'STATEMENT\_TIMEOUT\_IN\_SECONDS'" ): self.description = [("KEY",), ("VALUE",)] timeout = self._session_parameters.get("STATEMENT_TIMEOUT_IN_SECONDS") or "" self._rows = [{"KEY": "STATEMENT_TIMEOUT_IN_SECONDS", "VALUE": timeout}] return if upper.startswith("ALTER SESSION SET QUERY_TAG"): value = self._extract_assignment_value(normalized) self._session_parameters["QUERY_TAG"] = value or None self._rows = [] self.description = None return if upper.startswith("ALTER SESSION UNSET QUERY_TAG"): self._session_parameters["QUERY_TAG"] = None self._rows = [] self.description = None return if upper.startswith("ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS"): value = self._extract_assignment_value(normalized) self._session_parameters["STATEMENT_TIMEOUT_IN_SECONDS"] = value or "0" self._rows = [] self.description = None return if upper.startswith("ALTER SESSION UNSET STATEMENT_TIMEOUT_IN_SECONDS"): self._session_parameters["STATEMENT_TIMEOUT_IN_SECONDS"] = "0" self._rows = [] self.description = None return # Generic ALTER SESSION handler for any other session parameters if upper.startswith("ALTER SESSION"): self._rows = [] self.description = None return # Handle snapshot_session query: # SELECT CURRENT_ROLE(), CURRENT_WAREHOUSE(), CURRENT_DATABASE(), CURRENT_SCHEMA() if "CURRENT_ROLE()" in upper and "CURRENT_WAREHOUSE()" in upper: self._fetchone_map = { "ROLE": self.session_defaults.role, "WAREHOUSE": self.session_defaults.warehouse, "DATABASE": self.session_defaults.database, "SCHEMA": self.session_defaults.schema, } self.description = [("ROLE",), ("WAREHOUSE",), ("DATABASE",), ("SCHEMA",)] self._rows = [self._fetchone_map] return if upper.startswith("USE ROLE") or upper.startswith("USE WAREHOUSE"): self._rows = [] self.description = None return if upper.startswith("USE DATABASE") or upper.startswith("USE SCHEMA"): self._rows = [] self.description = None return # Main statement execution if not self._main_executed: self._execute_plan(normalized) return # Allow re-execution if it matches the plan (for cache hit scenarios or retries) expected = " ".join(self.plan.statement.strip().split()).upper() if expected and normalized.upper() == expected: # Reset and re-execute self._main_executed = False self._execute_plan(normalized) return raise RuntimeError(f"Unexpected extra execute call in fake cursor: {query}") def fetchall(self) -> list[dict[str, Any]]: return list(self._rows) def fetchone(self) -> dict[str, Any]: return dict(self._fetchone_map) def cancel(self) -> None: self._cancelled = True # -- Internal helpers ------------------------------------------------ def _execute_plan(self, normalized_query: str) -> None: self._main_executed = True expected = " ".join(self.plan.statement.strip().split()).upper() if expected and normalized_query.upper() != expected: raise AssertionError(f"Expected query '{self.plan.statement}' but received '{normalized_query}'") if self.plan.error: raise self.plan.error deadline = time.time() + max(self.plan.duration, 0.0) while time.time() < deadline: if self._cancelled: self.description = None self._rows = [] self.rowcount = 0 self.sfqid = None return time.sleep(0.01) self.sfqid = self.plan.sfqid self.query_tags_seen.append(self._session_parameters.get("QUERY_TAG")) self.statement_timeouts_seen.append(self._session_parameters.get("STATEMENT_TIMEOUT_IN_SECONDS")) if self.plan.rows is not None: column_names = self._infer_column_names(self.plan.rows) self.description = [(name,) for name in column_names] self._rows = list(self.plan.rows) self.rowcount = self.plan.rowcount or len(self.plan.rows) else: self.description = None self._rows = [] self.rowcount = int(self.plan.rowcount or 0) def _infer_column_names(self, rows: list[dict[str, Any]]) -> list[str]: if not rows: return [] first = rows[0] if isinstance(first, dict): return list(first.keys()) return [f"column_{idx}" for idx in range(len(first))] @staticmethod def _extract_assignment_value(normalized: str) -> str | None: if "=" not in normalized: return None _, value = normalized.split("=", 1) cleaned = value.strip().strip(";").strip() if cleaned.startswith("'") and cleaned.endswith("'"): cleaned = cleaned[1:-1] return cleaned or None

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/Evan-Kim2028/igloo-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server