Skip to main content
Glama
sql_validation.py16.9 kB
"""SQL validation and safe alternative generation for igloo-mcp. This module provides SQL statement type validation and generates safe alternatives for blocked operations like DELETE, DROP, and TRUNCATE. """ from __future__ import annotations import time from typing import Dict, List, Optional # Import upstream validation from snowflake-labs-mcp from mcp_server_snowflake.query_manager.tools import ( get_statement_type, validate_sql_type, ) try: import sqlglot from sqlglot import exp HAS_SQLGLOT = True except ImportError: # pragma: no cover HAS_SQLGLOT = False # Template-based safe alternatives for blocked SQL operations SAFE_ALTERNATIVES: Dict[str, Dict[str, str]] = { "Delete": { "soft_delete": "UPDATE {table} SET deleted_at = CURRENT_TIMESTAMP() WHERE {condition}", "create_view": "CREATE VIEW active_{table} AS SELECT * FROM {table} WHERE NOT ({condition})", }, "Drop": { "rename": "ALTER TABLE {table} RENAME TO {table}_deprecated_{timestamp}", "comment": "ALTER TABLE {table} SET COMMENT = 'Deprecated {timestamp}'", }, "Truncate": { "delete_all": "DELETE FROM {table} -- Add WHERE clause for safety", }, "TruncateTable": { # Upstream may return this variant "delete_all": "DELETE FROM {table} -- Add WHERE clause for safety", }, } # Statement types that should inherit SELECT permissions (case insensitive). _SELECT_EQUIVALENT_PREFIXES = ("union", "intersect", "except", "minus") _SELECT_EQUIVALENT_ALLOWLIST = { "union", "union all", "union_all", "unionall", "intersect", "intersect all", "intersect_all", "intersectall", "except", "except all", "except_all", "exceptall", "minus", "minus all", "minus_all", "minusall", } def _canonicalize_statement_type(stmt_type: str | None) -> str: """Return a lowercase canonical representation of a statement type.""" if not stmt_type: return "" normalized = stmt_type.replace("_", "") normalized = normalized.replace(" ", "") return normalized.lower() def _is_select_equivalent(stmt_type: str | None) -> bool: """Determine if a statement type should be treated as SELECT.""" canonical = _canonicalize_statement_type(stmt_type) if not canonical: return False if canonical.startswith(_SELECT_EQUIVALENT_PREFIXES): return True return False def extract_table_name(sql_statement: str) -> str: """Extract table name from SQL statement using sqlglot. Args: sql_statement: SQL statement to parse Returns: Table name or "<table_name>" if extraction fails Raises: ValueError: If sqlglot is not available """ if not HAS_SQLGLOT: # pragma: no cover raise ValueError("sqlglot is required for table name extraction") try: parsed = sqlglot.parse_one(sql_statement) # Try to find any Table node in the AST for table in parsed.find_all(exp.Table): if table.name: return table.name # Try to get the string representation table_str = str(table) if table_str and table_str != "<table_name>": return table_str # Special handling for DROP which uses Identifier if isinstance(parsed, exp.Drop): if hasattr(parsed, "this"): # Get the identifier identifier = parsed.this if hasattr(identifier, "name"): return identifier.name return str(identifier) except Exception: # If parsing fails, return placeholder pass return "<table_name>" def generate_sql_alternatives( statement: str, stmt_type: str, ) -> List[str]: """Generate safe alternative SQL statements for blocked operations. Args: statement: Original SQL statement stmt_type: Statement type (Delete, Drop, Truncate, etc.) Returns: List of formatted alternative SQL statements with warnings """ if stmt_type not in SAFE_ALTERNATIVES: return [] # Try to extract table name try: table = extract_table_name(statement) except ValueError: # sqlglot not available, use placeholder table = "<table_name>" except Exception: table = "<table_name>" alternatives = [] templates = SAFE_ALTERNATIVES[stmt_type] for name, template in templates.items(): # Format template with extracted values formatted = template.format( table=table, condition="<your_condition>", timestamp=int(time.time()), ) alternatives.append(f" {name}: {formatted}") # Add warning alternatives.append("\n⚠️ Review and customize templates before executing.") return alternatives def validate_sql_statement( statement: str, allow_list: List[str], disallow_list: List[str], ) -> tuple[str, bool, str | None]: """Validate SQL statement against permission lists. Args: statement: SQL statement to validate allow_list: List of allowed statement types (e.g., ["Select", "Insert"]) disallow_list: List of disallowed statement types (e.g., ["Delete", "Drop"]) Returns: Tuple of (statement_type, is_valid, error_message) - statement_type: The detected SQL statement type - is_valid: True if allowed, False if blocked - error_message: Detailed error with alternatives if blocked, None if valid """ # Build effective allow list: include SELECT-equivalent statements when SELECT is allowed allow_set = {item.lower() for item in allow_list} disallow_set = {item.lower() for item in disallow_list} # CRITICAL FIX: Ensure all lists are lowercase for upstream validation compatibility effective_allow_list = [item.lower() for item in allow_list] if "select" in allow_set: for extra in _SELECT_EQUIVALENT_ALLOWLIST: if extra not in allow_set: effective_allow_list.append(extra) allow_set.add(extra) # ENHANCEMENT: Fallback validation with sqlglot for better robustness fallback_stmt_type: Optional[str] = None select_like_hint = False multi_statement_detected = False parsed_expressions: list[exp.Expression] = [] if HAS_SQLGLOT: try: parsed_expressions = sqlglot.parse(statement, dialect="snowflake") except Exception: parsed_expressions = [] if parsed_expressions: primary_expression = parsed_expressions[0] key = primary_expression.key or "" fallback_stmt_type = key.upper() or None multi_statement_detected = len(parsed_expressions) > 1 if not multi_statement_detected: select_like_hint = _is_select_like_statement( statement, parsed=primary_expression ) if not select_like_hint and fallback_stmt_type in {"SELECT", "WITH"}: statement_upper = statement.upper() if ( "LATERAL FLATTEN" in statement_upper or "CROSS JOIN LATERAL" in statement_upper ): select_like_hint = True # CRITICAL FIX: Use lowercase lists for upstream validation (it's case-sensitive) lowercase_disallow_list = [item.lower() for item in disallow_list] stmt_type, is_valid = validate_sql_type( statement, effective_allow_list, lowercase_disallow_list ) if multi_statement_detected: detected: list[str] = [] for expr in parsed_expressions: name = (expr.key or "UNKNOWN").upper() if name not in detected: detected.append(name) pretty_detected = ( ", ".join(t.title() for t in detected) if detected else "Unknown" ) error_msg = ( "Multiple SQL statements detected in a single request. " "Only a single statement is permitted for execute_query. " f"Detected statements: {pretty_detected}." ) return "MultipleStatements", False, error_msg canonical_stmt = _canonicalize_statement_type(stmt_type) if canonical_stmt.startswith("with"): underlying_type = get_statement_type(statement) canonical_underlying = _canonicalize_statement_type(underlying_type) stmt_type = underlying_type or stmt_type if canonical_underlying == "select" and "select" in allow_set: return "Select", True, None if canonical_underlying in disallow_set: is_valid = False # Normalize statement types that should inherit SELECT permissions if _is_select_equivalent(stmt_type): stmt_type = "Select" if "select" in allow_set and not is_valid: # Treat SELECT-equivalent statements as allowed when SELECT is permitted return stmt_type, True, None if select_like_hint and "select" in allow_set and not multi_statement_detected: if is_valid: return "Select", True, None if canonical_stmt in {"", "unknown", "command"}: return "Select", True, None if _is_select_equivalent(stmt_type): return "Select", True, None if is_valid: return stmt_type, True, None # Generate error message with alternatives alternatives = generate_sql_alternatives(statement, stmt_type) # Enhanced structured error messages structured_error = { "code": "SQL_TYPE_NOT_ALLOWED", "statement_type": stmt_type, "allowed_types": [t.capitalize() for t in allow_list] if allow_list else [], "suggestions": [], } if alternatives: alt_text = "\n".join(alternatives) error_msg = ( f"SQL statement type '{stmt_type}' is not permitted.\n\n" f"Safe alternatives:\n{alt_text}" ) structured_error["suggestions"] = ["Use safe alternatives provided above"] else: # Capitalize allow_list for display (they're lowercase for validation) display_allowed = [t.capitalize() for t in allow_list] canonical_stmt = _canonicalize_statement_type(stmt_type) details = [f"SQL statement type '{stmt_type}' is not permitted."] if canonical_stmt == "command": details.append( "Snowflake returned 'Command' for this SQL, which is a fallback when the parser " "cannot classify the statement. Such statements are always blocked." ) else: details.append("Detected type is provided by the Snowflake parser.") # Enhanced suggestions for common issues if "Unknown" in stmt_type: # Special handling for Unknown type errors if "LATERAL" in statement.upper(): details.append("\n💡 This query contains LATERAL operations.") details.append( " If this is a SELECT query, LATERAL should be supported." ) structured_error["suggestions"].append( "Check if this is actually a SELECT query with LATERAL operations" ) elif "WITH" in statement.upper(): details.append("\n💡 This query starts WITH (CTE pattern).") details.append( " If this is a SELECT with CTE, it should be supported." ) structured_error["suggestions"].append( "Verify this is a SELECT statement with Common Table Expression" ) # Add sqlglot fallback information if available if fallback_stmt_type and fallback_stmt_type != "UNKNOWN": details.append(f"\n🔍 sqlglot detected this as: {fallback_stmt_type}") if fallback_stmt_type in ["SELECT", "WITH"] and "select" in allow_set: details.append( " This appears to be a SELECT query that should be allowed." ) structured_error["suggestions"].append( "Consider enabling SELECT statements if this is a data query" ) if display_allowed: details.append(f"\nAllowed types: {', '.join(display_allowed)}") structured_error["allowed_types"] = display_allowed error_msg = "\n".join(details) return stmt_type, False, error_msg def _is_select_like_statement( statement: str, parsed: Optional[exp.Expression] = None ) -> bool: """Return True when the SQL behaves like a SELECT or set operation.""" if not HAS_SQLGLOT: return False if parsed is None: try: parsed = sqlglot.parse_one(statement, dialect="snowflake") except Exception: return False def is_select_like(node: exp.Expression | None) -> bool: if node is None: return False if isinstance(node, (exp.Select, exp.SetOperation)): return True if isinstance(node, exp.With): target = node.this or node.args.get("expression") return is_select_like(target) if isinstance(node, (exp.Subquery, exp.Paren)): return is_select_like(node.this) if isinstance(node, exp.Query): return is_select_like(node.this) return False def strip_comments(sql: str) -> str: """Remove block and line comments while preserving string literals.""" def remove_block_comments(source: str) -> str: result: list[str] = [] in_single = False in_double = False idx = 0 while idx < len(source): char = source[idx] if char == "'" and not in_double: in_single = not in_single result.append(char) idx += 1 continue if char == '"' and not in_single: in_double = not in_double result.append(char) idx += 1 continue if not in_single and not in_double and source.startswith("/*", idx): idx += 2 depth = 1 while idx < len(source) and depth > 0: if source.startswith("/*", idx): depth += 1 idx += 2 continue if source.startswith("*/", idx): depth -= 1 idx += 2 continue idx += 1 continue result.append(char) idx += 1 return "".join(result) def remove_line_comments(source: str) -> str: lines: list[str] = [] for line in source.splitlines(): in_single = False in_double = False idx = 0 while idx < len(line): char = line[idx] if char == "'" and not in_double: in_single = not in_single elif char == '"' and not in_single: in_double = not in_double elif char == "-" and not in_single and not in_double: if idx + 1 < len(line) and line[idx + 1] == "-": line = line[:idx] break idx += 1 lines.append(line) return "\n".join(lines) without_blocks = remove_block_comments(sql) return remove_line_comments(without_blocks) structural_select = is_select_like(parsed) if not structural_select: return False upper_without_line_comments = strip_comments(statement).upper() keyword_tokens = ("UNION", "INTERSECT", "EXCEPT", "MINUS") contains_keywords = any( token in upper_without_line_comments for token in keyword_tokens ) if not contains_keywords: return True has_set_operation = isinstance(parsed, exp.SetOperation) if not has_set_operation: has_set_operation = any( isinstance(node, exp.SetOperation) for node in parsed.walk() ) return bool(has_set_operation) def get_sql_statement_type(statement: str) -> str: """Get the type of a SQL statement. Args: statement: SQL statement to analyze Returns: Statement type (e.g., "Select", "Delete", "Unknown") """ stmt_type = get_statement_type(statement) if _is_select_equivalent(stmt_type): return "Select" return stmt_type

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/Evan-Kim2028/igloo-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server