worker.rs•6.69 kB
use std::borrow::Cow;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, Level};
use super::{IntoTransport, Transport};
use crate::service::{RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage};
#[derive(Debug, thiserror::Error)]
pub enum WorkerQuitReason<E> {
#[error("Join error {0}")]
Join(#[from] tokio::task::JoinError),
#[error("Transport fatal {error}, when {context}")]
Fatal {
error: E,
context: Cow<'static, str>,
},
#[error("Transport cancelled")]
Cancelled,
#[error("Transport closed")]
TransportClosed,
#[error("Handler terminated")]
HandlerTerminated,
}
impl<E: std::error::Error + Send + 'static> WorkerQuitReason<E> {
pub fn fatal(error: E, context: impl Into<Cow<'static, str>>) -> Self {
Self::Fatal {
error,
context: context.into(),
}
}
pub fn fatal_context(context: impl Into<Cow<'static, str>>) -> impl FnOnce(E) -> Self {
|e| Self::Fatal {
error: e,
context: context.into(),
}
}
}
pub trait Worker: Sized + Send + 'static {
type Error: std::error::Error + Send + Sync + 'static;
type Role: ServiceRole;
fn err_closed() -> Self::Error;
fn err_join(e: tokio::task::JoinError) -> Self::Error;
fn run(
self,
context: WorkerContext<Self>,
) -> impl Future<Output = Result<(), WorkerQuitReason<Self::Error>>> + Send;
fn config(&self) -> WorkerConfig {
WorkerConfig::default()
}
}
pub struct WorkerSendRequest<W: Worker> {
pub message: TxJsonRpcMessage<W::Role>,
pub responder: tokio::sync::oneshot::Sender<Result<(), W::Error>>,
}
pub struct WorkerTransport<W: Worker> {
rx: tokio::sync::mpsc::Receiver<RxJsonRpcMessage<W::Role>>,
send_service: tokio::sync::mpsc::Sender<WorkerSendRequest<W>>,
join_handle: Option<tokio::task::JoinHandle<Result<(), WorkerQuitReason<W::Error>>>>,
_drop_guard: tokio_util::sync::DropGuard,
ct: CancellationToken,
}
pub struct WorkerConfig {
pub name: Option<String>,
pub channel_buffer_capacity: usize,
}
impl Default for WorkerConfig {
fn default() -> Self {
Self {
name: None,
channel_buffer_capacity: 16,
}
}
}
pub enum WorkerAdapter {}
impl<W: Worker> IntoTransport<W::Role, W::Error, WorkerAdapter> for W {
fn into_transport(self) -> impl Transport<W::Role, Error = W::Error> + 'static {
WorkerTransport::spawn(self)
}
}
impl<W: Worker> WorkerTransport<W> {
pub fn cancel_token(&self) -> CancellationToken {
self.ct.clone()
}
pub fn spawn(worker: W) -> Self {
Self::spawn_with_ct(worker, CancellationToken::new())
}
pub fn spawn_with_ct(worker: W, transport_task_ct: CancellationToken) -> Self {
let config = worker.config();
let worker_name = config.name;
let (to_transport_tx, from_handler_rx) =
tokio::sync::mpsc::channel::<WorkerSendRequest<W>>(config.channel_buffer_capacity);
let (to_handler_tx, from_transport_rx) =
tokio::sync::mpsc::channel::<RxJsonRpcMessage<W::Role>>(config.channel_buffer_capacity);
let context = WorkerContext {
to_handler_tx,
from_handler_rx,
cancellation_token: transport_task_ct.clone(),
};
let join_handle = tokio::spawn(async move {
worker
.run(context)
.instrument(tracing::span!(
Level::TRACE,
"transport_worker",
name = worker_name,
))
.await
.inspect_err(|e| match e {
WorkerQuitReason::Cancelled
| WorkerQuitReason::TransportClosed
| WorkerQuitReason::HandlerTerminated => {
tracing::debug!("worker quit with reason: {:?}", e);
}
WorkerQuitReason::Join(e) => {
tracing::error!("worker quit with join error: {:?}", e);
}
WorkerQuitReason::Fatal { error, context } => {
tracing::error!("worker quit with fatal: {error}, when {context}");
}
})
.inspect(|_| {
tracing::debug!("worker quit");
})
});
Self {
rx: from_transport_rx,
send_service: to_transport_tx,
join_handle: Some(join_handle),
ct: transport_task_ct.clone(),
_drop_guard: transport_task_ct.drop_guard(),
}
}
}
pub struct SendRequest<W: Worker> {
pub message: TxJsonRpcMessage<W::Role>,
pub responder: tokio::sync::oneshot::Sender<RxJsonRpcMessage<W::Role>>,
}
pub struct WorkerContext<W: Worker> {
pub to_handler_tx: tokio::sync::mpsc::Sender<RxJsonRpcMessage<W::Role>>,
pub from_handler_rx: tokio::sync::mpsc::Receiver<WorkerSendRequest<W>>,
pub cancellation_token: CancellationToken,
}
impl<W: Worker> WorkerContext<W> {
pub async fn send_to_handler(
&mut self,
item: RxJsonRpcMessage<W::Role>,
) -> Result<(), WorkerQuitReason<W::Error>> {
self.to_handler_tx
.send(item)
.await
.map_err(|_| WorkerQuitReason::HandlerTerminated)
}
pub async fn recv_from_handler(
&mut self,
) -> Result<WorkerSendRequest<W>, WorkerQuitReason<W::Error>> {
self.from_handler_rx
.recv()
.await
.ok_or(WorkerQuitReason::HandlerTerminated)
}
}
impl<W: Worker> Transport<W::Role> for WorkerTransport<W> {
type Error = W::Error;
fn send(
&mut self,
item: TxJsonRpcMessage<W::Role>,
) -> impl Future<Output = Result<(), Self::Error>> + Send + 'static {
let tx = self.send_service.clone();
let (responder, receiver) = tokio::sync::oneshot::channel();
let request = WorkerSendRequest {
message: item,
responder,
};
async move {
tx.send(request).await.map_err(|_| W::err_closed())?;
receiver.await.map_err(|_| W::err_closed())??;
Ok(())
}
}
async fn receive(&mut self) -> Option<RxJsonRpcMessage<W::Role>> {
self.rx.recv().await
}
async fn close(&mut self) -> Result<(), Self::Error> {
if let Some(handle) = self.join_handle.take() {
self.ct.cancel();
let _quit_reason = handle.await.map_err(W::err_join)?;
Ok(())
} else {
Ok(())
}
}
}