Skip to main content
Glama
setup_spark_session.py19.2 kB
""" Spark session setup for the Datalake MCP Server. This is copied from berdl_notebook_utils.setup_spark_session.py and adapted for the MCP server context (multi-user shared service, no environment variables). MAINTENANCE NOTE: This file is copied from: /spark_notebook/notebook_utils/berdl_notebook_utils/setup_spark_session.py When updating, copy the file and adapt the imports and warehouse configuration. """ import logging import os import re import socket import threading import warnings from datetime import datetime from typing import Any from pyspark.conf import SparkConf from pyspark.sql import SparkSession from src.settings import BERDLSettings, get_settings # Configure logging logger = logging.getLogger(__name__) # ============================================================================= # THREAD SAFETY # ============================================================================= # Global lock to prevent race conditions during Spark session creation. # PySpark's SparkSession.builder.getOrCreate() is NOT thread-safe and manipulates # global state (environment variables, builder._options). Without this lock, # concurrent requests can cause undefined behavior and deadlocks. _spark_session_lock = threading.Lock() # Suppress Protobuf version warnings from PySpark Spark Connect warnings.filterwarnings( "ignore", category=UserWarning, module="google.protobuf.runtime_version" ) # Suppress CANNOT_MODIFY_CONFIG warnings for Hive metastore settings in Spark Connect warnings.filterwarnings( "ignore", category=UserWarning, module="pyspark.sql.connect.conf" ) # ============================================================================= # CONSTANTS # ============================================================================= # Fair scheduler configuration SPARK_DEFAULT_POOL = "default" SPARK_POOLS = [SPARK_DEFAULT_POOL, "highPriority"] # Memory overhead percentages for Spark components EXECUTOR_MEMORY_OVERHEAD = ( 0.1 # 10% overhead for executors (accounts for JVM + system overhead) ) DRIVER_MEMORY_OVERHEAD = 0.05 # 5% overhead for driver (typically less memory pressure) # ============================================================================= # PRIVATE HELPER FUNCTIONS # ============================================================================= def convert_memory_format(memory_str: str, overhead_percentage: float = 0.1) -> str: """ Convert memory format from profile format to Spark format with overhead adjustment. Args: memory_str: Memory string in profile format (supports B, KiB, MiB, GiB, TiB) overhead_percentage: Percentage of memory to reserve for system overhead (default: 0.1 = 10%) Returns: Memory string in Spark format with overhead accounted for """ # Extract number and unit from memory string match = re.match(r"^(\d+(?:\.\d+)?)\s*([kmgtKMGT]i?[bB]?)$", memory_str) if not match: raise ValueError(f"Invalid memory format: {memory_str}") value, unit = match.groups() value = float(value) # Convert to bytes for calculation unit_lower = unit.lower() multipliers = { "b": 1, "kb": 1024, "kib": 1024, "mb": 1024**2, "mib": 1024**2, "gb": 1024**3, "gib": 1024**3, "tb": 1024**4, "tib": 1024**4, } # Remove trailing 'b' if present for lookup unit_key = ( unit_lower.rstrip("b") + "b" if unit_lower.endswith("b") else unit_lower + "b" ) if unit_key not in multipliers: unit_key = unit_lower bytes_value = value * multipliers.get(unit_key, multipliers["b"]) # Apply overhead reduction (reserve percentage for system) adjusted_bytes = bytes_value * (1 - overhead_percentage) # Convert back to appropriate Spark unit (prefer GiB for larger values) if adjusted_bytes >= 1024**3: adjusted_value = adjusted_bytes / (1024**3) spark_unit = "g" elif adjusted_bytes >= 1024**2: adjusted_value = adjusted_bytes / (1024**2) spark_unit = "m" elif adjusted_bytes >= 1024: adjusted_value = adjusted_bytes / 1024 spark_unit = "k" else: adjusted_value = adjusted_bytes spark_unit = "" # Format as integer to ensure Spark compatibility # Some Spark versions don't accept fractional memory values return f"{int(round(adjusted_value))}{spark_unit}" def _get_executor_conf( settings: BERDLSettings, use_spark_connect: bool ) -> dict[str, str]: """ Get Spark executor and driver configuration based on profile settings. Args: settings: BERDLSettings instance with profile-specific configuration use_spark_connect: bool indicating whether or not spark connect is to be used Returns: Dictionary of Spark executor and driver configuration """ # Convert memory formats from profile to Spark format with overhead adjustment executor_memory = convert_memory_format( settings.SPARK_WORKER_MEMORY, EXECUTOR_MEMORY_OVERHEAD ) driver_memory = convert_memory_format( settings.SPARK_MASTER_MEMORY, DRIVER_MEMORY_OVERHEAD ) if use_spark_connect: conf_base = {"spark.remote": str(settings.SPARK_CONNECT_URL)} else: driver_host = socket.gethostbyname(socket.gethostname()) conf_base = { "spark.driver.host": driver_host, "spark.driver.bindAddress": "0.0.0.0", # Bind to all interfaces "spark.master": str(settings.SPARK_MASTER_URL), } logger.info(f"Legacy mode: driver.host={driver_host}, bindAddress=0.0.0.0") return { **conf_base, # Driver configuration (critical for remote cluster connections) "spark.driver.memory": driver_memory, "spark.driver.cores": str(settings.SPARK_MASTER_CORES), # Executor configuration "spark.executor.instances": str(settings.SPARK_WORKER_COUNT), "spark.executor.cores": str(settings.SPARK_WORKER_CORES), "spark.executor.memory": executor_memory, # Disable dynamic allocation since we're setting explicit instances "spark.dynamicAllocation.enabled": "false", "spark.dynamicAllocation.shuffleTracking.enabled": "false", } def _get_spark_defaults_conf() -> dict[str, str]: """ Get Spark defaults configuration. """ return { # Decommissioning "spark.decommission.enabled": "true", "spark.storage.decommission.rddBlocks.enabled": "true", # Broadcast join configurations "spark.sql.autoBroadcastJoinThreshold": "52428800", # 50MB (default is 10MB) # Shuffle and compression configurations "spark.reducer.maxSizeInFlight": "96m", # 96MB (default is 48MB) "spark.shuffle.file.buffer": "1m", # 1MB (default is 32KB) } def _get_delta_conf() -> dict[str, str]: return { "spark.sql.extensions": "io.delta.sql.DeltaSparkSessionExtension", "spark.sql.catalog.spark_catalog": "org.apache.spark.sql.delta.catalog.DeltaCatalog", "spark.databricks.delta.retentionDurationCheck.enabled": "false", # Delta Lake optimizations "spark.databricks.delta.optimizeWrite.enabled": "true", "spark.databricks.delta.autoCompact.enabled": "true", } def _get_hive_conf(settings: BERDLSettings) -> dict[str, str]: return { "hive.metastore.uris": str(settings.BERDL_HIVE_METASTORE_URI), "spark.sql.catalogImplementation": "hive", "spark.sql.hive.metastore.version": "4.0.0", "spark.sql.hive.metastore.jars": "path", "spark.sql.hive.metastore.jars.path": "/usr/local/spark/jars/*", } def _get_s3_conf( settings: BERDLSettings, tenant_name: str | None = None ) -> dict[str, str]: """ Get S3 configuration for MinIO. Args: settings: BERDLSettings instance with configuration tenant_name: Tenant/group name to use for SQL warehouse. If provided, configures Spark to write tables to the tenant's SQL warehouse. If None, uses the user's personal SQL warehouse. Returns: Dictionary of S3/MinIO Spark configuration properties """ # Construct warehouse path directly (MCP server doesn't have access to governance API) if tenant_name: # Tenant warehouse: s3a://cdm-lake/tenant-sql-warehouse/{tenant_name}/ warehouse_dir = f"s3a://cdm-lake/tenant-sql-warehouse/{tenant_name}/" else: # User warehouse: s3a://cdm-lake/users-sql-warehouse/{username}/ warehouse_dir = f"s3a://cdm-lake/users-sql-warehouse/{settings.USER}/" event_log_dir = f"s3a://cdm-spark-job-logs/spark-job-logs/{settings.USER}/" return { "spark.hadoop.fs.s3a.endpoint": settings.MINIO_ENDPOINT_URL, "spark.hadoop.fs.s3a.access.key": settings.MINIO_ACCESS_KEY, "spark.hadoop.fs.s3a.secret.key": settings.MINIO_SECRET_KEY, "spark.hadoop.fs.s3a.connection.ssl.enabled": str( settings.MINIO_SECURE ).lower(), "spark.hadoop.fs.s3a.path.style.access": "true", "spark.hadoop.fs.s3a.impl": "org.apache.hadoop.fs.s3a.S3AFileSystem", "spark.sql.warehouse.dir": warehouse_dir, "spark.eventLog.enabled": "true", "spark.eventLog.dir": event_log_dir, } IMMUTABLE_CONFIGS = { # Cluster-level settings (must be set at master startup) "spark.decommission.enabled", "spark.storage.decommission.rddBlocks.enabled", "spark.reducer.maxSizeInFlight", "spark.shuffle.file.buffer", # Driver and executor resource configs (locked at server startup) "spark.driver.memory", "spark.driver.cores", "spark.executor.instances", "spark.executor.cores", "spark.executor.memory", "spark.dynamicAllocation.enabled", "spark.dynamicAllocation.shuffleTracking.enabled", # Event logging (locked at server startup) "spark.eventLog.enabled", "spark.eventLog.dir", # SQL extensions (must be loaded at startup) "spark.sql.extensions", "spark.sql.catalog.spark_catalog", # Hive catalog (locked at startup) "spark.sql.catalogImplementation", # Warehouse directory (locked at server startup) "spark.sql.warehouse.dir", } def _filter_immutable_spark_connect_configs(config: dict[str, str]) -> dict[str, str]: """ Filter out configurations that cannot be modified in Spark Connect mode. These configs must be set server-side when the Spark Connect server starts. Attempting to set them from the client results in CANNOT_MODIFY_CONFIG warnings. Args: config: Dictionary of Spark configurations Returns: Filtered configuration dictionary with only mutable configs """ return {k: v for k, v in config.items() if k not in IMMUTABLE_CONFIGS} def _set_scheduler_pool(spark: SparkSession, scheduler_pool: str) -> None: """Set the scheduler pool for the Spark session.""" if scheduler_pool not in SPARK_POOLS: print( f"Warning: Scheduler pool '{scheduler_pool}' not in available pools: {SPARK_POOLS}. " f"Defaulting to '{SPARK_DEFAULT_POOL}'" ) scheduler_pool = SPARK_DEFAULT_POOL spark.sparkContext.setLocalProperty("spark.scheduler.pool", scheduler_pool) def _clear_spark_env_for_mode_switch(use_spark_connect: bool) -> None: """ Clear PySpark environment variables to allow clean mode switching. PySpark 3.4+ uses several environment variables to determine whether to use Spark Connect or classic mode. These persist across sessions and can cause conflicts when switching modes within the same process. Environment variables managed: - SPARK_CONNECT_MODE_ENABLED: Set to "1" when Connect mode is used - SPARK_REMOTE: Spark Connect URL - SPARK_LOCAL_REMOTE: Set when using local Connect server - MASTER: Spark master URL (classic mode) - SPARK_API_MODE: Can be "classic" or "connect" Args: use_spark_connect: If True, clears legacy mode vars; if False, clears Connect vars """ if use_spark_connect: # Switching TO Spark Connect: clear legacy mode variables env_vars_to_clear = ["MASTER"] logger.debug( f"Clearing legacy mode env vars for Spark Connect: {env_vars_to_clear}" ) else: # Switching TO legacy mode: clear ALL Spark Connect related variables env_vars_to_clear = [ "SPARK_CONNECT_MODE_ENABLED", # Critical: forces Connect mode if present "SPARK_REMOTE", # Connect URL "SPARK_LOCAL_REMOTE", # Local Connect server flag ] logger.debug( f"Clearing Connect mode env vars for legacy mode: {env_vars_to_clear}" ) for var in env_vars_to_clear: if var in os.environ: logger.info(f"Clearing environment variable: {var}={os.environ[var]}") del os.environ[var] def generate_spark_conf( app_name: str | None = None, local: bool = False, use_delta_lake: bool = True, use_s3: bool = True, use_hive: bool = True, settings: BERDLSettings | None = None, tenant_name: str | None = None, use_spark_connect: bool = True, ) -> dict[str, str]: """Generate a spark session configuration dictionary from a set of input variables.""" # Generate app name if not provided if app_name is None: timestamp = datetime.now().strftime("%Y%m%d%H%M%S") app_name = f"kbase_spark_session_{timestamp}" # Build common configuration dictionary config: dict[str, str] = {"spark.app.name": app_name} if use_delta_lake: config.update(_get_delta_conf()) if not local: # Add default Spark configurations config.update(_get_spark_defaults_conf()) if settings is None: get_settings.cache_clear() settings = get_settings() # Add profile-specific executor and driver configuration config.update(_get_executor_conf(settings, use_spark_connect)) if use_s3: config.update(_get_s3_conf(settings, tenant_name)) if use_hive: config.update(_get_hive_conf(settings)) if use_spark_connect: # Spark Connect: filter out immutable configs that cannot be modified from the client config = _filter_immutable_spark_connect_configs(config) return config # ============================================================================= # PUBLIC FUNCTIONS # ============================================================================= def get_spark_session( app_name: str | None = None, local: bool = False, # TODO: switch to `use_delta_lake` for consistency with s3 / hive delta_lake: bool = True, scheduler_pool: str = SPARK_DEFAULT_POOL, use_s3: bool = True, use_hive: bool = True, settings: BERDLSettings | None = None, tenant_name: str | None = None, use_spark_connect: bool = True, override: dict[str, Any] | None = None, ) -> SparkSession: """ Create and configure a Spark session with BERDL-specific settings. This function creates a Spark session configured for the BERDL environment, including support for Delta Lake, MinIO S3 storage, and tenant-aware warehouses. Args: app_name: Application name. If None, generates a timestamp-based name local: If True, creates a local Spark session; the only other allowable option is `delta_lake` delta_lake: If True, enables Delta Lake support with required JARs scheduler_pool: Fair scheduler pool name (default: "default") use_s3: if True, enables reading from and writing to s3 use_hive: If True, enables Hive metastore integration settings: BERDLSettings instance. If None, creates new instance from env vars tenant_name: Tenant/group name to use for SQL warehouse location. If specified, tables will be written to the tenant's SQL warehouse instead of the user's personal warehouse. use_spark_connect: If True, uses Spark Connect instead of legacy mode override: dictionary of tag-value pairs to replace the values in the generated spark conf (e.g. for testing) Returns: Configured SparkSession instance Raises: EnvironmentError: If required environment variables are missing ValueError: If user is not a member of the specified tenant Example: >>> # Basic usage (user's personal warehouse) >>> spark = get_spark_session("MyApp") >>> # Using tenant warehouse (writes to tenant's SQL directory) >>> spark = get_spark_session("MyApp", tenant_name="research_team") >>> # With custom scheduler pool >>> spark = get_spark_session("MyApp", scheduler_pool="highPriority") >>> # Local development >>> spark = get_spark_session("TestApp", local=True) """ config = generate_spark_conf( app_name, local, delta_lake, use_s3, use_hive, settings, tenant_name, use_spark_connect, ) if override: config.update(override) # ========================================================================== # CRITICAL: Thread-safe session creation # ========================================================================== # PySpark's SparkSession.builder is NOT thread-safe. The following operations # must be performed atomically to prevent race conditions: # # 1. Clearing environment variables (os.environ modifications) # 2. Clearing builder._options # 3. Creating SparkConf and calling getOrCreate() # # Without this lock, concurrent requests can cause: # - Environment variable corruption between threads # - Builder options being modified mid-creation # - Undefined behavior leading to service hangs # ========================================================================== with _spark_session_lock: logger.debug("Acquired Spark session creation lock") # Clean environment before creating session # PySpark 3.4+ uses environment variables to determine mode _clear_spark_env_for_mode_switch(use_spark_connect) # Clear builder's cached options to prevent conflicts builder = SparkSession.builder if hasattr(builder, "_options"): builder._options.clear() # Use loadDefaults=False to prevent SparkConf from inheriting configuration # from any existing JVM (e.g., spark.master from a previous session). spark_conf = SparkConf(loadDefaults=False).setAll(list(config.items())) # Use the same builder instance that we cleared spark = builder.config(conf=spark_conf).getOrCreate() logger.debug("Spark session created, releasing lock") # Post-creation configuration (only for legacy mode with SparkContext) # This can be done outside the lock as it operates on the session instance if not local and not use_spark_connect: _set_scheduler_pool(spark, scheduler_pool) return spark

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/BERDataLakehouse/datalake-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server