Skip to main content
Glama

Supabase MCP Server

Apache 2.0
797
  • Apple
  • Linux
validator.py13.4 kB
from typing import Any from pglast.parser import ParseError, parse_sql from supabase_mcp.exceptions import ValidationError from supabase_mcp.logger import logger from supabase_mcp.services.database.sql.models import ( QueryValidationResults, SQLQueryCategory, SQLQueryCommand, ValidatedStatement, ) from supabase_mcp.services.safety.safety_configs import SQLSafetyConfig class SQLValidator: """SQL validator class that is based on pglast library. Responsible for: - SQL query syntax validation - SQL query categorization""" # Mapping from statement types to object types STATEMENT_TYPE_TO_OBJECT_TYPE = { "CreateFunctionStmt": "function", "ViewStmt": "view", "CreateTableAsStmt": "materialized_view", # When relkind is 'm', otherwise 'table' "CreateEnumStmt": "type", "CreateTypeStmt": "type", "CreateExtensionStmt": "extension", "CreateForeignTableStmt": "foreign_table", "CreatePolicyStmt": "policy", "CreateTrigStmt": "trigger", "IndexStmt": "index", "CreateStmt": "table", "AlterTableStmt": "table", "GrantStmt": "privilege", "RevokeStmt": "privilege", "CreateProcStmt": "procedure", # For CREATE PROCEDURE } def __init__(self, safety_config: SQLSafetyConfig | None = None) -> None: self.safety_config = safety_config or SQLSafetyConfig() def validate_schema_name(self, schema_name: str) -> str: """Validate schema name. Rules: - Must be a string - Cannot be empty - Cannot contain spaces or special characters """ if not schema_name.strip(): raise ValidationError("Schema name cannot be empty") if " " in schema_name: raise ValidationError("Schema name cannot contain spaces") return schema_name def validate_table_name(self, table: str) -> str: """Validate table name. Rules: - Must be a string - Cannot be empty - Cannot contain spaces or special characters """ if not table.strip(): raise ValidationError("Table name cannot be empty") if " " in table: raise ValidationError("Table name cannot contain spaces") return table def basic_query_validation(self, query: str) -> str: """Validate SQL query. Rules: - Must be a string - Cannot be empty """ if not query.strip(): raise ValidationError("Query cannot be empty") return query @classmethod def validate_transaction_control(cls, query: str) -> bool: """Check if the query contains transaction control statements. Args: query: SQL query string Returns: bool: True if the query contains any transaction control statements """ return any(x in query.upper() for x in ["BEGIN", "COMMIT", "ROLLBACK"]) def validate_query(self, sql_query: str) -> QueryValidationResults: """ Identify the type of SQL query using PostgreSQL's parser. Args: sql_query: A SQL query string to parse Returns: QueryValidationResults: A validation result object containing information about the SQL statements Raises: ValidationError: If the query is not valid or contains TCL statements """ try: # Validate raw input sql_query = self.basic_query_validation(sql_query) # Parse the SQL using PostgreSQL's parser parse_tree = parse_sql(sql_query) if parse_tree is None: logger.debug("No statements found in the query") # logger.debug(f"Parse tree generated with {parse_tree} statements") # Validate statements result = self.validate_statements(original_query=sql_query, parse_tree=parse_tree) # Check if the query contains transaction control statements and reject them for statement in result.statements: if statement.category == SQLQueryCategory.TCL: logger.warning(f"Transaction control statement detected: {statement.command}") raise ValidationError( "Transaction control statements (BEGIN, COMMIT, ROLLBACK) are not allowed. " "Queries will be automatically wrapped in transactions by the system." ) return result except ParseError as e: logger.exception(f"SQL syntax error: {str(e)}") raise ValidationError(f"SQL syntax error: {str(e)}") from e except ValidationError: # let it propagate raise except Exception as e: logger.exception(f"Unexpected error during SQL validation: {str(e)}") raise ValidationError(f"Unexpected error during SQL validation: {str(e)}") from e def _map_to_command(self, stmt_type: str) -> SQLQueryCommand: """Map a pglast statement type to our SQLQueryCommand enum.""" mapping = { # DQL Commands "SelectStmt": SQLQueryCommand.SELECT, # DML Commands "InsertStmt": SQLQueryCommand.INSERT, "UpdateStmt": SQLQueryCommand.UPDATE, "DeleteStmt": SQLQueryCommand.DELETE, "MergeStmt": SQLQueryCommand.MERGE, # DDL Commands "CreateStmt": SQLQueryCommand.CREATE, "CreateTableAsStmt": SQLQueryCommand.CREATE, "CreateSchemaStmt": SQLQueryCommand.CREATE, "CreateExtensionStmt": SQLQueryCommand.CREATE, "CreateFunctionStmt": SQLQueryCommand.CREATE, "CreateTrigStmt": SQLQueryCommand.CREATE, "ViewStmt": SQLQueryCommand.CREATE, "IndexStmt": SQLQueryCommand.CREATE, # Additional DDL Commands "CreateEnumStmt": SQLQueryCommand.CREATE, "CreateTypeStmt": SQLQueryCommand.CREATE, "CreateDomainStmt": SQLQueryCommand.CREATE, "CreateSeqStmt": SQLQueryCommand.CREATE, "CreateForeignTableStmt": SQLQueryCommand.CREATE, "CreatePolicyStmt": SQLQueryCommand.CREATE, "CreateCastStmt": SQLQueryCommand.CREATE, "CreateOpClassStmt": SQLQueryCommand.CREATE, "CreateOpFamilyStmt": SQLQueryCommand.CREATE, "AlterTableStmt": SQLQueryCommand.ALTER, "AlterDomainStmt": SQLQueryCommand.ALTER, "AlterEnumStmt": SQLQueryCommand.ALTER, "AlterSeqStmt": SQLQueryCommand.ALTER, "AlterOwnerStmt": SQLQueryCommand.ALTER, "AlterObjectSchemaStmt": SQLQueryCommand.ALTER, "DropStmt": SQLQueryCommand.DROP, "TruncateStmt": SQLQueryCommand.TRUNCATE, "CommentStmt": SQLQueryCommand.COMMENT, "RenameStmt": SQLQueryCommand.RENAME, # DCL Commands "GrantStmt": SQLQueryCommand.GRANT, "GrantRoleStmt": SQLQueryCommand.GRANT, "RevokeStmt": SQLQueryCommand.REVOKE, "RevokeRoleStmt": SQLQueryCommand.REVOKE, "CreateRoleStmt": SQLQueryCommand.CREATE, "AlterRoleStmt": SQLQueryCommand.ALTER, "DropRoleStmt": SQLQueryCommand.DROP, # TCL Commands "TransactionStmt": SQLQueryCommand.BEGIN, # Will need refinement for different transaction types # PostgreSQL-specific Commands "VacuumStmt": SQLQueryCommand.VACUUM, "ExplainStmt": SQLQueryCommand.EXPLAIN, "CopyStmt": SQLQueryCommand.COPY, "ListenStmt": SQLQueryCommand.LISTEN, "NotifyStmt": SQLQueryCommand.NOTIFY, "PrepareStmt": SQLQueryCommand.PREPARE, "ExecuteStmt": SQLQueryCommand.EXECUTE, "DeallocateStmt": SQLQueryCommand.DEALLOCATE, } # Try to map the statement type, default to UNKNOWN return mapping.get(stmt_type, SQLQueryCommand.UNKNOWN) def validate_statements(self, original_query: str, parse_tree: Any) -> QueryValidationResults: """Validate the statements in the parse tree. Args: parse_tree: The parse tree to validate Returns: SQLBatchValidationResult: A validation result object containing information about the SQL statements Raises: ValidationError: If the query is not valid """ result = QueryValidationResults(original_query=original_query) if parse_tree is None: return result try: for stmt in parse_tree: if not hasattr(stmt, "stmt"): continue stmt_node = stmt.stmt stmt_type = stmt_node.__class__.__name__ logger.debug(f"Processing statement node type: {stmt_type}") # logger.debug(f"DEBUGGING stmt_node: {stmt_node}") logger.debug(f"DEBUGGING stmt_node.stmt_location: {stmt.stmt_location}") # Extract the object type if available object_type = None schema_name = None if hasattr(stmt_node, "relation") and stmt_node.relation is not None: if hasattr(stmt_node.relation, "relname"): object_type = stmt_node.relation.relname if hasattr(stmt_node.relation, "schemaname"): schema_name = stmt_node.relation.schemaname # For statements with 'relations' list (like TRUNCATE) elif hasattr(stmt_node, "relations") and stmt_node.relations: for relation in stmt_node.relations: if hasattr(relation, "relname"): object_type = relation.relname if hasattr(relation, "schemaname"): schema_name = relation.schemaname break # Simple approach: Set object_type based on statement type if not already set if object_type is None and stmt_type in self.STATEMENT_TYPE_TO_OBJECT_TYPE: object_type = self.STATEMENT_TYPE_TO_OBJECT_TYPE[stmt_type] # Default schema to public if not set if schema_name is None: schema_name = "public" # Get classification for this statement type classification = self.safety_config.classify_statement(stmt_type, stmt_node) logger.debug( f"Statement category classified as: {classification.get('category', 'UNKNOWN')} - risk level: {classification.get('risk_level', 'UNKNOWN')}" ) logger.debug(f"DEBUGGING QUERY EXTRACTION LOCATION: {stmt.stmt_location} - {stmt.stmt_len}") # Create validation result query_result = ValidatedStatement( category=classification["category"], command=self._map_to_command(stmt_type), risk_level=classification["risk_level"], needs_migration=classification["needs_migration"], object_type=object_type, schema_name=schema_name, query=original_query[stmt.stmt_location : stmt.stmt_location + stmt.stmt_len] if hasattr(stmt, "stmt_location") and hasattr(stmt, "stmt_len") else None, ) # logger.debug(f"Isolated query: {query_result.query}") logger.debug( "Query validation result:", { "statement_category": query_result.category, "risk_level": query_result.risk_level, "needs_migration": query_result.needs_migration, "object_type": query_result.object_type, "schema_name": query_result.schema_name, "query": query_result.query, }, ) # Add result to the batch result.statements.append(query_result) # Update highest risk level if query_result.risk_level > result.highest_risk_level: result.highest_risk_level = query_result.risk_level logger.debug(f"Updated batch validation result to: {query_result.risk_level}") if len(result.statements) == 0: logger.debug("No valid statements found in the query") raise ValidationError("No queries were parsed - please check correctness of your query") logger.debug( f"Validated a total of {len(result.statements)} with the highest risk level of: {result.highest_risk_level}" ) return result except AttributeError as e: # Handle attempting to access missing attributes in the parse tree raise ValidationError(f"Error accessing parse tree structure: {str(e)}") from e except KeyError as e: # Handle missing keys in classification dictionary raise ValidationError(f"Missing classification key: {str(e)}") from e

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/alexander-zuev/supabase-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server