Skip to main content
Glama
googleapis

MCP Toolbox for Databases

by googleapis
tool.go154 kB
// Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package tests import ( "bytes" "context" "database/sql" "encoding/json" "fmt" "io" "net/http" "reflect" "sort" "strings" "sync" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/uuid" "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/jackc/pgx/v5/pgxpool" ) // RunToolGet runs the tool get endpoint func RunToolGetTest(t *testing.T) { // Test tool get endpoint tcs := []struct { name string api string want map[string]any }{ { name: "get my-simple-tool", api: "http://127.0.0.1:5000/api/tool/my-simple-tool/", want: map[string]any{ "my-simple-tool": map[string]any{ "description": "Simple tool to test end to end functionality.", "parameters": []any{}, "authRequired": []any{}, }, }, }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { resp, err := http.Get(tc.api) if err != nil { t.Fatalf("error when sending a request: %s", err) } defer resp.Body.Close() if resp.StatusCode != 200 { t.Fatalf("response status code is not 200") } var body map[string]interface{} err = json.NewDecoder(resp.Body).Decode(&body) if err != nil { t.Fatalf("error parsing response body") } got, ok := body["tools"] if !ok { t.Fatalf("unable to find tools in response body") } if !reflect.DeepEqual(got, tc.want) { t.Fatalf("got %q, want %q", got, tc.want) } }) } } func RunToolGetTestByName(t *testing.T, name string, want map[string]any) { // Test tool get endpoint tcs := []struct { name string api string want map[string]any }{ { name: fmt.Sprintf("get %s", name), api: fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/", name), want: want, }, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { resp, err := http.Get(tc.api) if err != nil { t.Fatalf("error when sending a request: %s", err) } defer resp.Body.Close() if resp.StatusCode != 200 { t.Fatalf("response status code is not 200") } var body map[string]interface{} err = json.NewDecoder(resp.Body).Decode(&body) if err != nil { t.Fatalf("error parsing response body") } got, ok := body["tools"] if !ok { t.Fatalf("unable to find tools in response body") } if !reflect.DeepEqual(got, tc.want) { t.Fatalf("got %q, want %q", got, tc.want) } }) } } // RunToolInvokeSimpleTest runs the tool invoke endpoint with no parameters func RunToolInvokeSimpleTest(t *testing.T, name string, simpleWant string) { // Test tool invoke endpoint invokeTcs := []struct { name string api string requestHeader map[string]string requestBody io.Reader want string isErr bool }{ { name: fmt.Sprintf("invoke %s", name), api: fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", name), requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{}`)), want: simpleWant, isErr: false, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { // Send Tool invocation request resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader) if resp.StatusCode != http.StatusOK { if tc.isErr { return } t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } // Check response body var body map[string]interface{} err := json.Unmarshal(respBody, &body) if err != nil { t.Fatalf("error parsing response body") } got, ok := body["result"].(string) if !ok { t.Fatalf("unable to find result in response body") } if !strings.Contains(got, tc.want) { t.Fatalf("unexpected value: got %q, want %q", got, tc.want) } }) } } func RunToolInvokeParametersTest(t *testing.T, name string, params []byte, simpleWant string) { // Test tool invoke endpoint invokeTcs := []struct { name string api string requestHeader map[string]string requestBody io.Reader want string isErr bool }{ { name: fmt.Sprintf("invoke %s", name), api: fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", name), requestHeader: map[string]string{}, requestBody: bytes.NewBuffer(params), want: simpleWant, isErr: false, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { // Send Tool invocation request resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader) if resp.StatusCode != http.StatusOK { if tc.isErr { return } t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } // Check response body var body map[string]interface{} err := json.Unmarshal(respBody, &body) if err != nil { t.Fatalf("error parsing response body") } got, ok := body["result"].(string) if !ok { t.Fatalf("unable to find result in response body") } if !strings.Contains(got, tc.want) { t.Fatalf("unexpected value: got %q, want %q", got, tc.want) } }) } } // RunToolInvoke runs the tool invoke endpoint func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOption) { // Resolve options // Default values for InvokeTestConfig configs := &InvokeTestConfig{ myToolId3NameAliceWant: "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]", myToolById4Want: "[{\"id\":4,\"name\":null}]", myArrayToolWant: "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]", nullWant: "null", supportOptionalNullParam: true, supportArrayParam: true, supportClientAuth: false, supportSelect1Want: true, supportSelect1Auth: true, } // Apply provided options for _, option := range options { option(configs) } // Get ID token idToken, err := GetGoogleIdToken(ClientId) if err != nil { t.Fatalf("error getting Google ID token: %s", err) } // Get access token accessToken, err := sources.GetIAMAccessToken(t.Context()) if err != nil { t.Fatalf("error getting access token from ADC: %s", err) } accessToken = "Bearer " + accessToken // Test tool invoke endpoint invokeTcs := []struct { name string api string enabled bool requestHeader map[string]string requestBody io.Reader wantStatusCode int wantBody string }{ { name: "invoke my-simple-tool", api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke", enabled: configs.supportSelect1Want, requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{}`)), wantBody: select1Want, wantStatusCode: http.StatusOK, }, { name: "invoke my-tool", api: "http://127.0.0.1:5000/api/tool/my-tool/invoke", enabled: true, requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{"id": 3, "name": "Alice"}`)), wantBody: configs.myToolId3NameAliceWant, wantStatusCode: http.StatusOK, }, { name: "invoke my-tool-by-id with nil response", api: "http://127.0.0.1:5000/api/tool/my-tool-by-id/invoke", enabled: true, requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{"id": 4}`)), wantBody: configs.myToolById4Want, wantStatusCode: http.StatusOK, }, { name: "invoke my-tool-by-name with nil response", api: "http://127.0.0.1:5000/api/tool/my-tool-by-name/invoke", enabled: configs.supportOptionalNullParam, requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{}`)), wantBody: configs.nullWant, wantStatusCode: http.StatusOK, }, { name: "Invoke my-tool without parameters", api: "http://127.0.0.1:5000/api/tool/my-tool/invoke", enabled: true, requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{}`)), wantBody: "", wantStatusCode: http.StatusBadRequest, }, { name: "Invoke my-tool with insufficient parameters", api: "http://127.0.0.1:5000/api/tool/my-tool/invoke", enabled: true, requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{"id": 1}`)), wantBody: "", wantStatusCode: http.StatusBadRequest, }, { name: "invoke my-array-tool", api: "http://127.0.0.1:5000/api/tool/my-array-tool/invoke", enabled: configs.supportArrayParam, requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{"idArray": [1,2,3], "nameArray": ["Alice", "Sid", "RandomName"], "cmdArray": ["HGETALL", "row3"]}`)), wantBody: configs.myArrayToolWant, wantStatusCode: http.StatusOK, }, { name: "Invoke my-auth-tool with auth token", api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke", enabled: configs.supportSelect1Auth, requestHeader: map[string]string{"my-google-auth_token": idToken}, requestBody: bytes.NewBuffer([]byte(`{}`)), wantBody: configs.myAuthToolWant, wantStatusCode: http.StatusOK, }, { name: "Invoke my-auth-tool with invalid auth token", api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke", enabled: configs.supportSelect1Auth, requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, requestBody: bytes.NewBuffer([]byte(`{}`)), wantBody: "", wantStatusCode: http.StatusUnauthorized, }, { name: "Invoke my-auth-tool without auth token", api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke", enabled: true, requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{}`)), wantBody: "", wantStatusCode: http.StatusUnauthorized, }, { name: "Invoke my-auth-required-tool with auth token", api: "http://127.0.0.1:5000/api/tool/my-auth-required-tool/invoke", enabled: configs.supportSelect1Auth, requestHeader: map[string]string{"my-google-auth_token": idToken}, requestBody: bytes.NewBuffer([]byte(`{}`)), wantBody: select1Want, wantStatusCode: http.StatusOK, }, { name: "Invoke my-auth-required-tool with invalid auth token", api: "http://127.0.0.1:5000/api/tool/my-auth-required-tool/invoke", enabled: true, requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, requestBody: bytes.NewBuffer([]byte(`{}`)), wantBody: "", wantStatusCode: http.StatusUnauthorized, }, { name: "Invoke my-auth-required-tool without auth token", api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke", enabled: true, requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{}`)), wantBody: "", wantStatusCode: http.StatusUnauthorized, }, { name: "Invoke my-client-auth-tool with auth token", api: "http://127.0.0.1:5000/api/tool/my-client-auth-tool/invoke", enabled: configs.supportClientAuth, requestHeader: map[string]string{"Authorization": accessToken}, requestBody: bytes.NewBuffer([]byte(`{}`)), wantBody: select1Want, wantStatusCode: http.StatusOK, }, { name: "Invoke my-client-auth-tool without auth token", api: "http://127.0.0.1:5000/api/tool/my-client-auth-tool/invoke", enabled: configs.supportClientAuth, requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{}`)), wantStatusCode: http.StatusUnauthorized, }, { name: "Invoke my-client-auth-tool with invalid auth token", api: "http://127.0.0.1:5000/api/tool/my-client-auth-tool/invoke", enabled: configs.supportClientAuth, requestHeader: map[string]string{"Authorization": "Bearer invalid-token"}, requestBody: bytes.NewBuffer([]byte(`{}`)), wantStatusCode: http.StatusUnauthorized, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { if !tc.enabled { return } // Send Tool invocation request resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader) // Check status code if resp.StatusCode != tc.wantStatusCode { t.Errorf("StatusCode mismatch: got %d, want %d. Response body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } // skip response body check if tc.wantBody == "" { return } // Check response body var body map[string]interface{} err = json.Unmarshal(respBody, &body) if err != nil { t.Fatalf("error parsing response body: %s", err) } got, ok := body["result"].(string) if !ok { t.Fatalf("unable to find result in response body") } if got != tc.wantBody { t.Fatalf("unexpected value: got %q, want %q", got, tc.wantBody) } }) } } // RunToolInvokeWithTemplateParameters runs tool invoke test cases with template parameters. func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options ...TemplateParamOption) { // Resolve options // Default values for TemplateParameterTestConfig configs := &TemplateParameterTestConfig{ ddlWant: "null", selectAllWant: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"},{\"age\":100,\"id\":2,\"name\":\"Alice\"}]", selectId1Want: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]", selectNameWant: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]", selectEmptyWant: "null", insert1Want: "null", nameFieldArray: `["name"]`, nameColFilter: "name", createColArray: `["id INT","name VARCHAR(20)","age INT"]`, supportDdl: true, supportInsert: true, } // Apply provided options for _, option := range options { option(configs) } selectOnlyNamesWant := "[{\"name\":\"Alex\"},{\"name\":\"Alice\"}]" // Test tool invoke endpoint invokeTcs := []struct { name string enabled bool ddl bool insert bool api string requestHeader map[string]string requestBody io.Reader want string isErr bool }{ { name: "invoke create-table-templateParams-tool", ddl: true, api: "http://127.0.0.1:5000/api/tool/create-table-templateParams-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "columns":%s}`, tableName, configs.createColArray))), want: configs.ddlWant, isErr: false, }, { name: "invoke insert-table-templateParams-tool", insert: true, api: "http://127.0.0.1:5000/api/tool/insert-table-templateParams-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "columns":["id","name","age"], "values":"1, 'Alex', 21"}`, tableName))), want: configs.insert1Want, isErr: false, }, { name: "invoke insert-table-templateParams-tool", insert: true, api: "http://127.0.0.1:5000/api/tool/insert-table-templateParams-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "columns":["id","name","age"], "values":"2, 'Alice', 100"}`, tableName))), want: configs.insert1Want, isErr: false, }, { name: "invoke select-templateParams-tool", api: "http://127.0.0.1:5000/api/tool/select-templateParams-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s"}`, tableName))), want: configs.selectAllWant, isErr: false, }, { name: "invoke select-templateParams-combined-tool", api: "http://127.0.0.1:5000/api/tool/select-templateParams-combined-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"id": 1, "tableName": "%s"}`, tableName))), want: configs.selectId1Want, isErr: false, }, { name: "invoke select-templateParams-combined-tool with no results", api: "http://127.0.0.1:5000/api/tool/select-templateParams-combined-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"id": 999, "tableName": "%s"}`, tableName))), want: configs.selectEmptyWant, isErr: false, }, { name: "invoke select-fields-templateParams-tool", enabled: configs.supportSelectFields, api: "http://127.0.0.1:5000/api/tool/select-fields-templateParams-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "fields":%s}`, tableName, configs.nameFieldArray))), want: selectOnlyNamesWant, isErr: false, }, { name: "invoke select-filter-templateParams-combined-tool", api: "http://127.0.0.1:5000/api/tool/select-filter-templateParams-combined-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"name": "Alex", "tableName": "%s", "columnFilter": "%s"}`, tableName, configs.nameColFilter))), want: configs.selectNameWant, isErr: false, }, { name: "invoke drop-table-templateParams-tool", ddl: true, api: "http://127.0.0.1:5000/api/tool/drop-table-templateParams-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s"}`, tableName))), want: configs.ddlWant, isErr: false, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { if !tc.enabled { return } // if test case is DDL and source support ddl test cases ddlAllow := !tc.ddl || (tc.ddl && configs.supportDdl) // if test case is insert statement and source support insert test cases insertAllow := !tc.insert || (tc.insert && configs.supportInsert) if ddlAllow && insertAllow { // Send Tool invocation request resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader) if resp.StatusCode != http.StatusOK { if tc.isErr { return } t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } // Check response body var body map[string]interface{} err := json.Unmarshal(respBody, &body) if err != nil { t.Fatalf("error parsing response body") } got, ok := body["result"].(string) if !ok { t.Fatalf("unable to find result in response body") } if got != tc.want { t.Fatalf("unexpected value: got %q, want %q", got, tc.want) } } }) } } func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement, select1Want string, options ...ExecuteSqlOption) { // Resolve options // Default values for ExecuteSqlTestConfig configs := &ExecuteSqlTestConfig{ select1Statement: `"SELECT 1"`, } // Apply provided options for _, option := range options { option(configs) } // Get ID token idToken, err := GetGoogleIdToken(ClientId) if err != nil { t.Fatalf("error getting Google ID token: %s", err) } // Test tool invoke endpoint invokeTcs := []struct { name string api string requestHeader map[string]string requestBody io.Reader want string isErr bool }{ { name: "invoke my-exec-sql-tool", api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql": %s}`, configs.select1Statement))), want: select1Want, isErr: false, }, { name: "invoke my-exec-sql-tool create table", api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql": %s}`, createTableStatement))), want: "null", isErr: false, }, { name: "invoke my-exec-sql-tool select table", api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT * FROM t"}`)), want: "null", isErr: false, }, { name: "invoke my-exec-sql-tool drop table", api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{"sql":"DROP TABLE t"}`)), want: "null", isErr: false, }, { name: "invoke my-exec-sql-tool without body", api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{}`)), isErr: true, }, { name: "Invoke my-auth-exec-sql-tool with auth token", api: "http://127.0.0.1:5000/api/tool/my-auth-exec-sql-tool/invoke", requestHeader: map[string]string{"my-google-auth_token": idToken}, requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql": %s}`, configs.select1Statement))), isErr: false, want: select1Want, }, { name: "Invoke my-auth-exec-sql-tool with invalid auth token", api: "http://127.0.0.1:5000/api/tool/my-auth-exec-sql-tool/invoke", requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql": %s}`, configs.select1Statement))), isErr: true, }, { name: "Invoke my-auth-exec-sql-tool without auth token", api: "http://127.0.0.1:5000/api/tool/my-auth-exec-sql-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"sql": %s}`, configs.select1Statement))), isErr: true, }, { name: "invoke my-exec-sql-tool with invalid SELECT SQL", api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT * FROM non_existent_table"}`)), isErr: true, }, { name: "invoke my-exec-sql-tool with invalid ALTER SQL", api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{"sql":"ALTER TALE t ALTER COLUMN id DROP NOT NULL"}`)), isErr: true, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { // Send Tool invocation request resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader) if resp.StatusCode != http.StatusOK { if tc.isErr { return } t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } // Check response body var body map[string]interface{} err = json.Unmarshal(respBody, &body) if err != nil { t.Fatalf("error parsing response body") } got, ok := body["result"].(string) if !ok { t.Fatalf("unable to find result in response body") } if got != tc.want { t.Fatalf("unexpected value: got %q, want %q", got, tc.want) } }) } } // RunInitialize runs the initialize lifecycle for mcp to set up client-server connection func RunInitialize(t *testing.T, protocolVersion string) string { url := "http://127.0.0.1:5000/mcp" initializeRequestBody := map[string]any{ "jsonrpc": "2.0", "id": "mcp-initialize", "method": "initialize", "params": map[string]any{ "protocolVersion": protocolVersion, }, } reqMarshal, err := json.Marshal(initializeRequestBody) if err != nil { t.Fatalf("unexpected error during marshaling of body") } resp, _ := RunRequest(t, http.MethodPost, url, bytes.NewBuffer(reqMarshal), nil) if resp.StatusCode != 200 { t.Fatalf("response status code is not 200") } if contentType := resp.Header.Get("Content-type"); contentType != "application/json" { t.Fatalf("unexpected content-type header: want %s, got %s", "application/json", contentType) } sessionId := resp.Header.Get("Mcp-Session-Id") header := map[string]string{} if sessionId != "" { header["Mcp-Session-Id"] = sessionId } initializeNotificationBody := map[string]any{ "jsonrpc": "2.0", "method": "notifications/initialized", } notiMarshal, err := json.Marshal(initializeNotificationBody) if err != nil { t.Fatalf("unexpected error during marshaling of notifications body") } _, _ = RunRequest(t, http.MethodPost, url, bytes.NewBuffer(notiMarshal), header) return sessionId } // RunMCPToolCallMethod runs the tool/call for mcp endpoint func RunMCPToolCallMethod(t *testing.T, myFailToolWant, select1Want string, options ...McpTestOption) { // Resolve options // Default values for MCPTestConfig configs := &MCPTestConfig{ myToolId3NameAliceWant: `{"jsonrpc":"2.0","id":"my-tool","result":{"content":[{"type":"text","text":"{\"id\":1,\"name\":\"Alice\"}"},{"type":"text","text":"{\"id\":3,\"name\":\"Sid\"}"}]}}`, mcpSelect1Want: select1Want, supportClientAuth: false, supportSelect1Auth: true, } // Apply provided options for _, option := range options { option(configs) } sessionId := RunInitialize(t, "2024-11-05") // Get access token accessToken, err := sources.GetIAMAccessToken(t.Context()) if err != nil { t.Fatalf("error getting access token from ADC: %s", err) } accessToken = "Bearer " + accessToken idToken, err := GetGoogleIdToken(ClientId) if err != nil { t.Fatalf("error getting Google ID token: %s", err) } // Test tool invoke endpoint invokeTcs := []struct { name string api string enabled bool // switch to turn on/off the test case requestBody jsonrpc.JSONRPCRequest requestHeader map[string]string wantStatusCode int wantBody string }{ { name: "MCP Invoke my-tool", api: "http://127.0.0.1:5000/mcp", enabled: true, requestHeader: map[string]string{}, requestBody: jsonrpc.JSONRPCRequest{ Jsonrpc: "2.0", Id: "my-tool", Request: jsonrpc.Request{ Method: "tools/call", }, Params: map[string]any{ "name": "my-tool", "arguments": map[string]any{ "id": int(3), "name": "Alice", }, }, }, wantStatusCode: http.StatusOK, wantBody: configs.myToolId3NameAliceWant, }, { name: "MCP Invoke invalid tool", api: "http://127.0.0.1:5000/mcp", enabled: true, requestHeader: map[string]string{}, requestBody: jsonrpc.JSONRPCRequest{ Jsonrpc: "2.0", Id: "invalid-tool", Request: jsonrpc.Request{ Method: "tools/call", }, Params: map[string]any{ "name": "foo", "arguments": map[string]any{}, }, }, wantStatusCode: http.StatusOK, wantBody: `{"jsonrpc":"2.0","id":"invalid-tool","error":{"code":-32602,"message":"invalid tool name: tool with name \"foo\" does not exist"}}`, }, { name: "MCP Invoke my-tool without parameters", api: "http://127.0.0.1:5000/mcp", enabled: true, requestHeader: map[string]string{}, requestBody: jsonrpc.JSONRPCRequest{ Jsonrpc: "2.0", Id: "invoke-without-parameter", Request: jsonrpc.Request{ Method: "tools/call", }, Params: map[string]any{ "name": "my-tool", "arguments": map[string]any{}, }, }, wantStatusCode: http.StatusOK, wantBody: `{"jsonrpc":"2.0","id":"invoke-without-parameter","error":{"code":-32602,"message":"provided parameters were invalid: parameter \"id\" is required"}}`, }, { name: "MCP Invoke my-tool with insufficient parameters", api: "http://127.0.0.1:5000/mcp", enabled: true, requestHeader: map[string]string{}, requestBody: jsonrpc.JSONRPCRequest{ Jsonrpc: "2.0", Id: "invoke-insufficient-parameter", Request: jsonrpc.Request{ Method: "tools/call", }, Params: map[string]any{ "name": "my-tool", "arguments": map[string]any{"id": 1}, }, }, wantStatusCode: http.StatusOK, wantBody: `{"jsonrpc":"2.0","id":"invoke-insufficient-parameter","error":{"code":-32602,"message":"provided parameters were invalid: parameter \"name\" is required"}}`, }, { name: "MCP Invoke my-auth-required-tool", api: "http://127.0.0.1:5000/mcp", enabled: configs.supportSelect1Auth, requestHeader: map[string]string{"my-google-auth_token": idToken}, requestBody: jsonrpc.JSONRPCRequest{ Jsonrpc: "2.0", Id: "invoke my-auth-required-tool", Request: jsonrpc.Request{ Method: "tools/call", }, Params: map[string]any{ "name": "my-auth-required-tool", "arguments": map[string]any{}, }, }, wantStatusCode: http.StatusOK, wantBody: configs.mcpSelect1Want, }, { name: "MCP Invoke my-auth-required-tool with invalid auth token", api: "http://127.0.0.1:5000/mcp", requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, requestBody: jsonrpc.JSONRPCRequest{ Jsonrpc: "2.0", Id: "invoke my-auth-required-tool with invalid token", Request: jsonrpc.Request{ Method: "tools/call", }, Params: map[string]any{ "name": "my-auth-required-tool", "arguments": map[string]any{}, }, }, wantStatusCode: http.StatusUnauthorized, wantBody: "{\"jsonrpc\":\"2.0\",\"id\":\"invoke my-auth-required-tool with invalid token\",\"error\":{\"code\":-32600,\"message\":\"unauthorized Tool call: Please make sure your specify correct auth headers: unauthorized\"}}", }, { name: "MCP Invoke my-auth-required-tool without auth token", api: "http://127.0.0.1:5000/mcp", requestHeader: map[string]string{}, requestBody: jsonrpc.JSONRPCRequest{ Jsonrpc: "2.0", Id: "invoke my-auth-required-tool without token", Request: jsonrpc.Request{ Method: "tools/call", }, Params: map[string]any{ "name": "my-auth-required-tool", "arguments": map[string]any{}, }, }, wantStatusCode: http.StatusUnauthorized, wantBody: "{\"jsonrpc\":\"2.0\",\"id\":\"invoke my-auth-required-tool without token\",\"error\":{\"code\":-32600,\"message\":\"unauthorized Tool call: Please make sure your specify correct auth headers: unauthorized\"}}", }, { name: "MCP Invoke my-client-auth-tool", enabled: configs.supportClientAuth, api: "http://127.0.0.1:5000/mcp", requestHeader: map[string]string{"Authorization": accessToken}, requestBody: jsonrpc.JSONRPCRequest{ Jsonrpc: "2.0", Id: "invoke my-client-auth-tool", Request: jsonrpc.Request{ Method: "tools/call", }, Params: map[string]any{ "name": "my-client-auth-tool", "arguments": map[string]any{}, }, }, wantStatusCode: http.StatusOK, wantBody: "{\"jsonrpc\":\"2.0\",\"id\":\"invoke my-client-auth-tool\",\"result\":{\"content\":[{\"type\":\"text\",\"text\":\"{\\\"f0_\\\":1}\"}]}}", }, { name: "MCP Invoke my-client-auth-tool without access token", enabled: configs.supportClientAuth, api: "http://127.0.0.1:5000/mcp", requestHeader: map[string]string{}, requestBody: jsonrpc.JSONRPCRequest{ Jsonrpc: "2.0", Id: "invoke my-client-auth-tool", Request: jsonrpc.Request{ Method: "tools/call", }, Params: map[string]any{ "name": "my-client-auth-tool", "arguments": map[string]any{}, }, }, wantStatusCode: http.StatusUnauthorized, wantBody: "{\"jsonrpc\":\"2.0\",\"id\":\"invoke my-client-auth-tool\",\"error\":{\"code\":-32600,\"message\":\"missing access token in the 'Authorization' header\"}", }, { name: "MCP Invoke my-client-auth-tool with invalid access token", enabled: configs.supportClientAuth, api: "http://127.0.0.1:5000/mcp", requestHeader: map[string]string{"Authorization": "Bearer invalid-token"}, requestBody: jsonrpc.JSONRPCRequest{ Jsonrpc: "2.0", Id: "invoke my-client-auth-tool", Request: jsonrpc.Request{ Method: "tools/call", }, Params: map[string]any{ "name": "my-client-auth-tool", "arguments": map[string]any{}, }, }, wantStatusCode: http.StatusUnauthorized, }, { name: "MCP Invoke my-fail-tool", api: "http://127.0.0.1:5000/mcp", enabled: true, requestHeader: map[string]string{}, requestBody: jsonrpc.JSONRPCRequest{ Jsonrpc: "2.0", Id: "invoke-fail-tool", Request: jsonrpc.Request{ Method: "tools/call", }, Params: map[string]any{ "name": "my-fail-tool", "arguments": map[string]any{"id": 1}, }, }, wantStatusCode: http.StatusOK, wantBody: myFailToolWant, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { if !tc.enabled { return } reqMarshal, err := json.Marshal(tc.requestBody) if err != nil { t.Fatalf("unexpected error during marshaling of request body") } // add headers headers := map[string]string{} if sessionId != "" { headers["Mcp-Session-Id"] = sessionId } for key, value := range tc.requestHeader { headers[key] = value } httpResponse, respBody := RunRequest(t, http.MethodPost, tc.api, bytes.NewBuffer(reqMarshal), headers) // Check status code if httpResponse.StatusCode != tc.wantStatusCode { t.Errorf("StatusCode mismatch: got %d, want %d", httpResponse.StatusCode, tc.wantStatusCode) } // Check response body got := string(bytes.TrimSpace(respBody)) if !strings.Contains(got, tc.wantBody) { t.Fatalf("Expected substring not found:\ngot: %q\nwant: %q (to be contained within got)", got, tc.wantBody) } }) } } func setupPostgresSchemas(t *testing.T, ctx context.Context, pool *pgxpool.Pool, schemaName string) func() { createSchemaStmt := fmt.Sprintf("CREATE SCHEMA %s", schemaName) _, err := pool.Exec(ctx, createSchemaStmt) if err != nil { t.Fatalf("failed to create schema: %v", err) } return func() { dropSchemaStmt := fmt.Sprintf("DROP SCHEMA %s CASCADE", schemaName) _, err := pool.Exec(ctx, dropSchemaStmt) if err != nil { t.Fatalf("failed to drop schema: %v", err) } } } func RunPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth, user string) { // TableNameParam columns to construct want paramTableColumns := fmt.Sprintf(`[ {"data_type": "integer", "column_name": "id", "column_default": "nextval('%s_id_seq'::regclass)", "is_not_nullable": true, "ordinal_position": 1, "column_comment": null}, {"data_type": "text", "column_name": "name", "column_default": null, "is_not_nullable": false, "ordinal_position": 2, "column_comment": null} ]`, tableNameParam) // TableNameAuth columns to construct want authTableColumns := fmt.Sprintf(`[ {"data_type": "integer", "column_name": "id", "column_default": "nextval('%s_id_seq'::regclass)", "is_not_nullable": true, "ordinal_position": 1, "column_comment": null}, {"data_type": "text", "column_name": "name", "column_default": null, "is_not_nullable": false, "ordinal_position": 2, "column_comment": null}, {"data_type": "text", "column_name": "email", "column_default": null, "is_not_nullable": false, "ordinal_position": 3, "column_comment": null} ]`, tableNameAuth) const ( // Template to construct detailed output want detailedObjectTemplate = `{ "object_name": "%[1]s", "schema_name": "public", "object_details": { "owner": "%[3]s", "comment": null, "indexes": [{"is_primary": true, "is_unique": true, "index_name": "%[1]s_pkey", "index_method": "btree", "index_columns": ["id"], "index_definition": "CREATE UNIQUE INDEX %[1]s_pkey ON public.%[1]s USING btree (id)"}], "triggers": [], "columns": %[2]s, "object_name": "%[1]s", "object_type": "TABLE", "schema_name": "public", "constraints": [{"constraint_name": "%[1]s_pkey", "constraint_type": "PRIMARY KEY", "constraint_columns": ["id"], "constraint_definition": "PRIMARY KEY (id)", "foreign_key_referenced_table": null, "foreign_key_referenced_columns": null}] } }` // Template to construct simple output want simpleObjectTemplate = `{"object_name":"%s", "schema_name":"public", "object_details":{"name":"%s"}}` ) // Helper to build json for detailed want getDetailedWant := func(tableName, columnJSON string) string { return fmt.Sprintf(detailedObjectTemplate, tableName, columnJSON, user) } // Helper to build template for simple want getSimpleWant := func(tableName string) string { return fmt.Sprintf(simpleObjectTemplate, tableName, tableName) } invokeTcs := []struct { name string api string requestBody io.Reader wantStatusCode int want string isAllTables bool }{ { name: "invoke list_tables all tables detailed output", api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: bytes.NewBuffer([]byte(`{"table_names": ""}`)), wantStatusCode: http.StatusOK, want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)), isAllTables: true, }, { name: "invoke list_tables all tables simple output", api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: bytes.NewBuffer([]byte(`{"table_names": "", "output_format": "simple"}`)), wantStatusCode: http.StatusOK, want: fmt.Sprintf("[%s,%s]", getSimpleWant(tableNameAuth), getSimpleWant(tableNameParam)), isAllTables: true, }, { name: "invoke list_tables detailed output", api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s"}`, tableNameAuth))), wantStatusCode: http.StatusOK, want: fmt.Sprintf("[%s]", getDetailedWant(tableNameAuth, authTableColumns)), }, { name: "invoke list_tables simple output", api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s", "output_format": "simple"}`, tableNameAuth))), wantStatusCode: http.StatusOK, want: fmt.Sprintf("[%s]", getSimpleWant(tableNameAuth)), }, { name: "invoke list_tables with invalid output format", api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: bytes.NewBuffer([]byte(`{"table_names": "", "output_format": "abcd"}`)), wantStatusCode: http.StatusBadRequest, }, { name: "invoke list_tables with malformed table_names parameter", api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: bytes.NewBuffer([]byte(`{"table_names": 12345, "output_format": "detailed"}`)), wantStatusCode: http.StatusBadRequest, }, { name: "invoke list_tables with multiple table names", api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s,%s"}`, tableNameParam, tableNameAuth))), wantStatusCode: http.StatusOK, want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)), }, { name: "invoke list_tables with non-existent table", api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: bytes.NewBuffer([]byte(`{"table_names": "non_existent_table"}`)), wantStatusCode: http.StatusOK, want: `[]`, }, { name: "invoke list_tables with one existing and one non-existent table", api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s,non_existent_table"}`, tableNameParam))), wantStatusCode: http.StatusOK, want: fmt.Sprintf("[%s]", getDetailedWant(tableNameParam, paramTableColumns)), }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { resp, respBytes := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBytes)) } if tc.wantStatusCode == http.StatusOK { var bodyWrapper map[string]json.RawMessage if err := json.Unmarshal(respBytes, &bodyWrapper); err != nil { t.Fatalf("error parsing response wrapper: %s, body: %s", err, string(respBytes)) } resultJSON, ok := bodyWrapper["result"] if !ok { t.Fatal("unable to find 'result' in response body") } var resultString string if err := json.Unmarshal(resultJSON, &resultString); err != nil { t.Fatalf("'result' is not a JSON-encoded string: %s", err) } var got, want []any if err := json.Unmarshal([]byte(resultString), &got); err != nil { t.Fatalf("failed to unmarshal actual result string: %v", err) } if err := json.Unmarshal([]byte(tc.want), &want); err != nil { t.Fatalf("failed to unmarshal expected want string: %v", err) } // Checking only the default public schema where the test tables are created to avoid brittle tests. if tc.isAllTables { var filteredGot []any for _, item := range got { if tableMap, ok := item.(map[string]interface{}); ok { if schema, ok := tableMap["schema_name"]; ok && schema == "public" { filteredGot = append(filteredGot, item) } } } got = filteredGot } sort.SliceStable(got, func(i, j int) bool { return fmt.Sprintf("%v", got[i]) < fmt.Sprintf("%v", got[j]) }) sort.SliceStable(want, func(i, j int) bool { return fmt.Sprintf("%v", want[i]) < fmt.Sprintf("%v", want[j]) }) if !reflect.DeepEqual(got, want) { t.Errorf("Unexpected result: got %#v, want: %#v", got, want) } } }) } } func setUpPostgresViews(t *testing.T, ctx context.Context, pool *pgxpool.Pool, viewName string) func() { createView := fmt.Sprintf("CREATE VIEW %s AS SELECT 1 AS col", viewName) _, err := pool.Exec(ctx, createView) if err != nil { t.Fatalf("failed to create view: %v", err) } return func() { dropView := fmt.Sprintf("DROP VIEW %s", viewName) _, err := pool.Exec(ctx, dropView) if err != nil { t.Fatalf("failed to drop view: %v", err) } } } func RunPostgresListViewsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { //adding this line temporarily viewName := "test_view_" + strings.ReplaceAll(uuid.New().String(), "-", "") dropViewfunc1 := setUpPostgresViews(t, ctx, pool, viewName) defer dropViewfunc1() invokeTcs := []struct { name string requestBody io.Reader wantStatusCode int want string }{ { name: "invoke list_views with newly created view", requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"view_name": "%s"}`, viewName))), wantStatusCode: http.StatusOK, want: fmt.Sprintf(`[{"schema_name":"public","view_name":"%s","owner_name":"postgres","definition":" SELECT 1 AS col;"}]`, viewName), }, { name: "invoke list_views with non-existent_view", requestBody: bytes.NewBuffer([]byte(`{"view_name": "non_existent_view"}`)), wantStatusCode: http.StatusOK, want: `null`, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { const api = "http://127.0.0.1:5000/api/tool/list_views/invoke" resp, body := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body)) } if tc.wantStatusCode != http.StatusOK { return } var bodyWrapper struct { Result json.RawMessage `json:"result"` } if err := json.Unmarshal(body, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } var resultString string if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { resultString = string(bodyWrapper.Result) } var got, want any if err := json.Unmarshal([]byte(resultString), &got); err != nil { t.Fatalf("failed to unmarshal nested result string: %v", err) } if err := json.Unmarshal([]byte(tc.want), &want); err != nil { t.Fatalf("failed to unmarshal want string: %v", err) } if diff := cmp.Diff(want, got); diff != "" { t.Errorf("Unexpected result (-want +got):\n%s", diff) } }) } } func RunPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { schemaName := "test_schema_" + strings.ReplaceAll(uuid.New().String(), "-", "") cleanup := setupPostgresSchemas(t, ctx, pool, schemaName) defer cleanup() wantSchema := map[string]any{"functions": float64(0), "grants": map[string]any{}, "owner": "postgres", "schema_name": schemaName, "tables": float64(0), "views": float64(0)} invokeTcs := []struct { name string requestBody io.Reader wantStatusCode int want []map[string]any compareSubset bool }{ { name: "invoke list_schemas with schema_name", requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"schema_name": "%s"}`, schemaName))), wantStatusCode: http.StatusOK, want: []map[string]any{wantSchema}, }, { name: "invoke list_schemas with owner name", requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"owner": "%s"}`, "postgres"))), wantStatusCode: http.StatusOK, want: []map[string]any{wantSchema}, compareSubset: true, }, { name: "invoke list_schemas with limit 1", requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"schema_name": "%s","limit": 1}`, schemaName))), wantStatusCode: http.StatusOK, want: []map[string]any{wantSchema}, }, { name: "invoke list_schemas with non-existent schema", requestBody: bytes.NewBuffer([]byte(`{"schema_name": "non_existent_schema"}`)), wantStatusCode: http.StatusOK, want: nil, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { const api = "http://127.0.0.1:5000/api/tool/list_schemas/invoke" resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } if tc.wantStatusCode != http.StatusOK { return } var bodyWrapper struct { Result json.RawMessage `json:"result"` } if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } var resultString string if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { resultString = string(bodyWrapper.Result) } var got []map[string]any if err := json.Unmarshal([]byte(resultString), &got); err != nil { t.Fatalf("failed to unmarshal nested result string: %v", err) } if tc.compareSubset { // Assert that the 'wantTrigger' is present in the 'got' list. found := false for _, resultSchema := range got { if resultSchema["schema_name"] == wantSchema["schema_name"] { found = true if diff := cmp.Diff(wantSchema, resultSchema); diff != "" { t.Errorf("Mismatch in fields for the expected trigger (-want +got):\n%s", diff) } break } } if !found { t.Errorf("Expected schema '%s' not found in the list of all schemas.", wantSchema) } } else { if diff := cmp.Diff(tc.want, got); diff != "" { t.Errorf("Unexpected result (-want +got):\n%s", diff) } } }) } } func RunPostgresDatabaseOverviewTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { const api = "http://127.0.0.1:5000/api/tool/database_overview/invoke" requestBody := bytes.NewBuffer([]byte(`{}`)) resp, respBody := RunRequest(t, http.MethodPost, api, requestBody, nil) if resp.StatusCode != http.StatusOK { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, http.StatusOK, string(respBody)) } var bodyWrapper struct { Result json.RawMessage `json:"result"` } if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v, body: %s", err, string(respBody)) } var resultString string if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { resultString = string(bodyWrapper.Result) } var got []map[string]any if err := json.Unmarshal([]byte(resultString), &got); err != nil { t.Fatalf("failed to unmarshal nested result string: %v, result string: %s", err, resultString) } if len(got) != 1 { t.Fatalf("Expected exactly one row in the result, got %d", len(got)) } resultRow := got[0] // Define expected keys based on the SELECT statement expectedKeys := []string{ "pg_version", "is_replica", "uptime", "max_connections", "current_connections", "active_connections", "pct_connections_used", } for _, key := range expectedKeys { if _, ok := resultRow[key]; !ok { t.Errorf("Missing expected key in result: %s", key) } } // Check types of the fields. JSON numbers are unmarshalled into float64. if _, ok := resultRow["pg_version"].(string); !ok { t.Errorf("Expected 'pg_version' to be a string, got %T", resultRow["pg_version"]) } if _, ok := resultRow["is_replica"].(bool); !ok { t.Errorf("Expected 'is_replica' to be a bool, got %T", resultRow["is_replica"]) } if _, ok := resultRow["uptime"].(string); !ok { t.Errorf("Expected 'uptime' to be a string, got %T", resultRow["uptime"]) } if _, ok := resultRow["max_connections"].(float64); !ok { t.Errorf("Expected 'max_connections' to be a number (float64), got %T", resultRow["max_connections"]) } if _, ok := resultRow["current_connections"].(float64); !ok { t.Errorf("Expected 'current_connections' to be a number (float64), got %T", resultRow["current_connections"]) } if _, ok := resultRow["active_connections"].(float64); !ok { t.Errorf("Expected 'active_connections' to be a number (float64), got %T", resultRow["active_connections"]) } if _, ok := resultRow["pct_connections_used"].(float64); !ok { t.Errorf("Expected 'pct_connections_used' to be a number (float64), got %T", resultRow["pct_connections_used"]) } // Basic sanity checks on values if maxConn, ok := resultRow["max_connections"].(float64); ok { if maxConn <= 0 { t.Errorf("Expected 'max_connections' to be positive, got %f", maxConn) } } if pctUsed, ok := resultRow["pct_connections_used"].(float64); ok { if pctUsed < 0 || pctUsed > 100 { t.Errorf("Expected 'pct_connections_used' to be between 0 and 100, got %f", pctUsed) } } } func setupPostgresTrigger(t *testing.T, ctx context.Context, pool *pgxpool.Pool, schemaName, tableName, functionName, triggerName string) func() { t.Helper() createSchemaStmt := fmt.Sprintf("CREATE SCHEMA %s", schemaName) if _, err := pool.Exec(ctx, createSchemaStmt); err != nil { t.Fatalf("failed to create schema %s: %v", schemaName, err) } createTableStmt := fmt.Sprintf("CREATE TABLE %s.%s (id SERIAL PRIMARY KEY, name TEXT)", schemaName, tableName) if _, err := pool.Exec(ctx, createTableStmt); err != nil { t.Fatalf("failed to create table %s.%s: %v", schemaName, tableName, err) } createFunctionStmt := fmt.Sprintf(` CREATE OR REPLACE FUNCTION %s.%s() RETURNS TRIGGER AS $$ BEGIN RETURN NEW; END; $$ LANGUAGE plpgsql; `, schemaName, functionName) if _, err := pool.Exec(ctx, createFunctionStmt); err != nil { t.Fatalf("failed to create function %s.%s: %v", schemaName, functionName, err) } createTriggerStmt := fmt.Sprintf(` CREATE TRIGGER %s AFTER INSERT ON %s.%s FOR EACH ROW EXECUTE FUNCTION %s.%s(); `, triggerName, schemaName, tableName, schemaName, functionName) if _, err := pool.Exec(ctx, createTriggerStmt); err != nil { t.Fatalf("failed to create trigger %s: %v", triggerName, err) } return func() { dropSchemaStmt := fmt.Sprintf("DROP SCHEMA %s CASCADE", schemaName) if _, err := pool.Exec(ctx, dropSchemaStmt); err != nil { t.Fatalf("failed to drop schema %s: %v", schemaName, err) } } } func RunPostgresListTriggersTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { uniqueID := strings.ReplaceAll(uuid.New().String(), "-", "") schemaName := "test_schema_" + uniqueID tableName := "test_table_" + uniqueID functionName := "test_func_" + uniqueID triggerName := "test_trigger_" + uniqueID cleanup := setupPostgresTrigger(t, ctx, pool, schemaName, tableName, functionName, triggerName) defer cleanup() // Definition can vary slightly based on server version/settings, so we fetch it to compare. var expectedDef string getDefQuery := fmt.Sprintf("SELECT pg_get_triggerdef(oid) FROM pg_trigger WHERE tgname = '%s'", triggerName) err := pool.QueryRow(ctx, getDefQuery).Scan(&expectedDef) if err != nil { t.Fatalf("failed to fetch trigger definition: %v", err) } wantTrigger := map[string]any{ "trigger_name": triggerName, "schema_name": schemaName, "table_name": tableName, "status": "ENABLED", "timing": "AFTER", "events": "INSERT", "activation_level": "ROW", "function_name": functionName, "definition": expectedDef, } invokeTcs := []struct { name string requestBody io.Reader wantStatusCode int want []map[string]any compareSubset bool }{ { name: "list all triggers (expecting the one we created)", requestBody: bytes.NewBuffer([]byte(`{}`)), wantStatusCode: http.StatusOK, want: []map[string]any{wantTrigger}, compareSubset: true, // avoid test flakiness in race condition }, { name: "filter by trigger_name", requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"trigger_name": "%s"}`, triggerName))), wantStatusCode: http.StatusOK, want: []map[string]any{wantTrigger}, }, { name: "filter by schema_name", requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"schema_name": "%s"}`, schemaName))), wantStatusCode: http.StatusOK, want: []map[string]any{wantTrigger}, }, { name: "filter by table_name", requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_name": "%s"}`, tableName))), wantStatusCode: http.StatusOK, want: []map[string]any{wantTrigger}, }, { name: "filter by non-existent trigger_name", requestBody: bytes.NewBuffer([]byte(`{"trigger_name": "non_existent_trigger"}`)), wantStatusCode: http.StatusOK, want: nil, }, { name: "filter by non-existent schema_name", requestBody: bytes.NewBuffer([]byte(`{"schema_name": "non_existent_schema"}`)), wantStatusCode: http.StatusOK, want: nil, }, { name: "filter by non-existent table_name", requestBody: bytes.NewBuffer([]byte(`{"table_name": "non_existent_table"}`)), wantStatusCode: http.StatusOK, want: nil, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { const api = "http://127.0.0.1:5000/api/tool/list_triggers/invoke" resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } if tc.wantStatusCode != http.StatusOK { return } var bodyWrapper struct { Result json.RawMessage `json:"result"` } if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } var resultString string if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { resultString = string(bodyWrapper.Result) } var got []map[string]any if err := json.Unmarshal([]byte(resultString), &got); err != nil { t.Fatalf("failed to unmarshal nested result string: %v, content: %s", err, resultString) } if tc.compareSubset { // Assert that the 'wantTrigger' is present in the 'got' list. found := false for _, resultTrigger := range got { if resultTrigger["trigger_name"] == wantTrigger["trigger_name"] { found = true if diff := cmp.Diff(wantTrigger, resultTrigger); diff != "" { t.Errorf("Mismatch in fields for the expected trigger (-want +got):\n%s", diff) } break } } if !found { t.Errorf("Expected trigger '%s' not found in the list of all triggers.", triggerName) } } else { if diff := cmp.Diff(tc.want, got); diff != "" { t.Errorf("Unexpected result (-want +got):\n%s", diff) } } }) } } func setupPostgresPublicationTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool, tableName string, pubName string) func(t *testing.T) { t.Helper() createTableStmt := fmt.Sprintf("CREATE TABLE %s (id SERIAL PRIMARY KEY, name TEXT);", tableName) if _, err := pool.Exec(ctx, createTableStmt); err != nil { t.Fatalf("unable to create table %s: %v", tableName, err) } createPubStmt := fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE %s;", pubName, tableName) if _, err := pool.Exec(ctx, createPubStmt); err != nil { if _, dropErr := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName)); dropErr != nil { t.Errorf("unable to drop table after failing to create publication: %v", dropErr) } t.Fatalf("unable to create publication %s: %v", pubName, err) } return func(t *testing.T) { t.Helper() if _, err := pool.Exec(ctx, fmt.Sprintf("DROP PUBLICATION IF EXISTS %s;", pubName)); err != nil { t.Errorf("unable to drop publication %s: %v", pubName, err) } if _, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName)); err != nil { t.Errorf("unable to drop table %s: %v", tableName, err) } } } func RunPostgresListPublicationTablesTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { table1Name := "pub_table_1" pub1Name := "pub_1" table2Name := "pub_table_2" pub2Name := "pub_2" cleanup := setupPostgresPublicationTable(t, ctx, pool, table1Name, pub1Name) defer cleanup(t) cleanup2 := setupPostgresPublicationTable(t, ctx, pool, table2Name, pub2Name) defer cleanup2(t) // Fetch the current user to match the publication_owner var currentUser string err := pool.QueryRow(ctx, "SELECT current_user;").Scan(&currentUser) if err != nil { t.Fatalf("unable to fetch current user: %v", err) } wantTable1 := map[string]any{ "publication_name": pub1Name, "schema_name": "public", "table_name": table1Name, "publishes_all_tables": false, "publishes_inserts": true, "publishes_updates": true, "publishes_deletes": true, "publishes_truncates": true, "publication_owner": currentUser, } wantTable2 := map[string]any{ "publication_name": pub2Name, "schema_name": "public", "table_name": table2Name, "publishes_all_tables": false, "publishes_inserts": true, "publishes_updates": true, "publishes_deletes": true, "publishes_truncates": true, "publication_owner": currentUser, } invokeTcs := []struct { name string requestBody io.Reader wantStatusCode int want []map[string]any }{ { name: "list all publication tables", requestBody: bytes.NewBufferString(`{}`), wantStatusCode: http.StatusOK, want: []map[string]any{wantTable1, wantTable2}, }, { name: "list all tables for the created publication", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"publication_names": "%s"}`, pub1Name)), wantStatusCode: http.StatusOK, want: []map[string]any{wantTable1}, }, { name: "filter by table_name", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s, %s"}`, table1Name, table2Name)), wantStatusCode: http.StatusOK, want: []map[string]any{wantTable1, wantTable2}, }, { name: "filter by schema_name and table_name", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_names": "public", "table_name": "%s , %s"}`, table1Name, table2Name)), wantStatusCode: http.StatusOK, want: []map[string]any{wantTable1, wantTable2}, }, { name: "invoke list_publication_tables with non-existent table", requestBody: bytes.NewBufferString(`{"table_names": "non_existent_table"}`), wantStatusCode: http.StatusOK, want: nil, }, { name: "invoke list_publication_tables with non-existent publication", requestBody: bytes.NewBufferString(`{"publication_names": "non_existent_pub"}`), wantStatusCode: http.StatusOK, want: nil, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { const api = "http://127.0.0.1:5000/api/tool/list_publication_tables/invoke" resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } if tc.wantStatusCode != http.StatusOK { return } var bodyWrapper struct { Result json.RawMessage `json:"result"` } if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } var resultString string if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { resultString = string(bodyWrapper.Result) } var got []map[string]any if err := json.Unmarshal([]byte(resultString), &got); err != nil { t.Fatalf("failed to unmarshal nested result string: %v, content: %s", err, resultString) } if diff := cmp.Diff(tc.want, got); diff != "" { t.Errorf("Unexpected result (-want +got):\n%s", diff) } }) } } func RunPostgresListActiveQueriesTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { type queryListDetails struct { ProcessId any `json:"pid"` User string `json:"user"` Datname string `json:"datname"` ApplicationName string `json:"application_name"` ClientAddress string `json:"client_addr"` State string `json:"state"` WaitEventType string `json:"wait_event_type"` WaitEvent string `json:"wait_event"` BackendStart any `json:"backend_start"` TransactionStart any `json:"xact_start"` QueryStart any `json:"query_start"` QueryDuration any `json:"query_duration"` Query string `json:"query"` } singleQueryWanted := queryListDetails{ ProcessId: any(nil), User: "", Datname: "", ApplicationName: "", ClientAddress: "", State: "", WaitEventType: "", WaitEvent: "", BackendStart: any(nil), TransactionStart: any(nil), QueryStart: any(nil), QueryDuration: any(nil), Query: "SELECT pg_sleep(10);", } invokeTcs := []struct { name string requestBody io.Reader clientSleepSecs int waitSecsBeforeCheck int wantStatusCode int want any }{ // exclude background monitoring apps such as "wal_uploader" { name: "invoke list_active_queries when the system is idle", requestBody: bytes.NewBufferString(`{"exclude_application_names": "wal_uploader"}`), clientSleepSecs: 0, waitSecsBeforeCheck: 0, wantStatusCode: http.StatusOK, want: []queryListDetails(nil), }, { name: "invoke list_active_queries when there is 1 ongoing but lower than the threshold", requestBody: bytes.NewBufferString(`{"min_duration": "100 seconds", "exclude_application_names": "wal_uploader"}`), clientSleepSecs: 1, waitSecsBeforeCheck: 1, wantStatusCode: http.StatusOK, want: []queryListDetails(nil), }, { name: "invoke list_active_queries when 1 ongoing query should show up", requestBody: bytes.NewBufferString(`{"min_duration": "1 seconds", "exclude_application_names": "wal_uploader"}`), clientSleepSecs: 10, waitSecsBeforeCheck: 5, wantStatusCode: http.StatusOK, want: []queryListDetails{singleQueryWanted}, }, } var wg sync.WaitGroup for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { if tc.clientSleepSecs > 0 { wg.Add(1) go func() { defer wg.Done() err := pool.Ping(ctx) if err != nil { t.Errorf("unable to connect to test database: %s", err) return } _, err = pool.Exec(ctx, fmt.Sprintf("SELECT pg_sleep(%d);", tc.clientSleepSecs)) if err != nil { t.Errorf("Executing 'SELECT pg_sleep' failed: %s", err) } }() } if tc.waitSecsBeforeCheck > 0 { time.Sleep(time.Duration(tc.waitSecsBeforeCheck) * time.Second) } const api = "http://127.0.0.1:5000/api/tool/list_active_queries/invoke" resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } if tc.wantStatusCode != http.StatusOK { return } var bodyWrapper struct { Result json.RawMessage `json:"result"` } if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } var resultString string if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { resultString = string(bodyWrapper.Result) } var got any var details []queryListDetails if err := json.Unmarshal([]byte(resultString), &details); err != nil { t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err) } got = details if diff := cmp.Diff(tc.want, got, cmp.Comparer(func(a, b queryListDetails) bool { return a.Query == b.Query })); diff != "" { t.Errorf("Unexpected result: got %#v, want: %#v", got, tc.want) } }) } wg.Wait() } func RunPostgresListAvailableExtensionsTest(t *testing.T) { invokeTcs := []struct { name string api string requestBody io.Reader wantStatusCode int }{ { name: "invoke list_available_extensions output", api: "http://127.0.0.1:5000/api/tool/list_available_extensions/invoke", wantStatusCode: http.StatusOK, requestBody: bytes.NewBuffer([]byte(`{}`)), }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } // Intentionally not adding the output check as output depends on the postgres instance used where the the functional test runs. // Adding the check will make the test flaky. }) } } func RunPostgresListInstalledExtensionsTest(t *testing.T) { invokeTcs := []struct { name string api string requestBody io.Reader wantStatusCode int }{ { name: "invoke list_installed_extensions output", api: "http://127.0.0.1:5000/api/tool/list_installed_extensions/invoke", wantStatusCode: http.StatusOK, requestBody: bytes.NewBuffer([]byte(`{}`)), }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { resp, bodyBytes := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) } // Intentionally not adding the output check as output depends on the postgres instance used where the the functional test runs. // Adding the check will make the test flaky. }) } } func setupPostgresIndex(t *testing.T, ctx context.Context, pool *pgxpool.Pool, schemaName string, tableName string) func(t *testing.T) { t.Helper() createSchemaStmt := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s;", schemaName) if _, err := pool.Exec(ctx, createSchemaStmt); err != nil { t.Fatalf("unable to create schema %s: %v", schemaName, err) } fullTableName := fmt.Sprintf("%s.%s", schemaName, tableName) createTableStmt := fmt.Sprintf("CREATE TABLE %s (id SERIAL PRIMARY KEY, name TEXT, email TEXT);", fullTableName) if _, err := pool.Exec(ctx, createTableStmt); err != nil { t.Fatalf("unable to create table %s: %v", fullTableName, err) } // Create a unique index on email index1Stmt := fmt.Sprintf("CREATE UNIQUE INDEX %s_email_idx ON %s (email);", tableName, fullTableName) if _, err := pool.Exec(ctx, index1Stmt); err != nil { t.Fatalf("unable to create index %s_email_idx: %v", tableName, err) } // Create a non-unique index on name index2Stmt := fmt.Sprintf("CREATE INDEX %s_name_idx ON %s (name);", tableName, fullTableName) if _, err := pool.Exec(ctx, index2Stmt); err != nil { t.Fatalf("unable to create index %s_name_idx: %v", tableName, err) } return func(t *testing.T) { t.Helper() if _, err := pool.Exec(ctx, fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE;", schemaName)); err != nil { t.Errorf("unable to drop schema: %v", err) } } } func RunPostgresListIndexesTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { schemaName := "testschema_" + strings.ReplaceAll(uuid.New().String(), "-", "") tableName := "table1_" + strings.ReplaceAll(uuid.New().String(), "-", "") cleanup := setupPostgresIndex(t, ctx, pool, schemaName, tableName) defer cleanup(t) // Primary key index wantIndexPK := map[string]any{ "schema_name": schemaName, "table_name": tableName, "index_name": tableName + "_pkey", "index_type": "btree", "is_unique": true, "is_primary": true, "is_used": false, "index_definition": fmt.Sprintf("CREATE UNIQUE INDEX %s_pkey ON %s.%s USING btree (id)", tableName, schemaName, tableName), // Size and scan counts can vary, so omitting them from strict checks or using ranges might be better in real tests. } // Email unique index wantIndexEmail := map[string]any{ "schema_name": schemaName, "table_name": tableName, "index_name": tableName + "_email_idx", "index_type": "btree", "is_unique": true, "is_primary": false, "is_used": false, "index_definition": fmt.Sprintf("CREATE UNIQUE INDEX %s_email_idx ON %s.%s USING btree (email)", tableName, schemaName, tableName), } // Name non-unique index wantIndexName := map[string]any{ "schema_name": schemaName, "table_name": tableName, "index_name": tableName + "_name_idx", "index_type": "btree", "is_unique": false, "is_primary": false, "is_used": false, "index_definition": fmt.Sprintf("CREATE INDEX %s_name_idx ON %s.%s USING btree (name)", tableName, schemaName, tableName), } allWantIndexes := []map[string]any{wantIndexEmail, wantIndexName, wantIndexPK} invokeTcs := []struct { name string requestBody io.Reader wantStatusCode int want []map[string]any }{ // List all indexes is skipped because the output might include indexes for other database tables // defined outside of this test, which could make the test flaky. { name: "list_indexes for a specific schema and table", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s", "table_name": "%s"}`, schemaName, tableName)), wantStatusCode: http.StatusOK, want: allWantIndexes, }, { name: "list_indexes for a specific schema", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s"}`, schemaName)), wantStatusCode: http.StatusOK, want: allWantIndexes, }, { name: "list_indexes with non-existent schema", requestBody: bytes.NewBufferString(`{"schema_name": "non_existent_schema"}`), wantStatusCode: http.StatusOK, want: nil, }, { name: "list_indexes with non-existent table in existing schema", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s", "table_name": "non_existent_table"}`, schemaName)), wantStatusCode: http.StatusOK, want: nil, }, { name: "list_indexes filter by index name", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s", "table_name": "%s", "index_name": "%s"}`, schemaName, tableName, tableName+"_email_idx")), wantStatusCode: http.StatusOK, want: []map[string]any{wantIndexEmail}, }, { name: "list_indexes filter by non-existent index name", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s", "table_name": "%s", "index_name": "non_existent_idx"}`, schemaName, tableName)), wantStatusCode: http.StatusOK, want: nil, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { const api = "http://127.0.0.1:5000/api/tool/list_indexes/invoke" resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } if tc.wantStatusCode != http.StatusOK { return } var bodyWrapper struct { Result json.RawMessage `json:"result"` } if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } var resultString string if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { resultString = string(bodyWrapper.Result) } var got []map[string]any if err := json.Unmarshal([]byte(resultString), &got); err != nil { t.Fatalf("failed to unmarshal nested result string: %v, resultString: %s", err, resultString) } // Normalize got by removing fields that are hard to predict (like size) for _, index := range got { delete(index, "index_size_bytes") delete(index, "index_scans") delete(index, "tuples_read") delete(index, "tuples_fetched") } if diff := cmp.Diff(tc.want, got); diff != "" { t.Errorf("Unexpected result (-want +got):\n%s", diff) } }) } } func setupListSequencesTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) (string, func(t *testing.T)) { sequenceName := "list_sequences_seq1_" + strings.ReplaceAll(uuid.New().String(), "-", "") createSequence1Stmt := fmt.Sprintf("CREATE SEQUENCE %s INCREMENT 1 START 1;", sequenceName) _, err := pool.Exec(ctx, createSequence1Stmt) if err != nil { t.Fatalf("unable to create sequence %s: %s", sequenceName, err) } return sequenceName, func(t *testing.T) { _, err := pool.Exec(ctx, fmt.Sprintf("DROP SEQUENCE IF EXISTS %s;", sequenceName)) if err != nil { t.Errorf("unable to drop sequences: %v", err) } } } func RunPostgresListSequencesTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { sequenceName, teardown := setupListSequencesTest(t, ctx, pool) defer teardown(t) wantSequence := map[string]any{ "sequence_name": sequenceName, "schema_name": "public", "sequence_owner": "postgres", "data_type": "bigint", "start_value": float64(1), "min_value": float64(1), "max_value": float64(9223372036854775807), "increment_by": float64(1), "last_value": nil, } invokeTcs := []struct { name string api string requestBody io.Reader wantStatusCode int want []map[string]any }{ { name: "invoke list_sequences", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"sequence_name": "%s"}`, sequenceName)), wantStatusCode: http.StatusOK, want: []map[string]any{wantSequence}, }, { name: "invoke list_sequences with non-existent sequence", requestBody: bytes.NewBufferString(`{"sequence_name": "non_existent_sequence"}`), wantStatusCode: http.StatusOK, want: nil, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { const api = "http://127.0.0.1:5000/api/tool/list_sequences/invoke" resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } if tc.wantStatusCode != http.StatusOK { return } var bodyWrapper struct { Result json.RawMessage `json:"result"` } if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } var resultString string if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { resultString = string(bodyWrapper.Result) } var got []map[string]any if err := json.Unmarshal([]byte(resultString), &got); err != nil { t.Fatalf("failed to unmarshal nested result string: %v", err) } if diff := cmp.Diff(tc.want, got); diff != "" { t.Errorf("Unexpected result (-want +got):\n%s", diff) } }) } } func RunPostgresListTableSpacesTest(t *testing.T) { invokeTcs := []struct { name string api string requestBody io.Reader wantStatusCode int }{ { name: "invoke list_tablespaces output", api: "http://127.0.0.1:5000/api/tool/list_tablespaces/invoke", wantStatusCode: http.StatusOK, requestBody: bytes.NewBuffer([]byte(`{}`)), }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } // Intentionally not adding the output check as output depends on the postgres instance used where the the functional test runs. // Adding the check will make the test flaky. }) } } func RunPostgresListPgSettingsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { targetSetting := "maintenance_work_mem" var name, setting, unit, shortDesc, source, contextVal string // We query the raw pg_settings to get the data needed to reconstruct the logic // defined in your listPgSettingQuery. err := pool.QueryRow(ctx, ` SELECT name, setting, unit, short_desc, source, context FROM pg_settings WHERE name = $1 `, targetSetting).Scan(&name, &setting, &unit, &shortDesc, &source, &contextVal) if err != nil { t.Fatalf("Setup failed: could not fetch postgres setting '%s': %v", targetSetting, err) } // Replicate the SQL CASE logic for 'requires_restart' field requiresRestart := "No" switch contextVal { case "postmaster": requiresRestart = "Yes" case "sighup": requiresRestart = "No (Reload sufficient)" } expectedObject := map[string]interface{}{ "name": name, "current_value": setting, "unit": unit, "short_desc": shortDesc, "source": source, "requires_restart": requiresRestart, } expectedJSON, _ := json.Marshal([]interface{}{expectedObject}) invokeTcs := []struct { name string requestBody io.Reader wantStatusCode int want string }{ { name: "invoke list_pg_settings with specific setting", requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"setting_name": "%s"}`, targetSetting))), wantStatusCode: http.StatusOK, want: string(expectedJSON), }, { name: "invoke list_pg_settings with non-existent setting", requestBody: bytes.NewBuffer([]byte(`{"setting_name": "non_existent_config_xyz"}`)), wantStatusCode: http.StatusOK, want: `null`, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { const api = "http://127.0.0.1:5000/api/tool/list_pg_settings/invoke" resp, body := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body)) } if tc.wantStatusCode != http.StatusOK { return } var bodyWrapper struct { Result json.RawMessage `json:"result"` } if err := json.Unmarshal(body, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } var resultString string if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { resultString = string(bodyWrapper.Result) } var got, want any if err := json.Unmarshal([]byte(resultString), &got); err != nil { t.Fatalf("failed to unmarshal nested result string: %v", err) } if err := json.Unmarshal([]byte(tc.want), &want); err != nil { t.Fatalf("failed to unmarshal want string: %v", err) } if diff := cmp.Diff(want, got); diff != "" { t.Errorf("Unexpected result (-want +got):\n%s", diff) } }) } } // RunPostgresDatabaseStatsTest tests the database_stats tool by comparing API results // against a direct query to the database. func RunPostgresListDatabaseStatsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { dbName1 := "test_db_stats_" + strings.ReplaceAll(uuid.NewString(), "-", "") dbOwner1 := "test_user_" + strings.ReplaceAll(uuid.NewString(), "-", "") dbName2 := "test_db_stats_" + strings.ReplaceAll(uuid.NewString(), "-", "") dbOwner2 := "test_user_" + strings.ReplaceAll(uuid.NewString(), "-", "") cleanup1 := setUpDatabase(t, ctx, pool, dbName1, dbOwner1) defer cleanup1() cleanup2 := setUpDatabase(t, ctx, pool, dbName2, dbOwner2) defer cleanup2() requiredKeys := map[string]bool{ "database_name": true, "database_owner": true, "default_tablespace": true, "is_connectable": true, } db1Want := map[string]interface{}{ "database_name": dbName1, "database_owner": dbOwner1, "default_tablespace": "pg_default", "is_connectable": true, } db2Want := map[string]interface{}{ "database_name": dbName2, "database_owner": dbOwner2, "default_tablespace": "pg_default", "is_connectable": true, } invokeTcs := []struct { name string requestBody io.Reader wantStatusCode int want []map[string]interface{} }{ { name: "invoke database_stats filtering by specific database name", requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"database_name": "%s"}`, dbName1))), wantStatusCode: http.StatusOK, want: []map[string]interface{}{db1Want}, }, { name: "invoke database_stats filtering by specific owner", requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"database_owner": "%s"}`, dbOwner2))), wantStatusCode: http.StatusOK, want: []map[string]interface{}{db2Want}, }, { name: "filter by tablespace", requestBody: bytes.NewBuffer([]byte(`{"default_tablespace": "pg_default"}`)), wantStatusCode: http.StatusOK, want: []map[string]interface{}{db1Want, db2Want}, }, { name: "sort by size (desc)", requestBody: bytes.NewBuffer([]byte(`{"sort_by": "size"}`)), wantStatusCode: http.StatusOK, want: []map[string]interface{}{db1Want, db2Want}, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { const api = "http://127.0.0.1:5000/api/tool/list_database_stats/invoke" resp, body := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body)) } var bodyWrapper struct { Result json.RawMessage `json:"result"` } if err := json.Unmarshal(body, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } var resultString string if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { resultString = string(bodyWrapper.Result) } var got []map[string]interface{} if err := json.Unmarshal([]byte(resultString), &got); err != nil { t.Fatalf("failed to unmarshal nested result string: %v", err) } // Configuration for comparison opts := []cmp.Option{ // Ensure consistent order based on name for comparison cmpopts.SortSlices(func(a, b map[string]interface{}) bool { return a["database_name"].(string) < b["database_name"].(string) }), // Ignore Volatile Keys which change in every run and only compare the keys in 'requiredKeys' cmpopts.IgnoreMapEntries(func(key string, _ interface{}) bool { return !requiredKeys[key] }), // Ignore Irrelevant Databases cmpopts.IgnoreSliceElements(func(v map[string]interface{}) bool { name, ok := v["database_name"].(string) if !ok { return true } return name != dbName1 && name != dbName2 }), } if diff := cmp.Diff(tc.want, got, opts...); diff != "" { t.Errorf("Unexpected result (-want +got):\n%s", diff) } }) } } func setUpDatabase(t *testing.T, ctx context.Context, pool *pgxpool.Pool, dbName, dbOwner string) func() { _, err := pool.Exec(ctx, fmt.Sprintf("CREATE ROLE %s LOGIN PASSWORD 'password';", dbOwner)) if err != nil { _, _ = pool.Exec(ctx, fmt.Sprintf("DROP ROLE %s;", dbOwner)) t.Fatalf("failed to create %s: %v", dbOwner, err) } _, err = pool.Exec(ctx, fmt.Sprintf("GRANT %s TO current_user;", dbOwner)) if err != nil { t.Fatalf("failed to grant %s to current_user: %v", dbOwner, err) } _, err = pool.Exec(ctx, fmt.Sprintf("CREATE DATABASE %s OWNER %s;", dbName, dbOwner)) if err != nil { t.Fatalf("failed to create %s: %v", dbName, err) } return func() { _, _ = pool.Exec(ctx, fmt.Sprintf("DROP DATABASE IF EXISTS %s;", dbName)) _, _ = pool.Exec(ctx, fmt.Sprintf("DROP ROLE IF EXISTS %s;", dbOwner)) } } func setupPostgresRoles(t *testing.T, ctx context.Context, pool *pgxpool.Pool) (string, string, string, func(t *testing.T)) { t.Helper() suffix := strings.ReplaceAll(uuid.New().String(), "-", "") adminUser := "test_role_admin_" + suffix superUser := "test_role_super_" + suffix normalUser := "test_role_normal_" + suffix createAdminStmt := fmt.Sprintf("CREATE ROLE %s NOLOGIN;", adminUser) if _, err := pool.Exec(ctx, createAdminStmt); err != nil { t.Fatalf("unable to create role %s: %v", adminUser, err) } createSuperUserStmt := fmt.Sprintf("CREATE ROLE %s LOGIN CREATEDB;", superUser) if _, err := pool.Exec(ctx, createSuperUserStmt); err != nil { t.Fatalf("unable to create role %s: %v", superUser, err) } createNormalUserStmt := fmt.Sprintf("CREATE ROLE %s LOGIN;", normalUser) if _, err := pool.Exec(ctx, createNormalUserStmt); err != nil { t.Fatalf("unable to create role %s: %v", normalUser, err) } // Establish Relationships (Admin -> Superuser -> Normal) if _, err := pool.Exec(ctx, fmt.Sprintf("GRANT %s TO %s;", adminUser, superUser)); err != nil { t.Fatalf("unable to grant %s to %s: %v", adminUser, superUser, err) } if _, err := pool.Exec(ctx, fmt.Sprintf("GRANT %s TO %s;", superUser, normalUser)); err != nil { t.Fatalf("unable to grant %s to %s: %v", superUser, normalUser, err) } return adminUser, superUser, normalUser, func(t *testing.T) { t.Helper() _, _ = pool.Exec(ctx, fmt.Sprintf("DROP ROLE IF EXISTS %s;", normalUser)) _, _ = pool.Exec(ctx, fmt.Sprintf("DROP ROLE IF EXISTS %s;", superUser)) _, _ = pool.Exec(ctx, fmt.Sprintf("DROP ROLE IF EXISTS %s;", adminUser)) } } func RunPostgresListRolesTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { adminUser, superUser, normalUser, cleanup := setupPostgresRoles(t, ctx, pool) defer cleanup(t) wantAdmin := map[string]any{ "role_name": adminUser, "connection_limit": float64(-1), "is_superuser": false, "inherits_privileges": true, "can_create_roles": false, "can_create_db": false, "can_login": false, "is_replication_role": false, "bypass_rls": false, "direct_members": []any{superUser}, "member_of": []any{}, } wantSuperUser := map[string]any{ "role_name": superUser, "connection_limit": float64(-1), "is_superuser": false, "inherits_privileges": true, "can_create_roles": false, "can_create_db": true, "can_login": true, "is_replication_role": false, "bypass_rls": false, "direct_members": []any{normalUser}, "member_of": []any{adminUser}, } wantNormalUser := map[string]any{ "role_name": normalUser, "connection_limit": float64(-1), "is_superuser": false, "inherits_privileges": true, "can_create_roles": false, "can_create_db": false, "can_login": true, "is_replication_role": false, "bypass_rls": false, "direct_members": []any{}, "member_of": []any{superUser}, } invokeTcs := []struct { name string requestBody io.Reader wantStatusCode int want []map[string]any }{ { name: "list_roles with filter for created roles", requestBody: bytes.NewBufferString(`{"role_name": "test_role_"}`), wantStatusCode: http.StatusOK, want: []map[string]any{wantAdmin, wantNormalUser, wantSuperUser}, }, { name: "list_roles filter specific role", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"role_name": "%s"}`, superUser)), wantStatusCode: http.StatusOK, want: []map[string]any{wantSuperUser}, }, { name: "list_roles non-existent role", requestBody: bytes.NewBufferString(`{"role_name": "non_existent_role_xyz"}`), wantStatusCode: http.StatusOK, want: nil, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { const api = "http://127.0.0.1:5000/api/tool/list_roles/invoke" resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } if tc.wantStatusCode != http.StatusOK { return } var bodyWrapper struct { Result json.RawMessage `json:"result"` } if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } var resultString string if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { resultString = string(bodyWrapper.Result) } var got []map[string]any if err := json.Unmarshal([]byte(resultString), &got); err != nil { t.Fatalf("failed to unmarshal nested result string: %v, resultString: %s", err, resultString) } gotMap := make(map[string]map[string]any) for _, role := range got { // Remove fields that change every run delete(role, "oid") delete(role, "valid_until") if name, ok := role["role_name"].(string); ok { gotMap[name] = role } } // Check that every role in 'want' exists in 'got' and matches for _, wantRole := range tc.want { roleName, _ := wantRole["role_name"].(string) gotRole, exists := gotMap[roleName] if !exists { t.Errorf("Expected role %q was not found in the response", roleName) continue } if diff := cmp.Diff(wantRole, gotRole); diff != "" { t.Errorf("Role %q mismatch (-want +got):\n%s", roleName, diff) } } // Verify that if want is nil/empty, got is also empty if len(tc.want) == 0 && len(got) != 0 { t.Errorf("Expected empty result, but got %d roles", len(got)) } }) } } // RunMySQLListTablesTest run tests against the mysql-list-tables tool func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNameAuth, expectedOwner string) { var ownerWant any if expectedOwner == "" { ownerWant = nil } else { ownerWant = expectedOwner } type tableInfo struct { ObjectName string `json:"object_name"` SchemaName string `json:"schema_name"` ObjectDetails string `json:"object_details"` } type column struct { DataType string `json:"data_type"` ColumnName string `json:"column_name"` ColumnComment string `json:"column_comment"` ColumnDefault any `json:"column_default"` IsNotNullable int `json:"is_not_nullable"` OrdinalPosition int `json:"ordinal_position"` } type objectDetails struct { Owner any `json:"owner"` Columns []column `json:"columns"` Comment string `json:"comment"` Indexes []any `json:"indexes"` Triggers []any `json:"triggers"` Constraints []any `json:"constraints"` ObjectName string `json:"object_name"` ObjectType string `json:"object_type"` SchemaName string `json:"schema_name"` } paramTableWant := objectDetails{ ObjectName: tableNameParam, SchemaName: databaseName, ObjectType: "TABLE", Owner: ownerWant, Columns: []column{ {DataType: "int", ColumnName: "id", IsNotNullable: 1, OrdinalPosition: 1}, {DataType: "varchar(255)", ColumnName: "name", OrdinalPosition: 2}, }, Indexes: []any{map[string]any{"index_columns": []any{"id"}, "index_name": "PRIMARY", "is_primary": float64(1), "is_unique": float64(1)}}, Triggers: []any{}, Constraints: []any{map[string]any{"constraint_columns": []any{"id"}, "constraint_name": "PRIMARY", "constraint_type": "PRIMARY KEY", "foreign_key_referenced_columns": any(nil), "foreign_key_referenced_table": any(nil), "constraint_definition": ""}}, } authTableWant := objectDetails{ ObjectName: tableNameAuth, SchemaName: databaseName, ObjectType: "TABLE", Owner: ownerWant, Columns: []column{ {DataType: "int", ColumnName: "id", IsNotNullable: 1, OrdinalPosition: 1}, {DataType: "varchar(255)", ColumnName: "name", OrdinalPosition: 2}, {DataType: "varchar(255)", ColumnName: "email", OrdinalPosition: 3}, }, Indexes: []any{map[string]any{"index_columns": []any{"id"}, "index_name": "PRIMARY", "is_primary": float64(1), "is_unique": float64(1)}}, Triggers: []any{}, Constraints: []any{map[string]any{"constraint_columns": []any{"id"}, "constraint_name": "PRIMARY", "constraint_type": "PRIMARY KEY", "foreign_key_referenced_columns": any(nil), "foreign_key_referenced_table": any(nil), "constraint_definition": ""}}, } invokeTcs := []struct { name string requestBody io.Reader wantStatusCode int want any isSimple bool isAllTables bool }{ { name: "invoke list_tables for all tables detailed output", requestBody: bytes.NewBufferString(`{"table_names":""}`), wantStatusCode: http.StatusOK, want: []objectDetails{authTableWant, paramTableWant}, isAllTables: true, }, { name: "invoke list_tables detailed output", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s"}`, tableNameAuth)), wantStatusCode: http.StatusOK, want: []objectDetails{authTableWant}, }, { name: "invoke list_tables simple output", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s", "output_format": "simple"}`, tableNameAuth)), wantStatusCode: http.StatusOK, want: []map[string]any{{"name": tableNameAuth}}, isSimple: true, }, { name: "invoke list_tables with multiple table names", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s,%s"}`, tableNameParam, tableNameAuth)), wantStatusCode: http.StatusOK, want: []objectDetails{authTableWant, paramTableWant}, }, { name: "invoke list_tables with one existing and one non-existent table", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s,non_existent_table"}`, tableNameAuth)), wantStatusCode: http.StatusOK, want: []objectDetails{authTableWant}, }, { name: "invoke list_tables with non-existent table", requestBody: bytes.NewBufferString(`{"table_names": "non_existent_table"}`), wantStatusCode: http.StatusOK, want: []objectDetails{}, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { const api = "http://127.0.0.1:5000/api/tool/list_tables/invoke" resp, body := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body)) } if tc.wantStatusCode != http.StatusOK { return } var bodyWrapper struct { Result json.RawMessage `json:"result"` } if err := json.Unmarshal(body, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } var resultString string if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { resultString = string(bodyWrapper.Result) } var got any if tc.isSimple { var tables []tableInfo if err := json.Unmarshal([]byte(resultString), &tables); err != nil { t.Fatalf("failed to unmarshal outer JSON array into []tableInfo: %v", err) } details := []map[string]any{} for _, table := range tables { var d map[string]any if err := json.Unmarshal([]byte(table.ObjectDetails), &d); err != nil { t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err) } details = append(details, d) } got = details } else { var tables []tableInfo if err := json.Unmarshal([]byte(resultString), &tables); err != nil { t.Fatalf("failed to unmarshal outer JSON array into []tableInfo: %v", err) } details := []objectDetails{} for _, table := range tables { var d objectDetails if err := json.Unmarshal([]byte(table.ObjectDetails), &d); err != nil { t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err) } details = append(details, d) } got = details } opts := []cmp.Option{ cmpopts.SortSlices(func(a, b objectDetails) bool { return a.ObjectName < b.ObjectName }), cmpopts.SortSlices(func(a, b column) bool { return a.ColumnName < b.ColumnName }), cmpopts.SortSlices(func(a, b map[string]any) bool { return a["name"].(string) < b["name"].(string) }), } // Checking only the current database where the test tables are created to avoid brittle tests. if tc.isAllTables { filteredGot := []objectDetails{} if got != nil { for _, item := range got.([]objectDetails) { if item.SchemaName == databaseName { filteredGot = append(filteredGot, item) } } } got = filteredGot } if diff := cmp.Diff(tc.want, got, opts...); diff != "" { t.Errorf("Unexpected result: got %#v, want: %#v", got, tc.want) } }) } } // RunMySQLListActiveQueriesTest run tests against the mysql-list-active-queries tests func RunMySQLListActiveQueriesTest(t *testing.T, ctx context.Context, pool *sql.DB) { type queryListDetails struct { ProcessId any `json:"process_id"` Query string `json:"query"` TrxStarted any `json:"trx_started"` TrxDuration any `json:"trx_duration_seconds"` TrxWaitDuration any `json:"trx_wait_duration_seconds"` QueryTime any `json:"query_time"` TrxState string `json:"trx_state"` ProcessState string `json:"process_state"` User string `json:"user"` TrxRowsLocked any `json:"trx_rows_locked"` TrxRowsModified any `json:"trx_rows_modified"` Db string `json:"db"` } singleQueryWanted := queryListDetails{ ProcessId: any(nil), Query: "SELECT sleep(10)", TrxStarted: any(nil), TrxDuration: any(nil), TrxWaitDuration: any(nil), QueryTime: any(nil), TrxState: "", ProcessState: "User sleep", User: "", TrxRowsLocked: any(nil), TrxRowsModified: any(nil), Db: "", } invokeTcs := []struct { name string requestBody io.Reader clientSleepSecs int waitSecsBeforeCheck int wantStatusCode int want any }{ { name: "invoke list_active_queries when the system is idle", requestBody: bytes.NewBufferString(`{}`), clientSleepSecs: 0, waitSecsBeforeCheck: 0, wantStatusCode: http.StatusOK, want: []queryListDetails(nil), }, { name: "invoke list_active_queries when there is 1 ongoing but lower than the threshold", requestBody: bytes.NewBufferString(`{"min_duration_secs": 100}`), clientSleepSecs: 10, waitSecsBeforeCheck: 1, wantStatusCode: http.StatusOK, want: []queryListDetails(nil), }, { name: "invoke list_active_queries when 1 ongoing query should show up", requestBody: bytes.NewBufferString(`{"min_duration_secs": 5}`), clientSleepSecs: 0, waitSecsBeforeCheck: 5, wantStatusCode: http.StatusOK, want: []queryListDetails{singleQueryWanted}, }, { name: "invoke list_active_queries when 2 ongoing query should show up", requestBody: bytes.NewBufferString(`{"min_duration_secs": 2}`), clientSleepSecs: 10, waitSecsBeforeCheck: 3, wantStatusCode: http.StatusOK, want: []queryListDetails{singleQueryWanted, singleQueryWanted}, }, } var wg sync.WaitGroup for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { if tc.clientSleepSecs > 0 { wg.Add(1) go func() { defer wg.Done() err := pool.PingContext(ctx) if err != nil { t.Errorf("unable to connect to test database: %s", err) return } _, err = pool.ExecContext(ctx, fmt.Sprintf("SELECT sleep(%d);", tc.clientSleepSecs)) if err != nil { t.Errorf("Executing 'SELECT sleep' failed: %s", err) } }() } if tc.waitSecsBeforeCheck > 0 { time.Sleep(time.Duration(tc.waitSecsBeforeCheck) * time.Second) } const api = "http://127.0.0.1:5000/api/tool/list_active_queries/invoke" resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } if tc.wantStatusCode != http.StatusOK { return } var bodyWrapper struct { Result json.RawMessage `json:"result"` } if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } var resultString string if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { resultString = string(bodyWrapper.Result) } var got any var details []queryListDetails if err := json.Unmarshal([]byte(resultString), &details); err != nil { t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err) } got = details if diff := cmp.Diff(tc.want, got, cmp.Comparer(func(a, b queryListDetails) bool { return a.Query == b.Query && a.ProcessState == b.ProcessState })); diff != "" { t.Errorf("Unexpected result: got %#v, want: %#v", got, tc.want) } }) } wg.Wait() } func RunMySQLListTablesMissingUniqueIndexes(t *testing.T, ctx context.Context, pool *sql.DB, databaseName string) { type listDetails struct { TableSchema string `json:"table_schema"` TableName string `json:"table_name"` } // bunch of wanted nonUniqueKeyTableName := "t03_non_unqiue_key_table" noKeyTableName := "t04_no_key_table" nonUniqueKeyTableWant := listDetails{ TableSchema: databaseName, TableName: nonUniqueKeyTableName, } noKeyTableWant := listDetails{ TableSchema: databaseName, TableName: noKeyTableName, } invokeTcs := []struct { name string requestBody io.Reader newTableName string newTablePrimaryKey bool newTableUniqueKey bool newTableNonUniqueKey bool wantStatusCode int want any }{ { name: "invoke list_tables_missing_unique_indexes when nothing to be found", requestBody: bytes.NewBufferString(`{}`), newTableName: "", newTablePrimaryKey: false, newTableUniqueKey: false, newTableNonUniqueKey: false, wantStatusCode: http.StatusOK, want: []listDetails(nil), }, { name: "invoke list_tables_missing_unique_indexes pk table will not show", requestBody: bytes.NewBufferString(`{}`), newTableName: "t01", newTablePrimaryKey: true, newTableUniqueKey: false, newTableNonUniqueKey: false, wantStatusCode: http.StatusOK, want: []listDetails(nil), }, { name: "invoke list_tables_missing_unique_indexes uk table will not show", requestBody: bytes.NewBufferString(`{}`), newTableName: "t02", newTablePrimaryKey: false, newTableUniqueKey: true, newTableNonUniqueKey: false, wantStatusCode: http.StatusOK, want: []listDetails(nil), }, { name: "invoke list_tables_missing_unique_indexes non-unique key only table will show", requestBody: bytes.NewBufferString(`{}`), newTableName: nonUniqueKeyTableName, newTablePrimaryKey: false, newTableUniqueKey: false, newTableNonUniqueKey: true, wantStatusCode: http.StatusOK, want: []listDetails{nonUniqueKeyTableWant}, }, { name: "invoke list_tables_missing_unique_indexes table with no key at all will show", requestBody: bytes.NewBufferString(`{}`), newTableName: noKeyTableName, newTablePrimaryKey: false, newTableUniqueKey: false, newTableNonUniqueKey: false, wantStatusCode: http.StatusOK, want: []listDetails{nonUniqueKeyTableWant, noKeyTableWant}, }, { name: "invoke list_tables_missing_unique_indexes table w/ both pk & uk will not show", requestBody: bytes.NewBufferString(`{}`), newTableName: "t05", newTablePrimaryKey: true, newTableUniqueKey: true, newTableNonUniqueKey: false, wantStatusCode: http.StatusOK, want: []listDetails{nonUniqueKeyTableWant, noKeyTableWant}, }, { name: "invoke list_tables_missing_unique_indexes table w/ uk & nk will not show", requestBody: bytes.NewBufferString(`{}`), newTableName: "t06", newTablePrimaryKey: false, newTableUniqueKey: true, newTableNonUniqueKey: true, wantStatusCode: http.StatusOK, want: []listDetails{nonUniqueKeyTableWant, noKeyTableWant}, }, { name: "invoke list_tables_missing_unique_indexes table w/ pk & nk will not show", requestBody: bytes.NewBufferString(`{}`), newTableName: "t07", newTablePrimaryKey: true, newTableUniqueKey: false, newTableNonUniqueKey: true, wantStatusCode: http.StatusOK, want: []listDetails{nonUniqueKeyTableWant, noKeyTableWant}, }, { name: "invoke list_tables_missing_unique_indexes with a non-exist database, nothing to show", requestBody: bytes.NewBufferString(`{"table_schema": "non-exist-database"}`), newTableName: "", newTablePrimaryKey: false, newTableUniqueKey: false, newTableNonUniqueKey: false, wantStatusCode: http.StatusOK, want: []listDetails(nil), }, { name: "invoke list_tables_missing_unique_indexes with the right database, show everything", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_schema": "%s"}`, databaseName)), newTableName: "", newTablePrimaryKey: false, newTableUniqueKey: false, newTableNonUniqueKey: false, wantStatusCode: http.StatusOK, want: []listDetails{nonUniqueKeyTableWant, noKeyTableWant}, }, { name: "invoke list_tables_missing_unique_indexes with limited output", requestBody: bytes.NewBufferString(`{"limit": 1}`), newTableName: "", newTablePrimaryKey: false, newTableUniqueKey: false, newTableNonUniqueKey: false, wantStatusCode: http.StatusOK, want: []listDetails{nonUniqueKeyTableWant}, }, } createTableHelper := func(t *testing.T, tableName, databaseName string, primaryKey, uniqueKey, nonUniqueKey bool, ctx context.Context, pool *sql.DB) func() { var stmt strings.Builder stmt.WriteString(fmt.Sprintf("CREATE TABLE %s (", tableName)) stmt.WriteString("c1 INT") if primaryKey { stmt.WriteString(" PRIMARY KEY") } stmt.WriteString(", c2 INT, c3 CHAR(8)") if uniqueKey { stmt.WriteString(", UNIQUE(c2)") } if nonUniqueKey { stmt.WriteString(", INDEX(c3)") } stmt.WriteString(")") t.Logf("Creating table: %s", stmt.String()) if _, err := pool.ExecContext(ctx, stmt.String()); err != nil { t.Fatalf("failed executing %s: %v", stmt.String(), err) } return func() { t.Logf("Dropping table: %s", tableName) if _, err := pool.ExecContext(ctx, fmt.Sprintf("DROP TABLE %s", tableName)); err != nil { t.Errorf("failed to drop table %s: %v", tableName, err) } } } var cleanups []func() defer func() { for i := len(cleanups) - 1; i >= 0; i-- { cleanups[i]() } }() for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { if tc.newTableName != "" { cleanup := createTableHelper(t, tc.newTableName, databaseName, tc.newTablePrimaryKey, tc.newTableUniqueKey, tc.newTableNonUniqueKey, ctx, pool) cleanups = append(cleanups, cleanup) } const api = "http://127.0.0.1:5000/api/tool/list_tables_missing_unique_indexes/invoke" resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } if tc.wantStatusCode != http.StatusOK { return } var bodyWrapper struct { Result json.RawMessage `json:"result"` } if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } var resultString string if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { resultString = string(bodyWrapper.Result) } var got any var details []listDetails if err := json.Unmarshal([]byte(resultString), &details); err != nil { t.Fatalf("failed to unmarshal nested listDetails string: %v", err) } got = details if diff := cmp.Diff(tc.want, got, cmp.Comparer(func(a, b listDetails) bool { return a.TableSchema == b.TableSchema && a.TableName == b.TableName })); diff != "" { t.Errorf("Unexpected result: got %#v, want: %#v", got, tc.want) } }) } } func RunMySQLListTableFragmentationTest(t *testing.T, databaseName, tableNameParam, tableNameAuth string) { type tableFragmentationDetails struct { TableSchema string `json:"table_schema"` TableName string `json:"table_name"` DataSize any `json:"data_size"` IndexSize any `json:"index_size"` DataFree any `json:"data_free"` FragmentationPercentage any `json:"fragmentation_percentage"` } paramTableEntryWanted := tableFragmentationDetails{ TableSchema: databaseName, TableName: tableNameParam, DataSize: any(nil), IndexSize: any(nil), DataFree: any(nil), FragmentationPercentage: any(nil), } authTableEntryWanted := tableFragmentationDetails{ TableSchema: databaseName, TableName: tableNameAuth, DataSize: any(nil), IndexSize: any(nil), DataFree: any(nil), FragmentationPercentage: any(nil), } invokeTcs := []struct { name string requestBody io.Reader wantStatusCode int want any }{ { name: "invoke list_table_fragmentation on all, no data_free threshold, expected to have 2 results", requestBody: bytes.NewBufferString(`{"data_free_threshold_bytes": 0}`), wantStatusCode: http.StatusOK, want: []tableFragmentationDetails{authTableEntryWanted, paramTableEntryWanted}, }, { name: "invoke list_table_fragmentation on all, no data_free threshold, limit to 1, expected to have 1 results", requestBody: bytes.NewBufferString(`{"data_free_threshold_bytes": 0, "limit": 1}`), wantStatusCode: http.StatusOK, want: []tableFragmentationDetails{authTableEntryWanted}, }, { name: "invoke list_table_fragmentation on all databases and 1 specific table name, no data_free threshold, expected to have 1 result", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_name": "%s","data_free_threshold_bytes": 0}`, tableNameAuth)), wantStatusCode: http.StatusOK, want: []tableFragmentationDetails{authTableEntryWanted}, }, { name: "invoke list_table_fragmentation on 1 database and 1 specific table name, no data_free threshold, expected to have 1 result", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_schema": "%s", "table_name": "%s", "data_free_threshold_bytes": 0}`, databaseName, tableNameParam)), wantStatusCode: http.StatusOK, want: []tableFragmentationDetails{paramTableEntryWanted}, }, { name: "invoke list_table_fragmentation on 1 database and 1 specific table name, high data_free threshold, expected to have 0 result", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_schema": "%s", "table_name": "%s", "data_free_threshold_bytes": 1000000000}`, databaseName, tableNameParam)), wantStatusCode: http.StatusOK, want: []tableFragmentationDetails(nil), }, { name: "invoke list_table_fragmentation on 1 non-exist database, no data_free threshold, expected to have 0 result", requestBody: bytes.NewBufferString(`{"table_schema": "non_existent_database", "data_free_threshold_bytes": 0}`), wantStatusCode: http.StatusOK, want: []tableFragmentationDetails(nil), }, { name: "invoke list_table_fragmentation on 1 non-exist table, no data_free threshold, expected to have 0 result", requestBody: bytes.NewBufferString(`{"table_name": "non_existent_table", "data_free_threshold_bytes": 0}`), wantStatusCode: http.StatusOK, want: []tableFragmentationDetails(nil), }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { const api = "http://127.0.0.1:5000/api/tool/list_table_fragmentation/invoke" resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } if tc.wantStatusCode != http.StatusOK { return } var bodyWrapper struct { Result json.RawMessage `json:"result"` } if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } var resultString string if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { resultString = string(bodyWrapper.Result) } var got any var details []tableFragmentationDetails if err := json.Unmarshal([]byte(resultString), &details); err != nil { t.Fatalf("failed to unmarshal outer JSON array into []tableInfo: %v", err) } got = details if diff := cmp.Diff(tc.want, got, cmp.Comparer(func(a, b tableFragmentationDetails) bool { return a.TableSchema == b.TableSchema && a.TableName == b.TableName })); diff != "" { t.Errorf("Unexpected result: got %#v, want: %#v", got, tc.want) } }) } } func RunMySQLGetQueryPlanTest(t *testing.T, ctx context.Context, pool *sql.DB, databaseName, tableNameParam string) { // Create a simple query to explain query := fmt.Sprintf("SELECT * FROM %s", tableNameParam) invokeTcs := []struct { name string requestBody io.Reader wantStatusCode int checkResult func(t *testing.T, result any) }{ { name: "invoke get_query_plan with valid query", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"sql_statement": "%s"}`, query)), wantStatusCode: http.StatusOK, checkResult: func(t *testing.T, result any) { resultMap, ok := result.(map[string]any) if !ok { t.Fatalf("result should be a map, got %T", result) } if _, ok := resultMap["query_block"]; !ok { t.Errorf("result should contain 'query_block', got %v", resultMap) } }, }, { name: "invoke get_query_plan with invalid query", requestBody: bytes.NewBufferString(`{"sql_statement": "SELECT * FROM non_existent_table"}`), wantStatusCode: http.StatusBadRequest, checkResult: nil, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { const api = "http://127.0.0.1:5000/api/tool/get_query_plan/invoke" resp, respBytes := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBytes)) } if tc.wantStatusCode != http.StatusOK { return } var bodyWrapper map[string]json.RawMessage if err := json.Unmarshal(respBytes, &bodyWrapper); err != nil { t.Fatalf("error parsing response wrapper: %s, body: %s", err, string(respBytes)) } resultJSON, ok := bodyWrapper["result"] if !ok { t.Fatal("unable to find 'result' in response body") } var resultString string if err := json.Unmarshal(resultJSON, &resultString); err != nil { if string(resultJSON) == "null" { resultString = "null" } else { t.Fatalf("'result' is not a JSON-encoded string: %s", err) } } var got map[string]any if err := json.Unmarshal([]byte(resultString), &got); err != nil { t.Fatalf("failed to unmarshal actual result string: %v", err) } if tc.checkResult != nil { tc.checkResult(t, got) } }) } } // RunMSSQLListTablesTest run tests againsts the mssql-list-tables tools. func RunMSSQLListTablesTest(t *testing.T, tableNameParam, tableNameAuth string) { // TableNameParam columns to construct want. const paramTableColumns = `[ {"column_name": "id", "data_type": "INT", "column_ordinal_position": 1, "is_not_nullable": true}, {"column_name": "name", "data_type": "VARCHAR(255)", "column_ordinal_position": 2, "is_not_nullable": false} ]` // TableNameAuth columns to construct want const authTableColumns = `[ {"column_name": "id", "data_type": "INT", "column_ordinal_position": 1, "is_not_nullable": true}, {"column_name": "name", "data_type": "VARCHAR(255)", "column_ordinal_position": 2, "is_not_nullable": false}, {"column_name": "email", "data_type": "VARCHAR(255)", "column_ordinal_position": 3, "is_not_nullable": false} ]` const ( // Template to construct detailed output want. detailedObjectTemplate = `{ "schema_name": "dbo", "object_name": "%[1]s", "object_details": { "owner": "dbo", "triggers": [], "columns": %[2]s, "object_name": "%[1]s", "object_type": "TABLE", "schema_name": "dbo" } }` // Template to construct simple output want simpleObjectTemplate = `{"object_name":"%s", "schema_name":"dbo", "object_details":{"name":"%s"}}` ) // Helper to build json for detailed want getDetailedWant := func(tableName, columnJSON string) string { return fmt.Sprintf(detailedObjectTemplate, tableName, columnJSON) } // Helper to build template for simple want getSimpleWant := func(tableName string) string { return fmt.Sprintf(simpleObjectTemplate, tableName, tableName) } invokeTcs := []struct { name string api string requestBody string wantStatusCode int want string isAllTables bool }{ { name: "invoke list_tables for all tables detailed output", api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: `{"table_names": ""}`, wantStatusCode: http.StatusOK, want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)), isAllTables: true, }, { name: "invoke list_tables for all tables simple output", api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: `{"table_names": "", "output_format": "simple"}`, wantStatusCode: http.StatusOK, want: fmt.Sprintf("[%s,%s]", getSimpleWant(tableNameAuth), getSimpleWant(tableNameParam)), isAllTables: true, }, { name: "invoke list_tables detailed output", api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: fmt.Sprintf(`{"table_names": "%s"}`, tableNameAuth), wantStatusCode: http.StatusOK, want: fmt.Sprintf("[%s]", getDetailedWant(tableNameAuth, authTableColumns)), }, { name: "invoke list_tables simple output", api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: fmt.Sprintf(`{"table_names": "%s", "output_format": "simple"}`, tableNameAuth), wantStatusCode: http.StatusOK, want: fmt.Sprintf("[%s]", getSimpleWant(tableNameAuth)), }, { name: "invoke list_tables with invalid output format", api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: `{"table_names": "", "output_format": "abcd"}`, wantStatusCode: http.StatusBadRequest, }, { name: "invoke list_tables with malformed table_names parameter", api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: `{"table_names": 12345, "output_format": "detailed"}`, wantStatusCode: http.StatusBadRequest, }, { name: "invoke list_tables with multiple table names", api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: fmt.Sprintf(`{"table_names": "%s,%s"}`, tableNameParam, tableNameAuth), wantStatusCode: http.StatusOK, want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)), }, { name: "invoke list_tables with non-existent table", api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: `{"table_names": "non_existent_table"}`, wantStatusCode: http.StatusOK, want: `[]`, }, { name: "invoke list_tables with one existing and one non-existent table", api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: fmt.Sprintf(`{"table_names": "%s,non_existent_table"}`, tableNameParam), wantStatusCode: http.StatusOK, want: fmt.Sprintf("[%s]", getDetailedWant(tableNameParam, paramTableColumns)), }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { resp, respBytes := RunRequest(t, http.MethodPost, tc.api, bytes.NewBuffer([]byte(tc.requestBody)), nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("response status code is not %d, got %d: %s", tc.wantStatusCode, resp.StatusCode, string(respBytes)) } if tc.wantStatusCode == http.StatusOK { var bodyWrapper map[string]json.RawMessage if err := json.Unmarshal(respBytes, &bodyWrapper); err != nil { t.Fatalf("error parsing response wrapper: %s, body: %s", err, string(respBytes)) } resultJSON, ok := bodyWrapper["result"] if !ok { t.Fatal("unable to find 'result' in response body") } var resultString string if err := json.Unmarshal(resultJSON, &resultString); err != nil { if string(resultJSON) == "null" { resultString = "null" } else { t.Fatalf("'result' is not a JSON-encoded string: %s", err) } } var got, want []any if err := json.Unmarshal([]byte(resultString), &got); err != nil { t.Fatalf("failed to unmarshal actual result string: %v", err) } if err := json.Unmarshal([]byte(tc.want), &want); err != nil { t.Fatalf("failed to unmarshal expected want string: %v", err) } for _, item := range got { itemMap, ok := item.(map[string]any) if !ok { continue } detailsStr, ok := itemMap["object_details"].(string) if !ok { continue } var detailsMap map[string]any if err := json.Unmarshal([]byte(detailsStr), &detailsMap); err != nil { t.Fatalf("failed to unmarshal nested object_details string: %v", err) } // clean unpredictable fields delete(detailsMap, "constraints") delete(detailsMap, "indexes") itemMap["object_details"] = detailsMap } // Checking only the default dbo schema where the test tables are created to avoid brittle tests. if tc.isAllTables { var filteredGot []any for _, item := range got { if tableMap, ok := item.(map[string]interface{}); ok { if schema, ok := tableMap["schema_name"]; ok && schema == "dbo" { filteredGot = append(filteredGot, item) } } } got = filteredGot } sort.SliceStable(got, func(i, j int) bool { return fmt.Sprintf("%v", got[i]) < fmt.Sprintf("%v", got[j]) }) sort.SliceStable(want, func(i, j int) bool { return fmt.Sprintf("%v", want[i]) < fmt.Sprintf("%v", want[j]) }) if !reflect.DeepEqual(got, want) { gotJSON, _ := json.MarshalIndent(got, "", " ") wantJSON, _ := json.MarshalIndent(want, "", " ") t.Errorf("Unexpected result:\ngot:\n%s\n\nwant:\n%s", string(gotJSON), string(wantJSON)) } } }) } } // RunPostgresListLocksTest runs tests for the postgres list-locks tool func RunPostgresListLocksTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { type lockDetails struct { Pid any `json:"pid"` Usename string `json:"usename"` Database string `json:"database"` RelName string `json:"relname"` LockType string `json:"locktype"` Mode string `json:"mode"` Granted bool `json:"granted"` FastPath bool `json:"fastpath"` VirtualXid any `json:"virtualxid"` TransactionId any `json:"transactionid"` ClassId any `json:"classid"` ObjId any `json:"objid"` ObjSubId any `json:"objsubid"` PageNumber any `json:"page"` TupleNumber any `json:"tuple"` VirtualBlock any `json:"virtualblock"` BlockNumber any `json:"blockno"` } invokeTcs := []struct { name string requestBody io.Reader wantStatusCode int expectResults bool }{ { name: "invoke list_locks with no arguments", requestBody: bytes.NewBuffer([]byte(`{}`)), wantStatusCode: http.StatusOK, expectResults: false, // locks may or may not exist }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { const api = "http://127.0.0.1:5000/api/tool/list_locks/invoke" resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } if tc.wantStatusCode != http.StatusOK { return } var bodyWrapper struct { Result json.RawMessage `json:"result"` } if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } var resultString string if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { resultString = string(bodyWrapper.Result) } var got []lockDetails if resultString != "null" { if err := json.Unmarshal([]byte(resultString), &got); err != nil { t.Fatalf("failed to unmarshal result: %v, result string: %s", err, resultString) } } // Verify that if results exist, they have the expected structure for _, lock := range got { if lock.LockType == "" { t.Errorf("lock type should not be empty") } } }) } } // RunPostgresLongRunningTransactionsTest runs tests for the postgres long-running-transactions tool func RunPostgresLongRunningTransactionsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { type transactionDetails struct { Pid any `json:"pid"` Usename string `json:"usename"` Database string `json:"database"` ApplicationName string `json:"application_name"` XactStart any `json:"xact_start"` XactDurationSecs any `json:"xact_duration_secs"` IdleInTransaction string `json:"idle_in_transaction"` Query string `json:"query"` } invokeTcs := []struct { name string requestBody io.Reader wantStatusCode int }{ { name: "invoke long_running_transactions with default threshold", requestBody: bytes.NewBuffer([]byte(`{}`)), wantStatusCode: http.StatusOK, }, { name: "invoke long_running_transactions with custom threshold", requestBody: bytes.NewBuffer([]byte(`{"min_transaction_duration_secs": 3600}`)), wantStatusCode: http.StatusOK, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { const api = "http://127.0.0.1:5000/api/tool/long_running_transactions/invoke" resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } if tc.wantStatusCode != http.StatusOK { return } var bodyWrapper struct { Result json.RawMessage `json:"result"` } if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } var resultString string if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { resultString = string(bodyWrapper.Result) } var got []transactionDetails if resultString != "null" { if err := json.Unmarshal([]byte(resultString), &got); err != nil { t.Fatalf("failed to unmarshal result: %v, result string: %s", err, resultString) } } // Verify that if results exist, they have the expected structure for _, tx := range got { if tx.XactDurationSecs == nil { t.Errorf("transaction duration should not be null for long-running transactions") } } }) } } // RunPostgresReplicationStatsTest runs tests for the postgres replication-stats tool func RunPostgresReplicationStatsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { type replicationStats struct { ClientAddr string `json:"client_addr"` Username string `json:"usename"` ApplicationName string `json:"application_name"` ClientHostname string `json:"client_hostname"` BackendStart any `json:"backend_start"` State string `json:"state"` SyncState string `json:"sync_state"` ReplyTime any `json:"reply_time"` FlushLsn string `json:"flush_lsn"` ReplayLsn string `json:"replay_lsn"` WriteLag any `json:"write_lag"` FlushLag any `json:"flush_lag"` ReplayLag any `json:"replay_lag"` SyncPriority any `json:"sync_priority"` ReplicationSlotName any `json:"slot_name"` IsStreaming bool `json:"is_streaming"` } invokeTcs := []struct { name string requestBody io.Reader wantStatusCode int }{ { name: "invoke replication_stats with no arguments", requestBody: bytes.NewBuffer([]byte(`{}`)), wantStatusCode: http.StatusOK, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { const api = "http://127.0.0.1:5000/api/tool/replication_stats/invoke" resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } if tc.wantStatusCode != http.StatusOK { return } var bodyWrapper struct { Result json.RawMessage `json:"result"` } if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } var resultString string if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { resultString = string(bodyWrapper.Result) } var got []replicationStats if resultString != "null" { if err := json.Unmarshal([]byte(resultString), &got); err != nil { t.Fatalf("failed to unmarshal result: %v, result string: %s", err, resultString) } } // Verify that if results exist, they have the expected structure for _, stat := range got { if stat.State == "" { t.Errorf("replication state should not be empty") } } }) } } func RunPostgresGetColumnCardinalityTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { schemaName := "testschema_" + strings.ReplaceAll(uuid.New().String(), "-", "") tableName := "table1_" + strings.ReplaceAll(uuid.New().String(), "-", "") cleanup := setupPostgresSchemas(t, ctx, pool, schemaName) defer cleanup() // Create table with multiple columns createTableStmt := fmt.Sprintf(` CREATE TABLE %s.%s ( id SERIAL PRIMARY KEY, email VARCHAR(100) UNIQUE, name VARCHAR(50), status VARCHAR(20), created_at TIMESTAMP ) `, schemaName, tableName) if _, err := pool.Exec(ctx, createTableStmt); err != nil { t.Fatalf("unable to create table: %s", err) } // Insert larger sample data to ensure statistics are collected insertStmt := fmt.Sprintf(` INSERT INTO %s.%s (email, name, status, created_at) VALUES ('user1@example.com', 'Alice', 'active', NOW()), ('user2@example.com', 'Bob', 'inactive', NOW()), ('user3@example.com', 'Charlie', 'active', NOW()), ('user4@example.com', 'David', 'active', NOW()), ('user5@example.com', 'Eve', 'inactive', NOW()), ('user6@example.com', 'Frank', 'active', NOW()), ('user7@example.com', 'Grace', 'inactive', NOW()), ('user8@example.com', 'Henry', 'active', NOW()), ('user9@example.com', 'Ivy', 'active', NOW()), ('user10@example.com', 'Jack', 'inactive', NOW()) `, schemaName, tableName) if _, err := pool.Exec(ctx, insertStmt); err != nil { t.Fatalf("unable to insert data: %s", err) } // Run ANALYZE to update statistics analyzeStmt := fmt.Sprintf(`ANALYZE %s.%s`, schemaName, tableName) if _, err := pool.Exec(ctx, analyzeStmt); err != nil { t.Fatalf("unable to run ANALYZE: %s", err) } invokeTcs := []struct { name string requestBody io.Reader wantStatusCode int shouldHaveData bool // Whether we expect data in the response }{ { name: "get cardinality for a specific column", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s", "table_name": "%s", "column_name": "email"}`, schemaName, tableName)), wantStatusCode: http.StatusOK, shouldHaveData: true, }, { name: "get cardinality for all columns", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s", "table_name": "%s"}`, schemaName, tableName)), wantStatusCode: http.StatusOK, shouldHaveData: true, }, { name: "get cardinality with non-existent column", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s", "table_name": "%s", "column_name": "non_existent"}`, schemaName, tableName)), wantStatusCode: http.StatusOK, shouldHaveData: false, }, { name: "get cardinality with non-existent schema", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "non_existent_schema", "table_name": "%s"}`, tableName)), wantStatusCode: http.StatusOK, shouldHaveData: false, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { const api = "http://127.0.0.1:5000/api/tool/get_column_cardinality/invoke" resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } if tc.wantStatusCode != http.StatusOK { return } var bodyWrapper struct { Result json.RawMessage `json:"result"` } if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } var resultString string if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { resultString = string(bodyWrapper.Result) } var got []map[string]any if err := json.Unmarshal([]byte(resultString), &got); err != nil { t.Fatalf("failed to unmarshal nested result string: %v", err) } // Verify that we got the expected data presence if tc.shouldHaveData { if len(got) == 0 { t.Logf("warning: expected data but got empty result. This can happen if pg_stats is not populated yet.") return } // Verify column names and cardinality values for _, row := range got { columnName, ok := row["column_name"].(string) if !ok { t.Fatalf("column_name is not a string: %v", row["column_name"]) } // Check that estimated_cardinality is present and is a number cardinality, ok := row["estimated_cardinality"] if !ok { t.Fatalf("estimated_cardinality is missing for column %s", columnName) } // Convert to float64 for numeric checks cardinalityFloat, ok := cardinality.(float64) if !ok { t.Fatalf("estimated_cardinality is not a number: %v", cardinality) } // Cardinality should be >= 0 if cardinalityFloat < 0 { t.Errorf("cardinality for column %s is negative: %v", columnName, cardinalityFloat) } } } else { if len(got) != 0 { t.Errorf("expected no data but got: %v", got) } } }) } } func createPostgresExtension(t *testing.T, ctx context.Context, pool *pgxpool.Pool, extensionName string) func() { createExtensionCmd := fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s", extensionName) _, err := pool.Exec(ctx, createExtensionCmd) if err != nil { t.Fatalf("failed to create extension: %v", err) } return func() { dropExtensionCmd := fmt.Sprintf("DROP EXTENSION IF EXISTS %s", extensionName) _, err := pool.Exec(ctx, dropExtensionCmd) if err != nil { t.Fatalf("failed to drop extension: %v", err) } } } func RunPostgresListQueryStatsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { // Insert a simple query by running a SELECT statement // This will record statistics in pg_stat_statements selectStmt := "SELECT 1 as test_query" if _, err := pool.Exec(ctx, selectStmt); err != nil { t.Logf("warning: unable to execute test query: %s", err) } dropExtensionFunc := createPostgresExtension(t, ctx, pool, "pg_stat_statements") defer dropExtensionFunc() type queryStatDetails struct { Datname string `json:"datname"` Query string `json:"query"` Calls any `json:"calls"` TotalExecTime any `json:"total_exec_time"` MinExecTime any `json:"min_exec_time"` MaxExecTime any `json:"max_exec_time"` MeanExecTime any `json:"mean_exec_time"` Rows any `json:"rows"` SharedBlksHit any `json:"shared_blks_hit"` SharedBlksRead any `json:"shared_blks_read"` } invokeTcs := []struct { name string requestBody io.Reader wantStatusCode int }{ { name: "list query stats with default limit", requestBody: bytes.NewBufferString(`{}`), wantStatusCode: http.StatusOK, }, { name: "list query stats with custom limit", requestBody: bytes.NewBufferString(`{"limit": 10}`), wantStatusCode: http.StatusOK, }, { name: "list query stats for specific database", requestBody: bytes.NewBufferString(`{"database_name": "postgres"}`), wantStatusCode: http.StatusOK, }, { name: "list query stats with non-existent database name", requestBody: bytes.NewBufferString(`{"database_name": "non_existent_db_xyz"}`), wantStatusCode: http.StatusOK, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { const api = "http://127.0.0.1:5000/api/tool/list_query_stats/invoke" resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } if tc.wantStatusCode != http.StatusOK { return } var bodyWrapper struct { Result json.RawMessage `json:"result"` } if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } var resultString string if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { resultString = string(bodyWrapper.Result) } var got []map[string]any if err := json.Unmarshal([]byte(resultString), &got); err != nil { t.Fatalf("failed to unmarshal nested result string: %v, resultString: %s", err, resultString) } // For databases with pg_stat_statements available, verify response structure if len(got) > 0 { // Verify the response has the expected fields requiredFields := []string{"datname", "query", "calls", "total_exec_time", "min_exec_time", "max_exec_time", "mean_exec_time", "rows", "shared_blks_hit", "shared_blks_read"} for _, field := range requiredFields { if _, ok := got[0][field]; !ok { t.Errorf("missing expected field: %s in result: %v", field, got[0]) } } // Verify data types var stat queryStatDetails statData, _ := json.Marshal(got[0]) if err := json.Unmarshal(statData, &stat); err != nil { t.Logf("warning: unable to unmarshal query stat: %v", err) } // Verify that results are ordered by total_exec_time (descending) if len(got) > 1 { for i := 0; i < len(got)-1; i++ { currentTime, ok1 := got[i]["total_exec_time"].(float64) nextTime, ok2 := got[i+1]["total_exec_time"].(float64) if ok1 && ok2 && currentTime < nextTime { t.Logf("warning: results may not be ordered by total_exec_time descending: %f vs %f", currentTime, nextTime) } } } } }) } } // RunPostgresListTableStatsTest runs tests for the postgres list-table-stats tool func RunPostgresListTableStatsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { type tableStatsDetails struct { SchemaName string `json:"schema_name"` TableName string `json:"table_name"` Owner string `json:"owner"` TotalSizeBytes any `json:"total_size_bytes"` SeqScan any `json:"seq_scan"` IdxScan any `json:"idx_scan"` IdxScanRatioPercent float64 `json:"idx_scan_ratio_percent"` LiveRows any `json:"live_rows"` DeadRows any `json:"dead_rows"` DeadRowRatioPercent float64 `json:"dead_row_ratio_percent"` NTupIns any `json:"n_tup_ins"` NTupUpd any `json:"n_tup_upd"` NTupDel any `json:"n_tup_del"` LastVacuum any `json:"last_vacuum"` LastAutovacuum any `json:"last_autovacuum"` LastAutoanalyze any `json:"last_autoanalyze"` } // Create a test table to generate statistics testTableName := "test_list_table_stats_" + strings.ReplaceAll(uuid.New().String(), "-", "") createTableStmt := fmt.Sprintf(` CREATE TABLE %s ( id SERIAL PRIMARY KEY, name VARCHAR(100), email VARCHAR(100) ) `, testTableName) if _, err := pool.Exec(ctx, createTableStmt); err != nil { t.Fatalf("unable to create test table: %s", err) } defer func() { dropTableStmt := fmt.Sprintf("DROP TABLE IF EXISTS %s", testTableName) if _, err := pool.Exec(ctx, dropTableStmt); err != nil { t.Logf("warning: unable to drop test table: %v", err) } }() // Insert some data to generate statistics insertStmt := fmt.Sprintf(` INSERT INTO %s (name, email) VALUES ('Alice', 'alice@example.com'), ('Bob', 'bob@example.com'), ('Charlie', 'charlie@example.com'), ('David', 'david@example.com'), ('Eve', 'eve@example.com') `, testTableName) if _, err := pool.Exec(ctx, insertStmt); err != nil { t.Fatalf("unable to insert test data: %s", err) } // Run some sequential scans to generate statistics for i := 0; i < 3; i++ { selectStmt := fmt.Sprintf("SELECT * FROM %s WHERE name = 'Alice'", testTableName) if _, err := pool.Exec(ctx, selectStmt); err != nil { t.Logf("warning: unable to execute select: %v", err) } } // Run ANALYZE to update statistics analyzeStmt := fmt.Sprintf("ANALYZE %s", testTableName) if _, err := pool.Exec(ctx, analyzeStmt); err != nil { t.Logf("warning: unable to run ANALYZE: %v", err) } invokeTcs := []struct { name string requestBody io.Reader wantStatusCode int shouldHaveData bool filterTable bool }{ { name: "list table stats with no arguments (default limit)", requestBody: bytes.NewBufferString(`{}`), wantStatusCode: http.StatusOK, shouldHaveData: false, // may or may not have data depending on what's in the database }, { name: "list table stats with default limit", requestBody: bytes.NewBufferString(`{"schema_name": "public"}`), wantStatusCode: http.StatusOK, shouldHaveData: false, }, { name: "list table stats filtering by specific table", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_name": "%s"}`, testTableName)), wantStatusCode: http.StatusOK, shouldHaveData: true, filterTable: true, }, { name: "list table stats with custom limit", requestBody: bytes.NewBufferString(`{"limit": 10}`), wantStatusCode: http.StatusOK, shouldHaveData: false, }, { name: "list table stats sorted by size", requestBody: bytes.NewBufferString(`{"sort_by": "size", "limit": 5}`), wantStatusCode: http.StatusOK, shouldHaveData: false, }, { name: "list table stats sorted by seq_scan", requestBody: bytes.NewBufferString(`{"sort_by": "seq_scan", "limit": 5}`), wantStatusCode: http.StatusOK, shouldHaveData: false, }, { name: "list table stats sorted by idx_scan", requestBody: bytes.NewBufferString(`{"sort_by": "idx_scan", "limit": 5}`), wantStatusCode: http.StatusOK, shouldHaveData: false, }, { name: "list table stats sorted by dead_rows", requestBody: bytes.NewBufferString(`{"sort_by": "dead_rows", "limit": 5}`), wantStatusCode: http.StatusOK, shouldHaveData: false, }, { name: "list table stats with non-existent table filter", requestBody: bytes.NewBufferString(`{"table_name": "non_existent_table_xyz"}`), wantStatusCode: http.StatusOK, shouldHaveData: false, }, { name: "list table stats with non-existent schema filter", requestBody: bytes.NewBufferString(`{"schema_name": "non_existent_schema_xyz"}`), wantStatusCode: http.StatusOK, shouldHaveData: false, }, { name: "list table stats with owner filter", requestBody: bytes.NewBufferString(`{"owner": "postgres"}`), wantStatusCode: http.StatusOK, shouldHaveData: false, }, } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { const api = "http://127.0.0.1:5000/api/tool/list_table_stats/invoke" resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } if tc.wantStatusCode != http.StatusOK { return } var bodyWrapper struct { Result json.RawMessage `json:"result"` } if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } var resultString string if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { resultString = string(bodyWrapper.Result) } var got []tableStatsDetails if resultString != "null" { if err := json.Unmarshal([]byte(resultString), &got); err != nil { t.Fatalf("failed to unmarshal result: %v, result string: %s", err, resultString) } } // Verify expected data presence if tc.shouldHaveData { if len(got) == 0 { t.Fatalf("expected data but got empty result") } // Verify the test table is in results found := false for _, row := range got { if row.TableName == testTableName { found = true // Verify expected fields are present if row.SchemaName == "" { t.Errorf("schema_name should not be empty") } if row.Owner == "" { t.Errorf("owner should not be empty") } if row.TotalSizeBytes == nil { t.Errorf("total_size_bytes should not be null") } if row.LiveRows == nil { t.Errorf("live_rows should not be null") } break } } if !found { t.Errorf("test table %s not found in results", testTableName) } } else if tc.filterTable { // For filtered queries that shouldn't find anything if len(got) != 0 { t.Logf("warning: expected no data but got: %v", len(got)) } } // Verify result structure and data types for _, stat := range got { // Verify schema_name and table_name are strings if stat.SchemaName == "" && stat.TableName != "" { t.Errorf("schema_name is empty for table %s", stat.TableName) } // Verify numeric fields are valid if stat.IdxScanRatioPercent < 0 || stat.IdxScanRatioPercent > 100 { t.Errorf("idx_scan_ratio_percent should be between 0 and 100, got %f", stat.IdxScanRatioPercent) } if stat.DeadRowRatioPercent < 0 || stat.DeadRowRatioPercent > 100 { t.Errorf("dead_row_ratio_percent should be between 0 and 100, got %f", stat.DeadRowRatioPercent) } } // Verify sorting for specific sort_by options if tc.name == "list table stats sorted by size" && len(got) > 1 { for i := 0; i < len(got)-1; i++ { current, ok1 := got[i].TotalSizeBytes.(float64) next, ok2 := got[i+1].TotalSizeBytes.(float64) if ok1 && ok2 && current < next { t.Logf("warning: results may not be sorted by total_size_bytes descending") } } } }) } } // RunRequest is a helper function to send HTTP requests and return the response func RunRequest(t *testing.T, method, url string, body io.Reader, headers map[string]string) (*http.Response, []byte) { // Send request req, err := http.NewRequest(method, url, body) if err != nil { t.Fatalf("unable to create request: %s", err) } req.Header.Set("Content-type", "application/json") for k, v := range headers { req.Header.Set(k, v) } resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("unable to send request: %s", err) } respBody, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("unable to read request body: %s", err) } defer resp.Body.Close() return resp, respBody }

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/googleapis/genai-toolbox'

If you have feedback or need assistance with the MCP directory API, please join our Discord server