import logging
from typing import Dict, Any, List
from ..client import SparkHistoryClient
from ..llm_client import LLMClient
from ..models import OptimizationReport, SkewAnalysis, SpillAnalysis, ResourceAnalysis, PartitioningAnalysis, JoinAnalysis, Recommendation
from .agents import (
ExecutionAnalysisAgent,
ShuffleSpillAgent,
SkewDetectionAgent,
SQLPlanAgent,
ConfigRecommendationAgent,
CodeRecommendationAgent
)
class OptimizationEngine:
def __init__(self, spark_client: SparkHistoryClient, llm_client: LLMClient):
self.spark_client = spark_client
self.llm_client = llm_client
# Initialize Agents
self.execution_agent = ExecutionAnalysisAgent(llm_client)
self.shuffle_agent = ShuffleSpillAgent(llm_client)
self.skew_agent = SkewDetectionAgent(llm_client)
self.sql_agent = SQLPlanAgent(llm_client)
self.config_agent = ConfigRecommendationAgent(llm_client)
self.code_agent = CodeRecommendationAgent(llm_client)
def analyze_application(self, app_id: str, code_path: str = None, job_code: str = None) -> OptimizationReport:
logging.info(f"Starting analysis for app: {app_id}")
# 1. Gather Data (The Planner's "Data Fetching" phase)
# In a fully autonomous planner, the LLM would ask for these.
# Here we pre-fetch for efficiency and determinism as per the system design.
context = self._gather_context(app_id)
if job_code:
context['code'] = job_code
elif code_path:
try:
with open(code_path, 'r') as f:
context['code'] = f.read()
except Exception as e:
logging.warning(f"Could not read code file: {e}")
# 2. Run Agents
# Execution Analysis
exec_result = self.execution_agent.analyze(context)
context['bottlenecks'] = exec_result.get('bottlenecks', [])
# Shuffle & Spill
shuffle_result = self.shuffle_agent.analyze(context)
# Skew
skew_result = self.skew_agent.analyze(context)
# SQL
sql_result = self.sql_agent.analyze(context)
# Config (uses outputs from above)
config_result = self.config_agent.analyze(context)
# Code (if code present)
code_result = {}
if context.get('code'):
code_result = self.code_agent.analyze(context)
# 3. Consolidate Report
return self._build_report(app_id, context, exec_result, shuffle_result, skew_result, sql_result, config_result, code_result)
def _gather_context(self, app_id: str) -> Dict[str, Any]:
stages = [s.to_dict() for s in self.spark_client.get_stages(app_id)]
# Enrich top 5 longest stages with distribution data (quantiles) for skew detection
# Sort by executorRunTime descending
sorted_stages = sorted(stages, key=lambda s: s.get('executorRunTime', 0), reverse=True)[:5]
task_distributions = {}
for s in sorted_stages:
sid = s.get('stageId')
# Fetch detailed stage info including task summary if available
# Note: get_stage_details implementation needs to support fetching summary/quantiles
details = self.spark_client.get_stage_details(app_id, sid)
# Extract quantiles if present (SHS output usually has 'taskMetricsDistributions')
if 'taskMetricsDistributions' in details:
task_distributions[sid] = details['taskMetricsDistributions']
else:
# Fallback: simple numeric summary if available, or just note it's missing
task_distributions[sid] = "Not Available"
return {
"app_info": self.spark_client.get_applications(),
"jobs": [j.to_dict() for j in self.spark_client.get_jobs(app_id)],
"stages": stages,
"executors": [e.to_dict() for e in self.spark_client.get_executors(app_id)],
"sql": [s.to_dict() for s in self.spark_client.get_sql_metrics(app_id)],
"rdd": [r.to_dict() for r in self.spark_client.get_rdd_storage(app_id)],
"environment": self.spark_client.get_environment(app_id),
"task_distribution": task_distributions
}
def _build_report(self, app_id: str, context: Dict[str, Any], exec_res, shuffle_res, skew_res, sql_res, config_res, code_res) -> OptimizationReport:
# Map agent JSON outputs to the strict data types for the Report
# This involves parsing the agent outputs and converting them.
# For robustness, we handle missing fields.
recommendations = []
# Config Recommendations
for rec in config_res.get('recommendations', []):
recommendations.append(Recommendation(
category="Configuration",
issue=rec.get('reason', 'Config adjustment'),
suggestion=f"Set {rec.get('config')} to {rec.get('suggested')}",
evidence=f"Current: {rec.get('current')}",
impact_level="High" # inferred
))
# Code Recommendations
for rec in code_res.get('code_issues', []):
recommendations.append(Recommendation(
category="Code",
issue=rec.get('issue'),
suggestion=rec.get('suggestion'),
evidence=f"Line: {rec.get('line')}",
impact_level="Medium"
))
# We also need to populate the specific analysis lists
# Skew
skew_analysis = []
for s in skew_res.get('skewed_stages', []):
skew_analysis.append(SkewAnalysis(
is_skewed=True,
skew_ratio=s.get('skew_ratio', 0.0),
max_duration=s.get('max_duration', 0.0),
median_duration=s.get('median_duration', 0.0),
stage_id=s.get('stageId')
))
# Spill - extract actual metrics from stage data
spill_analysis = []
stages_data = context.get('stages', [])
stage_map = {s['stageId']: s for s in stages_data}
for s in shuffle_res.get('spill_issues', []):
stage_id = s.get('stageId')
stage_data = stage_map.get(stage_id, {})
spill_analysis.append(SpillAnalysis(
has_spill=True,
total_disk_spill=stage_data.get('diskBytesSpilled', 0),
total_memory_spill=stage_data.get('memoryBytesSpilled', 0),
stage_id=stage_id
))
return OptimizationReport(
app_id=app_id,
skew_analysis=skew_analysis,
spill_analysis=spill_analysis,
resource_analysis=[], # TODO: map from exec_res
partitioning_analysis=[], # TODO: map from shuffle_res
join_analysis=[], # TODO: map from sql_res
recommendations=recommendations
)