Skip to main content
Glama
batch_client.py24.9 kB
"""Google Cloud Dataproc Batch operations client.""" import asyncio import json import os from typing import Any import structlog from google.api_core import client_options from google.auth import default from google.cloud import dataproc_v1 from google.cloud.dataproc_v1 import types from google.oauth2 import service_account from .gcloud_config import get_default_project logger = structlog.get_logger(__name__) class DataprocBatchClient: """Client for Dataproc Batch operations.""" def __init__(self, credentials_path: str | None = None): """Initialize the Dataproc Batch client.""" self._credentials = None self._project_id = None if credentials_path and os.path.exists(credentials_path): self._credentials = service_account.Credentials.from_service_account_file( credentials_path ) with open(credentials_path) as f: service_account_info = json.load(f) self._project_id = service_account_info.get("project_id") else: self._credentials, self._project_id = default() # If no project from ADC, try gcloud config if not self._project_id: self._project_id = get_default_project() def _get_batch_client(self, region: str) -> dataproc_v1.BatchControllerClient: """Get batch controller client with regional endpoint.""" # Configure regional endpoint regional_endpoint = f"{region}-dataproc.googleapis.com" client_opts = client_options.ClientOptions(api_endpoint=regional_endpoint) return dataproc_v1.BatchControllerClient( credentials=self._credentials, client_options=client_opts ) async def create_batch_job( self, project_id: str, region: str, batch_id: str, job_type: str, main_file: str, args: list[str] | None = None, jar_files: list[str] | None = None, properties: dict[str, str] | None = None, service_account: str | None = None, network_uri: str | None = None, subnetwork_uri: str | None = None, ) -> dict[str, Any]: """Create a batch job.""" try: loop = asyncio.get_event_loop() client = self._get_batch_client(region) args = args or [] jar_files = jar_files or [] properties = properties or {} # Configure runtime runtime_config = types.RuntimeConfig() if properties: runtime_config.properties = properties # Configure environment environment_config = types.EnvironmentConfig() if service_account or network_uri or subnetwork_uri: execution_config = types.ExecutionConfig() if service_account: execution_config.service_account = service_account if network_uri: execution_config.network_uri = network_uri if subnetwork_uri: execution_config.subnetwork_uri = subnetwork_uri environment_config.execution_config = execution_config # Configure job based on type if job_type == "spark": job_config = types.SparkBatch( main_class=main_file, jar_file_uris=jar_files, args=args ) batch = types.Batch( runtime_config=runtime_config, environment_config=environment_config, spark_batch=job_config, ) elif job_type == "pyspark": job_config = types.PySparkBatch( main_python_file_uri=main_file, args=args, jar_file_uris=jar_files ) batch = types.Batch( runtime_config=runtime_config, environment_config=environment_config, pyspark_batch=job_config, ) elif job_type == "spark_sql": job_config = types.SparkSqlBatch( query_file_uri=main_file, jar_file_uris=jar_files ) batch = types.Batch( runtime_config=runtime_config, environment_config=environment_config, spark_sql_batch=job_config, ) else: raise ValueError(f"Unsupported batch job type: {job_type}") request = types.CreateBatchRequest( parent=f"projects/{project_id}/locations/{region}", batch=batch, batch_id=batch_id, ) operation = await loop.run_in_executor(None, client.create_batch, request) operation_name = getattr(operation, "name", str(operation)) return { "operation_name": operation_name, "batch_id": batch_id, "job_type": job_type, "status": "CREATING", "message": f"Batch job creation initiated. Operation: {operation_name}", } except Exception as e: logger.error("Failed to create batch job", error=str(e)) raise async def list_batch_jobs( self, project_id: str, region: str, page_size: int = 100 ) -> dict[str, Any]: """List batch jobs.""" try: loop = asyncio.get_event_loop() client = self._get_batch_client(region) request = types.ListBatchesRequest( parent=f"projects/{project_id}/locations/{region}", page_size=page_size ) response = await loop.run_in_executor(None, client.list_batches, request) batches = [] for batch in response: batches.append( { "batch_id": batch.name.split("/")[-1], "state": batch.state.name, "create_time": batch.create_time.isoformat() if batch.create_time else None, "job_type": self._get_batch_job_type(batch), "operation": batch.operation if batch.operation else None, } ) return { "batches": batches, "total_count": len(batches), "project_id": project_id, "region": region, } except Exception as e: logger.error("Failed to list batch jobs", error=str(e)) raise async def get_batch_job( self, project_id: str, region: str, batch_id: str ) -> dict[str, Any]: """Get details of a specific batch job.""" try: loop = asyncio.get_event_loop() client = self._get_batch_client(region) request = types.GetBatchRequest( name=f"projects/{project_id}/locations/{region}/batches/{batch_id}" ) batch = await loop.run_in_executor(None, client.get_batch, request) # Extract runtime info if available runtime_info = {} if batch.runtime_info: runtime_info = { "endpoints": dict(batch.runtime_info.endpoints) if batch.runtime_info.endpoints else {}, "output_uri": batch.runtime_info.output_uri if batch.runtime_info.output_uri else None, "diagnostic_output_uri": batch.runtime_info.diagnostic_output_uri if batch.runtime_info.diagnostic_output_uri else None, } # Add usage information if available if batch.runtime_info.approximate_usage: runtime_info["approximate_usage"] = { "milli_dcu_seconds": str( batch.runtime_info.approximate_usage.milli_dcu_seconds ), "shuffle_storage_gb_seconds": str( batch.runtime_info.approximate_usage.shuffle_storage_gb_seconds ), } if batch.runtime_info.current_usage: runtime_info["current_usage"] = { "milli_dcu": str(batch.runtime_info.current_usage.milli_dcu), "shuffle_storage_gb": str( batch.runtime_info.current_usage.shuffle_storage_gb ), } # Extract job configuration details job_config: dict[str, Any] = {} job_type = self._get_batch_job_type(batch) if batch.spark_batch: job_config = { "main_class": batch.spark_batch.main_class if batch.spark_batch.main_class else None, "main_jar_file_uri": batch.spark_batch.main_jar_file_uri if batch.spark_batch.main_jar_file_uri else None, "jar_file_uris": list(batch.spark_batch.jar_file_uris) if batch.spark_batch.jar_file_uris else [], "file_uris": list(batch.spark_batch.file_uris) if batch.spark_batch.file_uris else [], "archive_uris": list(batch.spark_batch.archive_uris) if batch.spark_batch.archive_uris else [], "args": list(batch.spark_batch.args) if batch.spark_batch.args else [], } elif batch.pyspark_batch: job_config = { "main_python_file_uri": batch.pyspark_batch.main_python_file_uri, "python_file_uris": list(batch.pyspark_batch.python_file_uris) if batch.pyspark_batch.python_file_uris else [], "jar_file_uris": list(batch.pyspark_batch.jar_file_uris) if batch.pyspark_batch.jar_file_uris else [], "file_uris": list(batch.pyspark_batch.file_uris) if batch.pyspark_batch.file_uris else [], "archive_uris": list(batch.pyspark_batch.archive_uris) if batch.pyspark_batch.archive_uris else [], "args": list(batch.pyspark_batch.args) if batch.pyspark_batch.args else [], } elif batch.spark_sql_batch: job_config = { "query_file_uri": batch.spark_sql_batch.query_file_uri, "query_variables": dict(batch.spark_sql_batch.query_variables) if batch.spark_sql_batch.query_variables else {}, "jar_file_uris": list(batch.spark_sql_batch.jar_file_uris) if batch.spark_sql_batch.jar_file_uris else [], } elif batch.spark_r_batch: job_config = { "main_r_file_uri": batch.spark_r_batch.main_r_file_uri, "file_uris": list(batch.spark_r_batch.file_uris) if batch.spark_r_batch.file_uris else [], "archive_uris": list(batch.spark_r_batch.archive_uris) if batch.spark_r_batch.archive_uris else [], "args": list(batch.spark_r_batch.args) if batch.spark_r_batch.args else [], } # Extract runtime config details runtime_config = {} if batch.runtime_config: runtime_config = { "version": batch.runtime_config.version if batch.runtime_config.version else None, "container_image": batch.runtime_config.container_image if batch.runtime_config.container_image else None, "properties": dict(batch.runtime_config.properties) if batch.runtime_config.properties else {}, } # Extract environment config details environment_config: dict[str, Any] = {} if batch.environment_config: environment_config = { "execution_config": {}, "peripherals_config": {}, } if batch.environment_config.execution_config: exec_config = batch.environment_config.execution_config environment_config["execution_config"] = { "service_account": exec_config.service_account if exec_config.service_account else None, "network_uri": exec_config.network_uri if exec_config.network_uri else None, "subnetwork_uri": exec_config.subnetwork_uri if exec_config.subnetwork_uri else None, "network_tags": list(exec_config.network_tags) if exec_config.network_tags else [], "kms_key": exec_config.kms_key if exec_config.kms_key else None, } if batch.environment_config.peripherals_config: periph_config = batch.environment_config.peripherals_config environment_config["peripherals_config"] = { "metastore_service": periph_config.metastore_service if periph_config.metastore_service else None, "spark_history_server_config": {}, } if periph_config.spark_history_server_config: environment_config["peripherals_config"][ "spark_history_server_config" ] = { "dataproc_cluster": periph_config.spark_history_server_config.dataproc_cluster if periph_config.spark_history_server_config.dataproc_cluster else None, } return { "name": batch.name, "batch_id": batch.name.split("/")[-1], "uuid": batch.uuid if batch.uuid else None, "state": batch.state.name, "state_message": batch.state_message, "state_time": batch.state_time.isoformat() if batch.state_time else None, "create_time": batch.create_time.isoformat() if batch.create_time else None, "creator": batch.creator if batch.creator else None, "labels": dict(batch.labels) if batch.labels else {}, "job_type": job_type, "job_config": job_config, "runtime_config": runtime_config, "environment_config": environment_config, "runtime_info": runtime_info, "operation": batch.operation if batch.operation else None, "state_history": [ { "state": state.state.name, "state_message": state.state_message, "state_start_time": state.state_start_time.isoformat() if state.state_start_time else None, } for state in batch.state_history ], } except Exception as e: logger.error("Failed to get batch job", error=str(e)) raise async def delete_batch_job( self, project_id: str, region: str, batch_id: str ) -> dict[str, Any]: """Delete a batch job.""" try: loop = asyncio.get_event_loop() client = self._get_batch_client(region) request = types.DeleteBatchRequest( name=f"projects/{project_id}/locations/{region}/batches/{batch_id}" ) await loop.run_in_executor(None, client.delete_batch, request) return { "batch_id": batch_id, "status": "DELETED", "message": f"Batch job {batch_id} deletion initiated", } except Exception as e: logger.error("Failed to delete batch job", error=str(e)) raise async def compare_batches( self, project_id: str, region: str, batch_id_1: str, batch_id_2: str ) -> dict[str, Any]: """Compare two batch jobs and return detailed differences.""" try: # Get details for both batches batch_1 = await self.get_batch_job(project_id, region, batch_id_1) batch_2 = await self.get_batch_job(project_id, region, batch_id_2) # Compare basic information basic_comparison = { "batch_id": { "batch_1": batch_1["batch_id"], "batch_2": batch_2["batch_id"], }, "job_type": { "batch_1": batch_1["job_type"], "batch_2": batch_2["job_type"], "same": batch_1["job_type"] == batch_2["job_type"], }, "state": { "batch_1": batch_1["state"], "batch_2": batch_2["state"], "same": batch_1["state"] == batch_2["state"], }, "creator": { "batch_1": batch_1.get("creator"), "batch_2": batch_2.get("creator"), "same": batch_1.get("creator") == batch_2.get("creator"), }, "create_time": { "batch_1": batch_1.get("create_time"), "batch_2": batch_2.get("create_time"), }, } # Compare job configurations config_comparison = { "same_config": batch_1["job_config"] == batch_2["job_config"], "batch_1_config": batch_1["job_config"], "batch_2_config": batch_2["job_config"], } # Compare runtime configurations runtime_comparison = { "same_runtime": batch_1["runtime_config"] == batch_2["runtime_config"], "batch_1_runtime": batch_1["runtime_config"], "batch_2_runtime": batch_2["runtime_config"], } # Compare environment configurations env_comparison = { "same_environment": batch_1["environment_config"] == batch_2["environment_config"], "batch_1_environment": batch_1["environment_config"], "batch_2_environment": batch_2["environment_config"], } # Compare labels labels_comparison = { "same_labels": batch_1["labels"] == batch_2["labels"], "batch_1_labels": batch_1["labels"], "batch_2_labels": batch_2["labels"], } # Compare performance/runtime info performance_comparison = {} runtime_1 = batch_1.get("runtime_info", {}) runtime_2 = batch_2.get("runtime_info", {}) if runtime_1.get("approximate_usage") and runtime_2.get( "approximate_usage" ): usage_1 = runtime_1["approximate_usage"] usage_2 = runtime_2["approximate_usage"] performance_comparison = { "resource_usage": { "batch_1_milli_dcu_seconds": usage_1.get("milli_dcu_seconds"), "batch_2_milli_dcu_seconds": usage_2.get("milli_dcu_seconds"), "batch_1_shuffle_storage_gb_seconds": usage_1.get( "shuffle_storage_gb_seconds" ), "batch_2_shuffle_storage_gb_seconds": usage_2.get( "shuffle_storage_gb_seconds" ), } } # Compare state history (execution timeline) history_comparison = { "batch_1_states": [ state["state"] for state in batch_1.get("state_history", []) ], "batch_2_states": [ state["state"] for state in batch_2.get("state_history", []) ], "same_state_progression": [ state["state"] for state in batch_1.get("state_history", []) ] == [state["state"] for state in batch_2.get("state_history", [])], } # Calculate execution duration if possible def calculate_duration(batch_data: dict[str, Any]) -> float | None: state_history = batch_data.get("state_history", []) if len(state_history) >= 2: from datetime import datetime try: start_time = datetime.fromisoformat( state_history[0]["state_start_time"].replace("Z", "+00:00") ) end_time = datetime.fromisoformat( state_history[-1]["state_start_time"].replace("Z", "+00:00") ) return (end_time - start_time).total_seconds() except (ValueError, TypeError): return None return None duration_1 = calculate_duration(batch_1) duration_2 = calculate_duration(batch_2) if duration_1 is not None and duration_2 is not None: performance_comparison["execution_time"] = { "batch_1_seconds": duration_1, "batch_2_seconds": duration_2, "difference_seconds": abs(duration_1 - duration_2), } # Summary of differences differences = [] if not basic_comparison["job_type"]["same"]: differences.append("Different job types") if not basic_comparison["state"]["same"]: differences.append("Different current states") if not basic_comparison["creator"]["same"]: differences.append("Different creators") if not config_comparison["same_config"]: differences.append("Different job configurations") if not runtime_comparison["same_runtime"]: differences.append("Different runtime configurations") if not env_comparison["same_environment"]: differences.append("Different environment configurations") if not labels_comparison["same_labels"]: differences.append("Different labels") if not history_comparison["same_state_progression"]: differences.append("Different state progression") return { "comparison_summary": { "batch_1_id": batch_id_1, "batch_2_id": batch_id_2, "identical": len(differences) == 0, "differences": differences, }, "basic_info": basic_comparison, "job_configuration": config_comparison, "runtime_configuration": runtime_comparison, "environment_configuration": env_comparison, "labels": labels_comparison, "performance": performance_comparison, "state_history": history_comparison, } except Exception as e: logger.error("Failed to compare batch jobs", error=str(e)) raise def _get_batch_job_type(self, batch: types.Batch) -> str: """Extract job type from batch object.""" if batch.spark_batch: return "spark" elif batch.pyspark_batch: return "pyspark" elif batch.spark_sql_batch: return "spark_sql" elif batch.spark_r_batch: return "spark_r" else: return "unknown"

Implementation Reference

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/warrenzhu25/dataproc-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server