"""
PostgreSQL repository for raw database operations
"""
import time
import logging
import re
from typing import Dict, List, Optional, Any
from sqlalchemy import Engine, text, inspect
from sqlalchemy.exc import SQLAlchemyError
from shared.models import QueryResult, SchemaInfo
from shared.exceptions import DatabaseError, SQLSafetyError, SchemaError
logger = logging.getLogger(__name__)
class PostgresRepository:
"""
Pure database operations repository - no business logic
Handles raw SQL execution, schema introspection, and basic safety checks
"""
def __init__(self, engine: Engine):
self.engine = engine
self._inspector = inspect(engine)
def execute_query(self, sql: str, limit: bool = True) -> QueryResult:
"""
Execute raw SQL with comprehensive safety checks
Extracted and enhanced from llmDatabaseRouter.safe_run_sql()
Args:
sql: SQL query to execute
limit: Whether to apply row limit for safety
Returns:
QueryResult with execution details
"""
start_time = time.time()
try:
# Basic SQL validation
sql_clean = sql.strip()
if not sql_clean:
raise SQLSafetyError('Empty SQL query')
# Check for forbidden statements (comprehensive safety check)
forbidden_patterns = [
r'\b(INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|TRUNCATE)\b',
r';.*\w', # Multiple statements
r'/\*' # Block comments
]
for pattern in forbidden_patterns:
if re.search(pattern, sql_clean, re.IGNORECASE):
raise SQLSafetyError(f'Forbidden SQL pattern detected: {pattern}')
# Validate table existence before execution
validation_error = self.validate_tables_exist_in_sql(sql_clean)
if validation_error:
raise SQLSafetyError(validation_error)
# Auto-inject LIMIT if not present and not aggregate (safety feature)
if limit and not re.search(r'\bLIMIT\b', sql_clean, re.IGNORECASE):
if not re.search(r'\b(COUNT|SUM|AVG|MIN|MAX|GROUP BY)\b', sql_clean, re.IGNORECASE):
# Remove trailing semicolon before adding LIMIT
sql_clean = sql_clean.rstrip(';').strip()
sql_clean += ' LIMIT 100'
# Execute query
with self.engine.connect() as connection:
result = connection.execute(text(sql_clean))
rows = result.fetchall()
columns = list(result.keys()) if rows else []
# Convert to list of dicts with proper type handling
data = []
for row in rows:
row_dict = {}
for i, col in enumerate(columns):
value = row[i]
# Handle datetime objects
if hasattr(value, 'isoformat'):
value = value.isoformat()
# Handle UUID objects
elif hasattr(value, 'hex'):
value = str(value)
row_dict[col] = value
data.append(row_dict)
execution_time = time.time() - start_time
logger.info(f"Query executed successfully: {len(data)} rows, {execution_time:.2f}s")
return QueryResult(
success=True,
data=data,
rows_affected=len(data),
execution_time=execution_time,
query=sql_clean
)
except SQLSafetyError as e:
execution_time = time.time() - start_time
logger.error(f"SQL safety check failed: {e}")
return QueryResult(
success=False,
error=str(e),
execution_time=execution_time,
query=sql
)
except SQLAlchemyError as e:
execution_time = time.time() - start_time
error_msg = str(e)
logger.error(f"SQL execution failed: {error_msg}")
return QueryResult(
success=False,
error=error_msg,
execution_time=execution_time,
query=sql
)
except Exception as e:
execution_time = time.time() - start_time
error_msg = f"Unexpected error: {str(e)}"
logger.error(error_msg)
return QueryResult(
success=False,
error=error_msg,
execution_time=execution_time,
query=sql
)
def get_all_table_names(self) -> List[str]:
"""Get all table names from the database"""
try:
return self._inspector.get_table_names()
except Exception as e:
logger.error(f"Failed to get table names: {e}")
raise SchemaError(f"Could not retrieve table names: {e}")
def get_table_schema(self, table_name: str) -> Dict[str, Any]:
"""
Get schema information for a specific table
Args:
table_name: Name of the table
Returns:
Dictionary with table schema details
"""
try:
columns = self._inspector.get_columns(table_name)
primary_keys = self._inspector.get_pk_constraint(table_name)
foreign_keys = self._inspector.get_foreign_keys(table_name)
indexes = self._inspector.get_indexes(table_name)
return {
'table_name': table_name,
'columns': columns,
'primary_keys': primary_keys,
'foreign_keys': foreign_keys,
'indexes': indexes
}
except Exception as e:
logger.error(f"Failed to get schema for table {table_name}: {e}")
raise SchemaError(f"Could not retrieve schema for table {table_name}: {e}")
def get_all_schemas(self) -> List[str]:
"""Get all schema names from the database"""
try:
with self.engine.connect() as conn:
result = conn.execute(text("""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('information_schema', 'pg_catalog', 'pg_toast')
ORDER BY schema_name
"""))
return [row[0] for row in result]
except Exception as e:
logger.error(f"Failed to get schema names: {e}")
raise SchemaError(f"Could not retrieve schema names: {e}")
def validate_tables_exist(self, table_names: List[str]) -> Optional[str]:
"""
Validate that all specified tables exist
Args:
table_names: List of table names to validate
Returns:
None if all tables exist, error message if any are missing
"""
try:
existing_tables = set(self.get_all_table_names())
missing_tables = [table for table in table_names if table not in existing_tables]
if missing_tables:
return f"Tables do not exist: {', '.join(missing_tables)}"
return None
except Exception as e:
return f"Could not validate table existence: {e}"
def validate_tables_exist_in_sql(self, sql: str) -> Optional[str]:
"""
Validate that all tables referenced in the SQL actually exist
Extracted from llmDatabaseRouter._validate_table_existence()
Args:
sql: SQL query to validate
Returns:
None if all tables exist, error message if any are missing
"""
try:
# Get all actual table names from the database
actual_tables = set([table.lower() for table in self.get_all_table_names()])
# Extract table names from SQL using simple regex patterns
# This covers most common SQL patterns
table_patterns = [
r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_]*)', # FROM table_name
r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_]*)', # JOIN table_name
r'\bINTO\s+([a-zA-Z_][a-zA-Z0-9_]*)', # INTO table_name
r'\bUPDATE\s+([a-zA-Z_][a-zA-Z0-9_]*)', # UPDATE table_name
]
referenced_tables = set()
sql_upper = sql.upper()
for pattern in table_patterns:
matches = re.findall(pattern, sql_upper, re.IGNORECASE)
for match in matches:
# Clean the table name
table_name = match.strip().lower()
# Skip SQL keywords and common aliases
if table_name not in ['select', 'where', 'order', 'group', 'having', 'as', 'on', 'and', 'or']:
referenced_tables.add(table_name)
# Check if all referenced tables exist
missing_tables = referenced_tables - actual_tables
if missing_tables:
return f"Referenced tables do not exist: {', '.join(sorted(missing_tables))}"
return None
except Exception as e:
return f"Could not validate table references in SQL: {e}"
def get_row_count(self, table_name: str) -> int:
"""Get approximate row count for a table"""
try:
result = self.execute_query(f"SELECT COUNT(*) FROM {table_name}")
if result.success and result.data:
return result.data[0]['count']
return 0
except Exception as e:
logger.warning(f"Could not get row count for {table_name}: {e}")
return 0
def check_vector_extension(self) -> bool:
"""Check if pgvector extension is available"""
try:
result = self.execute_query("""
SELECT EXISTS(
SELECT 1 FROM pg_extension WHERE extname = 'vector'
) as has_vector
""")
if result.success and result.data:
return result.data[0]['has_vector']
return False
except Exception as e:
logger.warning(f"Could not check vector extension: {e}")
return False
def _is_sql_safe_to_run(self, sql: str) -> bool:
"""
Basic SQL safety check
Args:
sql: SQL query to validate
Returns:
True if safe to run, False otherwise
"""
sql_lower = sql.lower().strip()
# Block dangerous operations
dangerous_keywords = [
'drop table', 'drop database', 'drop schema',
'delete from', 'truncate', 'alter table',
'create table', 'create database', 'create schema',
'insert into', 'update ', 'grant ', 'revoke '
]
for keyword in dangerous_keywords:
if keyword in sql_lower:
return False
# Must start with SELECT
if not sql_lower.startswith('select'):
return False
return True
def _should_limit_query(self, sql: str) -> bool:
"""Check if query should have a limit applied"""
sql_lower = sql.lower()
return 'limit' not in sql_lower and 'count(' not in sql_lower
def _add_limit_to_query(self, sql: str, limit: int = 1000) -> str:
"""Add LIMIT clause to a query"""
return f"{sql.rstrip(';')} LIMIT {limit}"