Skip to main content
Glama
client_test.go15.4 kB
package trino import ( "context" "reflect" "strings" "testing" "github.com/tuannvm/mcp-trino/internal/config" oauth "github.com/tuannvm/oauth-mcp-proxy" ) func TestFilterCatalogs(t *testing.T) { tests := []struct { name string allowedCatalogs []string input []string expected []string }{ { name: "No allowlist - return all", allowedCatalogs: nil, input: []string{"hive", "postgresql", "mysql"}, expected: []string{"hive", "postgresql", "mysql"}, }, { name: "Empty allowlist - return all", allowedCatalogs: []string{}, input: []string{"hive", "postgresql", "mysql"}, expected: []string{"hive", "postgresql", "mysql"}, }, { name: "Filter to allowed catalogs", allowedCatalogs: []string{"hive", "postgresql"}, input: []string{"hive", "postgresql", "mysql", "oracle"}, expected: []string{"hive", "postgresql"}, }, { name: "Case insensitive filtering", allowedCatalogs: []string{"HIVE", "PostgreSQL"}, input: []string{"hive", "postgresql", "mysql"}, expected: []string{"hive", "postgresql"}, }, { name: "No matches", allowedCatalogs: []string{"nonexistent"}, input: []string{"hive", "postgresql", "mysql"}, expected: []string{}, }, { name: "Partial matches", allowedCatalogs: []string{"hive"}, input: []string{"hive", "postgresql", "mysql"}, expected: []string{"hive"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { client := &Client{ config: &config.TrinoConfig{ AllowedCatalogs: tt.allowedCatalogs, }, } result := client.filterCatalogs(tt.input) if !reflect.DeepEqual(result, tt.expected) { t.Errorf("filterCatalogs() = %v, want %v", result, tt.expected) } }) } } func TestFilterSchemas(t *testing.T) { tests := []struct { name string allowedSchemas []string catalog string input []string expected []string }{ { name: "No allowlist - return all", allowedSchemas: nil, catalog: "hive", input: []string{"analytics", "marts", "staging"}, expected: []string{"analytics", "marts", "staging"}, }, { name: "Filter to allowed schemas", allowedSchemas: []string{"hive.analytics", "hive.marts"}, catalog: "hive", input: []string{"analytics", "marts", "staging", "raw"}, expected: []string{"analytics", "marts"}, }, { name: "Case insensitive filtering", allowedSchemas: []string{"HIVE.ANALYTICS", "hive.marts"}, catalog: "hive", input: []string{"analytics", "marts", "staging"}, expected: []string{"analytics", "marts"}, }, { name: "Different catalog - no matches", allowedSchemas: []string{"hive.analytics", "hive.marts"}, catalog: "postgresql", input: []string{"public", "private"}, expected: []string{}, }, { name: "Mixed catalogs in allowlist", allowedSchemas: []string{"hive.analytics", "postgresql.public"}, catalog: "hive", input: []string{"analytics", "marts"}, expected: []string{"analytics"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { client := &Client{ config: &config.TrinoConfig{ AllowedSchemas: tt.allowedSchemas, }, } result := client.filterSchemas(tt.input, tt.catalog) if !reflect.DeepEqual(result, tt.expected) { t.Errorf("filterSchemas() = %v, want %v", result, tt.expected) } }) } } func TestFilterTables(t *testing.T) { tests := []struct { name string allowedTables []string catalog string schema string input []string expected []string }{ { name: "No allowlist - return all", allowedTables: nil, catalog: "hive", schema: "analytics", input: []string{"users", "events", "sessions"}, expected: []string{"users", "events", "sessions"}, }, { name: "Filter to allowed tables", allowedTables: []string{"hive.analytics.users", "hive.analytics.events"}, catalog: "hive", schema: "analytics", input: []string{"users", "events", "sessions", "temp"}, expected: []string{"users", "events"}, }, { name: "Case insensitive filtering", allowedTables: []string{"HIVE.ANALYTICS.USERS", "hive.analytics.events"}, catalog: "hive", schema: "analytics", input: []string{"users", "events", "sessions"}, expected: []string{"users", "events"}, }, { name: "Different catalog/schema - no matches", allowedTables: []string{"hive.analytics.users"}, catalog: "postgresql", schema: "public", input: []string{"orders", "customers"}, expected: []string{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { client := &Client{ config: &config.TrinoConfig{ AllowedTables: tt.allowedTables, }, } result := client.filterTables(tt.input, tt.catalog, tt.schema) if !reflect.DeepEqual(result, tt.expected) { t.Errorf("filterTables() = %v, want %v", result, tt.expected) } }) } } func TestIsCatalogAllowed(t *testing.T) { client := &Client{ config: &config.TrinoConfig{ AllowedCatalogs: []string{"hive", "postgresql", "MySQL"}, }, } tests := []struct { catalog string expected bool }{ {"hive", true}, {"postgresql", true}, {"mysql", true}, // Case insensitive {"MySQL", true}, {"HIVE", true}, {"oracle", false}, {"sqlserver", false}, {"", false}, } for _, tt := range tests { t.Run(tt.catalog, func(t *testing.T) { result := client.isCatalogAllowed(tt.catalog) if result != tt.expected { t.Errorf("isCatalogAllowed(%q) = %v, want %v", tt.catalog, result, tt.expected) } }) } } func TestIsSchemaAllowed(t *testing.T) { client := &Client{ config: &config.TrinoConfig{ AllowedSchemas: []string{"hive.analytics", "hive.marts", "PostgreSQL.PUBLIC"}, }, } tests := []struct { catalog string schema string expected bool }{ {"hive", "analytics", true}, {"hive", "marts", true}, {"postgresql", "public", true}, // Case insensitive {"PostgreSQL", "PUBLIC", true}, {"hive", "staging", false}, {"postgresql", "private", false}, {"mysql", "analytics", false}, } for _, tt := range tests { t.Run(tt.catalog+"."+tt.schema, func(t *testing.T) { result := client.isSchemaAllowed(tt.catalog, tt.schema) if result != tt.expected { t.Errorf("isSchemaAllowed(%q, %q) = %v, want %v", tt.catalog, tt.schema, result, tt.expected) } }) } } func TestIsTableAllowed(t *testing.T) { client := &Client{ config: &config.TrinoConfig{ AllowedTables: []string{"hive.analytics.users", "hive.marts.sales", "PostgreSQL.PUBLIC.ORDERS"}, }, } tests := []struct { name string catalog string schema string table string expected bool }{ {"Simple match", "hive", "analytics", "users", true}, {"Case insensitive match", "PostgreSQL", "PUBLIC", "ORDERS", true}, {"No match - different table", "hive", "analytics", "events", false}, {"No match - different schema", "hive", "staging", "users", false}, {"No match - different catalog", "mysql", "analytics", "users", false}, {"Empty catalog", "", "analytics", "users", false}, {"Empty schema", "hive", "", "users", false}, {"Empty table", "hive", "analytics", "", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := client.isTableAllowed(tt.catalog, tt.schema, tt.table) if result != tt.expected { t.Errorf("isTableAllowed(%q, %q, %q) = %v, want %v", tt.catalog, tt.schema, tt.table, result, tt.expected) } }) } } func TestTableParameterResolution(t *testing.T) { client := &Client{ config: &config.TrinoConfig{ Catalog: "hive", Schema: "default", }, } // Test table parameter resolution logic (extracted from GetTableSchema) testResolution := func(inputCatalog, inputSchema, inputTable, expectedCatalog, expectedSchema, expectedTable string) { // Simulate the resolution logic from GetTableSchema catalog, schema, table := inputCatalog, inputSchema, inputTable parts := strings.Split(table, ".") if len(parts) == 3 { // If table is already fully qualified, extract components catalog = parts[0] schema = parts[1] table = parts[2] } else if len(parts) == 2 { // If table has schema.table format schema = parts[0] table = parts[1] if catalog == "" { catalog = client.config.Catalog } } else { // Use provided or default catalog and schema if catalog == "" { catalog = client.config.Catalog } if schema == "" { schema = client.config.Schema } } if catalog != expectedCatalog || schema != expectedSchema || table != expectedTable { t.Errorf("Resolution(%q, %q, %q) = (%q, %q, %q), want (%q, %q, %q)", inputCatalog, inputSchema, inputTable, catalog, schema, table, expectedCatalog, expectedSchema, expectedTable) } } // Test the resolution logic that was causing the bug testResolution("", "analytics", "users", "hive", "analytics", "users") // use default catalog testResolution("", "", "analytics.users", "hive", "analytics", "users") // schema.table format testResolution("", "", "hive.analytics.users", "hive", "analytics", "users") // fully qualified testResolution("postgresql", "public", "orders", "postgresql", "public", "orders") // explicit params } func TestGetTableSchemaAllowlistLogic(t *testing.T) { client := &Client{ config: &config.TrinoConfig{ Catalog: "hive", Schema: "default", AllowedTables: []string{"hive.analytics.users", "hive.marts.sales"}, }, } // Test the combined resolution + allowlist check logic testAllowlistAfterResolution := func(inputCatalog, inputSchema, inputTable string, expectedAllowed bool) { // Simulate the resolution + allowlist check from GetTableSchema catalog, schema, table := inputCatalog, inputSchema, inputTable // Resolution logic (copied from GetTableSchema) parts := strings.Split(table, ".") if len(parts) == 3 { catalog = parts[0] schema = parts[1] table = parts[2] } else if len(parts) == 2 { schema = parts[0] table = parts[1] if catalog == "" { catalog = client.config.Catalog } } else { if catalog == "" { catalog = client.config.Catalog } if schema == "" { schema = client.config.Schema } } // Allowlist check (after resolution) allowed := client.isTableAllowed(catalog, schema, table) if allowed != expectedAllowed { t.Errorf("Allowlist check after resolution(%q, %q, %q) -> isTableAllowed(%q, %q, %q) = %v, want %v", inputCatalog, inputSchema, inputTable, catalog, schema, table, allowed, expectedAllowed) } } // Test cases that verify the bug fix testAllowlistAfterResolution("hive", "analytics", "users", true) // explicit - should work testAllowlistAfterResolution("", "analytics", "users", true) // default catalog - should work testAllowlistAfterResolution("", "", "analytics.users", true) // schema.table - BUG FIX: should work now testAllowlistAfterResolution("", "", "hive.analytics.users", true) // fully qualified - should work testAllowlistAfterResolution("hive", "analytics", "events", false) // not in allowlist - should deny testAllowlistAfterResolution("postgresql", "analytics", "users", false) // wrong catalog - should deny } func TestImprovedIsReadOnlyQuery(t *testing.T) { tests := []struct { name string query string expected bool }{ // Basic read-only queries with word boundaries {"SELECT with word boundary", "SELECT * FROM users", true}, {"SELECT with leading spaces", " SELECT * FROM users", true}, {"SELECT with newlines", "\n SELECT * FROM users\n", true}, {"SHOW with word boundary", "SHOW TABLES", true}, {"DESCRIBE with word boundary", "DESCRIBE users", true}, {"EXPLAIN with word boundary", "EXPLAIN SELECT * FROM users", true}, {"WITH CTE", "WITH cte AS (SELECT 1) SELECT * FROM cte", true}, // SHOW CREATE statements (read-only despite containing "create" keyword) {"SHOW CREATE TABLE", "SHOW CREATE TABLE users", true}, {"SHOW CREATE TABLE with schema", "SHOW CREATE TABLE myschema.users", true}, {"SHOW CREATE TABLE fully qualified", "SHOW CREATE TABLE catalog.schema.table", true}, {"SHOW CREATE TABLE with spaces", " SHOW CREATE TABLE users ", true}, {"SHOW CREATE VIEW", "SHOW CREATE VIEW my_view", true}, {"SHOW CREATE SCHEMA", "SHOW CREATE SCHEMA myschema", true}, {"SHOW CREATE MATERIALIZED VIEW", "SHOW CREATE MATERIALIZED VIEW my_mat_view", true}, // Edge cases with word boundaries (these should now be stricter) {"SELECT without space", "SELECT*FROM users", true}, // Word boundary handles this {"SHOW without space", "SHOWTABLES", false}, // Word boundary requires separation // Write operations that should be blocked {"INSERT statement", "INSERT INTO users VALUES (1)", false}, {"UPDATE statement", "UPDATE users SET name = 'test'", false}, {"DELETE statement", "DELETE FROM users", false}, {"CREATE statement", "CREATE TABLE test (id INT)", false}, {"CREATE VIEW statement", "CREATE VIEW myview AS SELECT 1", false}, {"DROP statement", "DROP TABLE users", false}, {"ALTER statement", "ALTER TABLE users ADD COLUMN age INT", false}, // Complex cases {"SELECT with INSERT in string", "SELECT 'INSERT INTO' FROM dual", true}, {"SELECT with INSERT in comment", "SELECT 1 -- INSERT INTO users", true}, {"Multi-statement with semicolon", "SELECT 1; INSERT INTO users VALUES (1)", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := isReadOnlyQuery(tt.query) if result != tt.expected { t.Errorf("isReadOnlyQuery(%q) = %v, want %v", tt.query, result, tt.expected) } }) } } func TestGetQueryUsername(t *testing.T) { tests := []struct { name string user *oauth.User expected string }{ { name: "User with email", user: &oauth.User{Email: "abc@example.com"}, expected: "abc@example.com", }, { name: "User with username", user: &oauth.User{Username: "abc@example.com"}, expected: "abc@example.com", }, { name: "User with username and email", user: &oauth.User{Username: "abc@example.com", Email: "def@example.com"}, expected: "abc@example.com", }, { name: "User with subject", user: &oauth.User{Subject: "abc@example.com"}, expected: "abc@example.com", }, { name: "Empty User - returns empty (no attribution)", user: &oauth.User{}, expected: "", }, { name: "Nil User - returns empty (no attribution)", user: nil, expected: "", }, } for index := range tests { t.Run(tests[index].name, func(t *testing.T) { ctx := context.Background() tt := tests[index] if tt.user != nil { ctx = oauth.WithUser(ctx, tt.user) } result := getQueryUsername(ctx) if result != tt.expected { t.Errorf("getQueryUsername(%v) = %s, want %s", tt.user, result, tt.expected) } }) } }

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/tuannvm/mcp-trino'

If you have feedback or need assistance with the MCP directory API, please join our Discord server