Skip to main content
Glama

mcp-run-python

Official
by pydantic
cohere.py13.9 kB
from __future__ import annotations as _annotations from collections.abc import Iterable from dataclasses import dataclass, field from typing import Literal, cast from typing_extensions import assert_never from pydantic_ai.exceptions import UserError from .. import ModelHTTPError, usage from .._utils import generate_tool_call_id as _generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id from ..messages import ( BuiltinToolCallPart, BuiltinToolReturnPart, FilePart, FinishReason, ModelMessage, ModelRequest, ModelResponse, ModelResponsePart, RetryPromptPart, SystemPromptPart, TextPart, ThinkingPart, ToolCallPart, ToolReturnPart, UserPromptPart, ) from ..profiles import ModelProfileSpec from ..providers import Provider, infer_provider from ..settings import ModelSettings from ..tools import ToolDefinition from . import Model, ModelRequestParameters, check_allow_model_requests try: from cohere import ( AssistantChatMessageV2, AssistantMessageV2ContentItem, AsyncClientV2, ChatFinishReason, ChatMessageV2, SystemChatMessageV2, TextAssistantMessageV2ContentItem, ThinkingAssistantMessageV2ContentItem, ToolCallV2, ToolCallV2Function, ToolChatMessageV2, ToolV2, ToolV2Function, UserChatMessageV2, V2ChatResponse, ) from cohere.core.api_error import ApiError from cohere.v2.client import OMIT except ImportError as _import_error: raise ImportError( 'Please install `cohere` to use the Cohere model, ' 'you can use the `cohere` optional group — `pip install "pydantic-ai-slim[cohere]"`' ) from _import_error LatestCohereModelNames = Literal[ 'c4ai-aya-expanse-32b', 'c4ai-aya-expanse-8b', 'command-nightly', 'command-r-08-2024', 'command-r-plus-08-2024', 'command-r7b-12-2024', ] """Latest Cohere models.""" CohereModelName = str | LatestCohereModelNames """Possible Cohere model names. Since Cohere supports a variety of date-stamped models, we explicitly list the latest models but allow any name in the type hints. See [Cohere's docs](https://docs.cohere.com/v2/docs/models) for a list of all available models. """ _FINISH_REASON_MAP: dict[ChatFinishReason, FinishReason] = { 'COMPLETE': 'stop', 'STOP_SEQUENCE': 'stop', 'MAX_TOKENS': 'length', 'TOOL_CALL': 'tool_call', 'ERROR': 'error', } class CohereModelSettings(ModelSettings, total=False): """Settings used for a Cohere model request.""" # ALL FIELDS MUST BE `cohere_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. # This class is a placeholder for any future cohere-specific settings @dataclass(init=False) class CohereModel(Model): """A model that uses the Cohere API. Internally, this uses the [Cohere Python client]( https://github.com/cohere-ai/cohere-python) to interact with the API. Apart from `__init__`, all methods are private or match those of the base class. """ client: AsyncClientV2 = field(repr=False) _model_name: CohereModelName = field(repr=False) _provider: Provider[AsyncClientV2] = field(repr=False) def __init__( self, model_name: CohereModelName, *, provider: Literal['cohere'] | Provider[AsyncClientV2] = 'cohere', profile: ModelProfileSpec | None = None, settings: ModelSettings | None = None, ): """Initialize an Cohere model. Args: model_name: The name of the Cohere model to use. List of model names available [here](https://docs.cohere.com/docs/models#command). provider: The provider to use for authentication and API access. Can be either the string 'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be created using the other parameters. profile: The model profile to use. Defaults to a profile picked by the provider based on the model name. settings: Model-specific settings that will be used as defaults for this model. """ self._model_name = model_name if isinstance(provider, str): provider = infer_provider(provider) self._provider = provider self.client = provider.client super().__init__(settings=settings, profile=profile or provider.model_profile) @property def base_url(self) -> str: client_wrapper = self.client._client_wrapper # type: ignore return str(client_wrapper.get_base_url()) @property def model_name(self) -> CohereModelName: """The model name.""" return self._model_name @property def system(self) -> str: """The model provider.""" return self._provider.name async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters, ) -> ModelResponse: check_allow_model_requests() model_settings, model_request_parameters = self.prepare_request( model_settings, model_request_parameters, ) response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters) model_response = self._process_response(response) return model_response async def _chat( self, messages: list[ModelMessage], model_settings: CohereModelSettings, model_request_parameters: ModelRequestParameters, ) -> V2ChatResponse: tools = self._get_tools(model_request_parameters) if model_request_parameters.builtin_tools: raise UserError('Cohere does not support built-in tools') cohere_messages = self._map_messages(messages) try: return await self.client.chat( model=self._model_name, messages=cohere_messages, tools=tools or OMIT, max_tokens=model_settings.get('max_tokens', OMIT), stop_sequences=model_settings.get('stop_sequences', OMIT), temperature=model_settings.get('temperature', OMIT), p=model_settings.get('top_p', OMIT), seed=model_settings.get('seed', OMIT), presence_penalty=model_settings.get('presence_penalty', OMIT), frequency_penalty=model_settings.get('frequency_penalty', OMIT), ) except ApiError as e: if (status_code := e.status_code) and status_code >= 400: raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e raise # pragma: lax no cover def _process_response(self, response: V2ChatResponse) -> ModelResponse: """Process a non-streamed response, and prepare a message to return.""" parts: list[ModelResponsePart] = [] if response.message.content is not None: for content in response.message.content: if content.type == 'text': parts.append(TextPart(content=content.text)) elif content.type == 'thinking': # pragma: no branch parts.append(ThinkingPart(content=content.thinking)) for c in response.message.tool_calls or []: if c.function and c.function.name and c.function.arguments: # pragma: no branch parts.append( ToolCallPart( tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id or _generate_tool_call_id(), ) ) raw_finish_reason = response.finish_reason provider_details = {'finish_reason': raw_finish_reason} finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason) return ModelResponse( parts=parts, usage=_map_usage(response), model_name=self._model_name, provider_name=self._provider.name, finish_reason=finish_reason, provider_details=provider_details, ) def _map_messages(self, messages: list[ModelMessage]) -> list[ChatMessageV2]: """Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`.""" cohere_messages: list[ChatMessageV2] = [] for message in messages: if isinstance(message, ModelRequest): cohere_messages.extend(self._map_user_message(message)) elif isinstance(message, ModelResponse): texts: list[str] = [] thinking: list[str] = [] tool_calls: list[ToolCallV2] = [] for item in message.parts: if isinstance(item, TextPart): texts.append(item.content) elif isinstance(item, ThinkingPart): thinking.append(item.content) elif isinstance(item, ToolCallPart): tool_calls.append(self._map_tool_call(item)) elif isinstance(item, BuiltinToolCallPart | BuiltinToolReturnPart): # pragma: no cover # This is currently never returned from cohere pass elif isinstance(item, FilePart): # pragma: no cover # Files generated by models are not sent back to models that don't themselves generate files. pass else: assert_never(item) message_param = AssistantChatMessageV2(role='assistant') if texts or thinking: contents: list[AssistantMessageV2ContentItem] = [] if thinking: contents.append(ThinkingAssistantMessageV2ContentItem(thinking='\n\n'.join(thinking))) if texts: # pragma: no branch contents.append(TextAssistantMessageV2ContentItem(text='\n\n'.join(texts))) message_param.content = contents if tool_calls: message_param.tool_calls = tool_calls cohere_messages.append(message_param) else: assert_never(message) if instructions := self._get_instructions(messages): cohere_messages.insert(0, SystemChatMessageV2(role='system', content=instructions)) return cohere_messages def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolV2]: return [self._map_tool_definition(r) for r in model_request_parameters.tool_defs.values()] @staticmethod def _map_tool_call(t: ToolCallPart) -> ToolCallV2: return ToolCallV2( id=_guard_tool_call_id(t=t), type='function', function=ToolCallV2Function( name=t.tool_name, arguments=t.args_as_json_str(), ), ) @staticmethod def _map_tool_definition(f: ToolDefinition) -> ToolV2: return ToolV2( type='function', function=ToolV2Function( name=f.name, description=f.description, parameters=f.parameters_json_schema, ), ) @classmethod def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]: for part in message.parts: if isinstance(part, SystemPromptPart): yield SystemChatMessageV2(role='system', content=part.content) elif isinstance(part, UserPromptPart): if isinstance(part.content, str): yield UserChatMessageV2(role='user', content=part.content) else: raise RuntimeError('Cohere does not yet support multi-modal inputs.') elif isinstance(part, ToolReturnPart): yield ToolChatMessageV2( role='tool', tool_call_id=_guard_tool_call_id(t=part), content=part.model_response_str(), ) elif isinstance(part, RetryPromptPart): if part.tool_name is None: yield UserChatMessageV2(role='user', content=part.model_response()) # pragma: no cover else: yield ToolChatMessageV2( role='tool', tool_call_id=_guard_tool_call_id(t=part), content=part.model_response(), ) else: assert_never(part) def _map_usage(response: V2ChatResponse) -> usage.RequestUsage: u = response.usage if u is None: return usage.RequestUsage() else: details: dict[str, int] = {} if u.billed_units is not None: if u.billed_units.input_tokens: # pragma: no branch details['input_tokens'] = int(u.billed_units.input_tokens) if u.billed_units.output_tokens: details['output_tokens'] = int(u.billed_units.output_tokens) if u.billed_units.search_units: # pragma: no cover details['search_units'] = int(u.billed_units.search_units) if u.billed_units.classifications: # pragma: no cover details['classifications'] = int(u.billed_units.classifications) request_tokens = int(u.tokens.input_tokens) if u.tokens and u.tokens.input_tokens else 0 response_tokens = int(u.tokens.output_tokens) if u.tokens and u.tokens.output_tokens else 0 return usage.RequestUsage( input_tokens=request_tokens, output_tokens=response_tokens, details=details, )

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/pydantic/pydantic-ai'

If you have feedback or need assistance with the MCP directory API, please join our Discord server