Skip to main content
Glama
repl.rs8.55 kB
use super::{ typed_socket::{websocket_connect, WebSocketRecv, WebSocketSend}, util::authorized_request, ClientError, }; use crate::api::{ api_types::{ExecResult, Instruction}, id_types::{InstructionSeq, MachineName, RequestSeq}, protocol::{MessageFromServer, MessageToServer, StandardOutput}, token::ApiToken, }; use std::{ ops::{Deref, DerefMut}, sync::{atomic::AtomicU32, Arc, Mutex}, }; use tokio::{ sync::{broadcast, oneshot}, task::JoinHandle, }; pub const DEFAULT_INSTRUCTION_TIMEOUT_SECONDS: i32 = 15; #[derive(Default)] pub struct RequestSeqGenerator { next: AtomicU32, } impl RequestSeqGenerator { pub fn next(&self) -> RequestSeq { let r = self.next.fetch_add(1, std::sync::atomic::Ordering::Relaxed); r.into() } } #[derive(Debug)] pub enum ReplConnectionState { Idle, WaitingForInstructionSeq { request_id: RequestSeq, send_result_handle: oneshot::Sender<ExecResultHandle>, }, WaitingForResult { instruction_id: InstructionSeq, output_sender: broadcast::Sender<StandardOutput>, result_sender: oneshot::Sender<ExecResult>, }, } impl Default for ReplConnectionState { fn default() -> Self { Self::Idle } } pub struct ReplConnection { pub machine_name: MachineName, request_seq_generator: RequestSeqGenerator, sender: WebSocketSend<MessageToServer>, receiver_handle: Option<JoinHandle<()>>, state: Arc<Mutex<ReplConnectionState>>, } fn handle_message( message: MessageFromServer, state: Arc<Mutex<ReplConnectionState>>, ) -> Result<(), ClientError> { let msg = message; match msg { MessageFromServer::ExecReceived { seq, request_id } => { let mut state = state.lock().expect("State lock poisoned"); let old_state = std::mem::take(state.deref_mut()); match old_state { ReplConnectionState::WaitingForInstructionSeq { request_id: expected_request_seq, send_result_handle: receiver_sender, } => { if request_id != expected_request_seq { tracing::warn!( ?request_id, ?expected_request_seq, "Unexpected request seq" ); return Ok(()); } let (output_sender, output_receiver) = broadcast::channel::<StandardOutput>(50); let (result_sender, result_receiver) = oneshot::channel(); *state = ReplConnectionState::WaitingForResult { instruction_id: seq, output_sender, result_sender, }; let _ = receiver_sender.send(ExecResultHandle { result: result_receiver, receiver: output_receiver, }); } state => { tracing::error!(?state, "Unexpected ExecReceived while in state {state:?}"); } } } MessageFromServer::Result(result) => { let mut state = state.lock().expect("State lock poisoned"); let old_state = std::mem::take(state.deref_mut()); match old_state { ReplConnectionState::WaitingForResult { instruction_id: instruction_seq, result_sender, .. } => { if result.instruction_id != instruction_seq { tracing::warn!( ?instruction_seq, ?result.instruction_id, "Unexpected instruction seq" ); return Ok(()); } let _ = result_sender.send(result.result); } state => { tracing::error!(?state, "Unexpected Result while in state {state:?}"); } } } MessageFromServer::Output { chunk, instruction_id: instruction, } => { let state = state.lock().expect("State lock poisoned"); match state.deref() { ReplConnectionState::WaitingForResult { instruction_id: instruction_seq, output_sender, .. } => { if *instruction_seq != instruction { tracing::warn!( ?instruction_seq, ?instruction, "Unexpected instruction seq" ); return Ok(()); } let _ = output_sender.send(chunk); } state => { tracing::error!(?state, "Unexpected Output while in state {state:?}"); } } } MessageFromServer::Error(err) => { return Err(ClientError::ApiError(err)); } MessageFromServer::Connected { machine_name: _ } => {} msg => tracing::warn!("message type not implmented: {msg:?}"), } Ok(()) } async fn receive_loop( mut receiver: WebSocketRecv<MessageFromServer>, state: Arc<Mutex<ReplConnectionState>>, ) { while let Ok(Some(msg)) = receiver.recv().await { if let Err(err) = handle_message(msg, state.clone()) { tracing::error!(?err, "Failed to handle message"); } } } impl ReplConnection { pub async fn new(url: reqwest::Url, token: ApiToken) -> Result<Self, ClientError> { let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); let req = authorized_request(url, token)?; let (sender, mut receiver) = websocket_connect::<MessageToServer, MessageFromServer>(req).await?; let state: Arc<Mutex<ReplConnectionState>> = Arc::default(); let machine_name = match receiver.recv().await? { Some(MessageFromServer::Connected { machine_name }) => machine_name, _ => { return Err(ClientError::Other(String::from( "Expected `connected` message from REPL.", ))) } }; let receiver_handle = tokio::spawn(receive_loop(receiver, state.clone())); Ok(Self { machine_name, request_seq_generator: Default::default(), sender, receiver_handle: Some(receiver_handle), state, }) } pub async fn exec(&mut self, code: &str) -> Result<ExecResultHandle, ClientError> { let instruction = Instruction { code: code.to_string(), timeout_seconds: DEFAULT_INSTRUCTION_TIMEOUT_SECONDS, }; self.exec_instruction(instruction).await } pub async fn exec_instruction( &mut self, instruction: Instruction, ) -> Result<ExecResultHandle, ClientError> { let request_id = self.request_seq_generator.next(); let message = MessageToServer::Exec { instruction, request_id, }; self.sender.send(&message).await?; let (send_result_handle, receive_result_handle) = oneshot::channel::<ExecResultHandle>(); { let mut state = self.state.lock().expect("State lock poisoned"); *state.deref_mut() = ReplConnectionState::WaitingForInstructionSeq { request_id, send_result_handle, }; } receive_result_handle .await .map_err(|_| ClientError::InstructionInterrupted) } } impl Drop for ReplConnection { fn drop(&mut self) { if let Some(handle) = self.receiver_handle.take() { handle.abort(); } } } #[derive(Debug)] pub struct ExecResultHandle { result: oneshot::Receiver<ExecResult>, receiver: broadcast::Receiver<StandardOutput>, } impl ExecResultHandle { pub async fn next(&mut self) -> Option<StandardOutput> { self.receiver.recv().await.ok() } pub async fn result(self) -> Result<ExecResult, ClientError> { self.result .await .map_err(|_| ClientError::InstructionInterrupted) } }

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/jamsocket/forevervm'

If you have feedback or need assistance with the MCP directory API, please join our Discord server