Skip to main content
Glama
googleapis

MCP Toolbox for Databases

by googleapis
serverless_spark_integration_test.go33 kB
// Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package serverlessspark import ( "bytes" "context" "encoding/json" "fmt" "io" "maps" "net/http" "os" "reflect" "regexp" "slices" "strings" "testing" "time" dataproc "cloud.google.com/go/dataproc/v2/apiv1" "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" "github.com/google/go-cmp/cmp" "github.com/googleapis/genai-toolbox/internal/testutils" "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches" "github.com/googleapis/genai-toolbox/tests" "google.golang.org/api/iterator" "google.golang.org/api/option" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/testing/protocmp" ) var ( serverlessSparkLocation = os.Getenv("SERVERLESS_SPARK_LOCATION") serverlessSparkProject = os.Getenv("SERVERLESS_SPARK_PROJECT") serverlessSparkServiceAccount = os.Getenv("SERVERLESS_SPARK_SERVICE_ACCOUNT") ) const ( batchURLPrefix = "https://console.cloud.google.com/dataproc/batches/" logsURLPrefix = "https://console.cloud.google.com/logs/viewer?" ) func getServerlessSparkVars(t *testing.T) map[string]any { switch "" { case serverlessSparkLocation: t.Fatal("'SERVERLESS_SPARK_LOCATION' not set") case serverlessSparkProject: t.Fatal("'SERVERLESS_SPARK_PROJECT' not set") case serverlessSparkServiceAccount: t.Fatal("'SERVERLESS_SPARK_SERVICE_ACCOUNT' not set") } return map[string]any{ "kind": "serverless-spark", "project": serverlessSparkProject, "location": serverlessSparkLocation, } } func TestServerlessSparkToolEndpoints(t *testing.T) { sourceConfig := getServerlessSparkVars(t) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) defer cancel() toolsFile := map[string]any{ "sources": map[string]any{ "my-spark": sourceConfig, }, "authServices": map[string]any{ "my-google-auth": map[string]any{ "kind": "google", "clientId": tests.ClientId, }, }, "tools": map[string]any{ "list-batches": map[string]any{ "kind": "serverless-spark-list-batches", "source": "my-spark", }, "list-batches-with-auth": map[string]any{ "kind": "serverless-spark-list-batches", "source": "my-spark", "authRequired": []string{"my-google-auth"}, }, "get-batch": map[string]any{ "kind": "serverless-spark-get-batch", "source": "my-spark", }, "get-batch-with-auth": map[string]any{ "kind": "serverless-spark-get-batch", "source": "my-spark", "authRequired": []string{"my-google-auth"}, }, "cancel-batch": map[string]any{ "kind": "serverless-spark-cancel-batch", "source": "my-spark", }, "cancel-batch-with-auth": map[string]any{ "kind": "serverless-spark-cancel-batch", "source": "my-spark", "authRequired": []string{"my-google-auth"}, }, "create-pyspark-batch": map[string]any{ "kind": "serverless-spark-create-pyspark-batch", "source": "my-spark", "environmentConfig": map[string]any{ "executionConfig": map[string]any{ "serviceAccount": serverlessSparkServiceAccount, }, }, }, "create-pyspark-batch-2-3": map[string]any{ "kind": "serverless-spark-create-pyspark-batch", "source": "my-spark", "runtimeConfig": map[string]any{"version": "2.3"}, "environmentConfig": map[string]any{ "executionConfig": map[string]any{ "serviceAccount": serverlessSparkServiceAccount, }, }, }, "create-pyspark-batch-with-auth": map[string]any{ "kind": "serverless-spark-create-pyspark-batch", "source": "my-spark", "authRequired": []string{"my-google-auth"}, }, "create-spark-batch": map[string]any{ "kind": "serverless-spark-create-spark-batch", "source": "my-spark", "environmentConfig": map[string]any{ "executionConfig": map[string]any{ "serviceAccount": serverlessSparkServiceAccount, }, }, }, "create-spark-batch-2-3": map[string]any{ "kind": "serverless-spark-create-spark-batch", "source": "my-spark", "runtimeConfig": map[string]any{"version": "2.3"}, "environmentConfig": map[string]any{ "executionConfig": map[string]any{ "serviceAccount": serverlessSparkServiceAccount, }, }, }, "create-spark-batch-with-auth": map[string]any{ "kind": "serverless-spark-create-spark-batch", "source": "my-spark", "authRequired": []string{"my-google-auth"}, }, }, } cmd, cleanup, err := tests.StartCmd(ctx, toolsFile) if err != nil { t.Fatalf("command initialization returned an error: %s", err) } defer cleanup() waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) if err != nil { t.Logf("toolbox command logs: \n%s", out) t.Fatalf("toolbox didn't start successfully: %s", err) } endpoint := fmt.Sprintf("%s-dataproc.googleapis.com:443", serverlessSparkLocation) client, err := dataproc.NewBatchControllerClient(ctx, option.WithEndpoint(endpoint)) if err != nil { t.Fatalf("failed to create dataproc client: %v", err) } defer client.Close() t.Run("list-batches", func(t *testing.T) { // list-batches is sensitive to state changes, so this test must run sequentially. t.Run("success", func(t *testing.T) { runListBatchesTest(t, client, ctx) }) t.Run("errors", func(t *testing.T) { t.Parallel() tcs := []struct { name string toolName string request map[string]any wantCode int wantMsg string }{ { name: "zero page size", toolName: "list-batches", request: map[string]any{"pageSize": 0}, wantCode: http.StatusBadRequest, wantMsg: "pageSize must be positive: 0", }, { name: "negative page size", toolName: "list-batches", request: map[string]any{"pageSize": -1}, wantCode: http.StatusBadRequest, wantMsg: "pageSize must be positive: -1", }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { t.Parallel() testError(t, tc.toolName, tc.request, tc.wantCode, tc.wantMsg) }) } }) t.Run("auth", func(t *testing.T) { t.Parallel() runAuthTest(t, "list-batches-with-auth", map[string]any{"pageSize": 1}, http.StatusOK) }) }) // The following tool tests are independent and can run in parallel with each other. t.Run("parallel-tool-tests", func(t *testing.T) { t.Run("get-batch", func(t *testing.T) { t.Parallel() fullName := listBatchesRpc(t, client, ctx, "", 1, true)[0].Name t.Run("success", func(t *testing.T) { t.Parallel() runGetBatchTest(t, client, ctx, fullName) }) t.Run("errors", func(t *testing.T) { t.Parallel() missingBatchFullName := fmt.Sprintf("projects/%s/locations/%s/batches/INVALID_BATCH", serverlessSparkProject, serverlessSparkLocation) tcs := []struct { name string toolName string request map[string]any wantCode int wantMsg string }{ { name: "missing batch", toolName: "get-batch", request: map[string]any{"name": "INVALID_BATCH"}, wantCode: http.StatusBadRequest, wantMsg: fmt.Sprintf("Not found: Batch projects/%s/locations/%s/batches/INVALID_BATCH", serverlessSparkProject, serverlessSparkLocation), }, { name: "full batch name", toolName: "get-batch", request: map[string]any{"name": missingBatchFullName}, wantCode: http.StatusBadRequest, wantMsg: fmt.Sprintf("name must be a short batch name without '/': %s", missingBatchFullName), }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { t.Parallel() testError(t, tc.toolName, tc.request, tc.wantCode, tc.wantMsg) }) } }) t.Run("auth", func(t *testing.T) { t.Parallel() runAuthTest(t, "get-batch-with-auth", map[string]any{"name": shortName(fullName)}, http.StatusOK) }) }) t.Run("create-pyspark-batch", func(t *testing.T) { t.Parallel() t.Run("success", func(t *testing.T) { t.Parallel() piPy := "file:///usr/lib/spark/examples/src/main/python/pi.py" tcs := []struct { name string toolName string request map[string]any waitForSuccess bool validate func(t *testing.T, b *dataprocpb.Batch) }{ { name: "no params", toolName: "create-pyspark-batch", waitForSuccess: true, request: map[string]any{"mainFile": piPy}, }, // Tests below are just verifying options are set correctly on created batches, // they don't need to wait for success. { name: "with arg", toolName: "create-pyspark-batch", request: map[string]any{"mainFile": piPy, "args": []string{"100"}}, validate: func(t *testing.T, b *dataprocpb.Batch) { if !cmp.Equal(b.GetPysparkBatch().Args, []string{"100"}) { t.Errorf("unexpected args: got %v, want %v", b.GetPysparkBatch().Args, []string{"100"}) } }, }, { name: "version", toolName: "create-pyspark-batch", request: map[string]any{"mainFile": piPy, "version": "2.2"}, validate: func(t *testing.T, b *dataprocpb.Batch) { v := b.GetRuntimeConfig().GetVersion() if v != "2.2" { t.Errorf("unexpected version: got %v, want 2.2", v) } }, }, { name: "version param overrides tool", toolName: "create-pyspark-batch-2-3", request: map[string]any{"mainFile": piPy, "version": "2.2"}, validate: func(t *testing.T, b *dataprocpb.Batch) { v := b.GetRuntimeConfig().GetVersion() if v != "2.2" { t.Errorf("unexpected version: got %v, want 2.2", v) } }, }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { t.Parallel() runCreateSparkBatchTest(t, client, ctx, tc.toolName, tc.request, tc.waitForSuccess, tc.validate) }) } }) t.Run("auth", func(t *testing.T) { t.Parallel() // Batch creation succeeds even with an invalid main file, but will fail quickly once running. runAuthTest(t, "create-pyspark-batch-with-auth", map[string]any{"mainFile": "file:///placeholder"}, http.StatusOK) }) t.Run("errors", func(t *testing.T) { t.Parallel() tcs := []struct { name string request map[string]any wantMsg string }{ { name: "missing main file", request: map[string]any{}, wantMsg: "parameter \\\"mainFile\\\" is required", }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { t.Parallel() testError(t, "create-pyspark-batch", tc.request, http.StatusBadRequest, tc.wantMsg) }) } }) }) t.Run("create-spark-batch", func(t *testing.T) { t.Parallel() t.Run("success", func(t *testing.T) { t.Parallel() tcs := []struct { name string toolName string request map[string]any waitForSuccess bool validate func(t *testing.T, b *dataprocpb.Batch) }{ { name: "main class", toolName: "create-spark-batch", waitForSuccess: true, request: javaReq(map[string]any{}), }, { // spark-examples.jar doesn't have a Main-Class, so pick an arbitrary other // jar that does. Note there's a chance a subminor release of 2.2 will // upgrade Spark and its dependencies, causing a failure. If that happens, // find the new ivy jar filename and use that. The alternative would be to // pin a subminor version, but that's guaranteed to be GC'ed after 1 year, // whereas 2.2 is old enough it's unlikely to see a Spark version bump. name: "main jar", toolName: "create-spark-batch", waitForSuccess: true, request: map[string]any{ "version": "2.2", "mainJarFile": "file:///usr/lib/spark/jars/ivy-2.5.2.jar", "args": []string{"-version"}, }, }, // Tests below are just verifying options are set correctly on created batches, // they don't need to wait for success. { name: "with arg", toolName: "create-spark-batch", request: javaReq(map[string]any{"args": []string{"100"}}), validate: func(t *testing.T, b *dataprocpb.Batch) { if !cmp.Equal(b.GetSparkBatch().Args, []string{"100"}) { t.Errorf("unexpected args: got %v, want %v", b.GetSparkBatch().Args, []string{"100"}) } }, }, { name: "version", toolName: "create-spark-batch", request: javaReq(map[string]any{"version": "2.2"}), validate: func(t *testing.T, b *dataprocpb.Batch) { v := b.GetRuntimeConfig().GetVersion() if v != "2.2" { t.Errorf("unexpected version: got %v, want 2.2", v) } }, }, { name: "version param overrides tool", toolName: "create-spark-batch-2-3", request: javaReq(map[string]any{"version": "2.2"}), validate: func(t *testing.T, b *dataprocpb.Batch) { v := b.GetRuntimeConfig().GetVersion() if v != "2.2" { t.Errorf("unexpected version: got %v, want 2.2", v) } }, }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { t.Parallel() runCreateSparkBatchTest(t, client, ctx, tc.toolName, tc.request, tc.waitForSuccess, tc.validate) }) } }) t.Run("auth", func(t *testing.T) { t.Parallel() // Batch creation succeeds even with an invalid main file, but will fail quickly once running. runAuthTest(t, "create-spark-batch-with-auth", map[string]any{"mainJarFile": "file:///placeholder"}, http.StatusOK) }) t.Run("errors", func(t *testing.T) { t.Parallel() tcs := []struct { name string request map[string]any wantMsg string }{ { name: "no main jar or main class", request: map[string]any{}, wantMsg: "must provide either mainJarFile or mainClass", }, { name: "both main jar and main class", request: map[string]any{ "mainJarFile": "my.jar", "mainClass": "com.example.MyClass", }, wantMsg: "cannot provide both mainJarFile and mainClass", }, { name: "main class without jar files", request: map[string]any{ "mainClass": "com.example.MyClass", }, wantMsg: "jarFiles is required when mainClass is provided", }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { t.Parallel() testError(t, "create-spark-batch", tc.request, http.StatusBadRequest, tc.wantMsg) }) } }) }) t.Run("cancel-batch", func(t *testing.T) { t.Parallel() t.Run("success", func(t *testing.T) { t.Parallel() tcs := []struct { name string getBatchName func(t *testing.T) string }{ { name: "running batch", getBatchName: func(t *testing.T) string { return createBatch(t, client, ctx) }, }, { name: "succeeded batch", getBatchName: func(t *testing.T) string { return listBatchesRpc(t, client, ctx, "state = SUCCEEDED", 1, true)[0].Name }, }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { t.Parallel() runCancelBatchTest(t, client, ctx, tc.getBatchName(t)) }) } }) t.Run("errors", func(t *testing.T) { t.Parallel() // Find a batch that's already completed. completedBatchOp := listBatchesRpc(t, client, ctx, "state = SUCCEEDED", 1, true)[0].Operation fullOpName := fmt.Sprintf("projects/%s/locations/%s/operations/%s", serverlessSparkProject, serverlessSparkLocation, shortName(completedBatchOp)) tcs := []struct { name string toolName string request map[string]any wantCode int wantMsg string }{ { name: "missing op parameter", toolName: "cancel-batch", request: map[string]any{}, wantCode: http.StatusBadRequest, wantMsg: "parameter \\\"operation\\\" is required", }, { name: "nonexistent op", toolName: "cancel-batch", request: map[string]any{"operation": "INVALID_OPERATION"}, wantCode: http.StatusBadRequest, wantMsg: "Operation not found", }, { name: "full op name", toolName: "cancel-batch", request: map[string]any{"operation": fullOpName}, wantCode: http.StatusBadRequest, wantMsg: fmt.Sprintf("operation must be a short operation name without '/': %s", fullOpName), }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { t.Parallel() testError(t, tc.toolName, tc.request, tc.wantCode, tc.wantMsg) }) } }) t.Run("auth", func(t *testing.T) { t.Parallel() runAuthTest(t, "cancel-batch-with-auth", map[string]any{"operation": "INVALID_OPERATION"}, http.StatusBadRequest) }) }) }) } func waitForBatch(t *testing.T, client *dataproc.BatchControllerClient, parentCtx context.Context, batch string, desiredStates []dataprocpb.Batch_State, timeout time.Duration) { t.Logf("waiting %s for batch %s to reach one of %v", timeout, batch, desiredStates) ctx, cancel := context.WithTimeout(parentCtx, timeout) defer cancel() start := time.Now() lastLog := start for { select { case <-ctx.Done(): t.Fatalf("timed out waiting for batch %s to reach one of %v", batch, desiredStates) default: } getReq := &dataprocpb.GetBatchRequest{Name: batch} batch, err := client.GetBatch(ctx, getReq) if err != nil { t.Fatalf("failed to get batch %s: %v", batch, err) } now := time.Now() if now.Sub(lastLog) >= 30*time.Second { t.Logf("%s: batch %s is in state %s after %s", t.Name(), batch.Name, batch.State, now.Sub(start)) lastLog = now } if slices.Contains(desiredStates, batch.State) { return } if batch.State == dataprocpb.Batch_FAILED || batch.State == dataprocpb.Batch_CANCELLED || batch.State == dataprocpb.Batch_SUCCEEDED { t.Fatalf("batch op %s is in a terminal state %s, but wanted one of %v. State message: %s", batch.Name, batch.State, desiredStates, batch.StateMessage) } time.Sleep(2 * time.Second) } } // createBatch creates a test batch and immediately returns the batch name, without waiting for the // batch to start or complete. func createBatch(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context) string { parent := fmt.Sprintf("projects/%s/locations/%s", serverlessSparkProject, serverlessSparkLocation) req := &dataprocpb.CreateBatchRequest{ Parent: parent, Batch: &dataprocpb.Batch{ BatchConfig: &dataprocpb.Batch_SparkBatch{ SparkBatch: &dataprocpb.SparkBatch{ Driver: &dataprocpb.SparkBatch_MainClass{ MainClass: "org.apache.spark.examples.SparkPi", }, JarFileUris: []string{ "file:///usr/lib/spark/examples/jars/spark-examples.jar", }, Args: []string{"1000"}, }, }, EnvironmentConfig: &dataprocpb.EnvironmentConfig{ ExecutionConfig: &dataprocpb.ExecutionConfig{ ServiceAccount: serverlessSparkServiceAccount, }, }, }, } createOp, err := client.CreateBatch(ctx, req) if err != nil { t.Fatalf("failed to create batch: %v", err) } meta, err := createOp.Metadata() if err != nil { t.Fatalf("failed to get batch metadata: %v", err) } // Wait for the batch to become at least PENDING; it typically takes >10s to go from PENDING to // RUNNING, giving the cancel batch tests plenty of time to cancel it before it completes. waitForBatch(t, client, ctx, meta.Batch, []dataprocpb.Batch_State{dataprocpb.Batch_PENDING, dataprocpb.Batch_RUNNING}, 1*time.Minute) return meta.Batch } func runCancelBatchTest(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, batchName string) { // First get the batch details directly from the Go proto API. batch, err := client.GetBatch(ctx, &dataprocpb.GetBatchRequest{Name: batchName}) if err != nil { t.Fatalf("failed to get batch: %s", err) } request := map[string]any{"operation": shortName(batch.Operation)} resp, err := invokeTool("cancel-batch", request, nil) if err != nil { t.Fatalf("invokeTool failed: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) } if batch.State != dataprocpb.Batch_SUCCEEDED { waitForBatch(t, client, ctx, batchName, []dataprocpb.Batch_State{dataprocpb.Batch_CANCELLING, dataprocpb.Batch_CANCELLED}, 2*time.Minute) } } // runListBatchesTest invokes the running list-batches tool and ensures it returns the correct // number of results. It can run successfully against any GCP project that contains at least 2 total // Serverless Spark batches. func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context) { batch2 := listBatchesRpc(t, client, ctx, "", 2, true) batch20 := listBatchesRpc(t, client, ctx, "", 20, false) tcs := []struct { name string filter string pageSize int numPages int want []serverlesssparklistbatches.Batch }{ {name: "one page", pageSize: 2, numPages: 1, want: batch2}, {name: "two pages", pageSize: 1, numPages: 2, want: batch2}, {name: "20 batches", pageSize: 20, numPages: 1, want: batch20}, {name: "omit page size", numPages: 1, want: batch20}, { name: "filtered", filter: "state = SUCCEEDED", pageSize: 2, numPages: 1, want: listBatchesRpc(t, client, ctx, "state = SUCCEEDED", 2, true), }, { name: "empty", filter: "state = SUCCEEDED AND state = FAILED", pageSize: 1, numPages: 1, want: nil, }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { t.Parallel() var actual []serverlesssparklistbatches.Batch var pageToken string for i := 0; i < tc.numPages; i++ { request := map[string]any{ "filter": tc.filter, "pageToken": pageToken, } if tc.pageSize > 0 { request["pageSize"] = tc.pageSize } resp, err := invokeTool("list-batches", request, nil) if err != nil { t.Fatalf("invokeTool failed: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) } var body map[string]any if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { t.Fatalf("error parsing response body: %v", err) } result, ok := body["result"].(string) if !ok { t.Fatalf("unable to find result in response body") } var listResponse serverlesssparklistbatches.ListBatchesResponse if err := json.Unmarshal([]byte(result), &listResponse); err != nil { t.Fatalf("error unmarshalling result: %s", err) } actual = append(actual, listResponse.Batches...) pageToken = listResponse.NextPageToken } if !reflect.DeepEqual(actual, tc.want) { t.Fatalf("unexpected batches: got %+v, want %+v", actual, tc.want) } // want has URLs because it's created from Batch instances by the same utility function // used by the tool internals. Double-check that the URLs are reasonable. for _, batch := range tc.want { if !strings.HasPrefix(batch.ConsoleURL, batchURLPrefix) { t.Errorf("unexpected consoleUrl in batch: %#v", batch) } if !strings.HasPrefix(batch.LogsURL, logsURLPrefix) { t.Errorf("unexpected logsUrl in batch: %#v", batch) } } }) } } func listBatchesRpc(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, filter string, n int, exact bool) []serverlesssparklistbatches.Batch { parent := fmt.Sprintf("projects/%s/locations/%s", serverlessSparkProject, serverlessSparkLocation) req := &dataprocpb.ListBatchesRequest{ Parent: parent, PageSize: 2, OrderBy: "create_time desc", } if filter != "" { req.Filter = filter } it := client.ListBatches(ctx, req) pager := iterator.NewPager(it, n, "") var batchPbs []*dataprocpb.Batch _, err := pager.NextPage(&batchPbs) if err != nil { t.Fatalf("failed to list batches: %s", err) } if exact && len(batchPbs) != n { t.Fatalf("expected exactly %d batches, got %d", n, len(batchPbs)) } if !exact && (len(batchPbs) == 0 || len(batchPbs) > n) { t.Fatalf("expected between 1 and %d batches, got %d", n, len(batchPbs)) } batches, err := serverlesssparklistbatches.ToBatches(batchPbs) if err != nil { t.Fatalf("failed to convert batches to JSON: %v", err) } return batches } func runAuthTest(t *testing.T, toolName string, request map[string]any, wantStatus int) { idToken, err := tests.GetGoogleIdToken(tests.ClientId) if err != nil { t.Fatalf("error getting Google ID token: %s", err) } tcs := []struct { name string headers map[string]string wantStatus int }{ { name: "valid auth token", headers: map[string]string{"my-google-auth_token": idToken}, wantStatus: wantStatus, }, { name: "invalid auth token", headers: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, wantStatus: http.StatusUnauthorized, }, { name: "no auth token", headers: nil, wantStatus: http.StatusUnauthorized, }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { t.Parallel() resp, err := invokeTool(toolName, request, tc.headers) if err != nil { t.Fatalf("invokeTool failed: %v", err) } defer resp.Body.Close() if resp.StatusCode != tc.wantStatus { bodyBytes, _ := io.ReadAll(resp.Body) t.Fatalf("response status code is not %d, got %d: %s", tc.wantStatus, resp.StatusCode, string(bodyBytes)) } }) } } func runGetBatchTest(t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, fullName string) { // First get the batch details directly from the Go proto API. req := &dataprocpb.GetBatchRequest{ Name: fullName, } rawWantBatchPb, err := client.GetBatch(ctx, req) if err != nil { t.Fatalf("failed to get batch: %s", err) } // Trim unknown fields from the proto by marshalling and unmarshalling. jsonBytes, err := protojson.Marshal(rawWantBatchPb) if err != nil { t.Fatalf("failed to marshal batch to JSON: %s", err) } var wantBatchPb dataprocpb.Batch if err := protojson.Unmarshal(jsonBytes, &wantBatchPb); err != nil { t.Fatalf("error unmarshalling result: %s", err) } tcs := []struct { name string batchName string want *dataprocpb.Batch }{ { name: "found batch", batchName: shortName(fullName), want: &wantBatchPb, }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { t.Parallel() request := map[string]any{"name": tc.batchName} resp, err := invokeTool("get-batch", request, nil) if err != nil { t.Fatalf("invokeTool failed: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) } var body map[string]any if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { t.Fatalf("error parsing response body: %v", err) } result, ok := body["result"].(string) if !ok { t.Fatalf("unable to find result in response body") } var wrappedResult map[string]any if err := json.Unmarshal([]byte(result), &wrappedResult); err != nil { t.Fatalf("error unmarshalling result: %s", err) } consoleURL, ok := wrappedResult["consoleUrl"].(string) if !ok || !strings.HasPrefix(consoleURL, batchURLPrefix) { t.Errorf("unexpected consoleUrl: %v", consoleURL) } logsURL, ok := wrappedResult["logsUrl"].(string) if !ok || !strings.HasPrefix(logsURL, logsURLPrefix) { t.Errorf("unexpected logsUrl: %v", logsURL) } batchJSON, err := json.Marshal(wrappedResult["batch"]) if err != nil { t.Fatalf("failed to marshal batch: %v", err) } // Unmarshal JSON to proto for proto-aware deep comparison. var batch dataprocpb.Batch if err := protojson.Unmarshal(batchJSON, &batch); err != nil { t.Fatalf("error unmarshalling batch from wrapped result: %s", err) } if !cmp.Equal(&batch, tc.want, protocmp.Transform()) { diff := cmp.Diff(&batch, tc.want, protocmp.Transform()) t.Errorf("GetBatch() returned diff (-got +want):\n%s", diff) } }) } } func javaReq(req map[string]any) map[string]any { merged := map[string]any{ "mainClass": "org.apache.spark.examples.SparkPi", "jarFiles": []string{"file:///usr/lib/spark/examples/jars/spark-examples.jar"}, } maps.Copy(merged, req) return merged } func runCreateSparkBatchTest( t *testing.T, client *dataproc.BatchControllerClient, ctx context.Context, toolName string, request map[string]any, waitForSuccess bool, validate func(t *testing.T, b *dataprocpb.Batch), ) { resp, err := invokeTool(toolName, request, nil) if err != nil { t.Fatalf("invokeTool failed: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) } var body map[string]any if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { t.Fatalf("error parsing response body: %v", err) } result, ok := body["result"].(string) if !ok { t.Fatalf("unable to find result in response body") } var resultMap map[string]any if err := json.Unmarshal([]byte(result), &resultMap); err != nil { t.Fatalf("failed to unmarshal result: %v", err) } consoleURL, ok := resultMap["consoleUrl"].(string) if !ok || !strings.HasPrefix(consoleURL, batchURLPrefix) { t.Errorf("unexpected consoleUrl: %v", consoleURL) } logsURL, ok := resultMap["logsUrl"].(string) if !ok || !strings.HasPrefix(logsURL, logsURLPrefix) { t.Errorf("unexpected logsUrl: %v", logsURL) } metaMap, ok := resultMap["opMetadata"].(map[string]any) if !ok { t.Fatalf("unexpected opMetadata: %v", metaMap) } metaJson, err := json.Marshal(metaMap) if err != nil { t.Fatalf("failed to marshal op metadata to JSON: %s", err) } var meta dataprocpb.BatchOperationMetadata if err := json.Unmarshal([]byte(metaJson), &meta); err != nil { t.Fatalf("failed to unmarshal result: %v", err) } if validate != nil { b, err := client.GetBatch(ctx, &dataprocpb.GetBatchRequest{Name: meta.Batch}) if err != nil { t.Fatalf("failed to get batch %s: %s", meta.Batch, err) } validate(t, b) } if waitForSuccess { waitForBatch(t, client, ctx, meta.Batch, []dataprocpb.Batch_State{dataprocpb.Batch_SUCCEEDED}, 5*time.Minute) } } func testError(t *testing.T, toolName string, request map[string]any, wantCode int, wantMsg string) { resp, err := invokeTool(toolName, request, nil) if err != nil { t.Fatalf("invokeTool failed: %v", err) } defer resp.Body.Close() if resp.StatusCode != wantCode { bodyBytes, _ := io.ReadAll(resp.Body) t.Fatalf("response status code is not %d, got %d: %s", wantCode, resp.StatusCode, string(bodyBytes)) } bodyBytes, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("failed to read response body: %v", err) } if !bytes.Contains(bodyBytes, []byte(wantMsg)) { t.Fatalf("response body does not contain %q: %s", wantMsg, string(bodyBytes)) } } func invokeTool(toolName string, request map[string]any, headers map[string]string) (*http.Response, error) { requestBytes, err := json.Marshal(request) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } url := fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", toolName) req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(requestBytes)) if err != nil { return nil, fmt.Errorf("unable to create request: %w", err) } req.Header.Add("Content-type", "application/json") for k, v := range headers { req.Header.Add(k, v) } return http.DefaultClient.Do(req) } func shortName(fullName string) string { parts := strings.Split(fullName, "/") return parts[len(parts)-1] }

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/googleapis/genai-toolbox'

If you have feedback or need assistance with the MCP directory API, please join our Discord server