Skip to main content
Glama
tower.rs8.28 kB
pub use web_socket_trace::WebSocketTraceLayer; mod limit_requests { use std::{ future::Future, pin::Pin, sync::{ Arc, atomic::{ AtomicU32, Ordering, }, }, task::{ Context, Poll, }, }; use pin_project_lite::pin_project; use telemetry::prelude::*; use tokio::sync::mpsc; use tower::{ Layer, Service, }; use crate::ShutdownSource; #[derive(Clone, Debug)] pub struct LimitRequestLayer { remaining: Arc<Option<AtomicU32>>, shutdown_tx: mpsc::Sender<ShutdownSource>, } impl LimitRequestLayer { #[allow(dead_code)] pub fn new( remaining: Arc<Option<AtomicU32>>, shutdown_tx: mpsc::Sender<ShutdownSource>, ) -> Self { Self { remaining, shutdown_tx, } } } impl<S> Layer<S> for LimitRequestLayer { type Service = LimitRequest<S>; fn layer(&self, inner: S) -> Self::Service { LimitRequest::new(inner, self.remaining.clone(), self.shutdown_tx.clone()) } } #[derive(Clone, Debug)] pub struct LimitRequest<T> { inner: T, remaining: Arc<Option<AtomicU32>>, shutdown_tx: mpsc::Sender<ShutdownSource>, } impl<T> LimitRequest<T> { pub fn new( inner: T, remaining: Arc<Option<AtomicU32>>, shutdown_tx: mpsc::Sender<ShutdownSource>, ) -> Self { Self { inner, remaining, shutdown_tx, } } } impl<S, Request> Service<Request> for LimitRequest<S> where S: Service<Request>, S::Error: Sync, { type Response = S::Response; type Error = S::Error; type Future = ResponseFuture<S::Future>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { self.inner.poll_ready(cx) } fn call(&mut self, request: Request) -> Self::Future { let shutdown_tx = match (*self.remaining).as_ref() { // If we're limiting requests, then decrement by one and if we're at zero, pass // along a shutdown tx handle to the future. Otherwise, the future doesn't get a // shutdown handle. Some(remaining) => { let mut updated = remaining.load(Ordering::Relaxed); updated = updated.saturating_sub(1); remaining.store(updated, Ordering::Relaxed); debug!("requests remaining: {}", updated); if updated > 0 { None } else { Some(self.shutdown_tx.clone()) } } // If we're not limiting requests, then pass `None` None => None, }; let response = self.inner.call(request); ResponseFuture::new(response, shutdown_tx) } } pin_project! { #[derive(Debug)] pub struct ResponseFuture<T> { #[pin] response: T, shutdown_tx: Option<mpsc::Sender<ShutdownSource>>, } } impl<T> ResponseFuture<T> { fn new(response: T, shutdown_tx: Option<mpsc::Sender<ShutdownSource>>) -> Self { Self { response, shutdown_tx, } } } impl<F, T, E> Future for ResponseFuture<F> where F: Future<Output = Result<T, E>>, { type Output = Result<T, E>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { let this = self.project(); match this.response.poll(cx) { Poll::Ready(value) => { if let Some(tx) = this.shutdown_tx { let tx = tx.clone(); tokio::spawn(async move { trace!("sending shutdown to limit request shutdown receiver"); if tx.send(ShutdownSource::LimitRequest).await.is_err() { trace!( "the limit request shutdown receiver has already been dropped" ); } }); } Poll::Ready(value) } Poll::Pending => Poll::Pending, } } } } mod web_socket_trace { use std::{ future::Future, pin::Pin, task::{ Context, Poll, }, }; use axum::{ body::Body, extract::MatchedPath, http::Request, response::Response, }; use pin_project_lite::pin_project; use telemetry::prelude::*; use telemetry_http::{ ParentSpan, propagation, }; use tower::{ Layer, Service, }; // Vendored and modified from futures_core::task::poll::ready macro_rules! ready { ($e:expr_2021 $(,)?) => { match $e { ::std::task::Poll::Ready(t) => t, ::std::task::Poll::Pending => return ::std::task::Poll::Pending, } }; } #[derive(Clone, Debug)] pub struct WebSocketTraceLayer; impl WebSocketTraceLayer { pub fn new() -> Self { Self } } impl<S> Layer<S> for WebSocketTraceLayer { type Service = WebSocketTrace<S>; fn layer(&self, inner: S) -> Self::Service { WebSocketTrace::new(inner) } } #[derive(Clone, Debug)] pub struct WebSocketTrace<S> { inner: S, } impl<S> WebSocketTrace<S> { pub fn new(inner: S) -> Self { Self { inner } } } impl<S> Service<Request<Body>> for WebSocketTrace<S> where S: Service<Request<Body>, Response = Response> + Send + 'static, S::Future: Send + 'static, { type Response = S::Response; type Error = S::Error; type Future = ResponseFuture<S::Future>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { self.inner.poll_ready(cx) } fn call(&mut self, mut req: Request<Body>) -> Self::Future { let span = span_from_request(&req); req.extensions_mut().insert(ParentSpan::new(span.clone())); let response = { let _guard = span.enter(); self.inner.call(req) }; ResponseFuture { inner: response } } } pin_project! { #[derive(Debug)] pub struct ResponseFuture<S> { #[pin] inner: S, } } impl<F, T, E> Future for ResponseFuture<F> where F: Future<Output = Result<T, E>>, { type Output = Result<T, E>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { let this = self.project(); let result = ready!(this.inner.poll(cx)); Poll::Ready(result) } } fn span_from_request(req: &Request<Body>) -> Span { let span = span!( Level::INFO, "GET websocket", otel.kind = SpanKind::Server.as_str(), otel.name = Empty, otel.status_code = Empty, otel.status_message = Empty, ); let method = req.method().as_str(); let matched_path = req.extensions().get::<MatchedPath>().map(|mp| mp.as_str()); span.record( "otel.name", match matched_path { Some(path) => format!("WS {method} {path}"), None => format!("WS {method}"), }, ); // Extract OpenTelemetry parent span metadata from the request headers (if it exists) and // associate it with this request span propagation::parent_span_from_headers(&span, req.headers()); span } }

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/systeminit/si'

If you have feedback or need assistance with the MCP directory API, please join our Discord server