Skip to main content
Glama
test_sql_validation.py13.1 kB
"""Tests for SQL validation and safe alternatives.""" from __future__ import annotations import pytest from igloo_mcp.config import SQLPermissions from igloo_mcp.sql_validation import ( extract_table_name, generate_sql_alternatives, get_sql_statement_type, validate_sql_statement, ) class TestSQLPermissions: """Test SQLPermissions configuration.""" def test_default_permissions(self): """Test default permissions block dangerous operations.""" perms = SQLPermissions() assert perms.select is True assert perms.insert is False assert perms.update is False assert perms.delete is False # Blocked by default assert perms.drop is False # Blocked by default assert perms.truncate is False # Blocked by default def test_get_allow_list(self): """Test getting list of allowed statement types.""" perms = SQLPermissions() allow_list = perms.get_allow_list() # Config returns lowercase to match upstream validation assert "select" in allow_list assert "insert" not in allow_list assert "update" not in allow_list assert "delete" not in allow_list assert "drop" not in allow_list assert "truncate" not in allow_list def test_get_disallow_list(self): """Test getting list of disallowed statement types.""" perms = SQLPermissions() disallow_list = perms.get_disallow_list() # Config returns lowercase to match upstream validation assert "insert" in disallow_list assert "update" in disallow_list assert "delete" in disallow_list assert "drop" in disallow_list assert "truncate" in disallow_list assert "select" not in disallow_list def test_custom_permissions(self): """Test custom permission configuration.""" perms = SQLPermissions(delete=True, drop=True) allow_list = perms.get_allow_list() # Config returns lowercase to match upstream validation assert "delete" in allow_list assert "drop" in allow_list class TestExtractTableName: """Test table name extraction from SQL.""" def test_extract_from_delete(self): """Test extracting table name from DELETE statement.""" sql = "DELETE FROM users WHERE id = 1" table = extract_table_name(sql) # Either successfully extracts or returns placeholder assert table == "<table_name>" or "users" in table.lower() def test_extract_from_drop(self): """Test extracting table name from DROP statement.""" sql = "DROP TABLE old_data" table = extract_table_name(sql) # Either successfully extracts or returns placeholder assert table == "<table_name>" or "old_data" in table.lower() def test_extract_from_truncate(self): """Test extracting table name from TRUNCATE statement.""" sql = "TRUNCATE TABLE temp_table" table = extract_table_name(sql) # Either successfully extracts or returns placeholder assert table == "<table_name>" or "temp_table" in table.lower() def test_extract_failure_returns_placeholder(self): """Test that failed extraction returns placeholder.""" # Invalid SQL sql = "INVALID SQL STATEMENT" table = extract_table_name(sql) assert table == "<table_name>" class TestGenerateSQLAlternatives: """Test safe SQL alternative generation.""" def test_delete_alternatives(self): """Test generating alternatives for DELETE.""" sql = "DELETE FROM users WHERE id = 1" alternatives = generate_sql_alternatives(sql, "Delete") alt_text = "\n".join(alternatives) assert "soft_delete" in alt_text assert "UPDATE" in alt_text assert "deleted_at" in alt_text assert "⚠️" in alt_text def test_drop_alternatives(self): """Test generating alternatives for DROP.""" sql = "DROP TABLE old_data" alternatives = generate_sql_alternatives(sql, "Drop") alt_text = "\n".join(alternatives) assert "rename" in alt_text assert "ALTER TABLE" in alt_text assert "RENAME TO" in alt_text assert "deprecated" in alt_text def test_truncate_alternatives(self): """Test generating alternatives for TRUNCATE.""" sql = "TRUNCATE TABLE temp_data" alternatives = generate_sql_alternatives(sql, "Truncate") alt_text = "\n".join(alternatives) assert "DELETE FROM" in alt_text assert "WHERE" in alt_text def test_no_alternatives_for_unknown_type(self): """Test that unknown statement types have no alternatives.""" alternatives = generate_sql_alternatives("SELECT * FROM foo", "Select") assert alternatives == [] class TestValidateSQLStatement: """Test SQL statement validation.""" @pytest.mark.skip(reason="Upstream validate_sql_type behavior needs investigation") def test_allowed_select_statement(self): """Test that SELECT is allowed when in allow list.""" sql = "SELECT * FROM users" # Config uses lowercase for validation (upstream returns capitalized) allow_list = ["select"] disallow_list = [] stmt_type, is_valid, error_msg = validate_sql_statement( sql, allow_list, disallow_list ) assert stmt_type == "Select" assert is_valid is True assert error_msg is None def test_blocked_delete_statement(self): """Test that DELETE is blocked with alternatives.""" sql = "DELETE FROM users WHERE id = 1" # Config uses lowercase for validation allow_list = ["select", "insert", "update"] disallow_list = ["delete", "drop", "truncate"] stmt_type, is_valid, error_msg = validate_sql_statement( sql, allow_list, disallow_list ) assert stmt_type == "Delete" assert is_valid is False assert error_msg is not None assert "not permitted" in error_msg assert "soft_delete" in error_msg assert "UPDATE" in error_msg def test_blocked_drop_statement(self): """Test that DROP is blocked with alternatives.""" sql = "DROP TABLE old_data" # Config uses lowercase for validation allow_list = ["select", "insert"] disallow_list = ["delete", "drop", "truncate"] stmt_type, is_valid, error_msg = validate_sql_statement( sql, allow_list, disallow_list ) assert stmt_type == "Drop" assert is_valid is False assert error_msg is not None assert "not permitted" in error_msg assert "rename" in error_msg assert "RENAME TO" in error_msg def test_blocked_truncate_statement(self): """Test that TRUNCATE is blocked with alternatives.""" sql = "TRUNCATE TABLE temp_data" # Config uses lowercase for validation allow_list = ["select"] disallow_list = ["delete", "drop", "truncate", "truncatetable"] stmt_type, is_valid, error_msg = validate_sql_statement( sql, allow_list, disallow_list ) # Upstream may return "Truncate" or "TruncateTable" assert stmt_type in ["Truncate", "TruncateTable"] assert is_valid is False assert error_msg is not None assert "not permitted" in error_msg def test_command_fallback_includes_parser_context(self): """Parser fallback to Command should explain why statement was blocked.""" sql = "ALTER USER foo SET RSA_PUBLIC_KEY = 'abc'" allow_list = ["select", "show", "describe", "use"] disallow_list = [ "insert", "update", "create", "alter", "delete", "drop", "truncate", "command", ] stmt_type, is_valid, error_msg = validate_sql_statement( sql, allow_list, disallow_list ) assert stmt_type == "Command" assert is_valid is False assert error_msg is not None assert "Snowflake returned 'Command'" in error_msg @pytest.mark.skip(reason="Upstream validate_sql_type behavior needs investigation") def test_allowed_insert_statement(self): """Test that INSERT is allowed when in allow list.""" sql = "INSERT INTO users (name) VALUES ('Alice')" # Config uses lowercase for validation (upstream returns capitalized) allow_list = ["insert"] disallow_list = [] stmt_type, is_valid, error_msg = validate_sql_statement( sql, allow_list, disallow_list ) assert stmt_type == "Insert" assert is_valid is True assert error_msg is None def test_select_union_is_allowed(self): """UNION queries should inherit SELECT permissions.""" sql = "SELECT 1 UNION SELECT 2" perms = SQLPermissions() allow_list = perms.get_allow_list() disallow_list = perms.get_disallow_list() stmt_type, is_valid, error_msg = validate_sql_statement( sql, allow_list, disallow_list ) assert stmt_type == "Select" assert is_valid is True assert error_msg is None def test_select_union_all_is_allowed(self): """UNION ALL queries should inherit SELECT permissions.""" sql = "SELECT 1 UNION ALL SELECT 2" perms = SQLPermissions() allow_list = perms.get_allow_list() disallow_list = perms.get_disallow_list() stmt_type, is_valid, error_msg = validate_sql_statement( sql, allow_list, disallow_list ) assert stmt_type == "Select" assert is_valid is True assert error_msg is None def test_select_intersect_is_allowed(self): """INTERSECT queries should inherit SELECT permissions.""" sql = "SELECT 1 INTERSECT SELECT 2" perms = SQLPermissions() allow_list = perms.get_allow_list() disallow_list = perms.get_disallow_list() stmt_type, is_valid, error_msg = validate_sql_statement( sql, allow_list, disallow_list ) assert stmt_type == "Select" assert is_valid is True assert error_msg is None def test_with_clause_is_allowed(self): """CTE (WITH) queries should inherit SELECT permissions.""" sql = "WITH cte AS (SELECT 1) SELECT * FROM cte" perms = SQLPermissions() allow_list = perms.get_allow_list() disallow_list = perms.get_disallow_list() stmt_type, is_valid, error_msg = validate_sql_statement( sql, allow_list, disallow_list ) assert stmt_type == "Select" assert is_valid is True assert error_msg is None def test_with_delete_remains_blocked(self): """CTE followed by DELETE should stay blocked.""" sql = "WITH cte AS (SELECT 1) DELETE FROM users" perms = SQLPermissions() allow_list = perms.get_allow_list() disallow_list = perms.get_disallow_list() stmt_type, is_valid, error_msg = validate_sql_statement( sql, allow_list, disallow_list ) assert stmt_type == "Delete" assert is_valid is False assert error_msg is not None class TestGetSQLStatementType: """Test SQL statement type detection.""" def test_detect_select(self): """Test detecting SELECT statement.""" sql = "SELECT * FROM users" stmt_type = get_sql_statement_type(sql) assert stmt_type == "Select" def test_detect_insert(self): """Test detecting INSERT statement.""" sql = "INSERT INTO users (name) VALUES ('Alice')" stmt_type = get_sql_statement_type(sql) assert stmt_type == "Insert" def test_detect_update(self): """Test detecting UPDATE statement.""" sql = "UPDATE users SET name = 'Bob' WHERE id = 1" stmt_type = get_sql_statement_type(sql) assert stmt_type == "Update" def test_detect_delete(self): """Test detecting DELETE statement.""" sql = "DELETE FROM users WHERE id = 1" stmt_type = get_sql_statement_type(sql) assert stmt_type == "Delete" def test_detect_drop(self): """Test detecting DROP statement.""" sql = "DROP TABLE old_data" stmt_type = get_sql_statement_type(sql) assert stmt_type == "Drop" def test_detect_truncate(self): """Test detecting TRUNCATE statement.""" sql = "TRUNCATE TABLE temp_data" stmt_type = get_sql_statement_type(sql) # Upstream may return "Truncate" or "TruncateTable" assert stmt_type in ["Truncate", "TruncateTable"] def test_detect_union_returns_select(self): """Union statements should be reported as Select for diagnostics.""" sql = "SELECT 1 UNION SELECT 2" stmt_type = get_sql_statement_type(sql) assert stmt_type == "Select"

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/Evan-Kim2028/igloo-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server