Skip to main content
Glama

MCP Toolbox for Databases

by googleapis
Apache 2.0
11,060
  • Linux
neo4jschema.go23.8 kB
// Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package neo4jschema import ( "context" "fmt" "sync" "time" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" neo4jsc "github.com/googleapis/genai-toolbox/internal/sources/neo4j" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/cache" "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/helpers" "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/types" "github.com/neo4j/neo4j-go-driver/v5/neo4j" ) // kind defines the unique identifier for this tool. const kind string = "neo4j-schema" // init registers the tool with the application's tool registry when the package is initialized. func init() { if !tools.Register(kind, newConfig) { panic(fmt.Sprintf("tool kind %q already registered", kind)) } } // newConfig decodes a YAML configuration into a Config struct. // This function is called by the tool registry to create a new configuration object. func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { actual := Config{Name: name} if err := decoder.DecodeContext(ctx, &actual); err != nil { return nil, err } return actual, nil } // compatibleSource defines the interface a data source must implement to be used by this tool. // It ensures that the source can provide a Neo4j driver and database name. type compatibleSource interface { Neo4jDriver() neo4j.DriverWithContext Neo4jDatabase() string } // Statically verify that our compatible source implementation is valid. var _ compatibleSource = &neo4jsc.Source{} // compatibleSources lists the kinds of sources that are compatible with this tool. var compatibleSources = [...]string{neo4jsc.SourceKind} // Config holds the configuration settings for the Neo4j schema tool. // These settings are typically read from a YAML file. type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` Source string `yaml:"source" validate:"required"` Description string `yaml:"description" validate:"required"` AuthRequired []string `yaml:"authRequired"` CacheExpireMinutes *int `yaml:"cacheExpireMinutes,omitempty"` // Cache expiration time in minutes. } // Statically verify that Config implements the tools.ToolConfig interface. var _ tools.ToolConfig = Config{} // ToolConfigKind returns the kind of this tool configuration. func (cfg Config) ToolConfigKind() string { return kind } // Initialize sets up the tool with its dependencies and returns a ready-to-use Tool instance. func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { // Verify that the specified source exists. rawS, ok := srcs[cfg.Source] if !ok { return nil, fmt.Errorf("no source named %q configured", cfg.Source) } // Verify the source is of a compatible kind. s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) } parameters := tools.Parameters{} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters) // Set a default cache expiration if not provided in the configuration. if cfg.CacheExpireMinutes == nil { defaultExpiration := cache.DefaultExpiration // Default to 60 minutes cfg.CacheExpireMinutes = &defaultExpiration } // Finish tool setup by creating the Tool instance. t := Tool{ Name: cfg.Name, Kind: kind, AuthRequired: cfg.AuthRequired, Driver: s.Neo4jDriver(), Database: s.Neo4jDatabase(), cache: cache.NewCache(), cacheExpireMinutes: cfg.CacheExpireMinutes, manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } return t, nil } // Statically verify that Tool implements the tools.Tool interface. var _ tools.Tool = Tool{} // Tool represents the Neo4j schema extraction tool. // It holds the Neo4j driver, database information, and a cache for the schema. type Tool struct { Name string `yaml:"name"` Kind string `yaml:"kind"` AuthRequired []string `yaml:"authRequired"` Driver neo4j.DriverWithContext Database string cache *cache.Cache cacheExpireMinutes *int manifest tools.Manifest mcpManifest tools.McpManifest } // Invoke executes the tool's main logic: fetching the Neo4j schema. // It first checks the cache for a valid schema before extracting it from the database. func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { // Check if a valid schema is already in the cache. if cachedSchema, ok := t.cache.Get("schema"); ok { if schema, ok := cachedSchema.(*types.SchemaInfo); ok { return schema, nil } } // If not cached, extract the schema from the database. schema, err := t.extractSchema(ctx) if err != nil { return nil, fmt.Errorf("failed to extract database schema: %w", err) } // Cache the newly extracted schema for future use. expiration := time.Duration(*t.cacheExpireMinutes) * time.Minute t.cache.Set("schema", schema, expiration) return schema, nil } // ParseParams is a placeholder as this tool does not require input parameters. func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (tools.ParamValues, error) { return tools.ParamValues{}, nil } // Manifest returns the tool's manifest, which describes its purpose and parameters. func (t Tool) Manifest() tools.Manifest { return t.manifest } // McpManifest returns the machine-consumable manifest for the tool. func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } // Authorized checks if the tool is authorized to run based on the provided authentication services. func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } func (t Tool) RequiresClientAuthorization() bool { return false } // checkAPOCProcedures verifies if essential APOC procedures are available in the database. // It returns true only if all required procedures are found. func (t Tool) checkAPOCProcedures(ctx context.Context) (bool, error) { proceduresToCheck := []string{"apoc.meta.schema", "apoc.meta.cypher.types"} session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database}) defer session.Close(ctx) // This query efficiently counts how many of the specified procedures exist. query := "SHOW PROCEDURES YIELD name WHERE name IN $procs RETURN count(name) AS procCount" params := map[string]any{"procs": proceduresToCheck} result, err := session.Run(ctx, query, params) if err != nil { return false, fmt.Errorf("failed to execute procedure check query: %w", err) } record, err := result.Single(ctx) if err != nil { return false, fmt.Errorf("failed to retrieve single result for procedure check: %w", err) } rawCount, found := record.Get("procCount") if !found { return false, fmt.Errorf("field 'procCount' not found in result record") } procCount, ok := rawCount.(int64) if !ok { return false, fmt.Errorf("expected 'procCount' to be of type int64, but got %T", rawCount) } // Return true only if the number of found procedures matches the number we were looking for. return procCount == int64(len(proceduresToCheck)), nil } // extractSchema orchestrates the concurrent extraction of different parts of the database schema. // It runs several extraction tasks in parallel for efficiency. func (t Tool) extractSchema(ctx context.Context) (*types.SchemaInfo, error) { schema := &types.SchemaInfo{} var mu sync.Mutex // Define the different schema extraction tasks. tasks := []struct { name string fn func() error }{ { name: "database-info", fn: func() error { dbInfo, err := t.extractDatabaseInfo(ctx) if err != nil { return fmt.Errorf("failed to extract database info: %w", err) } mu.Lock() defer mu.Unlock() schema.DatabaseInfo = *dbInfo return nil }, }, { name: "schema-extraction", fn: func() error { // Check if APOC procedures are available. hasAPOC, err := t.checkAPOCProcedures(ctx) if err != nil { return fmt.Errorf("failed to check APOC procedures: %w", err) } var nodeLabels []types.NodeLabel var relationships []types.Relationship var stats *types.Statistics // Use APOC if available for a more detailed schema; otherwise, use native queries. if hasAPOC { nodeLabels, relationships, stats, err = t.GetAPOCSchema(ctx) } else { nodeLabels, relationships, stats, err = t.GetSchemaWithoutAPOC(ctx, 100) } if err != nil { return fmt.Errorf("failed to get schema: %w", err) } mu.Lock() defer mu.Unlock() schema.NodeLabels = nodeLabels schema.Relationships = relationships schema.Statistics = *stats return nil }, }, { name: "constraints", fn: func() error { constraints, err := t.extractConstraints(ctx) if err != nil { return fmt.Errorf("failed to extract constraints: %w", err) } mu.Lock() defer mu.Unlock() schema.Constraints = constraints return nil }, }, { name: "indexes", fn: func() error { indexes, err := t.extractIndexes(ctx) if err != nil { return fmt.Errorf("failed to extract indexes: %w", err) } mu.Lock() defer mu.Unlock() schema.Indexes = indexes return nil }, }, } var wg sync.WaitGroup errCh := make(chan error, len(tasks)) // Execute all tasks concurrently. for _, task := range tasks { wg.Add(1) go func(task struct { name string fn func() error }) { defer wg.Done() if err := task.fn(); err != nil { errCh <- err } }(task) } wg.Wait() close(errCh) // Collect any errors that occurred during the concurrent tasks. for err := range errCh { if err != nil { schema.Errors = append(schema.Errors, err.Error()) } } return schema, nil } // GetAPOCSchema extracts schema information using the APOC library, which provides detailed metadata. func (t Tool) GetAPOCSchema(ctx context.Context) ([]types.NodeLabel, []types.Relationship, *types.Statistics, error) { var nodeLabels []types.NodeLabel var relationships []types.Relationship stats := &types.Statistics{ NodesByLabel: make(map[string]int64), RelationshipsByType: make(map[string]int64), PropertiesByLabel: make(map[string]int64), PropertiesByRelType: make(map[string]int64), } var mu sync.Mutex var firstErr error ctx, cancel := context.WithCancel(ctx) defer cancel() handleError := func(err error) { mu.Lock() defer mu.Unlock() if firstErr == nil { firstErr = err cancel() // Cancel other operations on the first error. } } tasks := []struct { name string fn func(session neo4j.SessionWithContext) error }{ { name: "apoc-schema", fn: func(session neo4j.SessionWithContext) error { result, err := session.Run(ctx, "CALL apoc.meta.schema({sample: 10}) YIELD value RETURN value", nil) if err != nil { return fmt.Errorf("failed to run APOC schema query: %w", err) } if !result.Next(ctx) { return fmt.Errorf("no results from APOC schema query") } schemaMap, ok := result.Record().Values[0].(map[string]any) if !ok { return fmt.Errorf("unexpected result format from APOC schema query: %T", result.Record().Values[0]) } apocSchema, err := helpers.MapToAPOCSchema(schemaMap) if err != nil { return fmt.Errorf("failed to convert schema map to APOCSchemaResult: %w", err) } nodes, _, apocStats := helpers.ProcessAPOCSchema(apocSchema) mu.Lock() defer mu.Unlock() nodeLabels = nodes stats.TotalNodes = apocStats.TotalNodes stats.TotalProperties += apocStats.TotalProperties stats.NodesByLabel = apocStats.NodesByLabel stats.PropertiesByLabel = apocStats.PropertiesByLabel return nil }, }, { name: "apoc-relationships", fn: func(session neo4j.SessionWithContext) error { query := ` MATCH (startNode)-[rel]->(endNode) WITH labels(startNode)[0] AS startNode, type(rel) AS relType, apoc.meta.cypher.types(rel) AS relProperties, labels(endNode)[0] AS endNode, count(*) AS count RETURN relType, startNode, endNode, relProperties, count` result, err := session.Run(ctx, query, nil) if err != nil { return fmt.Errorf("failed to extract relationships: %w", err) } for result.Next(ctx) { record := result.Record() relType, startNode, endNode := record.Values[0].(string), record.Values[1].(string), record.Values[2].(string) properties, count := record.Values[3].(map[string]any), record.Values[4].(int64) if relType == "" || count == 0 { continue } relationship := types.Relationship{Type: relType, StartNode: startNode, EndNode: endNode, Count: count, Properties: []types.PropertyInfo{}} for prop, propType := range properties { relationship.Properties = append(relationship.Properties, types.PropertyInfo{Name: prop, Types: []string{propType.(string)}}) } mu.Lock() relationships = append(relationships, relationship) stats.RelationshipsByType[relType] += count stats.TotalRelationships += count propCount := int64(len(relationship.Properties)) stats.TotalProperties += propCount stats.PropertiesByRelType[relType] += propCount mu.Unlock() } mu.Lock() defer mu.Unlock() if len(stats.RelationshipsByType) == 0 { stats.RelationshipsByType = nil } if len(stats.PropertiesByRelType) == 0 { stats.PropertiesByRelType = nil } return nil }, }, } var wg sync.WaitGroup wg.Add(len(tasks)) for _, task := range tasks { go func(task struct { name string fn func(session neo4j.SessionWithContext) error }) { defer wg.Done() session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database}) defer session.Close(ctx) if err := task.fn(session); err != nil { handleError(fmt.Errorf("task %s failed: %w", task.name, err)) } }(task) } wg.Wait() if firstErr != nil { return nil, nil, nil, firstErr } return nodeLabels, relationships, stats, nil } // GetSchemaWithoutAPOC extracts schema information using native Cypher queries. // This serves as a fallback for databases without APOC installed. func (t Tool) GetSchemaWithoutAPOC(ctx context.Context, sampleSize int) ([]types.NodeLabel, []types.Relationship, *types.Statistics, error) { nodePropsMap := make(map[string]map[string]map[string]bool) relPropsMap := make(map[string]map[string]map[string]bool) nodeCounts := make(map[string]int64) relCounts := make(map[string]int64) relConnectivity := make(map[string]types.RelConnectivityInfo) var mu sync.Mutex var firstErr error ctx, cancel := context.WithCancel(ctx) defer cancel() handleError := func(err error) { mu.Lock() defer mu.Unlock() if firstErr == nil { firstErr = err cancel() } } tasks := []struct { name string fn func(session neo4j.SessionWithContext) error }{ { name: "node-schema", fn: func(session neo4j.SessionWithContext) error { countResult, err := session.Run(ctx, `MATCH (n) UNWIND labels(n) AS label RETURN label, count(*) AS count ORDER BY count DESC`, nil) if err != nil { return fmt.Errorf("node count query failed: %w", err) } var labelsList []string mu.Lock() for countResult.Next(ctx) { record := countResult.Record() label, count := record.Values[0].(string), record.Values[1].(int64) nodeCounts[label] = count labelsList = append(labelsList, label) } mu.Unlock() if err = countResult.Err(); err != nil { return fmt.Errorf("node count result error: %w", err) } for _, label := range labelsList { propQuery := fmt.Sprintf(`MATCH (n:%s) WITH n LIMIT $sampleSize UNWIND keys(n) AS key WITH key, n[key] AS value WHERE value IS NOT NULL RETURN key, COLLECT(DISTINCT valueType(value)) AS types`, label) propResult, err := session.Run(ctx, propQuery, map[string]any{"sampleSize": sampleSize}) if err != nil { return fmt.Errorf("node properties query for label %s failed: %w", label, err) } mu.Lock() if nodePropsMap[label] == nil { nodePropsMap[label] = make(map[string]map[string]bool) } for propResult.Next(ctx) { record := propResult.Record() key, types := record.Values[0].(string), record.Values[1].([]any) if nodePropsMap[label][key] == nil { nodePropsMap[label][key] = make(map[string]bool) } for _, tp := range types { nodePropsMap[label][key][tp.(string)] = true } } mu.Unlock() if err = propResult.Err(); err != nil { return fmt.Errorf("node properties result error for label %s: %w", label, err) } } return nil }, }, { name: "relationship-schema", fn: func(session neo4j.SessionWithContext) error { relQuery := ` MATCH (start)-[r]->(end) WITH type(r) AS relType, labels(start) AS startLabels, labels(end) AS endLabels, count(*) AS count RETURN relType, CASE WHEN size(startLabels) > 0 THEN startLabels[0] ELSE null END AS startLabel, CASE WHEN size(endLabels) > 0 THEN endLabels[0] ELSE null END AS endLabel, sum(count) AS totalCount ORDER BY totalCount DESC` relResult, err := session.Run(ctx, relQuery, nil) if err != nil { return fmt.Errorf("relationship count query failed: %w", err) } var relTypesList []string mu.Lock() for relResult.Next(ctx) { record := relResult.Record() relType := record.Values[0].(string) startLabel := "" if record.Values[1] != nil { startLabel = record.Values[1].(string) } endLabel := "" if record.Values[2] != nil { endLabel = record.Values[2].(string) } count := record.Values[3].(int64) relCounts[relType] = count relTypesList = append(relTypesList, relType) if existing, ok := relConnectivity[relType]; !ok || count > existing.Count { relConnectivity[relType] = types.RelConnectivityInfo{StartNode: startLabel, EndNode: endLabel, Count: count} } } mu.Unlock() if err = relResult.Err(); err != nil { return fmt.Errorf("relationship count result error: %w", err) } for _, relType := range relTypesList { propQuery := fmt.Sprintf(`MATCH ()-[r:%s]->() WITH r LIMIT $sampleSize WHERE size(keys(r)) > 0 UNWIND keys(r) AS key WITH key, r[key] AS value WHERE value IS NOT NULL RETURN key, COLLECT(DISTINCT valueType(value)) AS types`, relType) propResult, err := session.Run(ctx, propQuery, map[string]any{"sampleSize": sampleSize}) if err != nil { return fmt.Errorf("relationship properties query for type %s failed: %w", relType, err) } mu.Lock() if relPropsMap[relType] == nil { relPropsMap[relType] = make(map[string]map[string]bool) } for propResult.Next(ctx) { record := propResult.Record() key, propTypes := record.Values[0].(string), record.Values[1].([]any) if relPropsMap[relType][key] == nil { relPropsMap[relType][key] = make(map[string]bool) } for _, t := range propTypes { relPropsMap[relType][key][t.(string)] = true } } mu.Unlock() if err = propResult.Err(); err != nil { return fmt.Errorf("relationship properties result error for type %s: %w", relType, err) } } return nil }, }, } var wg sync.WaitGroup wg.Add(len(tasks)) for _, task := range tasks { go func(task struct { name string fn func(session neo4j.SessionWithContext) error }) { defer wg.Done() session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database}) defer session.Close(ctx) if err := task.fn(session); err != nil { handleError(fmt.Errorf("task %s failed: %w", task.name, err)) } }(task) } wg.Wait() if firstErr != nil { return nil, nil, nil, firstErr } nodeLabels, relationships, stats := helpers.ProcessNonAPOCSchema(nodeCounts, nodePropsMap, relCounts, relPropsMap, relConnectivity) return nodeLabels, relationships, stats, nil } // extractDatabaseInfo retrieves general information about the Neo4j database instance. func (t Tool) extractDatabaseInfo(ctx context.Context) (*types.DatabaseInfo, error) { session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database}) defer session.Close(ctx) result, err := session.Run(ctx, "CALL dbms.components() YIELD name, versions, edition", nil) if err != nil { return nil, err } dbInfo := &types.DatabaseInfo{} if result.Next(ctx) { record := result.Record() dbInfo.Name = record.Values[0].(string) if versions, ok := record.Values[1].([]any); ok && len(versions) > 0 { dbInfo.Version = versions[0].(string) } dbInfo.Edition = record.Values[2].(string) } return dbInfo, result.Err() } // extractConstraints fetches all schema constraints from the database. func (t Tool) extractConstraints(ctx context.Context) ([]types.Constraint, error) { session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database}) defer session.Close(ctx) result, err := session.Run(ctx, "SHOW CONSTRAINTS", nil) if err != nil { return nil, err } var constraints []types.Constraint for result.Next(ctx) { record := result.Record().AsMap() constraint := types.Constraint{ Name: helpers.GetStringValue(record["name"]), Type: helpers.GetStringValue(record["type"]), EntityType: helpers.GetStringValue(record["entityType"]), } if labels, ok := record["labelsOrTypes"].([]any); ok && len(labels) > 0 { constraint.Label = labels[0].(string) } if props, ok := record["properties"].([]any); ok { constraint.Properties = helpers.ConvertToStringSlice(props) } constraints = append(constraints, constraint) } return constraints, result.Err() } // extractIndexes fetches all schema indexes from the database. func (t Tool) extractIndexes(ctx context.Context) ([]types.Index, error) { session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database}) defer session.Close(ctx) result, err := session.Run(ctx, "SHOW INDEXES", nil) if err != nil { return nil, err } var indexes []types.Index for result.Next(ctx) { record := result.Record().AsMap() index := types.Index{ Name: helpers.GetStringValue(record["name"]), State: helpers.GetStringValue(record["state"]), Type: helpers.GetStringValue(record["type"]), EntityType: helpers.GetStringValue(record["entityType"]), } if labels, ok := record["labelsOrTypes"].([]any); ok && len(labels) > 0 { index.Label = labels[0].(string) } if props, ok := record["properties"].([]any); ok { index.Properties = helpers.ConvertToStringSlice(props) } indexes = append(indexes, index) } return indexes, result.Err() }

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/googleapis/genai-toolbox'

If you have feedback or need assistance with the MCP directory API, please join our Discord server