package proto
import (
"fmt"
"io/fs"
"log/slog"
"path/filepath"
"sort"
"strings"
"sync"
"github.com/lithammer/fuzzysearch/fuzzy"
sahilfuzzy "github.com/sahilm/fuzzy"
)
// SearchResult represents a search result with metadata
type SearchResult struct {
Name string `json:"name"`
Type string `json:"type"`
File string `json:"file"`
Score int `json:"score"`
MatchType string `json:"match_type"`
Comment string `json:"comment,omitempty"`
RPCs []string `json:"rpcs,omitempty"`
RPCCount int `json:"rpc_count,omitempty"`
Fields []string `json:"fields,omitempty"`
FieldCount int `json:"field_count,omitempty"`
Values []string `json:"values,omitempty"`
ValueCount int `json:"value_count,omitempty"`
MatchedRPC string `json:"matched_rpc,omitempty"`
MatchedField string `json:"matched_field,omitempty"`
}
// Stats represents indexing statistics
type Stats struct {
TotalFiles int `json:"total_files"`
TotalServices int `json:"total_services"`
TotalMessages int `json:"total_messages"`
TotalEnums int `json:"total_enums"`
TotalSearchableEntries int `json:"total_searchable_entries"`
}
type searchEntry struct {
fullName string
entryType string
filePath string
service *ProtoService
message *ProtoMessage
enum *ProtoEnum
}
// ProtoIndex is an in-memory index of proto files with search capabilities
type ProtoIndex struct {
mu sync.RWMutex
files map[string]*ProtoFile
services map[string]*ProtoService
messages map[string]*ProtoMessage
enums map[string]*ProtoEnum
searchEntries []searchEntry
logger *slog.Logger
}
// NewProtoIndex creates a new proto index
func NewProtoIndex(logger *slog.Logger) *ProtoIndex {
if logger == nil {
logger = slog.Default()
}
return &ProtoIndex{
files: make(map[string]*ProtoFile),
services: make(map[string]*ProtoService),
messages: make(map[string]*ProtoMessage),
enums: make(map[string]*ProtoEnum),
searchEntries: make([]searchEntry, 0),
logger: logger,
}
}
// IndexDirectory recursively scans directory for .proto files and indexes them
func (pi *ProtoIndex) IndexDirectory(rootPath string) (int, error) {
matches, err := filepath.Glob(filepath.Join(rootPath, "**/*.proto"))
if err != nil {
return 0, fmt.Errorf("failed to glob proto files: %w", err)
}
// Also try direct scan
count := 0
err = filepath.Walk(rootPath, func(path string, info fs.FileInfo, err error) error {
if err != nil {
return err
}
if filepath.Ext(path) == ".proto" {
if err := pi.IndexFile(path); err != nil {
pi.logger.Error("failed to index file", "path", path, "error", err)
} else {
count++
}
}
return nil
})
if err != nil {
return count, fmt.Errorf("failed to walk directory: %w", err)
}
pi.logger.Info("indexed proto files", "count", count)
// Also index matches from glob if any
for _, match := range matches {
if err := pi.IndexFile(match); err != nil {
pi.logger.Error("failed to index file", "path", match, "error", err)
}
}
return count, nil
}
// IndexFile parses and indexes a single proto file
func (pi *ProtoIndex) IndexFile(filePath string) error {
parser := NewParser()
protoFile, err := parser.ParseFile(filePath)
if err != nil {
return fmt.Errorf("failed to parse file: %w", err)
}
pi.mu.Lock()
defer pi.mu.Unlock()
pi.files[filePath] = protoFile
// Index services
for i := range protoFile.Services {
service := &protoFile.Services[i]
pi.services[service.FullName] = service
pi.searchEntries = append(pi.searchEntries, searchEntry{
fullName: service.FullName,
entryType: "service",
filePath: filePath,
service: service,
})
}
// Index messages
for i := range protoFile.Messages {
message := &protoFile.Messages[i]
pi.messages[message.FullName] = message
pi.searchEntries = append(pi.searchEntries, searchEntry{
fullName: message.FullName,
entryType: "message",
filePath: filePath,
message: message,
})
}
// Index enums
for i := range protoFile.Enums {
enum := &protoFile.Enums[i]
pi.enums[enum.FullName] = enum
pi.searchEntries = append(pi.searchEntries, searchEntry{
fullName: enum.FullName,
entryType: "enum",
filePath: filePath,
enum: enum,
})
}
pi.logger.Debug("indexed file",
"path", filePath,
"services", len(protoFile.Services),
"messages", len(protoFile.Messages),
"enums", len(protoFile.Enums),
)
return nil
}
// RemoveFile removes a file from the index
func (pi *ProtoIndex) RemoveFile(filePath string) {
pi.mu.Lock()
defer pi.mu.Unlock()
protoFile, exists := pi.files[filePath]
if !exists {
return
}
// Remove services
for _, service := range protoFile.Services {
delete(pi.services, service.FullName)
}
// Remove messages
for _, message := range protoFile.Messages {
delete(pi.messages, message.FullName)
}
// Remove enums
for _, enum := range protoFile.Enums {
delete(pi.enums, enum.FullName)
}
// Remove from search entries
newEntries := make([]searchEntry, 0, len(pi.searchEntries))
for _, entry := range pi.searchEntries {
if entry.filePath != filePath {
newEntries = append(newEntries, entry)
}
}
pi.searchEntries = newEntries
delete(pi.files, filePath)
pi.logger.Debug("removed file from index", "path", filePath)
}
// Search performs fuzzy search across all proto definitions
// Searches in: names, field names, RPC names, and comments
func (pi *ProtoIndex) Search(query string, limit, minScore int) []SearchResult {
if query == "" {
return nil
}
pi.mu.RLock()
defer pi.mu.RUnlock()
var results []SearchResult
seen := make(map[string]bool)
queryLower := strings.ToLower(query)
// 1. Search in definition names (highest priority)
nameMatches := pi.searchInNames(query, minScore)
for _, result := range nameMatches {
if !seen[result.Name] {
results = append(results, result)
seen[result.Name] = true
}
}
// 2. Search in field names (for messages)
if len(results) < limit {
fieldMatches := pi.searchInFields(queryLower, minScore, seen)
results = append(results, fieldMatches...)
}
// 3. Search in RPC names (for services)
if len(results) < limit {
rpcMatches := pi.searchInRPCs(queryLower, minScore, seen)
results = append(results, rpcMatches...)
}
// 4. Search in comments
if len(results) < limit {
commentMatches := pi.searchInComments(queryLower, minScore, seen)
results = append(results, commentMatches...)
}
// Sort by score (descending) and limit results
sort.Slice(results, func(i, j int) bool {
return results[i].Score > results[j].Score
})
if len(results) > limit {
results = results[:limit]
}
return results
}
// searchInNames performs fuzzy search on definition names
func (pi *ProtoIndex) searchInNames(query string, minScore int) []SearchResult {
// Build list of searchable names
names := make([]string, len(pi.searchEntries))
for i, entry := range pi.searchEntries {
names[i] = entry.fullName
}
queryLower := strings.ToLower(query)
var results []SearchResult
seen := make(map[int]bool)
// Strategy 0: Multi-word search (highest priority for multi-word queries)
// Split query by spaces and check if all words appear in the name
queryWords := strings.Fields(queryLower)
if len(queryWords) > 1 {
for i, name := range names {
nameLower := strings.ToLower(name)
// Remove spaces from both for comparison
nameNoSpaces := strings.ReplaceAll(nameLower, " ", "")
queryNoSpaces := strings.ReplaceAll(queryLower, " ", "")
// Try exact match without spaces first
if strings.Contains(nameNoSpaces, queryNoSpaces) {
score := 100
// Adjust based on position and length
if strings.HasSuffix(nameNoSpaces, queryNoSpaces) {
score = 100
} else if strings.HasPrefix(nameNoSpaces, queryNoSpaces) {
score = 99
} else {
score = 97
}
if score >= minScore {
entry := pi.searchEntries[i]
result := pi.createSearchResult(entry, score, "name")
results = append(results, result)
seen[i] = true
}
continue
}
// Check if all query words appear in order (subsequence)
allWordsMatch := true
searchFrom := 0
matchPositions := make([]int, 0, len(queryWords))
for _, word := range queryWords {
idx := strings.Index(nameNoSpaces[searchFrom:], word)
if idx == -1 {
allWordsMatch = false
break
}
actualPos := searchFrom + idx
matchPositions = append(matchPositions, actualPos)
searchFrom = actualPos + len(word)
}
if allWordsMatch {
// Calculate score based on match quality
score := 95
// Bonus if words match at camelCase boundaries
simpleName := name
if lastDot := strings.LastIndex(name, "."); lastDot >= 0 {
simpleName = name[lastDot+1:]
}
// Check if words align with camelCase boundaries
camelBonus := 0
for _, word := range queryWords {
if matchesCamelCase(simpleName, word) {
camelBonus += 2
}
}
score += camelBonus
// Adjust for compactness of match (how close are the words?)
if len(matchPositions) > 1 {
totalGap := matchPositions[len(matchPositions)-1] - matchPositions[0]
queryTotalLen := 0
for _, word := range queryWords {
queryTotalLen += len(word)
}
compactness := float64(queryTotalLen) / float64(totalGap+queryTotalLen)
if compactness > 0.7 {
score += 3
}
}
if score >= minScore {
entry := pi.searchEntries[i]
result := pi.createSearchResult(entry, score, "name")
results = append(results, result)
seen[i] = true
}
}
}
}
// Strategy 1: Exact substring matches (case-insensitive)
for i, name := range names {
if seen[i] {
continue
}
nameLower := strings.ToLower(name)
nameNoSpaces := strings.ReplaceAll(nameLower, " ", "")
queryNoSpaces := strings.ReplaceAll(queryLower, " ", "")
// Try without spaces first
if idx := strings.Index(nameNoSpaces, queryNoSpaces); idx >= 0 {
score := 100
// Adjust based on position
if strings.HasSuffix(nameNoSpaces, queryNoSpaces) {
score = 100 // Perfect suffix match (simple name)
} else if idx == 0 {
score = 98 // Match at beginning
} else {
score = 95 // Match in middle
}
// Adjust for length ratio
lengthRatio := float64(len(name)) / float64(len(query))
if lengthRatio > 5.0 {
score -= 3 // Penalize very long FQNs
}
if score >= minScore {
entry := pi.searchEntries[i]
result := pi.createSearchResult(entry, score, "name")
results = append(results, result)
seen[i] = true
}
continue
}
// Try with spaces
if idx := strings.Index(nameLower, queryLower); idx >= 0 {
score := 100
// Adjust based on position
if strings.HasSuffix(nameLower, queryLower) {
score = 100 // Perfect suffix match (simple name)
} else if idx == 0 {
score = 98 // Match at beginning
} else if idx > 0 && nameLower[idx-1] == '.' {
score = 97 // Match after package separator
} else {
score = 95 // Match in middle
}
// Adjust for length ratio
lengthRatio := float64(len(name)) / float64(len(query))
if lengthRatio > 5.0 {
score -= 3 // Penalize very long FQNs
}
if score >= minScore {
entry := pi.searchEntries[i]
result := pi.createSearchResult(entry, score, "name")
results = append(results, result)
seen[i] = true
}
}
}
// Strategy 2: Levenshtein distance for typo tolerance
// Check each name's simple name (last part after final dot) against query
queryNoSpaces := strings.ReplaceAll(queryLower, " ", "")
for i, name := range names {
if seen[i] {
continue
}
// Extract simple name (last component)
simpleName := name
if lastDot := strings.LastIndex(name, "."); lastDot >= 0 {
simpleName = name[lastDot+1:]
}
simpleNameLower := strings.ToLower(simpleName)
// Calculate Levenshtein distance (try both with and without spaces)
distance := fuzzy.LevenshteinDistance(queryNoSpaces, simpleNameLower)
// Convert distance to score (0-100)
// For similar lengths, small distances should score high
maxLen := len(queryNoSpaces)
if len(simpleNameLower) > maxLen {
maxLen = len(simpleNameLower)
}
if maxLen == 0 {
continue
}
// Score based on how many characters are correct
similarity := float64(maxLen-distance) / float64(maxLen)
score := int(similarity * 100)
// Require high similarity for Levenshtein matches (at least 70%)
if score >= 70 && score >= minScore {
entry := pi.searchEntries[i]
result := pi.createSearchResult(entry, score, "name")
results = append(results, result)
seen[i] = true
}
}
// Strategy 3: Subsequence matching with sahilm/fuzzy (like VSCode)
// This catches cases like "UsrSvc" matching "UserService"
// Try both with spaces and without
queryVariants := []string{query}
if strings.Contains(query, " ") {
queryVariants = append(queryVariants, strings.ReplaceAll(query, " ", ""))
}
for _, qVariant := range queryVariants {
matches := sahilfuzzy.Find(qVariant, names)
for _, match := range matches {
if seen[match.Index] {
continue
}
score := calculateSubsequenceScore(match.Score, len(qVariant), len(match.Str))
if score >= minScore {
entry := pi.searchEntries[match.Index]
result := pi.createSearchResult(entry, score, "name")
results = append(results, result)
seen[match.Index] = true
}
}
}
return results
}
// matchesCamelCase checks if a word matches a camelCase boundary in a name
func matchesCamelCase(name, word string) bool {
wordLower := strings.ToLower(word)
// Look for the word at camelCase boundaries (uppercase letters)
for i := 0; i < len(name); i++ {
// Check if this is a camelCase boundary (uppercase letter or start)
if i == 0 || (i > 0 && name[i] >= 'A' && name[i] <= 'Z') {
// Check if word matches here
if strings.HasPrefix(strings.ToLower(name[i:]), wordLower) {
return true
}
}
}
return false
}
// searchInFields searches for query in message field names
func (pi *ProtoIndex) searchInFields(query string, minScore int, seen map[string]bool) []SearchResult {
var results []SearchResult
queryLower := strings.ToLower(query)
for _, entry := range pi.searchEntries {
if seen[entry.fullName] || entry.entryType != "message" || entry.message == nil {
continue
}
// Check each field for matches
var bestScore int
var bestField string
for _, field := range entry.message.Fields {
fieldLower := strings.ToLower(field.Name)
// Try exact match first
if fieldLower == queryLower {
bestScore = 100
bestField = field.Name
break
}
// Try substring match
if strings.Contains(fieldLower, queryLower) {
score := 95
if score > bestScore {
bestScore = score
bestField = field.Name
}
continue
}
// Try Levenshtein distance for typo tolerance
distance := fuzzy.LevenshteinDistance(queryLower, fieldLower)
maxLen := len(queryLower)
if len(fieldLower) > maxLen {
maxLen = len(fieldLower)
}
if maxLen > 0 {
similarity := float64(maxLen-distance) / float64(maxLen)
score := int(similarity * 100)
if score >= 70 && score > bestScore {
bestScore = score
bestField = field.Name
}
}
}
if bestScore >= minScore && bestField != "" {
result := pi.createSearchResult(entry, bestScore, "field")
result.MatchedField = bestField
results = append(results, result)
seen[entry.fullName] = true
}
}
return results
}
// searchInRPCs searches for query in service RPC names
func (pi *ProtoIndex) searchInRPCs(query string, minScore int, seen map[string]bool) []SearchResult {
var results []SearchResult
queryLower := strings.ToLower(query)
for _, entry := range pi.searchEntries {
if seen[entry.fullName] || entry.entryType != "service" || entry.service == nil {
continue
}
// Check each RPC for matches
var bestScore int
var bestRPC string
for _, rpc := range entry.service.RPCs {
rpcLower := strings.ToLower(rpc.Name)
// Try exact match first
if rpcLower == queryLower {
bestScore = 100
bestRPC = rpc.Name
break
}
// Try substring match
if strings.Contains(rpcLower, queryLower) {
score := 95
if score > bestScore {
bestScore = score
bestRPC = rpc.Name
}
continue
}
// Try Levenshtein distance for typo tolerance
distance := fuzzy.LevenshteinDistance(queryLower, rpcLower)
maxLen := len(queryLower)
if len(rpcLower) > maxLen {
maxLen = len(rpcLower)
}
if maxLen > 0 {
similarity := float64(maxLen-distance) / float64(maxLen)
score := int(similarity * 100)
if score >= 70 && score > bestScore {
bestScore = score
bestRPC = rpc.Name
}
}
}
if bestScore >= minScore && bestRPC != "" {
result := pi.createSearchResult(entry, bestScore, "rpc")
result.MatchedRPC = bestRPC
results = append(results, result)
seen[entry.fullName] = true
}
}
return results
}
// searchInComments searches for query in comments
func (pi *ProtoIndex) searchInComments(query string, minScore int, seen map[string]bool) []SearchResult {
var results []SearchResult
for _, entry := range pi.searchEntries {
if seen[entry.fullName] {
continue
}
var comment string
switch entry.entryType {
case "service":
if entry.service != nil {
comment = entry.service.Comment
}
case "message":
if entry.message != nil {
comment = entry.message.Comment
}
case "enum":
if entry.enum != nil {
comment = entry.enum.Comment
}
}
if comment == "" {
continue
}
// Simple substring match for comments (case-insensitive)
commentLower := strings.ToLower(comment)
if strings.Contains(commentLower, query) {
// Score based on position and length
score := calculateCommentScore(query, commentLower)
if score >= minScore {
result := pi.createSearchResult(entry, score, "comment")
results = append(results, result)
seen[entry.fullName] = true
}
}
}
return results
}
// createSearchResult creates a SearchResult from a search entry
func (pi *ProtoIndex) createSearchResult(entry searchEntry, score int, matchType string) SearchResult {
result := SearchResult{
Name: entry.fullName,
Type: entry.entryType,
File: entry.filePath,
Score: score,
MatchType: matchType,
}
// Add type-specific metadata
switch entry.entryType {
case "service":
if entry.service != nil {
result.RPCCount = len(entry.service.RPCs)
result.RPCs = make([]string, len(entry.service.RPCs))
for i, rpc := range entry.service.RPCs {
result.RPCs[i] = rpc.Name
}
result.Comment = entry.service.Comment
}
case "message":
if entry.message != nil {
result.FieldCount = len(entry.message.Fields)
result.Fields = make([]string, len(entry.message.Fields))
for i, field := range entry.message.Fields {
result.Fields[i] = field.Name
}
result.Comment = entry.message.Comment
}
case "enum":
if entry.enum != nil {
result.ValueCount = len(entry.enum.Values)
result.Values = make([]string, len(entry.enum.Values))
for i, value := range entry.enum.Values {
result.Values[i] = value.Name
}
result.Comment = entry.enum.Comment
}
}
return result
}
// calculateSubsequenceScore converts sahilm/fuzzy library score to 0-100 scale
// sahilm/fuzzy: lower score = better match, but scores can be very large for long strings with gaps
// we want: higher score = better match (100 = exact)
func calculateSubsequenceScore(fuzzyScore, queryLen, targetLen int) int {
// For exact matches
if fuzzyScore == 0 {
return 100
}
// The fuzzy library gives very large scores for distant matches.
// We need a better approach based on the characteristics of the match.
// Calculate a score based on the density of the match
// Lower fuzzy scores relative to target length indicate better matches
// Base score calculation:
// Good matches have low fuzzyScore relative to targetLen
// The score represents penalties for gaps and distance
// Normalize the fuzzy score by target length to get a penalty ratio
penaltyRatio := float64(fuzzyScore) / float64(targetLen)
// Convert penalty ratio to a score (0-100)
// penaltyRatio < 1.0 = very good match (95-100)
// penaltyRatio 1-10 = good match (80-95)
// penaltyRatio 10-100 = moderate match (60-80)
// penaltyRatio > 100 = poor match (< 60)
var baseScore int
if penaltyRatio < 1.0 {
baseScore = 95 + int((1.0-penaltyRatio)*5.0)
} else if penaltyRatio < 10.0 {
baseScore = 80 + int((10.0-penaltyRatio)*1.5)
} else if penaltyRatio < 100.0 {
baseScore = 60 + int((100.0-penaltyRatio)*0.2)
} else {
baseScore = int(60.0 * (1000.0 / (penaltyRatio + 900.0)))
}
// Bonus for targets close in length to query (more precise match)
lengthRatio := float64(targetLen) / float64(queryLen)
if lengthRatio >= 1.0 && lengthRatio <= 3.0 {
// Target is 1-3x the query length - good precision
baseScore += 5
} else if lengthRatio > 10.0 {
// Very long target compared to query - less precise
baseScore -= 5
}
// Cap the score
if baseScore > 100 {
baseScore = 100
}
if baseScore < 0 {
baseScore = 0
}
return baseScore
}
// calculateCommentScore scores comment matches
func calculateCommentScore(query, commentLower string) int {
// Base score for containing the query
score := 70
// Bonus if query is at the start
if strings.HasPrefix(commentLower, query) {
score += 15
} else {
// Check if it's at word boundary
idx := strings.Index(commentLower, query)
if idx > 0 && (commentLower[idx-1] == ' ' || commentLower[idx-1] == '\t') {
score += 10
}
}
// Bonus for exact word match
words := strings.Fields(commentLower)
for _, word := range words {
if word == query {
score += 10
break
}
}
// Penalty for very long comments (less precise match)
if len(commentLower) > len(query)*10 {
score -= 5
}
if score > 100 {
score = 100
}
if score < 0 {
score = 0
}
return score
}
// GetStats returns statistics about the indexed proto files
func (pi *ProtoIndex) GetStats() Stats {
pi.mu.RLock()
defer pi.mu.RUnlock()
return Stats{
TotalFiles: len(pi.files),
TotalServices: len(pi.services),
TotalMessages: len(pi.messages),
TotalEnums: len(pi.enums),
TotalSearchableEntries: len(pi.searchEntries),
}
}
// GetService retrieves a service by name
func (pi *ProtoIndex) GetService(name string, resolveTypes bool, maxDepth int) (map[string]interface{}, error) {
pi.mu.RLock()
defer pi.mu.RUnlock()
// Try exact match first
service, exists := pi.services[name]
if !exists {
// Try fuzzy match
for fullName, svc := range pi.services {
if endsWith(fullName, "."+name) || svc.Name == name {
service = svc
break
}
}
}
if service == nil {
return nil, fmt.Errorf("service not found: %s", name)
}
// Build result
result := map[string]interface{}{
"name": service.Name,
"full_name": service.FullName,
"comment": service.Comment,
"file": pi.findFileForDefinition(service.FullName, "service"),
}
// Add RPCs
rpcs := make([]map[string]interface{}, len(service.RPCs))
for i, rpc := range service.RPCs {
rpcs[i] = map[string]interface{}{
"name": rpc.Name,
"request_type": rpc.RequestType,
"response_type": rpc.ResponseType,
"request_streaming": rpc.RequestStreaming,
"response_streaming": rpc.ResponseStreaming,
"comment": rpc.Comment,
}
}
result["rpcs"] = rpcs
// Recursively resolve request/response types
if resolveTypes && maxDepth > 0 {
resolvedTypes := pi.resolveServiceTypes(service, maxDepth)
if len(resolvedTypes) > 0 {
result["resolved_types"] = resolvedTypes
}
}
return result, nil
}
// GetMessage retrieves a message by name
func (pi *ProtoIndex) GetMessage(name string, resolveTypes bool, maxDepth int) (map[string]interface{}, error) {
pi.mu.RLock()
defer pi.mu.RUnlock()
// Try exact match first
message, exists := pi.messages[name]
if !exists {
// Try fuzzy match
for fullName, msg := range pi.messages {
if endsWith(fullName, "."+name) || msg.Name == name {
message = msg
break
}
}
}
if message == nil {
return nil, fmt.Errorf("message not found: %s", name)
}
// Build result
result := map[string]interface{}{
"name": message.Name,
"full_name": message.FullName,
"comment": message.Comment,
"file": pi.findFileForDefinition(message.FullName, "message"),
}
// Add fields
fields := make([]map[string]interface{}, len(message.Fields))
for i, field := range message.Fields {
fields[i] = map[string]interface{}{
"name": field.Name,
"type": field.Type,
"number": field.Number,
"label": field.Label,
"comment": field.Comment,
}
}
result["fields"] = fields
// Recursively resolve field types
if resolveTypes && maxDepth > 0 {
resolvedTypes := pi.resolveMessageTypes(message, maxDepth, nil)
if len(resolvedTypes) > 0 {
result["resolved_types"] = resolvedTypes
}
}
return result, nil
}
// GetEnum retrieves an enum by name
func (pi *ProtoIndex) GetEnum(name string) (map[string]interface{}, error) {
pi.mu.RLock()
defer pi.mu.RUnlock()
// Try exact match first
enum, exists := pi.enums[name]
if !exists {
// Try fuzzy match
for fullName, e := range pi.enums {
if endsWith(fullName, "."+name) || e.Name == name {
enum = e
break
}
}
}
if enum == nil {
return nil, fmt.Errorf("enum not found: %s", name)
}
// Build result
result := map[string]interface{}{
"name": enum.Name,
"full_name": enum.FullName,
"comment": enum.Comment,
"file": pi.findFileForDefinition(enum.FullName, "enum"),
}
// Add values
values := make([]map[string]interface{}, len(enum.Values))
for i, value := range enum.Values {
values[i] = map[string]interface{}{
"name": value.Name,
"number": value.Number,
"comment": value.Comment,
}
}
result["values"] = values
return result, nil
}
func (pi *ProtoIndex) findFileForDefinition(fullName, defType string) string {
for filePath, protoFile := range pi.files {
switch defType {
case "service":
for _, s := range protoFile.Services {
if s.FullName == fullName {
return filePath
}
}
case "message":
for _, m := range protoFile.Messages {
if m.FullName == fullName {
return filePath
}
}
case "enum":
for _, e := range protoFile.Enums {
if e.FullName == fullName {
return filePath
}
}
}
}
return ""
}
func endsWith(s, suffix string) bool {
return len(s) >= len(suffix) && s[len(s)-len(suffix):] == suffix
}
// TypeUsage represents a usage of a type in a service/RPC
type TypeUsage struct {
ServiceName string `json:"service_name"`
ServiceFullName string `json:"service_full_name"`
ServiceFile string `json:"service_file"`
RPCName string `json:"rpc_name"`
UsageContext string `json:"usage_context"` // e.g., "Request", "Response"
MessageType string `json:"message_type"` // The immediate message containing the type
FieldPath []string `json:"field_path"` // Path to the field containing the type
Depth int `json:"depth"` // How deep the type is nested
}
// FindTypeUsages finds all services and RPCs that use a given type (directly or transitively)
func (pi *ProtoIndex) FindTypeUsages(typeName string) ([]TypeUsage, error) {
pi.mu.RLock()
defer pi.mu.RUnlock()
// Resolve the type to its full name
targetFullName := pi.resolveTypeName(typeName)
if targetFullName == "" {
return nil, fmt.Errorf("type not found: %s", typeName)
}
var usages []TypeUsage
// Iterate through all services
for _, service := range pi.services {
// Check each RPC method
for _, rpc := range service.RPCs {
// Check request type
if requestUsages := pi.findTypeInMessage(rpc.RequestType, targetFullName, service, rpc.Name, "Request"); len(requestUsages) > 0 {
usages = append(usages, requestUsages...)
}
// Check response type
if responseUsages := pi.findTypeInMessage(rpc.ResponseType, targetFullName, service, rpc.Name, "Response"); len(responseUsages) > 0 {
usages = append(usages, responseUsages...)
}
}
}
return usages, nil
}
// resolveTypeName resolves a simple or fully qualified type name to its full name
func (pi *ProtoIndex) resolveTypeName(typeName string) string {
// Try exact match in messages
if _, exists := pi.messages[typeName]; exists {
return typeName
}
// Try exact match in enums
if _, exists := pi.enums[typeName]; exists {
return typeName
}
// Try fuzzy match in messages
for fullName, msg := range pi.messages {
if endsWith(fullName, "."+typeName) || msg.Name == typeName {
return fullName
}
}
// Try fuzzy match in enums
for fullName, enum := range pi.enums {
if endsWith(fullName, "."+typeName) || enum.Name == typeName {
return fullName
}
}
return ""
}
// findTypeInMessage recursively searches for a target type within a message
func (pi *ProtoIndex) findTypeInMessage(messageTypeName, targetFullName string, service *ProtoService, rpcName, context string) []TypeUsage {
var usages []TypeUsage
// Resolve the message type
resolvedMessageName := pi.resolveTypeName(messageTypeName)
if resolvedMessageName == "" {
return usages
}
// Track visited types to avoid infinite recursion
visited := make(map[string]bool)
pi.findTypeInMessageRecursive(resolvedMessageName, targetFullName, service, rpcName, context, []string{}, 0, visited, &usages)
return usages
}
// findTypeInMessageRecursive is the recursive helper for finding type usages
func (pi *ProtoIndex) findTypeInMessageRecursive(messageFullName, targetFullName string, service *ProtoService, rpcName, context string, fieldPath []string, depth int, visited map[string]bool, usages *[]TypeUsage) {
// Prevent infinite recursion
if visited[messageFullName] {
return
}
visited[messageFullName] = true
// Check if this message itself is the target type
if messageFullName == targetFullName {
serviceFile := pi.findFileForDefinition(service.FullName, "service")
*usages = append(*usages, TypeUsage{
ServiceName: service.Name,
ServiceFullName: service.FullName,
ServiceFile: serviceFile,
RPCName: rpcName,
UsageContext: context,
MessageType: messageFullName,
FieldPath: append([]string{}, fieldPath...), // Copy slice
Depth: depth,
})
return
}
// Get the message definition
message, exists := pi.messages[messageFullName]
if !exists {
return
}
// Check each field in the message
for _, field := range message.Fields {
fieldType := field.Type
// Skip primitive types
if isPrimitiveType(fieldType) {
continue
}
// Resolve the field type (try both message and enum)
resolvedFieldType := ""
if msg := pi.findMessageByType(fieldType, messageFullName); msg != nil {
resolvedFieldType = msg.FullName
} else if enum := pi.findEnumByType(fieldType, messageFullName); enum != nil {
resolvedFieldType = enum.FullName
}
// If we couldn't resolve the type, skip it
if resolvedFieldType == "" {
continue
}
// Check if the field type is the target
if resolvedFieldType == targetFullName {
serviceFile := pi.findFileForDefinition(service.FullName, "service")
newFieldPath := append(append([]string{}, fieldPath...), field.Name)
*usages = append(*usages, TypeUsage{
ServiceName: service.Name,
ServiceFullName: service.FullName,
ServiceFile: serviceFile,
RPCName: rpcName,
UsageContext: context,
MessageType: messageFullName,
FieldPath: newFieldPath,
Depth: depth + 1,
})
} else {
// Recursively search in this field's type (only if it's a message, not an enum)
if pi.messages[resolvedFieldType] != nil {
newFieldPath := append(append([]string{}, fieldPath...), field.Name)
pi.findTypeInMessageRecursive(resolvedFieldType, targetFullName, service, rpcName, context, newFieldPath, depth+1, visited, usages)
}
}
}
}