"""Tests for SQL validation and safe alternatives."""
from __future__ import annotations
import pytest
from igloo_mcp.config import SQLPermissions
from igloo_mcp.sql_validation import (
extract_table_name,
generate_sql_alternatives,
get_sql_statement_type,
validate_sql_statement,
)
class TestSQLPermissions:
"""Test SQLPermissions configuration."""
def test_default_permissions(self):
"""Test default permissions block dangerous operations."""
perms = SQLPermissions()
assert perms.select is True
assert perms.insert is False
assert perms.update is False
assert perms.delete is False # Blocked by default
assert perms.drop is False # Blocked by default
assert perms.truncate is False # Blocked by default
def test_get_allow_list(self):
"""Test getting list of allowed statement types."""
perms = SQLPermissions()
allow_list = perms.get_allow_list()
# Config returns lowercase to match upstream validation
assert "select" in allow_list
assert "insert" not in allow_list
assert "update" not in allow_list
assert "delete" not in allow_list
assert "drop" not in allow_list
assert "truncate" not in allow_list
def test_get_disallow_list(self):
"""Test getting list of disallowed statement types."""
perms = SQLPermissions()
disallow_list = perms.get_disallow_list()
# Config returns lowercase to match upstream validation
assert "insert" in disallow_list
assert "update" in disallow_list
assert "delete" in disallow_list
assert "drop" in disallow_list
assert "truncate" in disallow_list
assert "select" not in disallow_list
def test_custom_permissions(self):
"""Test custom permission configuration."""
perms = SQLPermissions(delete=True, drop=True)
allow_list = perms.get_allow_list()
# Config returns lowercase to match upstream validation
assert "delete" in allow_list
assert "drop" in allow_list
class TestExtractTableName:
"""Test table name extraction from SQL."""
def test_extract_from_delete(self):
"""Test extracting table name from DELETE statement."""
sql = "DELETE FROM users WHERE id = 1"
table = extract_table_name(sql)
# Either successfully extracts or returns placeholder
assert table == "<table_name>" or "users" in table.lower()
def test_extract_from_drop(self):
"""Test extracting table name from DROP statement."""
sql = "DROP TABLE old_data"
table = extract_table_name(sql)
# Either successfully extracts or returns placeholder
assert table == "<table_name>" or "old_data" in table.lower()
def test_extract_from_truncate(self):
"""Test extracting table name from TRUNCATE statement."""
sql = "TRUNCATE TABLE temp_table"
table = extract_table_name(sql)
# Either successfully extracts or returns placeholder
assert table == "<table_name>" or "temp_table" in table.lower()
def test_extract_failure_returns_placeholder(self):
"""Test that failed extraction returns placeholder."""
# Invalid SQL
sql = "INVALID SQL STATEMENT"
table = extract_table_name(sql)
assert table == "<table_name>"
class TestGenerateSQLAlternatives:
"""Test safe SQL alternative generation."""
def test_delete_alternatives(self):
"""Test generating alternatives for DELETE."""
sql = "DELETE FROM users WHERE id = 1"
alternatives = generate_sql_alternatives(sql, "Delete")
alt_text = "\n".join(alternatives)
assert "soft_delete" in alt_text
assert "UPDATE" in alt_text
assert "deleted_at" in alt_text
assert "⚠️" in alt_text
def test_drop_alternatives(self):
"""Test generating alternatives for DROP."""
sql = "DROP TABLE old_data"
alternatives = generate_sql_alternatives(sql, "Drop")
alt_text = "\n".join(alternatives)
assert "rename" in alt_text
assert "ALTER TABLE" in alt_text
assert "RENAME TO" in alt_text
assert "deprecated" in alt_text
def test_truncate_alternatives(self):
"""Test generating alternatives for TRUNCATE."""
sql = "TRUNCATE TABLE temp_data"
alternatives = generate_sql_alternatives(sql, "Truncate")
alt_text = "\n".join(alternatives)
assert "DELETE FROM" in alt_text
assert "WHERE" in alt_text
def test_no_alternatives_for_unknown_type(self):
"""Test that unknown statement types have no alternatives."""
alternatives = generate_sql_alternatives("SELECT * FROM foo", "Select")
assert alternatives == []
class TestValidateSQLStatement:
"""Test SQL statement validation."""
@pytest.mark.skip(reason="Upstream validate_sql_type behavior needs investigation")
def test_allowed_select_statement(self):
"""Test that SELECT is allowed when in allow list."""
sql = "SELECT * FROM users"
# Config uses lowercase for validation (upstream returns capitalized)
allow_list = ["select"]
disallow_list = []
stmt_type, is_valid, error_msg = validate_sql_statement(
sql, allow_list, disallow_list
)
assert stmt_type == "Select"
assert is_valid is True
assert error_msg is None
def test_blocked_delete_statement(self):
"""Test that DELETE is blocked with alternatives."""
sql = "DELETE FROM users WHERE id = 1"
# Config uses lowercase for validation
allow_list = ["select", "insert", "update"]
disallow_list = ["delete", "drop", "truncate"]
stmt_type, is_valid, error_msg = validate_sql_statement(
sql, allow_list, disallow_list
)
assert stmt_type == "Delete"
assert is_valid is False
assert error_msg is not None
assert "not permitted" in error_msg
assert "soft_delete" in error_msg
assert "UPDATE" in error_msg
def test_blocked_drop_statement(self):
"""Test that DROP is blocked with alternatives."""
sql = "DROP TABLE old_data"
# Config uses lowercase for validation
allow_list = ["select", "insert"]
disallow_list = ["delete", "drop", "truncate"]
stmt_type, is_valid, error_msg = validate_sql_statement(
sql, allow_list, disallow_list
)
assert stmt_type == "Drop"
assert is_valid is False
assert error_msg is not None
assert "not permitted" in error_msg
assert "rename" in error_msg
assert "RENAME TO" in error_msg
def test_blocked_truncate_statement(self):
"""Test that TRUNCATE is blocked with alternatives."""
sql = "TRUNCATE TABLE temp_data"
# Config uses lowercase for validation
allow_list = ["select"]
disallow_list = ["delete", "drop", "truncate", "truncatetable"]
stmt_type, is_valid, error_msg = validate_sql_statement(
sql, allow_list, disallow_list
)
# Upstream may return "Truncate" or "TruncateTable"
assert stmt_type in ["Truncate", "TruncateTable"]
assert is_valid is False
assert error_msg is not None
assert "not permitted" in error_msg
def test_command_fallback_includes_parser_context(self):
"""Parser fallback to Command should explain why statement was blocked."""
sql = "ALTER USER foo SET RSA_PUBLIC_KEY = 'abc'"
allow_list = ["select", "show", "describe", "use"]
disallow_list = [
"insert",
"update",
"create",
"alter",
"delete",
"drop",
"truncate",
"command",
]
stmt_type, is_valid, error_msg = validate_sql_statement(
sql, allow_list, disallow_list
)
assert stmt_type == "Command"
assert is_valid is False
assert error_msg is not None
assert "Snowflake returned 'Command'" in error_msg
@pytest.mark.skip(reason="Upstream validate_sql_type behavior needs investigation")
def test_allowed_insert_statement(self):
"""Test that INSERT is allowed when in allow list."""
sql = "INSERT INTO users (name) VALUES ('Alice')"
# Config uses lowercase for validation (upstream returns capitalized)
allow_list = ["insert"]
disallow_list = []
stmt_type, is_valid, error_msg = validate_sql_statement(
sql, allow_list, disallow_list
)
assert stmt_type == "Insert"
assert is_valid is True
assert error_msg is None
def test_select_union_is_allowed(self):
"""UNION queries should inherit SELECT permissions."""
sql = "SELECT 1 UNION SELECT 2"
perms = SQLPermissions()
allow_list = perms.get_allow_list()
disallow_list = perms.get_disallow_list()
stmt_type, is_valid, error_msg = validate_sql_statement(
sql, allow_list, disallow_list
)
assert stmt_type == "Select"
assert is_valid is True
assert error_msg is None
def test_select_union_all_is_allowed(self):
"""UNION ALL queries should inherit SELECT permissions."""
sql = "SELECT 1 UNION ALL SELECT 2"
perms = SQLPermissions()
allow_list = perms.get_allow_list()
disallow_list = perms.get_disallow_list()
stmt_type, is_valid, error_msg = validate_sql_statement(
sql, allow_list, disallow_list
)
assert stmt_type == "Select"
assert is_valid is True
assert error_msg is None
def test_select_intersect_is_allowed(self):
"""INTERSECT queries should inherit SELECT permissions."""
sql = "SELECT 1 INTERSECT SELECT 2"
perms = SQLPermissions()
allow_list = perms.get_allow_list()
disallow_list = perms.get_disallow_list()
stmt_type, is_valid, error_msg = validate_sql_statement(
sql, allow_list, disallow_list
)
assert stmt_type == "Select"
assert is_valid is True
assert error_msg is None
def test_with_clause_is_allowed(self):
"""CTE (WITH) queries should inherit SELECT permissions."""
sql = "WITH cte AS (SELECT 1) SELECT * FROM cte"
perms = SQLPermissions()
allow_list = perms.get_allow_list()
disallow_list = perms.get_disallow_list()
stmt_type, is_valid, error_msg = validate_sql_statement(
sql, allow_list, disallow_list
)
assert stmt_type == "Select"
assert is_valid is True
assert error_msg is None
def test_with_delete_remains_blocked(self):
"""CTE followed by DELETE should stay blocked."""
sql = "WITH cte AS (SELECT 1) DELETE FROM users"
perms = SQLPermissions()
allow_list = perms.get_allow_list()
disallow_list = perms.get_disallow_list()
stmt_type, is_valid, error_msg = validate_sql_statement(
sql, allow_list, disallow_list
)
assert stmt_type == "Delete"
assert is_valid is False
assert error_msg is not None
class TestGetSQLStatementType:
"""Test SQL statement type detection."""
def test_detect_select(self):
"""Test detecting SELECT statement."""
sql = "SELECT * FROM users"
stmt_type = get_sql_statement_type(sql)
assert stmt_type == "Select"
def test_detect_insert(self):
"""Test detecting INSERT statement."""
sql = "INSERT INTO users (name) VALUES ('Alice')"
stmt_type = get_sql_statement_type(sql)
assert stmt_type == "Insert"
def test_detect_update(self):
"""Test detecting UPDATE statement."""
sql = "UPDATE users SET name = 'Bob' WHERE id = 1"
stmt_type = get_sql_statement_type(sql)
assert stmt_type == "Update"
def test_detect_delete(self):
"""Test detecting DELETE statement."""
sql = "DELETE FROM users WHERE id = 1"
stmt_type = get_sql_statement_type(sql)
assert stmt_type == "Delete"
def test_detect_drop(self):
"""Test detecting DROP statement."""
sql = "DROP TABLE old_data"
stmt_type = get_sql_statement_type(sql)
assert stmt_type == "Drop"
def test_detect_truncate(self):
"""Test detecting TRUNCATE statement."""
sql = "TRUNCATE TABLE temp_data"
stmt_type = get_sql_statement_type(sql)
# Upstream may return "Truncate" or "TruncateTable"
assert stmt_type in ["Truncate", "TruncateTable"]
def test_detect_union_returns_select(self):
"""Union statements should be reported as Select for diagnostics."""
sql = "SELECT 1 UNION SELECT 2"
stmt_type = get_sql_statement_type(sql)
assert stmt_type == "Select"