base.pyβ’13 kB
"""
Base classes for MCP tools, resources, and prompts.
Provides consistent interfaces and shared functionality for all MCP components.
"""
import logging
from abc import ABC, abstractmethod
from typing import Any
from fastmcp import Context
from splunklib import client
logger = logging.getLogger(__name__)
class SplunkContext:
    """Shared context for Splunk operations"""
    def __init__(
        self,
        service: client.Service | None,
        is_connected: bool,
        client_config: dict[str, Any] = None,
    ):
        self.service = service
        self.is_connected = is_connected
        self.client_config = client_config or {}
class BaseTool(ABC):
    """
    Base class for all MCP tools.
    Provides common functionality like error handling, logging, and Splunk connection validation.
    """
    def __init__(self, name: str, description: str):
        self.name = name
        self.description = description
        self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
    def extract_client_config(self, kwargs: dict[str, Any]) -> dict[str, Any]:
        """
        Extract Splunk configuration from tool parameters.
        Looks for parameters that start with 'splunk_' and returns them as a dict.
        Also removes them from kwargs to prevent passing to tool logic.
        Args:
            kwargs: Tool parameters that may contain Splunk config
        Returns:
            Dict containing extracted Splunk configuration
        """
        client_config = {}
        splunk_keys = [key for key in list(kwargs.keys()) if key.startswith("splunk_")]
        for key in splunk_keys:
            value = kwargs.pop(key)
            # Ignore unset values to avoid passing None (e.g., int(None)) downstream
            if value is not None:
                client_config[key] = value
        return client_config if client_config else None
    def get_client_config_from_context(self, ctx: Context) -> dict[str, Any] | None:
        """
        Get client configuration from MCP context.
        Checks multiple sources in priority order:
        1. Context state (set by middleware per request)
        2. HTTP request headers (for HTTP transport)
        3. MCP client environment variables (lifespan context)
        Args:
            ctx: MCP context
        Returns:
            Client configuration dict or None
        """
        # Priority 1: Context state (preferred, set by middleware)
        try:
            if hasattr(ctx, "get_state"):
                state_cfg = ctx.get_state("client_config")  # type: ignore[attr-defined]
                if state_cfg:
                    self.logger.info(
                        "Using client config from context state (keys=%s)", list(state_cfg.keys())
                    )
                    return state_cfg
        except Exception as e:
            self.logger.debug("Failed to get client config from context state: %s", e)
        # Priority 2: HTTP headers (using FastMCP runtime dependencies)
        try:
            from fastmcp.server.dependencies import get_http_headers
            headers = get_http_headers(include_all=True)
            if headers:
                self.logger.debug(
                    "Attempting to extract config from HTTP headers (available: %s)",
                    list(headers.keys()),
                )
                # Extract Splunk configuration from headers
                from src.server import extract_client_config_from_headers
                client_config = extract_client_config_from_headers(headers)
                if client_config:
                    self.logger.info(
                        "Using client config from HTTP headers (keys=%s)",
                        list(client_config.keys()),
                    )
                    return client_config
        except Exception as e:
            self.logger.debug("Failed to get client config from HTTP headers: %s", e)
        # Priority 3: Lifespan context (client environment)
        try:
            splunk_ctx = ctx.request_context.lifespan_context
            if hasattr(splunk_ctx, "client_config") and splunk_ctx.client_config:
                self.logger.info("Using client config from environment variables")
                return splunk_ctx.client_config
        except Exception as e:
            self.logger.debug("Failed to get client config from lifespan context: %s", e)
        self.logger.debug("No client config found in any source")
        return None
    async def get_splunk_service(
        self, ctx: Context, tool_level_config: dict[str, Any] | None = None
    ) -> client.Service:
        """
        Get Splunk service connection using client config or fallback to server default.
        Priority order:
        1. Tool-level configuration (passed as parameter)
        2. MCP client configuration (from headers or environment)
        3. Server default connection
        Args:
            ctx: MCP context
            tool_level_config: Optional tool-level Splunk configuration
        Returns:
            Splunk service connection
        Raises:
            Exception: If no connection available and client config doesn't work
        """
        # Priority 1: Tool-level configuration
        if tool_level_config:
            try:
                from src.client.splunk_client import get_splunk_service
                self.logger.info("Using tool-level Splunk configuration")
                return get_splunk_service(tool_level_config)
            except Exception as e:
                self.logger.warning(f"Failed to connect with tool-level config: {e}")
        # Priority 2: MCP client configuration
        client_config = self.get_client_config_from_context(ctx)
        if client_config:
            try:
                from src.client.splunk_client import get_splunk_service
                self.logger.info("Using MCP client configuration")
                return get_splunk_service(client_config)
            except Exception as e:
                self.logger.warning(f"Failed to connect with MCP client config: {e}")
        # Priority 3: Server default connection
        is_available, service, error = self.check_splunk_available(ctx)
        if not is_available:
            raise Exception(f"Splunk connection not available: {error}")
        return service
    def check_splunk_available(self, ctx: Context) -> tuple[bool, client.Service | None, str]:
        """
        Check if Splunk is available and return status.
        Returns:
            Tuple of (is_available, service, error_message)
        """
        # Get splunk context from available sources
        splunk_ctx = self._get_splunk_context(ctx)
        # First, prefer per-request client configuration (HTTP headers / client env)
        try:
            client_config = None
            # From HTTP request state (preferred for HTTP transport)
            if (
                hasattr(ctx.request_context, "request")
                and hasattr(ctx.request_context.request, "state")
                and hasattr(ctx.request_context.request.state, "client_config")
            ):
                client_config = ctx.request_context.request.state.client_config
            # Or from context client_config (handle both dict and object)
            elif splunk_ctx:
                if hasattr(splunk_ctx, "client_config"):
                    client_config = splunk_ctx.client_config
                elif isinstance(splunk_ctx, dict) and "client_config" in splunk_ctx:
                    client_config = splunk_ctx["client_config"]
            if client_config:
                try:
                    from src.client.splunk_client import get_splunk_service
                    service = get_splunk_service(client_config)
                    return True, service, ""
                except Exception as e:
                    # Fall back to server default if client-config connection fails
                    logger.warning(
                        f"Client-config Splunk connection failed in availability check: {e}"
                    )
        except Exception:
            # Ignore header/env extraction issues and continue with server default
            pass
        # Fallback: use server default service established at startup
        # Handle both SplunkContext objects and dict formats
        is_connected = False
        service = None
        if splunk_ctx:
            if hasattr(splunk_ctx, "is_connected") and hasattr(splunk_ctx, "service"):
                # SplunkContext object
                is_connected = splunk_ctx.is_connected
                service = splunk_ctx.service
            elif isinstance(splunk_ctx, dict):
                # Dict format
                is_connected = splunk_ctx.get("is_connected", False)
                service = splunk_ctx.get("service", None)
        if not is_connected or not service:
            return (
                False,
                None,
                "Splunk service is not available. MCP server is running in degraded mode.",
            )
        return True, service, ""
    def _get_splunk_context(self, ctx: Context):
        """
        Get Splunk context from available sources with proper fallback handling.
        Returns:
            SplunkContext object, dict, or None
        """
        try:
            # Try lifespan context first (traditional path)
            if hasattr(ctx.request_context, "lifespan_context"):
                return ctx.request_context.lifespan_context
        except Exception:
            pass
        try:
            # Fallback: try to get from server instance (module initialization path)
            from fastmcp.server.dependencies import get_server
            server = get_server()
            if hasattr(server, "_splunk_context"):
                return server._splunk_context
        except Exception:
            pass
        return None
    def format_error_response(self, error: str, **kwargs) -> dict[str, Any]:
        """Format a consistent error response"""
        return {"status": "error", "error": error, **kwargs}
    def format_success_response(self, data: dict[str, Any]) -> dict[str, Any]:
        """Format a consistent success response"""
        return {"status": "success", **data}
    @abstractmethod
    async def execute(self, ctx: Context, **kwargs) -> dict[str, Any]:
        """Execute the tool's main functionality"""
        pass
