test_progress_subscriber.rs•3.81 kB
use futures::StreamExt;
use rmcp::{
ClientHandler, Peer, RoleServer, ServerHandler, ServiceExt,
handler::{client::progress::ProgressDispatcher, server::tool::ToolRouter},
model::{CallToolRequestParam, ClientRequest, Meta, ProgressNotificationParam, Request},
service::PeerRequestOptions,
tool, tool_handler, tool_router,
};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
pub struct MyClient {
progress_handler: ProgressDispatcher,
}
impl MyClient {
pub fn new() -> Self {
Self {
progress_handler: ProgressDispatcher::new(),
}
}
}
impl Default for MyClient {
fn default() -> Self {
Self::new()
}
}
impl ClientHandler for MyClient {
async fn on_progress(
&self,
params: rmcp::model::ProgressNotificationParam,
_context: rmcp::service::NotificationContext<rmcp::RoleClient>,
) {
tracing::info!("Received progress notification: {:?}", params);
self.progress_handler.handle_notification(params).await;
}
}
pub struct MyServer {
tool_router: ToolRouter<Self>,
}
impl MyServer {
pub fn new() -> Self {
Self {
tool_router: Self::tool_router(),
}
}
}
impl Default for MyServer {
fn default() -> Self {
Self::new()
}
}
#[tool_router]
impl MyServer {
#[tool]
pub async fn some_progress(
meta: Meta,
client: Peer<RoleServer>,
) -> Result<(), rmcp::ErrorData> {
let progress_token = meta
.get_progress_token()
.ok_or(rmcp::ErrorData::invalid_params(
"Progress token is required for this tool",
None,
))?;
for step in 0..10 {
let _ = client
.notify_progress(ProgressNotificationParam {
progress_token: progress_token.clone(),
progress: (step as f64),
total: Some(10.0),
message: Some("Some message".into()),
})
.await;
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
Ok(())
}
}
#[tool_handler]
impl ServerHandler for MyServer {}
#[tokio::test]
async fn test_progress_subscriber() -> anyhow::Result<()> {
let _ = tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "debug".to_string().into()),
)
.with(tracing_subscriber::fmt::layer())
.try_init();
let client = MyClient::new();
let server = MyServer::new();
let (transport_server, transport_client) = tokio::io::duplex(4096);
tokio::spawn(async move {
let service = server.serve(transport_server).await?;
service.waiting().await?;
anyhow::Ok(())
});
let client_service = client.serve(transport_client).await?;
let handle = client_service
.send_cancellable_request(
ClientRequest::CallToolRequest(Request::new(CallToolRequestParam {
name: "some_progress".into(),
arguments: None,
})),
PeerRequestOptions::no_options(),
)
.await?;
let mut progress_subscriber = client_service
.service()
.progress_handler
.subscribe(handle.progress_token.clone())
.await;
tokio::spawn(async move {
while let Some(notification) = progress_subscriber.next().await {
tracing::info!("Progress notification: {:?}", notification);
}
});
let _response = handle.await_response().await?;
// Simulate some delay to allow the async task to complete
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
Ok(())
}