Skip to main content
Glama
8b-is
by 8b-is
embedding.rs1.84 kB
//! Embedding module. use crate::nn::Module; use crate::tensor::Tensor; pub struct Embedding { weights: Tensor, } impl Embedding { pub fn new(vocab_size: usize, embed_dim: usize) -> Self { println!( "INFO: Initializing Embedding layer with vocab_size={}, embed_dim={}", vocab_size, embed_dim ); // Initialize weights randomly. This matrix will be learned during training. let weights = Tensor::rand(vec![vocab_size, embed_dim]); Self { weights } } } impl Module for Embedding { fn forward(&self, input: &Tensor) -> Tensor { // Input can be 1D [seq_len] or 2D [batch_size, seq_len]. let input_shape = input.shape(); let embed_dim = self.weights.shape()[1]; if input_shape.len() == 1 { // Handle 1D input (single sequence) for backward compatibility or inference. self.weights.gather(input, 0) } else if input_shape.len() == 2 { // Handle 2D input (batch of sequences). let batch_size = input_shape[0]; let seq_len = input_shape[1]; // 1. Reshape [batch_size, seq_len] -> [batch_size * seq_len] let flat_input = input.reshape(vec![batch_size * seq_len]); // 2. Gather embeddings: [batch_size * seq_len] -> [batch_size * seq_len, embed_dim] let gathered = self.weights.gather(&flat_input, 0); // 3. Reshape back to [batch_size, seq_len, embed_dim] gathered.reshape(vec![batch_size, seq_len, embed_dim]) } else { panic!( "Embedding input must be a 1D or 2D tensor of token IDs, but got shape {:?}", input_shape ); } } fn parameters(&self) -> Vec<Tensor> { vec![self.weights.clone()] } }

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/8b-is/smart-tree'

If you have feedback or need assistance with the MCP directory API, please join our Discord server