import calendar
from datetime import datetime, timedelta
from typing import Any
from dateutil.relativedelta import relativedelta
from lib import timeutils
def _validate_alias(alias: str | None) -> None:
"""Validate table alias for SQL injection prevention.
Args:
alias: Table alias to validate
Raises:
ValueError: If alias contains invalid characters
"""
if alias is not None and not alias.replace('_', '').isalnum():
raise ValueError(
f"Invalid alias: {alias}. Must contain only alphanumeric characters and underscores"
)
def _validate_date_range(
start: datetime,
end: datetime,
max_delta: relativedelta | timedelta,
) -> None:
"""Validate date range to prevent OOM and query plan issues.
Args:
start: Start datetime
end: End datetime
max_delta: Maximum allowed delta
Raises:
ValueError: If start > end or date range exceeds maximum allowed delta
"""
if start > end:
raise ValueError(f"Start date {start} must be <= end date {end}")
max_end = start + max_delta
if max_end < end:
actual_delta = end - start
actual_delta_str = timeutils.pretty_timedelta(actual_delta)
max_delta_str = timeutils.pretty_timedelta(max_delta)
raise ValueError(
f"Date range too large: from {start} to {end} ({actual_delta_str}). "
f"Maximum allowed: {max_delta_str}. "
f"Consider using a coarser granularity or splitting the query."
)
def clear_duplicates(
rows: list[dict[str, Any]], conflict_fields: list[str] | set[str]
) -> list[dict[str, Any]]:
"""Remove duplicate rows based on conflict fields.
Args:
rows: List of dictionaries to deduplicate
conflict_fields: Fields to use for uniqueness check
Returns:
List with duplicates removed (last occurrence wins)
Raises:
KeyError: If any row is missing a conflict field
"""
index: dict[tuple[Any, ...], dict[str, Any]] = {}
for row in rows:
# Use tuple of values as key - no collision possible
key = tuple(row[field] for field in conflict_fields)
index[key] = row
return list(index.values())
def list_to_filter(filter_name: str, data: list[Any] | None) -> str:
"""Build SQL IN filter with proper escaping.
Args:
filter_name: Column name (alphanumeric, underscore, dot only)
data: List of values to filter
Returns:
SQL IN clause or empty string if data is empty
Raises:
ValueError: If filter_name contains invalid characters
"""
# Validate filter_name to prevent SQL injection
# Allow alphanumeric, underscore, and dot (for qualified names like table.column)
if not filter_name.replace('_', '').replace('.', '').isalnum():
raise ValueError(
f"Invalid filter_name: {filter_name}. "
f"Must contain only alphanumeric characters, underscores, and dots"
)
if not data:
return ''
# Escape single quotes in values by doubling them (SQL standard)
escaped_data = []
for p in data:
escaped_value = str(p).replace("'", "''")
escaped_data.append(f"'{escaped_value}'")
return f'{filter_name} IN ({",".join(escaped_data)})'
def build_partition_filter_minute(start: datetime, end: datetime, alias: str | None = None) -> str:
"""Build partition filter with minute granularity for partitioned tables.
Generates SQL WHERE clause that combines year/month/day/hour/minute predicates while
collapsing consecutive minutes/hours/days into compact `IN`/`BETWEEN` blocks.
Args:
start: Start datetime (inclusive, seconds/micros truncated automatically).
end: End datetime (inclusive; seconds/micros truncated automatically).
alias: Optional table alias to prefix column names (alphanumeric + underscore only).
Returns:
SQL filter clause with OR-separated groups. Examples:
>>> build_partition_filter_minute(
... datetime(2025, 1, 1, 10, 0), datetime(2025, 1, 1, 10, 2)
... )
(year = 2025 AND month = 1 AND day = 1 AND (hour = 10 AND minute <= 2))
>>> build_partition_filter_minute(
... datetime(2025, 1, 1, 23, 58), datetime(2025, 1, 2, 0, 1), alias='p'
... )
(p.year = 2025 AND p.month = 1 AND p.day = 1 AND (p.hour = 23 AND p.minute >= 58))
OR (p.year = 2025 AND p.month = 1 AND p.day = 2 AND (p.hour = 0 AND p.minute <= 1))
>>> build_partition_filter_minute(
... datetime(2025, 1, 1, 23, 58), datetime(2025, 1, 5, 0, 1)
... )
(year = 2025 AND month = 1 AND day = 1 AND (hour = 23 AND minute >= 58))
OR (year = 2025 AND month = 1 AND day IN (2, 3, 4))
OR (year = 2025 AND month = 1 AND day = 5 AND (hour = 0 AND minute <= 1))
Raises:
ValueError: If start > end or alias contains invalid characters.
"""
_validate_alias(alias)
if start > end:
raise ValueError(f'Start date {start} must be <= end date {end}')
filters: list[str] = []
prefix = f'{alias}.' if alias else ''
day_cursor = start.replace(second=0, microsecond=0)
day_limit = end.replace(second=0, microsecond=0)
full_days: list[int] = []
current_year: int | None = None
current_month: int | None = None
def append_day_group() -> None:
nonlocal full_days, current_year, current_month
if not full_days or current_year is None or current_month is None:
return
if len(full_days) == 1:
filters.append(
f"({prefix}year = {current_year} AND {prefix}month = {current_month} "
f"AND {prefix}day = {full_days[0]})"
)
else:
days = ', '.join(str(d) for d in full_days)
filters.append(
f"({prefix}year = {current_year} AND {prefix}month = {current_month} "
f"AND {prefix}day IN ({days}))"
)
full_days = []
current_year = None
current_month = None
while day_cursor <= day_limit:
day_start = day_cursor
day_end = min(day_limit, day_cursor.replace(hour=23, minute=59))
day_clause = (
f"{prefix}year = {day_cursor.year} AND {prefix}month = {day_cursor.month} "
f"AND {prefix}day = {day_cursor.day}"
)
if (
day_start.hour == 0
and day_start.minute == 0
and day_end.hour == 23
and day_end.minute == 59
):
if current_year != day_start.year or current_month != day_start.month:
append_day_group()
current_year = day_start.year
current_month = day_start.month
full_days.append(day_start.day)
else:
append_day_group()
hour_filters: list[str] = []
hour_cursor = day_start
full_hours: list[int] = []
while hour_cursor <= day_end:
hour_end_dt = min(day_end, hour_cursor.replace(minute=59))
start_minute = hour_cursor.minute
end_minute = hour_end_dt.minute
if start_minute == 0 and end_minute == 59:
full_hours.append(hour_cursor.hour)
else:
if full_hours:
if len(full_hours) == 1:
hour_filters.append(f'{prefix}hour = {full_hours[0]}')
else:
hours = ', '.join(str(h) for h in full_hours)
hour_filters.append(f'{prefix}hour IN ({hours})')
full_hours = []
if start_minute == end_minute:
minute_clause = f'{prefix}minute = {start_minute}'
elif start_minute == 0:
minute_clause = f'{prefix}minute <= {end_minute}'
elif end_minute == 59:
minute_clause = f'{prefix}minute >= {start_minute}'
else:
minute_clause = f'{prefix}minute BETWEEN {start_minute} AND {end_minute}'
hour_filters.append(f'{prefix}hour = {hour_cursor.hour} AND {minute_clause}')
hour_cursor = (hour_cursor + timedelta(hours=1)).replace(minute=0)
if full_hours:
if len(full_hours) == 1:
hour_filters.append(f'{prefix}hour = {full_hours[0]}')
else:
hours = ', '.join(str(h) for h in full_hours)
hour_filters.append(f'{prefix}hour IN ({hours})')
filters.append(f"({day_clause} AND ({' OR '.join(hour_filters)}))")
day_cursor = (day_cursor + timedelta(days=1)).replace(hour=0, minute=0)
if full_days:
append_day_group()
return ' OR '.join(filters)
def build_partition_filter_hour(start: datetime, end: datetime, alias: str | None = None) -> str:
"""Build partition filter with hour granularity for partitioned tables.
Generates SQL WHERE clause to filter partitioned data by year/month/day/hour.
Uses BETWEEN for consecutive hours within same day.
Args:
start: Start datetime (inclusive, rounded down to hour)
end: End datetime (exclusive, rounded up to hour if not on boundary)
alias: Optional table alias to prefix column names (alphanumeric + underscore only)
Returns:
SQL filter clause with OR-separated day+hour conditions.
Example: (year = 2025 AND month = 11 AND day = 4 AND hour BETWEEN 10 AND 12)
OR (year = 2025 AND month = 11 AND day = 5 AND hour BETWEEN 0 AND 2)
Raises:
ValueError: If start > end or alias contains invalid characters
"""
_validate_alias(alias)
start_hour = timeutils.truncate_hour(start)
end_hour = timeutils.truncate_hour(end)
if end > end_hour:
end_hour += timedelta(hours=1)
_validate_date_range(start_hour, end_hour, relativedelta(days=30))
filters: list[str] = []
current = start_hour
prefix = f'{alias}.' if alias else ''
while current < end_hour:
day_key = current.date()
day_hours: list[int] = []
# Collect all hours for current day
while current < end_hour:
current_day = current.date()
if current_day != day_key:
break
day_hours.append(current.hour)
current += timedelta(hours=1)
if len(day_hours) == 24:
# Full day - no hour filter needed
filters.append(
f"({prefix}year = {day_key.year} AND {prefix}month = {day_key.month} "
f"AND {prefix}day = {day_key.day})"
)
else:
if len(day_hours) == 1:
hours_clause = f"{prefix}hour = {day_hours[0]}"
else:
hours_clause = f"{prefix}hour BETWEEN {day_hours[0]} AND {day_hours[-1]}"
filters.append(
f"({prefix}year = {day_key.year} AND {prefix}month = {day_key.month} "
f"AND {prefix}day = {day_key.day} AND {hours_clause})"
)
return ' OR '.join(filters)
def build_partition_filter_day(start: datetime, end: datetime, alias: str | None = None) -> str:
"""Build partition filter with day granularity for partitioned tables.
Generates SQL WHERE clause to filter partitioned data by year/month/day.
Uses BETWEEN for consecutive days within same month.
Args:
start: Start datetime (inclusive, rounded down to day)
end: End datetime (exclusive, rounded up to day if not on boundary)
alias: Optional table alias to prefix column names (alphanumeric + underscore only)
Returns:
SQL filter clause with OR-separated month+day conditions.
Example: (year = 2025 AND month = 11 AND day BETWEEN 4 AND 6)
Raises:
ValueError: If start > end or alias contains invalid characters
"""
_validate_alias(alias)
start_day = timeutils.truncate_day(start)
end_day = timeutils.truncate_day(end)
if end > end_day:
end_day += timedelta(days=1)
_validate_date_range(start_day, end_day, relativedelta(days=365))
filters: list[str] = []
current = start_day
prefix = f'{alias}.' if alias else ''
while current < end_day:
month_year = (current.year, current.month)
month_days: list[int] = []
# Collect all days for current month
while current < end_day and (current.year, current.month) == month_year:
month_days.append(current.day)
current += timedelta(days=1)
# Check if full month
days_in_month = calendar.monthrange(month_year[0], month_year[1])[1]
if len(month_days) == days_in_month and month_days[0] == 1:
# Full month - no day filter needed
filters.append(f"({prefix}year = {month_year[0]} AND {prefix}month = {month_year[1]})")
else:
if len(month_days) == 1:
day_clause = f"{prefix}day = {month_days[0]}"
else:
day_clause = f"{prefix}day BETWEEN {month_days[0]} AND {month_days[-1]}"
filters.append(
f"({prefix}year = {month_year[0]} AND {prefix}month = {month_year[1]} "
f"AND {day_clause})"
)
return ' OR '.join(filters)
def build_partition_filter_month(start: datetime, end: datetime, alias: str | None = None) -> str:
"""Build partition filter with month granularity for partitioned tables.
Generates SQL WHERE clause to filter partitioned data by year/month.
Args:
start: Start datetime (inclusive, rounded down to month)
end: End datetime (exclusive, rounded up to month if not on boundary)
alias: Optional table alias to prefix column names (alphanumeric + underscore only)
Returns:
SQL filter clause with OR-separated month conditions.
Example: (year = 2025 AND month = 10)
OR (year = 2025 AND month = 11)
Raises:
ValueError: If start > end or alias contains invalid characters
"""
_validate_alias(alias)
start_month = timeutils.truncate_month(start)
end_month = timeutils.truncate_month(end)
if end > end_month:
end_month += relativedelta(months=1)
_validate_date_range(start_month, end_month, relativedelta(months=24))
filters: list[str] = []
current = start_month
prefix = f'{alias}.' if alias else ''
while current < end_month:
filters.append(f"({prefix}year = {current.year} AND {prefix}month = {current.month})")
# Move to next month
if current.month == 12:
current = current.replace(year=current.year + 1, month=1)
else:
current = current.replace(month=current.month + 1)
return ' OR '.join(filters)
def strip_sql_comments(query: str) -> str:
"""
Strip SQL comments from query template, keeping only first comment block.
Reduces network traffic and DB logs by removing documentation headers
while preserving query identifier comment.
"""
result = ''
comment_head_kept = False
for line in query.split('\n'):
stripped = line.strip()
if not stripped:
continue
if stripped.startswith('--') and comment_head_kept:
continue
if stripped.startswith('--'):
comment_head_kept = True
result += line + '\n'
return result