"""
Service layer for interacting with Delta Lake tables via Spark.
"""
import hashlib
import json
import logging
import re
from typing import Any, Dict, List, Optional
import sqlparse
from pyspark.sql import SparkSession
from src.cache.redis_cache import get_cached_value, set_cached_value
from src.delta_lake.data_store import database_exists, table_exists
from src.service.exceptions import (
DeltaDatabaseNotFoundError,
DeltaTableNotFoundError,
SparkOperationError,
SparkQueryError,
SparkTimeoutError,
)
from src.service.timeouts import run_with_timeout, DEFAULT_SPARK_COLLECT_TIMEOUT
from src.service.models import (
AggregationSpec,
ColumnSpec,
FilterCondition,
JoinClause,
PaginationInfo,
TableSelectRequest,
TableSelectResponse,
)
logger = logging.getLogger(__name__)
# Row limits to prevent OOM and ensure service stability
MAX_SAMPLE_ROWS = 1000
MAX_QUERY_ROWS = 50000 # Maximum rows returned by arbitrary SQL queries
MAX_SELECT_ROWS = 10000 # Maximum rows for structured SELECT (enforced in model)
CACHE_EXPIRY_SECONDS = 3600 # Cache results for 1 hour by default
# Common SQL keywords that might indicate destructive operations
FORBIDDEN_KEYWORDS = {
# NOTE: This might create false positives, legitemate queries might include these keywords
# e.g. "SELECT * FROM orders ORDER BY created_at DESC"
"drop",
"truncate",
"delete",
"insert",
"update",
"create",
"alter",
"merge",
"replace",
"rename",
"vacuum",
}
DISALLOW_SQL_META_CHARS = {
"--",
"/*",
"*/",
";",
"\\",
}
ALLOWED_STATEMENTS = {
"select",
}
FORBIDDEN_POSTGRESQL_SCHEMAS = {
# NOTE: This might create false positives, legitemate queries might include these schemas
# e.g. "SELECT * FROM jpg_files"
# NOTE: may also need to expand this if other databases are used
"pg_",
"pg_catalog",
"information_schema",
}
def _extract_limit_from_query(query: str) -> Optional[int]:
"""
Extract the LIMIT value from a SQL query if present.
Args:
query: SQL query string
Returns:
The LIMIT value as int, or None if no LIMIT clause found
"""
# Simple regex to find LIMIT clause - handles most common cases
# Pattern matches: LIMIT <number> with optional whitespace
limit_pattern = re.compile(r"\bLIMIT\s+(\d+)\b", re.IGNORECASE)
match = limit_pattern.search(query)
if match:
return int(match.group(1))
return None
def _enforce_query_limit(query: str, max_rows: int = MAX_QUERY_ROWS) -> str:
"""
Ensure a SQL query has a LIMIT clause that doesn't exceed max_rows.
If the query has no LIMIT, adds one. If it has a LIMIT > max_rows,
raises an error to inform the user.
Args:
query: SQL query string
max_rows: Maximum allowed rows (default: MAX_QUERY_ROWS)
Returns:
Query with enforced LIMIT clause
Raises:
SparkQueryError: If query has LIMIT exceeding max_rows
"""
existing_limit = _extract_limit_from_query(query)
if existing_limit is not None:
if existing_limit > max_rows:
raise SparkQueryError(
f"Query LIMIT ({existing_limit}) exceeds maximum allowed ({max_rows}). "
f"Please reduce your LIMIT or use pagination."
)
# Query already has acceptable LIMIT
return query
# No LIMIT clause - add one
# Strip trailing whitespace and add LIMIT
query = query.rstrip()
logger.info(f"Adding LIMIT {max_rows} to query without explicit limit")
return f"{query} LIMIT {max_rows}"
def _check_query_is_valid(query: str) -> bool:
"""
Check if a query is valid.
Please note that this function is not a comprehensive SQL query validator.
It only checks for basic syntax and structure.
MCP server should be configured to use read-only user for both PostgreSQL and MinIO.
"""
try:
# NOTE: sqlparse does not validate SQL syntax; what happens with unexpected syntax is unknown
statements = sqlparse.parse(query)
except Exception as e:
raise SparkQueryError(f"Query {query} is not a valid SQL query: {e}")
if len(statements) != 1:
raise SparkQueryError(f"Query {query} must contain exactly one statement")
statement = statements[0]
# NOTE: statement might have subqueries, we only check the main statement here!
if statement.get_type().lower() not in ALLOWED_STATEMENTS:
raise SparkQueryError(
f"Query {query} must be one of the following: {', '.join(ALLOWED_STATEMENTS)}, got {statement.get_type()}"
)
if any(schema in query.lower() for schema in FORBIDDEN_POSTGRESQL_SCHEMAS):
raise SparkQueryError(
f"Query {query} contains forbidden PostgreSQL schema: {', '.join(FORBIDDEN_POSTGRESQL_SCHEMAS)}"
)
if any(char in query for char in DISALLOW_SQL_META_CHARS):
raise SparkQueryError(
f"Query {query} contains disallowed metacharacter: {', '.join(char for char in DISALLOW_SQL_META_CHARS if char in query)}"
)
if any(keyword in query.lower() for keyword in FORBIDDEN_KEYWORDS):
raise SparkQueryError(
f"Query {query} contains forbidden keyword: {', '.join(keyword for keyword in FORBIDDEN_KEYWORDS if keyword in query.lower())}"
)
return True
def _check_exists(database: str, table: str) -> bool:
"""
Check if a table exists in a database.
"""
if not database_exists(database):
raise DeltaDatabaseNotFoundError(f"Database [{database}] not found")
if not table_exists(database, table):
raise DeltaTableNotFoundError(
f"Table [{table}] not found in database [{database}]"
)
return True
def _generate_cache_key(params: Dict[str, Any]) -> str:
"""
Generate a cache key from parameters.
"""
# Convert parameters to a sorted JSON string to ensure consistency
param_str = json.dumps(params, sort_keys=True)
# Create a hash of the parameters to avoid very long keys
param_hash = hashlib.md5(param_str.encode()).hexdigest()
return param_hash
def _get_from_cache(namespace: str, cache_key: str) -> Optional[List[Dict[str, Any]]]:
"""
Try to get data from Redis cache.
"""
return get_cached_value(namespace=namespace, cache_key=cache_key)
def _store_in_cache(
namespace: str,
cache_key: str,
data: List[Dict[str, Any]],
ttl: int = CACHE_EXPIRY_SECONDS,
) -> None:
"""
Store data in Redis cache.
"""
set_cached_value(namespace=namespace, cache_key=cache_key, data=data, ttl=ttl)
def count_delta_table(
spark: SparkSession, database: str, table: str, use_cache: bool = True
) -> int:
"""
Counts the number of rows in a specific Delta table.
Args:
spark: The SparkSession object.
database: The database (namespace) containing the table.
table: The name of the Delta table.
use_cache: Whether to use the redis cache to store the result.
Returns:
The number of rows in the table.
"""
namespace = "count"
params = {"database": database, "table": table}
cache_key = _generate_cache_key(params)
if use_cache:
cached_result = _get_from_cache(namespace, cache_key)
if cached_result:
logger.info(f"Cache hit for count on {database}.{table}")
return cached_result[0]["count"]
_check_exists(database, table)
full_table_name = f"`{database}`.`{table}`"
logger.info(f"Counting rows in {full_table_name}")
try:
# Use timeout wrapper for count operation
count = run_with_timeout(
lambda: spark.table(full_table_name).count(),
timeout_seconds=DEFAULT_SPARK_COLLECT_TIMEOUT,
operation_name=f"count_{database}.{table}",
)
logger.info(f"{full_table_name} has {count} rows.")
if use_cache:
_store_in_cache(namespace, cache_key, [{"count": count}])
return count
except SparkTimeoutError:
raise # Re-raise timeout errors as-is
except Exception as e:
logger.error(f"Error counting rows in {full_table_name}: {e}")
raise SparkOperationError(
f"Failed to count rows in {full_table_name}: {str(e)}"
)
def sample_delta_table(
spark: SparkSession,
database: str,
table: str,
limit: int = 10,
columns: List[str] | None = None,
where_clause: str | None = None,
use_cache: bool = True,
) -> List[Dict[str, Any]]:
"""
Retrieves a sample of rows from a specific Delta table.
Args:
spark: The SparkSession object.
database: The database (namespace) containing the table.
table: The name of the Delta table.
limit: The maximum number of rows to return.
columns: The columns to return. If None, all columns will be returned.
where_clause: A SQL WHERE clause to filter the rows. e.g. "id > 100"
use_cache: Whether to use the redis cache to store the result.
Returns:
A list of dictionaries, where each dictionary represents a row.
"""
namespace = "sample"
params = {
"database": database,
"table": table,
"limit": limit,
"columns": sorted(columns) if columns else None,
"where_clause": where_clause,
}
cache_key = _generate_cache_key(params)
if use_cache:
cached_result = _get_from_cache(namespace, cache_key)
if cached_result:
logger.info(f"Cache hit for sample on {database}.{table}")
return cached_result
if not 0 < limit <= MAX_SAMPLE_ROWS:
raise ValueError(f"Limit must be between 1 and {MAX_SAMPLE_ROWS}, got {limit}")
_check_exists(database, table)
full_table_name = f"`{database}`.`{table}`"
logger.info(f"Sampling {limit} rows from {full_table_name}")
try:
df = spark.table(full_table_name)
if columns:
df = df.select(columns)
if where_clause:
equivalent_query = f"SELECT * FROM {full_table_name} WHERE {where_clause}"
_check_query_is_valid(equivalent_query)
df = df.filter(where_clause)
df = df.limit(limit)
# Use timeout wrapper for collect operation
results = run_with_timeout(
lambda: [row.asDict() for row in df.collect()],
timeout_seconds=DEFAULT_SPARK_COLLECT_TIMEOUT,
operation_name=f"sample_{database}.{table}",
)
logger.info(f"Sampled {len(results)} rows.")
if use_cache:
_store_in_cache(namespace, cache_key, results)
return results
except SparkTimeoutError:
raise # Re-raise timeout errors as-is
except Exception as e:
logger.error(f"Error sampling rows from {full_table_name}: {e}")
raise SparkOperationError(
f"Failed to sample rows from {full_table_name}: {str(e)}"
)
def query_delta_table(
spark: SparkSession, query: str, use_cache: bool = True
) -> List[Dict[str, Any]]:
"""
Executes a SQL query against a specific Delta table after basic validation.
Note: Queries are automatically limited to MAX_QUERY_ROWS (50,000) to prevent
OOM errors. If your query needs more rows, use pagination via the select endpoint.
Args:
spark: The SparkSession object.
query: The SQL query string to execute.
use_cache: Whether to use the redis cache to store the result.
Returns:
A list of dictionaries, where each dictionary represents a row.
Raises:
SparkQueryError: If query validation fails or LIMIT exceeds MAX_QUERY_ROWS
SparkOperationError: If query execution fails
"""
# Validate query structure first
_check_query_is_valid(query)
# Enforce row limit to prevent OOM
query = _enforce_query_limit(query, MAX_QUERY_ROWS)
namespace = "query"
params = {"query": query}
cache_key = _generate_cache_key(params)
if use_cache:
cached_result = _get_from_cache(namespace, cache_key)
if cached_result:
logger.info(
f"Cache hit for query: {query[:50]}{'...' if len(query) > 50 else ''}"
)
return cached_result
logger.info(f"Executing validated query: {query}")
try:
df = spark.sql(query)
# Use timeout wrapper for collect operation
results = run_with_timeout(
lambda: [row.asDict() for row in df.collect()],
timeout_seconds=DEFAULT_SPARK_COLLECT_TIMEOUT,
operation_name="query_delta_table",
)
logger.info(f"Query returned {len(results)} rows.")
if use_cache:
_store_in_cache(namespace, cache_key, results)
return results
except SparkTimeoutError:
raise # Re-raise timeout errors as-is
except Exception as e:
logger.error(f"Error executing query: {e}")
raise SparkOperationError(f"Failed to execute query: {str(e)}")
# ---
# Query Builder Functions
# ---
# Valid identifier pattern: alphanumeric and underscores only
VALID_IDENTIFIER_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
def _validate_identifier(name: str, identifier_type: str = "identifier") -> None:
"""
Validate that an identifier (table, column, database name) is safe.
Args:
name: The identifier to validate.
identifier_type: Type of identifier for error messages.
Raises:
SparkQueryError: If the identifier is invalid.
"""
if not name or not VALID_IDENTIFIER_PATTERN.match(name):
raise SparkQueryError(
f"Invalid {identifier_type}: '{name}'. "
"Identifiers must start with a letter or underscore and contain "
"only alphanumeric characters and underscores."
)
def _escape_value(value: Any) -> str:
"""
Escape a value for safe SQL use.
Args:
value: The value to escape.
Returns:
The escaped value as a SQL-safe string.
"""
if value is None:
return "NULL"
elif isinstance(value, bool):
return "TRUE" if value else "FALSE"
elif isinstance(value, (int, float)):
return str(value)
elif isinstance(value, str):
# Escape single quotes by doubling them
escaped = value.replace("'", "''")
return f"'{escaped}'"
else:
# For other types, convert to string and escape
escaped = str(value).replace("'", "''")
return f"'{escaped}'"
def _build_column_expression(col: ColumnSpec) -> str:
"""
Build a column expression from a ColumnSpec.
Args:
col: The column specification.
Returns:
SQL column expression string.
"""
parts = []
if col.table_alias:
_validate_identifier(col.table_alias, "table alias")
parts.append(f"`{col.table_alias}`.")
_validate_identifier(col.column, "column")
parts.append(f"`{col.column}`")
if col.alias:
_validate_identifier(col.alias, "column alias")
parts.append(f" AS `{col.alias}`")
return "".join(parts)
def _build_aggregation_expression(agg: AggregationSpec) -> str:
"""
Build an aggregation expression from an AggregationSpec.
Args:
agg: The aggregation specification.
Returns:
SQL aggregation expression string.
"""
if agg.column == "*":
if agg.function != "COUNT":
raise SparkQueryError(
f"Aggregation function {agg.function} does not support '*'. "
"Only COUNT(*) is valid."
)
expr = "COUNT(*)"
else:
_validate_identifier(agg.column, "aggregation column")
expr = f"{agg.function}(`{agg.column}`)"
if agg.alias:
_validate_identifier(agg.alias, "aggregation alias")
expr += f" AS `{agg.alias}`"
return expr
def _build_filter_condition(condition: FilterCondition) -> str:
"""
Build a single filter condition for WHERE or HAVING clause.
Args:
condition: The filter condition specification.
Returns:
SQL condition string.
"""
_validate_identifier(condition.column, "filter column")
col = f"`{condition.column}`"
op = condition.operator
if op in ("IS NULL", "IS NOT NULL"):
return f"{col} {op}"
elif op in ("IN", "NOT IN"):
if not condition.values:
raise SparkQueryError(f"Operator {op} requires 'values' to be provided")
escaped_values = [_escape_value(v) for v in condition.values]
values_str = ", ".join(escaped_values)
return f"{col} {op} ({values_str})"
elif op == "BETWEEN":
if not condition.values or len(condition.values) != 2:
raise SparkQueryError(
"Operator BETWEEN requires exactly 2 values in 'values'"
)
return (
f"{col} BETWEEN {_escape_value(condition.values[0])} "
f"AND {_escape_value(condition.values[1])}"
)
else:
# Standard comparison operators: =, !=, <, >, <=, >=, LIKE, NOT LIKE
if condition.value is None:
raise SparkQueryError(f"Operator {op} requires 'value' to be provided")
return f"{col} {op} {_escape_value(condition.value)}"
def _build_filter_clause(
conditions: List[FilterCondition], clause_type: str = "WHERE"
) -> str:
"""
Build a WHERE or HAVING clause from filter conditions.
Args:
conditions: List of filter conditions.
clause_type: Either "WHERE" or "HAVING".
Returns:
SQL clause string (including the keyword) or empty string if no conditions.
"""
if not conditions:
return ""
condition_strs = [_build_filter_condition(c) for c in conditions]
return f" {clause_type} " + " AND ".join(condition_strs)
def _build_join_clause(join: JoinClause, main_table: str) -> str:
"""
Build a JOIN clause.
Args:
join: The join specification.
main_table: The name of the main table for the ON clause.
Returns:
SQL JOIN clause string.
"""
_validate_identifier(join.database, "join database")
_validate_identifier(join.table, "join table")
_validate_identifier(join.on_left_column, "join left column")
_validate_identifier(join.on_right_column, "join right column")
join_table = f"`{join.database}`.`{join.table}`"
join_type = join.join_type
return (
f" {join_type} JOIN {join_table} "
f"ON `{main_table}`.`{join.on_left_column}` = "
f"`{join.table}`.`{join.on_right_column}`"
)
def build_select_query(
request: TableSelectRequest, include_pagination: bool = True
) -> str:
"""
Build a SQL SELECT query from a TableSelectRequest.
Args:
request: The structured select request.
include_pagination: Whether to include LIMIT/OFFSET clauses.
Returns:
The constructed SQL query string.
"""
_validate_identifier(request.database, "database")
_validate_identifier(request.table, "table")
# Build SELECT clause
select_parts = []
# Add DISTINCT keyword if requested
distinct_keyword = "DISTINCT " if request.distinct else ""
# Add columns
if request.columns:
select_parts.extend([_build_column_expression(c) for c in request.columns])
# Add aggregations
if request.aggregations:
select_parts.extend(
[_build_aggregation_expression(a) for a in request.aggregations]
)
# If no columns or aggregations, select all
if not select_parts:
select_clause = f"SELECT {distinct_keyword}*"
else:
select_clause = f"SELECT {distinct_keyword}" + ", ".join(select_parts)
# Build FROM clause
main_table = f"`{request.database}`.`{request.table}`"
from_clause = f" FROM {main_table}"
# Build JOIN clauses
join_clauses = ""
if request.joins:
for join in request.joins:
# Validate join table exists
_check_exists(join.database, join.table)
join_clauses += _build_join_clause(join, request.table)
# Build WHERE clause
where_clause = (
_build_filter_clause(request.filters, "WHERE") if request.filters else ""
)
# Build GROUP BY clause
group_by_clause = ""
if request.group_by:
for col in request.group_by:
_validate_identifier(col, "group by column")
group_by_cols = ", ".join([f"`{col}`" for col in request.group_by])
group_by_clause = f" GROUP BY {group_by_cols}"
# Build HAVING clause
having_clause = (
_build_filter_clause(request.having, "HAVING") if request.having else ""
)
# Build ORDER BY clause
order_by_clause = ""
if request.order_by:
order_parts = []
for order in request.order_by:
_validate_identifier(order.column, "order by column")
order_parts.append(f"`{order.column}` {order.direction}")
order_by_clause = " ORDER BY " + ", ".join(order_parts)
# Build LIMIT/OFFSET clause
pagination_clause = ""
if include_pagination:
pagination_clause = f" LIMIT {request.limit} OFFSET {request.offset}"
# Combine all parts
query = (
select_clause
+ from_clause
+ join_clauses
+ where_clause
+ group_by_clause
+ having_clause
+ order_by_clause
+ pagination_clause
)
return query
def select_from_delta_table(
spark: SparkSession, request: TableSelectRequest, use_cache: bool = True
) -> TableSelectResponse:
"""
Execute a structured SELECT query against Delta tables with pagination.
Args:
spark: The SparkSession object.
request: The structured select request.
use_cache: Whether to use the redis cache to store the result.
Returns:
TableSelectResponse with data and pagination info.
"""
namespace = "select"
# Generate cache key from request parameters
params = request.model_dump()
cache_key = _generate_cache_key(params)
if use_cache:
cached_result = _get_from_cache(namespace, cache_key)
if cached_result:
logger.info(f"Cache hit for select on {request.database}.{request.table}")
# Reconstruct response from cached data
return TableSelectResponse(
data=cached_result[0]["data"],
pagination=PaginationInfo(**cached_result[0]["pagination"]),
)
# Validate main table exists
_check_exists(request.database, request.table)
# Build and execute count query (without pagination) for total count
count_query = f"SELECT COUNT(*) as cnt FROM ({build_select_query(request, include_pagination=False)})"
logger.info(f"Executing count query: {count_query}")
try:
# Use timeout wrapper for count query
count_result = run_with_timeout(
lambda: spark.sql(count_query).collect(),
timeout_seconds=DEFAULT_SPARK_COLLECT_TIMEOUT,
operation_name=f"count_select_{request.database}.{request.table}",
)
total_count = count_result[0]["cnt"]
except SparkTimeoutError:
raise # Re-raise timeout errors as-is
except Exception as e:
logger.error(f"Error executing count query: {e}")
raise SparkOperationError(f"Failed to execute count query: {str(e)}")
# Build and execute main query with pagination
main_query = build_select_query(request, include_pagination=True)
logger.info(f"Executing select query: {main_query}")
try:
df = spark.sql(main_query)
# Use timeout wrapper for data query
results = run_with_timeout(
lambda: [row.asDict() for row in df.collect()],
timeout_seconds=DEFAULT_SPARK_COLLECT_TIMEOUT,
operation_name=f"select_{request.database}.{request.table}",
)
logger.info(f"Select query returned {len(results)} rows.")
# Calculate pagination info
has_more = (request.offset + len(results)) < total_count
pagination = PaginationInfo(
limit=request.limit,
offset=request.offset,
total_count=total_count,
has_more=has_more,
)
response = TableSelectResponse(data=results, pagination=pagination)
if use_cache:
# Store serializable version in cache
cache_data = [
{
"data": results,
"pagination": pagination.model_dump(),
}
]
_store_in_cache(namespace, cache_key, cache_data)
return response
except SparkTimeoutError:
raise # Re-raise timeout errors as-is
except Exception as e:
logger.error(f"Error executing select query: {e}")
raise SparkOperationError(f"Failed to execute select query: {str(e)}")