search.py•10.8 kB
"""Flight search tools using Duffel API."""
import logging
from typing import Dict
import json
from mcp.server.fastmcp import FastMCP
# Import all models through flight_search
from ..models.flight_search import (
    FlightSearch,
    MultiCityRequest,
    OfferDetails
)
from ..models.time_specs import TimeSpec
from ..api import DuffelClient
# Set up logging
logger = logging.getLogger(__name__)
# Initialize FastMCP server and API client
mcp = FastMCP("find-flights-mcp")
flight_client = DuffelClient(logger)
def _create_slice(origin: str, destination: str, date: str, 
                 departure_time: TimeSpec | None = None,
                 arrival_time: TimeSpec | None = None) -> Dict:
    """Helper to create a slice with time ranges."""
    slice_data = {
        "origin": origin,
        "destination": destination,
        "departure_date": date,
        "departure_time": {
            "from": "00:00",
            "to": "23:59"
        },
        "arrival_time": {
            "from": "00:00",
            "to": "23:59"
        }
    }
    
    if departure_time:
        slice_data["departure_time"] = {
            "from": departure_time.from_time,
            "to": departure_time.to_time
        }
    
    if arrival_time:
        slice_data["arrival_time"] = {
            "from": arrival_time.from_time,
            "to": arrival_time.to_time
        }
    
    return slice_data
@mcp.tool()
async def search_flights(params: FlightSearch) -> str:
    """Search for flights based on parameters."""
    try:
        slices = []
        
        # Build slices based on flight type
        if params.type == "one_way":
            slices = [_create_slice(
                params.origin, 
                params.destination, 
                params.departure_date,
                params.departure_time,
                params.arrival_time
            )]
        elif params.type == "round_trip":
            if not params.return_date:
                raise ValueError("Return date required for round-trip flights")
            slices = [
                _create_slice(
                    params.origin,
                    params.destination,
                    params.departure_date,
                    params.departure_time,
                    params.arrival_time
                ),
                _create_slice(
                    params.destination,
                    params.origin,
                    params.return_date,
                    params.departure_time,
                    params.arrival_time
                )
            ]
        elif params.type == "multi_city":
            if not params.additional_stops:
                raise ValueError("Additional stops required for multi-city flights")
            
            # First leg
            slices.append({
                "origin": params.origin,
                "destination": params.destination,
                "departure_date": params.departure_date,
                "departure_time": {
                    "from": "00:00",
                    "to": "23:59"
                },
                "arrival_time": {
                    "from": "00:00",
                    "to": "23:59"
                }
            })
            
            # Additional legs
            for stop in params.additional_stops:
                slices.append({
                    "origin": stop["origin"],
                    "destination": stop["destination"],
                    "departure_date": stop["departure_date"],
                    "departure_time": {
                        "from": "00:00",
                        "to": "23:59"
                    },
                    "arrival_time": {
                        "from": "00:00",
                        "to": "23:59"
                    }
                })
        
        # Use async context manager
        async with flight_client as client:
            response = await client.create_offer_request(
                slices=slices,
                cabin_class=params.cabin_class,
                adult_count=params.adults,
                max_connections=params.max_connections,
                return_offers=True,
                supplier_timeout=15000
            )
        
        # Format the response
        formatted_response = {
            'request_id': response['request_id'],
            'offers': []
        }
        
        # Get all offers (limit to 10 to manage response size)
        for offer in response.get('offers', [])[:50]:  # Keep the slice to limit offers
            offer_details = {
                'offer_id': offer.get('id'),
                'price': {
                    'amount': offer.get('total_amount'),
                    'currency': offer.get('total_currency')
                },
                'slices': []
            }
            
            # Only include essential slice details
            for slice in offer.get('slices', []):
                segments = slice.get('segments', [])
                if segments:  # Check if there are any segments
                    slice_details = {
                        'origin': slice['origin']['iata_code'],
                        'destination': slice['destination']['iata_code'],
                        'departure': segments[0].get('departing_at'),  # First segment departure
                        'arrival': segments[-1].get('arriving_at'),    # Last segment arrival
                        'duration': slice.get('duration'),
                        'carrier': segments[0].get('marketing_carrier', {}).get('name'),
                        'stops': len(segments) - 1,
                        'stops_description': 'Non-stop' if len(segments) == 1 else f'{len(segments) - 1} stop{"s" if len(segments) - 1 > 1 else ""}',
                        'connections': []
                    }
                    
                    # Add connection information if there are multiple segments
                    if len(segments) > 1:
                        for i in range(len(segments)-1):
                            connection = {
                                'airport': segments[i].get('destination', {}).get('iata_code'),
                                'arrival': segments[i].get('arriving_at'),
                                'departure': segments[i+1].get('departing_at'),
                                'duration': segments[i+1].get('duration')
                            }
                            slice_details['connections'].append(connection)
                    
                    offer_details['slices'].append(slice_details)
            
            formatted_response['offers'].append(offer_details)
        
        return json.dumps(formatted_response, indent=2)
            
    except Exception as e:
        logger.error(f"Error searching flights: {str(e)}", exc_info=True)
        raise
