client.rsā¢19.7 kB
use std::borrow::Cow;
use thiserror::Error;
use super::*;
use crate::{
model::{
ArgumentInfo, CallToolRequest, CallToolRequestParam, CallToolResult, CancelledNotification,
CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage, ClientNotification,
ClientRequest, ClientResult, CompleteRequest, CompleteRequestParam, CompleteResult,
CompletionContext, CompletionInfo, GetPromptRequest, GetPromptRequestParam,
GetPromptResult, InitializeRequest, InitializedNotification, JsonRpcResponse,
ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest,
ListResourceTemplatesResult, ListResourcesRequest, ListResourcesResult, ListToolsRequest,
ListToolsResult, PaginatedRequestParam, ProgressNotification, ProgressNotificationParam,
ReadResourceRequest, ReadResourceRequestParam, ReadResourceResult, Reference, RequestId,
RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage, ServerNotification,
ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParam, SubscribeRequest,
SubscribeRequestParam, UnsubscribeRequest, UnsubscribeRequestParam,
},
transport::DynamicTransportError,
};
/// It represents the error that may occur when serving the client.
///
/// if you want to handle the error, you can use `serve_client_with_ct` or `serve_client` with `Result<RunningService<RoleClient, S>, ClientError>`
#[derive(Error, Debug)]
pub enum ClientInitializeError {
#[error("expect initialized response, but received: {0:?}")]
ExpectedInitResponse(Option<ServerJsonRpcMessage>),
#[error("expect initialized result, but received: {0:?}")]
ExpectedInitResult(Option<ServerResult>),
#[error("conflict initialized response id: expected {0}, got {1}")]
ConflictInitResponseId(RequestId, RequestId),
#[error("connection closed: {0}")]
ConnectionClosed(String),
#[error("Send message error {error}, when {context}")]
TransportError {
error: DynamicTransportError,
context: Cow<'static, str>,
},
#[error("Cancelled")]
Cancelled,
}
impl ClientInitializeError {
pub fn transport<T: Transport<RoleClient> + 'static>(
error: T::Error,
context: impl Into<Cow<'static, str>>,
) -> Self {
Self::TransportError {
error: DynamicTransportError::new::<T, _>(error),
context: context.into(),
}
}
}
/// Helper function to get the next message from the stream
async fn expect_next_message<T>(
transport: &mut T,
context: &str,
) -> Result<ServerJsonRpcMessage, ClientInitializeError>
where
T: Transport<RoleClient>,
{
transport
.receive()
.await
.ok_or_else(|| ClientInitializeError::ConnectionClosed(context.to_string()))
}
/// Helper function to expect a response from the stream
async fn expect_response<T, S>(
transport: &mut T,
context: &str,
service: &S,
peer: Peer<RoleClient>,
) -> Result<(ServerResult, RequestId), ClientInitializeError>
where
T: Transport<RoleClient>,
S: Service<RoleClient>,
{
loop {
let message = expect_next_message(transport, context).await?;
match message {
// Expected message to complete the initialization
ServerJsonRpcMessage::Response(JsonRpcResponse { id, result, .. }) => {
break Ok((result, id));
}
// Server could send logging messages before handshake
ServerJsonRpcMessage::Notification(mut notification) => {
let ServerNotification::LoggingMessageNotification(logging) =
&mut notification.notification
else {
tracing::warn!(?notification, "Received unexpected message");
continue;
};
let mut context = NotificationContext {
peer: peer.clone(),
meta: Meta::default(),
extensions: Extensions::default(),
};
if let Some(meta) = logging.extensions.get_mut::<Meta>() {
std::mem::swap(&mut context.meta, meta);
}
std::mem::swap(&mut context.extensions, &mut logging.extensions);
if let Err(error) = service
.handle_notification(notification.notification, context)
.await
{
tracing::warn!(?error, "Handle logging before handshake failed.");
}
}
// Server could send pings before handshake
ServerJsonRpcMessage::Request(ref request)
if matches!(request.request, ServerRequest::PingRequest(_)) =>
{
tracing::trace!("Received ping request. Ignored.")
}
// Server SHOULD NOT send any other messages before handshake. We ignore them anyway
_ => tracing::warn!(?message, "Received unexpected message"),
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct RoleClient;
impl ServiceRole for RoleClient {
type Req = ClientRequest;
type Resp = ClientResult;
type Not = ClientNotification;
type PeerReq = ServerRequest;
type PeerResp = ServerResult;
type PeerNot = ServerNotification;
type Info = ClientInfo;
type PeerInfo = ServerInfo;
type InitializeError = ClientInitializeError;
const IS_CLIENT: bool = true;
}
pub type ServerSink = Peer<RoleClient>;
impl<S: Service<RoleClient>> ServiceExt<RoleClient> for S {
fn serve_with_ct<T, E, A>(
self,
transport: T,
ct: CancellationToken,
) -> impl Future<Output = Result<RunningService<RoleClient, Self>, ClientInitializeError>> + Send
where
T: IntoTransport<RoleClient, E, A>,
E: std::error::Error + Send + Sync + 'static,
Self: Sized,
{
serve_client_with_ct(self, transport, ct)
}
}
pub async fn serve_client<S, T, E, A>(
service: S,
transport: T,
) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
where
S: Service<RoleClient>,
T: IntoTransport<RoleClient, E, A>,
E: std::error::Error + Send + Sync + 'static,
{
serve_client_with_ct(service, transport, Default::default()).await
}
pub async fn serve_client_with_ct<S, T, E, A>(
service: S,
transport: T,
ct: CancellationToken,
) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
where
S: Service<RoleClient>,
T: IntoTransport<RoleClient, E, A>,
E: std::error::Error + Send + Sync + 'static,
{
tokio::select! {
result = serve_client_with_ct_inner(service, transport.into_transport(), ct.clone()) => { result }
_ = ct.cancelled() => {
Err(ClientInitializeError::Cancelled)
}
}
}
async fn serve_client_with_ct_inner<S, T>(
service: S,
transport: T,
ct: CancellationToken,
) -> Result<RunningService<RoleClient, S>, ClientInitializeError>
where
S: Service<RoleClient>,
T: Transport<RoleClient> + 'static,
{
let mut transport = transport.into_transport();
let id_provider = <Arc<AtomicU32RequestIdProvider>>::default();
// service
let id = id_provider.next_request_id();
let init_request = InitializeRequest {
method: Default::default(),
params: service.get_info(),
extensions: Default::default(),
};
transport
.send(ClientJsonRpcMessage::request(
ClientRequest::InitializeRequest(init_request),
id.clone(),
))
.await
.map_err(|error| ClientInitializeError::TransportError {
error: DynamicTransportError::new::<T, _>(error),
context: "send initialize request".into(),
})?;
let (peer, peer_rx) = Peer::new(id_provider, None);
let (response, response_id) = expect_response(
&mut transport,
"initialize response",
&service,
peer.clone(),
)
.await?;
if id != response_id {
return Err(ClientInitializeError::ConflictInitResponseId(
id,
response_id,
));
}
let ServerResult::InitializeResult(initialize_result) = response else {
return Err(ClientInitializeError::ExpectedInitResult(Some(response)));
};
peer.set_peer_info(initialize_result);
// send notification
let notification = ClientJsonRpcMessage::notification(
ClientNotification::InitializedNotification(InitializedNotification {
method: Default::default(),
extensions: Default::default(),
}),
);
transport.send(notification).await.map_err(|error| {
ClientInitializeError::transport::<T>(error, "send initialized notification")
})?;
Ok(serve_inner(service, transport, peer, peer_rx, ct))
}
macro_rules! method {
(peer_req $method:ident $Req:ident() => $Resp: ident ) => {
pub async fn $method(&self) -> Result<$Resp, ServiceError> {
let result = self
.send_request(ClientRequest::$Req($Req {
method: Default::default(),
}))
.await?;
match result {
ServerResult::$Resp(result) => Ok(result),
_ => Err(ServiceError::UnexpectedResponse),
}
}
};
(peer_req $method:ident $Req:ident($Param: ident) => $Resp: ident ) => {
pub async fn $method(&self, params: $Param) -> Result<$Resp, ServiceError> {
let result = self
.send_request(ClientRequest::$Req($Req {
method: Default::default(),
params,
extensions: Default::default(),
}))
.await?;
match result {
ServerResult::$Resp(result) => Ok(result),
_ => Err(ServiceError::UnexpectedResponse),
}
}
};
(peer_req $method:ident $Req:ident($Param: ident)? => $Resp: ident ) => {
pub async fn $method(&self, params: Option<$Param>) -> Result<$Resp, ServiceError> {
let result = self
.send_request(ClientRequest::$Req($Req {
method: Default::default(),
params,
extensions: Default::default(),
}))
.await?;
match result {
ServerResult::$Resp(result) => Ok(result),
_ => Err(ServiceError::UnexpectedResponse),
}
}
};
(peer_req $method:ident $Req:ident($Param: ident)) => {
pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
let result = self
.send_request(ClientRequest::$Req($Req {
method: Default::default(),
params,
extensions: Default::default(),
}))
.await?;
match result {
ServerResult::EmptyResult(_) => Ok(()),
_ => Err(ServiceError::UnexpectedResponse),
}
}
};
(peer_not $method:ident $Not:ident($Param: ident)) => {
pub async fn $method(&self, params: $Param) -> Result<(), ServiceError> {
self.send_notification(ClientNotification::$Not($Not {
method: Default::default(),
params,
extensions: Default::default(),
}))
.await?;
Ok(())
}
};
(peer_not $method:ident $Not:ident) => {
pub async fn $method(&self) -> Result<(), ServiceError> {
self.send_notification(ClientNotification::$Not($Not {
method: Default::default(),
extensions: Default::default(),
}))
.await?;
Ok(())
}
};
}
impl Peer<RoleClient> {
method!(peer_req complete CompleteRequest(CompleteRequestParam) => CompleteResult);
method!(peer_req set_level SetLevelRequest(SetLevelRequestParam));
method!(peer_req get_prompt GetPromptRequest(GetPromptRequestParam) => GetPromptResult);
method!(peer_req list_prompts ListPromptsRequest(PaginatedRequestParam)? => ListPromptsResult);
method!(peer_req list_resources ListResourcesRequest(PaginatedRequestParam)? => ListResourcesResult);
method!(peer_req list_resource_templates ListResourceTemplatesRequest(PaginatedRequestParam)? => ListResourceTemplatesResult);
method!(peer_req read_resource ReadResourceRequest(ReadResourceRequestParam) => ReadResourceResult);
method!(peer_req subscribe SubscribeRequest(SubscribeRequestParam) );
method!(peer_req unsubscribe UnsubscribeRequest(UnsubscribeRequestParam));
method!(peer_req call_tool CallToolRequest(CallToolRequestParam) => CallToolResult);
method!(peer_req list_tools ListToolsRequest(PaginatedRequestParam)? => ListToolsResult);
method!(peer_not notify_cancelled CancelledNotification(CancelledNotificationParam));
method!(peer_not notify_progress ProgressNotification(ProgressNotificationParam));
method!(peer_not notify_initialized InitializedNotification);
method!(peer_not notify_roots_list_changed RootsListChangedNotification);
}
impl Peer<RoleClient> {
/// A wrapper method for [`Peer<RoleClient>::list_tools`].
///
/// This function will call [`Peer<RoleClient>::list_tools`] multiple times until all tools are listed.
pub async fn list_all_tools(&self) -> Result<Vec<crate::model::Tool>, ServiceError> {
let mut tools = Vec::new();
let mut cursor = None;
loop {
let result = self
.list_tools(Some(PaginatedRequestParam { cursor }))
.await?;
tools.extend(result.tools);
cursor = result.next_cursor;
if cursor.is_none() {
break;
}
}
Ok(tools)
}
/// A wrapper method for [`Peer<RoleClient>::list_prompts`].
///
/// This function will call [`Peer<RoleClient>::list_prompts`] multiple times until all prompts are listed.
pub async fn list_all_prompts(&self) -> Result<Vec<crate::model::Prompt>, ServiceError> {
let mut prompts = Vec::new();
let mut cursor = None;
loop {
let result = self
.list_prompts(Some(PaginatedRequestParam { cursor }))
.await?;
prompts.extend(result.prompts);
cursor = result.next_cursor;
if cursor.is_none() {
break;
}
}
Ok(prompts)
}
/// A wrapper method for [`Peer<RoleClient>::list_resources`].
///
/// This function will call [`Peer<RoleClient>::list_resources`] multiple times until all resources are listed.
pub async fn list_all_resources(&self) -> Result<Vec<crate::model::Resource>, ServiceError> {
let mut resources = Vec::new();
let mut cursor = None;
loop {
let result = self
.list_resources(Some(PaginatedRequestParam { cursor }))
.await?;
resources.extend(result.resources);
cursor = result.next_cursor;
if cursor.is_none() {
break;
}
}
Ok(resources)
}
/// A wrapper method for [`Peer<RoleClient>::list_resource_templates`].
///
/// This function will call [`Peer<RoleClient>::list_resource_templates`] multiple times until all resource templates are listed.
pub async fn list_all_resource_templates(
&self,
) -> Result<Vec<crate::model::ResourceTemplate>, ServiceError> {
let mut resource_templates = Vec::new();
let mut cursor = None;
loop {
let result = self
.list_resource_templates(Some(PaginatedRequestParam { cursor }))
.await?;
resource_templates.extend(result.resource_templates);
cursor = result.next_cursor;
if cursor.is_none() {
break;
}
}
Ok(resource_templates)
}
/// Convenient method to get completion suggestions for a prompt argument
///
/// # Arguments
/// * `prompt_name` - Name of the prompt being completed
/// * `argument_name` - Name of the argument being completed
/// * `current_value` - Current partial value of the argument
/// * `context` - Optional context with previously resolved arguments
///
/// # Returns
/// CompletionInfo with suggestions for the specified prompt argument
pub async fn complete_prompt_argument(
&self,
prompt_name: impl Into<String>,
argument_name: impl Into<String>,
current_value: impl Into<String>,
context: Option<CompletionContext>,
) -> Result<CompletionInfo, ServiceError> {
let request = CompleteRequestParam {
r#ref: Reference::for_prompt(prompt_name),
argument: ArgumentInfo {
name: argument_name.into(),
value: current_value.into(),
},
context,
};
let result = self.complete(request).await?;
Ok(result.completion)
}
/// Convenient method to get completion suggestions for a resource URI argument
///
/// # Arguments
/// * `uri_template` - URI template pattern being completed
/// * `argument_name` - Name of the URI parameter being completed
/// * `current_value` - Current partial value of the parameter
/// * `context` - Optional context with previously resolved arguments
///
/// # Returns
/// CompletionInfo with suggestions for the specified resource URI argument
pub async fn complete_resource_argument(
&self,
uri_template: impl Into<String>,
argument_name: impl Into<String>,
current_value: impl Into<String>,
context: Option<CompletionContext>,
) -> Result<CompletionInfo, ServiceError> {
let request = CompleteRequestParam {
r#ref: Reference::for_resource(uri_template),
argument: ArgumentInfo {
name: argument_name.into(),
value: current_value.into(),
},
context,
};
let result = self.complete(request).await?;
Ok(result.completion)
}
/// Simple completion for a prompt argument without context
///
/// This is a convenience wrapper around `complete_prompt_argument` for
/// simple completion scenarios that don't require context awareness.
pub async fn complete_prompt_simple(
&self,
prompt_name: impl Into<String>,
argument_name: impl Into<String>,
current_value: impl Into<String>,
) -> Result<Vec<String>, ServiceError> {
let completion = self
.complete_prompt_argument(prompt_name, argument_name, current_value, None)
.await?;
Ok(completion.values)
}
/// Simple completion for a resource URI argument without context
///
/// This is a convenience wrapper around `complete_resource_argument` for
/// simple completion scenarios that don't require context awareness.
pub async fn complete_resource_simple(
&self,
uri_template: impl Into<String>,
argument_name: impl Into<String>,
current_value: impl Into<String>,
) -> Result<Vec<String>, ServiceError> {
let completion = self
.complete_resource_argument(uri_template, argument_name, current_value, None)
.await?;
Ok(completion.values)
}
}