Skip to main content
Glama

MCP Toolbox for Databases

by googleapis
Apache 2.0
11,037
  • Linux
mcp.go15.2 kB
// Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package server import ( "bufio" "bytes" "context" "encoding/json" "errors" "fmt" "io" "net/http" "strings" "sync" "time" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/render" "github.com/google/uuid" "github.com/googleapis/genai-toolbox/internal/server/mcp" "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" mcputil "github.com/googleapis/genai-toolbox/internal/server/mcp/util" v20241105 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20241105" v20250326 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250326" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/metric" ) type sseSession struct { writer http.ResponseWriter flusher http.Flusher done chan struct{} eventQueue chan string lastActive time.Time } // sseManager manages and control access to sse sessions type sseManager struct { mu sync.Mutex sseSessions map[string]*sseSession } func (m *sseManager) get(id string) (*sseSession, bool) { m.mu.Lock() defer m.mu.Unlock() session, ok := m.sseSessions[id] session.lastActive = time.Now() return session, ok } func newSseManager(ctx context.Context) *sseManager { sseM := &sseManager{ mu: sync.Mutex{}, sseSessions: make(map[string]*sseSession), } go sseM.cleanupRoutine(ctx) return sseM } func (m *sseManager) add(id string, session *sseSession) { m.mu.Lock() defer m.mu.Unlock() m.sseSessions[id] = session session.lastActive = time.Now() } func (m *sseManager) remove(id string) { m.mu.Lock() delete(m.sseSessions, id) m.mu.Unlock() } func (m *sseManager) cleanupRoutine(ctx context.Context) { timeout := 10 * time.Minute ticker := time.NewTicker(timeout) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: func() { m.mu.Lock() defer m.mu.Unlock() now := time.Now() for id, sess := range m.sseSessions { if now.Sub(sess.lastActive) > timeout { delete(m.sseSessions, id) } } }() } } } type stdioSession struct { protocol string server *Server reader *bufio.Reader writer io.Writer } func NewStdioSession(s *Server, stdin io.Reader, stdout io.Writer) *stdioSession { stdioSession := &stdioSession{ server: s, reader: bufio.NewReader(stdin), writer: stdout, } return stdioSession } func (s *stdioSession) Start(ctx context.Context) error { return s.readInputStream(ctx) } // readInputStream reads requests/notifications from MCP clients through stdin func (s *stdioSession) readInputStream(ctx context.Context) error { for { if err := ctx.Err(); err != nil { return err } line, err := s.readLine(ctx) if err != nil { if err == io.EOF { return nil } return err } v, res, err := processMcpMessage(ctx, []byte(line), s.server, s.protocol, "", nil) if err != nil { // errors during the processing of message will generate a valid MCP Error response. // server can continue to run. s.server.logger.ErrorContext(ctx, err.Error()) } if v != "" { s.protocol = v } // no responses for notifications if res != nil { if err = s.write(ctx, res); err != nil { return err } } } } // readLine process each line within the input stream. func (s *stdioSession) readLine(ctx context.Context) (string, error) { readChan := make(chan string, 1) errChan := make(chan error, 1) done := make(chan struct{}) defer close(done) defer close(readChan) defer close(errChan) go func() { select { case <-done: return default: line, err := s.reader.ReadString('\n') if err != nil { select { case errChan <- err: case <-done: } return } select { case readChan <- line: case <-done: } return } }() select { // if context is cancelled, return an empty string case <-ctx.Done(): return "", ctx.Err() // return error if error is found case err := <-errChan: return "", err // return line if successful case line := <-readChan: return line, nil } } // write writes to stdout with response to client func (s *stdioSession) write(ctx context.Context, response any) error { res, _ := json.Marshal(response) _, err := fmt.Fprintf(s.writer, "%s\n", res) return err } // mcpRouter creates a router that represents the routes under /mcp func mcpRouter(s *Server) (chi.Router, error) { r := chi.NewRouter() r.Use(middleware.AllowContentType("application/json", "application/json-rpc", "application/jsonrequest")) r.Use(middleware.StripSlashes) r.Use(render.SetContentType(render.ContentTypeJSON)) r.Get("/sse", func(w http.ResponseWriter, r *http.Request) { sseHandler(s, w, r) }) r.Get("/", func(w http.ResponseWriter, r *http.Request) { methodNotAllowed(s, w, r) }) r.Post("/", func(w http.ResponseWriter, r *http.Request) { httpHandler(s, w, r) }) r.Delete("/", func(w http.ResponseWriter, r *http.Request) {}) r.Route("/{toolsetName}", func(r chi.Router) { r.Get("/sse", func(w http.ResponseWriter, r *http.Request) { sseHandler(s, w, r) }) r.Get("/", func(w http.ResponseWriter, r *http.Request) { methodNotAllowed(s, w, r) }) r.Post("/", func(w http.ResponseWriter, r *http.Request) { httpHandler(s, w, r) }) r.Delete("/", func(w http.ResponseWriter, r *http.Request) {}) }) return r, nil } // sseHandler handles sse initialization and message. func sseHandler(s *Server, w http.ResponseWriter, r *http.Request) { ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp/sse") r = r.WithContext(ctx) sessionId := uuid.New().String() toolsetName := chi.URLParam(r, "toolsetName") s.logger.DebugContext(ctx, fmt.Sprintf("toolset name: %s", toolsetName)) span.SetAttributes(attribute.String("session_id", sessionId)) span.SetAttributes(attribute.String("toolset_name", toolsetName)) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("Access-Control-Allow-Origin", "*") var err error defer func() { if err != nil { span.SetStatus(codes.Error, err.Error()) } span.End() status := "success" if err != nil { status = "error" } s.instrumentation.McpSse.Add( r.Context(), 1, metric.WithAttributes(attribute.String("toolbox.toolset.name", toolsetName)), metric.WithAttributes(attribute.String("toolbox.sse.sessionId", sessionId)), metric.WithAttributes(attribute.String("toolbox.operation.status", status)), ) }() flusher, ok := w.(http.Flusher) if !ok { err = fmt.Errorf("unable to retrieve flusher for sse") s.logger.DebugContext(ctx, err.Error()) _ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError)) } session := &sseSession{ writer: w, flusher: flusher, done: make(chan struct{}), eventQueue: make(chan string, 100), } s.sseManager.add(sessionId, session) defer s.sseManager.remove(sessionId) // https scheme formatting if (forwarded) request is a TLS request proto := r.Header.Get("X-Forwarded-Proto") if proto == "" { if r.TLS == nil { proto = "http" } else { proto = "https" } } // send initial endpoint event toolsetURL := "" if toolsetName != "" { toolsetURL = fmt.Sprintf("/%s", toolsetName) } messageEndpoint := fmt.Sprintf("%s://%s/mcp%s?sessionId=%s", proto, r.Host, toolsetURL, sessionId) s.logger.DebugContext(ctx, fmt.Sprintf("sending endpoint event: %s", messageEndpoint)) fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", messageEndpoint) flusher.Flush() clientClose := r.Context().Done() for { select { // Ensure that only a single responses are written at once case event := <-session.eventQueue: fmt.Fprint(w, event) s.logger.DebugContext(ctx, fmt.Sprintf("sending event: %s", event)) flusher.Flush() // channel for client disconnection case <-clientClose: close(session.done) s.logger.DebugContext(ctx, "client disconnected") return } } } // methodNotAllowed handles all mcp messages. func methodNotAllowed(s *Server, w http.ResponseWriter, r *http.Request) { err := fmt.Errorf("toolbox does not support streaming in streamable HTTP transport") s.logger.DebugContext(r.Context(), err.Error()) _ = render.Render(w, r, newErrResponse(err, http.StatusMethodNotAllowed)) } // httpHandler handles all mcp messages. func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp") r = r.WithContext(ctx) ctx = util.WithLogger(r.Context(), s.logger) var sessionId, protocolVersion string var session *sseSession // check if client connects via sse // v2024-11-05 supports http with sse paramSessionId := r.URL.Query().Get("sessionId") if paramSessionId != "" { sessionId = paramSessionId protocolVersion = v20241105.PROTOCOL_VERSION var ok bool session, ok = s.sseManager.get(sessionId) if !ok { s.logger.DebugContext(ctx, "sse session not available") } } // check if client have `Mcp-Session-Id` header // `Mcp-Session-Id` is only set for v2025-03-26 in Toolbox headerSessionId := r.Header.Get("Mcp-Session-Id") if headerSessionId != "" { protocolVersion = v20250326.PROTOCOL_VERSION } // check if client have `MCP-Protocol-Version` header // Only supported for v2025-06-18+. headerProtocolVersion := r.Header.Get("MCP-Protocol-Version") if headerProtocolVersion != "" { if !mcp.VerifyProtocolVersion(headerProtocolVersion) { err := fmt.Errorf("invalid protocol version: %s", headerProtocolVersion) _ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest)) return } protocolVersion = headerProtocolVersion } toolsetName := chi.URLParam(r, "toolsetName") s.logger.DebugContext(ctx, fmt.Sprintf("toolset name: %s", toolsetName)) span.SetAttributes(attribute.String("toolset_name", toolsetName)) var err error defer func() { if err != nil { span.SetStatus(codes.Error, err.Error()) } span.End() status := "success" if err != nil { status = "error" } s.instrumentation.McpPost.Add( r.Context(), 1, metric.WithAttributes(attribute.String("toolbox.sse.sessionId", sessionId)), metric.WithAttributes(attribute.String("toolbox.operation.status", status)), ) }() // Read and returns a body from io.Reader body, err := io.ReadAll(r.Body) if err != nil { // Generate a new uuid if unable to decode id := uuid.New().String() s.logger.DebugContext(ctx, err.Error()) render.JSON(w, r, jsonrpc.NewError(id, jsonrpc.PARSE_ERROR, err.Error(), nil)) return } v, res, err := processMcpMessage(ctx, body, s, protocolVersion, toolsetName, r.Header) if err != nil { s.logger.DebugContext(ctx, fmt.Errorf("error processing message: %w", err).Error()) } // notifications will return empty string if res == nil { // Notifications do not expect a response // Toolbox doesn't do anything with notifications yet w.WriteHeader(http.StatusAccepted) return } // for v20250326, add the `Mcp-Session-Id` header if v == v20250326.PROTOCOL_VERSION { sessionId = uuid.New().String() w.Header().Set("Mcp-Session-Id", sessionId) } if session != nil { // queue sse event eventData, _ := json.Marshal(res) select { case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData): s.logger.DebugContext(ctx, "event queue successful") case <-session.done: s.logger.DebugContext(ctx, "session is close") default: s.logger.DebugContext(ctx, "unable to add to event queue") } } if rpcResponse, ok := res.(jsonrpc.JSONRPCError); ok { code := rpcResponse.Error.Code switch code { case jsonrpc.INTERNAL_ERROR: w.WriteHeader(http.StatusInternalServerError) case jsonrpc.INVALID_REQUEST: errStr := err.Error() if errors.Is(err, tools.ErrUnauthorized) { w.WriteHeader(http.StatusUnauthorized) } else if strings.Contains(errStr, "Error 401") { w.WriteHeader(http.StatusUnauthorized) } else if strings.Contains(errStr, "Error 403") { w.WriteHeader(http.StatusForbidden) } } } // send HTTP response render.JSON(w, r, res) } // processMcpMessage process the messages received from clients func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVersion string, toolsetName string, header http.Header) (string, any, error) { logger, err := util.LoggerFromContext(ctx) if err != nil { return "", jsonrpc.NewError("", jsonrpc.INTERNAL_ERROR, err.Error(), nil), err } // Generic baseMessage could either be a JSONRPCNotification or JSONRPCRequest var baseMessage jsonrpc.BaseMessage if err = util.DecodeJSON(bytes.NewBuffer(body), &baseMessage); err != nil { // Generate a new uuid if unable to decode id := uuid.New().String() // check if user is sending a batch request var a []any unmarshalErr := json.Unmarshal(body, &a) if unmarshalErr == nil { err = fmt.Errorf("not supporting batch requests") return "", jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err } return "", jsonrpc.NewError(id, jsonrpc.PARSE_ERROR, err.Error(), nil), err } // Check if method is present if baseMessage.Method == "" { err = fmt.Errorf("method not found") return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err } logger.DebugContext(ctx, fmt.Sprintf("method is: %s", baseMessage.Method)) // Check for JSON-RPC 2.0 if baseMessage.Jsonrpc != jsonrpc.JSONRPC_VERSION { err = fmt.Errorf("invalid json-rpc version") return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err } // Check if message is a notification if baseMessage.Id == nil { err := mcp.NotificationHandler(ctx, body) return "", nil, err } switch baseMessage.Method { case mcputil.INITIALIZE: res, v, err := mcp.InitializeResponse(ctx, baseMessage.Id, body, s.version) if err != nil { return "", res, err } return v, res, err default: toolset, ok := s.ResourceMgr.GetToolset(toolsetName) if !ok { err = fmt.Errorf("toolset does not exist") return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err } res, err := mcp.ProcessMethod(ctx, protocolVersion, baseMessage.Id, baseMessage.Method, toolset, s.ResourceMgr.GetToolsMap(), s.ResourceMgr.GetAuthServiceMap(), body, header) return "", res, err } }

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/googleapis/genai-toolbox'

If you have feedback or need assistance with the MCP directory API, please join our Discord server