import unittest
from unittest.mock import MagicMock, patch
import json
import os
import sys
# Add parent directory to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.client import SparkHistoryClient
from src.optimizer.engine import OptimizationEngine
from src.llm_client import LLMClient
from src.models import StageMetric, ExecutorMetric, SQLMetric
class TestSparkHistoryClient(unittest.TestCase):
"""Test the Spark History Server client."""
def setUp(self):
self.client = SparkHistoryClient(base_url="http://localhost:18080")
@patch('src.client.urllib.request.urlopen')
def test_get_stages(self, mock_urlopen):
"""Test fetching stages from SHS."""
mock_response = MagicMock()
mock_response.status = 200
mock_response.read.return_value = json.dumps([
{
"stageId": 0,
"name": "map",
"status": "COMPLETE",
"numTasks": 10,
"numActiveTasks": 0,
"numCompleteTasks": 10,
"numFailedTasks": 0,
"executorRunTime": 5000,
"executorCpuTime": 4000000000,
"jvmGcTime": 100,
"resultSerializationTime": 10,
"inputBytes": 1024000,
"outputBytes": 512000,
"shuffleReadBytes": 0,
"shuffleWriteBytes": 256000,
"diskBytesSpilled": 0,
"memoryBytesSpilled": 0,
"peakExecutionMemory": 1024000
}
]).encode()
mock_urlopen.return_value.__enter__.return_value = mock_response
stages = self.client.get_stages("app-123")
self.assertEqual(len(stages), 1)
self.assertEqual(stages[0].stageId, 0)
self.assertEqual(stages[0].name, "map")
@patch('src.client.urllib.request.urlopen')
def test_get_stage_details_with_quantiles(self, mock_urlopen):
"""Test fetching stage details with quantiles parameter."""
mock_response = MagicMock()
mock_response.status = 200
mock_response.read.return_value = json.dumps([{
"stageId": 0,
"status": "COMPLETE",
"executorMetricsDistributions": {
"quantiles": [0.05, 0.25, 0.5, 0.75, 0.95],
"executorRunTime": [100, 200, 300, 500, 1000]
}
}]).encode()
mock_urlopen.return_value.__enter__.return_value = mock_response
details = self.client.get_stage_details("app-123", 0)
self.assertIn("executorMetricsDistributions", details)
self.assertEqual(details["executorMetricsDistributions"]["quantiles"][2], 0.5)
class TestOptimizationEngine(unittest.TestCase):
"""Test the optimization engine and agent orchestration."""
def setUp(self):
self.mock_spark = MagicMock(spec=SparkHistoryClient)
self.mock_llm = MagicMock(spec=LLMClient)
self.engine = OptimizationEngine(self.mock_spark, self.mock_llm)
def test_skew_detection_with_metrics(self):
"""Test that skew detection uses task distribution metrics."""
# Setup mock data with skew
self.mock_spark.get_jobs.return_value = []
self.mock_spark.get_stages.return_value = [
StageMetric(
stageId=0, name="map", status="COMPLETE",
numTasks=10, numActiveTasks=0, numCompleteTasks=10, numFailedTasks=0,
executorRunTime=10000, executorCpuTime=8000000000,
jvmGcTime=100, resultSerializationTime=10,
inputBytes=1024000, outputBytes=512000,
shuffleReadBytes=0, shuffleWriteBytes=256000,
diskBytesSpilled=0, memoryBytesSpilled=0,
peakExecutionMemory=1024000
)
]
self.mock_spark.get_executors.return_value = []
self.mock_spark.get_sql_metrics.return_value = []
self.mock_spark.get_rdd_storage.return_value = []
self.mock_spark.get_environment.return_value = {}
# Mock stage details with quantiles showing skew
self.mock_spark.get_stage_details.return_value = {
"stageId": 0,
"executorMetricsDistributions": {
"quantiles": [0.05, 0.25, 0.5, 0.75, 0.95],
"executorRunTime": [100, 200, 500, 1000, 5000] # High variance = skew
}
}
# Mock LLM responses with proper skew metrics
self.mock_llm.generate_recommendation.side_effect = [
'{"bottlenecks": [], "imbalance_detected": false, "overhead_analysis": "ok"}',
'{"spill_issues": [], "partitioning_issues": "ok"}',
'{"skewed_stages": [{"stageId": 0, "skewType": "duration", "details": "High variance", "skew_ratio": 10.0, "max_duration": 5000.0, "median_duration": 500.0}], "suggested_mitigations": ["Enable AQE"]}',
'{"inefficient_joins": [], "missing_predicates": [], "aqe_opportunities": "none"}',
'{"recommendations": []}',
'{"code_issues": []}'
]
report = self.engine.analyze_application("app-123")
# Verify skew was detected with proper metrics
self.assertEqual(len(report.skew_analysis), 1)
self.assertTrue(report.skew_analysis[0].is_skewed)
self.assertEqual(report.skew_analysis[0].skew_ratio, 10.0)
self.assertEqual(report.skew_analysis[0].max_duration, 5000.0)
self.assertEqual(report.skew_analysis[0].median_duration, 500.0)
def test_spill_detection(self):
"""Test that spill is properly detected from stage metrics."""
self.mock_spark.get_jobs.return_value = []
self.mock_spark.get_stages.return_value = [
StageMetric(
stageId=0, name="sort", status="COMPLETE",
numTasks=10, numActiveTasks=0, numCompleteTasks=10, numFailedTasks=0,
executorRunTime=10000, executorCpuTime=8000000000,
jvmGcTime=100, resultSerializationTime=10,
inputBytes=1024000, outputBytes=512000,
shuffleReadBytes=512000, shuffleWriteBytes=256000,
diskBytesSpilled=104857600, # 100MB spill
memoryBytesSpilled=52428800, # 50MB spill
peakExecutionMemory=1024000
)
]
self.mock_spark.get_executors.return_value = []
self.mock_spark.get_sql_metrics.return_value = []
self.mock_spark.get_rdd_storage.return_value = []
self.mock_spark.get_environment.return_value = {}
self.mock_spark.get_stage_details.return_value = {}
# Mock LLM responses
self.mock_llm.generate_recommendation.side_effect = [
'{"bottlenecks": [], "imbalance_detected": false, "overhead_analysis": "ok"}',
'{"spill_issues": [{"stageId": 0, "memorySpill": "50MB", "diskSpill": "100MB", "recommendation": "Increase executor memory"}], "partitioning_issues": "ok"}',
'{"skewed_stages": [], "suggested_mitigations": []}',
'{"inefficient_joins": [], "missing_predicates": [], "aqe_opportunities": "none"}',
'{"recommendations": [{"config": "spark.executor.memory", "current": "1g", "suggested": "4g", "reason": "Reduce spill"}]}',
'{"code_issues": []}'
]
report = self.engine.analyze_application("app-123")
# Verify spill was detected
self.assertEqual(len(report.spill_analysis), 1)
self.assertTrue(report.spill_analysis[0].has_spill)
# Note: The actual values come from stage metrics, not LLM
self.assertGreater(report.spill_analysis[0].total_disk_spill, 0)
class TestEndToEndWithRealJobs(unittest.TestCase):
"""Integration tests using real Spark History Server data."""
@unittest.skipUnless(os.getenv("RUN_INTEGRATION_TESTS"), "Integration tests disabled")
def test_analyze_skew_job(self):
"""Test analyzing the skew job from golden test cases."""
client = SparkHistoryClient(base_url="http://localhost:18080")
llm = LLMClient(api_key=os.getenv("GEMINI_API_KEY"))
engine = OptimizationEngine(client, llm)
report = engine.analyze_application("application_1768320005356_0008")
# Verify report structure
self.assertIsNotNone(report)
self.assertEqual(report.app_id, "application_1768320005356_0008")
# Should detect skew
self.assertGreater(len(report.skew_analysis), 0)
# Should have recommendations
self.assertGreater(len(report.recommendations), 0)
print(f"\n✅ Skew Job Analysis:")
print(f" - Skewed stages: {len(report.skew_analysis)}")
print(f" - Recommendations: {len(report.recommendations)}")
if __name__ == '__main__':
# Set environment variable for API key
os.environ['GEMINI_API_KEY'] = 'AIzaSyCU2RV4BpPL8HaYX7sIu5D3mSig6nKDvTE'
# Run tests
unittest.main(verbosity=2)