"""Database inspection operations."""
from typing import Any, Optional
from .credentials import Credentials
from .connection import ConnectionManager
from .dependencies import ensure_deps_once
try:
from sqlalchemy import Engine, MetaData, Table, select, func, text, inspect
from sqlalchemy.exc import SQLAlchemyError
except ImportError:
ensure_deps_once()
from sqlalchemy import Engine, MetaData, Table, select, func, text, inspect
from sqlalchemy.exc import SQLAlchemyError
class DatabaseInspector:
"""Handles database inspection operations."""
def __init__(self, connection_manager: ConnectionManager):
self.conn_manager = connection_manager
def get_databases(self, creds: Credentials) -> list[str]:
"""Get list of databases on server."""
try:
engine = self.conn_manager.get_engine_with_credentials(
Credentials(
user=creds.user,
password=creds.password,
server=creds.server,
database="master",
driver=creds.driver,
port=creds.port
)
)
if not engine:
return [f"Error: Could not create connection"]
with engine.connect() as conn:
query_str = "SELECT name FROM sys.databases WHERE database_id > 4"
result = conn.execute(text(query_str))
return [row[0] for row in result]
except Exception as e:
return [f"Error retrieving databases: {e}"]
def get_tables(self, creds: Credentials) -> list[str]:
"""Get list of tables in database."""
engine = self.conn_manager.get_engine_with_credentials(creds)
if not engine:
return [f"Error: Could not create connection"]
try:
inspector = inspect(engine)
tables = inspector.get_table_names(schema=None)
if not tables:
tables = self._fallback_get_tables(engine, creds.database)
return tables
except SQLAlchemyError as e:
return [f"Error retrieving tables: {str(e)}"]
def _fallback_get_tables(self, engine: Engine, database: str) -> list[str]:
"""Fallback method to get tables using information_schema."""
try:
with engine.connect() as connection:
if engine.dialect.name == 'mysql' or engine.dialect.name == 'mariadb':
query_str = text("SELECT table_name FROM information_schema.tables WHERE table_schema = :db_name")
result = connection.execute(query_str, {"db_name": database})
elif engine.dialect.name == 'mssql':
query_str = text("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'dbo'")
result = connection.execute(query_str)
else:
query_str = text("SELECT table_name FROM information_schema.tables WHERE table_schema = :db_name")
result = connection.execute(query_str, {"db_name": database})
return [row[0] for row in result]
except SQLAlchemyError:
return []
def describe_table(self, creds: Credentials, table: str) -> dict[str, Any]:
"""Get detailed table description."""
engine = self.conn_manager.get_engine_with_credentials(creds)
if not engine:
return {"error": "Could not create connection"}
try:
inspector = inspect(engine)
return {
"server_name": creds.server,
"database_name": creds.database,
"table_name": table,
"columns": self._get_columns_info(inspector, table),
"primary_key": self._get_primary_keys(inspector, table),
"foreign_keys": self._get_foreign_keys(inspector, table),
"indexes": self._get_indexes(inspector, table),
"row_count": self._get_row_count(engine, table)
}
except Exception as e:
return {"error": f"Error describing table: {str(e)}"}
def _get_columns_info(self, inspector, table: str) -> list[dict]:
"""Get column information."""
try:
columns_result = inspector.get_columns(table_name=table, schema=None)
return [{
"name": col.get('name'),
"type": str(col.get('type')),
"nullable": col.get('nullable'),
"default": col.get('default'),
"comment": col.get('comment')
} for col in columns_result]
except SQLAlchemyError as e:
print(f"Error getting columns: {e}")
return []
def _get_primary_keys(self, inspector, table: str) -> list[str]:
"""Get primary key columns."""
try:
pk_constraint = inspector.get_pk_constraint(table_name=table, schema=None)
return pk_constraint.get('constrained_columns', [])
except SQLAlchemyError as e:
print(f"Error getting primary keys: {e}")
return []
def _get_foreign_keys(self, inspector, table: str) -> list[dict]:
"""Get foreign key information."""
try:
fks = inspector.get_foreign_keys(table_name=table, schema=None)
return [{
"name": fk.get('name'),
"constrained_columns": fk.get('constrained_columns'),
"referred_schema": fk.get('referred_schema'),
"referred_table": fk.get('referred_table'),
"referred_columns": fk.get('referred_columns'),
"options": fk.get('options')
} for fk in fks]
except SQLAlchemyError as e:
print(f"Error getting foreign keys: {e}")
return []
def _get_indexes(self, inspector, table: str) -> list[dict]:
"""Get index information."""
try:
indexes = inspector.get_indexes(table_name=table, schema=None)
return [{
"name": idx.get('name'),
"column_names": idx.get('column_names'),
"unique": idx.get('unique', False)
} for idx in indexes]
except SQLAlchemyError as e:
print(f"Error getting indexes: {e}")
return []
def _get_row_count(self, engine: Engine, table: str) -> int:
"""Get table row count."""
try:
metadata = MetaData()
schema = 'dbo' if engine.dialect.name == 'mssql' else None
try:
table_obj = Table(table, metadata, autoload_with=engine, schema=schema)
except Exception:
if schema:
table_obj = Table(table, metadata, autoload_with=engine)
else:
raise
with engine.connect() as connection:
count_query = select(func.count()).select_from(table_obj)
result = connection.execute(count_query)
return result.scalar()
except Exception as e:
print(f"Error getting row count: {e}")
return -1