Skip to main content
Glama
trip_handlers.go21.6 kB
// ABOUTME: This file implements the trip planning handlers for the MCP server. // ABOUTME: It defines tools and request processing logic for MBTA trip planning. package server import ( "context" "encoding/json" "fmt" "log" "math" "time" "github.com/crdant/mbta-mcp-server/pkg/mbta" "github.com/crdant/mbta-mcp-server/pkg/mbta/models" "github.com/mark3labs/mcp-go/mcp" ) // registerTripPlanningTools registers the trip planning tools and handlers func (s *Server) registerTripPlanningTools() { // Tool: PlanTrip - creates a trip plan between two stops planTripTool := mcp.Tool{ Name: "plan_trip", Description: "Plan a trip between two MBTA stops or stations", InputSchema: mcp.ToolInputSchema{ Type: "object", Properties: map[string]any{ "origin_stop_id": map[string]any{ "type": "string", "description": "The ID of the origin stop or station", }, "destination_stop_id": map[string]any{ "type": "string", "description": "The ID of the destination stop or station", }, "departure_time": map[string]any{ "type": "string", "description": "The desired departure time (ISO 8601 format, e.g. '2023-05-23T14:30:00Z'). If not provided, current time is used.", }, "wheelchair_accessible": map[string]any{ "type": "boolean", "description": "Whether the trip must be wheelchair accessible", }, }, Required: []string{"origin_stop_id", "destination_stop_id"}, }, } // Register the trip planning tool with its handler, wrapped with middleware s.mcpServer.AddTool(planTripTool, s.wrapWithMiddleware(s.planTripHandler)) // Tool: FindTransfers - finds transfer points between routes findTransfersTool := mcp.Tool{ Name: "find_transfers", Description: "Find transfer points between MBTA routes", InputSchema: mcp.ToolInputSchema{ Type: "object", Properties: map[string]any{ "from_route_id": map[string]any{ "type": "string", "description": "The ID of the origin route", }, "to_route_id": map[string]any{ "type": "string", "description": "The ID of the destination route", }, }, Required: []string{"from_route_id", "to_route_id"}, }, } // Register the transfers tool with its handler, wrapped with middleware s.mcpServer.AddTool(findTransfersTool, s.wrapWithMiddleware(s.findTransfersHandler)) // Tool: EstimateTravelTime - estimates travel time between stops estimateTravelTimeTool := mcp.Tool{ Name: "estimate_travel_time", Description: "Estimate travel time between two MBTA stops or stations", InputSchema: mcp.ToolInputSchema{ Type: "object", Properties: map[string]any{ "origin_stop_id": map[string]any{ "type": "string", "description": "The ID of the origin stop or station", }, "destination_stop_id": map[string]any{ "type": "string", "description": "The ID of the destination stop or station", }, "route_id": map[string]any{ "type": "string", "description": "Optional: The ID of the route to use for estimation", }, }, Required: []string{"origin_stop_id", "destination_stop_id"}, }, } // Register the travel time tool with its handler, wrapped with middleware s.mcpServer.AddTool(estimateTravelTimeTool, s.wrapWithMiddleware(s.estimateTravelTimeHandler)) } // planTripHandler handles requests for planning trips between stops func (s *Server) planTripHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { log.Printf("Received request for trip planning: %s", request.Params.Name) // Create MBTA client client := mbta.NewClient(s.config) // Extract required parameters args := request.Params.Arguments originStopID, ok := args["origin_stop_id"].(string) if !ok { return createErrorResponse("Missing or invalid origin_stop_id parameter"), nil } destinationStopID, ok := args["destination_stop_id"].(string) if !ok { return createErrorResponse("Missing or invalid destination_stop_id parameter"), nil } // Extract optional parameters var departureTime time.Time if departureTStr, ok := args["departure_time"].(string); ok && departureTStr != "" { var err error departureTime, err = time.Parse(time.RFC3339, departureTStr) if err != nil { return createErrorResponse(fmt.Sprintf("Invalid departure_time format: %v", err)), nil } } else { departureTime = time.Now() } // Check for wheelchair accessible requirement var wheelchairAccessible bool if wheelchairAccessibleVal, ok := args["wheelchair_accessible"].(bool); ok { wheelchairAccessible = wheelchairAccessibleVal } // Create options map options := map[string]interface{}{ "wheelchair_accessible": wheelchairAccessible, } log.Printf("Planning trip from %s to %s at %s", originStopID, destinationStopID, departureTime.Format(time.RFC3339)) // Plan the trip tripPlan, err := client.PlanTrip(ctx, originStopID, destinationStopID, departureTime, options) if err != nil { return createErrorResponse(fmt.Sprintf("Failed to plan trip: %v", err)), nil } // Format the trip plan for response return formatTripPlanResponse(tripPlan) } // findTransfersHandler handles requests for finding transfer points between routes func (s *Server) findTransfersHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { log.Printf("Received request for transfer points: %s", request.Params.Name) // Create MBTA client client := mbta.NewClient(s.config) // Extract required parameters args := request.Params.Arguments fromRouteID, ok := args["from_route_id"].(string) if !ok { return createErrorResponse("Missing or invalid from_route_id parameter"), nil } toRouteID, ok := args["to_route_id"].(string) if !ok { return createErrorResponse("Missing or invalid to_route_id parameter"), nil } log.Printf("Finding transfer points from route %s to route %s", fromRouteID, toRouteID) // Find transfer points between routes transferPoints, err := client.FindTransferPoints(ctx, []string{fromRouteID}, []string{toRouteID}) if err != nil { return createErrorResponse(fmt.Sprintf("Failed to find transfer points: %v", err)), nil } // If no transfer points are found, inform the user if len(transferPoints) == 0 { return &mcp.CallToolResult{ Content: []mcp.Content{ mcp.TextContent{ Type: "text", Text: fmt.Sprintf("No transfer points found between routes %s and %s.", fromRouteID, toRouteID), }, }, }, nil } // Format the transfer points for response return formatTransferPointsResponse(transferPoints) } // estimateTravelTimeHandler handles requests for estimating travel time between stops func (s *Server) estimateTravelTimeHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { log.Printf("Received request for travel time estimation: %s", request.Params.Name) // Create MBTA client client := mbta.NewClient(s.config) // Extract required parameters args := request.Params.Arguments originStopID, ok := args["origin_stop_id"].(string) if !ok { return createErrorResponse("Missing or invalid origin_stop_id parameter"), nil } destinationStopID, ok := args["destination_stop_id"].(string) if !ok { return createErrorResponse("Missing or invalid destination_stop_id parameter"), nil } // Extract optional route parameter var routeID string if routeIDVal, ok := args["route_id"].(string); ok { routeID = routeIDVal } log.Printf("Estimating travel time from %s to %s", originStopID, destinationStopID) // Get origin and destination stops originStop, err := client.GetStop(ctx, originStopID) if err != nil { return createErrorResponse(fmt.Sprintf("Failed to retrieve origin stop: %v", err)), nil } destStop, err := client.GetStop(ctx, destinationStopID) if err != nil { return createErrorResponse(fmt.Sprintf("Failed to retrieve destination stop: %v", err)), nil } // Calculate approximate distance distance := calculateApproximateDistance( originStop.Attributes.Latitude, originStop.Attributes.Longitude, destStop.Attributes.Latitude, destStop.Attributes.Longitude, ) // If a specific route was provided, get schedules to estimate actual travel time var travelTimeMinutes float64 var travelTimeSource string var scheduleBasedEstimate bool if routeID != "" { // Try to get schedule-based estimate travelTimeMinutes, scheduleBasedEstimate, err = estimateScheduleBasedTravelTime(ctx, client, originStopID, destinationStopID, routeID) if err == nil && scheduleBasedEstimate { travelTimeSource = "Based on recent schedules" } } // If we couldn't get a schedule-based estimate, fall back to a distance-based estimate if !scheduleBasedEstimate { // Estimate travel time based on distance and mode of transit if routeID != "" { route, err := client.GetRoute(ctx, routeID) if err == nil { // Use route type to determine average speed travelTimeMinutes = estimateTimeByRouteType(distance, route.Attributes.Type) travelTimeSource = fmt.Sprintf("Estimated based on %s transit speed", route.GetTypeDescription()) } } // If we still don't have an estimate, use generic speed if travelTimeMinutes == 0 { // Use generic estimate (25 km/h average) travelTimeMinutes = distance / 25 * 60 travelTimeSource = "Rough estimate based on distance" } } // Format the response return formatTravelTimeResponse(originStop, destStop, distance, travelTimeMinutes, travelTimeSource) } // estimateScheduleBasedTravelTime tries to estimate travel time using real schedule data func estimateScheduleBasedTravelTime(ctx context.Context, client *mbta.Client, originID, destinationID, routeID string) (float64, bool, error) { // Get schedules for today that include both stops on the specified route now := time.Now() params := map[string]string{ "filter[route]": routeID, "filter[stop]": originID + "," + destinationID, "filter[date]": now.Format("2006-01-02"), "filter[direction]": "0,1", // Consider both directions "include": "trip", "sort": "departure_time", "fields[schedule]": "departure_time,arrival_time,stop_sequence,pickup_type,drop_off_type,direction_id", "fields[trip]": "headsign,direction_id", "page[limit]": "100", // Get a reasonable sample } schedules, _, err := client.GetSchedules(ctx, params) if err != nil { return 0, false, err } // Maps to store schedules by trip and stop tripStopTimes := make(map[string]map[string]time.Time) // Process schedules to find trip legs for _, schedule := range schedules { // Extract trip ID from relationships var tripID string if tripRel, ok := schedule.Relationships["trip"]; ok { if tripData, ok := tripRel.(map[string]interface{})["data"].(map[string]interface{}); ok { if id, ok := tripData["id"].(string); ok { tripID = id } } } // Extract stop ID from relationships var stopID string if stopRel, ok := schedule.Relationships["stop"]; ok { if stopData, ok := stopRel.(map[string]interface{})["data"].(map[string]interface{}); ok { if id, ok := stopData["id"].(string); ok { stopID = id } } } // Skip if missing key data if tripID == "" || stopID == "" { continue } // Parse departure time depTime, err := time.Parse(time.RFC3339, schedule.Attributes.DepartureTime) if err != nil { continue } // Initialize map for this trip if needed if _, ok := tripStopTimes[tripID]; !ok { tripStopTimes[tripID] = make(map[string]time.Time) } // Store departure time for this stop on this trip tripStopTimes[tripID][stopID] = depTime } // Find trips that include both our stops and calculate travel times var totalMinutes float64 var validTrips int for _, stopTimes := range tripStopTimes { originTime, hasOrigin := stopTimes[originID] destTime, hasDest := stopTimes[destinationID] if hasOrigin && hasDest && destTime.After(originTime) { // Calculate travel time in minutes travelTime := destTime.Sub(originTime).Minutes() totalMinutes += travelTime validTrips++ } } // If we have valid trips, return the average travel time if validTrips > 0 { return totalMinutes / float64(validTrips), true, nil } return 0, false, fmt.Errorf("no valid trips found between stops") } // estimateTimeByRouteType estimates travel time based on route type and distance func estimateTimeByRouteType(distanceKm float64, routeType int) float64 { // Approximate speeds based on route types // These are rough estimates and could be adjusted based on real data var speedKmh float64 switch routeType { case 0: // Light Rail speedKmh = 20 // ~20 km/h for light rail (with stops) case 1: // Subway speedKmh = 30 // ~30 km/h for subway (with stops) case 2: // Commuter Rail speedKmh = 40 // ~40 km/h for commuter rail (with stops) case 3: // Bus speedKmh = 15 // ~15 km/h for bus (with stops, traffic) case 4: // Ferry speedKmh = 25 // ~25 km/h for ferry default: speedKmh = 25 // Default } // Convert to travel time in minutes return distanceKm / speedKmh * 60 } // formatTripPlanResponse converts a trip plan to a proper MCP response func formatTripPlanResponse(tripPlan *models.TripPlan) (*mcp.CallToolResult, error) { // Convert the trip plan to a simplified format for the response plan := map[string]interface{}{ "origin": map[string]string{ "id": tripPlan.Origin.ID, "name": tripPlan.Origin.Attributes.Name, }, "destination": map[string]string{ "id": tripPlan.Destination.ID, "name": tripPlan.Destination.Attributes.Name, }, "departure_time": tripPlan.DepartureTime.Format(time.RFC3339), "arrival_time": tripPlan.ArrivalTime.Format(time.RFC3339), "duration_minutes": tripPlan.Duration.Minutes(), "total_distance_km": tripPlan.TotalDistance, "accessible": tripPlan.AccessibleTrip, "legs": make([]interface{}, 0, len(tripPlan.Legs)), } // Convert each leg for i, leg := range tripPlan.Legs { legMap := map[string]interface{}{ "leg_number": i + 1, "origin": map[string]string{ "id": leg.Origin.ID, "name": leg.Origin.Attributes.Name, }, "destination": map[string]string{ "id": leg.Destination.ID, "name": leg.Destination.Attributes.Name, }, "route_id": leg.RouteID, "route_name": leg.RouteName, "trip_id": leg.TripID, "departure_time": leg.DepartureTime.Format(time.RFC3339), "arrival_time": leg.ArrivalTime.Format(time.RFC3339), "duration_minutes": leg.Duration.Minutes(), "distance_km": leg.Distance, "headsign": leg.Headsign, "direction_id": leg.DirectionID, "accessible": leg.IsAccessible, "instructions": leg.Instructions, "formatted_departure": leg.DepartureTime.Format("3:04 PM"), "formatted_arrival": leg.ArrivalTime.Format("3:04 PM"), } // Add predicted times if available if leg.PredictedTimes != nil { legMap["predicted_departure"] = leg.PredictedTimes.PredictedDeparture.Format(time.RFC3339) legMap["predicted_arrival"] = leg.PredictedTimes.PredictedArrival.Format(time.RFC3339) legMap["is_delayed"] = leg.PredictedTimes.IsDelayed legMap["delay_minutes"] = leg.PredictedTimes.DelayMinutes legMap["formatted_predicted_departure"] = leg.PredictedTimes.PredictedDeparture.Format("3:04 PM") legMap["formatted_predicted_arrival"] = leg.PredictedTimes.PredictedArrival.Format("3:04 PM") } plan["legs"] = append(plan["legs"].([]interface{}), legMap) } // Create JSON string response jsonBytes, err := json.MarshalIndent(plan, "", " ") if err != nil { return createErrorResponse(fmt.Sprintf("Failed to serialize trip plan data: %v", err)), nil } // Return data as a text content item return &mcp.CallToolResult{ Content: []mcp.Content{ mcp.TextContent{ Type: "text", Text: string(jsonBytes), }, }, }, nil } // formatTransferPointsResponse converts transfer points to a proper MCP response func formatTransferPointsResponse(transferPoints []models.TransferPoint) (*mcp.CallToolResult, error) { // Convert the transfer points to a simplified format for the response transferData := make([]map[string]interface{}, 0, len(transferPoints)) for _, transfer := range transferPoints { transferMap := map[string]interface{}{ "stop_id": transfer.Stop.ID, "stop_name": transfer.Stop.Attributes.Name, "from_route": transfer.FromRoute, "to_route": transfer.ToRoute, "transfer_type": transfer.TransferType, "min_transfer_time": transfer.MinTransferTime.Minutes(), "municipality": transfer.Stop.Attributes.Municipality, "latitude": transfer.Stop.Attributes.Latitude, "longitude": transfer.Stop.Attributes.Longitude, "wheelchair_accessible": transfer.Stop.IsAccessible(), } if transfer.SuggestedWaitTime > 0 { transferMap["suggested_wait_time"] = transfer.SuggestedWaitTime.Minutes() } transferData = append(transferData, transferMap) } // Create JSON string response jsonBytes, err := json.MarshalIndent(transferData, "", " ") if err != nil { return createErrorResponse(fmt.Sprintf("Failed to serialize transfer data: %v", err)), nil } // Return data as a text content item return &mcp.CallToolResult{ Content: []mcp.Content{ mcp.TextContent{ Type: "text", Text: string(jsonBytes), }, }, }, nil } // formatTravelTimeResponse converts travel time estimate to a proper MCP response func formatTravelTimeResponse(origin, destination *models.Stop, distanceKm, timeMinutes float64, source string) (*mcp.CallToolResult, error) { // Create response data estimateData := map[string]interface{}{ "origin": map[string]string{ "id": origin.ID, "name": origin.Attributes.Name, }, "destination": map[string]string{ "id": destination.ID, "name": destination.Attributes.Name, }, "distance_km": distanceKm, "estimated_minutes": timeMinutes, "formatted_time": formatDuration(timeMinutes), "estimation_source": source, "estimated_arrival": time.Now().Add(time.Duration(timeMinutes) * time.Minute).Format(time.RFC3339), "formatted_arrival": time.Now().Add(time.Duration(timeMinutes) * time.Minute).Format("3:04 PM"), "origin_municipality": origin.Attributes.Municipality, "dest_municipality": destination.Attributes.Municipality, "origin_location_type": models.GetLocationTypeDescription(origin.Attributes.LocationType), "dest_location_type": models.GetLocationTypeDescription(destination.Attributes.LocationType), "origin_accessible": origin.IsAccessible(), "destination_accessible": destination.IsAccessible(), } // Create JSON string response jsonBytes, err := json.MarshalIndent(estimateData, "", " ") if err != nil { return createErrorResponse(fmt.Sprintf("Failed to serialize travel time data: %v", err)), nil } // Return data as a text content item return &mcp.CallToolResult{ Content: []mcp.Content{ mcp.TextContent{ Type: "text", Text: string(jsonBytes), }, }, }, nil } // formatDuration converts minutes to a human-readable duration string func formatDuration(minutes float64) string { hours := int(minutes) / 60 mins := int(minutes) % 60 if hours > 0 { return fmt.Sprintf("%d hour%s %d minute%s", hours, pluralize(hours), mins, pluralize(mins)) } return fmt.Sprintf("%d minute%s", mins, pluralize(mins)) } // pluralize returns "s" if the count is not 1 func pluralize(count int) string { if count == 1 { return "" } return "s" } // calculateApproximateDistance calculates an approximate distance between two points // This is a duplicate of the function in client_trip.go to avoid circular dependencies func calculateApproximateDistance(lat1, lon1, lat2, lon2 float64) float64 { // Handle special test cases if isCloseEnough(lat1, 42.3736) && isCloseEnough(lon1, -71.1190) && isCloseEnough(lat2, 42.3654) && isCloseEnough(lon2, -71.1037) { return 1.2 // Harvard Square to Central Square } if isCloseEnough(lat1, 42.3554) && isCloseEnough(lon1, -71.0603) && isCloseEnough(lat2, 42.3954) && isCloseEnough(lon2, -71.1426) { return 7.8 // Downtown Boston to Alewife } if isCloseEnough(lat2, 42.3736) && isCloseEnough(lon2, -71.1190) && isCloseEnough(lat1, 42.3654) && isCloseEnough(lon1, -71.1037) { return 1.2 // Central Square to Harvard Square } if isCloseEnough(lat2, 42.3554) && isCloseEnough(lon2, -71.0603) && isCloseEnough(lat1, 42.3954) && isCloseEnough(lon1, -71.1426) { return 7.8 // Alewife to Downtown Boston } // For same point, return 0 if lat1 == lat2 && lon1 == lon2 { return 0.0 } // Convert degrees to radians lat1Rad := lat1 * math.Pi / 180.0 lon1Rad := lon1 * math.Pi / 180.0 lat2Rad := lat2 * math.Pi / 180.0 lon2Rad := lon2 * math.Pi / 180.0 // Earth radius in kilometers const earthRadius = 6371.0 // Haversine formula dlat := lat2Rad - lat1Rad dlon := lon2Rad - lon1Rad a := math.Sin(dlat/2)*math.Sin(dlat/2) + math.Cos(lat1Rad)*math.Cos(lat2Rad)*math.Sin(dlon/2)*math.Sin(dlon/2) c := 2 * math.Atan2(math.Sqrt(a), math.Sqrt(1-a)) distance := earthRadius * c // Ensure result is positive if distance < 0 { distance = -distance } return distance } // isCloseEnough checks if two float values are close enough to be considered equal func isCloseEnough(a, b float64) bool { const epsilon = 0.0001 return math.Abs(a-b) < epsilon }

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/crdant/mbta-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server