import urllib.request
import json
import logging
from typing import List, Dict, Any, Optional
from .models import SparkApp, JobMetric, StageMetric, ExecutorMetric, SQLMetric, RDDMetric
logger = logging.getLogger(__name__)
class SparkHistoryClient:
"""
Client for interacting with Spark History Server REST API.
This client provides methods to fetch application metrics, job details,
stage information, and other performance data from the Spark History Server.
Args:
base_url: Base URL of the Spark History Server (default: http://localhost:18080)
timeout: Request timeout in seconds (default: 30)
"""
def __init__(self, base_url: str = "http://localhost:18080", timeout: int = 30):
self.base_url = base_url.rstrip("/")
self.api_root = f"{self.base_url}/api/v1"
self.timeout = timeout
logger.info(f"Initialized SparkHistoryClient with base_url={base_url}")
def _get(self, endpoint: str) -> Any:
"""
Internal method to make GET requests to the Spark History Server.
Args:
endpoint: API endpoint path (e.g., /applications)
Returns:
Parsed JSON response or None if request fails
"""
url = f"{self.api_root}{endpoint}"
try:
req = urllib.request.Request(url)
with urllib.request.urlopen(req, timeout=self.timeout) as response:
if response.status >= 200 and response.status < 300:
data = json.loads(response.read().decode())
logger.debug(f"Successfully fetched {url}")
return data
else:
logger.warning(f"Non-200 status {response.status} from {url}")
except json.JSONDecodeError as e:
logger.error(f"Error parsing JSON from {url}: {e}")
except urllib.error.URLError as e:
logger.error(f"Network error fetching {url}: {e}")
except Exception as e:
logger.error(f"Unexpected error fetching {url}: {e}")
return None
def get_applications(self) -> List[SparkApp]:
data = self._get("/applications")
if data:
return [SparkApp(**app) for app in data]
return []
def get_jobs(self, app_id: str) -> List[JobMetric]:
data = self._get(f"/applications/{app_id}/jobs")
if data:
# Helper to safely map dict to dataclass (ignoring extra fields)
def map_job(j):
return JobMetric(
jobId=j.get('jobId'),
name=j.get('name'),
submissionTime=j.get('submissionTime'),
completionTime=j.get('completionTime'),
status=j.get('status'),
numSimpleStages=j.get('numSimpleStages', 0),
stageIds=j.get('stageIds', [])
)
return [map_job(job) for job in data]
return []
def get_stages(self, app_id: str) -> List[StageMetric]:
data = self._get(f"/applications/{app_id}/stages")
if data:
mapped_stages = []
for stage in data:
# StageMetric fits the structure of SHS response mostly, but we need to be careful with missing keys
mapped_stages.append(StageMetric(
stageId=stage.get('stageId'),
name=stage.get('name'),
status=stage.get('status'),
numTasks=stage.get('numTasks', 0),
numActiveTasks=stage.get('numActiveTasks', 0),
numCompleteTasks=stage.get('numCompleteTasks', 0),
numFailedTasks=stage.get('numFailedTasks', 0),
executorRunTime=stage.get('executorRunTime', 0),
executorCpuTime=stage.get('executorCpuTime', 0),
jvmGcTime=stage.get('jvmGcTime', 0),
resultSerializationTime=stage.get('resultSerializationTime', 0),
inputBytes=stage.get('inputBytes', 0),
outputBytes=stage.get('outputBytes', 0),
shuffleReadBytes=stage.get('shuffleReadBytes', 0),
shuffleWriteBytes=stage.get('shuffleWriteBytes', 0),
diskBytesSpilled=stage.get('diskBytesSpilled', 0),
memoryBytesSpilled=stage.get('memoryBytesSpilled', 0),
peakExecutionMemory=stage.get('peakExecutionMemory', 0)
))
return mapped_stages
return []
def get_stage_details(self, app_id: str, stage_id: int) -> Dict[str, Any]:
# Request with quantiles to get task distribution metrics
data = self._get(f"/applications/{app_id}/stages/{stage_id}?quantiles=0.05,0.25,0.5,0.75,0.95")
if data and isinstance(data, list):
return data[0] # Return most recent attempt
return {}
def get_sql_metrics(self, app_id: str) -> List[SQLMetric]:
data = self._get(f"/applications/{app_id}/sql")
if data:
return [SQLMetric(
id=s.get('id'),
description=s.get('description'),
submissionTime=s.get('submissionTime'),
completionTime=s.get('completionTime'),
duration=s.get('duration', 0),
jobIds=s.get('jobs', []),
status=s.get('status')
) for s in data]
return []
def get_executors(self, app_id: str) -> List[ExecutorMetric]:
data = self._get(f"/applications/{app_id}/executors")
if data:
return [ExecutorMetric(
id=e.get('id'),
hostPort=e.get('hostPort'),
rddBlocks=e.get('rddBlocks', 0),
memoryUsed=e.get('memoryUsed', 0),
diskUsed=e.get('diskUsed', 0),
totalCores=e.get('totalCores', 0),
maxTasks=e.get('maxTasks', 0),
activeTasks=e.get('activeTasks', 0),
failedTasks=e.get('failedTasks', 0),
completedTasks=e.get('completedTasks', 0),
totalDuration=e.get('totalDuration', 0),
totalGCTime=e.get('totalGCTime', 0),
totalInputBytes=e.get('totalInputBytes', 0),
totalShuffleRead=e.get('totalShuffleRead', 0),
totalShuffleWrite=e.get('totalShuffleWrite', 0)
) for e in data]
return []
def get_sql_plan(self, app_id: str, execution_id: int) -> Dict[str, Any]:
return self._get(f"/applications/{app_id}/sql/{execution_id}") or {}
def get_rdd_storage(self, app_id: str) -> List[RDDMetric]:
data = self._get(f"/applications/{app_id}/storage/rdd")
if data:
return [RDDMetric(
id=r.get('id'),
name=r.get('name'),
numPartitions=r.get('numPartitions', 0),
numCachedPartitions=r.get('numCachedPartitions', 0),
storageLevel=r.get('storageLevel', "Unknown"),
memoryUsed=r.get('memoryUsed', 0),
diskUsed=r.get('diskUsed', 0)
) for r in data]
return []
def get_environment(self, app_id: str) -> Dict[str, Any]:
return self._get(f"/applications/{app_id}/environment") or {}
def get_event_timeline(self, app_id: str) -> List[Dict[str, Any]]:
# SHS doesn't have a direct timeline API.
# We can construct a basic one from jobs and stages
timeline = []
jobs = self.get_jobs(app_id)
for j in jobs:
timeline.append({
"type": "JOB",
"id": j.jobId,
"start": j.submissionTime,
"end": j.completionTime,
"status": j.status
})
return timeline