Skip to main content
Glama
mod.rs26.4 kB
use std::{ collections::HashMap, fs::{ self, create_dir_all, }, io::Write, os::unix::fs::PermissionsExt, path::{ Path, PathBuf, }, process::{ Command, Stdio, }, sync::{ Arc, atomic::{ AtomicBool, AtomicUsize, Ordering, }, }, time::{ Duration, Instant, }, }; use async_nats::jetstream::{ consumer::{ AckPolicy, DeliverPolicy, ReplayPolicy, pull::Config as PullConsumerConfig, }, stream::{ Config, DiscardPolicy, RetentionPolicy, Source, StorageType, }, }; use aws_credential_types::Credentials as AwsCredentials; use aws_sdk_s3::{ Client, config::{ Builder as ConfigBuilder, Region as AwsRegion, }, primitives::ByteStream, }; use aws_smithy_runtime_api::client::behavior_version::BehaviorVersion; use bedrock_core::{ ArtifactStoreConfig, PublishResult, }; use futures::{ StreamExt, TryStreamExt, }; use s3::{ bucket::Bucket, creds::Credentials, region::Region, }; use serde::{ Deserialize, Serialize, }; use serde_json::Value; use si_data_nats::{ NatsClient, jetstream, }; use telemetry::tracing::{ error, info, }; use tokio::{ fs::{ File as TokioFile, File, }, io::{ AsyncReadExt, AsyncWriteExt, BufReader, }, task, time::{ interval, timeout, }, }; #[derive(Serialize, Deserialize)] struct JsonMessage { subject: String, headers: HashMap<String, String>, payload_hex: String, } fn progress_bar_line(percentage: f64, current: usize, total: usize) -> String { let width = 30; let filled = (percentage * width as f64).round() as usize; let empty = width - filled; let bar = format!( "[{}{}] artifact @ {:>3}% | artifact {}/{}", "#".repeat(filled), " ".repeat(empty), (percentage * 100.0).round() as u64, current, total ); bar } pub async fn capture_nats( nats_client: &NatsClient, nats_streams: &[String], recording_id: &str, ) -> Result<(), String> { let js = jetstream::new(nats_client.clone()); let base_dir = PathBuf::from("./recordings/datasources") .join(recording_id) .join("nats_sequences"); create_dir_all(&base_dir).map_err(|e| format!("Failed to create output directory: {e}"))?; for source_stream in nats_streams { let mirror_stream_name = format!("{source_stream}_AUDIT"); let output_path = base_dir.join(format!("{source_stream}.sequence")); println!("Dumping stream: {mirror_stream_name}"); let mut stream = js .get_stream(&mirror_stream_name) .await .map_err(|e| format!("Failed to get stream {mirror_stream_name}: {e}"))?; let mut messages = vec![]; let mut seq = 1u64; loop { // Refresh stream info to get the latest last_sequence let stream_info = stream .info() .await .map_err(|e| format!("Failed to fetch stream info: {e}"))?; let last_seq = stream_info.state.last_sequence; if seq > last_seq { println!("Reached end of stream at seq {last_seq}. Ending capture."); break; } match timeout(Duration::from_secs(2), stream.get_raw_message(seq)).await { Ok(Ok(msg)) => { let mut headers = HashMap::new(); for (k, vs) in msg.headers.iter() { if let Some(v) = vs.first() { if let Ok(vstr) = std::str::from_utf8(v.as_ref()) { headers.insert(k.to_string(), vstr.to_string()); } } } messages.push(JsonMessage { subject: msg.subject.to_string(), headers, payload_hex: hex::encode(&msg.payload), }); seq += 1; } Ok(Err(e)) => { return Err(format!("Failed to get message {seq}: {e}")); } Err(_) => { println!("Timeout waiting for message at seq {seq}. Ending capture."); break; } } } let mut file = TokioFile::create(&output_path) .await .map_err(|e| format!("Failed to create output file: {e}"))?; let serialized = serde_json::to_string_pretty(&messages) .map_err(|e| format!("Failed to serialize messages: {e}"))?; file.write_all(serialized.as_bytes()) .await .map_err(|e| format!("Failed to write output file: {e}"))?; println!( "Dumped {} messages to {}", messages.len(), output_path.display() ); js.delete_stream(&mirror_stream_name) .await .map_err(|e| format!("Failed to delete stream {mirror_stream_name}: {e}"))?; println!("🗑 Deleted stream {mirror_stream_name}"); } Ok(()) } pub fn resolve_local_sql_files(recording_id: &str) -> Result<Vec<String>, String> { let base_dir = PathBuf::from("./recordings/datasources") .join(recording_id) .join("database_restore_points/start"); if !base_dir.exists() { return Err(format!("Directory does not exist: {}", base_dir.display())); } let entries = fs::read_dir(&base_dir) .map_err(|e| format!("Failed to read directory {}: {}", base_dir.display(), e))?; let mut sql_paths = Vec::new(); for entry in entries { let entry = entry.map_err(|e| format!("Failed to read entry: {e}"))?; let path = entry.path(); if path.is_file() && path.extension().is_some_and(|e| e == "sql") { sql_paths.push( path.canonicalize() .map_err(|e| format!("Failed to resolve absolute path: {e}"))? .to_string_lossy() .to_string(), ); } } if sql_paths.is_empty() { Err(format!( "No SQL restore files found in {}", base_dir.display() )) } else { println!( "Found {} SQL restore file(s) in {}", sql_paths.len(), base_dir.display() ); Ok(sql_paths) } } pub async fn resolve_remote_artifact_files( recording_id: &str, config: &ArtifactStoreConfig, ) -> Result<Vec<String>, String> { let bucket_name = config .metadata .get("bucketName") .and_then(|v| v.as_str()) .ok_or_else(|| "Missing or invalid 'bucketName' in artifact config".to_string())?; let prefix = format!("bedrock/datasources/{recording_id}/"); info!("🔍 Using prefix '{}' in bucket '{}'", prefix, bucket_name); let credentials = Credentials::anonymous() .map_err(|e| format!("Failed to get anonymous credentials: {e}"))?; let region = Region::Custom { region: "us-east-1".to_string(), endpoint: "https://s3.amazonaws.com".to_string(), }; let bucket = Bucket::new(bucket_name, region, credentials) .map_err(|e| format!("Failed to create bucket: {e}"))? .with_path_style(); let results = bucket .list(prefix.clone(), None) .await .map_err(|e| format!("Failed to list objects: {e}"))?; let mut all_objects = vec![]; for result in results { all_objects.extend( result .contents .into_iter() .filter(|obj| !obj.key.ends_with('/')), ); } let total = all_objects.len(); if total == 0 { return Err(format!("No downloadable files found under prefix {prefix}")); } info!("Found {} files under S3 prefix '{}'", total, prefix); let mut downloaded_paths = Vec::new(); let mut found_sql = false; let mut found_sequence = false; for (index, obj) in all_objects.into_iter().enumerate() { let key = obj.key; let relative_s3_path = key .strip_prefix("bedrock/datasources/") .ok_or_else(|| format!("Unexpected object key format: {key}"))?; let relative_path = Path::new("recordings") .join("datasources") .join(relative_s3_path); info!("[{}/{}] Fetching {}", index + 1, total, key); if let Some(parent) = relative_path.parent() { create_dir_all(parent) .map_err(|e| format!("Failed to create directory {parent:?}: {e}"))?; } let mut file = File::create(&relative_path) .await .map_err(|e| format!("Failed to create file {relative_path:?}: {e}"))?; // Start downloading with progress tracking let mut response = bucket .get_object_stream(&key) .await .map_err(|e| format!("Failed to fetch object {key}: {e}"))?; let total_bytes = Arc::new(AtomicUsize::new(0)); let size_bytes = obj.size as usize; let total_bytes_clone = Arc::clone(&total_bytes); let artifact_index = index + 1; let downloading = Arc::new(AtomicBool::new(true)); let downloading_clone = Arc::clone(&downloading); let logger_handle = tokio::spawn(async move { let mut ticker = interval(Duration::from_secs(5)); while downloading_clone.load(Ordering::Relaxed) { ticker.tick().await; let downloaded = total_bytes_clone.load(Ordering::Relaxed); let percent = (downloaded as f64 / size_bytes as f64).min(1.0); let bar = progress_bar_line(percent, artifact_index, total); info!("{}", bar); } }); while let Some(chunk) = response.bytes().next().await { let bytes = chunk.map_err(|e| format!("Stream error: {e}"))?; total_bytes.fetch_add(bytes.len(), Ordering::Relaxed); file.write_all(&bytes) .await .map_err(|e| format!("Failed to write chunk to file {relative_path:?}: {e}"))?; } downloading.store(false, Ordering::Relaxed); logger_handle.abort(); let bar = progress_bar_line(1.0, index + 1, total); info!("{}", bar); if key.ends_with(".sql") { found_sql = true; } else if key.ends_with(".sequence") { found_sequence = true; } downloaded_paths.push(relative_path.to_string_lossy().to_string()); let progress_bar = format!( "[{}>{}]", "=".repeat((index + 1) * 20 / total), " ".repeat(20 - (index + 1) * 20 / total) ); info!("{} Finished: {}", progress_bar, relative_path.display()); } if !found_sql && !found_sequence { return Err(format!( "No .sql or .sequence files found under prefix {prefix}" )); } else if !found_sql { return Err(format!("No .sql files found under prefix {prefix}")); } // No sequence files is totally valid for a DB restore point only, i.e. not recording. info!( "Downloaded {} file(s) to ./recordings/datasources/{}", downloaded_paths.len(), recording_id ); Ok(downloaded_paths) } pub async fn resolve_test( recording_id: &String, artifact_config: ArtifactStoreConfig, ) -> Result<Vec<String>, String> { match resolve_local_sql_files(recording_id) { Ok(paths) => Ok(paths), Err(local_err) => { println!("Local resolution failed: {local_err}. Trying S3..."); let all_paths = resolve_remote_artifact_files(recording_id, &artifact_config).await?; let sql_paths: Vec<String> = all_paths .into_iter() .filter(|p| p.ends_with(".sql")) .collect(); if sql_paths.is_empty() { Err(format!( "No .sql files found remotely for recording {recording_id}" )) } else { println!( "✅ Found {} SQL file(s) from remote download", sql_paths.len() ); Ok(sql_paths) } } } } pub async fn collect_files(recording_id: &str) -> Result<Vec<PathBuf>, String> { fn collect_files_rec(dir: &Path, files: &mut Vec<PathBuf>) -> Result<(), String> { for entry in fs::read_dir(dir).map_err(|e| format!("Read dir error: {e}"))? { let entry = entry.map_err(|e| format!("Dir entry error: {e}"))?; let path = entry.path(); if path.is_dir() { collect_files_rec(&path, files)?; } else if path.is_file() { files.push(path); } } Ok(()) } let base_path = PathBuf::from("recordings/datasources").join(recording_id); if !base_path.exists() { return Err(format!("Path does not exist: {base_path:?}")); } let mut file_paths = Vec::new(); collect_files_rec(&base_path, &mut file_paths)?; Ok(file_paths) } pub async fn configure_nats( nats_client: &NatsClient, nats_streams: &[String], recording_id: &str, ) -> Result<(), String> { let js = jetstream::new(nats_client.clone()); for source_stream in nats_streams { let mirror_stream_name = format!("{source_stream}_AUDIT"); if js.get_stream(&mirror_stream_name).await.is_ok() { println!("🗑 Deleting existing stream: {mirror_stream_name}"); js.delete_stream(&mirror_stream_name) .await .map_err(|e| format!("Failed to delete stream {mirror_stream_name}: {e}"))?; } let stream_config = Config { name: mirror_stream_name.clone(), description: Some(format!( "Passive copy of {source_stream} stream for recording ID {recording_id}" )), storage: StorageType::File, retention: RetentionPolicy::Limits, discard: DiscardPolicy::Old, allow_direct: true, sources: Some(vec![Source { name: source_stream.clone(), ..Default::default() }]), duplicate_window: Duration::from_secs(0), ..Default::default() }; js.create_stream(stream_config.clone()) .await .map_err(|e| format!("Failed to create stream {mirror_stream_name}: {e}"))?; let consumer_config = PullConsumerConfig { durable_name: Some(format!("{mirror_stream_name}_SINK")), deliver_policy: DeliverPolicy::All, ack_policy: AckPolicy::None, replay_policy: ReplayPolicy::Instant, max_ack_pending: 1024, ..Default::default() }; js.create_consumer_on_stream(consumer_config, mirror_stream_name.clone()) .await .map_err(|e| format!("Failed to create consumer for {mirror_stream_name}: {e}"))?; } Ok(()) } /* use futures::TryStreamExt; let client = Client::connect_with_options( "localhost:4222", None, ConnectOptions::default(), ).await?; let jetstream = si_data_nats::jetstream::new(client); let mut names = jetstream.stream_names(); while let Some(stream) = names.try_next().await? { println!("stream: {}", stream); } */ pub async fn clear_nats(nats_client: &NatsClient) -> Result<(), String> { let js = si_data_nats::jetstream::new(nats_client.clone()); let mut names = js.stream_names(); while let Ok(Some(stream_name)) = TryStreamExt::try_next(&mut names).await { println!("stream: {stream_name}"); if stream_name.ends_with("_AUDIT") { println!("Deleting stream: {stream_name}"); js.delete_stream(&stream_name) .await .map_err(|e| format!("Failed to delete stream {stream_name}: {e}"))?; } else { println!("Purging stream: {stream_name}"); let stream = js .get_stream(&stream_name) .await .map_err(|e| format!("Failed to get stream {stream_name}: {e}"))?; stream .purge() .await .map_err(|e| format!("Failed to purge stream {stream_name}: {e}"))?; } } Ok(()) } const DATABASE_DUMP_SCRIPT: &str = include_str!("../../scripts/dump-database.sh"); pub async fn dump_databases( databases: &[String], recording_id: &str, variant: &str, ) -> Result<(), String> { let script_path = std::env::temp_dir().join("dump-database.sh"); { let mut file = fs::File::create(&script_path) .map_err(|e| format!("Failed to create script file: {e}"))?; file.write_all(DATABASE_DUMP_SCRIPT.as_bytes()) .map_err(|e| format!("Failed to write script: {e}"))?; let mut perms = file .metadata() .map_err(|e| format!("Failed to read metadata: {e}"))? .permissions(); perms.set_mode(0o755); fs::set_permissions(&script_path, perms) .map_err(|e| format!("Failed to set permissions: {e}"))?; } for db in databases { let db = db.clone(); let recording_id = recording_id.to_string(); let variant = variant.to_string(); let script_path = script_path.clone(); task::spawn_blocking(move || { Command::new(&script_path) .arg(&db) .arg(&recording_id) .arg(&variant) .status() .map_err(|e| format!("Failed to run script: {e}")) .and_then(|status| { if status.success() { Ok(()) } else { Err(format!("Script failed for {db} with exit code: {status}")) } }) }) .await .map_err(|e| format!("Join error: {e}"))??; } Ok(()) } const DATABASE_PREPARE_SCRIPT: &str = include_str!("../../scripts/prepare-database.sh"); pub async fn prepare_databases(sql_paths: Vec<String>) -> Result<(), String> { let script_path = std::env::temp_dir().join("prepare-database.sh"); { let mut file = fs::File::create(&script_path) .map_err(|e| format!("Failed to create script file: {e}"))?; file.write_all(DATABASE_PREPARE_SCRIPT.as_bytes()) .map_err(|e| format!("Failed to write script: {e}"))?; let mut perms = file .metadata() .map_err(|e| format!("Failed to read metadata: {e}"))? .permissions(); perms.set_mode(0o755); fs::set_permissions(&script_path, perms) .map_err(|e| format!("Failed to set permissions: {e}"))?; } // Spawn a task for each SQL file to be restored for sql_path in sql_paths { let path = Path::new(&sql_path); let file_name = path .file_name() .and_then(|f| f.to_str()) .ok_or_else(|| format!("Invalid file name in path: {sql_path}"))?; let database_name = if file_name == "globals.sql" { "postgres".to_string() } else if file_name.ends_with("public_schema.sql") { file_name .strip_suffix("_public_schema.sql") .ok_or_else(|| format!("Invalid public_schema filename: {file_name}"))? .to_string() } else { return Err(format!("Unknown SQL filename pattern: {file_name}")); }; let script_path = script_path.clone(); let sql_path = sql_path.clone(); task::spawn_blocking(move || { Command::new(&script_path) .arg(&sql_path) .arg(&database_name) .stdout(Stdio::null()) // suppress stdout as it's super chatty .stderr(Stdio::inherit()) .status() .map_err(|e| format!("Failed to run script: {e}")) .and_then(|status| { if status.success() { Ok(()) } else { Err(format!( "Script failed for {sql_path} with exit code: {status}" )) } }) }) .await .map_err(|e| format!("Join error: {e}"))??; } Ok(()) } pub async fn publish_artifact( artifact_id: &str, aws_credentials: AwsCredentials, config: &ArtifactStoreConfig, ) -> PublishResult { let start_time = Instant::now(); let result: Result<(), String> = async { let bucket_name = config .metadata .get("bucketName") .and_then(Value::as_str) .ok_or_else(|| "Missing `bucketName` in config metadata".to_string())?; // Check if credentials are empty let access_key_empty = aws_credentials.access_key_id().trim().is_empty(); let secret_key_empty = aws_credentials.secret_access_key().trim().is_empty(); if access_key_empty || secret_key_empty { return Err(format!( "Credentials are required to publish to the artifact store: {artifact_id}" )); } let region = AwsRegion::new("us-east-1"); let config = ConfigBuilder::new() .behavior_version(BehaviorVersion::latest()) .region(region) .credentials_provider(aws_credentials) .build(); let client = Client::from_conf(config); let s3_prefix = format!("bedrock/datasources/{artifact_id}/"); let existing = client .list_objects_v2() .bucket(bucket_name) .prefix(&s3_prefix) .send() .await .map_err(|e| format!("Failed to list objects: {e}"))?; if existing.key_count().unwrap_or(0) > 0 { return Err(format!( "Test '{artifact_id}' already exists. Please re-identify and retry." )); } let base_path = PathBuf::from("recordings/datasources").join(artifact_id); if !base_path.exists() { return Err(format!("Local artifact path does not exist: {base_path:?}")); } let file_paths: Vec<PathBuf> = collect_files(artifact_id).await?; for (index, path) in file_paths.iter().enumerate() { let key = path .strip_prefix("recordings/datasources") .map_err(|e| e.to_string())? .to_string_lossy() .replace('\\', "/"); let s3_key = format!("bedrock/datasources/{key}"); let total = file_paths.len(); let artifact_index = index + 1; info!( "[{}/{}] Uploading -> s3://{}/{}", artifact_index, total, bucket_name, s3_key ); let file = TokioFile::open(&path) .await .map_err(|e| format!("Failed to open file {path:?}: {e}"))?; let metadata = file.metadata().await.map_err(|e| e.to_string())?; let size_bytes = metadata.len() as usize; let total_bytes = Arc::new(AtomicUsize::new(0)); let downloading = Arc::new(AtomicBool::new(true)); let total_bytes_clone = Arc::clone(&total_bytes); let downloading_clone = Arc::clone(&downloading); let logger_handle = tokio::spawn(async move { let mut ticker = interval(Duration::from_secs(5)); while downloading_clone.load(Ordering::Relaxed) { ticker.tick().await; let downloaded = total_bytes_clone.load(Ordering::Relaxed); let percent = if size_bytes > 0 { (downloaded as f64 / size_bytes as f64).min(1.0) } else { 1.0 }; let bar = progress_bar_line(percent, artifact_index, total); info!("{}", bar); } }); let mut buffer = Vec::with_capacity(size_bytes); let mut reader = BufReader::new(file); let mut chunk = [0u8; 8192]; loop { let n = reader.read(&mut chunk).await.map_err(|e| e.to_string())?; if n == 0 { break; } total_bytes.fetch_add(n, Ordering::Relaxed); buffer.extend_from_slice(&chunk[..n]); } downloading.store(false, Ordering::Relaxed); logger_handle.abort(); client .put_object() .bucket(bucket_name) .key(&s3_key) .body(ByteStream::from(buffer)) .send() .await .map_err(|e| format!("Upload failed: {e}"))?; let bar = progress_bar_line(1.0, artifact_index, total); info!("{}", bar); let progress_bar = format!( "[{}>{}]", "=".repeat((artifact_index * 20) / total), " ".repeat(20 - (artifact_index * 20) / total) ); info!("{} Finished: {}", progress_bar, path.display()); } Ok(()) } .await; let duration = start_time.elapsed().as_millis() as u64; match result { Ok(_) => PublishResult { success: true, message: "Artifact published successfully".into(), duration_ms: Some(duration), output: None, }, Err(e) => { error!("Publishing failed: {}", e); PublishResult { success: false, message: e, duration_ms: Some(duration), output: None, } } } }

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/systeminit/si'

If you have feedback or need assistance with the MCP directory API, please join our Discord server