base.py•10.1 kB
"""
Base endpoint class for the cBioPortal MCP server.
This module provides a base class that contains common functionality
shared across all endpoint classes, including pagination logic,
error handling, and response formatting.
"""
import asyncio
from functools import wraps
from typing import Any, Dict, List, Optional, Union
from ..api_client import APIClient
from ..constants import FETCH_ALL_PAGE_SIZE
from ..utils.logging import get_logger
from ..utils.pagination import collect_all_results
from ..utils.validation import validate_page_params, validate_sort_params
logger = get_logger(__name__)
def handle_api_errors(operation_name: str):
    """
    Decorator to handle common API errors consistently across endpoints.
    
    Args:
        operation_name: Description of the operation being performed
    """
    def decorator(func):
        @wraps(func)
        async def wrapper(*args, **kwargs):
            try:
                # Ensure API client is ready for BaseEndpoint instances
                if len(args) > 0 and hasattr(args[0], '_ensure_api_client_ready'):
                    await args[0]._ensure_api_client_ready()
                return await func(*args, **kwargs)
            except (ValueError, TypeError) as e:
                # Re-raise validation errors so they can be caught by tests
                raise e
            except Exception as e:
                logger.error(f"Error in {operation_name}: {str(e)}")
                return {"error": f"Failed to {operation_name}: {str(e)}"}
        return wrapper
    return decorator
def validate_paginated_params(func):
    """
    Decorator to validate common pagination parameters.
    
    Automatically validates page_number, page_size, limit, sort_by, and direction
    parameters if they exist in the function signature.
    """
    @wraps(func)
    async def wrapper(self, *args, **kwargs):
        # Extract parameters from positional args based on function signature
        import inspect
        sig = inspect.signature(func)
        param_names = list(sig.parameters.keys())[1:]  # Skip 'self'
        
        # Build a dictionary of parameter values
        bound_args = sig.bind(self, *args, **kwargs)
        bound_args.apply_defaults()
        
        page_number = bound_args.arguments.get('page_number', 0)
        page_size = bound_args.arguments.get('page_size', 50)
        limit = bound_args.arguments.get('limit', None)
        sort_by = bound_args.arguments.get('sort_by', None)
        direction = bound_args.arguments.get('direction', 'ASC')
        
        # Validate pagination parameters - let exceptions bubble up
        validate_page_params(page_number, page_size, limit)
        validate_sort_params(sort_by, direction)
        
        return await func(self, *args, **kwargs)
    return wrapper
