mcp-snowflake-server
- src
- mcp_snowflake_server
import sqlparse
from sqlparse.sql import Token, TokenList
from sqlparse.tokens import Keyword, DML, DDL
from typing import Dict, List, Set, Tuple
class SQLWriteDetector:
def __init__(self):
# Define sets of keywords that indicate write operations
self.dml_write_keywords = {"INSERT", "UPDATE", "DELETE", "MERGE", "UPSERT", "REPLACE"}
self.ddl_keywords = {"CREATE", "ALTER", "DROP", "TRUNCATE", "RENAME"}
self.dcl_keywords = {"GRANT", "REVOKE"}
# Combine all write keywords
self.write_keywords = self.dml_write_keywords | self.ddl_keywords | self.dcl_keywords
def analyze_query(self, sql_query: str) -> Dict:
"""
Analyze a SQL query to determine if it contains write operations.
Args:
sql_query: The SQL query string to analyze
Returns:
Dictionary containing analysis results
"""
# Parse the SQL query
parsed = sqlparse.parse(sql_query)
if not parsed:
return {"contains_write": False, "write_operations": set(), "has_cte_write": False}
# Initialize result tracking
found_operations = set()
has_cte_write = False
# Analyze each statement in the query
for statement in parsed:
# Check for write operations in CTEs (WITH clauses)
if self._has_cte(statement):
cte_write = self._analyze_cte(statement)
if cte_write:
has_cte_write = True
found_operations.add("CTE_WRITE")
# Analyze the main query
operations = self._find_write_operations(statement)
found_operations.update(operations)
return {
"contains_write": bool(found_operations) or has_cte_write,
"write_operations": found_operations,
"has_cte_write": has_cte_write,
}
def _has_cte(self, statement: TokenList) -> bool:
"""Check if the statement has a WITH clause."""
return any(token.is_keyword and token.normalized == "WITH" for token in statement.tokens)
def _analyze_cte(self, statement: TokenList) -> bool:
"""
Analyze CTEs (WITH clauses) for write operations.
Returns True if any CTE contains a write operation.
"""
in_cte = False
for token in statement.tokens:
if token.is_keyword and token.normalized == "WITH":
in_cte = True
elif in_cte:
if any(write_kw in token.normalized for write_kw in self.write_keywords):
return True
return False
def _find_write_operations(self, statement: TokenList) -> Set[str]:
"""
Find all write operations in a statement.
Returns a set of found write operation keywords.
"""
operations = set()
for token in statement.tokens:
# Skip comments and whitespace
if token.is_whitespace or token.ttype in (sqlparse.tokens.Comment,):
continue
# Check if token is a keyword
if token.ttype in (Keyword, DML, DDL):
normalized = token.normalized.upper()
if normalized in self.write_keywords:
operations.add(normalized)
# Recursively check child tokens
if isinstance(token, TokenList):
child_ops = self._find_write_operations(token)
operations.update(child_ops)
return operations