connection.rs•8.93 kB
use crate::heartbeat::HeartbeatManager;
use crate::message::*;
use crate::protocol::{handshake, parse_response_typed, McpProtocol};
use crate::transport::{backoff_durations, IncomingFrame, Transport, WebSocketTransport};
use crate::version::{ProtocolVersion, VersionNegotiator, DEFAULT_VERSION};
use crate::{McpError, Result};
use dashmap::DashMap;
use serde::de::DeserializeOwned;
use serde::Serialize;
use serde_json::{json, Value};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{broadcast, oneshot, RwLock};
use tracing::warn;
use url::Url;
#[derive(Debug, Clone)]
pub struct McpClientConfig {
pub url: Url,
pub client_name: String,
pub client_version: String,
pub request_timeout: Duration,
pub connect_max_retries: usize,
pub heartbeat: Option<HeartbeatManager>,
}
impl McpClientConfig {
pub fn new(url: Url) -> Self {
Self {
url,
client_name: "codegraph-mcp-rs".into(),
client_version: env!("CARGO_PKG_VERSION").into(),
request_timeout: Duration::from_secs(30),
connect_max_retries: 5,
heartbeat: None,
}
}
pub fn with_heartbeat(mut self, hb: HeartbeatManager) -> Self {
self.heartbeat = Some(hb);
self
}
}
/// Core MCP connection supporting JSON-RPC 2.0 and MCP handshake
pub struct McpConnection {
#[allow(dead_code)]
url: Url,
writer: Arc<dyn Transport>,
incoming: broadcast::Receiver<IncomingFrame>,
negotiator: VersionNegotiator,
protocol: RwLock<McpProtocol>,
pending: DashMap<String, oneshot::Sender<JsonRpcMessage>>, // request_id -> tx
in_flight: AtomicU64,
notify_handler: RwLock<Option<Arc<dyn Fn(JsonRpcNotification) + Send + Sync>>>,
}
impl McpConnection {
pub async fn connect(cfg: &McpClientConfig) -> Result<Arc<Self>> {
// Retry with backoff
let mut last_err: Option<McpError> = None;
for d in backoff_durations(cfg.connect_max_retries) {
match WebSocketTransport::connect(&cfg.url, cfg.heartbeat.clone()).await {
Ok((t, rx)) => {
let conn = Arc::new(Self::new_inner(cfg.url.clone(), t, rx));
conn.spawn_reader();
conn.initialize(&cfg.client_name, &cfg.client_version)
.await?;
return Ok(conn);
}
Err(e) => {
last_err = Some(e);
warn!(delay_ms = d.as_millis() as u64, "Connect failed, retrying");
tokio::time::sleep(d).await;
}
}
}
Err(last_err.unwrap_or(McpError::Transport("connect failed".into())))
}
fn new_inner(
url: Url,
writer: Arc<dyn Transport>,
incoming: broadcast::Receiver<IncomingFrame>,
) -> Self {
Self {
url,
writer,
incoming,
negotiator: VersionNegotiator::new(),
protocol: RwLock::new(McpProtocol::default()),
pending: DashMap::new(),
in_flight: AtomicU64::new(0),
notify_handler: RwLock::new(None),
}
}
fn spawn_reader(self: &Arc<Self>) {
let this = Arc::clone(self);
tokio::spawn(async move {
let mut rx = this.incoming.resubscribe();
while let Ok(frame) = rx.recv().await {
match frame {
IncomingFrame::Text(txt) => {
match serde_json::from_str::<JsonRpcMessage>(&txt) {
Ok(msg) => this.on_message(msg).await,
Err(e) => warn!(%e, "Failed to parse incoming JSON-RPC"),
}
}
IncomingFrame::Close(code_reason) => {
warn!(?code_reason, "Connection closed by server");
break;
}
_ => {}
}
}
});
}
async fn on_message(&self, msg: JsonRpcMessage) {
match msg.clone() {
JsonRpcMessage::V2(JsonRpcV2Message::Response(res)) => {
let id_str = match &res.id {
Value::String(s) => s.clone(),
v => v.to_string(),
};
if let Some((_, tx)) = self.pending.remove(&id_str) {
let _ = tx.send(JsonRpcMessage::V2(JsonRpcV2Message::Response(res)));
} else {
warn!(id = id_str, "Response with unknown id");
}
}
JsonRpcMessage::V2(JsonRpcV2Message::Notification(notif)) => {
if let Some(handler) = self.notify_handler.read().await.as_ref() {
handler(notif)
}
}
JsonRpcMessage::V2(JsonRpcV2Message::Request(_req)) => {
// Server initiated request – not supported for now
warn!("Server-initiated request received; ignoring for now");
}
}
}
pub async fn set_notification_handler<F>(&self, f: F)
where
F: Fn(JsonRpcNotification) + Send + Sync + 'static,
{
*self.notify_handler.write().await = Some(Arc::new(f));
}
async fn initialize(&self, client_name: &str, client_version: &str) -> Result<()> {
let req = handshake::build_initialize_request(
&self.negotiator,
Some(DEFAULT_VERSION),
client_name,
client_version,
None,
)
.await?;
let resp: McpInitializeResult = self
.send_request_typed("initialize", &req.params.unwrap_or(json!({})))
.await?;
let negotiated = ProtocolVersion::new(resp.protocol_version)?;
*self.protocol.write().await = McpProtocol::new(negotiated);
Ok(())
}
pub fn inflight(&self) -> u64 {
self.in_flight.load(Ordering::Relaxed)
}
pub async fn send_notification<T: Serialize>(&self, method: &str, params: &T) -> Result<()> {
let p = self.protocol.read().await.clone();
let notif = p.build_notification(method, params)?;
let msg = JsonRpcMessage::V2(JsonRpcV2Message::Notification(notif));
let text = serde_json::to_string(&msg)?;
self.writer.send_text(&text).await
}
pub async fn send_request_raw(
&self,
method: &str,
params: Value,
_timeout_dur: Duration,
) -> Result<JsonRpcMessage> {
let id = uuid::Uuid::new_v4().to_string();
let req = JsonRpcRequest::new(json!(id.clone()), method.to_string(), Some(params));
let msg = JsonRpcMessage::V2(JsonRpcV2Message::Request(req));
let text = serde_json::to_string(&msg)?;
let (tx, rx) = oneshot::channel();
self.pending.insert(id.clone(), tx);
self.in_flight.fetch_add(1, Ordering::SeqCst);
let send_res = self.writer.send_text(&text).await;
if let Err(e) = send_res {
self.pending.remove(&id);
self.in_flight.fetch_sub(1, Ordering::SeqCst);
return Err(e);
}
let res = rx.await;
self.in_flight.fetch_sub(1, Ordering::SeqCst);
match res {
Ok(msg) => Ok(msg),
Err(_canceled) => {
self.pending.remove(&id);
Err(McpError::ConnectionClosed)
}
}
}
pub async fn send_request_typed<P, R>(&self, method: &str, params: &P) -> Result<R>
where
P: Serialize,
R: DeserializeOwned,
{
let val = serde_json::to_value(params)?;
let msg = self
.send_request_raw(method, val, Duration::from_secs(30))
.await?;
parse_response_typed::<R>(&msg)
}
pub async fn close(&self) -> Result<()> {
self.writer.close().await
}
}
/// Simple connection pool for multiplexed MCP clients
pub struct McpClientPool {
clients: Vec<Arc<McpConnection>>, // shared connections
}
impl McpClientPool {
pub async fn connect(url: Url, size: usize) -> Result<Self> {
let mut clients = Vec::with_capacity(size);
for _ in 0..size.max(1) {
let cfg = McpClientConfig::new(url.clone());
clients.push(McpConnection::connect(&cfg).await?);
}
Ok(Self { clients })
}
/// Get the least busy connection (based on in-flight requests)
pub fn acquire(&self) -> Arc<McpConnection> {
let mut best = None;
let mut best_load = u64::MAX;
for c in &self.clients {
let load = c.inflight();
if load < best_load {
best = Some(Arc::clone(c));
best_load = load;
}
}
best.expect("pool should contain at least one client")
}
}