class BaseEndpoint:
    """
    Base class for all endpoint classes.
    
    Provides common functionality including pagination logic,
    error handling, and response formatting.
    """
    
    def __init__(self, api_client: APIClient):
        self.api_client = api_client
    
    async def _ensure_api_client_ready(self):
        """Ensure APIClient is initialized before making requests."""
        if not hasattr(self.api_client, '_client') or self.api_client._client is None:
            await self.api_client.startup()
            logger.info("APIClient initialized via BaseEndpoint._ensure_api_client_ready")
    
    def build_pagination_params(
        self,
        page_number: int,
        page_size: int,
        sort_by: Optional[str],
        direction: str,
        limit: Optional[int] = None
    ) -> Dict[str, Any]:
        """
        Build standardized pagination parameters for API requests.
        
        Args:
            page_number: Page number to retrieve (0-based)
            page_size: Number of items per page
            sort_by: Field to sort by
            direction: Sort direction (ASC or DESC)
            limit: Maximum number of items to return
            
        Returns:
            Dictionary of API parameters
        """
        params = {
            "pageNumber": page_number,
            "pageSize": page_size,
            "direction": direction,
        }
        
        if sort_by:
            params["sortBy"] = sort_by
            
        if limit == 0:
            params["pageSize"] = FETCH_ALL_PAGE_SIZE
            
        return params
    
    def build_pagination_response(
        self,
        results: List[Dict[str, Any]],
        page_number: int,
        page_size: int,
        has_more: bool,
        data_key: str
    ) -> Dict[str, Any]:
        """
        Build standardized pagination response structure.
        
        Args:
            results: List of result items
            page_number: Current page number
            page_size: Page size used
            has_more: Whether more results are available
            data_key: Key name for the results in the response
            
        Returns:
            Standardized response dictionary
        """
        return {
            data_key: results,
            "pagination": {
                "page": page_number,
                "page_size": page_size,
                "total_found": len(results),
                "has_more": has_more,
            },
        }
    
    def determine_has_more(
        self,
        results: List[Dict[str, Any]],
        api_params: Dict[str, Any]
    ) -> bool:
        """
        Determine if more results are available based on the API response.
        
        Args:
            results: Results returned from the API
            api_params: Parameters used in the API call
            
        Returns:
            True if more results might be available, False otherwise
        """
        api_might_have_more = len(results) == api_params["pageSize"]
        
        # If 'fetch all' was intended and API returned less than max fetch size,
        # then it's definitely the end
        if (api_params["pageSize"] == FETCH_ALL_PAGE_SIZE
            and len(results) < FETCH_ALL_PAGE_SIZE):
            api_might_have_more = False
            
        return api_might_have_more
    
    def apply_limit(
        self,
        results: List[Dict[str, Any]],
        limit: Optional[int]
    ) -> List[Dict[str, Any]]:
        """
        Apply limit to results if specified.
        
        Args:
            results: List of results to limit
            limit: Maximum number of results to return
            
        Returns:
            Limited results list
        """
        if limit and limit > 0 and len(results) > limit:
            return results[:limit]
        return results
    
    async def paginated_request(
        self,
        endpoint: str,
        page_number: int = 0,
        page_size: int = 50,
        sort_by: Optional[str] = None,
        direction: str = "ASC",
        limit: Optional[int] = None,
        data_key: str = "results",
        additional_params: Optional[Dict[str, Any]] = None
    ) -> Dict[str, Any]:
        """
        Make a paginated API request with standardized handling.
        
        Args:
            endpoint: API endpoint to call
            page_number: Page number to retrieve (0-based)
            page_size: Number of items per page
            sort_by: Field to sort by
            direction: Sort direction (ASC or DESC)
            limit: Maximum number of items to return
            data_key: Key name for results in response
            additional_params: Additional parameters to include in the request
            
        Returns:
            Standardized paginated response
        """
        # Build base pagination parameters
        api_params = self.build_pagination_params(
            page_number, page_size, sort_by, direction, limit
        )
        
        # Add any additional parameters
        if additional_params:
            api_params.update(additional_params)
        
        # Special behavior for limit=0 (fetch all results)
        if limit == 0:
            results_from_api = await collect_all_results(
                self.api_client, endpoint, params=api_params
            )
            results_for_response = results_from_api
            has_more = False  # We fetched everything
        else:
            # Fetch just the requested page
            results_from_api = await self.api_client.make_api_request(
                endpoint, params=api_params
            )
            
            # Apply limit if specified
            results_for_response = self.apply_limit(results_from_api, limit)
            
            # Determine if there might be more data available
            has_more = self.determine_has_more(results_from_api, api_params)
        
        return self.build_pagination_response(
            results_for_response, page_number, page_size, has_more, data_key
        )
    
    async def concurrent_fetch(
        self,
        fetch_tasks: List[asyncio.Task],
        operation_name: str = "concurrent operation"
    ) -> Dict[str, Any]:
        """
        Execute multiple API requests concurrently and return structured results.
        
        Args:
            fetch_tasks: List of async tasks to execute
            operation_name: Description of the operation for logging
            
        Returns:
            Dictionary with results and metadata
        """
        import time
        
        start_time = time.perf_counter()
        results = await asyncio.gather(*fetch_tasks)
        end_time = time.perf_counter()
        
        # Process results
        success_count = sum(1 for r in results if r.get("success", True))
        error_count = len(results) - success_count
        
        return {
            "results": results,
            "metadata": {
                "count": len(results),
                "errors": error_count,
                "execution_time": round(end_time - start_time, 3),
                "concurrent": True,
                "operation": operation_name,
            },
        }