# -*- coding: utf-8 -*-
"""Location: ./mcpgateway/services/llm_provider_service.py
Copyright 2025
SPDX-License-Identifier: Apache-2.0
LLM Provider Service
This module implements LLM provider management for ContextForge.
It handles provider registration, CRUD operations, model management,
and health checks for the internal LLM Chat feature.
"""
# Standard
from datetime import datetime, timezone
from typing import List, Optional, Tuple
# Third-Party
import httpx
from sqlalchemy import and_, func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
# First-Party
from mcpgateway.db import LLMModel, LLMProvider, LLMProviderType
from mcpgateway.llm_schemas import (
GatewayModelInfo,
HealthStatus,
LLMModelCreate,
LLMModelResponse,
LLMModelUpdate,
LLMProviderCreate,
LLMProviderResponse,
LLMProviderUpdate,
ProviderHealthCheck,
)
from mcpgateway.services.logging_service import LoggingService
from mcpgateway.utils.create_slug import slugify
from mcpgateway.utils.services_auth import decode_auth, encode_auth
# Initialize logging
logging_service = LoggingService()
logger = logging_service.get_logger(__name__)
class LLMProviderError(Exception):
"""Base class for LLM provider-related errors."""
class LLMProviderNotFoundError(LLMProviderError):
"""Raised when a requested LLM provider is not found."""
class LLMProviderNameConflictError(LLMProviderError):
"""Raised when an LLM provider name conflicts with an existing one."""
def __init__(self, name: str, provider_id: Optional[str] = None):
"""Initialize the exception.
Args:
name: The conflicting provider name.
provider_id: Optional ID of the existing provider.
"""
self.name = name
self.provider_id = provider_id
message = f"LLM Provider already exists with name: {name}"
if provider_id:
message += f" (ID: {provider_id})"
super().__init__(message)
class LLMModelNotFoundError(LLMProviderError):
"""Raised when a requested LLM model is not found."""
class LLMModelConflictError(LLMProviderError):
"""Raised when an LLM model conflicts with an existing one."""
class LLMProviderService:
"""Service for managing LLM providers and models.
Provides methods to create, list, retrieve, update, and delete
provider and model records. Also supports health checks.
"""
def __init__(self) -> None:
"""Initialize a new LLMProviderService instance."""
self._initialized = False
async def initialize(self) -> None:
"""Initialize the LLM provider service."""
if not self._initialized:
logger.info("Initializing LLM Provider Service")
self._initialized = True
async def shutdown(self) -> None:
"""Shutdown the LLM provider service."""
if self._initialized:
logger.info("Shutting down LLM Provider Service")
self._initialized = False
# ---------------------------------------------------------------------------
# Provider CRUD Operations
# ---------------------------------------------------------------------------
def create_provider(
self,
db: Session,
provider_data: LLMProviderCreate,
created_by: Optional[str] = None,
) -> LLMProvider:
"""Create a new LLM provider.
Args:
db: Database session.
provider_data: Provider data to create.
created_by: Username of creator.
Returns:
Created LLMProvider instance.
Raises:
LLMProviderNameConflictError: If provider name already exists.
"""
# Check for name conflict
existing = db.execute(select(LLMProvider).where(LLMProvider.name == provider_data.name)).scalar_one_or_none()
if existing:
raise LLMProviderNameConflictError(provider_data.name, existing.id)
# Encrypt API key if provided
encrypted_api_key = None
if provider_data.api_key:
encrypted_api_key = encode_auth({"api_key": provider_data.api_key})
# Create provider
provider = LLMProvider(
name=provider_data.name,
slug=slugify(provider_data.name),
description=provider_data.description,
provider_type=provider_data.provider_type.value,
api_key=encrypted_api_key,
api_base=provider_data.api_base,
api_version=provider_data.api_version,
config=provider_data.config,
default_model=provider_data.default_model,
default_temperature=provider_data.default_temperature,
default_max_tokens=provider_data.default_max_tokens,
enabled=provider_data.enabled,
plugin_ids=provider_data.plugin_ids,
created_by=created_by,
)
try:
db.add(provider)
db.commit()
db.refresh(provider)
logger.info(f"Created LLM provider: {provider.name} (ID: {provider.id})")
return provider
except IntegrityError as e:
db.rollback()
logger.error(f"Failed to create LLM provider: {e}")
raise LLMProviderNameConflictError(provider_data.name)
def get_provider(self, db: Session, provider_id: str) -> LLMProvider:
"""Get an LLM provider by ID.
Args:
db: Database session.
provider_id: Provider ID to retrieve.
Returns:
LLMProvider instance.
Raises:
LLMProviderNotFoundError: If provider not found.
"""
provider = db.execute(select(LLMProvider).where(LLMProvider.id == provider_id)).scalar_one_or_none()
if not provider:
raise LLMProviderNotFoundError(f"Provider not found: {provider_id}")
return provider
def get_provider_by_slug(self, db: Session, slug: str) -> LLMProvider:
"""Get an LLM provider by slug.
Args:
db: Database session.
slug: Provider slug to retrieve.
Returns:
LLMProvider instance.
Raises:
LLMProviderNotFoundError: If provider not found.
"""
provider = db.execute(select(LLMProvider).where(LLMProvider.slug == slug)).scalar_one_or_none()
if not provider:
raise LLMProviderNotFoundError(f"Provider not found: {slug}")
return provider
def list_providers(
self,
db: Session,
enabled_only: bool = False,
page: int = 1,
page_size: int = 50,
) -> Tuple[List[LLMProvider], int]:
"""List all LLM providers.
Args:
db: Database session.
enabled_only: Only return enabled providers.
page: Page number (1-indexed).
page_size: Items per page.
Returns:
Tuple of (providers list, total count).
"""
query = select(LLMProvider)
if enabled_only:
query = query.where(LLMProvider.enabled.is_(True))
# Get total count efficiently using func.count()
count_query = select(func.count(LLMProvider.id)) # pylint: disable=not-callable
if enabled_only:
count_query = count_query.where(LLMProvider.enabled.is_(True))
total = db.execute(count_query).scalar() or 0
# Apply pagination
offset = (page - 1) * page_size
query = query.offset(offset).limit(page_size).order_by(LLMProvider.name)
providers = list(db.execute(query).scalars().all())
return providers, total
def update_provider(
self,
db: Session,
provider_id: str,
provider_data: LLMProviderUpdate,
modified_by: Optional[str] = None,
) -> LLMProvider:
"""Update an LLM provider.
Args:
db: Database session.
provider_id: Provider ID to update.
provider_data: Updated provider data.
modified_by: Username of modifier.
Returns:
Updated LLMProvider instance.
Raises:
LLMProviderNotFoundError: If provider not found.
LLMProviderNameConflictError: If new name conflicts.
IntegrityError: If database constraint violation.
"""
provider = self.get_provider(db, provider_id)
# Check for name conflict if name is being changed
if provider_data.name and provider_data.name != provider.name:
existing = db.execute(
select(LLMProvider).where(
and_(
LLMProvider.name == provider_data.name,
LLMProvider.id != provider_id,
)
)
).scalar_one_or_none()
if existing:
raise LLMProviderNameConflictError(provider_data.name, existing.id)
provider.name = provider_data.name
provider.slug = slugify(provider_data.name)
# Update fields if provided
if provider_data.description is not None:
provider.description = provider_data.description
if provider_data.provider_type is not None:
provider.provider_type = provider_data.provider_type.value
if provider_data.api_key is not None:
provider.api_key = encode_auth({"api_key": provider_data.api_key})
if provider_data.api_base is not None:
provider.api_base = provider_data.api_base
if provider_data.api_version is not None:
provider.api_version = provider_data.api_version
if provider_data.config is not None:
provider.config = provider_data.config
if provider_data.default_model is not None:
provider.default_model = provider_data.default_model
if provider_data.default_temperature is not None:
provider.default_temperature = provider_data.default_temperature
if provider_data.default_max_tokens is not None:
provider.default_max_tokens = provider_data.default_max_tokens
if provider_data.enabled is not None:
provider.enabled = provider_data.enabled
if provider_data.plugin_ids is not None:
provider.plugin_ids = provider_data.plugin_ids
provider.modified_by = modified_by
try:
db.commit()
db.refresh(provider)
logger.info(f"Updated LLM provider: {provider.name} (ID: {provider.id})")
return provider
except IntegrityError as e:
db.rollback()
logger.error(f"Failed to update LLM provider: {e}")
raise
def delete_provider(self, db: Session, provider_id: str) -> bool:
"""Delete an LLM provider.
Args:
db: Database session.
provider_id: Provider ID to delete.
Returns:
True if deleted successfully.
Raises:
LLMProviderNotFoundError: If provider not found.
"""
provider = self.get_provider(db, provider_id)
provider_name = provider.name
db.delete(provider)
db.commit()
logger.info(f"Deleted LLM provider: {provider_name} (ID: {provider_id})")
return True
def set_provider_state(self, db: Session, provider_id: str, activate: Optional[bool] = None) -> LLMProvider:
"""Set provider enabled state.
Args:
db: Database session.
provider_id: Provider ID to update.
activate: If provided, sets enabled to this value. If None, inverts current state (legacy behavior).
Returns:
Updated LLMProvider instance.
"""
provider = self.get_provider(db, provider_id)
if activate is None:
# Legacy toggle behavior for backward compatibility
provider.enabled = not provider.enabled
else:
provider.enabled = activate
db.commit()
db.refresh(provider)
logger.info(f"Set LLM provider state: {provider.name} enabled={provider.enabled}")
return provider
# ---------------------------------------------------------------------------
# Model CRUD Operations
# ---------------------------------------------------------------------------
def create_model(
self,
db: Session,
model_data: LLMModelCreate,
) -> LLMModel:
"""Create a new LLM model.
Args:
db: Database session.
model_data: Model data to create.
Returns:
Created LLMModel instance.
Raises:
LLMProviderNotFoundError: If provider not found.
LLMModelConflictError: If model already exists for provider.
"""
# Verify provider exists
self.get_provider(db, model_data.provider_id)
# Check for conflict
existing = db.execute(
select(LLMModel).where(
and_(
LLMModel.provider_id == model_data.provider_id,
LLMModel.model_id == model_data.model_id,
)
)
).scalar_one_or_none()
if existing:
raise LLMModelConflictError(f"Model {model_data.model_id} already exists for provider {model_data.provider_id}")
model = LLMModel(
provider_id=model_data.provider_id,
model_id=model_data.model_id,
model_name=model_data.model_name,
model_alias=model_data.model_alias,
description=model_data.description,
supports_chat=model_data.supports_chat,
supports_streaming=model_data.supports_streaming,
supports_function_calling=model_data.supports_function_calling,
supports_vision=model_data.supports_vision,
context_window=model_data.context_window,
max_output_tokens=model_data.max_output_tokens,
enabled=model_data.enabled,
deprecated=model_data.deprecated,
)
try:
db.add(model)
db.commit()
db.refresh(model)
logger.info(f"Created LLM model: {model.model_id} (ID: {model.id})")
return model
except IntegrityError as e:
db.rollback()
logger.error(f"Failed to create LLM model: {e}")
raise LLMModelConflictError(f"Model conflict: {model_data.model_id}")
def get_model(self, db: Session, model_id: str) -> LLMModel:
"""Get an LLM model by ID.
Args:
db: Database session.
model_id: Model ID to retrieve.
Returns:
LLMModel instance.
Raises:
LLMModelNotFoundError: If model not found.
"""
model = db.execute(select(LLMModel).where(LLMModel.id == model_id)).scalar_one_or_none()
if not model:
raise LLMModelNotFoundError(f"Model not found: {model_id}")
return model
def list_models(
self,
db: Session,
provider_id: Optional[str] = None,
enabled_only: bool = False,
page: int = 1,
page_size: int = 50,
) -> Tuple[List[LLMModel], int]:
"""List LLM models.
Args:
db: Database session.
provider_id: Filter by provider ID.
enabled_only: Only return enabled models.
page: Page number (1-indexed).
page_size: Items per page.
Returns:
Tuple of (models list, total count).
"""
query = select(LLMModel)
if provider_id:
query = query.where(LLMModel.provider_id == provider_id)
if enabled_only:
query = query.where(LLMModel.enabled.is_(True))
# Get total count efficiently using func.count()
count_query = select(func.count(LLMModel.id)) # pylint: disable=not-callable
if provider_id:
count_query = count_query.where(LLMModel.provider_id == provider_id)
if enabled_only:
count_query = count_query.where(LLMModel.enabled.is_(True))
total = db.execute(count_query).scalar() or 0
# Apply pagination
offset = (page - 1) * page_size
query = query.offset(offset).limit(page_size).order_by(LLMModel.model_name)
models = list(db.execute(query).scalars().all())
return models, total
def update_model(
self,
db: Session,
model_id: str,
model_data: LLMModelUpdate,
) -> LLMModel:
"""Update an LLM model.
Args:
db: Database session.
model_id: Model ID to update.
model_data: Updated model data.
Returns:
Updated LLMModel instance.
"""
model = self.get_model(db, model_id)
if model_data.model_id is not None:
model.model_id = model_data.model_id
if model_data.model_name is not None:
model.model_name = model_data.model_name
if model_data.model_alias is not None:
model.model_alias = model_data.model_alias
if model_data.description is not None:
model.description = model_data.description
if model_data.supports_chat is not None:
model.supports_chat = model_data.supports_chat
if model_data.supports_streaming is not None:
model.supports_streaming = model_data.supports_streaming
if model_data.supports_function_calling is not None:
model.supports_function_calling = model_data.supports_function_calling
if model_data.supports_vision is not None:
model.supports_vision = model_data.supports_vision
if model_data.context_window is not None:
model.context_window = model_data.context_window
if model_data.max_output_tokens is not None:
model.max_output_tokens = model_data.max_output_tokens
if model_data.enabled is not None:
model.enabled = model_data.enabled
if model_data.deprecated is not None:
model.deprecated = model_data.deprecated
db.commit()
db.refresh(model)
logger.info(f"Updated LLM model: {model.model_id} (ID: {model.id})")
return model
def delete_model(self, db: Session, model_id: str) -> bool:
"""Delete an LLM model.
Args:
db: Database session.
model_id: Model ID to delete.
Returns:
True if deleted successfully.
"""
model = self.get_model(db, model_id)
model_name = model.model_id
db.delete(model)
db.commit()
logger.info(f"Deleted LLM model: {model_name} (ID: {model_id})")
return True
def set_model_state(self, db: Session, model_id: str, activate: Optional[bool] = None) -> LLMModel:
"""Set model enabled state.
Args:
db: Database session.
model_id: Model ID to update.
activate: If provided, sets enabled to this value. If None, inverts current state (legacy behavior).
Returns:
Updated LLMModel instance.
"""
model = self.get_model(db, model_id)
if activate is None:
# Legacy toggle behavior for backward compatibility
model.enabled = not model.enabled
else:
model.enabled = activate
db.commit()
db.refresh(model)
logger.info(f"Set LLM model state: {model.model_id} enabled={model.enabled}")
return model
# ---------------------------------------------------------------------------
# Gateway Models (for LLM Chat dropdown)
# ---------------------------------------------------------------------------
def get_gateway_models(self, db: Session) -> List[GatewayModelInfo]:
"""Get enabled models for the LLM Chat dropdown.
Args:
db: Database session.
Returns:
List of GatewayModelInfo for enabled models.
"""
# Get enabled models from enabled providers
query = (
select(LLMModel, LLMProvider)
.join(LLMProvider, LLMModel.provider_id == LLMProvider.id)
.where(
and_(
LLMModel.enabled.is_(True),
LLMProvider.enabled.is_(True),
LLMModel.supports_chat.is_(True),
)
)
.order_by(LLMProvider.name, LLMModel.model_name)
)
results = db.execute(query).all()
models = []
for model, provider in results:
models.append(
GatewayModelInfo(
id=model.id,
model_id=model.model_id,
model_name=model.model_name,
provider_id=provider.id,
provider_name=provider.name,
provider_type=provider.provider_type,
supports_streaming=model.supports_streaming,
supports_function_calling=model.supports_function_calling,
supports_vision=model.supports_vision,
)
)
return models
# ---------------------------------------------------------------------------
# Health Check Operations
# ---------------------------------------------------------------------------
async def check_provider_health(
self,
db: Session,
provider_id: str,
) -> ProviderHealthCheck:
"""Check health of an LLM provider.
Args:
db: Database session.
provider_id: Provider ID to check.
Returns:
ProviderHealthCheck result.
"""
provider = self.get_provider(db, provider_id)
start_time = datetime.now(timezone.utc)
status = HealthStatus.UNKNOWN
error_msg = None
response_time_ms = None
try:
# Get API key
api_key = None
if provider.api_key:
auth_data = decode_auth(provider.api_key)
api_key = auth_data.get("api_key")
# Perform health check based on provider type using shared HTTP client
# First-Party
from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel
client = await get_http_client()
if provider.provider_type == LLMProviderType.OPENAI:
# Check OpenAI models endpoint
headers = {"Authorization": f"Bearer {api_key}"}
base_url = provider.api_base or "https://api.openai.com/v1"
response = await client.get(f"{base_url}/models", headers=headers, timeout=10.0)
if response.status_code == 200:
status = HealthStatus.HEALTHY
else:
status = HealthStatus.UNHEALTHY
error_msg = f"HTTP {response.status_code}"
elif provider.provider_type == LLMProviderType.OLLAMA:
# Check Ollama health endpoint
base_url = provider.api_base or "http://localhost:11434"
# Handle OpenAI-compatible endpoint (/v1)
if base_url.rstrip("/").endswith("/v1"):
# Use OpenAI-compatible models endpoint
response = await client.get(f"{base_url.rstrip('/')}/models", timeout=10.0)
else:
# Use native Ollama API
response = await client.get(f"{base_url.rstrip('/')}/api/tags", timeout=10.0)
if response.status_code == 200:
status = HealthStatus.HEALTHY
else:
status = HealthStatus.UNHEALTHY
error_msg = f"HTTP {response.status_code}"
else:
# Generic check - just verify connectivity
if provider.api_base:
response = await client.get(provider.api_base, timeout=5.0)
status = HealthStatus.HEALTHY if response.status_code < 500 else HealthStatus.UNHEALTHY
else:
status = HealthStatus.UNKNOWN
error_msg = "No API base URL configured"
except httpx.TimeoutException:
status = HealthStatus.UNHEALTHY
error_msg = "Connection timeout"
except httpx.RequestError as e:
status = HealthStatus.UNHEALTHY
error_msg = f"Connection error: {str(e)}"
except Exception as e:
status = HealthStatus.UNHEALTHY
error_msg = f"Error: {str(e)}"
end_time = datetime.now(timezone.utc)
response_time_ms = (end_time - start_time).total_seconds() * 1000
# Update provider health status
provider.health_status = status.value
provider.last_health_check = end_time
db.commit()
return ProviderHealthCheck(
provider_id=provider.id,
provider_name=provider.name,
provider_type=provider.provider_type,
status=status,
response_time_ms=response_time_ms,
error=error_msg,
checked_at=end_time,
)
def to_provider_response(
self,
provider: LLMProvider,
model_count: int = 0,
) -> LLMProviderResponse:
"""Convert LLMProvider to LLMProviderResponse.
Args:
provider: LLMProvider instance.
model_count: Number of models for this provider.
Returns:
LLMProviderResponse instance.
"""
return LLMProviderResponse(
id=provider.id,
name=provider.name,
slug=provider.slug,
description=provider.description,
provider_type=provider.provider_type,
api_base=provider.api_base,
api_version=provider.api_version,
config=provider.config,
default_model=provider.default_model,
default_temperature=provider.default_temperature,
default_max_tokens=provider.default_max_tokens,
enabled=provider.enabled,
health_status=provider.health_status,
last_health_check=provider.last_health_check,
plugin_ids=provider.plugin_ids,
created_at=provider.created_at,
updated_at=provider.updated_at,
created_by=provider.created_by,
modified_by=provider.modified_by,
model_count=model_count,
)
def to_model_response(
self,
model: LLMModel,
provider: Optional[LLMProvider] = None,
) -> LLMModelResponse:
"""Convert LLMModel to LLMModelResponse.
Args:
model: LLMModel instance.
provider: Optional provider for name/type info.
Returns:
LLMModelResponse instance.
"""
return LLMModelResponse(
id=model.id,
provider_id=model.provider_id,
model_id=model.model_id,
model_name=model.model_name,
model_alias=model.model_alias,
description=model.description,
supports_chat=model.supports_chat,
supports_streaming=model.supports_streaming,
supports_function_calling=model.supports_function_calling,
supports_vision=model.supports_vision,
context_window=model.context_window,
max_output_tokens=model.max_output_tokens,
enabled=model.enabled,
deprecated=model.deprecated,
created_at=model.created_at,
updated_at=model.updated_at,
provider_name=provider.name if provider else None,
provider_type=provider.provider_type if provider else None,
)