stdio.go•5 kB
package server
import (
"bufio"
"context"
"encoding/json"
"fmt"
"github.com/centralmind/gateway/xcontext"
"io"
"log"
"os"
"os/signal"
"syscall"
"github.com/centralmind/gateway/mcp"
)
// StdioServer wraps a MCPServer and handles stdio communication.
// It provides a simple way to create command-line MCP servers that
// communicate via standard input/output streams using JSON-RPC messages.
type StdioServer struct {
server *MCPServer
errLogger *log.Logger
}
// NewStdioServer creates a new stdio server wrapper around an MCPServer.
// It initializes the server with a default error logger that discards all output.
func NewStdioServer(server *MCPServer) *StdioServer {
return &StdioServer{
server: server,
errLogger: log.New(
os.Stderr,
"",
log.LstdFlags,
), // Default to discarding logs
}
}
// SetErrorLogger configures where error messages from the StdioServer are logged.
// The provided logger will receive all error messages generated during server operation.
func (s *StdioServer) SetErrorLogger(logger *log.Logger) {
s.errLogger = logger
}
// Listen starts listening for JSON-RPC messages on the provided input and writes responses to the provided output.
// It runs until the context is cancelled or an error occurs.
// Returns an error if there are issues with reading input or writing output.
func (s *StdioServer) Listen(
ctx context.Context,
stdin io.Reader,
stdout io.Writer,
) error {
// Set a static client context since stdio only has one client
ctx = s.server.WithContext(ctx, NotificationContext{
ClientID: "stdio",
SessionID: "stdio",
})
ctx = xcontext.WithSession(ctx, "stdio")
reader := bufio.NewReader(stdin)
// Start notification handler
go func() {
for {
select {
case serverNotification := <-s.server.notifications:
// Only handle notifications for stdio client
if serverNotification.Context.ClientID == "stdio" {
err := s.writeResponse(
serverNotification.Notification,
stdout,
)
if err != nil {
s.errLogger.Printf(
"Error writing notification: %v",
err,
)
}
}
case <-ctx.Done():
return
}
}
}()
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
// Use a goroutine to make the read cancellable
readChan := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
line, err := reader.ReadString('\n')
if err != nil {
errChan <- err
return
}
readChan <- line
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-errChan:
if err == io.EOF {
return nil
}
s.errLogger.Printf("Error reading input: %v", err)
return err
case line := <-readChan:
if err := s.processMessage(ctx, line, stdout); err != nil {
if err == io.EOF {
return nil
}
s.errLogger.Printf("Error handling message: %v", err)
return err
}
}
}
}
}
// processMessage handles a single JSON-RPC message and writes the response.
// It parses the message, processes it through the wrapped MCPServer, and writes any response.
// Returns an error if there are issues with message processing or response writing.
func (s *StdioServer) processMessage(
ctx context.Context,
line string,
writer io.Writer,
) error {
// Parse the message as raw JSON
var rawMessage json.RawMessage
if err := json.Unmarshal([]byte(line), &rawMessage); err != nil {
response := CreateErrorResponse(nil, mcp.PARSE_ERROR, "Parse error")
return s.writeResponse(response, writer)
}
// Handle the message using the wrapped server
response := s.server.HandleMessage(ctx, rawMessage)
// Only write response if there is one (not for notifications)
if response != nil {
if err := s.writeResponse(response, writer); err != nil {
return fmt.Errorf("failed to write response: %w", err)
}
}
return nil
}
// writeResponse marshals and writes a JSON-RPC response message followed by a newline.
// Returns an error if marshaling or writing fails.
func (s *StdioServer) writeResponse(
response mcp.JSONRPCMessage,
writer io.Writer,
) error {
responseBytes, err := json.Marshal(response)
if err != nil {
return err
}
// Write response followed by newline
if _, err := fmt.Fprintf(writer, "%s\n", responseBytes); err != nil {
return err
}
return nil
}
// ServeStdio is a convenience function that creates and starts a StdioServer with os.Stdin and os.Stdout.
// It sets up signal handling for graceful shutdown on SIGTERM and SIGINT.
// Returns an error if the server encounters any issues during operation.
func ServeStdio(server *MCPServer) error {
s := NewStdioServer(server)
s.SetErrorLogger(log.New(os.Stderr, "", log.LstdFlags))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Set up signal handling
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT)
go func() {
<-sigChan
cancel()
}()
return s.Listen(ctx, os.Stdin, os.Stdout)
}