package server
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"sync/atomic"
"github.com/google/uuid"
"github.com/mark3labs/mcp-go/mcp"
)
// sseSession represents an active SSE connection.
type sseSession struct {
writer http.ResponseWriter
flusher http.Flusher
done chan struct{}
eventQueue chan string // Channel for queuing events
sessionID string
notificationChannel chan mcp.JSONRPCNotification
initialized atomic.Bool
}
// SSEContextFunc is a function that takes an existing context and the current
// request and returns a potentially modified context based on the request
// content. This can be used to inject context values from headers, for example.
type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context
func (s *sseSession) SessionID() string {
return s.sessionID
}
func (s *sseSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
return s.notificationChannel
}
func (s *sseSession) Initialize() {
s.initialized.Store(true)
}
func (s *sseSession) Initialized() bool {
return s.initialized.Load()
}
var _ ClientSession = (*sseSession)(nil)
// SSEServer implements a Server-Sent Events (SSE) based MCP server.
// It provides real-time communication capabilities over HTTP using the SSE protocol.
type SSEServer struct {
server *MCPServer
baseURL string
basePath string
messageEndpoint string
sseEndpoint string
sessions sync.Map
srv *http.Server
contextFunc SSEContextFunc
}
// SSEOption defines a function type for configuring SSEServer
type SSEOption func(*SSEServer)
// WithBaseURL sets the base URL for the SSE server
func WithBaseURL(baseURL string) SSEOption {
return func(s *SSEServer) {
if baseURL != "" {
u, err := url.Parse(baseURL)
if err != nil {
return
}
if u.Scheme != "http" && u.Scheme != "https" {
return
}
// Check if the host is empty or only contains a port
if u.Host == "" || strings.HasPrefix(u.Host, ":") {
return
}
if len(u.Query()) > 0 {
return
}
}
s.baseURL = strings.TrimSuffix(baseURL, "/")
}
}
// Add a new option for setting base path
func WithBasePath(basePath string) SSEOption {
return func(s *SSEServer) {
// Ensure the path starts with / and doesn't end with /
if !strings.HasPrefix(basePath, "/") {
basePath = "/" + basePath
}
s.basePath = strings.TrimSuffix(basePath, "/")
}
}
// WithMessageEndpoint sets the message endpoint path
func WithMessageEndpoint(endpoint string) SSEOption {
return func(s *SSEServer) {
s.messageEndpoint = endpoint
}
}
// WithSSEEndpoint sets the SSE endpoint path
func WithSSEEndpoint(endpoint string) SSEOption {
return func(s *SSEServer) {
s.sseEndpoint = endpoint
}
}
// WithHTTPServer sets the HTTP server instance
func WithHTTPServer(srv *http.Server) SSEOption {
return func(s *SSEServer) {
s.srv = srv
}
}
// WithContextFunc sets a function that will be called to customise the context
// to the server using the incoming request.
func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
return func(s *SSEServer) {
s.contextFunc = fn
}
}
// NewSSEServer creates a new SSE server instance with the given MCP server and options.
func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer {
s := &SSEServer{
server: server,
sseEndpoint: "/sse",
messageEndpoint: "/message",
}
// Apply all options
for _, opt := range opts {
opt(s)
}
return s
}
// NewTestServer creates a test server for testing purposes
func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server {
sseServer := NewSSEServer(server)
for _, opt := range opts {
opt(sseServer)
}
testServer := httptest.NewServer(sseServer)
sseServer.baseURL = testServer.URL
return testServer
}
// Start begins serving SSE connections on the specified address.
// It sets up HTTP handlers for SSE and message endpoints.
func (s *SSEServer) Start(addr string) error {
s.srv = &http.Server{
Addr: addr,
Handler: s,
}
return s.srv.ListenAndServe()
}
// Shutdown gracefully stops the SSE server, closing all active sessions
// and shutting down the HTTP server.
func (s *SSEServer) Shutdown(ctx context.Context) error {
if s.srv != nil {
s.sessions.Range(func(key, value interface{}) bool {
if session, ok := value.(*sseSession); ok {
close(session.done)
}
s.sessions.Delete(key)
return true
})
return s.srv.Shutdown(ctx)
}
return nil
}
// handleSSE handles incoming SSE connection requests.
// It sets up appropriate headers and creates a new session for the client.
func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
return
}
sessionID := uuid.New().String()
session := &sseSession{
writer: w,
flusher: flusher,
done: make(chan struct{}),
eventQueue: make(chan string, 100), // Buffer for events
sessionID: sessionID,
notificationChannel: make(chan mcp.JSONRPCNotification, 100),
}
s.sessions.Store(sessionID, session)
defer s.sessions.Delete(sessionID)
if err := s.server.RegisterSession(session); err != nil {
http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusInternalServerError)
return
}
defer s.server.UnregisterSession(sessionID)
// Start notification handler for this session
go func() {
for {
select {
case notification := <-session.notificationChannel:
eventData, err := json.Marshal(notification)
if err == nil {
select {
case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData):
// Event queued successfully
case <-session.done:
return
}
}
case <-session.done:
return
case <-r.Context().Done():
return
}
}
}()
messageEndpoint := fmt.Sprintf("%s?sessionId=%s", s.CompleteMessageEndpoint(), sessionID)
// Send the initial endpoint event
fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", messageEndpoint)
flusher.Flush()
// Main event loop - this runs in the HTTP handler goroutine
for {
select {
case event := <-session.eventQueue:
// Write the event to the response
fmt.Fprint(w, event)
flusher.Flush()
case <-r.Context().Done():
close(session.done)
return
}
}
}
// handleMessage processes incoming JSON-RPC messages from clients and sends responses
// back through both the SSE connection and HTTP response.
func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
s.writeJSONRPCError(w, nil, mcp.INVALID_REQUEST, "Method not allowed")
return
}
sessionID := r.URL.Query().Get("sessionId")
if sessionID == "" {
s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId")
return
}
sessionI, ok := s.sessions.Load(sessionID)
if !ok {
s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID")
return
}
session := sessionI.(*sseSession)
// Set the client context before handling the message
ctx := s.server.WithContext(r.Context(), session)
if s.contextFunc != nil {
ctx = s.contextFunc(ctx, r)
}
// Parse message as raw JSON
var rawMessage json.RawMessage
if err := json.NewDecoder(r.Body).Decode(&rawMessage); err != nil {
s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "Parse error")
return
}
// Process message through MCPServer
response := s.server.HandleMessage(ctx, rawMessage)
// Only send response if there is one (not for notifications)
if response != nil {
eventData, _ := json.Marshal(response)
// Queue the event for sending via SSE
select {
case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData):
// Event queued successfully
case <-session.done:
// Session is closed, don't try to queue
default:
// Queue is full, could log this
}
// Send HTTP response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
json.NewEncoder(w).Encode(response)
} else {
// For notifications, just send 202 Accepted with no body
w.WriteHeader(http.StatusAccepted)
}
}
// writeJSONRPCError writes a JSON-RPC error response with the given error details.
func (s *SSEServer) writeJSONRPCError(
w http.ResponseWriter,
id interface{},
code int,
message string,
) {
response := createErrorResponse(id, code, message)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(response)
}
// SendEventToSession sends an event to a specific SSE session identified by sessionID.
// Returns an error if the session is not found or closed.
func (s *SSEServer) SendEventToSession(
sessionID string,
event interface{},
) error {
sessionI, ok := s.sessions.Load(sessionID)
if !ok {
return fmt.Errorf("session not found: %s", sessionID)
}
session := sessionI.(*sseSession)
eventData, err := json.Marshal(event)
if err != nil {
return err
}
// Queue the event for sending via SSE
select {
case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData):
return nil
case <-session.done:
return fmt.Errorf("session closed")
default:
return fmt.Errorf("event queue full")
}
}
func (s *SSEServer) GetUrlPath(input string) (string, error) {
parse, err := url.Parse(input)
if err != nil {
return "", fmt.Errorf("failed to parse URL %s: %w", input, err)
}
return parse.Path, nil
}
func (s *SSEServer) CompleteSseEndpoint() string {
return s.baseURL + s.basePath + s.sseEndpoint
}
func (s *SSEServer) CompleteSsePath() string {
path, err := s.GetUrlPath(s.CompleteSseEndpoint())
if err != nil {
return s.basePath + s.sseEndpoint
}
return path
}
func (s *SSEServer) CompleteMessageEndpoint() string {
return s.baseURL + s.basePath + s.messageEndpoint
}
func (s *SSEServer) CompleteMessagePath() string {
path, err := s.GetUrlPath(s.CompleteMessageEndpoint())
if err != nil {
return s.basePath + s.messageEndpoint
}
return path
}
// ServeHTTP implements the http.Handler interface.
func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
// Use exact path matching rather than Contains
ssePath := s.CompleteSsePath()
if ssePath != "" && path == ssePath {
s.handleSSE(w, r)
return
}
messagePath := s.CompleteMessagePath()
if messagePath != "" && path == messagePath {
s.handleMessage(w, r)
return
}
http.NotFound(w, r)
}