import unittest
from unittest.mock import MagicMock, patch
import json
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.client import SparkHistoryClient
from src.optimizer.engine import OptimizationEngine
from src.llm_client import LLMClient
from src.models import StageMetric
class TestErrorHandling(unittest.TestCase):
"""Test error handling and edge cases."""
def setUp(self):
self.mock_spark = MagicMock(spec=SparkHistoryClient)
self.mock_llm = MagicMock(spec=LLMClient)
self.engine = OptimizationEngine(self.mock_spark, self.mock_llm)
def test_empty_stages(self):
"""Test handling of application with no stages."""
self.mock_spark.get_jobs.return_value = []
self.mock_spark.get_stages.return_value = []
self.mock_spark.get_executors.return_value = []
self.mock_spark.get_sql_metrics.return_value = []
self.mock_spark.get_rdd_storage.return_value = []
self.mock_spark.get_environment.return_value = {}
# Mock LLM responses
self.mock_llm.generate_recommendation.side_effect = [
'{"bottlenecks": [], "imbalance_detected": false, "overhead_analysis": "ok"}',
'{"spill_issues": [], "partitioning_issues": "ok"}',
'{"skewed_stages": [], "suggested_mitigations": []}',
'{"inefficient_joins": [], "missing_predicates": [], "aqe_opportunities": "none"}',
'{"recommendations": []}',
'{"code_issues": []}'
]
report = self.engine.analyze_application("app-empty")
# Should not crash and return valid report
self.assertIsNotNone(report)
self.assertEqual(report.app_id, "app-empty")
self.assertEqual(len(report.skew_analysis), 0)
self.assertEqual(len(report.spill_analysis), 0)
def test_malformed_llm_response(self):
"""Test handling of malformed LLM JSON responses."""
self.mock_spark.get_jobs.return_value = []
self.mock_spark.get_stages.return_value = []
self.mock_spark.get_executors.return_value = []
self.mock_spark.get_sql_metrics.return_value = []
self.mock_spark.get_rdd_storage.return_value = []
self.mock_spark.get_environment.return_value = {}
# Return invalid JSON
self.mock_llm.generate_recommendation.side_effect = [
'This is not JSON', # Execution agent
'{"spill_issues": [}', # Malformed JSON
'{"skewed_stages": []}',
'{"inefficient_joins": []}',
'{"recommendations": []}',
'{"code_issues": []}'
]
# Should not crash
report = self.engine.analyze_application("app-malformed")
self.assertIsNotNone(report)
def test_missing_stage_data(self):
"""Test handling of missing stage detail data."""
self.mock_spark.get_jobs.return_value = []
self.mock_spark.get_stages.return_value = [
StageMetric(
stageId=0, name="test", status="COMPLETE",
numTasks=10, numActiveTasks=0, numCompleteTasks=10, numFailedTasks=0,
executorRunTime=1000, executorCpuTime=800000000,
jvmGcTime=10, resultSerializationTime=5,
inputBytes=1024, outputBytes=512,
shuffleReadBytes=0, shuffleWriteBytes=256,
diskBytesSpilled=0, memoryBytesSpilled=0,
peakExecutionMemory=1024
)
]
self.mock_spark.get_executors.return_value = []
self.mock_spark.get_sql_metrics.return_value = []
self.mock_spark.get_rdd_storage.return_value = []
self.mock_spark.get_environment.return_value = {}
# Return empty dict for stage details
self.mock_spark.get_stage_details.return_value = {}
self.mock_llm.generate_recommendation.side_effect = [
'{"bottlenecks": [], "imbalance_detected": false, "overhead_analysis": "ok"}',
'{"spill_issues": [], "partitioning_issues": "ok"}',
'{"skewed_stages": [], "suggested_mitigations": []}',
'{"inefficient_joins": [], "missing_predicates": [], "aqe_opportunities": "none"}',
'{"recommendations": []}',
'{"code_issues": []}'
]
# Should handle gracefully
report = self.engine.analyze_application("app-missing-data")
self.assertIsNotNone(report)
class TestClientErrorHandling(unittest.TestCase):
"""Test Spark History Client error handling."""
def setUp(self):
self.client = SparkHistoryClient(base_url="http://localhost:18080")
@patch('src.client.urllib.request.urlopen')
def test_network_error(self, mock_urlopen):
"""Test handling of network errors."""
mock_urlopen.side_effect = Exception("Connection refused")
# Should return empty list, not crash
stages = self.client.get_stages("app-123")
self.assertEqual(stages, [])
@patch('src.client.urllib.request.urlopen')
def test_404_error(self, mock_urlopen):
"""Test handling of 404 errors."""
mock_response = MagicMock()
mock_response.status = 404
mock_urlopen.return_value.__enter__.return_value = mock_response
stages = self.client.get_stages("app-nonexistent")
self.assertEqual(stages, [])
@patch('src.client.urllib.request.urlopen')
def test_invalid_json_response(self, mock_urlopen):
"""Test handling of invalid JSON in response."""
mock_response = MagicMock()
mock_response.status = 200
mock_response.read.return_value = b"Not JSON"
mock_urlopen.return_value.__enter__.return_value = mock_response
stages = self.client.get_stages("app-123")
self.assertEqual(stages, [])
class TestPerformanceScenarios(unittest.TestCase):
"""Test various performance issue scenarios."""
def setUp(self):
self.mock_spark = MagicMock(spec=SparkHistoryClient)
self.mock_llm = MagicMock(spec=LLMClient)
self.engine = OptimizationEngine(self.mock_spark, self.mock_llm)
def test_high_gc_overhead(self):
"""Test detection of high GC overhead."""
self.mock_spark.get_jobs.return_value = []
self.mock_spark.get_stages.return_value = [
StageMetric(
stageId=0, name="gc_heavy", status="COMPLETE",
numTasks=10, numActiveTasks=0, numCompleteTasks=10, numFailedTasks=0,
executorRunTime=10000, executorCpuTime=5000000000,
jvmGcTime=5000, # 50% GC time!
resultSerializationTime=10,
inputBytes=1024000, outputBytes=512000,
shuffleReadBytes=0, shuffleWriteBytes=256000,
diskBytesSpilled=0, memoryBytesSpilled=0,
peakExecutionMemory=1024000
)
]
self.mock_spark.get_executors.return_value = []
self.mock_spark.get_sql_metrics.return_value = []
self.mock_spark.get_rdd_storage.return_value = []
self.mock_spark.get_environment.return_value = {}
self.mock_spark.get_stage_details.return_value = {}
self.mock_llm.generate_recommendation.side_effect = [
'{"bottlenecks": [{"stageId": 0, "issue": "High GC", "severity": "high"}], "imbalance_detected": false, "overhead_analysis": "High GC overhead detected"}',
'{"spill_issues": [], "partitioning_issues": "ok"}',
'{"skewed_stages": [], "suggested_mitigations": []}',
'{"inefficient_joins": [], "missing_predicates": [], "aqe_opportunities": "none"}',
'{"recommendations": [{"config": "spark.executor.memory", "current": "1g", "suggested": "4g", "reason": "Reduce GC pressure"}]}',
'{"code_issues": []}'
]
report = self.engine.analyze_application("app-gc")
# Should have recommendations about memory
self.assertGreater(len(report.recommendations), 0)
self.assertTrue(any("memory" in r.suggestion.lower() for r in report.recommendations))
if __name__ == '__main__':
os.environ['GEMINI_API_KEY'] = 'AIzaSyCU2RV4BpPL8HaYX7sIu5D3mSig6nKDvTE'
unittest.main(verbosity=2)