We provide all the information about MCP servers via our MCP API.
curl -X GET 'https://glama.ai/api/mcp/v1/servers/pydantic/pydantic-ai'
If you have feedback or need assistance with the MCP directory API, please join our Discord server
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any, Literal, cast
from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, UnexpectedModelBehavior
from pydantic_ai.providers import Provider, infer_provider
from pydantic_ai.usage import RequestUsage
from .base import EmbeddingModel, EmbedInputType
from .result import EmbeddingResult
from .settings import EmbeddingSettings
try:
from cohere import AsyncClientV2
from cohere.core.api_error import ApiError
from cohere.core.request_options import RequestOptions
from cohere.types.embed_by_type_response import EmbedByTypeResponse
from cohere.types.embed_input_type import EmbedInputType as CohereEmbedInputType
from cohere.v2.types.v2embed_request_truncate import V2EmbedRequestTruncate
from pydantic_ai.providers.cohere import CohereProvider
except ImportError as _import_error:
raise ImportError(
'Please install `cohere` to use the Cohere embeddings model, '
'you can use the `cohere` optional group — `pip install "pydantic-ai-slim[cohere]"`'
) from _import_error
LatestCohereEmbeddingModelNames = Literal[
'embed-v4.0',
'embed-english-v3.0',
'embed-english-light-v3.0',
'embed-multilingual-v3.0',
'embed-multilingual-light-v3.0',
]
"""Latest Cohere embeddings models.
See the [Cohere Embed documentation](https://docs.cohere.com/docs/cohere-embed)
for available models and their capabilities.
"""
CohereEmbeddingModelName = str | LatestCohereEmbeddingModelNames
"""Possible Cohere embeddings model names."""
# Taken from https://docs.cohere.com/docs/cohere-embed
_MAX_INPUT_TOKENS: dict[CohereEmbeddingModelName, int] = {
'embed-v4.0': 128000,
'embed-english-v3.0': 512,
'embed-english-light-v3.0': 512,
'embed-multilingual-v3.0': 512,
'embed-multilingual-light-v3.0': 512,
}
class CohereEmbeddingSettings(EmbeddingSettings, total=False):
"""Settings used for a Cohere embedding model request.
All fields from [`EmbeddingSettings`][pydantic_ai.embeddings.EmbeddingSettings] are supported,
plus Cohere-specific settings prefixed with `cohere_`.
"""
# ALL FIELDS MUST BE `cohere_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
cohere_max_tokens: int
"""The maximum number of tokens to embed."""
cohere_input_type: CohereEmbedInputType
"""The Cohere-specific input type for the embedding.
Overrides the standard `input_type` argument. Options include:
`'search_query'`, `'search_document'`, `'classification'`, `'clustering'`, and `'image'`.
"""
cohere_truncate: V2EmbedRequestTruncate
"""The truncation strategy to use:
- `'NONE'` (default): Raise an error if input exceeds max tokens.
- `'END'`: Truncate the end of the input text.
- `'START'`: Truncate the start of the input text.
"""
@dataclass(init=False)
class CohereEmbeddingModel(EmbeddingModel):
"""Cohere embedding model implementation.
This model works with Cohere's embeddings API, which offers
multilingual support and various model sizes.
Example:
```python
from pydantic_ai.embeddings.cohere import CohereEmbeddingModel
model = CohereEmbeddingModel('embed-v4.0')
```
"""
_model_name: CohereEmbeddingModelName = field(repr=False)
_provider: Provider[AsyncClientV2] = field(repr=False)
def __init__(
self,
model_name: CohereEmbeddingModelName,
*,
provider: Literal['cohere'] | Provider[AsyncClientV2] = 'cohere',
settings: EmbeddingSettings | None = None,
):
"""Initialize a Cohere embedding model.
Args:
model_name: The name of the Cohere model to use.
See [Cohere Embed documentation](https://docs.cohere.com/docs/cohere-embed)
for available models.
provider: The provider to use for authentication and API access. Can be:
- `'cohere'` (default): Uses the standard Cohere API
- A [`CohereProvider`][pydantic_ai.providers.cohere.CohereProvider] instance
for custom configuration
settings: Model-specific [`EmbeddingSettings`][pydantic_ai.embeddings.EmbeddingSettings]
to use as defaults for this model.
"""
self._model_name = model_name
if isinstance(provider, str):
provider = infer_provider(provider)
self._provider = provider
self._client = provider.client
self._v1_client = provider.v1_client if isinstance(provider, CohereProvider) else None
super().__init__(settings=settings)
@property
def base_url(self) -> str:
"""The base URL for the provider API, if available."""
return self._provider.base_url
@property
def model_name(self) -> CohereEmbeddingModelName:
"""The embedding model name."""
return self._model_name
@property
def system(self) -> str:
"""The embedding model provider."""
return self._provider.name
async def embed(
self, inputs: str | Sequence[str], *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None
) -> EmbeddingResult:
inputs, settings = self.prepare_embed(inputs, settings)
settings = cast(CohereEmbeddingSettings, settings)
cohere_input_type = settings.get(
'cohere_input_type', 'search_document' if input_type == 'document' else 'search_query'
)
request_options: RequestOptions = {}
if extra_headers := settings.get('extra_headers'): # pragma: no cover
request_options['additional_headers'] = extra_headers
if extra_body := settings.get('extra_body'): # pragma: no cover
request_options['additional_body_parameters'] = cast(dict[str, Any], extra_body)
try:
response = await self._client.embed(
model=self.model_name,
texts=inputs,
output_dimension=settings.get('dimensions'),
input_type=cohere_input_type,
max_tokens=settings.get('cohere_max_tokens'),
truncate=settings.get('cohere_truncate', 'NONE'),
request_options=request_options,
)
except ApiError as e:
if (status_code := e.status_code) and status_code >= 400:
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
raise ModelAPIError(model_name=self.model_name, message=str(e)) from e # pragma: no cover
embeddings = response.embeddings.float_
if embeddings is None:
raise UnexpectedModelBehavior( # pragma: no cover
'The Cohere embeddings response did not have an `embeddings` field holding a list of floats',
str(response),
)
return EmbeddingResult(
embeddings=embeddings,
inputs=inputs,
input_type=input_type,
usage=_map_usage(response, self.system, self.base_url, self.model_name),
model_name=self.model_name,
provider_name=self.system,
provider_response_id=response.id,
)
async def max_input_tokens(self) -> int | None:
return _MAX_INPUT_TOKENS.get(self.model_name)
async def count_tokens(self, text: str) -> int:
if self._v1_client is None:
raise NotImplementedError('Counting tokens requires the Cohere v1 client')
try:
result = await self._v1_client.tokenize(
model=self.model_name,
text=text, # Has a max length of 65536 characters
offline=False,
)
except ApiError as e: # pragma: no cover
if (status_code := e.status_code) and status_code >= 400:
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
raise ModelAPIError(model_name=self.model_name, message=str(e)) from e
return len(result.tokens)
def _map_usage(response: EmbedByTypeResponse, provider: str, provider_url: str, model: str) -> RequestUsage:
u = response.meta
if u is None or u.billed_units is None:
return RequestUsage() # pragma: no cover
usage_data = {
k: int(v)
for k, v in u.billed_units.model_dump(exclude_none=True).items()
if isinstance(v, int | float) and v > 0
}
details = {k: int(v) for k, v in usage_data.items() if k != 'input_tokens' and isinstance(v, int | float) and v > 0}
response_data = dict(model=model, meta=dict(billed_units=usage_data))
return RequestUsage.extract(
response_data,
provider=provider,
provider_url=provider_url,
provider_fallback='cohere',
api_flavor='embeddings',
details=details,
)