Skip to main content
Glama
data_store.py16.4 kB
"""Module for interacting with Spark databases and tables. This module provides functions to retrieve information about databases, tables, and their schemas from a Spark cluster or directly from Hive metastore in PostgreSQL. Uses berdl_notebook_utils for shared functionality with the notebook environment. """ import atexit import json import logging from functools import lru_cache from typing import Any, Dict, List, Optional, Union import httpx from pyspark.sql import SparkSession # Use local MCP copies that don't rely on environment variables from src.delta_lake import hive_metastore # Use shared utilities from berdl_notebook_utils for consistency with notebooks from berdl_notebook_utils.spark import data_store as notebook_data_store from src.settings import BERDLSettings, get_settings logger = logging.getLogger(__name__) # ============================================================================= # HTTP CLIENT MANAGEMENT # ============================================================================= # Shared httpx client with connection pooling for governance API calls. # Creating a new client per request is inefficient and can exhaust connections. @lru_cache(maxsize=1) def _get_http_client() -> httpx.Client: """ Get a shared httpx Client with connection pooling. The client is cached and reused across all requests for efficiency. Uses reasonable defaults for timeouts and connection limits. Returns: Shared httpx.Client instance """ logger.info("Initializing shared httpx client for governance API") client = httpx.Client( timeout=httpx.Timeout( connect=10.0, # Connection timeout read=30.0, # Read timeout write=10.0, # Write timeout pool=5.0, # Pool timeout ), limits=httpx.Limits( max_connections=100, # Total max connections in pool max_keepalive_connections=20, # Max idle keepalive connections keepalive_expiry=30.0, # Keepalive connection expiry in seconds ), ) return client def _close_http_client() -> None: """Close the shared HTTP client on shutdown.""" if _get_http_client.cache_info().currsize > 0: try: client = _get_http_client() client.close() logger.info("Shared httpx client closed") except Exception as e: logger.warning(f"Error closing httpx client: {e}") # Register cleanup on process exit atexit.register(_close_http_client) # Re-export get_table_schema from berdl_notebook_utils (works identically in both contexts) get_table_schema = notebook_data_store.get_table_schema # get_tables is customized below to support use_hms parameter with settings def _execute_with_spark( func: Any, spark: Optional[SparkSession] = None, *args, **kwargs ) -> Any: """ Execute a function with a SparkSession. In the MCP server context, spark must be provided via FastAPI dependency injection. """ if spark is None: raise ValueError( "SparkSession must be provided. In MCP server context, use FastAPI dependency injection." ) return func(spark, *args, **kwargs) def _format_output(data: Any, return_json: bool = True) -> Union[str, Any]: """ Format the output based on the return_json flag. """ return json.dumps(data) if return_json else data def _extract_databases_from_paths(paths: List[str]) -> List[str]: """ Extract unique database names from S3 SQL warehouse paths. Only considers paths in SQL warehouses (not general warehouses or logs): - s3a://cdm-lake/users-sql-warehouse/... - s3a://cdm-lake/tenant-sql-warehouse/... S3 paths are in format: s3a://bucket/warehouse/user_or_tenant/database.db/... Example: s3a://cdm-lake/users-sql-warehouse/tgu1/u_tgu1__sharing_test.db/employee_records_1/ Args: paths: List of S3 paths from accessible paths API Returns: List of unique database names (without .db suffix) """ databases = set() for path in paths: # Only process paths from SQL warehouses if not any( warehouse in path for warehouse in ["/users-sql-warehouse/", "/tenant-sql-warehouse/"] ): continue # Remove s3a:// prefix and split by / parts = path.replace("s3a://", "").split("/") # Look for .db directory (database directory in Hive convention) for part in parts: if part.endswith(".db"): db_name = part[:-3] # Remove .db suffix databases.add(db_name) break return sorted(list(databases)) def _get_user_namespace_prefixes(auth_token: str) -> List[str]: """ Get all namespace prefixes for the authenticated user (user + all groups). Args: auth_token: KBase authentication token Returns: List of namespace prefixes (user prefix + all group/tenant prefixes) """ settings = get_settings() governance_url = str(settings.GOVERNANCE_API_URL).rstrip("/") prefixes = [] headers = {"Authorization": f"Bearer {auth_token}"} client = _get_http_client() try: # Get user's namespace prefix response = client.get( f"{governance_url}/workspaces/me/namespace-prefix", headers=headers ) response.raise_for_status() data = response.json() user_prefix = data.get("user_namespace_prefix") if user_prefix: prefixes.append(user_prefix) logger.debug(f"User namespace prefix: {user_prefix}") # Get user's groups response = client.get(f"{governance_url}/workspaces/me/groups", headers=headers) response.raise_for_status() groups_data = response.json() groups = groups_data.get("groups", []) logger.debug(f"User groups: {groups}") # Get namespace prefix for each group for group_name in groups: try: response = client.get( f"{governance_url}/workspaces/me/namespace-prefix", params={"tenant": group_name}, headers=headers, ) response.raise_for_status() data = response.json() tenant_prefix = data.get("tenant_namespace_prefix") if tenant_prefix: prefixes.append(tenant_prefix) logger.debug( f"Tenant '{group_name}' namespace prefix: {tenant_prefix}" ) except Exception as e: logger.warning( f"Could not get namespace prefix for group {group_name}: {e}" ) # Continue with other groups return prefixes except Exception as e: logger.error(f"Error fetching namespace prefixes from governance API: {e}") raise Exception(f"Could not filter databases by namespace: {e}") from e def _get_accessible_paths(auth_token: str) -> List[str]: """ Get all S3 paths accessible to the user from the governance API. Args: auth_token: KBase authentication token Returns: List of accessible S3 paths """ settings = get_settings() governance_url = str(settings.GOVERNANCE_API_URL).rstrip("/") headers = {"Authorization": f"Bearer {auth_token}"} client = _get_http_client() try: response = client.get( f"{governance_url}/workspaces/me/accessible-paths", headers=headers ) response.raise_for_status() data = response.json() accessible_paths = data.get("accessible_paths", []) logger.debug(f"Retrieved {len(accessible_paths)} accessible paths") return accessible_paths except Exception as e: logger.error(f"Error fetching accessible paths from governance API: {e}") raise Exception(f"Could not get accessible paths: {e}") from e def _get_tables_with_schemas( db: str, tables: List[str], spark: SparkSession ) -> Dict[str, Any]: """ Get schemas for a list of tables in a database. """ return { table: get_table_schema( database=db, table=table, spark=spark, return_json=False ) for table in tables } def get_databases( spark: Optional[SparkSession] = None, use_hms: bool = True, return_json: bool = True, filter_by_namespace: bool = False, auth_token: Optional[str] = None, settings: Optional[BERDLSettings] = None, ) -> Union[str, List[str]]: """ Get the list of databases in the Hive metastore. Args: spark: Optional SparkSession to use (if use_hms is False) use_hms: Whether to use Hive Metastore client direct query (faster) or Spark return_json: Whether to return JSON string or raw data filter_by_namespace: Whether to filter databases by user/group ownership AND shared access. When True, returns: - User's owned databases (u_username_*) - Group/tenant databases (groupname_*) - Databases shared with the user (from accessible paths API) When False, returns all databases in the metastore. auth_token: KBase auth token (required if filter_by_namespace is True) settings: BERDLSettings instance (required if use_hms is True) Returns: List of database names, either as JSON string or raw list Raises: ValueError: If filter_by_namespace is True but auth_token is not provided ValueError: If use_hms is True but settings is not provided """ # Validate early: auth_token is required if filtering by namespace if filter_by_namespace and not auth_token: raise ValueError("auth_token is required when filter_by_namespace is True") def _get_dbs(session: SparkSession) -> List[str]: return [db.name for db in session.catalog.listDatabases()] if use_hms: if settings is None: settings = get_settings() databases = hive_metastore.get_databases(settings=settings) else: databases = _execute_with_spark(_get_dbs, spark) # Apply filtering: owned databases (fast) + shared databases (API call) if filter_by_namespace: try: # Step 1: Get owned/group databases using namespace prefixes (fast) prefixes = _get_user_namespace_prefixes(auth_token) if prefixes: # Filter databases by namespace prefixes (owned + group databases) owned_databases = [ db for db in databases if db.startswith(tuple(prefixes)) ] logger.info( f"Found {len(owned_databases)} owned/group databases matching {len(prefixes)} prefix(es)" ) else: logger.warning("No namespace prefixes found") owned_databases = [] # Step 2: Get shared databases from accessible paths API # These are databases shared with the user that don't match their namespace accessible_paths = _get_accessible_paths(auth_token) shared_databases = _extract_databases_from_paths(accessible_paths) logger.info( f"Found {len(shared_databases)} shared databases from accessible paths" ) # Combine owned and shared, remove duplicates all_accessible = set(owned_databases) | set(shared_databases) # Filter to only databases that exist in metastore databases = sorted([db for db in databases if db in all_accessible]) logger.info( f"Total accessible databases: {len(databases)} (owned: {len(owned_databases)}, shared: {len(shared_databases)})" ) except Exception as e: logger.error(f"Error filtering databases by namespace: {e}") raise return _format_output(databases, return_json) def get_tables( database: str, spark: Optional[SparkSession] = None, use_hms: bool = True, return_json: bool = True, settings: Optional[BERDLSettings] = None, ) -> Union[str, List[str]]: """ Get the list of tables in a database. Args: database: Name of the database spark: Optional SparkSession to use (if use_hms is False) use_hms: Whether to use Hive Metastore client direct query (faster) or Spark return_json: Whether to return JSON string or raw data settings: BERDLSettings instance (required if use_hms is True) Returns: List of table names, either as JSON string or raw list """ def _get_tbls(session: SparkSession) -> List[str]: return [t.name for t in session.catalog.listTables(database)] if use_hms: if settings is None: settings = get_settings() tables = hive_metastore.get_tables(database=database, settings=settings) else: tables = _execute_with_spark(_get_tbls, spark) return _format_output(tables, return_json) def get_db_structure( spark: Optional[SparkSession] = None, with_schema: bool = False, use_hms: bool = True, return_json: bool = True, settings: Optional[BERDLSettings] = None, ) -> Union[str, Dict]: """Get the structure of all databases in the Hive metastore. Args: spark: Optional SparkSession to use for operations with_schema: Whether to include table schemas use_hms: Whether to use Hive Metastore client for metadata retrieval return_json: Whether to return the result as a JSON string settings: BERDLSettings instance (required if use_hms is True) Returns: Database structure as either JSON string or dictionary: { "database_name": ["table1", "table2"] or "database_name": { "table1": ["column1", "column2"], "table2": ["column1", "column2"] } } """ def _get_structure( session: SparkSession, ) -> Dict[str, Union[List[str], Dict[str, List[str]]]]: db_structure = {} databases = get_databases(spark=session, return_json=False) for db in databases: tables = get_tables(database=db, spark=session, return_json=False) if with_schema and isinstance(tables, list): db_structure[db] = _get_tables_with_schemas(db, tables, session) else: db_structure[db] = tables return db_structure if use_hms: if settings is None: settings = get_settings() db_structure = {} databases = hive_metastore.get_databases(settings=settings) for db in databases: tables = hive_metastore.get_tables(database=db, settings=settings) if with_schema and isinstance(tables, list): if spark is None: raise ValueError( "SparkSession must be provided for schema retrieval. " "In MCP server context, use FastAPI dependency injection." ) db_structure[db] = _get_tables_with_schemas(db, tables, spark) else: db_structure[db] = tables else: db_structure = _execute_with_spark(_get_structure, spark) return _format_output(db_structure, return_json) def database_exists( database: str, spark: Optional[SparkSession] = None, use_hms: bool = True, settings: Optional[BERDLSettings] = None, ) -> bool: """ Check if a database exists in the Hive metastore. """ if settings is None: settings = get_settings() return database in get_databases( spark=spark, use_hms=use_hms, return_json=False, settings=settings ) def table_exists( database: str, table: str, spark: Optional[SparkSession] = None, use_hms: bool = True, settings: Optional[BERDLSettings] = None, ) -> bool: """ Check if a table exists in a database. """ if settings is None: settings = get_settings() return table in get_tables( database=database, spark=spark, use_hms=use_hms, return_json=False, settings=settings, )

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/BERDataLakehouse/datalake-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server