Skip to main content
Glama
test_setup_spark_session.py40.3 kB
""" Tests for the setup_spark_session module. This is the most critical module for Spark session management. Tests cover: - Memory format conversion with overhead calculations - Executor and driver configuration generation - Delta Lake, Hive, and S3 configuration - Immutable configuration filtering for Spark Connect - Environment variable management for mode switching - Thread-safe session creation with locking - Full configuration generation workflow - Concurrent session creation safety """ import os import threading from unittest.mock import MagicMock, patch import pytest from pydantic import AnyUrl, AnyHttpUrl from src.delta_lake.setup_spark_session import ( convert_memory_format, _get_executor_conf, _get_spark_defaults_conf, _get_delta_conf, _get_hive_conf, _get_s3_conf, _filter_immutable_spark_connect_configs, _set_scheduler_pool, _clear_spark_env_for_mode_switch, generate_spark_conf, get_spark_session, SPARK_DEFAULT_POOL, SPARK_POOLS, EXECUTOR_MEMORY_OVERHEAD, DRIVER_MEMORY_OVERHEAD, IMMUTABLE_CONFIGS, _spark_session_lock, ) from src.settings import BERDLSettings # ============================================================================= # Fixtures # ============================================================================= @pytest.fixture def test_settings(): """Create test BERDLSettings.""" return BERDLSettings( KBASE_AUTH_TOKEN="test_token", USER="testuser", MINIO_ENDPOINT_URL="minio.test:9000", MINIO_ACCESS_KEY="test_access", MINIO_SECRET_KEY="test_secret", MINIO_SECURE=False, BERDL_REDIS_HOST="localhost", BERDL_REDIS_PORT=6379, SPARK_HOME="/usr/local/spark", SPARK_MASTER_URL=AnyUrl("spark://master:7077"), SPARK_CONNECT_URL=AnyUrl("sc://connect:15002"), BERDL_HIVE_METASTORE_URI=AnyUrl("thrift://hive:9083"), SPARK_WORKER_COUNT=2, SPARK_WORKER_CORES=4, SPARK_WORKER_MEMORY="4GiB", SPARK_MASTER_CORES=2, SPARK_MASTER_MEMORY="2GiB", GOVERNANCE_API_URL=AnyHttpUrl("http://governance:8000"), BERDL_POD_IP="10.0.0.1", ) # ============================================================================= # Test convert_memory_format # ============================================================================= class TestConvertMemoryFormat: """Tests for the convert_memory_format function.""" def test_gib_to_spark_format(self): """Test converting GiB to Spark format with overhead.""" # 4GiB with 10% overhead = 3.6GiB, rounds to 4g result = convert_memory_format("4GiB", overhead_percentage=0.1) assert result.endswith("g") # Result depends on rounding behavior assert result in ["3g", "4g"] def test_mib_to_spark_format(self): """Test converting MiB to Spark format.""" # 512MiB with 10% overhead = 460.8MiB result = convert_memory_format("512MiB", overhead_percentage=0.1) assert result.endswith("m") # 512 * 0.9 = 460.8 → rounds to 461 assert int(result[:-1]) in [460, 461] def test_kib_to_spark_format(self): """Test converting KiB to Spark format.""" result = convert_memory_format("1024KiB", overhead_percentage=0.0) assert result.endswith("m") assert result == "1m" # 1024KB = 1MB def test_gib_format_basic(self): """Test basic GiB format conversion.""" result = convert_memory_format("2GiB", overhead_percentage=0.0) assert result == "2g" def test_zero_overhead(self): """Test with zero overhead percentage.""" result = convert_memory_format("1GiB", overhead_percentage=0.0) assert result == "1g" def test_high_overhead(self): """Test with high overhead percentage.""" # 50% overhead result = convert_memory_format("4GiB", overhead_percentage=0.5) assert result == "2g" def test_various_unit_formats(self): """Test various memory unit formats.""" # Different case variations assert convert_memory_format("1GB", 0.0) == "1g" assert convert_memory_format("1gb", 0.0) == "1g" assert convert_memory_format("1Gb", 0.0) == "1g" assert convert_memory_format("1gib", 0.0) == "1g" def test_decimal_values(self): """Test decimal memory values.""" result = convert_memory_format("1.5GiB", overhead_percentage=0.0) # 1.5GB = 1536MB, may return as g or appropriate unit assert result in ["1g", "2g", "1536m"] def test_tb_format(self): """Test terabyte format.""" result = convert_memory_format("1TiB", overhead_percentage=0.0) assert result == "1024g" # 1TB = 1024GB def test_invalid_format_raises_error(self): """Test that invalid format raises ValueError.""" with pytest.raises(ValueError, match="Invalid memory format"): convert_memory_format("invalid") def test_invalid_unit_raises_error(self): """Test that invalid unit raises error or is handled.""" with pytest.raises(ValueError, match="Invalid memory format"): convert_memory_format("1XB") def test_executor_overhead_constant(self): """Test EXECUTOR_MEMORY_OVERHEAD constant value.""" assert EXECUTOR_MEMORY_OVERHEAD == 0.1 def test_driver_overhead_constant(self): """Test DRIVER_MEMORY_OVERHEAD constant value.""" assert DRIVER_MEMORY_OVERHEAD == 0.05 def test_memory_with_mib_unit(self): """Test memory with MiB unit.""" result = convert_memory_format("256MiB", overhead_percentage=0.0) assert result == "256m" def test_whitespace_in_format(self): """Test memory format with whitespace.""" result = convert_memory_format("4 GiB", overhead_percentage=0.0) assert result == "4g" # ============================================================================= # Test _get_executor_conf # ============================================================================= class TestGetExecutorConf: """Tests for the _get_executor_conf function.""" def test_spark_connect_mode_config(self, test_settings): """Test configuration for Spark Connect mode.""" config = _get_executor_conf(test_settings, use_spark_connect=True) assert "spark.remote" in config assert config["spark.remote"] == str(test_settings.SPARK_CONNECT_URL) assert "spark.driver.host" not in config # Not in Connect mode assert "spark.master" not in config # Not in Connect mode def test_legacy_mode_config(self, test_settings): """Test configuration for legacy mode.""" with patch("socket.gethostbyname", return_value="192.168.1.100"): with patch("socket.gethostname", return_value="testhost"): config = _get_executor_conf(test_settings, use_spark_connect=False) assert "spark.driver.host" in config assert config["spark.driver.host"] == "192.168.1.100" assert config["spark.driver.bindAddress"] == "0.0.0.0" assert "spark.master" in config assert "spark.remote" not in config def test_executor_configuration(self, test_settings): """Test executor-specific configuration.""" config = _get_executor_conf(test_settings, use_spark_connect=True) assert "spark.executor.instances" in config assert config["spark.executor.instances"] == "2" assert config["spark.executor.cores"] == "4" assert "spark.executor.memory" in config def test_driver_configuration(self, test_settings): """Test driver-specific configuration.""" config = _get_executor_conf(test_settings, use_spark_connect=True) assert "spark.driver.cores" in config assert config["spark.driver.cores"] == "2" assert "spark.driver.memory" in config def test_dynamic_allocation_disabled(self, test_settings): """Test that dynamic allocation is disabled.""" config = _get_executor_conf(test_settings, use_spark_connect=True) assert config["spark.dynamicAllocation.enabled"] == "false" assert config["spark.dynamicAllocation.shuffleTracking.enabled"] == "false" # ============================================================================= # Test _get_spark_defaults_conf # ============================================================================= class TestGetSparkDefaultsConf: """Tests for the _get_spark_defaults_conf function.""" def test_returns_dict(self): """Test that function returns a dictionary.""" config = _get_spark_defaults_conf() assert isinstance(config, dict) def test_decommission_settings(self): """Test decommission settings are present.""" config = _get_spark_defaults_conf() assert config["spark.decommission.enabled"] == "true" assert config["spark.storage.decommission.rddBlocks.enabled"] == "true" def test_broadcast_join_threshold(self): """Test broadcast join threshold setting.""" config = _get_spark_defaults_conf() assert "spark.sql.autoBroadcastJoinThreshold" in config # 50MB = 52428800 bytes assert config["spark.sql.autoBroadcastJoinThreshold"] == "52428800" def test_shuffle_configuration(self): """Test shuffle configuration settings.""" config = _get_spark_defaults_conf() assert config["spark.reducer.maxSizeInFlight"] == "96m" assert config["spark.shuffle.file.buffer"] == "1m" # ============================================================================= # Test _get_delta_conf # ============================================================================= class TestGetDeltaConf: """Tests for the _get_delta_conf function.""" def test_returns_dict(self): """Test that function returns a dictionary.""" config = _get_delta_conf() assert isinstance(config, dict) def test_delta_extensions(self): """Test Delta Lake SQL extensions are configured.""" config = _get_delta_conf() assert ( config["spark.sql.extensions"] == "io.delta.sql.DeltaSparkSessionExtension" ) assert ( config["spark.sql.catalog.spark_catalog"] == "org.apache.spark.sql.delta.catalog.DeltaCatalog" ) def test_delta_optimizations(self): """Test Delta Lake optimization settings.""" config = _get_delta_conf() assert config["spark.databricks.delta.optimizeWrite.enabled"] == "true" assert config["spark.databricks.delta.autoCompact.enabled"] == "true" assert ( config["spark.databricks.delta.retentionDurationCheck.enabled"] == "false" ) # ============================================================================= # Test _get_hive_conf # ============================================================================= class TestGetHiveConf: """Tests for the _get_hive_conf function.""" def test_hive_metastore_uri(self, test_settings): """Test Hive metastore URI is configured.""" config = _get_hive_conf(test_settings) assert config["hive.metastore.uris"] == str( test_settings.BERDL_HIVE_METASTORE_URI ) def test_catalog_implementation(self, test_settings): """Test catalog implementation is Hive.""" config = _get_hive_conf(test_settings) assert config["spark.sql.catalogImplementation"] == "hive" def test_hive_metastore_version(self, test_settings): """Test Hive metastore version.""" config = _get_hive_conf(test_settings) assert config["spark.sql.hive.metastore.version"] == "4.0.0" assert config["spark.sql.hive.metastore.jars"] == "path" assert "spark.sql.hive.metastore.jars.path" in config # ============================================================================= # Test _get_s3_conf # ============================================================================= class TestGetS3Conf: """Tests for the _get_s3_conf function.""" def test_s3_endpoint(self, test_settings): """Test S3 endpoint configuration.""" config = _get_s3_conf(test_settings) assert ( config["spark.hadoop.fs.s3a.endpoint"] == test_settings.MINIO_ENDPOINT_URL ) assert ( config["spark.hadoop.fs.s3a.access.key"] == test_settings.MINIO_ACCESS_KEY ) assert ( config["spark.hadoop.fs.s3a.secret.key"] == test_settings.MINIO_SECRET_KEY ) def test_ssl_disabled(self, test_settings): """Test SSL disabled setting.""" config = _get_s3_conf(test_settings) assert config["spark.hadoop.fs.s3a.connection.ssl.enabled"] == "false" def test_ssl_enabled(self, test_settings): """Test SSL enabled setting.""" test_settings_ssl = BERDLSettings( USER="testuser", MINIO_ENDPOINT_URL="minio.test:9000", MINIO_ACCESS_KEY="key", MINIO_SECRET_KEY="secret", MINIO_SECURE=True, SPARK_CONNECT_URL=AnyUrl("sc://localhost:15002"), BERDL_HIVE_METASTORE_URI=AnyUrl("thrift://localhost:9083"), GOVERNANCE_API_URL=AnyHttpUrl("http://localhost:8000"), ) config = _get_s3_conf(test_settings_ssl) assert config["spark.hadoop.fs.s3a.connection.ssl.enabled"] == "true" def test_user_warehouse_path(self, test_settings): """Test user warehouse path without tenant.""" config = _get_s3_conf(test_settings, tenant_name=None) expected_warehouse = f"s3a://cdm-lake/users-sql-warehouse/{test_settings.USER}/" assert config["spark.sql.warehouse.dir"] == expected_warehouse def test_tenant_warehouse_path(self, test_settings): """Test tenant warehouse path.""" config = _get_s3_conf(test_settings, tenant_name="research_team") expected_warehouse = "s3a://cdm-lake/tenant-sql-warehouse/research_team/" assert config["spark.sql.warehouse.dir"] == expected_warehouse def test_event_log_directory(self, test_settings): """Test event log directory is configured.""" config = _get_s3_conf(test_settings) assert config["spark.eventLog.enabled"] == "true" assert f"{test_settings.USER}/" in config["spark.eventLog.dir"] def test_s3a_implementation(self, test_settings): """Test S3A implementation is configured.""" config = _get_s3_conf(test_settings) assert config["spark.hadoop.fs.s3a.path.style.access"] == "true" assert ( config["spark.hadoop.fs.s3a.impl"] == "org.apache.hadoop.fs.s3a.S3AFileSystem" ) # ============================================================================= # Test _filter_immutable_spark_connect_configs # ============================================================================= class TestFilterImmutableSparkConnectConfigs: """Tests for the _filter_immutable_spark_connect_configs function.""" def test_filters_immutable_configs(self): """Test that immutable configs are filtered out.""" config = { "spark.app.name": "test", "spark.driver.memory": "2g", # Immutable "spark.executor.memory": "4g", # Immutable "spark.sql.shuffle.partitions": "200", # Mutable } filtered = _filter_immutable_spark_connect_configs(config) assert "spark.app.name" in filtered assert "spark.sql.shuffle.partitions" in filtered assert "spark.driver.memory" not in filtered assert "spark.executor.memory" not in filtered def test_keeps_mutable_configs(self): """Test that mutable configs are kept.""" config = { "spark.app.name": "test", "spark.sql.shuffle.partitions": "100", "custom.config": "value", } filtered = _filter_immutable_spark_connect_configs(config) assert filtered == config def test_empty_config(self): """Test with empty config.""" filtered = _filter_immutable_spark_connect_configs({}) assert filtered == {} def test_all_immutable_configs(self): """Test with all immutable configs.""" config = {k: "value" for k in list(IMMUTABLE_CONFIGS)[:5]} filtered = _filter_immutable_spark_connect_configs(config) assert filtered == {} def test_immutable_configs_set(self): """Test that IMMUTABLE_CONFIGS contains expected keys.""" assert "spark.driver.memory" in IMMUTABLE_CONFIGS assert "spark.executor.memory" in IMMUTABLE_CONFIGS assert "spark.sql.warehouse.dir" in IMMUTABLE_CONFIGS assert "spark.eventLog.dir" in IMMUTABLE_CONFIGS # ============================================================================= # Test _set_scheduler_pool # ============================================================================= class TestSetSchedulerPool: """Tests for the _set_scheduler_pool function.""" def test_valid_pool(self): """Test setting a valid scheduler pool.""" mock_spark = MagicMock() mock_spark_context = MagicMock() mock_spark.sparkContext = mock_spark_context _set_scheduler_pool(mock_spark, "default") mock_spark_context.setLocalProperty.assert_called_once_with( "spark.scheduler.pool", "default" ) def test_high_priority_pool(self): """Test setting high priority pool.""" mock_spark = MagicMock() mock_spark_context = MagicMock() mock_spark.sparkContext = mock_spark_context _set_scheduler_pool(mock_spark, "highPriority") mock_spark_context.setLocalProperty.assert_called_once_with( "spark.scheduler.pool", "highPriority" ) def test_invalid_pool_defaults(self, capsys): """Test that invalid pool defaults to default pool.""" mock_spark = MagicMock() mock_spark_context = MagicMock() mock_spark.sparkContext = mock_spark_context _set_scheduler_pool(mock_spark, "invalidPool") # Should print warning captured = capsys.readouterr() assert "Warning" in captured.out assert "invalidPool" in captured.out # Should use default pool mock_spark_context.setLocalProperty.assert_called_once_with( "spark.scheduler.pool", SPARK_DEFAULT_POOL ) def test_spark_pools_constant(self): """Test SPARK_POOLS constant.""" assert "default" in SPARK_POOLS assert "highPriority" in SPARK_POOLS assert SPARK_DEFAULT_POOL == "default" # ============================================================================= # Test _clear_spark_env_for_mode_switch # ============================================================================= class TestClearSparkEnvForModeSwitch: """Tests for the _clear_spark_env_for_mode_switch function.""" def test_clear_for_spark_connect_mode(self): """Test clearing env vars for Spark Connect mode.""" with patch.dict(os.environ, {"MASTER": "spark://old:7077"}, clear=False): _clear_spark_env_for_mode_switch(use_spark_connect=True) # MASTER should be cleared for Connect mode assert "MASTER" not in os.environ def test_clear_for_legacy_mode(self): """Test clearing env vars for legacy mode.""" env_vars = { "SPARK_CONNECT_MODE_ENABLED": "1", "SPARK_REMOTE": "sc://old:15002", "SPARK_LOCAL_REMOTE": "1", } with patch.dict(os.environ, env_vars, clear=False): _clear_spark_env_for_mode_switch(use_spark_connect=False) # All Connect mode vars should be cleared assert "SPARK_CONNECT_MODE_ENABLED" not in os.environ assert "SPARK_REMOTE" not in os.environ assert "SPARK_LOCAL_REMOTE" not in os.environ def test_no_vars_to_clear(self): """Test when no env vars need to be cleared.""" # Should not raise even if vars don't exist _clear_spark_env_for_mode_switch(use_spark_connect=True) _clear_spark_env_for_mode_switch(use_spark_connect=False) # Verify no unexpected changes # (can't easily assert exact equality due to test isolation) def test_preserves_other_env_vars(self): """Test that other env vars are preserved.""" with patch.dict( os.environ, {"MY_VAR": "value", "MASTER": "spark://old:7077"}, clear=False ): _clear_spark_env_for_mode_switch(use_spark_connect=True) # MY_VAR should still exist assert os.environ.get("MY_VAR") == "value" # ============================================================================= # Test generate_spark_conf # ============================================================================= class TestGenerateSparkConf: """Tests for the generate_spark_conf function.""" def test_default_app_name_generated(self, test_settings): """Test that default app name is generated when not provided.""" config = generate_spark_conf( app_name=None, local=True, use_delta_lake=True, settings=test_settings ) assert "spark.app.name" in config assert "kbase_spark_session_" in config["spark.app.name"] def test_custom_app_name(self, test_settings): """Test custom app name is used.""" config = generate_spark_conf( app_name="MyCustomApp", local=True, settings=test_settings ) assert config["spark.app.name"] == "MyCustomApp" def test_local_mode_minimal_config(self, test_settings): """Test local mode has minimal configuration.""" config = generate_spark_conf( app_name="LocalApp", local=True, use_delta_lake=False, settings=test_settings, ) # Should only have app name assert "spark.app.name" in config # Should not have cluster configs assert "spark.executor.instances" not in config assert "spark.driver.host" not in config def test_local_mode_with_delta(self, test_settings): """Test local mode with Delta Lake enabled.""" config = generate_spark_conf( app_name="LocalApp", local=True, use_delta_lake=True, settings=test_settings ) assert "spark.sql.extensions" in config assert "DeltaSparkSessionExtension" in config["spark.sql.extensions"] def test_cluster_mode_full_config(self, test_settings): """Test cluster mode has full configuration.""" config = generate_spark_conf( app_name="ClusterApp", local=False, use_delta_lake=True, use_s3=True, use_hive=True, settings=test_settings, use_spark_connect=True, ) # Should have S3 config assert "spark.hadoop.fs.s3a.endpoint" in config # Should have app name assert config["spark.app.name"] == "ClusterApp" def test_spark_connect_mode_filters_immutable(self, test_settings): """Test Spark Connect mode filters immutable configs.""" config = generate_spark_conf( app_name="ConnectApp", local=False, settings=test_settings, use_spark_connect=True, ) # Immutable configs should be filtered assert "spark.driver.memory" not in config assert "spark.executor.memory" not in config def test_legacy_mode_keeps_all_configs(self, test_settings): """Test legacy mode keeps all configurations.""" with patch("socket.gethostbyname", return_value="192.168.1.1"): with patch("socket.gethostname", return_value="host"): config = generate_spark_conf( app_name="LegacyApp", local=False, settings=test_settings, use_spark_connect=False, ) # Should have driver/executor configs assert "spark.driver.memory" in config assert "spark.executor.memory" in config def test_disable_s3(self, test_settings): """Test disabling S3 configuration.""" config = generate_spark_conf( app_name="NoS3App", local=False, use_s3=False, settings=test_settings, use_spark_connect=True, ) assert "spark.hadoop.fs.s3a.endpoint" not in config def test_disable_hive(self, test_settings): """Test disabling Hive configuration.""" config = generate_spark_conf( app_name="NoHiveApp", local=False, use_hive=False, settings=test_settings, use_spark_connect=True, ) assert "hive.metastore.uris" not in config def test_with_tenant_name(self, test_settings): """Test configuration with tenant name.""" with patch("socket.gethostbyname", return_value="192.168.1.1"): with patch("socket.gethostname", return_value="host"): config = generate_spark_conf( app_name="TenantApp", local=False, tenant_name="my_tenant", settings=test_settings, use_spark_connect=False, ) # Warehouse should use tenant path assert "tenant-sql-warehouse/my_tenant" in config["spark.sql.warehouse.dir"] def test_settings_from_get_settings(self): """Test that settings are loaded from get_settings when not provided.""" with patch("src.delta_lake.setup_spark_session.get_settings") as mock_get: mock_settings = MagicMock() mock_settings.SPARK_CONNECT_URL = AnyUrl("sc://test:15002") mock_settings.SPARK_WORKER_MEMORY = "2GiB" mock_settings.SPARK_MASTER_MEMORY = "1GiB" mock_settings.SPARK_WORKER_COUNT = 1 mock_settings.SPARK_WORKER_CORES = 1 mock_settings.SPARK_MASTER_CORES = 1 mock_settings.MINIO_ENDPOINT_URL = "localhost:9000" mock_settings.MINIO_ACCESS_KEY = "key" mock_settings.MINIO_SECRET_KEY = "secret" mock_settings.MINIO_SECURE = False mock_settings.USER = "testuser" mock_settings.BERDL_HIVE_METASTORE_URI = AnyUrl("thrift://localhost:9083") mock_get.return_value = mock_settings # Call generate_spark_conf to trigger get_settings loading generate_spark_conf( app_name="TestApp", local=False, settings=None, # Force loading from get_settings use_spark_connect=True, ) mock_get.assert_called() # ============================================================================= # Test get_spark_session # ============================================================================= class TestGetSparkSession: """Tests for the get_spark_session function.""" def test_creates_spark_session(self, test_settings): """Test that get_spark_session creates a SparkSession.""" mock_session = MagicMock() mock_builder = MagicMock() mock_builder.config.return_value = mock_builder mock_builder.getOrCreate.return_value = mock_session with patch( "src.delta_lake.setup_spark_session.SparkSession" ) as mock_spark_class: mock_spark_class.builder = mock_builder with patch( "src.delta_lake.setup_spark_session._clear_spark_env_for_mode_switch" ): result = get_spark_session( app_name="TestSession", local=True, delta_lake=True, settings=test_settings, ) assert result == mock_session mock_builder.getOrCreate.assert_called_once() def test_clears_builder_options(self, test_settings): """Test that builder options are cleared.""" mock_session = MagicMock() mock_builder = MagicMock() mock_builder._options = {"old": "option"} mock_builder.config.return_value = mock_builder mock_builder.getOrCreate.return_value = mock_session with patch( "src.delta_lake.setup_spark_session.SparkSession" ) as mock_spark_class: mock_spark_class.builder = mock_builder with patch( "src.delta_lake.setup_spark_session._clear_spark_env_for_mode_switch" ): get_spark_session( app_name="TestSession", local=True, settings=test_settings, ) # Options should be cleared assert mock_builder._options == {} def test_override_config(self, test_settings): """Test override configuration.""" mock_session = MagicMock() mock_builder = MagicMock() mock_builder.config.return_value = mock_builder mock_builder.getOrCreate.return_value = mock_session with patch( "src.delta_lake.setup_spark_session.SparkSession" ) as mock_spark_class: mock_spark_class.builder = mock_builder with patch( "src.delta_lake.setup_spark_session._clear_spark_env_for_mode_switch" ): get_spark_session( app_name="TestSession", local=True, settings=test_settings, override={"custom.setting": "value"}, ) # Verify config was called (with override included) mock_builder.config.assert_called_once() def test_sets_scheduler_pool_in_legacy_mode(self, test_settings): """Test scheduler pool is set in legacy mode.""" mock_session = MagicMock() mock_builder = MagicMock() mock_builder.config.return_value = mock_builder mock_builder.getOrCreate.return_value = mock_session with patch( "src.delta_lake.setup_spark_session.SparkSession" ) as mock_spark_class: mock_spark_class.builder = mock_builder with patch( "src.delta_lake.setup_spark_session._clear_spark_env_for_mode_switch" ): with patch( "src.delta_lake.setup_spark_session._set_scheduler_pool" ) as mock_set_pool: with patch("socket.gethostbyname", return_value="192.168.1.1"): with patch("socket.gethostname", return_value="host"): get_spark_session( app_name="TestSession", local=False, scheduler_pool="highPriority", settings=test_settings, use_spark_connect=False, ) mock_set_pool.assert_called_once_with(mock_session, "highPriority") def test_no_scheduler_pool_in_connect_mode(self, test_settings): """Test scheduler pool is NOT set in Connect mode.""" mock_session = MagicMock() mock_builder = MagicMock() mock_builder.config.return_value = mock_builder mock_builder.getOrCreate.return_value = mock_session with patch( "src.delta_lake.setup_spark_session.SparkSession" ) as mock_spark_class: mock_spark_class.builder = mock_builder with patch( "src.delta_lake.setup_spark_session._clear_spark_env_for_mode_switch" ): with patch( "src.delta_lake.setup_spark_session._set_scheduler_pool" ) as mock_set_pool: get_spark_session( app_name="TestSession", local=False, settings=test_settings, use_spark_connect=True, ) mock_set_pool.assert_not_called() def test_uses_spark_conf_with_load_defaults_false(self, test_settings): """Test that SparkConf is created with loadDefaults=False.""" mock_session = MagicMock() mock_builder = MagicMock() mock_builder.config.return_value = mock_builder mock_builder.getOrCreate.return_value = mock_session with patch( "src.delta_lake.setup_spark_session.SparkSession" ) as mock_spark_class: mock_spark_class.builder = mock_builder with patch( "src.delta_lake.setup_spark_session.SparkConf" ) as mock_conf_class: mock_conf = MagicMock() mock_conf.setAll.return_value = mock_conf mock_conf_class.return_value = mock_conf with patch( "src.delta_lake.setup_spark_session._clear_spark_env_for_mode_switch" ): get_spark_session( app_name="TestSession", local=True, settings=test_settings, ) # Verify SparkConf was created with loadDefaults=False mock_conf_class.assert_called_once_with(loadDefaults=False) # ============================================================================= # Thread Safety Tests # ============================================================================= class TestThreadSafety: """Tests for thread safety of Spark session creation.""" def test_spark_session_lock_exists(self): """Test that the global lock exists.""" assert _spark_session_lock is not None assert isinstance(_spark_session_lock, type(threading.Lock())) def test_concurrent_session_creation_uses_lock(self, test_settings): """Test that concurrent session creation uses the lock.""" results = [] errors = [] mock_session = MagicMock() mock_builder = MagicMock() mock_builder.config.return_value = mock_builder mock_builder.getOrCreate.return_value = mock_session def create_session(thread_id): try: with patch( "src.delta_lake.setup_spark_session.SparkSession" ) as mock_spark_class: mock_spark_class.builder = mock_builder with patch( "src.delta_lake.setup_spark_session._clear_spark_env_for_mode_switch" ): # Simulate some work inside the lock result = get_spark_session( app_name=f"Thread{thread_id}Session", local=True, settings=test_settings, ) results.append((thread_id, result)) except Exception as e: errors.append((thread_id, e)) threads = [] for i in range(5): t = threading.Thread(target=create_session, args=(i,)) threads.append(t) for t in threads: t.start() for t in threads: t.join() # All threads should complete without errors assert len(errors) == 0 assert len(results) == 5 def test_lock_provides_thread_isolation(self, test_settings): """Test that lock provides thread isolation.""" # This test verifies the lock mechanism exists and works # We test it indirectly by verifying sessions can be created safely execution_count = {"value": 0} results = [] errors = [] lock = threading.Lock() def create_session_mock(thread_id): try: mock_session = MagicMock() mock_builder = MagicMock() mock_builder._options = {} mock_builder.config.return_value = mock_builder mock_builder.getOrCreate.return_value = mock_session with patch( "src.delta_lake.setup_spark_session.SparkSession" ) as mock_spark_class: mock_spark_class.builder = mock_builder with patch( "src.delta_lake.setup_spark_session._clear_spark_env_for_mode_switch" ): result = get_spark_session( app_name=f"Thread{thread_id}", local=True, settings=test_settings, ) with lock: execution_count["value"] += 1 results.append((thread_id, result)) except Exception as e: with lock: errors.append((thread_id, e)) threads = [] for i in range(3): t = threading.Thread( target=create_session_mock, args=(i,), name=f"Thread-{i}" ) threads.append(t) for t in threads: t.start() for t in threads: t.join() # All threads should complete without errors assert len(errors) == 0, f"Errors occurred: {errors}" assert execution_count["value"] == 3 # ============================================================================= # Integration-style Tests # ============================================================================= class TestConfigIntegration: """Integration tests for configuration generation.""" def test_full_config_generation_spark_connect(self, test_settings): """Test full configuration generation for Spark Connect mode.""" config = generate_spark_conf( app_name="IntegrationTest", local=False, use_delta_lake=True, use_s3=True, use_hive=True, settings=test_settings, use_spark_connect=True, ) # Verify essential configs are present assert "spark.app.name" in config assert config["spark.app.name"] == "IntegrationTest" # S3 configs should be present but filtered assert "spark.hadoop.fs.s3a.endpoint" in config # Immutable configs should be filtered assert "spark.driver.memory" not in config def test_full_config_generation_legacy_mode(self, test_settings): """Test full configuration generation for legacy mode.""" with patch("socket.gethostbyname", return_value="10.0.0.1"): with patch("socket.gethostname", return_value="testhost"): config = generate_spark_conf( app_name="LegacyIntegration", local=False, use_delta_lake=True, use_s3=True, use_hive=True, settings=test_settings, use_spark_connect=False, ) # All configs should be present assert "spark.driver.memory" in config assert "spark.executor.memory" in config assert "spark.driver.host" in config assert config["spark.driver.host"] == "10.0.0.1" def test_local_development_config(self, test_settings): """Test configuration for local development.""" config = generate_spark_conf( app_name="LocalDev", local=True, use_delta_lake=True, settings=test_settings, ) # Should have minimal config for local mode assert "spark.app.name" in config assert "spark.sql.extensions" in config # Delta enabled # Should not have cluster configs assert "spark.executor.instances" not in config assert "spark.hadoop.fs.s3a.endpoint" not in config

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/BERDataLakehouse/datalake-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server