Skip to main content
Glama
8b-is
by 8b-is
cli.rs22.4 kB
use rusty_flow::data::{DataLoader, Vocabulary}; use rusty_flow::gpu; use rusty_flow::loss; use rusty_flow::models::transformer::{Transformer, TransformerConfig}; use rusty_flow::nn::Module; use rusty_flow::optimizer::{Optimizer, SGD}; use rusty_flow::tensor::Tensor; use std::env; use std::fs::{self, File}; use std::io::{self, BufReader, BufWriter, Write}; use std::path::Path; use std::time::Instant; use sysinfo::{ProcessExt, System, SystemExt}; use bincode; use ndarray::{s, ArrayD}; use serde::{Deserialize, Serialize}; // --- Serialization --- /// A container for saving/loading the model, its config, and vocabulary. #[derive(Serialize, Deserialize)] struct SerializableModel { config: TransformerConfig, vocab: Vocabulary, params: Vec<ArrayD<f32>>, seq_len: usize, } /// Helper to save the model and vocab to a file. fn save_model( model: &Transformer, config: &TransformerConfig, vocab: &Vocabulary, seq_len: usize, path: &str, ) -> Result<(), Box<dyn std::error::Error>> { // Ensure directory exists if let Some(parent) = Path::new(path).parent() { fs::create_dir_all(parent)?; } let params_data: Vec<ArrayD<f32>> = model.parameters().iter().map(|p| p.data().clone()).collect(); let serializable = SerializableModel { config: config.clone(), vocab: vocab.clone(), params: params_data, seq_len, }; let file = File::create(path)?; let writer = BufWriter::new(file); bincode::serialize_into(writer, &serializable)?; println!(" - Model saved to {}", path); Ok(()) } /// Helper to load a model and vocab from a file. fn load_model(path: &str) -> Result<(Transformer, Vocabulary, usize), Box<dyn std::error::Error>> { let file = File::open(path)?; let reader = BufReader::new(file); let loaded: SerializableModel = bincode::deserialize_from(reader)?; let model = Transformer::new(&loaded.config); let model_params = model.parameters(); if loaded.params.len() != model_params.len() { return Err(format!( "Parameter count mismatch: loaded {} vs model {}", loaded.params.len(), model_params.len() ) .into()); } for (p, data) in model_params.iter().zip(loaded.params) { if p.shape() != data.shape() { return Err(format!( "Shape mismatch for parameter: model {:?} vs loaded {:?}", p.shape(), data.shape() ) .into()); } p.inner.borrow_mut().data = data; } println!(" - Model loaded successfully from {}", path); Ok((model, loaded.vocab, loaded.seq_len)) } // --- Helper Functions --- /// A simple helper function to parse command-line arguments. fn get_arg_value(args: &[String], key: &str) -> Option<String> { args.iter() .position(|arg| arg == key) .and_then(|pos| args.get(pos + 1).cloned()) } /// Reports the current process's memory usage. fn report_memory_usage(system: &mut System, stage: &str) { if let Ok(pid) = sysinfo::get_current_pid() { system.refresh_process(pid); if let Some(process) = system.process(pid) { let memory_mb = process.memory() as f32 / 1024.0; println!(" - MEMORY USAGE ({}): {:.2} MB", stage, memory_mb); } } } /// Splits a corpus into training and validation sets by words. fn split_corpus_by_words(corpus: &str, train_ratio: f32) -> (String, String) { let words: Vec<&str> = corpus.split_whitespace().collect(); let train_len = (words.len() as f32 * train_ratio).floor() as usize; if train_len >= words.len() { return (corpus.to_string(), String::new()); } ( words[..train_len].join(" "), words[train_len..].join(" "), ) } // --- Main Logic --- fn main() { let args: Vec<String> = env::args().collect(); // --- Global Setup: Handle GPU initialization early --- let use_gpu_arg = get_arg_value(&args, "--use-gpu") .map(|s| s.parse().unwrap_or(false)) .unwrap_or(false); if use_gpu_arg { // Attempt to initialize GPU context and set the flag if successful. if let Some(_ctx) = &*gpu::GPU_CONTEXT { // The Lazy will print the adapter info on first access. gpu::USE_GPU.store(true, std::sync::atomic::Ordering::Relaxed); } else { println!("WARN: --use-gpu was specified, but no compatible GPU was found. Falling back to CPU."); gpu::USE_GPU.store(false, std::sync::atomic::Ordering::Relaxed); } } // --- Dispatch to appropriate mode --- if args.contains(&"--list-models".to_string()) { list_models(); } else if args.contains(&"--chat".to_string()) { run_chat_session(args); } else { run_training_session(args); } } fn list_models() { println!("--- Available Models ---"); let models_dir = "models"; if let Ok(entries) = fs::read_dir(models_dir) { let mut models_found = false; for entry in entries { if let Ok(entry) = entry { let path = entry.path(); if path.is_file() && path.extension().and_then(|s| s.to_str()) == Some("bin") { println!(" - {}", path.display()); models_found = true; } } } if !models_found { println!(" No models found in '{}' directory.", models_dir); } } else { println!(" '{}' directory not found or could not be read.", models_dir); println!(" (You might need to create it: mkdir models)"); } println!("\nTo chat with a model, update MODEL_PATH in your config file or run './train.sh' to create one."); } fn run_training_session(args: Vec<String>) { println!("--- RustyFlow Language Model Training ---"); // --- 1. Define Dataset and Hyperparameters --- println!("\n[Step 1: Dataset and Hyperparameters]"); let dataset_arg = get_arg_value(&args, "--dataset").unwrap_or_else(|| "tinyshakespeare".to_string()); let save_path = get_arg_value(&args, "--save-path"); let load_path = get_arg_value(&args, "--load-path"); let batch_size = get_arg_value(&args, "--batch-size") .and_then(|s| s.parse().ok()) .unwrap_or(16); let seq_len = get_arg_value(&args, "--seq-len") .and_then(|s| s.parse().ok()) .unwrap_or(64); let num_epochs = get_arg_value(&args, "--num-epochs") .and_then(|s| s.parse().ok()) .unwrap_or(10); let learning_rate = get_arg_value(&args, "--learning-rate") .and_then(|s| s.parse().ok()) .unwrap_or(0.01); println!(" - Dataset: {}", dataset_arg); println!(" - Batch Size: {}", batch_size); println!(" - Sequence Length: {}", seq_len); println!(" - Epochs: {}", num_epochs); println!(" - Learning Rate: {}", learning_rate); // --- 2. Load Data and Build Vocabulary --- println!("\n[Step 2: Vocabulary and Data Loading]"); let (train_corpus, valid_corpus, data_source_display) = match dataset_arg.as_str() { "short" => ( "A short paragraph for testing. It has a small vocabulary.".to_string(), String::new(), "in-memory short paragraph".to_string(), ), "tinyshakespeare" => { let path = "data/tinyshakespeare/input.txt"; if !Path::new(path).exists() { panic!("Dataset not found at '{}'. Please run './run.sh get-data tinyshakespeare'.", path); } let corpus = fs::read_to_string(path).expect("Failed to read dataset file."); let (train, valid) = split_corpus_by_words(&corpus, 0.9); (train, valid, path.to_string()) } "wikitext-2" => { let (train_path, valid_path) = ("data/wikitext-2/wiki.train.tokens", "data/wikitext-2/wiki.valid.tokens"); if !Path::new(train_path).exists() { panic!("WikiText-2 dataset not found. The project should include these files."); } let train = fs::read_to_string(train_path).expect("Failed to read train file."); let valid = fs::read_to_string(valid_path).expect("Failed to read valid file."); (train, valid, "wikitext-2 (tokenized)".to_string()) } path if Path::new(path).exists() => { let corpus = fs::read_to_string(path).expect("Failed to read custom dataset file."); let (train, valid) = split_corpus_by_words(&corpus, 0.9); (train, valid, path.to_string()) } _ => panic!("Dataset '{}' not found.", dataset_arg), }; println!(" - Using data source: {}", data_source_display); let vocab = Vocabulary::new(&train_corpus); println!(" - Vocabulary built with {} unique tokens.", vocab.size()); // --- 3. Model Configuration --- println!("\n[Step 3: Model Configuration]"); let embed_dim = get_arg_value(&args, "--embed-dim") .and_then(|s| s.parse().ok()) .unwrap_or(128); let num_heads = get_arg_value(&args, "--num-heads") .and_then(|s| s.parse().ok()) .unwrap_or(4); let num_layers = get_arg_value(&args, "--num-layers") .and_then(|s| s.parse().ok()) .unwrap_or(2); let config = TransformerConfig { vocab_size: vocab.size(), embed_dim, num_heads, num_layers, }; println!(" - Vocabulary Size: {}", config.vocab_size); println!(" - Embedding Dimension: {}", config.embed_dim); println!(" - Number of Heads: {}", config.num_heads); println!(" - Number of Layers: {}", config.num_layers); // --- 4. Instantiate Model and Optimizer --- println!("\n[Step 4: Instantiate Model and Optimizer]"); let model = Transformer::new(&config); if let Some(path) = load_path { println!("\n[INFO] Loading model from '{}' to continue training...", path); let file = File::open(&path).expect("Failed to open model file."); let reader = BufReader::new(file); let loaded: SerializableModel = bincode::deserialize_from(reader).expect("Failed to deserialize model."); let model_params = model.parameters(); if loaded.params.len() == model_params.len() { for (p, data) in model_params.iter().zip(loaded.params) { if p.shape() == data.shape() { p.inner.borrow_mut().data = data; } else { panic!("Shape mismatch while loading model for continued training."); } } println!(" - Loaded weights successfully."); } else { panic!("Parameter count mismatch while loading model for continued training."); } } let mut optimizer = SGD::new(model.parameters(), learning_rate); println!(" - Full Transformer model and SGD optimizer created."); // --- 5. Training Loop --- println!("\n[Step 5: Training Loop]"); // Reset profiling counters before training gpu::TOTAL_GPU_TIME_NS.store(0, std::sync::atomic::Ordering::Relaxed); rusty_flow::tensor::CPU_MATMUL_TIME_NS.store(0, std::sync::atomic::Ordering::Relaxed); let mut system = System::new_all(); report_memory_usage(&mut system, "Before Training"); let training_start_time = Instant::now(); for epoch in 0..num_epochs { let mut epoch_loss = 0.0; let mut num_batches = 0; let data_loader = DataLoader::new(&train_corpus, &vocab, batch_size, seq_len); if epoch == 0 { println!( " - DataLoader created {} training sequences.", data_loader.num_sequences() ); } for (i, batch) in data_loader.enumerate() { let logits = model.forward(&batch.inputs); let loss = loss::cross_entropy_loss(&logits, &batch.targets); epoch_loss += *loss.data().first().unwrap(); num_batches += 1; optimizer.zero_grad(); loss.backward(); optimizer.step(); if (i + 1) % 50 == 0 { println!( " - Epoch {:>3}/{} | Batch {:>4} | Batch Loss = {:.6}", epoch + 1, num_epochs, i + 1, *loss.data().first().unwrap() ); } } if num_batches > 0 { println!( " - Epoch {:>3}/{}: Average Loss = {:.6}", epoch + 1, num_epochs, epoch_loss / num_batches as f32 ); } report_memory_usage(&mut system, &format!("End of Epoch {}", epoch + 1)); } let training_duration = training_start_time.elapsed(); println!(" - Training finished in {:?}.", training_duration); // --- 5.5 Profiling Report --- println!("\n[Step 5.5: Profiling Report]"); let use_gpu = gpu::USE_GPU.load(std::sync::atomic::Ordering::Relaxed); let gpu_time_ns = gpu::TOTAL_GPU_TIME_NS.load(std::sync::atomic::Ordering::Relaxed); let cpu_matmul_time_ns = rusty_flow::tensor::CPU_MATMUL_TIME_NS.load(std::sync::atomic::Ordering::Relaxed); if use_gpu && gpu_time_ns > 0 { let gpu_time_s = gpu_time_ns as f64 / 1e9; println!( " - Total time in GPU matmul kernels: {:.4}s", gpu_time_s ); } if cpu_matmul_time_ns > 0 { let cpu_time_s = cpu_matmul_time_ns as f64 / 1e9; println!( " - Total time in CPU matmul fallback: {:.4}s", cpu_time_s ); } let total_training_s = training_duration.as_secs_f64(); let total_matmul_ns = gpu_time_ns + cpu_matmul_time_ns; let total_matmul_s = total_matmul_ns as f64 / 1e9; if total_training_s > 0.0 && total_matmul_s > 0.0 { println!( " - Total matmul time (CPU + GPU): {:.4}s ({:.1}% of total training time)", total_matmul_s, (total_matmul_s / total_training_s) * 100.0 ); let other_time_s = (total_training_s - total_matmul_s).max(0.0); println!( " - Other CPU time (data loading, other ops, etc.): {:.4}s ({:.1}% of total training time)", other_time_s, (other_time_s / total_training_s) * 100.0 ); } // --- End Profiling Report --- // Add parseable output for logging if num_epochs > 0 { let avg_epoch_time_secs = training_duration.as_secs_f32() / num_epochs as f32; println!("[AVG_EPOCH_TIME] {:.4}", avg_epoch_time_secs); } // --- 6. Evaluation --- println!("\n[Step 6: Evaluation]"); let eval_start_time = Instant::now(); run_evaluation(&model, &valid_corpus, &vocab, batch_size, seq_len); println!(" - Evaluation finished in {:?}.", eval_start_time.elapsed()); // --- 7. Save Model --- if let Some(path) = save_path { println!("\n[Step 7: Saving Model]"); save_model(&model, &config, &vocab, seq_len, &path).expect("Failed to save model."); } println!("\n--- Training Finished ---"); } fn run_chat_session(args: Vec<String>) { println!("--- RustyFlow Interactive Chat ---"); let load_path = get_arg_value(&args, "--load-path").expect("--load-path is required for chat mode."); let temperature = get_arg_value(&args, "--temperature") .and_then(|s| s.parse::<f32>().ok()) .unwrap_or(0.8); let top_p = get_arg_value(&args, "--top-p") .and_then(|s| s.parse::<f32>().ok()) .unwrap_or(0.9); println!("\n[Step 1: Loading Model]"); let (model, vocab, seq_len) = load_model(&load_path).expect("Failed to load model."); println!(" - Context length from model file: {}", seq_len); println!("\n[Step 2: Start Chatting]"); println!("Model loaded. You can now chat with the model."); println!("Type a prompt and press Enter. Type 'exit' or 'quit' to end."); loop { print!("> "); io::stdout().flush().unwrap(); let mut prompt = String::new(); io::stdin().read_line(&mut prompt).expect("Failed to read line"); let prompt = prompt.trim(); if prompt == "exit" || prompt == "quit" { break; } if prompt.is_empty() { continue; } generate_text(&model, &vocab, prompt, seq_len, 50, temperature, top_p); } } // --- Standalone Functions (used by both modes) --- fn run_evaluation(model: &Transformer, corpus: &str, vocab: &Vocabulary, batch_size: usize, seq_len: usize) { if corpus.trim().is_empty() { println!(" - Validation set is empty, skipping evaluation."); return; } println!(" - Running evaluation on validation set..."); let mut total_loss = 0.0; let mut num_batches = 0; let data_loader = DataLoader::new(corpus, vocab, batch_size, seq_len); model.zero_grad(); for batch in data_loader { let logits = model.forward(&batch.inputs); let loss = loss::cross_entropy_loss(&logits, &batch.targets); total_loss += *loss.data().first().unwrap(); num_batches += 1; loss.backward(); // Clear graph } if num_batches > 0 { let avg_loss = total_loss / num_batches as f32; let perplexity = avg_loss.exp(); println!(" ------------------------------------"); println!(" - Validation Average Loss: {:.4}", avg_loss); println!(" - Validation Perplexity: {:.4}", perplexity); println!(" ------------------------------------"); } } fn generate_text( model: &Transformer, vocab: &Vocabulary, prompt: &str, seq_len: usize, num_tokens_to_generate: usize, temperature: f32, top_p: f32, ) { println!(" - Generating {} tokens...", num_tokens_to_generate); let mut generated_tokens = vocab.tokenize(prompt); let mut result_string = prompt.to_string(); let pad_id = vocab.word_to_id["<pad>"]; for _ in 0..num_tokens_to_generate { let context_start = generated_tokens.len().saturating_sub(seq_len); let mut input_tokens = generated_tokens[context_start..].to_vec(); let padding_needed = seq_len.saturating_sub(input_tokens.len()); if padding_needed > 0 { input_tokens.splice(0..0, vec![pad_id; padding_needed]); } let input_f32: Vec<f32> = input_tokens.iter().map(|&id| id as f32).collect(); let input_tensor = Tensor::new(input_f32, vec![1, seq_len]); let logits = model.forward(&input_tensor); // Scope the data borrow so it's released before backward() is called. let next_token_id = { let logits_data = logits.data(); let last_token_logits_slice = logits_data.slice(s![0, seq_len - 1, ..]); // --- Advanced Sampling (Temperature and Top-P) --- // 1. Apply temperature let scaled_logits: Vec<f32> = last_token_logits_slice .iter() .map(|&l| l / temperature) .collect(); // 2. Calculate softmax probabilities let max_logit = scaled_logits .iter() .fold(f32::NEG_INFINITY, |a, &b| a.max(b)); let exps: Vec<f32> = scaled_logits .iter() .map(|&l| (l - max_logit).exp()) .collect(); let sum_exps: f32 = exps.iter().sum(); let probs: Vec<f32> = exps.iter().map(|&e| e / sum_exps).collect(); // --- Top-P (Nucleus) Sampling --- let mut sampled_id: u32; if top_p < 1.0 && top_p > 0.0 { // Create a vector of (index, probability) pairs to sort let mut indexed_probs: Vec<(usize, f32)> = probs.iter().enumerate().map(|(i, &p)| (i, p)).collect(); // Sort by probability in descending order indexed_probs.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); // Find the nucleus of tokens let mut cumulative_prob = 0.0; let mut nucleus_indices = Vec::new(); for &(index, prob) in &indexed_probs { nucleus_indices.push((index, prob)); cumulative_prob += prob; if cumulative_prob >= top_p { break; } } // Re-normalize the probabilities of the nucleus tokens let nucleus_sum_probs: f32 = nucleus_indices.iter().map(|&(_, p)| p).sum(); let nucleus_probs: Vec<(usize, f32)> = nucleus_indices .iter() .map(|&(i, p)| (i, p / nucleus_sum_probs)) .collect(); // Sample from the nucleus distribution let r: f32 = rand::random(); let mut cum_prob = 0.0; sampled_id = nucleus_probs.last().unwrap().0 as u32; // Default to last in nucleus for &(index, prob) in &nucleus_probs { cum_prob += prob; if r < cum_prob { sampled_id = index as u32; break; } } } else { // Original inverse transform sampling on the full distribution let r: f32 = rand::random(); let mut cum_prob = 0.0; sampled_id = (vocab.size() - 1) as u32; // Default to last token for (i, &prob) in probs.iter().enumerate() { cum_prob += prob; if r < cum_prob { sampled_id = i as u32; break; } } } sampled_id }; // Clear the computation graph to free memory. logits.backward(); if next_token_id == pad_id || next_token_id == vocab.word_to_id["<unk>"] { break; } generated_tokens.push(next_token_id); if let Some(word) = vocab.id_to_word.get(&next_token_id) { result_string.push(' '); result_string.push_str(word); } } println!(" - Result: \"{}\"", result_string); }

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/8b-is/smart-tree'

If you have feedback or need assistance with the MCP directory API, please join our Discord server