Skip to main content
Glama
pydantic

mcp-run-python

Official
by pydantic
cohere.py8.92 kB
from collections.abc import Sequence from dataclasses import dataclass, field from typing import Any, Literal, cast from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, UnexpectedModelBehavior from pydantic_ai.providers import Provider, infer_provider from pydantic_ai.usage import RequestUsage from .base import EmbeddingModel, EmbedInputType from .result import EmbeddingResult from .settings import EmbeddingSettings try: from cohere import AsyncClientV2 from cohere.core.api_error import ApiError from cohere.core.request_options import RequestOptions from cohere.types.embed_by_type_response import EmbedByTypeResponse from cohere.types.embed_input_type import EmbedInputType as CohereEmbedInputType from cohere.v2.types.v2embed_request_truncate import V2EmbedRequestTruncate from pydantic_ai.providers.cohere import CohereProvider except ImportError as _import_error: raise ImportError( 'Please install `cohere` to use the Cohere embeddings model, ' 'you can use the `cohere` optional group — `pip install "pydantic-ai-slim[cohere]"`' ) from _import_error LatestCohereEmbeddingModelNames = Literal[ 'embed-v4.0', 'embed-english-v3.0', 'embed-english-light-v3.0', 'embed-multilingual-v3.0', 'embed-multilingual-light-v3.0', ] """Latest Cohere embeddings models. See the [Cohere Embed documentation](https://docs.cohere.com/docs/cohere-embed) for available models and their capabilities. """ CohereEmbeddingModelName = str | LatestCohereEmbeddingModelNames """Possible Cohere embeddings model names.""" # Taken from https://docs.cohere.com/docs/cohere-embed _MAX_INPUT_TOKENS: dict[CohereEmbeddingModelName, int] = { 'embed-v4.0': 128000, 'embed-english-v3.0': 512, 'embed-english-light-v3.0': 512, 'embed-multilingual-v3.0': 512, 'embed-multilingual-light-v3.0': 512, } class CohereEmbeddingSettings(EmbeddingSettings, total=False): """Settings used for a Cohere embedding model request. All fields from [`EmbeddingSettings`][pydantic_ai.embeddings.EmbeddingSettings] are supported, plus Cohere-specific settings prefixed with `cohere_`. """ # ALL FIELDS MUST BE `cohere_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. cohere_max_tokens: int """The maximum number of tokens to embed.""" cohere_input_type: CohereEmbedInputType """The Cohere-specific input type for the embedding. Overrides the standard `input_type` argument. Options include: `'search_query'`, `'search_document'`, `'classification'`, `'clustering'`, and `'image'`. """ cohere_truncate: V2EmbedRequestTruncate """The truncation strategy to use: - `'NONE'` (default): Raise an error if input exceeds max tokens. - `'END'`: Truncate the end of the input text. - `'START'`: Truncate the start of the input text. """ @dataclass(init=False) class CohereEmbeddingModel(EmbeddingModel): """Cohere embedding model implementation. This model works with Cohere's embeddings API, which offers multilingual support and various model sizes. Example: ```python from pydantic_ai.embeddings.cohere import CohereEmbeddingModel model = CohereEmbeddingModel('embed-v4.0') ``` """ _model_name: CohereEmbeddingModelName = field(repr=False) _provider: Provider[AsyncClientV2] = field(repr=False) def __init__( self, model_name: CohereEmbeddingModelName, *, provider: Literal['cohere'] | Provider[AsyncClientV2] = 'cohere', settings: EmbeddingSettings | None = None, ): """Initialize a Cohere embedding model. Args: model_name: The name of the Cohere model to use. See [Cohere Embed documentation](https://docs.cohere.com/docs/cohere-embed) for available models. provider: The provider to use for authentication and API access. Can be: - `'cohere'` (default): Uses the standard Cohere API - A [`CohereProvider`][pydantic_ai.providers.cohere.CohereProvider] instance for custom configuration settings: Model-specific [`EmbeddingSettings`][pydantic_ai.embeddings.EmbeddingSettings] to use as defaults for this model. """ self._model_name = model_name if isinstance(provider, str): provider = infer_provider(provider) self._provider = provider self._client = provider.client self._v1_client = provider.v1_client if isinstance(provider, CohereProvider) else None super().__init__(settings=settings) @property def base_url(self) -> str: """The base URL for the provider API, if available.""" return self._provider.base_url @property def model_name(self) -> CohereEmbeddingModelName: """The embedding model name.""" return self._model_name @property def system(self) -> str: """The embedding model provider.""" return self._provider.name async def embed( self, inputs: str | Sequence[str], *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None ) -> EmbeddingResult: inputs, settings = self.prepare_embed(inputs, settings) settings = cast(CohereEmbeddingSettings, settings) cohere_input_type = settings.get( 'cohere_input_type', 'search_document' if input_type == 'document' else 'search_query' ) request_options: RequestOptions = {} if extra_headers := settings.get('extra_headers'): # pragma: no cover request_options['additional_headers'] = extra_headers if extra_body := settings.get('extra_body'): # pragma: no cover request_options['additional_body_parameters'] = cast(dict[str, Any], extra_body) try: response = await self._client.embed( model=self.model_name, texts=inputs, output_dimension=settings.get('dimensions'), input_type=cohere_input_type, max_tokens=settings.get('cohere_max_tokens'), truncate=settings.get('cohere_truncate', 'NONE'), request_options=request_options, ) except ApiError as e: if (status_code := e.status_code) and status_code >= 400: raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e raise ModelAPIError(model_name=self.model_name, message=str(e)) from e # pragma: no cover embeddings = response.embeddings.float_ if embeddings is None: raise UnexpectedModelBehavior( # pragma: no cover 'The Cohere embeddings response did not have an `embeddings` field holding a list of floats', str(response), ) return EmbeddingResult( embeddings=embeddings, inputs=inputs, input_type=input_type, usage=_map_usage(response, self.system, self.base_url, self.model_name), model_name=self.model_name, provider_name=self.system, provider_response_id=response.id, ) async def max_input_tokens(self) -> int | None: return _MAX_INPUT_TOKENS.get(self.model_name) async def count_tokens(self, text: str) -> int: if self._v1_client is None: raise NotImplementedError('Counting tokens requires the Cohere v1 client') try: result = await self._v1_client.tokenize( model=self.model_name, text=text, # Has a max length of 65536 characters offline=False, ) except ApiError as e: # pragma: no cover if (status_code := e.status_code) and status_code >= 400: raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e raise ModelAPIError(model_name=self.model_name, message=str(e)) from e return len(result.tokens) def _map_usage(response: EmbedByTypeResponse, provider: str, provider_url: str, model: str) -> RequestUsage: u = response.meta if u is None or u.billed_units is None: return RequestUsage() # pragma: no cover usage_data = { k: int(v) for k, v in u.billed_units.model_dump(exclude_none=True).items() if isinstance(v, int | float) and v > 0 } details = {k: int(v) for k, v in usage_data.items() if k != 'input_tokens' and isinstance(v, int | float) and v > 0} response_data = dict(model=model, meta=dict(billed_units=usage_data)) return RequestUsage.extract( response_data, provider=provider, provider_url=provider_url, provider_fallback='cohere', api_flavor='embeddings', details=details, )

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/pydantic/pydantic-ai'

If you have feedback or need assistance with the MCP directory API, please join our Discord server