package server
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"log"
"os"
"os/signal"
"sync/atomic"
"syscall"
"github.com/mark3labs/mcp-go/mcp"
)
// StdioContextFunc is a function that takes an existing context and returns
// a potentially modified context.
// This can be used to inject context values from environment variables,
// for example.
type StdioContextFunc func(ctx context.Context) context.Context
// StdioServer wraps a MCPServer and handles stdio communication.
// It provides a simple way to create command-line MCP servers that
// communicate via standard input/output streams using JSON-RPC messages.
type StdioServer struct {
server *MCPServer
errLogger *log.Logger
contextFunc StdioContextFunc
}
// StdioOption defines a function type for configuring StdioServer
type StdioOption func(*StdioServer)
// WithErrorLogger sets the error logger for the server
func WithErrorLogger(logger *log.Logger) StdioOption {
return func(s *StdioServer) {
s.errLogger = logger
}
}
// WithContextFunc sets a function that will be called to customise the context
// to the server. Note that the stdio server uses the same context for all requests,
// so this function will only be called once per server instance.
func WithStdioContextFunc(fn StdioContextFunc) StdioOption {
return func(s *StdioServer) {
s.contextFunc = fn
}
}
// stdioSession is a static client session, since stdio has only one client.
type stdioSession struct {
notifications chan mcp.JSONRPCNotification
initialized atomic.Bool
}
func (s *stdioSession) SessionID() string {
return "stdio"
}
func (s *stdioSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
return s.notifications
}
func (s *stdioSession) Initialize() {
s.initialized.Store(true)
}
func (s *stdioSession) Initialized() bool {
return s.initialized.Load()
}
var _ ClientSession = (*stdioSession)(nil)
var stdioSessionInstance = stdioSession{
notifications: make(chan mcp.JSONRPCNotification, 100),
}
// NewStdioServer creates a new stdio server wrapper around an MCPServer.
// It initializes the server with a default error logger that discards all output.
func NewStdioServer(server *MCPServer) *StdioServer {
return &StdioServer{
server: server,
errLogger: log.New(
os.Stderr,
"",
log.LstdFlags,
), // Default to discarding logs
}
}
// SetErrorLogger configures where error messages from the StdioServer are logged.
// The provided logger will receive all error messages generated during server operation.
func (s *StdioServer) SetErrorLogger(logger *log.Logger) {
s.errLogger = logger
}
// SetContextFunc sets a function that will be called to customise the context
// to the server. Note that the stdio server uses the same context for all requests,
// so this function will only be called once per server instance.
func (s *StdioServer) SetContextFunc(fn StdioContextFunc) {
s.contextFunc = fn
}
// Listen starts listening for JSON-RPC messages on the provided input and writes responses to the provided output.
// It runs until the context is cancelled or an error occurs.
// Returns an error if there are issues with reading input or writing output.
func (s *StdioServer) Listen(
ctx context.Context,
stdin io.Reader,
stdout io.Writer,
) error {
// Set a static client context since stdio only has one client
if err := s.server.RegisterSession(&stdioSessionInstance); err != nil {
return fmt.Errorf("register session: %w", err)
}
defer s.server.UnregisterSession(stdioSessionInstance.SessionID())
ctx = s.server.WithContext(ctx, &stdioSessionInstance)
// Add in any custom context.
if s.contextFunc != nil {
ctx = s.contextFunc(ctx)
}
reader := bufio.NewReader(stdin)
// Start notification handler
go func() {
for {
select {
case notification := <-stdioSessionInstance.notifications:
err := s.writeResponse(
notification,
stdout,
)
if err != nil {
s.errLogger.Printf(
"Error writing notification: %v",
err,
)
}
case <-ctx.Done():
return
}
}
}()
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
// Use a goroutine to make the read cancellable
readChan := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
line, err := reader.ReadString('\n')
if err != nil {
errChan <- err
return
}
readChan <- line
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-errChan:
if err == io.EOF {
return nil
}
s.errLogger.Printf("Error reading input: %v", err)
return err
case line := <-readChan:
if err := s.processMessage(ctx, line, stdout); err != nil {
if err == io.EOF {
return nil
}
s.errLogger.Printf("Error handling message: %v", err)
return err
}
}
}
}
}
// processMessage handles a single JSON-RPC message and writes the response.
// It parses the message, processes it through the wrapped MCPServer, and writes any response.
// Returns an error if there are issues with message processing or response writing.
func (s *StdioServer) processMessage(
ctx context.Context,
line string,
writer io.Writer,
) error {
// Parse the message as raw JSON
var rawMessage json.RawMessage
if err := json.Unmarshal([]byte(line), &rawMessage); err != nil {
response := createErrorResponse(nil, mcp.PARSE_ERROR, "Parse error")
return s.writeResponse(response, writer)
}
// Handle the message using the wrapped server
response := s.server.HandleMessage(ctx, rawMessage)
// Only write response if there is one (not for notifications)
if response != nil {
if err := s.writeResponse(response, writer); err != nil {
return fmt.Errorf("failed to write response: %w", err)
}
}
return nil
}
// writeResponse marshals and writes a JSON-RPC response message followed by a newline.
// Returns an error if marshaling or writing fails.
func (s *StdioServer) writeResponse(
response mcp.JSONRPCMessage,
writer io.Writer,
) error {
responseBytes, err := json.Marshal(response)
if err != nil {
return err
}
// Write response followed by newline
if _, err := fmt.Fprintf(writer, "%s\n", responseBytes); err != nil {
return err
}
return nil
}
// ServeStdio is a convenience function that creates and starts a StdioServer with os.Stdin and os.Stdout.
// It sets up signal handling for graceful shutdown on SIGTERM and SIGINT.
// Returns an error if the server encounters any issues during operation.
func ServeStdio(server *MCPServer, opts ...StdioOption) error {
s := NewStdioServer(server)
s.SetErrorLogger(log.New(os.Stderr, "", log.LstdFlags))
for _, opt := range opts {
opt(s)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Set up signal handling
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT)
go func() {
<-sigChan
cancel()
}()
return s.Listen(ctx, os.Stdin, os.Stdout)
}