sse.go•7.6 kB
package server
import (
"context"
"encoding/json"
"fmt"
"github.com/centralmind/gateway/cors"
"github.com/centralmind/gateway/xcontext"
"net/http"
"net/http/httptest"
"path"
"sync"
"github.com/centralmind/gateway/mcp"
"github.com/google/uuid"
)
// SSEServer implements a Server-Sent Events (SSE) based MCP server.
// It provides real-time communication capabilities over HTTP using the SSE protocol.
type SSEServer struct {
server *MCPServer
baseURL string
sessions sync.Map
srv *http.Server
prefix string
}
// sseSession represents an active SSE connection.
type sseSession struct {
writer http.ResponseWriter
flusher http.Flusher
done chan struct{}
eventQueue chan string // Channel for queuing events
}
// NewSSEServer creates a new SSE server instance with the given MCP server and base URL.
func NewSSEServer(server *MCPServer, baseURL string, prefix string) *SSEServer {
return &SSEServer{
server: server,
baseURL: baseURL,
prefix: prefix,
}
}
// NewTestServer creates a test server for testing purposes
func NewTestServer(server *MCPServer) *httptest.Server {
sseServer := &SSEServer{
server: server,
}
testServer := httptest.NewServer(sseServer)
sseServer.baseURL = testServer.URL
return testServer
}
// Start begins serving SSE connections on the specified address.
// It sets up HTTP handlers for SSE and message endpoints.
func (s *SSEServer) Start(addr string) error {
s.srv = &http.Server{
Addr: addr,
Handler: s,
}
return s.srv.ListenAndServe()
}
// Shutdown gracefully stops the SSE server, closing all active sessions
// and shutting down the HTTP server.
func (s *SSEServer) Shutdown(ctx context.Context) error {
if s.srv != nil {
s.sessions.Range(func(key, value interface{}) bool {
if session, ok := value.(*sseSession); ok {
close(session.done)
}
s.sessions.Delete(key)
return true
})
return s.srv.Shutdown(ctx)
}
return nil
}
// handleSSE handles incoming SSE connection requests.
// It sets up appropriate headers and creates a new session for the client.
func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
cors.ApplyCORSHeaders(w, "GET")
if cors.HandlePreflight(w, r) {
return
}
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if s.server.NeedAuth(r) {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
return
}
sessionID := uuid.New().String()
session := &sseSession{
writer: w,
flusher: flusher,
done: make(chan struct{}),
eventQueue: make(chan string, 100), // Buffer for events
}
s.sessions.Store(sessionID, session)
defer s.sessions.Delete(sessionID)
// Start notification handler for this session
go func() {
for {
select {
case serverNotification := <-s.server.notifications:
// Only forward notifications meant for this session
if serverNotification.Context.SessionID == sessionID {
eventData, err := json.Marshal(serverNotification.Notification)
if err == nil {
select {
case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData):
// Event queued successfully
case <-session.done:
return
}
}
}
case <-session.done:
return
case <-r.Context().Done():
return
}
}
}()
messageEndpoint := fmt.Sprintf(
"%s?sessionId=%s",
s.baseURL+path.Join("/", s.prefix, "message"),
sessionID,
)
// Send the initial endpoint event
fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", messageEndpoint)
flusher.Flush()
// Main event loop - this runs in the HTTP handler goroutine
for {
select {
case event := <-session.eventQueue:
// Write the event to the response
fmt.Fprint(w, event)
flusher.Flush()
case <-r.Context().Done():
close(session.done)
return
}
}
}
// handleMessage processes incoming JSON-RPC messages from clients and sends responses
// back through both the SSE connection and HTTP response.
func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusOK)
return
}
if r.Method != http.MethodPost {
s.writeJSONRPCError(w, nil, mcp.INVALID_REQUEST, "Method not allowed")
return
}
sessionID := r.URL.Query().Get("sessionId")
if sessionID == "" {
s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId")
return
}
// Set the client context in the server before handling the message
ctx := s.server.WithContext(r.Context(), NotificationContext{
ClientID: sessionID,
SessionID: sessionID,
})
ctx = xcontext.WithSession(ctx, sessionID)
ctx = xcontext.WithHeader(ctx, r.Header)
sessionI, ok := s.sessions.Load(sessionID)
if !ok {
s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID")
return
}
session := sessionI.(*sseSession)
// Parse message as raw JSON
var rawMessage json.RawMessage
if err := json.NewDecoder(r.Body).Decode(&rawMessage); err != nil {
s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "Parse error")
return
}
// Process message through MCPServer
response := s.server.HandleMessage(ctx, rawMessage)
// Only send response if there is one (not for notifications)
if response != nil {
eventData, _ := json.Marshal(response)
// Queue the event for sending via SSE
select {
case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData):
// Event queued successfully
case <-session.done:
// Session is closed, don't try to queue
default:
// Queue is full, could log this
}
// Send HTTP response
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
json.NewEncoder(w).Encode(response)
} else {
// For notifications, just send 202 Accepted with no body
w.WriteHeader(http.StatusAccepted)
}
}
// writeJSONRPCError writes a JSON-RPC error response with the given error details.
func (s *SSEServer) writeJSONRPCError(
w http.ResponseWriter,
id interface{},
code int,
message string,
) {
response := CreateErrorResponse(id, code, message)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(response)
}
// SendEventToSession sends an event to a specific SSE session identified by sessionID.
// Returns an error if the session is not found or closed.
func (s *SSEServer) SendEventToSession(
sessionID string,
event interface{},
) error {
sessionI, ok := s.sessions.Load(sessionID)
if !ok {
return fmt.Errorf("session not found: %s", sessionID)
}
session := sessionI.(*sseSession)
eventData, err := json.Marshal(event)
if err != nil {
return err
}
// Queue the event for sending via SSE
select {
case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData):
return nil
case <-session.done:
return fmt.Errorf("session closed")
default:
return fmt.Errorf("event queue full")
}
}
// ServeHTTP implements the http.Handler interface.
func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/" + path.Join(s.prefix, "sse"):
s.handleSSE(w, r)
case "/" + path.Join(s.prefix, "message"):
s.handleMessage(w, r)
default:
http.NotFound(w, r)
}
}