Skip to main content
Glama
test_delta_service.py43.8 kB
"""Tests for the delta_service module. Tests cover: - Query validation (_check_query_is_valid) - Query building (build_select_query) - Spark operations with mocked SparkSession: - count_delta_table() - sample_delta_table() - query_delta_table() - select_from_delta_table() - Caching behavior - Timeout handling - Concurrent query execution """ from unittest.mock import MagicMock, patch import pytest from src.delta_lake import delta_service from src.service.exceptions import ( DeltaDatabaseNotFoundError, DeltaTableNotFoundError, SparkOperationError, SparkQueryError, SparkTimeoutError, ) from src.service.models import ( AggregationSpec, ColumnSpec, FilterCondition, OrderBySpec, TableSelectRequest, TableSelectResponse, ) def test_delta_service_imports(): """Test that delta_service module can be imported.""" assert delta_service is not None # Lists of valid and invalid queries to test validation logic VALID_QUERIES = [ "SELECT * FROM my_table", "SELECT id, name FROM users", "SELECT COUNT(*) FROM transactions", "SELECT AVG(amount) FROM payments", "SELECT * FROM table WHERE id > 100", "SELECT DISTINCT category FROM products", "SELECT * FROM orders ORDER BY id DESC", "SELECT * FROM customers LIMIT 10", "SELECT * FROM (SELECT id FROM inner_table) AS subquery", "SELECT t1.id, t2.name FROM table1 t1 JOIN table2 t2 ON t1.id = t2.id", "SELECT * FROM my_table WHERE date BETWEEN '2023-01-01' AND '2023-12-31'", "SELECT id FROM events WHERE type IN ('click', 'view', 'purchase')", "SELECT * FROM sales WHERE region = 'North' AND amount > 1000", "SELECT COALESCE(email, phone, 'no contact') FROM contacts", "SELECT * FROM employees WHERE department LIKE 'eng%'", ] # Invalid queries grouped by expected error message INVALID_QUERIES = { "must contain exactly one statement": [ # Multiple statements "SELECT * FROM users; SELECT * FROM roles", "SELECT * FROM table1; DROP TABLE table2", "SELECT * FROM users; WAITFOR DELAY '0:0:10'--", "SELECT * FROM users WHERE id = ABS(1); DROP TABLE logs; --", "SELECT * FROM users; DROP TABLE logs", # Empty statement "", " ", "\n\n", ], "must be one of the following: select": [ # Non-SELECT statements "INSERT INTO users VALUES (1, 'john')", "UPDATE users SET active = true WHERE id = 1", "DELETE FROM logs WHERE created_at < '2023-01-01'", "DROP TABLE old_data", "CREATE TABLE new_table (id INT, name VARCHAR)", "TRUNCATE TABLE logs", "ALTER TABLE users ADD COLUMN age INT", "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED THEN UPDATE SET target.val = source.val", "VACUUM delta_table", "CREATE OR REPLACE FUNCTION f() RETURNS void AS $$ SELECT 1; $$ LANGUAGE SQL", "WITH t AS (SELECT 1) DELETE FROM users", # Invalid SQL queries "/**/SEL/**/ECT * FR/**/OM users", "S%45L%45CT%20*%20%46ROM%20users", "S\tE\tL\tE\tC\tT * F\tR\tO\tM users", ], "contains forbidden keyword": [ # Forbidden keywords in various contexts # Most of them are not valid SQL queries "Select * from users where id in (DELETE from users where id < 100)", "SELECT * FROM users WHERE DROP = 1", "SELECT ${1+drop}table FROM users", "SELECT * FROM users WHERE id = 1 OR drop = true", "SELECT * FROM users WHERE id = 1 OR drop table users", "SELECT * FROM users WHERE id = 1 OR delete from users", "SELECT * FROM users WHERE id = 1 AND update users set name = 'hacked'", "SELECT * FROM users WHERE id = 1 AND create table hacked (id int)", "SELECT * FROM users WHERE command = 'drop table'", "SELECT * FROM users WHERE col = 'value' OR drop table logs", "SELECT * FROM users WHERE DROP = true", ], "contains disallowed metacharacter": [ # Metacharacters "SELECT * FROM users WHERE name = 'user;' ", "/* Comment */ SELECT id FROM users /* Another comment */", "SELECT * FROM users WHERE id = 1 -- bypass filter", ], "contains forbidden PostgreSQL schema": [ # PostgreSQL system schemas that should be blocked "SELECT * FROM pg_catalog.pg_tables", "SELECT * FROM information_schema.tables", "SELECT * FROM pg_class", "SELECT tablename FROM pg_tables", "SELECT table_name FROM information_schema.tables", "SELECT * FROM pg_settings WHERE name = 'data_directory'", "SELECT * FROM information_schema.schemata", "SELECT * FROM pg_shadow", "SELECT * FROM pg_authid", "SELECT * FROM pg_stat_activity", "SELECT datname FROM pg_database", ], } def test_valid_queries(): """Test that all valid queries pass validation.""" for query in VALID_QUERIES: try: assert delta_service._check_query_is_valid(query) is True except Exception as e: pytest.fail(f"Valid query '{query}' failed validation: {str(e)}") def test_invalid_queries(): """Test that all invalid queries fail validation with the correct error messages.""" for error_pattern, queries in INVALID_QUERIES.items(): for query in queries: with pytest.raises(SparkQueryError, match=error_pattern): delta_service._check_query_is_valid(query) # --- # Query Builder Tests # --- class TestValidateIdentifier: """Tests for _validate_identifier function.""" def test_valid_identifiers(self): """Test that valid identifiers pass validation.""" valid_names = [ "table_name", "column1", "_private", "TableName", "UPPERCASE", "a1b2c3", "_", "a", ] for name in valid_names: # Should not raise delta_service._validate_identifier(name) def test_invalid_identifiers(self): """Test that invalid identifiers fail validation.""" invalid_names = [ "", # Empty "123start", # Starts with number "has-dash", # Contains dash "has space", # Contains space "has.dot", # Contains dot "has;semicolon", # Contains semicolon "has'quote", # Contains quote 'has"double', # Contains double quote "has`backtick", # Contains backtick ] for name in invalid_names: with pytest.raises(SparkQueryError, match="Invalid"): delta_service._validate_identifier(name) class TestEscapeValue: """Tests for _escape_value function.""" def test_escape_none(self): """Test escaping None values.""" assert delta_service._escape_value(None) == "NULL" def test_escape_bool(self): """Test escaping boolean values.""" assert delta_service._escape_value(True) == "TRUE" assert delta_service._escape_value(False) == "FALSE" def test_escape_numbers(self): """Test escaping numeric values.""" assert delta_service._escape_value(42) == "42" assert delta_service._escape_value(3.14) == "3.14" assert delta_service._escape_value(-100) == "-100" assert delta_service._escape_value(0) == "0" def test_escape_strings(self): """Test escaping string values.""" assert delta_service._escape_value("hello") == "'hello'" assert delta_service._escape_value("O'Brien") == "'O''Brien'" assert delta_service._escape_value("test''double") == "'test''''double'" assert delta_service._escape_value("") == "''" class TestBuildColumnExpression: """Tests for _build_column_expression function.""" def test_simple_column(self): """Test building a simple column expression.""" col = ColumnSpec(column="name") assert delta_service._build_column_expression(col) == "`name`" def test_column_with_alias(self): """Test building a column expression with alias.""" col = ColumnSpec(column="name", alias="user_name") assert delta_service._build_column_expression(col) == "`name` AS `user_name`" def test_column_with_table_alias(self): """Test building a column expression with table alias.""" col = ColumnSpec(column="name", table_alias="u") assert delta_service._build_column_expression(col) == "`u`.`name`" def test_column_with_both_aliases(self): """Test building a column expression with both table and column alias.""" col = ColumnSpec(column="name", table_alias="u", alias="user_name") assert ( delta_service._build_column_expression(col) == "`u`.`name` AS `user_name`" ) class TestBuildAggregationExpression: """Tests for _build_aggregation_expression function.""" def test_count_star(self): """Test COUNT(*) aggregation.""" agg = AggregationSpec(function="COUNT", column="*") assert delta_service._build_aggregation_expression(agg) == "COUNT(*)" def test_count_star_with_alias(self): """Test COUNT(*) with alias.""" agg = AggregationSpec(function="COUNT", column="*", alias="total") assert delta_service._build_aggregation_expression(agg) == "COUNT(*) AS `total`" def test_sum_column(self): """Test SUM aggregation.""" agg = AggregationSpec(function="SUM", column="amount") assert delta_service._build_aggregation_expression(agg) == "SUM(`amount`)" def test_avg_with_alias(self): """Test AVG with alias.""" agg = AggregationSpec(function="AVG", column="price", alias="avg_price") assert ( delta_service._build_aggregation_expression(agg) == "AVG(`price`) AS `avg_price`" ) def test_min_max(self): """Test MIN and MAX aggregations.""" min_agg = AggregationSpec(function="MIN", column="created_at") max_agg = AggregationSpec(function="MAX", column="updated_at") assert ( delta_service._build_aggregation_expression(min_agg) == "MIN(`created_at`)" ) assert ( delta_service._build_aggregation_expression(max_agg) == "MAX(`updated_at`)" ) def test_invalid_star_with_non_count(self): """Test that * is only valid for COUNT.""" agg = AggregationSpec(function="SUM", column="*") with pytest.raises(SparkQueryError, match="does not support"): delta_service._build_aggregation_expression(agg) class TestBuildFilterCondition: """Tests for _build_filter_condition function.""" def test_equality(self): """Test equality filter.""" condition = FilterCondition(column="status", operator="=", value="active") assert delta_service._build_filter_condition(condition) == "`status` = 'active'" def test_not_equal(self): """Test not equal filter.""" condition = FilterCondition(column="type", operator="!=", value="deleted") assert delta_service._build_filter_condition(condition) == "`type` != 'deleted'" def test_comparison_operators(self): """Test numeric comparison operators.""" test_cases = [ (FilterCondition(column="age", operator=">", value=18), "`age` > 18"), (FilterCondition(column="age", operator=">=", value=21), "`age` >= 21"), (FilterCondition(column="price", operator="<", value=100), "`price` < 100"), (FilterCondition(column="qty", operator="<=", value=10), "`qty` <= 10"), ] for condition, expected in test_cases: assert delta_service._build_filter_condition(condition) == expected def test_in_operator(self): """Test IN operator.""" condition = FilterCondition( column="status", operator="IN", values=["active", "pending"] ) assert ( delta_service._build_filter_condition(condition) == "`status` IN ('active', 'pending')" ) def test_not_in_operator(self): """Test NOT IN operator.""" condition = FilterCondition(column="id", operator="NOT IN", values=[1, 2, 3]) assert ( delta_service._build_filter_condition(condition) == "`id` NOT IN (1, 2, 3)" ) def test_like_operator(self): """Test LIKE operator.""" condition = FilterCondition(column="name", operator="LIKE", value="%john%") assert ( delta_service._build_filter_condition(condition) == "`name` LIKE '%john%'" ) def test_not_like_operator(self): """Test NOT LIKE operator.""" condition = FilterCondition(column="email", operator="NOT LIKE", value="%spam%") assert ( delta_service._build_filter_condition(condition) == "`email` NOT LIKE '%spam%'" ) def test_is_null(self): """Test IS NULL operator.""" condition = FilterCondition(column="deleted_at", operator="IS NULL") assert ( delta_service._build_filter_condition(condition) == "`deleted_at` IS NULL" ) def test_is_not_null(self): """Test IS NOT NULL operator.""" condition = FilterCondition(column="email", operator="IS NOT NULL") assert delta_service._build_filter_condition(condition) == "`email` IS NOT NULL" def test_between_operator(self): """Test BETWEEN operator.""" condition = FilterCondition( column="created_at", operator="BETWEEN", values=["2023-01-01", "2023-12-31"] ) assert ( delta_service._build_filter_condition(condition) == "`created_at` BETWEEN '2023-01-01' AND '2023-12-31'" ) def test_in_without_values_raises_error(self): """Test that IN operator without values raises error.""" condition = FilterCondition(column="status", operator="IN") with pytest.raises(SparkQueryError, match="requires 'values'"): delta_service._build_filter_condition(condition) def test_between_with_wrong_values_count(self): """Test that BETWEEN with wrong number of values raises error.""" condition = FilterCondition(column="age", operator="BETWEEN", values=[18]) with pytest.raises(SparkQueryError, match="requires exactly 2 values"): delta_service._build_filter_condition(condition) def test_equality_without_value_raises_error(self): """Test that equality without value raises error.""" condition = FilterCondition(column="status", operator="=") with pytest.raises(SparkQueryError, match="requires 'value'"): delta_service._build_filter_condition(condition) class TestBuildSelectQuery: """Tests for build_select_query function.""" def test_simple_select_all(self): """Test SELECT * query.""" request = TableSelectRequest(database="mydb", table="users") query = delta_service.build_select_query(request) assert query == "SELECT * FROM `mydb`.`users` LIMIT 100 OFFSET 0" def test_select_specific_columns(self): """Test SELECT with specific columns.""" request = TableSelectRequest( database="mydb", table="users", columns=[ ColumnSpec(column="id"), ColumnSpec(column="name"), ], ) query = delta_service.build_select_query(request) assert "SELECT `id`, `name` FROM" in query def test_select_with_distinct(self): """Test SELECT DISTINCT.""" request = TableSelectRequest( database="mydb", table="users", columns=[ColumnSpec(column="status")], distinct=True, ) query = delta_service.build_select_query(request) assert "SELECT DISTINCT `status`" in query def test_select_with_aggregations(self): """Test SELECT with aggregations.""" request = TableSelectRequest( database="mydb", table="orders", aggregations=[ AggregationSpec(function="COUNT", column="*", alias="total"), AggregationSpec(function="SUM", column="amount", alias="total_amount"), ], ) query = delta_service.build_select_query(request) assert "COUNT(*) AS `total`" in query assert "SUM(`amount`) AS `total_amount`" in query def test_select_with_where(self): """Test SELECT with WHERE clause.""" request = TableSelectRequest( database="mydb", table="users", filters=[ FilterCondition(column="status", operator="=", value="active"), FilterCondition(column="age", operator=">=", value=18), ], ) query = delta_service.build_select_query(request) assert "WHERE `status` = 'active' AND `age` >= 18" in query def test_select_with_group_by(self): """Test SELECT with GROUP BY.""" request = TableSelectRequest( database="mydb", table="orders", columns=[ColumnSpec(column="category")], aggregations=[AggregationSpec(function="COUNT", column="*", alias="count")], group_by=["category"], ) query = delta_service.build_select_query(request) assert "GROUP BY `category`" in query def test_select_with_having(self): """Test SELECT with HAVING clause.""" request = TableSelectRequest( database="mydb", table="orders", columns=[ColumnSpec(column="category")], aggregations=[AggregationSpec(function="COUNT", column="*", alias="count")], group_by=["category"], having=[FilterCondition(column="count", operator=">", value=10)], ) query = delta_service.build_select_query(request) assert "HAVING `count` > 10" in query def test_select_with_order_by(self): """Test SELECT with ORDER BY.""" request = TableSelectRequest( database="mydb", table="users", order_by=[ OrderBySpec(column="created_at", direction="DESC"), OrderBySpec(column="name", direction="ASC"), ], ) query = delta_service.build_select_query(request) assert "ORDER BY `created_at` DESC, `name` ASC" in query def test_select_with_pagination(self): """Test SELECT with custom pagination.""" request = TableSelectRequest( database="mydb", table="users", limit=50, offset=100 ) query = delta_service.build_select_query(request) assert "LIMIT 50 OFFSET 100" in query def test_select_without_pagination(self): """Test SELECT without pagination clause.""" request = TableSelectRequest(database="mydb", table="users") query = delta_service.build_select_query(request, include_pagination=False) assert "LIMIT" not in query assert "OFFSET" not in query def test_complex_query(self): """Test a complex query with multiple features.""" request = TableSelectRequest( database="sales", table="orders", columns=[ ColumnSpec(column="customer_id"), ColumnSpec(column="category"), ], aggregations=[ AggregationSpec(function="SUM", column="amount", alias="total"), AggregationSpec(function="COUNT", column="*", alias="order_count"), ], distinct=False, filters=[ FilterCondition(column="status", operator="=", value="completed"), FilterCondition(column="created_at", operator=">=", value="2023-01-01"), ], group_by=["customer_id", "category"], having=[FilterCondition(column="total", operator=">", value=1000)], order_by=[OrderBySpec(column="total", direction="DESC")], limit=20, offset=0, ) query = delta_service.build_select_query(request) # Verify all parts are present assert "SELECT `customer_id`, `category`" in query assert "SUM(`amount`) AS `total`" in query assert "COUNT(*) AS `order_count`" in query assert "FROM `sales`.`orders`" in query assert "WHERE `status` = 'completed'" in query assert "`created_at` >= '2023-01-01'" in query assert "GROUP BY `customer_id`, `category`" in query assert "HAVING `total` > 1000" in query assert "ORDER BY `total` DESC" in query assert "LIMIT 20 OFFSET 0" in query # ============================================================================= # Tests for Limit Enforcement # ============================================================================= class TestQueryLimitEnforcement: """Tests for query limit enforcement functions.""" def test_extract_limit_from_query_with_limit(self): """Test extracting LIMIT from query that has one.""" query = "SELECT * FROM users LIMIT 50" limit = delta_service._extract_limit_from_query(query) assert limit == 50 def test_extract_limit_from_query_without_limit(self): """Test extracting LIMIT from query that doesn't have one.""" query = "SELECT * FROM users" limit = delta_service._extract_limit_from_query(query) assert limit is None def test_extract_limit_case_insensitive(self): """Test that LIMIT extraction is case insensitive.""" query = "SELECT * FROM users limit 100" limit = delta_service._extract_limit_from_query(query) assert limit == 100 def test_enforce_query_limit_adds_limit(self): """Test that _enforce_query_limit adds LIMIT when missing.""" query = "SELECT * FROM users" result = delta_service._enforce_query_limit(query, max_rows=1000) assert "LIMIT 1000" in result def test_enforce_query_limit_keeps_acceptable_limit(self): """Test that acceptable LIMIT is kept.""" query = "SELECT * FROM users LIMIT 500" result = delta_service._enforce_query_limit(query, max_rows=1000) assert result == query def test_enforce_query_limit_rejects_excessive_limit(self): """Test that excessive LIMIT raises error.""" query = "SELECT * FROM users LIMIT 100000" with pytest.raises(SparkQueryError, match="exceeds maximum"): delta_service._enforce_query_limit(query, max_rows=50000) # ============================================================================= # Tests for count_delta_table with Mocked Spark # ============================================================================= class TestCountDeltaTable: """Tests for count_delta_table function with mocked Spark.""" def test_count_returns_correct_value(self, mock_spark_session): """Test that count_delta_table returns correct count.""" spark = mock_spark_session() spark.table.return_value.count.return_value = 12345 with patch("src.delta_lake.delta_service._check_exists", return_value=True): with patch( "src.delta_lake.delta_service.run_with_timeout", side_effect=lambda func, **kwargs: func(), ): with patch( "src.delta_lake.delta_service._get_from_cache", return_value=None ): with patch("src.delta_lake.delta_service._store_in_cache"): count = delta_service.count_delta_table( spark, "testdb", "testtable", use_cache=False ) assert count == 12345 def test_count_uses_cache_hit(self, mock_spark_session): """Test that count_delta_table uses cached value.""" spark = mock_spark_session() with patch( "src.delta_lake.delta_service._get_from_cache", return_value=[{"count": 9999}], ): count = delta_service.count_delta_table( spark, "testdb", "testtable", use_cache=True ) assert count == 9999 # Spark should not be called spark.table.assert_not_called() def test_count_database_not_found(self, mock_spark_session): """Test that missing database raises error.""" spark = mock_spark_session() with patch("src.delta_lake.delta_service.database_exists", return_value=False): with pytest.raises(DeltaDatabaseNotFoundError): delta_service.count_delta_table(spark, "nonexistent", "table") def test_count_table_not_found(self, mock_spark_session): """Test that missing table raises error.""" spark = mock_spark_session() with patch("src.delta_lake.delta_service.database_exists", return_value=True): with patch("src.delta_lake.delta_service.table_exists", return_value=False): with pytest.raises(DeltaTableNotFoundError): delta_service.count_delta_table(spark, "testdb", "nonexistent") def test_count_timeout_raises_error(self, mock_spark_session): """Test that timeout raises SparkTimeoutError.""" spark = mock_spark_session() with patch("src.delta_lake.delta_service._check_exists", return_value=True): with patch( "src.delta_lake.delta_service._get_from_cache", return_value=None ): with patch( "src.delta_lake.delta_service.run_with_timeout", side_effect=SparkTimeoutError(operation="count", timeout=30), ): with pytest.raises(SparkTimeoutError): delta_service.count_delta_table( spark, "testdb", "table", use_cache=False ) def test_count_spark_error_wrapped(self, mock_spark_session): """Test that Spark errors are wrapped in SparkOperationError.""" spark = mock_spark_session() spark.table.return_value.count.side_effect = Exception("Spark failed") with patch("src.delta_lake.delta_service._check_exists", return_value=True): with patch( "src.delta_lake.delta_service._get_from_cache", return_value=None ): with patch( "src.delta_lake.delta_service.run_with_timeout", side_effect=lambda func, **kwargs: func(), ): with pytest.raises(SparkOperationError, match="Spark failed"): delta_service.count_delta_table( spark, "testdb", "table", use_cache=False ) # ============================================================================= # Tests for sample_delta_table with Mocked Spark # ============================================================================= class TestSampleDeltaTable: """Tests for sample_delta_table function with mocked Spark.""" def test_sample_returns_rows(self, mock_spark_session, mock_spark_row): """Test that sample_delta_table returns sample rows.""" test_data = [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}] spark = mock_spark_session(table_results=test_data) with patch("src.delta_lake.delta_service._check_exists", return_value=True): with patch( "src.delta_lake.delta_service._get_from_cache", return_value=None ): with patch("src.delta_lake.delta_service._store_in_cache"): with patch( "src.delta_lake.delta_service.run_with_timeout", side_effect=lambda func, **kwargs: func(), ): result = delta_service.sample_delta_table( spark, "testdb", "testtable", limit=10, use_cache=False ) assert len(result) == 2 assert result[0]["name"] == "Alice" def test_sample_with_columns(self, mock_spark_session): """Test sampling with specific columns.""" spark = mock_spark_session() spark.table.return_value.select.return_value = spark.table.return_value with patch("src.delta_lake.delta_service._check_exists", return_value=True): with patch( "src.delta_lake.delta_service._get_from_cache", return_value=None ): with patch("src.delta_lake.delta_service._store_in_cache"): with patch( "src.delta_lake.delta_service.run_with_timeout", side_effect=lambda func, **kwargs: func(), ): delta_service.sample_delta_table( spark, "testdb", "testtable", columns=["id", "name"], use_cache=False, ) spark.table.return_value.select.assert_called_with(["id", "name"]) def test_sample_with_where_clause(self, mock_spark_session): """Test sampling with WHERE clause.""" spark = mock_spark_session() with patch("src.delta_lake.delta_service._check_exists", return_value=True): with patch( "src.delta_lake.delta_service._get_from_cache", return_value=None ): with patch("src.delta_lake.delta_service._store_in_cache"): with patch( "src.delta_lake.delta_service.run_with_timeout", side_effect=lambda func, **kwargs: func(), ): delta_service.sample_delta_table( spark, "testdb", "testtable", where_clause="id > 100", use_cache=False, ) spark.table.return_value.filter.assert_called_with("id > 100") def test_sample_invalid_limit_raises_error(self, mock_spark_session): """Test that invalid limit raises ValueError.""" spark = mock_spark_session() with pytest.raises(ValueError, match="Limit must be between"): delta_service.sample_delta_table(spark, "db", "table", limit=0) with pytest.raises(ValueError, match="Limit must be between"): delta_service.sample_delta_table(spark, "db", "table", limit=10000) def test_sample_uses_cache(self, mock_spark_session): """Test that sample uses cached results.""" spark = mock_spark_session() cached_data = [{"id": 1, "cached": True}] with patch( "src.delta_lake.delta_service._get_from_cache", return_value=cached_data ): result = delta_service.sample_delta_table( spark, "testdb", "testtable", use_cache=True ) assert result == cached_data spark.table.assert_not_called() # ============================================================================= # Tests for query_delta_table with Mocked Spark # ============================================================================= class TestQueryDeltaTable: """Tests for query_delta_table function with mocked Spark.""" def test_query_returns_results(self, mock_spark_session): """Test that query_delta_table returns query results.""" test_data = [{"count": 100}] spark = mock_spark_session(sql_results=test_data) with patch("src.delta_lake.delta_service._get_from_cache", return_value=None): with patch("src.delta_lake.delta_service._store_in_cache"): with patch( "src.delta_lake.delta_service.run_with_timeout", side_effect=lambda func, **kwargs: func(), ): result = delta_service.query_delta_table( spark, "SELECT COUNT(*) as count FROM users", use_cache=False ) assert len(result) == 1 assert result[0]["count"] == 100 def test_query_invalid_sql_rejected(self, mock_spark_session): """Test that invalid SQL is rejected.""" spark = mock_spark_session() with pytest.raises(SparkQueryError): delta_service.query_delta_table(spark, "DROP TABLE users") def test_query_adds_limit_when_missing(self, mock_spark_session): """Test that query gets LIMIT added when missing.""" spark = mock_spark_session() with patch("src.delta_lake.delta_service._get_from_cache", return_value=None): with patch("src.delta_lake.delta_service._store_in_cache"): with patch( "src.delta_lake.delta_service.run_with_timeout", side_effect=lambda func, **kwargs: func(), ): delta_service.query_delta_table( spark, "SELECT * FROM users", use_cache=False ) # Verify sql was called with LIMIT call_args = spark.sql.call_args[0][0] assert "LIMIT" in call_args def test_query_uses_cache(self, mock_spark_session): """Test that query uses cached results.""" spark = mock_spark_session() cached_data = [{"id": 1}] with patch( "src.delta_lake.delta_service._get_from_cache", return_value=cached_data ): result = delta_service.query_delta_table( spark, "SELECT * FROM users", use_cache=True ) assert result == cached_data spark.sql.assert_not_called() def test_query_timeout_raises_error(self, mock_spark_session): """Test that query timeout raises SparkTimeoutError.""" spark = mock_spark_session() with patch("src.delta_lake.delta_service._get_from_cache", return_value=None): with patch( "src.delta_lake.delta_service.run_with_timeout", side_effect=SparkTimeoutError(operation="query", timeout=30), ): with pytest.raises(SparkTimeoutError): delta_service.query_delta_table( spark, "SELECT * FROM users", use_cache=False ) # ============================================================================= # Tests for select_from_delta_table with Mocked Spark # ============================================================================= class TestSelectFromDeltaTable: """Tests for select_from_delta_table function with mocked Spark.""" def test_select_returns_response(self, mock_spark_session): """Test that select_from_delta_table returns TableSelectResponse.""" test_data = [{"id": 1}, {"id": 2}] spark = mock_spark_session(sql_results=test_data) # Mock the count query result count_row = MagicMock() count_row.__getitem__ = lambda self, key: 100 if key == "cnt" else None with patch("src.delta_lake.delta_service._check_exists", return_value=True): with patch( "src.delta_lake.delta_service._get_from_cache", return_value=None ): with patch("src.delta_lake.delta_service._store_in_cache"): with patch( "src.delta_lake.delta_service.run_with_timeout" ) as mock_timeout: # First call is count, second is data mock_timeout.side_effect = [ [count_row], # Count query result test_data, # Data query result (already converted) ] request = TableSelectRequest(database="testdb", table="users") result = delta_service.select_from_delta_table( spark, request, use_cache=False ) assert isinstance(result, TableSelectResponse) assert result.pagination.total_count == 100 def test_select_with_pagination(self, mock_spark_session): """Test select with pagination parameters.""" spark = mock_spark_session() count_row = MagicMock() count_row.__getitem__ = lambda self, key: 500 if key == "cnt" else None with patch("src.delta_lake.delta_service._check_exists", return_value=True): with patch( "src.delta_lake.delta_service._get_from_cache", return_value=None ): with patch("src.delta_lake.delta_service._store_in_cache"): with patch( "src.delta_lake.delta_service.run_with_timeout" ) as mock_timeout: mock_timeout.side_effect = [ [count_row], [], ] request = TableSelectRequest( database="testdb", table="users", limit=50, offset=100 ) result = delta_service.select_from_delta_table( spark, request, use_cache=False ) assert result.pagination.limit == 50 assert result.pagination.offset == 100 assert result.pagination.total_count == 500 assert result.pagination.has_more is True def test_select_uses_cache(self, mock_spark_session): """Test that select uses cached results.""" spark = mock_spark_session() cached_data = [ { "data": [{"id": 1}], "pagination": { "limit": 100, "offset": 0, "total_count": 1, "has_more": False, }, } ] with patch( "src.delta_lake.delta_service._get_from_cache", return_value=cached_data ): request = TableSelectRequest(database="testdb", table="users") result = delta_service.select_from_delta_table( spark, request, use_cache=True ) assert isinstance(result, TableSelectResponse) assert result.data == [{"id": 1}] spark.sql.assert_not_called() # ============================================================================= # Tests for Cache Key Generation # ============================================================================= class TestCacheKeyGeneration: """Tests for cache key generation.""" def test_generate_cache_key_deterministic(self): """Test that cache key generation is deterministic.""" params = {"database": "testdb", "table": "users"} key1 = delta_service._generate_cache_key(params) key2 = delta_service._generate_cache_key(params) assert key1 == key2 def test_generate_cache_key_different_for_different_params(self): """Test that different params produce different keys.""" params1 = {"database": "db1", "table": "users"} params2 = {"database": "db2", "table": "users"} key1 = delta_service._generate_cache_key(params1) key2 = delta_service._generate_cache_key(params2) assert key1 != key2 def test_generate_cache_key_order_independent(self): """Test that param order doesn't affect key.""" params1 = {"a": 1, "b": 2} params2 = {"b": 2, "a": 1} key1 = delta_service._generate_cache_key(params1) key2 = delta_service._generate_cache_key(params2) assert key1 == key2 # ============================================================================= # Concurrent Query Tests # ============================================================================= class TestConcurrentQueries: """Tests for concurrent query execution.""" def test_concurrent_count_queries(self, mock_spark_session, concurrent_executor): """Test concurrent count queries.""" def count_query(i): spark = mock_spark_session() spark.table.return_value.count.return_value = i * 100 with patch("src.delta_lake.delta_service._check_exists", return_value=True): with patch( "src.delta_lake.delta_service._get_from_cache", return_value=None ): with patch("src.delta_lake.delta_service._store_in_cache"): with patch( "src.delta_lake.delta_service.run_with_timeout", side_effect=lambda func, **kwargs: func(), ): return delta_service.count_delta_table( spark, f"db_{i}", "table", use_cache=False ) args_list = [(i,) for i in range(5)] results, exceptions = concurrent_executor(count_query, args_list) assert len(exceptions) == 0 assert sorted(results) == [0, 100, 200, 300, 400] def test_concurrent_queries_with_timeout( self, mock_spark_session, concurrent_executor ): """Test that concurrent queries respect timeouts.""" def query_with_timeout(i): spark = mock_spark_session(sql_results=[{"id": i}]) with patch( "src.delta_lake.delta_service._get_from_cache", return_value=None ): with patch("src.delta_lake.delta_service._store_in_cache"): with patch( "src.delta_lake.delta_service.run_with_timeout", side_effect=lambda func, **kwargs: func(), ): result = delta_service.query_delta_table( spark, f"SELECT {i} as id FROM dual", use_cache=False, ) return result[0]["id"] args_list = [(i,) for i in range(5)] results, exceptions = concurrent_executor(query_with_timeout, args_list) assert len(exceptions) == 0 assert sorted(results) == [0, 1, 2, 3, 4] # ============================================================================= # Constants Tests # ============================================================================= class TestServiceConstants: """Tests for service constants.""" def test_max_sample_rows(self): """Test MAX_SAMPLE_ROWS constant.""" assert delta_service.MAX_SAMPLE_ROWS == 1000 def test_max_query_rows(self): """Test MAX_QUERY_ROWS constant.""" assert delta_service.MAX_QUERY_ROWS == 50000 def test_max_select_rows(self): """Test MAX_SELECT_ROWS constant.""" assert delta_service.MAX_SELECT_ROWS == 10000 def test_cache_expiry_seconds(self): """Test CACHE_EXPIRY_SECONDS constant.""" assert delta_service.CACHE_EXPIRY_SECONDS == 3600 def test_forbidden_keywords_set(self): """Test FORBIDDEN_KEYWORDS contains expected values.""" assert "drop" in delta_service.FORBIDDEN_KEYWORDS assert "delete" in delta_service.FORBIDDEN_KEYWORDS assert "insert" in delta_service.FORBIDDEN_KEYWORDS assert "update" in delta_service.FORBIDDEN_KEYWORDS

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/BERDataLakehouse/datalake-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server