// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package serverlessspark
import (
"context"
"encoding/json"
"fmt"
"time"
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
longrunning "cloud.google.com/go/longrunning/autogen"
"cloud.google.com/go/longrunning/autogen/longrunningpb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
"google.golang.org/protobuf/encoding/protojson"
)
const SourceKind string = "serverless-spark"
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
}
return actual, nil
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Project string `yaml:"project" validate:"required"`
Location string `yaml:"location" validate:"required"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
ua, err := util.UserAgentFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("error in User Agent retrieval: %s", err)
}
endpoint := fmt.Sprintf("%s-dataproc.googleapis.com:443", r.Location)
client, err := dataproc.NewBatchControllerClient(ctx, option.WithEndpoint(endpoint), option.WithUserAgent(ua))
if err != nil {
return nil, fmt.Errorf("failed to create dataproc client: %w", err)
}
opsClient, err := longrunning.NewOperationsClient(ctx, option.WithEndpoint(endpoint), option.WithUserAgent(ua))
if err != nil {
return nil, fmt.Errorf("failed to create longrunning client: %w", err)
}
s := &Source{
Config: r,
Client: client,
OpsClient: opsClient,
}
return s, nil
}
var _ sources.Source = &Source{}
type Source struct {
Config
Client *dataproc.BatchControllerClient
OpsClient *longrunning.OperationsClient
}
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) GetProject() string {
return s.Project
}
func (s *Source) GetLocation() string {
return s.Location
}
func (s *Source) GetBatchControllerClient() *dataproc.BatchControllerClient {
return s.Client
}
func (s *Source) GetOperationsClient(ctx context.Context) (*longrunning.OperationsClient, error) {
return s.OpsClient, nil
}
func (s *Source) Close() error {
if err := s.Client.Close(); err != nil {
return err
}
if err := s.OpsClient.Close(); err != nil {
return err
}
return nil
}
func (s *Source) CancelOperation(ctx context.Context, operation string) (any, error) {
req := &longrunningpb.CancelOperationRequest{
Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", s.GetProject(), s.GetLocation(), operation),
}
client, err := s.GetOperationsClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get operations client: %w", err)
}
err = client.CancelOperation(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to cancel operation: %w", err)
}
return fmt.Sprintf("Cancelled [%s].", operation), nil
}
func (s *Source) CreateBatch(ctx context.Context, batch *dataprocpb.Batch) (map[string]any, error) {
req := &dataprocpb.CreateBatchRequest{
Parent: fmt.Sprintf("projects/%s/locations/%s", s.GetProject(), s.GetLocation()),
Batch: batch,
}
client := s.GetBatchControllerClient()
op, err := client.CreateBatch(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to create batch: %w", err)
}
meta, err := op.Metadata()
if err != nil {
return nil, fmt.Errorf("failed to get create batch op metadata: %w", err)
}
projectID, location, batchID, err := ExtractBatchDetails(meta.GetBatch())
if err != nil {
return nil, fmt.Errorf("error extracting batch details from name %q: %v", meta.GetBatch(), err)
}
consoleUrl := BatchConsoleURL(projectID, location, batchID)
logsUrl := BatchLogsURL(projectID, location, batchID, meta.GetCreateTime().AsTime(), time.Time{})
wrappedResult := map[string]any{
"opMetadata": meta,
"consoleUrl": consoleUrl,
"logsUrl": logsUrl,
}
return wrappedResult, nil
}
// ListBatchesResponse is the response from the list batches API.
type ListBatchesResponse struct {
Batches []Batch `json:"batches"`
NextPageToken string `json:"nextPageToken"`
}
// Batch represents a single batch job.
type Batch struct {
Name string `json:"name"`
UUID string `json:"uuid"`
State string `json:"state"`
Creator string `json:"creator"`
CreateTime string `json:"createTime"`
Operation string `json:"operation"`
ConsoleURL string `json:"consoleUrl"`
LogsURL string `json:"logsUrl"`
}
func (s *Source) ListBatches(ctx context.Context, ps *int, pt, filter string) (any, error) {
client := s.GetBatchControllerClient()
parent := fmt.Sprintf("projects/%s/locations/%s", s.GetProject(), s.GetLocation())
req := &dataprocpb.ListBatchesRequest{
Parent: parent,
OrderBy: "create_time desc",
}
if ps != nil {
req.PageSize = int32(*ps)
}
if pt != "" {
req.PageToken = pt
}
if filter != "" {
req.Filter = filter
}
it := client.ListBatches(ctx, req)
pager := iterator.NewPager(it, int(req.PageSize), req.PageToken)
var batchPbs []*dataprocpb.Batch
nextPageToken, err := pager.NextPage(&batchPbs)
if err != nil {
return nil, fmt.Errorf("failed to list batches: %w", err)
}
batches, err := ToBatches(batchPbs)
if err != nil {
return nil, err
}
return ListBatchesResponse{Batches: batches, NextPageToken: nextPageToken}, nil
}
// ToBatches converts a slice of protobuf Batch messages to a slice of Batch structs.
func ToBatches(batchPbs []*dataprocpb.Batch) ([]Batch, error) {
batches := make([]Batch, 0, len(batchPbs))
for _, batchPb := range batchPbs {
consoleUrl, err := BatchConsoleURLFromProto(batchPb)
if err != nil {
return nil, fmt.Errorf("error generating console url: %v", err)
}
logsUrl, err := BatchLogsURLFromProto(batchPb)
if err != nil {
return nil, fmt.Errorf("error generating logs url: %v", err)
}
batch := Batch{
Name: batchPb.Name,
UUID: batchPb.Uuid,
State: batchPb.State.Enum().String(),
Creator: batchPb.Creator,
CreateTime: batchPb.CreateTime.AsTime().Format(time.RFC3339),
Operation: batchPb.Operation,
ConsoleURL: consoleUrl,
LogsURL: logsUrl,
}
batches = append(batches, batch)
}
return batches, nil
}
func (s *Source) GetBatch(ctx context.Context, name string) (map[string]any, error) {
client := s.GetBatchControllerClient()
req := &dataprocpb.GetBatchRequest{
Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", s.GetProject(), s.GetLocation(), name),
}
batchPb, err := client.GetBatch(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to get batch: %w", err)
}
jsonBytes, err := protojson.Marshal(batchPb)
if err != nil {
return nil, fmt.Errorf("failed to marshal batch to JSON: %w", err)
}
var result map[string]any
if err := json.Unmarshal(jsonBytes, &result); err != nil {
return nil, fmt.Errorf("failed to unmarshal batch JSON: %w", err)
}
consoleUrl, err := BatchConsoleURLFromProto(batchPb)
if err != nil {
return nil, fmt.Errorf("error generating console url: %v", err)
}
logsUrl, err := BatchLogsURLFromProto(batchPb)
if err != nil {
return nil, fmt.Errorf("error generating logs url: %v", err)
}
wrappedResult := map[string]any{
"consoleUrl": consoleUrl,
"logsUrl": logsUrl,
"batch": result,
}
return wrappedResult, nil
}