postgres_integration.py•45 kB
"""
PostgreSQL Integration Module
Provides comprehensive PostgreSQL database integration functionality including
safe SQL execution, schema inspection, and general database operations.
"""
import json
import logging
import re
from typing import Any, Dict, List, Optional
from sqlalchemy import text
from sqlalchemy.engine import Engine
logger = logging.getLogger(__name__)
class PostgreSQLIntegration:
"""Handles PostgreSQL database operations with safety and utility features"""
def __init__(self, engine: Engine, db_key: str = "default", table_config: Optional[Dict[str, Any]] = None):
self.engine = engine
self.db_key = db_key
self.table_config = table_config or self._create_default_table_config()
def _create_default_table_config(self) -> Dict[str, Any]:
"""Create a default table configuration that learns from the actual database"""
return {
"semantic_mappings": {
# These will be populated dynamically from actual table discovery
},
"query_patterns": {
"count_keywords": ["count", "how many", "number of", "total"],
"recent_keywords": ["recent", "latest", "last", "newest", "current"],
"list_keywords": ["show", "list", "display", "get", "find", "all"],
"search_keywords": ["search", "find", "where", "with", "containing"]
},
"date_column_patterns": ["date", "time", "created", "updated", "start", "end"],
"name_column_patterns": ["name", "title", "label", "description"],
"id_column_patterns": ["id", "key", "pk"],
"priority_tables": [], # Will be populated based on table discovery
"relationship_hints": {} # Will be discovered from foreign keys
}
def set_table_config(self, config: Dict[str, Any]) -> None:
"""Set a custom table configuration"""
self.table_config = {**self._create_default_table_config(), **config}
def _discover_table_semantics(self) -> Dict[str, Any]:
"""Automatically discover table semantics from the database schema"""
try:
tables = self.get_all_table_names()
semantic_mappings = {}
relationship_hints = {}
priority_tables = []
# Categorize tables based on naming patterns
table_categories = {
"entities": [], # Main entity tables (users, products, etc.)
"relationships": [], # Junction/linking tables
"metadata": [], # Lookup/reference tables
"temporal": [] # Time-based data tables
}
for table in tables:
table_lower = table.lower()
schema = self.get_table_schema(table)
if 'error' not in schema:
columns = schema.get('columns', [])
column_names = [col['name'].lower() for col in columns]
# Detect entity tables (have ID and name-like columns)
has_id = any(any(pattern in col_name for pattern in self.table_config["id_column_patterns"])
for col_name in column_names)
has_name = any(any(pattern in col_name for pattern in self.table_config["name_column_patterns"])
for col_name in column_names)
# Detect temporal tables (have date columns)
has_date = any(any(pattern in col_name for pattern in self.table_config["date_column_patterns"])
for col_name in column_names)
# Detect relationship tables (have multiple foreign keys)
foreign_keys = schema.get('foreign_keys', [])
if len(foreign_keys) >= 2:
table_categories["relationships"].append(table)
# Map relationship patterns
for fk in foreign_keys:
ref_table = fk.get('references_table')
if ref_table and ref_table not in relationship_hints:
relationship_hints[ref_table] = []
if ref_table:
relationship_hints[ref_table].append({
'via_table': table,
'column': fk.get('column'),
'references': fk.get('references_column')
})
elif has_id and has_name:
table_categories["entities"].append(table)
priority_tables.append(table) # Main entities get priority
# Create semantic mapping based on table name
semantic_mappings[table] = {
'type': 'entity',
'common_names': [table, table.rstrip('s'), table + 's'], # Handle plurals
'primary_display_columns': [col['name'] for col in columns
if any(pattern in col['name'].lower()
for pattern in self.table_config["name_column_patterns"])][:3],
'date_columns': [col['name'] for col in columns
if any(pattern in col['name'].lower()
for pattern in self.table_config["date_column_patterns"])],
'searchable_columns': [col['name'] for col in columns
if col['type'] in ['text', 'varchar', 'character varying']][:5]
}
elif has_date:
table_categories["temporal"].append(table)
if table not in priority_tables:
priority_tables.append(table)
else:
table_categories["metadata"].append(table)
return {
'semantic_mappings': semantic_mappings,
'relationship_hints': relationship_hints,
'priority_tables': priority_tables[:10], # Limit to top 10
'table_categories': table_categories
}
except Exception as e:
logger.warning(f"Error discovering table semantics: {e}")
return {
'semantic_mappings': {},
'relationship_hints': {},
'priority_tables': [],
'table_categories': {}
}
def safe_run_sql(self, sql: str, limit_safe: bool = True) -> Dict[str, Any]:
"""Safely execute SQL with guardrails"""
try:
# Basic SQL validation
sql_clean = sql.strip()
if not sql_clean:
return {'error': 'Empty SQL query'}
# Check for forbidden statements
forbidden_patterns = [
r'\b(INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|TRUNCATE)\b',
r';.*\w', # Multiple statements
r'/\*' # Block comments
]
for pattern in forbidden_patterns:
if re.search(pattern, sql_clean, re.IGNORECASE):
return {'error': f'Forbidden SQL pattern detected: {pattern}'}
# Validate table existence before execution
validation_error = self._validate_table_existence(sql_clean)
if validation_error:
return {'error': validation_error}
# Auto-inject LIMIT if not present and not aggregate
if limit_safe and not re.search(r'\bLIMIT\b', sql_clean, re.IGNORECASE):
if not re.search(r'\b(COUNT|SUM|AVG|MIN|MAX|GROUP BY)\b', sql_clean, re.IGNORECASE):
# Remove trailing semicolon before adding LIMIT
sql_clean = sql_clean.rstrip(';').strip()
sql_clean += ' LIMIT 100'
# Execute query
with self.engine.connect() as connection:
result = connection.execute(text(sql_clean))
rows = result.fetchall()
columns = list(result.keys()) if rows else []
# Convert to list of dicts
data = []
for row in rows:
row_dict = {}
for i, col in enumerate(columns):
value = row[i]
# Handle datetime objects
if hasattr(value, 'isoformat'):
value = value.isoformat()
# Handle UUID objects
elif hasattr(value, 'hex'):
value = str(value)
row_dict[col] = value
data.append(row_dict)
return {
'success': True,
'data': data,
'columns': columns,
'row_count': len(data),
'sql_executed': sql_clean
}
except Exception as e:
error_msg = str(e)
logger.error(f"Error executing SQL in {self.db_key}: {error_msg}")
return {'error': error_msg, 'sql_attempted': sql}
def get_all_table_names(self) -> List[str]:
"""Get all table names from the current database"""
try:
with self.engine.connect() as connection:
result = connection.execute(text("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
ORDER BY table_name
"""))
return [row[0] for row in result]
except Exception as e:
logger.warning(f"Could not get table names for {self.db_key}: {e}")
return []
def get_table_schema(self, table_name: str) -> Dict[str, Any]:
"""Get detailed schema information for a specific table"""
try:
with self.engine.connect() as connection:
# Get column information
columns_result = connection.execute(text("""
SELECT
column_name,
data_type,
is_nullable,
column_default,
character_maximum_length
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = :table_name
ORDER BY ordinal_position
"""), {'table_name': table_name})
columns = []
for row in columns_result:
columns.append({
'name': row.column_name,
'type': row.data_type,
'nullable': row.is_nullable == 'YES',
'default': row.column_default,
'max_length': row.character_maximum_length
})
# Get primary key information
pk_result = connection.execute(text("""
SELECT kcu.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
WHERE tc.table_schema = 'public'
AND tc.table_name = :table_name
AND tc.constraint_type = 'PRIMARY KEY'
ORDER BY kcu.ordinal_position
"""), {'table_name': table_name})
primary_keys = [row.column_name for row in pk_result]
# Get foreign key information
fk_result = connection.execute(text("""
SELECT
kcu.column_name,
ccu.table_name AS foreign_table,
ccu.column_name AS foreign_column
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
JOIN information_schema.constraint_column_usage ccu
ON ccu.constraint_name = tc.constraint_name
WHERE tc.table_schema = 'public'
AND tc.table_name = :table_name
AND tc.constraint_type = 'FOREIGN KEY'
"""), {'table_name': table_name})
foreign_keys = []
for row in fk_result:
foreign_keys.append({
'column': row.column_name,
'references_table': row.foreign_table,
'references_column': row.foreign_column
})
return {
'table_name': table_name,
'columns': columns,
'primary_keys': primary_keys,
'foreign_keys': foreign_keys
}
except Exception as e:
logger.error(f"Error getting schema for table {table_name}: {e}")
return {'error': str(e)}
def get_database_schema(self) -> Dict[str, Any]:
"""Get comprehensive schema information for the entire database"""
try:
tables = self.get_all_table_names()
schema_info = {
'database': self.db_key,
'table_count': len(tables),
'tables': {}
}
for table in tables:
table_schema = self.get_table_schema(table)
if 'error' not in table_schema:
schema_info['tables'][table] = table_schema
return schema_info
except Exception as e:
logger.error(f"Error getting database schema: {e}")
return {'error': str(e)}
def check_table_exists(self, table_name: str) -> bool:
"""Check if a specific table exists"""
try:
with self.engine.connect() as connection:
result = connection.execute(text("""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = :table_name
)
"""), {'table_name': table_name})
return result.scalar()
except Exception as e:
logger.warning(f"Error checking table existence: {e}")
return False
def get_table_row_count(self, table_name: str) -> Optional[int]:
"""Get the number of rows in a table"""
try:
if not self.check_table_exists(table_name):
return None
# Validate table name to prevent SQL injection
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', table_name):
logger.warning(f"Invalid table name format: {table_name}")
return None
with self.engine.connect() as connection:
# Use validated table name in query (table names can't be parameterized)
result = connection.execute(text(f"SELECT COUNT(*) FROM {table_name}"))
return result.scalar()
except Exception as e:
logger.warning(f"Error getting row count for {table_name}: {e}")
return None
def get_table_size(self, table_name: str) -> Dict[str, Any]:
"""Get detailed size information for a table"""
try:
if not self.check_table_exists(table_name):
return {'error': f'Table {table_name} does not exist'}
# Validate table name to prevent SQL injection
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', table_name):
return {'error': f'Invalid table name format: {table_name}'}
with self.engine.connect() as connection:
result = connection.execute(text(f"""
SELECT
pg_size_pretty(pg_total_relation_size('{table_name}')) as total_size,
pg_size_pretty(pg_relation_size('{table_name}')) as table_size,
pg_size_pretty(pg_total_relation_size('{table_name}') - pg_relation_size('{table_name}')) as index_size,
(SELECT COUNT(*) FROM information_schema.columns WHERE table_name = '{table_name}') as column_count
"""))
row = result.fetchone()
if row:
return {
'table_name': table_name,
'total_size': row.total_size,
'table_size': row.table_size,
'index_size': row.index_size,
'column_count': row.column_count,
'row_count': self.get_table_row_count(table_name)
}
return {'error': 'Could not retrieve size information'}
except Exception as e:
logger.error(f"Error getting table size for {table_name}: {e}")
return {'error': str(e)}
def get_database_stats(self) -> Dict[str, Any]:
"""Get comprehensive database statistics"""
try:
with self.engine.connect() as connection:
# Get database size
db_size_result = connection.execute(text("""
SELECT pg_size_pretty(pg_database_size(current_database())) as database_size
"""))
db_size = db_size_result.scalar()
# Get table statistics
tables_result = connection.execute(text("""
SELECT
schemaname,
relname as tablename,
n_tup_ins as inserts,
n_tup_upd as updates,
n_tup_del as deletes,
n_live_tup as live_tuples,
n_dead_tup as dead_tuples
FROM pg_stat_user_tables
WHERE schemaname = 'public'
ORDER BY n_live_tup DESC
LIMIT 10
"""))
table_stats = []
for row in tables_result:
table_stats.append({
'table_name': row.tablename,
'inserts': row.inserts,
'updates': row.updates,
'deletes': row.deletes,
'live_tuples': row.live_tuples,
'dead_tuples': row.dead_tuples
})
return {
'database_size': db_size,
'table_count': len(self.get_all_table_names()),
'top_tables_by_size': table_stats
}
except Exception as e:
logger.error(f"Error getting database stats: {e}")
return {'error': str(e)}
def get_postgresql_config(self) -> Dict[str, Any]:
"""Get PostgreSQL server configuration settings"""
try:
with self.engine.connect() as connection:
# Get current configuration settings
config_result = connection.execute(text("""
SELECT name, setting, unit, category, short_desc, context, vartype, source, min_val, max_val, enumvals, boot_val, reset_val, pending_restart
FROM pg_settings
ORDER BY category, name
"""))
config_settings = {}
categories = {}
for row in config_result:
category = row.category or 'Uncategorized'
if category not in categories:
categories[category] = []
setting_info = {
'name': row.name,
'current_value': row.setting,
'unit': row.unit,
'description': row.short_desc,
'context': row.context,
'type': row.vartype,
'source': row.source,
'min_value': row.min_val,
'max_value': row.max_val,
'allowed_values': row.enumvals,
'boot_value': row.boot_val,
'reset_value': row.reset_val,
'pending_restart': row.pending_restart
}
config_settings[row.name] = setting_info
categories[category].append(row.name)
# Get some key server information
server_info = connection.execute(text("""
SELECT version() as version,
current_database() as database,
current_user as user,
inet_server_addr() as server_address,
inet_server_port() as server_port
""")).fetchone()
return {
'server_info': {
'version': server_info.version,
'database': server_info.database,
'user': server_info.user,
'address': str(server_info.server_address) if server_info.server_address else None,
'port': server_info.server_port
},
'configuration_settings': config_settings,
'settings_by_category': categories,
'total_settings': len(config_settings)
}
except Exception as e:
logger.error(f"Error getting PostgreSQL configuration: {e}")
return {'error': str(e)}
def execute_explain(self, sql: str) -> Dict[str, Any]:
"""Execute EXPLAIN for a query to analyze performance"""
try:
# Basic SQL validation (same as safe_run_sql)
sql_clean = sql.strip()
if not sql_clean:
return {'error': 'Empty SQL query'}
# Only allow SELECT statements for EXPLAIN
if not re.match(r'^\s*SELECT\b', sql_clean, re.IGNORECASE):
return {'error': 'EXPLAIN only allowed for SELECT statements'}
# Validate table existence
validation_error = self._validate_table_existence(sql_clean)
if validation_error:
return {'error': validation_error}
with self.engine.connect() as connection:
# Execute EXPLAIN
explain_sql = f"EXPLAIN (ANALYZE, BUFFERS, FORMAT JSON) {sql_clean}"
result = connection.execute(text(explain_sql))
explain_result = result.fetchone()[0]
return {
'success': True,
'query': sql_clean,
'explain_plan': explain_result[0] # JSON result
}
except Exception as e:
logger.error(f"Error executing EXPLAIN: {e}")
return {'error': str(e)}
def generate_sql(self, question: str, schema_slice: Optional[Dict[str, Any]] = None) -> str:
"""Generate SQL from natural language question and schema context"""
# If no schema slice provided, get basic table info
if not schema_slice:
schema_slice = {'tables': []}
# First, get all available tables in this database to ensure accurate schema context
try:
all_tables = self.get_all_table_names()
available_tables_text = "Available tables: " + ", ".join(sorted(all_tables))
# Get actual column info for key tables
table_info = self._get_key_table_info()
available_tables_text += "\n" + table_info
except:
available_tables_text = "Available tables: Could not retrieve table list"
# Build a comprehensive prompt for the LLM
schema_description = []
for table_info in schema_slice.get('tables', []):
if isinstance(table_info, dict) and 'payload' in table_info:
payload = table_info['payload']
if isinstance(payload, str):
try:
payload = json.loads(payload)
except:
pass
if isinstance(payload, dict):
table_name = table_info.get('table_name', 'unknown')
columns = payload.get('columns', [])
col_descriptions = []
for col in columns:
col_desc = f"{col.get('column_name', 'unknown')} ({col.get('data_type', 'unknown')})"
col_descriptions.append(col_desc)
if col_descriptions:
schema_description.append(f"Table {table_name}: {', '.join(col_descriptions)}")
schema_text = "\n".join(schema_description) if schema_description else "No detailed schema information available"
# Get discovered table semantics for intelligent prompting
discovered = self._discover_table_semantics()
relationship_hints = discovered.get('relationship_hints', {})
semantic_mappings = discovered.get('semantic_mappings', {})
# Build relationship context dynamically
relationship_context = ""
if relationship_hints:
relationship_context = "\nDISCOVERED TABLE RELATIONSHIPS:\n"
for table, relationships in list(relationship_hints.items())[:3]:
for rel in relationships[:2]:
relationship_context += f"- {table} connects to other tables via {rel['via_table']} (join on {rel['column']})\n"
# Build semantic context
semantic_context = ""
if semantic_mappings:
semantic_context = "\nTABLE SEMANTIC MAPPINGS:\n"
for table, mapping in list(semantic_mappings.items())[:5]:
common_names = mapping.get('common_names', [])
if common_names:
semantic_context += f"- {table} can be referenced as: {', '.join(common_names)}\n"
date_columns = mapping.get('date_columns', [])
if date_columns:
semantic_context += f" Date columns in {table}: {', '.join(date_columns)}\n"
prompt = f"""You are a PostgreSQL expert. Generate a safe SELECT query for the user's question.
{available_tables_text}
Detailed schema for relevant tables:
{schema_text}
{relationship_context}
{semantic_context}
User question: {question}
CRITICAL REQUIREMENTS:
- ONLY use tables that exist in the "Available tables" list above
- Only SELECT statements (no INSERT, UPDATE, DELETE, DROP, etc.)
- Use proper JOIN syntax where needed based on discovered relationships above
- Include appropriate WHERE clauses
- Add LIMIT clauses for large result sets (LIMIT 50 or less)
- Use PostgreSQL-specific syntax and functions
- For temporal queries, use the appropriate date columns shown in the semantic mappings
- For search queries, use ILIKE '%term%' for flexible matching
- When joining tables, follow the relationship patterns discovered above
- Return ONLY the SQL query, no explanations or markdown formatting
SQL Query:"""
try:
# Try to use LLM integration if available
try:
# Check if we can import the LLM integration from the current system
import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# Try different potential LLM integration paths
llm_service = None
try:
from llmintegrationsystem import LLMIntegrationSystem
llm_service = LLMIntegrationSystem()
except ImportError:
try:
from AI.ai import query_vectorize_generate
# Use the external AI service
result = query_vectorize_generate(self.db_key, prompt, 'ollama/gpt-oss:20b')
if result and isinstance(result, dict):
generated_sql = result.get('text', '') or result.get('response', '') or str(result)
elif result:
generated_sql = str(result)
else:
generated_sql = None
except ImportError:
llm_service = None
generated_sql = None
# If we have an LLM service, use it
if llm_service and not generated_sql:
try:
import asyncio
# Create event loop if needed
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Use the LLM integration
response = loop.run_until_complete(llm_service.chat(prompt))
generated_sql = response if isinstance(response, str) else str(response)
except Exception as e:
logger.warning(f"LLM integration failed: {e}")
generated_sql = None
if generated_sql and generated_sql.strip():
# Clean up the response - remove markdown formatting and quotes if present
sql_clean = generated_sql.strip()
# Remove markdown formatting
if sql_clean.startswith('```sql'):
sql_clean = sql_clean.replace('```sql', '').replace('```', '').strip()
elif sql_clean.startswith('```'):
sql_clean = sql_clean.replace('```', '').strip()
# Remove surrounding quotes if present (common LLM issue)
if sql_clean.startswith('"') and sql_clean.endswith('"'):
sql_clean = sql_clean[1:-1].strip()
elif sql_clean.startswith("'") and sql_clean.endswith("'"):
sql_clean = sql_clean[1:-1].strip()
# Remove trailing semicolon if present (will be added by safe_run_sql if needed)
sql_clean = sql_clean.rstrip(';').strip()
return sql_clean
else:
# If LLM doesn't provide useful SQL, return a simple query
return "SELECT 'No specific pattern matched for this question' as message"
except Exception as e:
logger.warning(f"LLM SQL generation failed for {self.db_key}: {e}")
# Fallback to heuristics
return self._generate_sql_heuristics(question, schema_slice)
except Exception as e:
logger.error(f"Error in generate_sql: {e}")
# Fallback to heuristics
return self._generate_sql_heuristics(question, schema_slice)
def _get_key_table_info(self) -> str:
"""Get key table information for SQL generation context using discovered semantics"""
try:
# Discover table semantics dynamically
discovered = self._discover_table_semantics()
priority_tables = discovered.get('priority_tables', [])
semantic_mappings = discovered.get('semantic_mappings', {})
# If no priority tables discovered, fall back to first few tables
if not priority_tables:
all_tables = self.get_all_table_names()
priority_tables = all_tables[:5] # Just take first 5 tables
info_parts = []
for table_name in priority_tables[:5]: # Limit to 5 tables for prompt size
schema = self.get_table_schema(table_name)
if 'error' not in schema:
columns = schema.get('columns', [])
# Prioritize important columns based on patterns
important_columns = []
regular_columns = []
for col in columns:
col_name_lower = col['name'].lower()
is_important = (
any(pattern in col_name_lower for pattern in self.table_config["id_column_patterns"]) or
any(pattern in col_name_lower for pattern in self.table_config["name_column_patterns"]) or
any(pattern in col_name_lower for pattern in self.table_config["date_column_patterns"])
)
col_info = f"{col['name']} ({col['type']})"
if is_important:
important_columns.append(col_info)
else:
regular_columns.append(col_info)
# Combine important columns first, then regular ones (limited)
all_col_info = important_columns + regular_columns[:max(0, 8 - len(important_columns))]
if all_col_info:
table_info = f"{table_name}: {', '.join(all_col_info)}"
# Add semantic context if available
if table_name in semantic_mappings:
mapping = semantic_mappings[table_name]
common_names = mapping.get('common_names', [])
if common_names:
table_info += f" (also known as: {', '.join(common_names)})"
info_parts.append(table_info)
# Add relationship hints
relationship_hints = discovered.get('relationship_hints', {})
if relationship_hints:
info_parts.append("\nTable Relationships:")
for table, relationships in list(relationship_hints.items())[:3]: # Limit relationships
for rel in relationships[:2]: # Limit relations per table
info_parts.append(f"- {table} connects to other tables via {rel['via_table']}")
return "\n".join(info_parts) if info_parts else "No detailed table information available"
except Exception as e:
logger.warning(f"Error getting key table info: {e}")
return "Could not retrieve detailed table information"
def _generate_sql_heuristics(self, question: str, schema_slice: Dict[str, Any]) -> str:
"""Generate SQL using heuristic patterns when LLM is not available"""
try:
question_lower = question.lower()
# Discover table semantics to make intelligent choices
discovered = self._discover_table_semantics()
priority_tables = discovered.get('priority_tables', [])
semantic_mappings = discovered.get('semantic_mappings', {})
# Get all tables as fallback
all_tables = self.get_all_table_names()
tables_to_search = priority_tables if priority_tables else all_tables
# Pattern matching for query types
patterns = self.table_config["query_patterns"]
# COUNT queries
if any(keyword in question_lower for keyword in patterns["count_keywords"]):
target_table = self._find_relevant_table(question_lower, tables_to_search, semantic_mappings)
if target_table:
return f"SELECT COUNT(*) as count FROM {target_table}"
else:
return f"SELECT COUNT(*) as count FROM {all_tables[0]}" if all_tables else "SELECT 1 as count"
# RECENT/LATEST queries
elif any(keyword in question_lower for keyword in patterns["recent_keywords"]):
target_table = self._find_relevant_table(question_lower, tables_to_search, semantic_mappings)
if target_table:
# Find appropriate date column for ordering
date_column = self._find_date_column(target_table)
order_clause = f"ORDER BY {date_column} DESC" if date_column else "ORDER BY 1 DESC"
return f"SELECT * FROM {target_table} {order_clause} LIMIT 10"
else:
return f"SELECT * FROM {all_tables[0]} LIMIT 10" if all_tables else "SELECT 'No tables available' as message"
# LIST/SHOW queries
elif any(keyword in question_lower for keyword in patterns["list_keywords"]):
target_table = self._find_relevant_table(question_lower, tables_to_search, semantic_mappings)
if target_table:
return f"SELECT * FROM {target_table} LIMIT 20"
else:
return f"SELECT * FROM {all_tables[0]} LIMIT 20" if all_tables else "SELECT 'No tables available' as message"
# SEARCH queries - try to find searchable columns
elif any(keyword in question_lower for keyword in patterns["search_keywords"]):
target_table = self._find_relevant_table(question_lower, tables_to_search, semantic_mappings)
if target_table:
searchable_cols = self._find_searchable_columns(target_table)
if searchable_cols:
# Extract potential search terms from question
search_terms = self._extract_search_terms_from_question(question_lower)
if search_terms:
search_term = search_terms[0] # Use first term
where_clauses = [f"{col} ILIKE '%{search_term}%'" for col in searchable_cols[:2]]
where_clause = " OR ".join(where_clauses)
return f"SELECT * FROM {target_table} WHERE {where_clause} LIMIT 20"
return f"SELECT * FROM {target_table} LIMIT 20"
else:
return f"SELECT * FROM {all_tables[0]} LIMIT 20" if all_tables else "SELECT 'No tables available' as message"
# Default fallback - try to find most relevant table
else:
target_table = self._find_relevant_table(question_lower, tables_to_search, semantic_mappings)
if target_table:
return f"SELECT * FROM {target_table} LIMIT 10"
else:
return f"SELECT * FROM {all_tables[0]} LIMIT 10" if all_tables else "SELECT 'No tables available' as message"
except Exception as e:
logger.error(f"Error in heuristic SQL generation: {e}")
return "SELECT 'Error generating SQL query' as message"
def _find_relevant_table(self, question: str, tables: List[str], semantic_mappings: Dict[str, Any]) -> Optional[str]:
"""Find the most relevant table for a question using semantic mappings"""
question_words = question.lower().split()
# Check semantic mappings first
for table, mapping in semantic_mappings.items():
common_names = mapping.get('common_names', [])
for name in common_names:
if name.lower() in question:
return table
# Fallback to direct table name matching
for table in tables:
table_lower = table.lower()
# Check if table name (or singular/plural variants) appear in question
variants = [table_lower, table_lower.rstrip('s'), table_lower + 's']
if any(variant in question for variant in variants):
return table
# Return first priority table as last resort
return tables[0] if tables else None
def _find_date_column(self, table_name: str) -> Optional[str]:
"""Find a suitable date column for ordering"""
try:
schema = self.get_table_schema(table_name)
if 'error' not in schema:
columns = schema.get('columns', [])
date_patterns = self.table_config["date_column_patterns"]
for col in columns:
col_name_lower = col['name'].lower()
col_type_lower = col['type'].lower()
# Check if it's a date/time column
if ('date' in col_type_lower or 'time' in col_type_lower or 'timestamp' in col_type_lower):
if any(pattern in col_name_lower for pattern in date_patterns):
return col['name']
# Fallback to any date/time column
for col in columns:
col_type_lower = col['type'].lower()
if ('date' in col_type_lower or 'time' in col_type_lower or 'timestamp' in col_type_lower):
return col['name']
except Exception as e:
logger.warning(f"Error finding date column for {table_name}: {e}")
return None
def _find_searchable_columns(self, table_name: str) -> List[str]:
"""Find columns suitable for text search"""
try:
schema = self.get_table_schema(table_name)
if 'error' not in schema:
columns = schema.get('columns', [])
searchable = []
for col in columns:
col_type_lower = col['type'].lower()
# Text-based columns that can be searched
if any(text_type in col_type_lower for text_type in ['text', 'varchar', 'char', 'string']):
searchable.append(col['name'])
return searchable[:3] # Limit to 3 columns
except Exception as e:
logger.warning(f"Error finding searchable columns for {table_name}: {e}")
return []
def _extract_search_terms_from_question(self, question: str) -> List[str]:
"""Extract potential search terms from a question"""
# Remove common query words
stop_words = set(['show', 'find', 'get', 'search', 'where', 'with', 'containing', 'for', 'all', 'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'by', 'from'])
words = [word.strip('",.:;!?()') for word in question.split() if len(word) > 2 and word.lower() not in stop_words]
return words[:3] # Return first 3 meaningful words
def test_connection(self) -> Dict[str, Any]:
"""Test the database connection"""
try:
with self.engine.connect() as connection:
result = connection.execute(text("SELECT version()"))
version = result.scalar()
return {
'success': True,
'database_version': version,
'connection_active': True
}
except Exception as e:
return {
'success': False,
'error': str(e),
'connection_active': False
}
def _validate_table_existence(self, sql: str) -> Optional[str]:
"""Validate that all tables referenced in the SQL actually exist"""
try:
# Get all actual table names from the database
actual_tables = set(self.get_all_table_names())
# Extract table names from SQL using simple regex patterns
# This covers most common SQL patterns
table_patterns = [
r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_]*)', # FROM table_name
r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_]*)', # JOIN table_name
r'\bINTO\s+([a-zA-Z_][a-zA-Z0-9_]*)', # INTO table_name
r'\bUPDATE\s+([a-zA-Z_][a-zA-Z0-9_]*)', # UPDATE table_name
]
referenced_tables = set()
sql_upper = sql.upper()
for pattern in table_patterns:
matches = re.findall(pattern, sql_upper, re.IGNORECASE)
for match in matches:
referenced_tables.add(match.upper())
# Check if all referenced tables exist
missing_tables = referenced_tables - {t.upper() for t in actual_tables}
if missing_tables:
return f"Table(s) not found: {', '.join(missing_tables)}"
return None
except Exception as e:
logger.warning(f"Table validation failed: {e}")
# If validation fails, allow the query to proceed (safer than blocking all queries)
return None