@mcp.tool()
async def get_offer_details(params: OfferDetails) -> str:
    """Get detailed information about a specific flight offer."""
    try:
        async with flight_client as client:
            response = await client.get_offer(
                offer_id=params.offer_id
            )
            return json.dumps(response, indent=2)
            
    except Exception as e:
        logger.error(f"Error getting offer details: {str(e)}", exc_info=True)
        raise
@mcp.tool(name="search_multi_city")
async def search_multi_city(params: MultiCityRequest) -> str:
    """Search for multi-city flights."""
    try:
        slices = []
        for segment in params.segments:
            slices.append(_create_slice(
                segment.origin,
                segment.destination,
                segment.departure_date,
                None,
                None
            ))
        # Use async context manager with shorter timeout
        async with flight_client as client:
            response = await client.create_offer_request(
                slices=slices,
                cabin_class=params.cabin_class,
                adult_count=params.adults,
                max_connections=params.max_connections,
                return_offers=True,
                supplier_timeout=30000  # Increased timeout for multi-city
            )
        
            # Format response inside the context
            formatted_response = {
                'request_id': response['request_id'],
                'offers': []
            }
            
            # Process offers inside the context
            for offer in response.get('offers', [])[:10]:
                offer_details = {
                    'offer_id': offer.get('id'),
                    'price': {
                        'amount': offer.get('total_amount'),
                        'currency': offer.get('total_currency')
                    },
                    'slices': []
                }
                
                for slice in offer.get('slices', []):
                    segments = slice.get('segments', [])
                    if segments:
                        slice_details = {
                            'origin': slice['origin']['iata_code'],
                            'destination': slice['destination']['iata_code'],
                            'departure': segments[0].get('departing_at'),
                            'arrival': segments[-1].get('arriving_at'),
                            'duration': slice.get('duration'),
                            'carrier': segments[0].get('marketing_carrier', {}).get('name'),
                            'stops': len(segments) - 1,
                            'stops_description': 'Non-stop' if len(segments) == 1 else f'{len(segments) - 1} stop{"s" if len(segments) - 1 > 1 else ""}',
                            'connections': []
                        }
                        
                        if len(segments) > 1:
                            for i in range(len(segments)-1):
                                connection = {
                                    'airport': segments[i].get('destination', {}).get('iata_code'),
                                    'arrival': segments[i].get('arriving_at'),
                                    'departure': segments[i+1].get('departing_at'),
                                    'duration': segments[i+1].get('duration')
                                }
                                slice_details['connections'].append(connection)
                        
                        offer_details['slices'].append(slice_details)
                
                formatted_response['offers'].append(offer_details)
            
            return json.dumps(formatted_response, indent=2)
            
    except Exception as e:
        logger.error(f"Error searching flights: {str(e)}", exc_info=True)
        raise