"""
Schema service for database introspection and schema management
"""
import logging
from typing import Dict, List, Optional, Any
from repositories.postgres_repository import PostgresRepository
from shared.models import SchemaInfo
from shared.exceptions import SchemaError
logger = logging.getLogger(__name__)
class SchemaService:
"""
Service for database schema introspection and management
Provides business logic around schema discovery and catalog building
"""
def __init__(self, postgres_repo: PostgresRepository):
self.postgres_repo = postgres_repo
self._schema_cache = {}
self._catalog_cache = None
def get_schema_info(self, use_cache: bool = True) -> SchemaInfo:
"""
Get comprehensive schema information for the database
Args:
use_cache: Whether to use cached schema info
Returns:
SchemaInfo object with complete schema details
"""
cache_key = "full_schema"
if use_cache and cache_key in self._schema_cache:
logger.info("Returning cached schema info")
return self._schema_cache[cache_key]
try:
# Get all tables
table_names = self.postgres_repo.get_all_table_names()
tables = []
for table_name in table_names:
table_schema = self.postgres_repo.get_table_schema(table_name)
table_schema['row_count'] = self.postgres_repo.get_row_count(table_name)
tables.append(table_schema)
# Get relationships (foreign keys across all tables)
relationships = self._extract_relationships(tables)
# Get indexes across all tables
indexes = self._extract_indexes(tables)
# Get views (if any)
views = self._get_views()
# Get functions (stored procedures, etc.)
functions = self._get_functions()
# Build summary
summary = self._build_schema_summary(tables, relationships, views, functions)
schema_info = SchemaInfo(
tables=tables,
relationships=relationships,
indexes=indexes,
views=views,
functions=functions,
summary=summary
)
# Cache the result
self._schema_cache[cache_key] = schema_info
logger.info(f"Schema info collected: {len(tables)} tables, {len(relationships)} relationships")
return schema_info
except Exception as e:
logger.error(f"Failed to get schema info: {e}")
raise SchemaError(f"Could not retrieve schema information: {e}")
def find_relevant_tables(self, question: str, schema_info: Optional[SchemaInfo] = None) -> List[str]:
"""
Find tables relevant to a natural language question
Args:
question: User's question
schema_info: Optional pre-loaded schema info
Returns:
List of relevant table names
"""
if schema_info is None:
schema_info = self.get_schema_info()
question_lower = question.lower()
relevant_tables = []
for table in schema_info.tables:
table_name = table['table_name'].lower()
# Direct table name mention
if table_name in question_lower:
relevant_tables.append(table['table_name'])
continue
# Check column names
for column in table['columns']:
column_name = column['name'].lower()
if column_name in question_lower:
relevant_tables.append(table['table_name'])
break
# Check table comment/description if available
if 'comment' in table and table['comment']:
if any(word in table['comment'].lower() for word in question_lower.split()):
relevant_tables.append(table['table_name'])
# Remove duplicates while preserving order
return list(dict.fromkeys(relevant_tables))
def get_table_details(self, table_name: str) -> Dict[str, Any]:
"""Get detailed information about a specific table"""
try:
schema = self.postgres_repo.get_table_schema(table_name)
schema['row_count'] = self.postgres_repo.get_row_count(table_name)
schema['sample_data'] = self._get_sample_data(table_name)
return schema
except Exception as e:
logger.error(f"Failed to get table details for {table_name}: {e}")
raise SchemaError(f"Could not get details for table {table_name}: {e}")
def build_catalog_descriptions(self, schema_info: Optional[SchemaInfo] = None) -> List[Dict[str, Any]]:
"""
Build semantic descriptions of schema elements for catalog search
Args:
schema_info: Optional pre-loaded schema info
Returns:
List of catalog entries with descriptions
"""
if schema_info is None:
schema_info = self.get_schema_info()
if self._catalog_cache is not None:
return self._catalog_cache
catalog_entries = []
for table in schema_info.tables:
# Main table entry
description = self._build_table_description(table)
catalog_entries.append({
'type': 'table',
'name': table['table_name'],
'description': description,
'metadata': {
'columns': len(table['columns']),
'row_count': table.get('row_count', 0),
'has_primary_key': bool(table.get('primary_keys', {}).get('constrained_columns')),
'has_foreign_keys': bool(table.get('foreign_keys'))
}
})
# Individual column entries for important columns
for column in table['columns']:
if self._is_important_column(column):
column_description = self._build_column_description(table['table_name'], column)
catalog_entries.append({
'type': 'column',
'name': f"{table['table_name']}.{column['name']}",
'description': column_description,
'metadata': {
'table': table['table_name'],
'column': column['name'],
'data_type': column['type'],
'nullable': column['nullable']
}
})
self._catalog_cache = catalog_entries
logger.info(f"Built catalog with {len(catalog_entries)} entries")
return catalog_entries
def invalidate_cache(self):
"""Clear all cached schema information"""
self._schema_cache.clear()
self._catalog_cache = None
logger.info("Schema cache invalidated")
def _extract_relationships(self, tables: List[Dict]) -> List[Dict[str, Any]]:
"""Extract foreign key relationships from table schemas"""
relationships = []
for table in tables:
table_name = table['table_name']
for fk in table.get('foreign_keys', []):
relationships.append({
'from_table': table_name,
'from_columns': fk['constrained_columns'],
'to_table': fk['referred_table'],
'to_columns': fk['referred_columns'],
'constraint_name': fk['name']
})
return relationships
def _extract_indexes(self, tables: List[Dict]) -> List[Dict[str, Any]]:
"""Extract index information from table schemas"""
indexes = []
for table in tables:
table_name = table['table_name']
for index in table.get('indexes', []):
indexes.append({
'table': table_name,
'name': index['name'],
'columns': index['column_names'],
'unique': index['unique']
})
return indexes
def _get_views(self) -> List[Dict[str, Any]]:
"""Get database views"""
try:
result = self.postgres_repo.execute_query("""
SELECT table_name, view_definition
FROM information_schema.views
WHERE table_schema = 'public'
""")
if result.success:
return result.data
return []
except Exception as e:
logger.warning(f"Could not get views: {e}")
return []
def _get_functions(self) -> List[Dict[str, Any]]:
"""Get database functions and procedures"""
try:
result = self.postgres_repo.execute_query("""
SELECT routine_name, routine_type, data_type
FROM information_schema.routines
WHERE routine_schema = 'public'
""")
if result.success:
return result.data
return []
except Exception as e:
logger.warning(f"Could not get functions: {e}")
return []
def _build_schema_summary(self, tables: List[Dict], relationships: List[Dict],
views: List[Dict], functions: List[Dict]) -> str:
"""Build a human-readable summary of the schema"""
summary_parts = [
f"Database contains {len(tables)} tables with {len(relationships)} foreign key relationships."
]
if views:
summary_parts.append(f"Includes {len(views)} views.")
if functions:
summary_parts.append(f"Has {len(functions)} stored procedures/functions.")
# Add information about key tables
if tables:
largest_tables = sorted(tables, key=lambda t: t.get('row_count', 0), reverse=True)[:3]
if largest_tables[0].get('row_count', 0) > 0:
table_info = ", ".join([f"{t['table_name']} ({t.get('row_count', 0)} rows)"
for t in largest_tables if t.get('row_count', 0) > 0])
summary_parts.append(f"Largest tables: {table_info}.")
return " ".join(summary_parts)
def _build_table_description(self, table: Dict) -> str:
"""Build a semantic description of a table"""
table_name = table['table_name']
columns = table['columns']
row_count = table.get('row_count', 0)
# Basic description
description = f"Table '{table_name}' with {len(columns)} columns"
if row_count > 0:
description += f" and {row_count} rows"
# Add key column information
key_columns = []
for col in columns[:5]: # First 5 columns
col_desc = f"{col['name']} ({col['type']})"
if not col['nullable']:
col_desc += " NOT NULL"
key_columns.append(col_desc)
if key_columns:
description += f". Key columns: {', '.join(key_columns)}"
return description
def _build_column_description(self, table_name: str, column: Dict) -> str:
"""Build a semantic description of a column"""
col_name = column['name']
col_type = column['type']
nullable = "nullable" if column['nullable'] else "required"
return f"Column '{col_name}' in table '{table_name}': {col_type}, {nullable}"
def _is_important_column(self, column: Dict) -> bool:
"""Determine if a column is important enough to include in catalog"""
col_name = column['name'].lower()
col_type = str(column['type']).lower()
# Always include primary keys
if 'id' in col_name and col_name.endswith('id'):
return True
# Include text columns that might contain searchable content
if 'text' in col_type or 'varchar' in col_type:
return True
# Include date/time columns
if any(t in col_type for t in ['date', 'time', 'timestamp']):
return True
# Include commonly searched column names
important_names = ['name', 'title', 'description', 'email', 'status', 'type', 'category']
return any(name in col_name for name in important_names)
def _get_sample_data(self, table_name: str, limit: int = 3) -> List[Dict]:
"""Get sample data from a table"""
try:
result = self.postgres_repo.execute_query(f"SELECT * FROM {table_name} LIMIT {limit}")
if result.success:
return result.data
return []
except Exception as e:
logger.warning(f"Could not get sample data for {table_name}: {e}")
return []