Skip to main content
Glama
client.go20.5 kB
package trino import ( "context" "database/sql" "fmt" "log" "net/http" "net/url" "regexp" "strings" "time" "github.com/trinodb/trino-go-client/trino" "github.com/tuannvm/mcp-trino/internal/config" oauth "github.com/tuannvm/oauth-mcp-proxy" ) // Context key for impersonated user type contextKey string const ( impersonatedUserKey contextKey = "impersonated_user" ) // headerRoundTripper adds X-Trino-Source and X-Trino-User headers to requests type headerRoundTripper struct { base http.RoundTripper config *config.TrinoConfig } func (t *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { req = req.Clone(req.Context()) // Set X-Trino-Source header for query attribution if t.config.TrinoSource != "" { req.Header.Set("X-Trino-Source", t.config.TrinoSource) } // Set X-Trino-User header if impersonation is enabled if t.config.EnableImpersonation { if user, ok := req.Context().Value(impersonatedUserKey).(string); ok && user != "" { req.Header.Set("X-Trino-User", user) } } return t.base.RoundTrip(req) } // Client is a wrapper around Trino client type Client struct { db *sql.DB config *config.TrinoConfig timeout time.Duration } // NewClient creates a new Trino client func NewClient(cfg *config.TrinoConfig) (*Client, error) { dsnURL := url.URL{ Scheme: cfg.Scheme, User: url.UserPassword(cfg.User, cfg.Password), Host: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), } params := url.Values{} params.Add("catalog", cfg.Catalog) params.Add("schema", cfg.Schema) params.Add("SSL", fmt.Sprintf("%t", cfg.SSL)) params.Add("SSLInsecure", fmt.Sprintf("%t", cfg.SSLInsecure)) params.Add("custom_client", "mcp-trino") dsnURL.RawQuery = params.Encode() dsn := dsnURL.String() httpClient := &http.Client{ Transport: &headerRoundTripper{ base: http.DefaultTransport, config: cfg, }, } if err := trino.RegisterCustomClient("mcp-trino", httpClient); err != nil { // Ignore "already registered" errors - this can happen in tests or when client is recreated if !strings.Contains(err.Error(), "already registered") { return nil, fmt.Errorf("failed to register custom HTTP client: %w", err) } } db, err := sql.Open("trino", dsn) if err != nil { // Sanitize error to prevent password exposure sanitizedErr := sanitizeConnectionError(err, cfg.Password) return nil, fmt.Errorf("failed to connect to Trino: %w", sanitizedErr) } // Set connection pool parameters db.SetMaxOpenConns(10) db.SetMaxIdleConns(5) db.SetConnMaxLifetime(5 * time.Minute) // Test the connection if err := db.Ping(); err != nil { closeErr := db.Close() if closeErr != nil { log.Printf("Error closing DB connection: %v", closeErr) } // Sanitize error to prevent password exposure sanitizedErr := sanitizeConnectionError(err, cfg.Password) return nil, fmt.Errorf("failed to ping Trino: %w", sanitizedErr) } return &Client{ db: db, config: cfg, timeout: cfg.QueryTimeout, }, nil } // Close closes the database connection func (c *Client) Close() error { return c.db.Close() } // WithImpersonatedUser adds impersonated user to context func WithImpersonatedUser(ctx context.Context, username string) context.Context { return context.WithValue(ctx, impersonatedUserKey, username) } // GetImpersonatedUser retrieves impersonated user from context func GetImpersonatedUser(ctx context.Context) (string, bool) { user, ok := ctx.Value(impersonatedUserKey).(string) return user, ok } // isReadOnlyQuery checks if the SQL query is read-only (SELECT, SHOW, DESCRIBE, EXPLAIN) // This helps prevent SQL injection attacks by restricting the types of queries allowed func isReadOnlyQuery(query string) bool { // Convert to lowercase for case-insensitive comparison and normalize whitespace queryLower := strings.ToLower(strings.TrimSpace(query)) // Replace any newline characters with spaces to normalize the query format queryLower = strings.ReplaceAll(queryLower, "\n", " ") queryLower = strings.ReplaceAll(queryLower, "\r", " ") // Remove string literals and comments to avoid false positives queryLower = sanitizeQueryForKeywordDetection(queryLower) // First check for SQL injection attempts with multiple statements if strings.Contains(queryLower, ";") { return false } // Check if query starts with SELECT, SHOW, DESCRIBE, EXPLAIN or WITH (for CTEs) // These are generally read-only operations. Use word boundaries for robustness. // IMPORTANT: This check must come BEFORE write operation detection to avoid false positives // (e.g., "SHOW CREATE TABLE" contains "create" but is read-only) readOnlyPrefixPatterns := []string{ `^\s*select\b`, `^\s*show\b`, `^\s*describe\b`, `^\s*explain\b`, `^\s*with\b`, } for _, pattern := range readOnlyPrefixPatterns { matched, _ := regexp.MatchString(pattern, queryLower) if matched { // For queries starting with read-only prefixes, we still need to check // for disallowed write operations that might be embedded // But we allow common read-only patterns like "SHOW CREATE TABLE" if isAllowedReadOnlyPattern(queryLower) { return true } } } // Check for write operations anywhere in the query using word boundaries // - https://trino.io/docs/current/sql.html - Main SQL reference writeOperations := []string{ "insert", "update", "delete", "drop", "create", "alter", "truncate", "merge", "copy", "grant", "revoke", "commit", "rollback", "call", "execute", "refresh", "set", "reset", } for _, op := range writeOperations { // Use word boundary regex to catch operations followed by any whitespace pattern := fmt.Sprintf(`\b%s\b`, regexp.QuoteMeta(op)) matched, _ := regexp.MatchString(pattern, queryLower) if matched { return false } } return false } // isAllowedReadOnlyPattern checks if a query matches known safe read-only patterns // even if it contains keywords that might look like write operations func isAllowedReadOnlyPattern(queryLower string) bool { // SHOW CREATE statements are read-only (they just display DDL) showCreatePatterns := []string{ `^\s*show\s+create\s+table\b`, `^\s*show\s+create\s+view\b`, `^\s*show\s+create\s+schema\b`, `^\s*show\s+create\s+materialized\s+view\b`, } for _, pattern := range showCreatePatterns { matched, _ := regexp.MatchString(pattern, queryLower) if matched { return true } } // Other SHOW statements without CREATE are safe if matched, _ := regexp.MatchString(`^\s*show\b`, queryLower); matched { // Check if it doesn't contain any write operation keywords after SHOW // (except for "create" which is handled above) writeOpsExceptCreate := []string{ "insert", "update", "delete", "drop", "alter", "truncate", "merge", "copy", "grant", "revoke", "commit", "rollback", "call", "execute", "refresh", "set", "reset", } for _, op := range writeOpsExceptCreate { pattern := fmt.Sprintf(`\b%s\b`, regexp.QuoteMeta(op)) if matched, _ := regexp.MatchString(pattern, queryLower); matched { return false } } return true } // SELECT, DESCRIBE, EXPLAIN, WITH without write operations are safe safeStarts := []string{`^\s*select\b`, `^\s*describe\b`, `^\s*explain\b`, `^\s*with\b`} for _, pattern := range safeStarts { if matched, _ := regexp.MatchString(pattern, queryLower); matched { // If it starts with a safe keyword, check there are no write operations writeOps := []string{ "insert", "update", "delete", "drop", "create", "alter", "truncate", "merge", "copy", "grant", "revoke", "commit", "rollback", "call", "execute", "refresh", "set", "reset", } for _, op := range writeOps { opPattern := fmt.Sprintf(`\b%s\b`, regexp.QuoteMeta(op)) if matched, _ := regexp.MatchString(opPattern, queryLower); matched { return false } } return true } } return false } // sanitizeQueryForKeywordDetection removes string literals, quoted identifiers, and comments // to prevent false positives when detecting write operations func sanitizeQueryForKeywordDetection(query string) string { // Remove single-quoted string literals: 'text' // Handle escaped quotes: 'don''t' becomes 'don''t' query = regexp.MustCompile(`'(?:[^']|'')*'`).ReplaceAllString(query, "'LITERAL'") // Remove double-quoted identifiers: "column_name" // Handle escaped quotes: "column""name" becomes "column""name" query = regexp.MustCompile(`"(?:[^"]|"")*"`).ReplaceAllString(query, "\"IDENTIFIER\"") // Remove backtick-quoted identifiers: `column_name` query = regexp.MustCompile("`[^`]*`").ReplaceAllString(query, "`IDENTIFIER`") // Remove single-line comments: -- comment query = regexp.MustCompile(`--[^\r\n]*`).ReplaceAllString(query, "") // Remove multi-line comments: /* comment */ query = regexp.MustCompile(`/\*[^*]*\*+(?:[^/*][^*]*\*+)*/`).ReplaceAllString(query, "") return strings.TrimSpace(query) } // getQueryUsername returns the username of the user executing the query if present in OAuth context // This is used for query attribution (X-Trino-Client-Tags/Info) independent of impersonation func getQueryUsername(ctx context.Context) string { user, exists := oauth.GetUserFromContext(ctx) if !exists || user == nil { return "" } if user.Username != "" { return user.Username } if user.Email != "" { return user.Email } if user.Subject != "" { return user.Subject } return "" } // ExecuteQuery executes a SQL query and returns the results func (c *Client) ExecuteQuery(query string) ([]map[string]interface{}, error) { return c.ExecuteQueryWithContext(context.Background(), query) } // ExecuteQueryWithContext executes a SQL query and returns the results // It supports both: // - User impersonation via X-Trino-User header (when EnableImpersonation is true) // - Query attribution via X-Trino-Client-Tags/Info/Source (from OAuth user context) func (c *Client) ExecuteQueryWithContext(ctx context.Context, query string) ([]map[string]interface{}, error) { // Strip trailing semicolon that Trino doesn't allow query = strings.TrimSuffix(strings.TrimSpace(query), ";") // SQL injection protection: only allow read-only queries unless explicitly allowed in config if !c.config.AllowWriteQueries && !isReadOnlyQuery(query) { return nil, fmt.Errorf("security restriction: only SELECT, SHOW, DESCRIBE, and EXPLAIN queries are allowed. " + "Set TRINO_ALLOW_WRITE_QUERIES=true to enable write operations (at your own risk)") } // Create context with timeout, preserving any impersonation data queryCtx, cancel := context.WithTimeout(ctx, c.timeout) defer cancel() // Build query arguments for attribution headers // These are complementary to the X-Trino-User header set by RoundTripper var queryArgs []interface{} if userName := getQueryUsername(ctx); userName != "" { queryArgs = append(queryArgs, sql.Named("X-Trino-Client-Tags", userName), sql.Named("X-Trino-Client-Info", userName), ) // Only set X-Trino-Source if not already configured globally if c.config.TrinoSource == "" { queryArgs = append(queryArgs, sql.Named("X-Trino-Source", userName)) } } // Execute the query with optional attribution headers rows, err := c.db.QueryContext(queryCtx, query, queryArgs...) if err != nil { return nil, fmt.Errorf("query execution failed: %w", err) } defer func() { if err := rows.Close(); err != nil { log.Printf("Error closing rows: %v", err) } }() // Get column names columns, err := rows.Columns() if err != nil { return nil, fmt.Errorf("failed to get column names: %w", err) } // Prepare result container results := make([]map[string]interface{}, 0) // Iterate through rows for rows.Next() { // Create a slice of interface{} to hold the values values := make([]interface{}, len(columns)) valuePtrs := make([]interface{}, len(columns)) // Initialize the pointers for i := range values { valuePtrs[i] = &values[i] } // Scan the row into values if err := rows.Scan(valuePtrs...); err != nil { log.Printf("Error scanning row: %v", err) continue } // Create a map for the current row rowMap := make(map[string]interface{}) for i, col := range columns { val := values[i] rowMap[col] = val } results = append(results, rowMap) } // Check for errors after iterating if err := rows.Err(); err != nil { return nil, fmt.Errorf("error iterating rows: %w", err) } return results, nil } // ListCatalogs returns a list of available catalogs func (c *Client) ListCatalogs() ([]string, error) { return c.ListCatalogsWithContext(context.Background()) } // ListCatalogsWithContext returns a list of available catalogs with context func (c *Client) ListCatalogsWithContext(ctx context.Context) ([]string, error) { results, err := c.ExecuteQueryWithContext(ctx, "SHOW CATALOGS") if err != nil { return nil, err } catalogs := make([]string, 0, len(results)) for _, row := range results { if catalog, ok := row["Catalog"].(string); ok { catalogs = append(catalogs, catalog) } } // Apply catalog filtering if allowlist is configured if len(c.config.AllowedCatalogs) > 0 { catalogs = c.filterCatalogs(catalogs) } return catalogs, nil } // ListSchemas returns a list of schemas in the specified catalog func (c *Client) ListSchemas(catalog string) ([]string, error) { return c.ListSchemasWithContext(context.Background(), catalog) } // ListSchemasWithContext returns a list of schemas in the specified catalog with context func (c *Client) ListSchemasWithContext(ctx context.Context, catalog string) ([]string, error) { if catalog == "" { catalog = c.config.Catalog } query := fmt.Sprintf("SHOW SCHEMAS FROM %s", catalog) results, err := c.ExecuteQueryWithContext(ctx, query) if err != nil { return nil, err } schemas := make([]string, 0, len(results)) for _, row := range results { if schema, ok := row["Schema"].(string); ok { schemas = append(schemas, schema) } } // Apply schema filtering if allowlist is configured if len(c.config.AllowedSchemas) > 0 { schemas = c.filterSchemas(schemas, catalog) } return schemas, nil } // ListTables returns a list of tables in the specified catalog and schema func (c *Client) ListTables(catalog, schema string) ([]string, error) { return c.ListTablesWithContext(context.Background(), catalog, schema) } // ListTablesWithContext returns a list of tables in the specified catalog and schema with context func (c *Client) ListTablesWithContext(ctx context.Context, catalog, schema string) ([]string, error) { if catalog == "" { catalog = c.config.Catalog } if schema == "" { schema = c.config.Schema } query := fmt.Sprintf("SHOW TABLES FROM %s.%s", catalog, schema) results, err := c.ExecuteQueryWithContext(ctx, query) if err != nil { return nil, err } tables := make([]string, 0, len(results)) for _, row := range results { if table, ok := row["Table"].(string); ok { tables = append(tables, table) } } // Apply table filtering if allowlist is configured if len(c.config.AllowedTables) > 0 { tables = c.filterTables(tables, catalog, schema) } return tables, nil } // GetTableSchema returns the schema of a table func (c *Client) GetTableSchema(catalog, schema, table string) ([]map[string]interface{}, error) { return c.GetTableSchemaWithContext(context.Background(), catalog, schema, table) } // GetTableSchemaWithContext returns the schema of a table with context func (c *Client) GetTableSchemaWithContext(ctx context.Context, catalog, schema, table string) ([]map[string]interface{}, error) { // Resolve catalog/schema/table parameters first parts := strings.Split(table, ".") if len(parts) == 3 { // If table is already fully qualified, extract components catalog = parts[0] schema = parts[1] table = parts[2] } else if len(parts) == 2 { // If table has schema.table format schema = parts[0] table = parts[1] if catalog == "" { catalog = c.config.Catalog } } else { // Use provided or default catalog and schema if catalog == "" { catalog = c.config.Catalog } if schema == "" { schema = c.config.Schema } } // Check if table access is allowed when table allowlist is configured (after resolution) if len(c.config.AllowedTables) > 0 { if !c.isTableAllowed(catalog, schema, table) { return nil, fmt.Errorf("table access denied: %s.%s.%s not in allowlist", catalog, schema, table) } } // Build and execute query with resolved parameters query := fmt.Sprintf("DESCRIBE %s.%s.%s", catalog, schema, table) return c.ExecuteQueryWithContext(ctx, query) } // ExplainQuery returns the query execution plan for a given SQL query func (c *Client) ExplainQuery(query string, format string) ([]map[string]interface{}, error) { return c.ExplainQueryWithContext(context.Background(), query, format) } // ExplainQueryWithContext returns the query execution plan for a given SQL query with context func (c *Client) ExplainQueryWithContext(ctx context.Context, query string, format string) ([]map[string]interface{}, error) { // Build EXPLAIN query with optional TYPE format (LOGICAL|DISTRIBUTED|VALIDATE|IO) explainQuery := "EXPLAIN" if f := strings.ToUpper(strings.TrimSpace(format)); f != "" { switch f { case "LOGICAL", "DISTRIBUTED", "VALIDATE", "IO": explainQuery = fmt.Sprintf("EXPLAIN (TYPE %s)", f) default: return nil, fmt.Errorf("invalid EXPLAIN format: %q (allowed: LOGICAL, DISTRIBUTED, VALIDATE, IO)", format) } } explainQuery = fmt.Sprintf("%s %s", explainQuery, query) return c.ExecuteQueryWithContext(ctx, explainQuery) } // sanitizeConnectionError removes sensitive information from connection errors func sanitizeConnectionError(err error, password string) error { if err == nil { return err } errStr := err.Error() // Replace password in error message if it exists if password != "" { // Replace URL-encoded password encodedPassword := url.QueryEscape(password) errStr = strings.ReplaceAll(errStr, encodedPassword, "[PASSWORD_REDACTED]") // Replace plain password errStr = strings.ReplaceAll(errStr, password, "[PASSWORD_REDACTED]") } return fmt.Errorf("%s", errStr) } // filterCatalogs filters a list of catalogs based on the allowlist configuration func (c *Client) filterCatalogs(catalogs []string) []string { if len(c.config.AllowedCatalogs) == 0 { return catalogs } filtered := make([]string, 0, len(catalogs)) for _, catalog := range catalogs { if c.isCatalogAllowed(catalog) { filtered = append(filtered, catalog) } } log.Printf("DEBUG: Catalog filtering: %d catalogs -> %d catalogs", len(catalogs), len(filtered)) return filtered } // filterSchemas filters a list of schemas based on the allowlist configuration func (c *Client) filterSchemas(schemas []string, catalog string) []string { if len(c.config.AllowedSchemas) == 0 { return schemas } filtered := make([]string, 0, len(schemas)) for _, schema := range schemas { if c.isSchemaAllowed(catalog, schema) { filtered = append(filtered, schema) } } log.Printf("DEBUG: Schema filtering: %d schemas -> %d schemas", len(schemas), len(filtered)) return filtered } // filterTables filters a list of tables based on the allowlist configuration func (c *Client) filterTables(tables []string, catalog, schema string) []string { if len(c.config.AllowedTables) == 0 { return tables } filtered := make([]string, 0, len(tables)) for _, table := range tables { if c.isTableAllowed(catalog, schema, table) { filtered = append(filtered, table) } } log.Printf("DEBUG: Table filtering: %d tables -> %d tables", len(tables), len(filtered)) return filtered } // isCatalogAllowed checks if a catalog is in the allowed catalogs list func (c *Client) isCatalogAllowed(catalog string) bool { for _, allowed := range c.config.AllowedCatalogs { if strings.EqualFold(catalog, allowed) { return true } } return false } // isSchemaAllowed checks if a schema is in the allowed schemas list func (c *Client) isSchemaAllowed(catalog, schema string) bool { fullSchemaName := catalog + "." + schema for _, allowed := range c.config.AllowedSchemas { if strings.EqualFold(fullSchemaName, allowed) { return true } } return false } // isTableAllowed checks if a table is in the allowed tables list func (c *Client) isTableAllowed(catalog, schema, table string) bool { fullTableName := catalog + "." + schema + "." + table for _, allowed := range c.config.AllowedTables { if strings.EqualFold(fullTableName, allowed) { return true } } return false }

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/tuannvm/mcp-trino'

If you have feedback or need assistance with the MCP directory API, please join our Discord server