Skip to main content
Glama
8b-is
by 8b-is
loss.rs1.29 kB
//! Loss functions. use crate::tensor::Tensor; /// Computes the cross-entropy loss between logits and targets. /// /// This is the standard loss function for multi-class classification. /// It is equivalent to applying `log_softmax` to the logits and then `nll_loss`. /// /// # Arguments /// * `logits` - The raw, unnormalized scores from the model. Shape `[..., num_classes]`. /// * `targets` - The ground truth labels (class indices). Shape `[...]`. /// /// The shapes of logits and targets must be broadcastable, with `logits` having one extra dimension. pub fn cross_entropy_loss(logits: &Tensor, targets: &Tensor) -> Tensor { let logits_shape = logits.shape(); let num_classes = logits_shape[logits_shape.len() - 1]; // Reshape logits to 2D ([N, num_classes]) and targets to 1D ([N]) // where N is the product of all other dimensions (e.g., batch_size * seq_len). let n: usize = logits_shape[..logits_shape.len() - 1].iter().product(); let logits_2d = logits.reshape(vec![n, num_classes]); let targets_1d = targets.reshape(vec![n]); // 1. Calculate log_softmax of the logits. let log_probs = logits_2d.log_softmax(1); // Softmax over the class dimension. // 2. Compute Negative Log Likelihood loss. log_probs.nll_loss(&targets_1d) }

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/8b-is/smart-tree'

If you have feedback or need assistance with the MCP directory API, please join our Discord server