"""
Abstract base class for LLM providers.
This defines the contract that all LLM providers must implement,
ensuring consistent behavior regardless of the underlying service.
Design Pattern: Strategy Pattern
- Each provider is a concrete strategy implementing the same interface
- The orchestrator doesn't need to know which provider it's using
- Easy to add new providers without modifying existing code
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from dataclasses import dataclass
from enum import Enum
class ToolCallStatus(str, Enum):
"""Status of a tool call in the conversation."""
PENDING = "pending"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class ToolCall:
"""
Represents a tool/function call made by the LLM.
This is a normalized representation that works across all providers,
even though each provider has different internal formats.
"""
id: str # Unique identifier for this tool call
name: str # Tool/function name
arguments: Dict[str, Any] # Parsed arguments as dict
status: ToolCallStatus = ToolCallStatus.PENDING
result: Optional[Any] = None # Result after execution
error: Optional[str] = None # Error message if failed
@dataclass
class LLMResponse:
"""
Normalized response from any LLM provider.
This abstraction hides provider-specific response formats and provides
a consistent interface for the orchestrator to work with.
"""
content: str # The text response from the LLM
tool_calls: List[ToolCall] # Any tool calls the LLM wants to make
finish_reason: str # Why the generation stopped (stop, tool_calls, length, etc.)
usage: Dict[str, int] # Token usage stats
raw_response: Any # Original provider response for debugging
@property
def has_tool_calls(self) -> bool:
"""Check if this response contains tool calls."""
return len(self.tool_calls) > 0
@property
def is_complete(self) -> bool:
"""Check if generation completed naturally (not truncated)."""
return self.finish_reason in ("stop", "end_turn", "end")
class BaseLLMProvider(ABC):
"""
Abstract base class for all LLM providers.
Any new LLM provider (OpenAI, Claude, Gemini, etc.) must implement these methods.
This ensures the orchestrator can work with any provider without modification.
Design Principles:
- Provider-agnostic: Methods work the same regardless of underlying service
- Async-first: All I/O operations are async for better performance
- Type-safe: Full type hints for better IDE support and fewer bugs
- Error-transparent: Errors are normalized but preserve original context
"""
def __init__(self, config: Dict[str, Any]):
"""
Initialize the provider with configuration.
Args:
config: Provider-specific configuration (API keys, models, etc.)
"""
self.config = config
@abstractmethod
async def generate(
self,
messages: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
**kwargs
) -> LLMResponse:
"""
Generate a response from the LLM.
This is the core method that all providers must implement. It takes
a conversation history and optional tools, and returns a normalized response.
Args:
messages: Conversation history in standard format:
[{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
tools: Optional list of tools/functions the LLM can call
temperature: Sampling temperature (0.0 = deterministic, 2.0 = creative)
max_tokens: Maximum tokens to generate
**kwargs: Provider-specific options
Returns:
LLMResponse: Normalized response with content and/or tool calls
Raises:
LLMProviderError: If the API call fails
"""
pass
@abstractmethod
async def generate_with_tools(
self,
messages: List[Dict[str, Any]],
tools: List[Dict[str, Any]],
max_turns: int = 10,
**kwargs
) -> tuple[str, List[Dict[str, Any]]]:
"""
Generate a response with automatic tool calling loop.
This method handles the full tool-calling workflow:
1. Send messages + tools to LLM
2. If LLM wants to call tools, execute them
3. Send tool results back to LLM
4. Repeat until LLM provides final answer
This abstracts away the differences in how each provider handles tool calling.
Args:
messages: Initial conversation history
tools: Available tools in MCP format
max_turns: Maximum tool-calling iterations to prevent infinite loops
**kwargs: Provider-specific options
Returns:
tuple: (final_text_response, full_conversation_history)
Raises:
LLMProviderError: If the API call fails
MaxTurnsExceededError: If max_turns is reached without completion
"""
pass
@abstractmethod
def format_tool_result(self, tool_call: ToolCall) -> Dict[str, Any]:
"""
Format a tool execution result for this provider.
Each provider expects tool results in a different format.
This method converts our normalized ToolCall into the provider's format.
Args:
tool_call: The completed tool call with result
Returns:
dict: Provider-specific formatted result
"""
pass
@property
@abstractmethod
def provider_name(self) -> str:
"""Return the name of this provider (e.g., 'openai', 'anthropic')."""
pass
@property
@abstractmethod
def model_name(self) -> str:
"""Return the model being used (e.g., 'gpt-4o', 'claude-3-5-sonnet')."""
pass
def __repr__(self) -> str:
"""String representation for debugging."""
return f"{self.__class__.__name__}(model={self.model_name})"
class LLMProviderError(Exception):
"""Base exception for LLM provider errors."""
def __init__(self, message: str, provider: str, original_error: Optional[Exception] = None):
self.message = message
self.provider = provider
self.original_error = original_error
super().__init__(f"[{provider}] {message}")
class MaxTurnsExceededError(LLMProviderError):
"""Raised when tool-calling loop exceeds maximum turns."""
pass
class RateLimitError(LLMProviderError):
"""Raised when provider rate limit is hit."""
pass
class AuthenticationError(LLMProviderError):
"""Raised when API key is invalid or missing."""
pass