azure_openai.pyโข14.5 kB
"""Azure OpenAI provider built on the OpenAI-compatible implementation."""
from __future__ import annotations
import logging
from dataclasses import asdict, replace
try: # pragma: no cover - optional dependency
from openai import AzureOpenAI
except ImportError: # pragma: no cover
AzureOpenAI = None # type: ignore[assignment]
from utils.env import get_env, suppress_env_vars
from .openai import OpenAIModelProvider
from .openai_compatible import OpenAICompatibleProvider
from .registries.azure import AzureModelRegistry
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
logger = logging.getLogger(__name__)
class AzureOpenAIProvider(OpenAICompatibleProvider):
"""Thin Azure wrapper that reuses the OpenAI-compatible request pipeline."""
FRIENDLY_NAME = "Azure OpenAI"
DEFAULT_API_VERSION = "2024-02-15-preview"
# The OpenAI-compatible base expects subclasses to expose capabilities via
# ``get_all_model_capabilities``. Azure deployments are user-defined, so we
# build the catalogue dynamically from environment configuration instead of
# relying on a static ``MODEL_CAPABILITIES`` map.
MODEL_CAPABILITIES: dict[str, ModelCapabilities] = {}
def __init__(
self,
api_key: str,
*,
azure_endpoint: str | None = None,
api_version: str | None = None,
deployments: dict[str, object] | None = None,
**kwargs,
) -> None:
# Let the OpenAI-compatible base handle shared configuration such as
# timeouts, restriction-aware allowlists, and logging. ``base_url`` maps
# directly onto Azure's endpoint URL.
super().__init__(api_key, base_url=azure_endpoint, **kwargs)
if not azure_endpoint:
azure_endpoint = get_env("AZURE_OPENAI_ENDPOINT")
if not azure_endpoint:
raise ValueError("Azure OpenAI endpoint is required via parameter or AZURE_OPENAI_ENDPOINT")
self.azure_endpoint = azure_endpoint.rstrip("/")
self.api_version = api_version or get_env("AZURE_OPENAI_API_VERSION", self.DEFAULT_API_VERSION)
registry_specs = self._load_registry_entries()
override_specs = self._normalise_deployments(deployments or {}) if deployments else {}
self._model_specs = self._merge_specs(registry_specs, override_specs)
if not self._model_specs:
raise ValueError(
"Azure OpenAI provider requires at least one configured deployment. "
"Populate conf/azure_models.json or set AZURE_MODELS_CONFIG_PATH."
)
self._capabilities = self._build_capabilities_map()
self._deployment_map = {name: spec["deployment"] for name, spec in self._model_specs.items()}
self._deployment_alias_lookup = {
deployment.lower(): canonical for canonical, deployment in self._deployment_map.items()
}
self._canonical_lookup = {name.lower(): name for name in self._model_specs.keys()}
self._invalidate_capability_cache()
# ------------------------------------------------------------------
# Capability helpers
# ------------------------------------------------------------------
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
return dict(self._capabilities)
def get_provider_type(self) -> ProviderType:
return ProviderType.AZURE
def get_capabilities(self, model_name: str) -> ModelCapabilities: # type: ignore[override]
lowered = model_name.lower()
if lowered in self._deployment_alias_lookup:
canonical = self._deployment_alias_lookup[lowered]
return super().get_capabilities(canonical)
canonical = self._canonical_lookup.get(lowered)
if canonical:
return super().get_capabilities(canonical)
return super().get_capabilities(model_name)
def validate_model_name(self, model_name: str) -> bool: # type: ignore[override]
lowered = model_name.lower()
if lowered in self._deployment_alias_lookup or lowered in self._canonical_lookup:
return True
return super().validate_model_name(model_name)
def _build_capabilities_map(self) -> dict[str, ModelCapabilities]:
capabilities: dict[str, ModelCapabilities] = {}
for canonical_name, spec in self._model_specs.items():
template_capability: ModelCapabilities | None = spec.get("capability")
overrides = spec.get("overrides", {})
if template_capability:
cloned = replace(template_capability)
else:
template = OpenAIModelProvider.MODEL_CAPABILITIES.get(canonical_name)
if template:
friendly = template.friendly_name.replace("OpenAI", "Azure OpenAI", 1)
cloned = replace(
template,
provider=ProviderType.AZURE,
friendly_name=friendly,
aliases=list(template.aliases),
)
else:
deployment_name = spec.get("deployment", "")
cloned = ModelCapabilities(
provider=ProviderType.AZURE,
model_name=canonical_name,
friendly_name=f"Azure OpenAI ({canonical_name})",
description=f"Azure deployment '{deployment_name}' for {canonical_name}",
aliases=[],
)
if overrides:
overrides = dict(overrides)
temp_override = overrides.get("temperature_constraint")
if isinstance(temp_override, str):
overrides["temperature_constraint"] = TemperatureConstraint.create(temp_override)
aliases_override = overrides.get("aliases")
if isinstance(aliases_override, str):
overrides["aliases"] = [alias.strip() for alias in aliases_override.split(",") if alias.strip()]
provider_override = overrides.get("provider")
if provider_override:
overrides.pop("provider", None)
try:
cloned = replace(cloned, **overrides)
except TypeError:
base_data = asdict(cloned)
base_data.update(overrides)
base_data["provider"] = ProviderType.AZURE
temp_value = base_data.get("temperature_constraint")
if isinstance(temp_value, str):
base_data["temperature_constraint"] = TemperatureConstraint.create(temp_value)
cloned = ModelCapabilities(**base_data)
if cloned.provider != ProviderType.AZURE:
cloned.provider = ProviderType.AZURE
capabilities[canonical_name] = cloned
return capabilities
def _load_registry_entries(self) -> dict[str, dict]:
try:
registry = AzureModelRegistry()
except Exception as exc: # pragma: no cover - registry failure should not crash provider
logger.warning("Unable to load Azure model registry: %s", exc)
return {}
entries: dict[str, dict] = {}
for model_name, capability, extra in registry.iter_entries():
deployment = extra.get("deployment")
if not deployment:
logger.warning("Azure model '%s' missing deployment in registry", model_name)
continue
entries[model_name] = {"deployment": deployment, "capability": capability}
return entries
@staticmethod
def _merge_specs(
registry_specs: dict[str, dict],
override_specs: dict[str, dict],
) -> dict[str, dict]:
specs: dict[str, dict] = {}
for canonical, entry in registry_specs.items():
specs[canonical] = {
"deployment": entry.get("deployment"),
"capability": entry.get("capability"),
"overrides": {},
}
for canonical, entry in override_specs.items():
spec = specs.get(canonical, {"deployment": None, "capability": None, "overrides": {}})
deployment = entry.get("deployment")
if deployment:
spec["deployment"] = deployment
overrides = {k: v for k, v in entry.items() if k not in {"deployment"}}
overrides.pop("capability", None)
if overrides:
spec["overrides"].update(overrides)
specs[canonical] = spec
return {k: v for k, v in specs.items() if v.get("deployment")}
@staticmethod
def _normalise_deployments(mapping: dict[str, object]) -> dict[str, dict]:
normalised: dict[str, dict] = {}
for canonical, spec in mapping.items():
canonical_name = (canonical or "").strip()
if not canonical_name:
continue
deployment_name: str | None = None
overrides: dict[str, object] = {}
if isinstance(spec, str):
deployment_name = spec.strip()
elif isinstance(spec, dict):
deployment_name = spec.get("deployment") or spec.get("deployment_name")
overrides = {k: v for k, v in spec.items() if k not in {"deployment", "deployment_name"}}
if not deployment_name:
continue
normalised[canonical_name] = {"deployment": deployment_name.strip(), **overrides}
return normalised
# ------------------------------------------------------------------
# Azure-specific configuration
# ------------------------------------------------------------------
@property
def client(self): # type: ignore[override]
"""Instantiate the Azure OpenAI client on first use."""
if self._client is None:
if AzureOpenAI is None:
raise ImportError(
"Azure OpenAI support requires the 'openai' package. Install it with `pip install openai`."
)
import httpx
proxy_env_vars = ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]
with suppress_env_vars(*proxy_env_vars):
try:
timeout_config = self.timeout_config
http_client = httpx.Client(timeout=timeout_config, follow_redirects=True)
client_kwargs = {
"api_key": self.api_key,
"azure_endpoint": self.azure_endpoint,
"api_version": self.api_version,
"http_client": http_client,
}
if self.DEFAULT_HEADERS:
client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
logger.debug(
"Initializing Azure OpenAI client endpoint=%s api_version=%s timeouts=%s",
self.azure_endpoint,
self.api_version,
timeout_config,
)
self._client = AzureOpenAI(**client_kwargs)
except Exception as exc:
logger.error("Failed to create Azure OpenAI client: %s", exc)
raise
return self._client
# ------------------------------------------------------------------
# Request delegation
# ------------------------------------------------------------------
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: str | None = None,
temperature: float = 0.3,
max_output_tokens: int | None = None,
images: list[str] | None = None,
**kwargs,
) -> ModelResponse:
canonical_name, deployment_name = self._resolve_canonical_and_deployment(model_name)
# Delegate to the shared OpenAI-compatible implementation using the
# deployment name โ Azure requires the deployment identifier in the
# ``model`` field. The returned ``ModelResponse`` is normalised so
# downstream consumers continue to see the canonical model name.
raw_response = super().generate_content(
prompt=prompt,
model_name=deployment_name,
system_prompt=system_prompt,
temperature=temperature,
max_output_tokens=max_output_tokens,
images=images,
**kwargs,
)
capabilities = self._capabilities.get(canonical_name)
friendly_name = capabilities.friendly_name if capabilities else self.FRIENDLY_NAME
return ModelResponse(
content=raw_response.content,
usage=raw_response.usage,
model_name=canonical_name,
friendly_name=friendly_name,
provider=ProviderType.AZURE,
metadata={**raw_response.metadata, "deployment": deployment_name},
)
def _resolve_canonical_and_deployment(self, model_name: str) -> tuple[str, str]:
resolved_canonical = self._resolve_model_name(model_name)
if resolved_canonical not in self._deployment_map:
# The base resolver may hand back the deployment alias. Try to map it
# back to a canonical entry.
for canonical, deployment in self._deployment_map.items():
if deployment.lower() == resolved_canonical.lower():
return canonical, deployment
raise ValueError(f"Model '{model_name}' is not configured for Azure OpenAI")
return resolved_canonical, self._deployment_map[resolved_canonical]
def _parse_allowed_models(self) -> set[str] | None: # type: ignore[override]
# Support both AZURE_ALLOWED_MODELS (inherited behaviour) and the
# clearer AZURE_OPENAI_ALLOWED_MODELS alias.
explicit = get_env("AZURE_OPENAI_ALLOWED_MODELS")
if explicit:
models = {m.strip().lower() for m in explicit.split(",") if m.strip()}
if models:
logger.info("Configured allowed models for Azure OpenAI: %s", sorted(models))
self._allowed_alias_cache = {}
return models
return super()._parse_allowed_models()