table_name_parser.go•11.7 kB
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bigquerycommon
import (
	"fmt"
	"strings"
	"unicode"
)
// parserState defines the state of the SQL parser's state machine.
type parserState int
const (
	stateNormal parserState = iota
	// String states
	stateInSingleQuoteString
	stateInDoubleQuoteString
	stateInTripleSingleQuoteString
	stateInTripleDoubleQuoteString
	stateInRawSingleQuoteString
	stateInRawDoubleQuoteString
	stateInRawTripleSingleQuoteString
	stateInRawTripleDoubleQuoteString
	// Comment states
	stateInSingleLineCommentDash
	stateInSingleLineCommentHash
	stateInMultiLineComment
)
// SQL statement verbs
const (
	verbCreate = "create"
	verbAlter  = "alter"
	verbDrop   = "drop"
	verbSelect = "select"
	verbInsert = "insert"
	verbUpdate = "update"
	verbDelete = "delete"
	verbMerge  = "merge"
)
var tableFollowsKeywords = map[string]bool{
	"from":   true,
	"join":   true,
	"update": true,
	"into":   true, // INSERT INTO, MERGE INTO
	"table":  true, // CREATE TABLE, ALTER TABLE
	"using":  true, // MERGE ... USING
	"insert": true, // INSERT my_table
	"merge":  true, // MERGE my_table
}
var tableContextExitKeywords = map[string]bool{
	"where":  true,
	"group":  true, // GROUP BY
	"having": true,
	"order":  true, // ORDER BY
	"limit":  true,
	"window": true,
	"on":     true, // JOIN ... ON
	"set":    true, // UPDATE ... SET
	"when":   true, // MERGE ... WHEN
}
// TableParser is the main entry point for parsing a SQL string to find all referenced table IDs.
// It handles multi-statement SQL, comments, and recursive parsing of EXECUTE IMMEDIATE statements.
func TableParser(sql, defaultProjectID string) ([]string, error) {
	tableIDSet := make(map[string]struct{})
	visitedSQLs := make(map[string]struct{})
	if _, err := parseSQL(sql, defaultProjectID, tableIDSet, visitedSQLs, false); err != nil {
		return nil, err
	}
	tableIDs := make([]string, 0, len(tableIDSet))
	for id := range tableIDSet {
		tableIDs = append(tableIDs, id)
	}
	return tableIDs, nil
}
// parseSQL is the core recursive function that processes SQL strings.
// It uses a state machine to find table names and recursively parse EXECUTE IMMEDIATE.
func parseSQL(sql, defaultProjectID string, tableIDSet map[string]struct{}, visitedSQLs map[string]struct{}, inSubquery bool) (int, error) {
	// Prevent infinite recursion.
	if _, ok := visitedSQLs[sql]; ok {
		return len(sql), nil
	}
	visitedSQLs[sql] = struct{}{}
	state := stateNormal
	expectingTable := false
	var lastTableKeyword, lastToken, statementVerb string
	runes := []rune(sql)
	for i := 0; i < len(runes); {
		char := runes[i]
		remaining := sql[i:]
		switch state {
		case stateNormal:
			if strings.HasPrefix(remaining, "--") {
				state = stateInSingleLineCommentDash
				i += 2
				continue
			}
			if strings.HasPrefix(remaining, "#") {
				state = stateInSingleLineCommentHash
				i++
				continue
			}
			if strings.HasPrefix(remaining, "/*") {
				state = stateInMultiLineComment
				i += 2
				continue
			}
			if char == '(' {
				if expectingTable {
					// The subquery starts after '('.
					consumed, err := parseSQL(remaining[1:], defaultProjectID, tableIDSet, visitedSQLs, true)
					if err != nil {
						return 0, err
					}
					// Advance i by the length of the subquery + the opening parenthesis.
					// The recursive call returns what it consumed, including the closing parenthesis.
					i += consumed + 1
					// For most keywords, we expect only one table. `from` can have multiple "tables" (subqueries).
					if lastTableKeyword != "from" {
						expectingTable = false
					}
					continue
				}
			}
			if char == ')' {
				if inSubquery {
					return i + 1, nil
				}
			}
			if char == ';' {
				statementVerb = ""
				lastToken = ""
				i++
				continue
			}
			// Raw strings must be checked before regular strings.
			if strings.HasPrefix(remaining, "r'''") || strings.HasPrefix(remaining, "R'''") {
				state = stateInRawTripleSingleQuoteString
				i += 4
				continue
			}
			if strings.HasPrefix(remaining, `r"""`) || strings.HasPrefix(remaining, `R"""`) {
				state = stateInRawTripleDoubleQuoteString
				i += 4
				continue
			}
			if strings.HasPrefix(remaining, "r'") || strings.HasPrefix(remaining, "R'") {
				state = stateInRawSingleQuoteString
				i += 2
				continue
			}
			if strings.HasPrefix(remaining, `r"`) || strings.HasPrefix(remaining, `R"`) {
				state = stateInRawDoubleQuoteString
				i += 2
				continue
			}
			if strings.HasPrefix(remaining, "'''") {
				state = stateInTripleSingleQuoteString
				i += 3
				continue
			}
			if strings.HasPrefix(remaining, `"""`) {
				state = stateInTripleDoubleQuoteString
				i += 3
				continue
			}
			if char == '\'' {
				state = stateInSingleQuoteString
				i++
				continue
			}
			if char == '"' {
				state = stateInDoubleQuoteString
				i++
				continue
			}
			if unicode.IsLetter(char) || char == '`' {
				parts, consumed, err := parseIdentifierSequence(remaining)
				if err != nil {
					return 0, err
				}
				if consumed == 0 {
					i++
					continue
				}
				if len(parts) == 1 {
					keyword := strings.ToLower(parts[0])
					switch keyword {
					case "call":
						return 0, fmt.Errorf("CALL is not allowed when dataset restrictions are in place, as the called procedure's contents cannot be safely analyzed")
					case "immediate":
						if lastToken == "execute" {
							return 0, fmt.Errorf("EXECUTE IMMEDIATE is not allowed when dataset restrictions are in place, as its contents cannot be safely analyzed")
						}
					case "procedure", "function":
						if lastToken == "create" || lastToken == "create or replace" {
							return 0, fmt.Errorf("unanalyzable statements like '%s %s' are not allowed", strings.ToUpper(lastToken), strings.ToUpper(keyword))
						}
					case verbCreate, verbAlter, verbDrop, verbSelect, verbInsert, verbUpdate, verbDelete, verbMerge:
						if statementVerb == "" {
							statementVerb = keyword
						}
					}
					if statementVerb == verbCreate || statementVerb == verbAlter || statementVerb == verbDrop {
						if keyword == "schema" || keyword == "dataset" {
							return 0, fmt.Errorf("dataset-level operations like '%s %s' are not allowed when dataset restrictions are in place", strings.ToUpper(statementVerb), strings.ToUpper(keyword))
						}
					}
					if _, ok := tableFollowsKeywords[keyword]; ok {
						expectingTable = true
						lastTableKeyword = keyword
					} else if _, ok := tableContextExitKeywords[keyword]; ok {
						expectingTable = false
						lastTableKeyword = ""
					}
					if lastToken == "create" && keyword == "or" {
						lastToken = "create or"
					} else if lastToken == "create or" && keyword == "replace" {
						lastToken = "create or replace"
					} else {
						lastToken = keyword
					}
				} else if len(parts) >= 2 {
					// This is a multi-part identifier. If we were expecting a table, this is it.
					if expectingTable {
						tableID, err := formatTableID(parts, defaultProjectID)
						if err != nil {
							return 0, err
						}
						if tableID != "" {
							tableIDSet[tableID] = struct{}{}
						}
						// For most keywords, we expect only one table.
						if lastTableKeyword != "from" {
							expectingTable = false
						}
					}
					lastToken = ""
				}
				i += consumed
				continue
			}
			i++
		case stateInSingleQuoteString:
			if char == '\\' {
				i += 2 // Skip backslash and the escaped character.
				continue
			}
			if char == '\'' {
				state = stateNormal
			}
			i++
		case stateInDoubleQuoteString:
			if char == '\\' {
				i += 2 // Skip backslash and the escaped character.
				continue
			}
			if char == '"' {
				state = stateNormal
			}
			i++
		case stateInTripleSingleQuoteString:
			if strings.HasPrefix(remaining, "'''") {
				state = stateNormal
				i += 3
			} else {
				i++
			}
		case stateInTripleDoubleQuoteString:
			if strings.HasPrefix(remaining, `"""`) {
				state = stateNormal
				i += 3
			} else {
				i++
			}
		case stateInSingleLineCommentDash, stateInSingleLineCommentHash:
			if char == '\n' {
				state = stateNormal
			}
			i++
		case stateInMultiLineComment:
			if strings.HasPrefix(remaining, "*/") {
				state = stateNormal
				i += 2
			} else {
				i++
			}
		case stateInRawSingleQuoteString:
			if char == '\'' {
				state = stateNormal
			}
			i++
		case stateInRawDoubleQuoteString:
			if char == '"' {
				state = stateNormal
			}
			i++
		case stateInRawTripleSingleQuoteString:
			if strings.HasPrefix(remaining, "'''") {
				state = stateNormal
				i += 3
			} else {
				i++
			}
		case stateInRawTripleDoubleQuoteString:
			if strings.HasPrefix(remaining, `"""`) {
				state = stateNormal
				i += 3
			} else {
				i++
			}
		}
	}
	if inSubquery {
		return 0, fmt.Errorf("unclosed subquery parenthesis")
	}
	return len(sql), nil
}
// parseIdentifierSequence parses a sequence of dot-separated identifiers.
// It returns the parts of the identifier, the number of characters consumed, and an error.
func parseIdentifierSequence(s string) ([]string, int, error) {
	var parts []string
	var totalConsumed int
	for {
		remaining := s[totalConsumed:]
		trimmed := strings.TrimLeftFunc(remaining, unicode.IsSpace)
		totalConsumed += len(remaining) - len(trimmed)
		current := s[totalConsumed:]
		if len(current) == 0 {
			break
		}
		var part string
		var consumed int
		if current[0] == '`' {
			end := strings.Index(current[1:], "`")
			if end == -1 {
				return nil, 0, fmt.Errorf("unclosed backtick identifier")
			}
			part = current[1 : end+1]
			consumed = end + 2
		} else if len(current) > 0 && unicode.IsLetter(rune(current[0])) {
			end := strings.IndexFunc(current, func(r rune) bool {
				return !unicode.IsLetter(r) && !unicode.IsNumber(r) && r != '_' && r != '-'
			})
			if end == -1 {
				part = current
				consumed = len(current)
			} else {
				part = current[:end]
				consumed = end
			}
		} else {
			break
		}
		if current[0] == '`' && strings.Contains(part, ".") {
			// This handles cases like `project.dataset.table` but not `project.dataset`.table.
			// If the character after the quoted identifier is not a dot, we treat it as a full name.
			if len(current) <= consumed || current[consumed] != '.' {
				parts = append(parts, strings.Split(part, ".")...)
				totalConsumed += consumed
				break
			}
		}
		parts = append(parts, strings.Split(part, ".")...)
		totalConsumed += consumed
		if len(s) <= totalConsumed || s[totalConsumed] != '.' {
			break
		}
		totalConsumed++
	}
	return parts, totalConsumed, nil
}
func formatTableID(parts []string, defaultProjectID string) (string, error) {
	if len(parts) < 2 || len(parts) > 3 {
		// Not a table identifier (could be a CTE, column, etc.).
		// Return the consumed length so the main loop can skip this identifier.
		return "", nil
	}
	var tableID string
	if len(parts) == 3 { // project.dataset.table
		tableID = strings.Join(parts, ".")
	} else { // dataset.table
		if defaultProjectID == "" {
			return "", fmt.Errorf("query contains table '%s' without project ID, and no default project ID is provided", strings.Join(parts, "."))
		}
		tableID = fmt.Sprintf("%s.%s", defaultProjectID, strings.Join(parts, "."))
	}
	return tableID, nil
}