from pathlib import Path
from typing import Optional, List, Dict, Any
from .database import DatabaseConnector
from .schema.manager import SchemaManager
from .models import TableInfo
class MultiDatabaseContext:
"""Manages multiple database contexts for different Oracle databases"""
def __init__(
self,
databases: Dict[str, Dict[str, Any]],
cache_base_path: Path,
read_only: bool = True
):
"""
Initialize multi-database context.
Args:
databases: Dict mapping database names to connection configs
Example: {
"prod": {
"connection_string": "user/pass@host:port/service",
"target_schema": "PROD_SCHEMA",
"use_thick_mode": False,
"lib_dir": None
},
"test": {...}
}
cache_base_path: Base directory for all database caches
read_only: Whether to enable read-only mode for all databases
"""
self.databases: Dict[str, 'DatabaseContext'] = {}
self.cache_base_path = cache_base_path
self.read_only = read_only
# Create individual DatabaseContext for each database
for db_name, config in databases.items():
cache_path = cache_base_path / f"{db_name}_schema_cache.json"
self.databases[db_name] = DatabaseContext(
connection_string=config["connection_string"],
cache_path=cache_path,
target_schema=config.get("target_schema"),
use_thick_mode=config.get("use_thick_mode", False),
lib_dir=config.get("lib_dir"),
read_only=read_only
)
async def initialize(self):
"""Initialize all database contexts"""
for db_name, ctx in self.databases.items():
print(f"Initializing database context for: {db_name}")
await ctx.initialize()
async def close(self):
"""Close all database connections"""
for ctx in self.databases.values():
await ctx.close()
def get_database(self, db_name: str) -> 'DatabaseContext':
"""Get a specific database context"""
if db_name not in self.databases:
raise ValueError(f"Database '{db_name}' not found. Available: {list(self.databases.keys())}")
return self.databases[db_name]
def list_databases(self) -> List[str]:
"""List all available database names"""
return list(self.databases.keys())
async def get_all_database_info(self) -> Dict[str, Dict[str, Any]]:
"""Get vendor info for all databases"""
result = {}
for db_name, ctx in self.databases.items():
try:
result[db_name] = await ctx.get_database_info()
except Exception as e:
result[db_name] = {"error": str(e)}
return result
async def list_tables(self, db_name: str) -> List[str]:
"""List all table names in the specified database."""
ctx = self.get_database(db_name)
return await ctx.list_tables()
class DatabaseContext:
def __init__(self, connection_string: str, cache_path: Path, target_schema: Optional[str] = None, use_thick_mode: bool = False, lib_dir: Optional[str] = None, read_only: bool = True):
self.db_connector = DatabaseConnector(connection_string, target_schema, use_thick_mode, lib_dir, read_only)
self.schema_manager = SchemaManager(self.db_connector, cache_path)
# Set the schema manager reference in the connector
self.db_connector.set_schema_manager(self.schema_manager)
async def initialize(self) -> None:
"""Initialize the database context, connection pool, and schema cache"""
await self.db_connector.initialize_pool()
await self.schema_manager.initialize()
async def close(self) -> None:
"""Close the database context and connection pool"""
await self.db_connector.close_pool()
async def get_database_info(self):
"""Get information about the database vendor and version"""
return await self.db_connector.get_database_info()
async def get_schema_info(self, table_name: str) -> Optional[TableInfo]:
"""Get schema information for a specific table"""
return await self.schema_manager.get_schema_info(table_name)
async def search_tables(self, search_term: str, limit: int = 20) -> List[str]:
"""Search for table names matching the search term"""
return await self.schema_manager.search_tables(search_term, limit)
async def rebuild_cache(self, fetch_all_metadata: bool = False) -> None:
"""Force a rebuild of the schema cache
Args:
fetch_all_metadata: If True, fetches complete metadata (constraints, indexes, stats, comments)
for all tables during cache rebuild. This takes longer but provides comprehensive metadata.
"""
self.schema_manager.cache = await self.schema_manager.load_or_build_cache(
force_rebuild=True,
fetch_all_metadata=fetch_all_metadata
)
async def search_columns(self, search_term: str, limit: int = 50) -> Dict[str, List[Dict[str, Any]]]:
"""Search for columns matching the given pattern across all tables"""
return await self.schema_manager.search_columns(search_term, limit)
async def get_pl_sql_objects(self, object_type: str, name_pattern: Optional[str] = None) -> List[Dict[str, Any]]:
"""Get information about PL/SQL objects of the specified type"""
# First check schema manager cache
cache_key = f"{object_type}_{name_pattern or 'all'}"
if self.schema_manager.is_cache_valid('plsql', cache_key):
self.schema_manager.cache_stats['hits'] += 1
return self.schema_manager.object_cache['plsql'][cache_key]['data']
# If not in cache or expired, get from database
self.schema_manager.cache_stats['misses'] += 1
result = await self.db_connector.get_pl_sql_objects(object_type, name_pattern)
# Update cache
self.schema_manager.update_cache('plsql', cache_key, result)
await self.schema_manager.save_cache()
return result
async def get_object_source(self, object_type: str, object_name: str) -> str:
"""Get the source code for a PL/SQL object"""
return await self.db_connector.get_object_source(object_type, object_name)
async def get_table_constraints(self, table_name: str) -> List[Dict[str, Any]]:
"""Get constraints for a specific table"""
# Check cache first
if self.schema_manager.is_cache_valid('constraints', table_name):
self.schema_manager.cache_stats['hits'] += 1
return self.schema_manager.object_cache['constraints'][table_name]['data']
# If not in cache or expired, get from database
self.schema_manager.cache_stats['misses'] += 1
result = await self.db_connector.get_table_constraints(table_name)
# Update cache
self.schema_manager.update_cache('constraints', table_name, result)
await self.schema_manager.save_cache()
return result
async def get_table_indexes(self, table_name: str) -> List[Dict[str, Any]]:
"""Get indexes for a specific table"""
# Check cache first
if self.schema_manager.is_cache_valid('indexes', table_name):
self.schema_manager.cache_stats['hits'] += 1
return self.schema_manager.object_cache['indexes'][table_name]['data']
# If not in cache or expired, get from database
self.schema_manager.cache_stats['misses'] += 1
result = await self.db_connector.get_table_indexes(table_name)
# Update cache
self.schema_manager.update_cache('indexes', table_name, result)
await self.schema_manager.save_cache()
return result
async def get_dependent_objects(self, object_name: str) -> List[Dict[str, Any]]:
"""Get objects that depend on the specified object"""
return await self.db_connector.get_dependent_objects(object_name)
async def get_user_defined_types(self, type_pattern: Optional[str] = None) -> List[Dict[str, Any]]:
"""Get information about user-defined types"""
# Check cache first
cache_key = type_pattern or 'all'
if self.schema_manager.is_cache_valid('types', cache_key):
self.schema_manager.cache_stats['hits'] += 1
return self.schema_manager.object_cache['types'][cache_key]['data']
# If not in cache or expired, get from database
self.schema_manager.cache_stats['misses'] += 1
result = await self.db_connector.get_user_defined_types(type_pattern)
# Update cache
self.schema_manager.update_cache('types', cache_key, result)
await self.schema_manager.save_cache()
return result
async def get_related_tables(self, table_name: str) -> Dict[str, List[str]]:
"""Get all tables that are related to the specified table through foreign keys."""
# Check cache first
cache_key = f"related_{table_name}"
if self.schema_manager.is_cache_valid('related_tables', cache_key):
self.schema_manager.cache_stats['hits'] += 1
return self.schema_manager.object_cache['related_tables'][cache_key]['data']
# If not in cache or expired, get from database
self.schema_manager.cache_stats['misses'] += 1
result = await self.db_connector.get_related_tables(table_name)
# Update cache
self.schema_manager.update_cache('related_tables', cache_key, result)
await self.schema_manager.save_cache()
return result
async def run_sql_query(self, sql: str, params: Optional[Dict[str, Any]] = None, max_rows: int = 100) -> Dict[str, Any]:
"""Runs a SQL query and returns the results."""
return await self.db_connector.execute_sql_query(sql, params, max_rows)
async def explain_query_plan(self, query: str) -> Dict[str, Any]:
"""Get execution plan for an SQL query with optimization suggestions"""
return await self.db_connector.explain_query_plan(query)
async def list_tables(self) -> List[str]:
"""List all table names in the database."""
table_names = await self.db_connector.get_all_table_names()
return sorted(list(table_names))