Skip to main content
Glama
fake_snowflake_connector.py9.29 kB
"""Deterministic Snowflake service/cursor fakes for execute_query tests.""" from __future__ import annotations import threading import time from dataclasses import dataclass from typing import Any, Dict, Iterable, List, Optional @dataclass class FakeSessionDefaults: """Session defaults returned by snapshot_session in tests.""" role: str = "TEST_ROLE" warehouse: str = "TEST_WH" database: str = "TEST_DB" schema: str = "PUBLIC" @dataclass class FakeQueryPlan: """Plan describing how the fake cursor should behave for a single query.""" statement: str rows: Optional[List[Dict[str, Any]]] = None rowcount: Optional[int] = None duration: float = 0.05 sfqid: str = "FAKE_QID_123" error: Optional[Exception] = None def clone(self) -> "FakeQueryPlan": rows_copy = None if self.rows is not None: rows_copy = [] for row in self.rows: if isinstance(row, dict): rows_copy.append(dict(row)) elif hasattr(row, "_asdict"): rows_copy.append(row.__class__(**row._asdict())) else: rows_copy.append(row) return FakeQueryPlan( statement=self.statement, rows=rows_copy, rowcount=self.rowcount, duration=self.duration, sfqid=self.sfqid, error=self.error, ) class FakeSnowflakeService: """Service that returns fixture-backed cursors following provided plans.""" def __init__( self, plans: Iterable[FakeQueryPlan], *, session_defaults: Optional[FakeSessionDefaults] = None, query_tag_param: Optional[Dict[str, Any]] = None, ) -> None: self._plans = [plan.clone() for plan in plans] if not self._plans: raise ValueError("FakeSnowflakeService requires at least one query plan.") self._plan_index = 0 self.session_defaults = session_defaults or FakeSessionDefaults() self.cursors: List[FakeSnowflakeCursor] = [] self._snowcli_session_lock = threading.Lock() self._query_tag_param = dict(query_tag_param or {}) def get_query_tag_param(self) -> Dict[str, Any]: return dict(self._query_tag_param) def get_connection(self, **_: Any) -> "FakeSnowflakeConnection": plan = self._consume_plan() cursor = FakeSnowflakeCursor(plan, self.session_defaults) self.cursors.append(cursor) return FakeSnowflakeConnection(cursor) def _consume_plan(self) -> FakeQueryPlan: if self._plan_index < len(self._plans): plan = self._plans[self._plan_index].clone() self._plan_index += 1 return plan # Reuse the last plan when more connections are requested return self._plans[-1].clone() class FakeSnowflakeConnection: """Context manager returning the prepared fake cursor.""" def __init__(self, cursor: "FakeSnowflakeCursor") -> None: self.cursor = cursor def __enter__(self) -> tuple[None, "FakeSnowflakeCursor"]: return None, self.cursor def __exit__(self, exc_type, exc, tb) -> bool: # noqa: D401 - standard protocol return False class FakeSnowflakeCursor: """Cursor that emulates the subset of Snowflake behaviour execute_query needs.""" def __init__( self, plan: FakeQueryPlan, session_defaults: FakeSessionDefaults, ) -> None: self.plan = plan self.session_defaults = session_defaults self.sfqid: Optional[str] = None self.description: Optional[List[tuple[str]]] = None self.rowcount: int = 0 self._rows: List[Dict[str, Any]] = [] self._cancelled: bool = False self._main_executed: bool = False self._fetchone_map: Dict[str, Any] = {} self._session_parameters: Dict[str, Optional[str]] = { "QUERY_TAG": None, "STATEMENT_TIMEOUT_IN_SECONDS": "0", } self.query_tags_seen: List[Optional[str]] = [] self.statement_timeouts_seen: List[Optional[str]] = [] # -- DB-API subset --------------------------------------------------- def execute(self, query: str) -> None: normalized = " ".join(query.strip().split()) upper = normalized.upper() if upper.startswith("SHOW PARAMETERS LIKE 'QUERY_TAG'"): self.description = [("KEY",), ("VALUE",)] tag = self._session_parameters.get("QUERY_TAG") or "" self._rows = [{"KEY": "QUERY_TAG", "VALUE": tag}] return if upper.startswith("SHOW PARAMETERS LIKE 'STATEMENT_TIMEOUT_IN_SECONDS'"): self.description = [("KEY",), ("VALUE",)] timeout = self._session_parameters.get("STATEMENT_TIMEOUT_IN_SECONDS") or "" self._rows = [{"KEY": "STATEMENT_TIMEOUT_IN_SECONDS", "VALUE": timeout}] return if upper.startswith("ALTER SESSION SET QUERY_TAG"): value = self._extract_assignment_value(normalized) self._session_parameters["QUERY_TAG"] = value or None self._rows = [] self.description = None return if upper.startswith("ALTER SESSION UNSET QUERY_TAG"): self._session_parameters["QUERY_TAG"] = None self._rows = [] self.description = None return if upper.startswith("ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS"): value = self._extract_assignment_value(normalized) self._session_parameters["STATEMENT_TIMEOUT_IN_SECONDS"] = value or "0" self._rows = [] self.description = None return if upper.startswith("ALTER SESSION UNSET STATEMENT_TIMEOUT_IN_SECONDS"): self._session_parameters["STATEMENT_TIMEOUT_IN_SECONDS"] = "0" self._rows = [] self.description = None return if "CURRENT_ROLE()" in upper and "CURRENT_WAREHOUSE()" in upper: self._fetchone_map = { "ROLE": self.session_defaults.role, "WAREHOUSE": self.session_defaults.warehouse, "DATABASE": self.session_defaults.database, "SCHEMA": self.session_defaults.schema, } return if upper.startswith("USE ROLE") or upper.startswith("USE WAREHOUSE"): self._rows = [] self.description = None return if upper.startswith("USE DATABASE") or upper.startswith("USE SCHEMA"): self._rows = [] self.description = None return # Main statement execution if not self._main_executed: self._execute_plan(normalized) return raise RuntimeError(f"Unexpected extra execute call in fake cursor: {query}") def fetchall(self) -> List[Dict[str, Any]]: return list(self._rows) def fetchone(self) -> Dict[str, Any]: return dict(self._fetchone_map) def cancel(self) -> None: self._cancelled = True # -- Internal helpers ------------------------------------------------ def _execute_plan(self, normalized_query: str) -> None: self._main_executed = True expected = " ".join(self.plan.statement.strip().split()).upper() if expected and normalized_query.upper() != expected: raise AssertionError( f"Expected query '{self.plan.statement}' but received '{normalized_query}'" ) if self.plan.error: raise self.plan.error deadline = time.time() + max(self.plan.duration, 0.0) while time.time() < deadline: if self._cancelled: self.description = None self._rows = [] self.rowcount = 0 self.sfqid = None return time.sleep(0.01) self.sfqid = self.plan.sfqid self.query_tags_seen.append(self._session_parameters.get("QUERY_TAG")) self.statement_timeouts_seen.append( self._session_parameters.get("STATEMENT_TIMEOUT_IN_SECONDS") ) if self.plan.rows is not None: column_names = self._infer_column_names(self.plan.rows) self.description = [(name,) for name in column_names] self._rows = list(self.plan.rows) self.rowcount = self.plan.rowcount or len(self.plan.rows) else: self.description = None self._rows = [] self.rowcount = int(self.plan.rowcount or 0) def _infer_column_names(self, rows: List[Dict[str, Any]]) -> List[str]: if not rows: return [] first = rows[0] if isinstance(first, dict): return list(first.keys()) return [f"column_{idx}" for idx in range(len(first))] @staticmethod def _extract_assignment_value(normalized: str) -> Optional[str]: if "=" not in normalized: return None _, value = normalized.split("=", 1) cleaned = value.strip().strip(";").strip() if cleaned.startswith("'") and cleaned.endswith("'"): cleaned = cleaned[1:-1] return cleaned or None

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/Evan-Kim2028/igloo-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server