class BaseResource(ABC):
    """
    Base class for all MCP resources.
    Resources provide read-only data that can be accessed by MCP clients.
    """
    def __init__(self, uri: str, name: str, description: str, mime_type: str = "text/plain"):
        self.uri = uri
        self.name = name
        self.description = description
        self.mime_type = mime_type
        self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
    @abstractmethod
    async def get_content(self, ctx: Context) -> str:
        """Get the resource content"""
        pass
class BasePrompt(ABC):
    """
    Base class for all MCP prompts.
    Prompts provide templated interactions for common use cases.
    """
    def __init__(self, name: str, description: str):
        self.name = name
        self.description = description
        self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
    @abstractmethod
    async def get_prompt(self, ctx: Context, **kwargs) -> dict[str, Any]:
        """Get the prompt content with any dynamic variables filled in"""
        pass
class ToolMetadata:
    """Metadata for tool registration and discovery"""
    def __init__(
        self,
        name: str,
        description: str,
        category: str,
        tags: list[str] | None = None,
        requires_connection: bool = True,
        version: str = "1.0.0",
    ):
        self.name = name
        self.description = description
        self.category = category
        self.tags = tags or []
        self.requires_connection = requires_connection
        self.version = version
class ResourceMetadata:
    """Metadata for resource registration and discovery"""
    def __init__(
        self,
        uri: str,
        name: str,
        description: str,
        mime_type: str = "text/plain",
        category: str = "general",
        tags: list[str] | None = None,
    ):
        self.uri = uri
        self.name = name
        self.description = description
        self.mime_type = mime_type
        self.category = category
        self.tags = tags or []
class PromptMetadata:
    """Metadata for prompt registration and discovery"""
    def __init__(
        self,
        name: str,
        description: str,
        category: str,
        tags: list[str] | None = None,
        arguments: list[dict[str, Any]] | None = None,
    ):
        self.name = name
        self.description = description
        self.category = category
        self.tags = tags or []
        self.arguments = arguments or []