from collections.abc import Iterable, Mapping
from dataclasses import dataclass
from typing import Any, Literal, TypeVar, overload
import httpx
from openai import APIConnectionError, APIStatusError, AsyncOpenAI, RateLimitError
from openai.types.chat import ChatCompletionContentPartTextParam
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.shared.chat_model import ChatModel
from openai.types.shared_params.reasoning_effort import ReasoningEffort
from openai.types.shared_params.response_format_json_schema import JSONSchema, ResponseFormatJSONSchema
from pydantic import BaseModel
from browser_use.llm.base import BaseChatModel
from browser_use.llm.exceptions import ModelProviderError
from browser_use.llm.messages import BaseMessage
from browser_use.llm.openai.serializer import OpenAIMessageSerializer
from browser_use.llm.schema import SchemaOptimizer
from browser_use.llm.views import ChatInvokeCompletion, ChatInvokeUsage
T = TypeVar('T', bound=BaseModel)
ReasoningModels: list[ChatModel | str] = [
'o4-mini',
'o3',
'o3-mini',
'o1',
'o1-pro',
'o3-pro',
'gpt-5',
'gpt-5-mini',
'gpt-5-nano',
]
@dataclass
class ChatOpenAI(BaseChatModel):
"""
A wrapper around AsyncOpenAI that implements the BaseLLM protocol.
This class accepts all AsyncOpenAI parameters while adding model
and temperature parameters for the LLM interface (if temperature it not `None`).
"""
# Model configuration
model: ChatModel | str
# Model params
# set to 0.1 because browser-use aims to be more reliable and deterministic
temperature: float | None = 0.2
frequency_penalty: float | None = 0.1
reasoning_effort: ReasoningEffort = 'low'
seed: int | None = None
service_tier: Literal['auto', 'default', 'flex', 'priority', 'scale'] | None = None
top_p: float | None = None
add_schema_to_system_prompt: bool = False # Add JSON schema to system prompt instead of using response_format
# Client initialization parameters
api_key: str | None = None
organization: str | None = None
project: str | None = None
base_url: str | httpx.URL | None = None
websocket_base_url: str | httpx.URL | None = None
timeout: float | httpx.Timeout | None = None
max_retries: int = 10 # Increase default retries for automation reliability
default_headers: Mapping[str, str] | None = None
default_query: Mapping[str, object] | None = None
http_client: httpx.AsyncClient | None = None
_strict_response_validation: bool = False
max_completion_tokens: int | None = 8000
# Static
@property
def provider(self) -> str:
return 'openai'
def _get_client_params(self) -> dict[str, Any]:
"""Prepare client parameters dictionary."""
# Define base client params
base_params = {
'api_key': self.api_key,
'organization': self.organization,
'project': self.project,
'base_url': self.base_url,
'websocket_base_url': self.websocket_base_url,
'timeout': self.timeout,
'max_retries': self.max_retries,
'default_headers': self.default_headers,
'default_query': self.default_query,
'_strict_response_validation': self._strict_response_validation,
}
# Create client_params dict with non-None values
client_params = {k: v for k, v in base_params.items() if v is not None}
# Add http_client if provided
if self.http_client is not None:
client_params['http_client'] = self.http_client
return client_params
def get_client(self) -> AsyncOpenAI:
"""
Returns an AsyncOpenAI client.
Returns:
AsyncOpenAI: An instance of the AsyncOpenAI client.
"""
client_params = self._get_client_params()
return AsyncOpenAI(**client_params)
@property
def name(self) -> str:
return str(self.model)
def _get_usage(self, response: ChatCompletion) -> ChatInvokeUsage | None:
if response.usage is not None:
completion_tokens = response.usage.completion_tokens
completion_token_details = response.usage.completion_tokens_details
if completion_token_details is not None:
reasoning_tokens = completion_token_details.reasoning_tokens
if reasoning_tokens is not None:
completion_tokens += reasoning_tokens
usage = ChatInvokeUsage(
prompt_tokens=response.usage.prompt_tokens,
prompt_cached_tokens=response.usage.prompt_tokens_details.cached_tokens
if response.usage.prompt_tokens_details is not None
else None,
prompt_cache_creation_tokens=None,
prompt_image_tokens=None,
# Completion
completion_tokens=completion_tokens,
total_tokens=response.usage.total_tokens,
)
else:
usage = None
return usage
@overload
async def ainvoke(self, messages: list[BaseMessage], output_format: None = None) -> ChatInvokeCompletion[str]: ...
@overload
async def ainvoke(self, messages: list[BaseMessage], output_format: type[T]) -> ChatInvokeCompletion[T]: ...
async def ainvoke(
self, messages: list[BaseMessage], output_format: type[T] | None = None
) -> ChatInvokeCompletion[T] | ChatInvokeCompletion[str]:
"""
Invoke the model with the given messages.
Args:
messages: List of chat messages
output_format: Optional Pydantic model class for structured output
Returns:
Either a string response or an instance of output_format
"""
openai_messages = OpenAIMessageSerializer.serialize_messages(messages)
try:
model_params: dict[str, Any] = {}
if self.temperature is not None:
model_params['temperature'] = self.temperature
if self.frequency_penalty is not None:
model_params['frequency_penalty'] = self.frequency_penalty
if self.max_completion_tokens is not None:
model_params['max_completion_tokens'] = self.max_completion_tokens
if self.top_p is not None:
model_params['top_p'] = self.top_p
if self.seed is not None:
model_params['seed'] = self.seed
if self.service_tier is not None:
model_params['service_tier'] = self.service_tier
if any(str(m).lower() in str(self.model).lower() for m in ReasoningModels):
model_params['reasoning_effort'] = self.reasoning_effort
del model_params['temperature']
del model_params['frequency_penalty']
if output_format is None:
# Return string response
response = await self.get_client().chat.completions.create(
model=self.model,
messages=openai_messages,
**model_params,
)
usage = self._get_usage(response)
return ChatInvokeCompletion(
completion=response.choices[0].message.content or '',
usage=usage,
)
else:
response_format: JSONSchema = {
'name': 'agent_output',
'strict': True,
'schema': SchemaOptimizer.create_optimized_json_schema(output_format),
}
# Add JSON schema to system prompt if requested
if self.add_schema_to_system_prompt and openai_messages and openai_messages[0]['role'] == 'system':
schema_text = f'\n<json_schema>\n{response_format}\n</json_schema>'
if isinstance(openai_messages[0]['content'], str):
openai_messages[0]['content'] += schema_text
elif isinstance(openai_messages[0]['content'], Iterable):
openai_messages[0]['content'] = list(openai_messages[0]['content']) + [
ChatCompletionContentPartTextParam(text=schema_text, type='text')
]
# Return structured response
response = await self.get_client().chat.completions.create(
model=self.model,
messages=openai_messages,
response_format=ResponseFormatJSONSchema(json_schema=response_format, type='json_schema'),
**model_params,
)
if response.choices[0].message.content is None:
raise ModelProviderError(
message='Failed to parse structured output from model response',
status_code=500,
model=self.name,
)
usage = self._get_usage(response)
parsed = output_format.model_validate_json(response.choices[0].message.content)
return ChatInvokeCompletion(
completion=parsed,
usage=usage,
)
except RateLimitError as e:
error_message = e.response.json().get('error', {})
error_message = (
error_message.get('message', 'Unknown model error') if isinstance(error_message, dict) else error_message
)
raise ModelProviderError(
message=error_message,
status_code=e.response.status_code,
model=self.name,
) from e
except APIConnectionError as e:
raise ModelProviderError(message=str(e), model=self.name) from e
except APIStatusError as e:
try:
error_message = e.response.json().get('error', {})
except Exception:
error_message = e.response.text
error_message = (
error_message.get('message', 'Unknown model error') if isinstance(error_message, dict) else error_message
)
raise ModelProviderError(
message=error_message,
status_code=e.response.status_code,
model=self.name,
) from e
except Exception as e:
raise ModelProviderError(message=str(e), model=self.name) from e