Skip to main content
Glama
8b-is
by 8b-is
gpu.rs18.4 kB
//! GPU compute backend using wgpu. use bytemuck; use ndarray::{s, Array, ArrayD, Axis, Ix2}; use once_cell::sync::Lazy; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::time::Instant; use wgpu::util::DeviceExt; /// Holds the wgpu device, queue, and limits for GPU operations. pub struct GpuContext { pub device: wgpu::Device, pub queue: wgpu::Queue, pub limits: wgpu::Limits, } impl GpuContext { /// Tries to initialize a new GpuContext. /// This is an async function that is run inside a blocking context. fn new() -> Option<Self> { pollster::block_on(async { let instance = wgpu::Instance::new(wgpu::InstanceDescriptor::default()); let adapter = instance .request_adapter(&wgpu::RequestAdapterOptions { power_preference: wgpu::PowerPreference::HighPerformance, compatible_surface: None, force_fallback_adapter: false, }) .await?; println!("INFO: Using GPU adapter: {}", adapter.get_info().name); let limits = adapter.limits(); let (device, queue) = adapter .request_device( &wgpu::DeviceDescriptor { label: Some("RustyFlow GPU Device"), required_features: wgpu::Features::empty(), required_limits: limits.clone(), }, None, ) .await .ok()?; Some(Self { device, queue, limits }) }) } } /// A lazily initialized, global GPU context. /// It will attempt to find a GPU device the first time it's accessed. pub static GPU_CONTEXT: Lazy<Option<GpuContext>> = Lazy::new(GpuContext::new); /// A global flag to enable or disable GPU acceleration. /// This is set by the application at startup. pub static USE_GPU: AtomicBool = AtomicBool::new(false); /// An atomic counter to track the total time spent waiting on GPU matmul kernels. pub static TOTAL_GPU_TIME_NS: AtomicU64 = AtomicU64::new(0); /// Performs 2D matrix multiplication on the GPU, with tiling for large matrices. pub fn gpu_matmul_2d( gpu_context: &GpuContext, a: &ArrayD<f32>, b: &ArrayD<f32>, ) -> Result<ArrayD<f32>, &'static str> { let a_shape = a.shape(); let b_shape = b.shape(); if a.ndim() != 2 || b.ndim() != 2 { return Err("GPU matmul currently only supports 2D matrices."); } if a_shape[1] != b_shape[0] { return Err("Incompatible dimensions for matrix multiplication."); } let m = a_shape[0]; let n = b_shape[1]; let output_size_bytes = m * n * std::mem::size_of::<f32>(); let limit = gpu_context.limits.max_storage_buffer_binding_size as usize; if output_size_bytes <= limit { // If it fits, run the operation in one go. return run_matmul_shader(gpu_context, a, b); } // --- Tiling logic --- println!( "INFO: [GPU] Output size ({:.2} MB) exceeds limit ({:.2} MB). Tiling matmul.", output_size_bytes as f32 / (1024.0 * 1024.0), limit as f32 / (1024.0 * 1024.0) ); let bytes_per_output_column = m * std::mem::size_of::<f32>(); if bytes_per_output_column > limit { return Err("Cannot tile matrix multiplication: a single output column is too large for GPU memory limits."); } let mut n_tile = limit / bytes_per_output_column; let workgroup_size_y = 8; n_tile = (n_tile / workgroup_size_y) * workgroup_size_y; if n_tile == 0 { // This can happen if limit / bytes_per_output_column is < 8. // We must process at least one column. The shader can handle this with boundary checks. n_tile = (limit / bytes_per_output_column).max(1); } let a_dyn = a.clone(); let b_2d = b.clone().into_dimensionality::<Ix2>().unwrap(); let mut result_tiles = Vec::new(); for n_start in (0..n).step_by(n_tile) { let n_end = (n_start + n_tile).min(n); let b_tile_view = b_2d.slice(s![.., n_start..n_end]); let b_tile_owned = b_tile_view.to_owned().into_dyn(); // This call is NOT recursive because `run_matmul_shader` doesn't do tiling. let c_tile = run_matmul_shader(gpu_context, &a_dyn, &b_tile_owned)?; result_tiles.push(c_tile.into_dimensionality::<Ix2>().unwrap()); } // Concatenate tiles on CPU let views: Vec<_> = result_tiles.iter().map(|a| a.view()).collect(); let final_result = ndarray::concatenate(Axis(1), &views) .map_err(|_| "Failed to concatenate result tiles on CPU.")? .into_dyn(); Ok(final_result) } /// The core GPU matmul logic that runs a single compute shader. fn run_matmul_shader( gpu_context: &GpuContext, a: &ArrayD<f32>, b: &ArrayD<f32>, ) -> Result<ArrayD<f32>, &'static str> { let GpuContext { device, queue, .. } = gpu_context; let a_shape = a.shape(); let b_shape = b.shape(); let m = a_shape[0]; let k = a_shape[1]; let n = b_shape[1]; let a_2d = a.clone().into_dimensionality::<Ix2>().unwrap(); let b_2d = b.clone().into_dimensionality::<Ix2>().unwrap(); let shader_code = format!( r#" struct Matrix {{ data: array<f32>, }}; @group(0) @binding(0) var<storage, read> a: Matrix; @group(0) @binding(1) var<storage, read> b: Matrix; @group(0) @binding(2) var<storage, read_write> c: Matrix; @compute @workgroup_size(8, 8, 1) fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {{ let m: u32 = {}u; let k: u32 = {}u; let n: u32 = {}u; if (global_id.x >= m || global_id.y >= n) {{ return; }} var sum = 0.0; for (var i = 0u; i < k; i = i + 1u) {{ sum = sum + a.data[global_id.x * k + i] * b.data[i * n + global_id.y]; }} c.data[global_id.x * n + global_id.y] = sum; }} "#, m, k, n ); let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor { label: Some("Matmul Shader"), source: wgpu::ShaderSource::Wgsl(shader_code.into()), }); let a_contig = a_2d.as_standard_layout(); let b_contig = b_2d.as_standard_layout(); let a_slice = a_contig .as_slice() .ok_or("Matrix A is not contiguous in memory for GPU transfer")?; let b_slice = b_contig .as_slice() .ok_or("Matrix B is not contiguous in memory for GPU transfer")?; let a_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { label: Some("Matrix A Buffer"), contents: bytemuck::cast_slice(a_slice), usage: wgpu::BufferUsages::STORAGE, }); let b_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { label: Some("Matrix B Buffer"), contents: bytemuck::cast_slice(b_slice), usage: wgpu::BufferUsages::STORAGE, }); let c_buffer_size = (m * n * std::mem::size_of::<f32>()) as u64; let c_buffer = device.create_buffer(&wgpu::BufferDescriptor { label: Some("Matrix C Buffer"), size: c_buffer_size, usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, mapped_at_creation: false, }); let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { label: Some("Matmul Bind Group Layout"), entries: &[ wgpu::BindGroupLayoutEntry { binding: 0, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: true }, has_dynamic_offset: false, min_binding_size: None }, count: None, }, wgpu::BindGroupLayoutEntry { binding: 1, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: true }, has_dynamic_offset: false, min_binding_size: None }, count: None, }, wgpu::BindGroupLayoutEntry { binding: 2, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: false }, has_dynamic_offset: false, min_binding_size: None }, count: None, }, ], }); let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { label: Some("Matmul Pipeline Layout"), bind_group_layouts: &[&bind_group_layout], push_constant_ranges: &[], }); let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { label: Some("Matmul Pipeline"), layout: Some(&pipeline_layout), module: &shader, entry_point: "main", }); let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { label: Some("Matmul Bind Group"), layout: &bind_group_layout, entries: &[ wgpu::BindGroupEntry { binding: 0, resource: a_buffer.as_entire_binding() }, wgpu::BindGroupEntry { binding: 1, resource: b_buffer.as_entire_binding() }, wgpu::BindGroupEntry { binding: 2, resource: c_buffer.as_entire_binding() }, ], }); let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("Matmul Command Encoder"), }); { let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("Matmul Compute Pass"), timestamp_writes: None, }); compute_pass.set_pipeline(&pipeline); compute_pass.set_bind_group(0, &bind_group, &[]); let workgroup_size_x = 8; let workgroup_size_y = 8; compute_pass.dispatch_workgroups( (m as u32 + workgroup_size_x - 1) / workgroup_size_x, (n as u32 + workgroup_size_y - 1) / workgroup_size_y, 1, ); } let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { label: Some("Staging Buffer"), size: c_buffer_size, usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, mapped_at_creation: false, }); encoder.copy_buffer_to_buffer(&c_buffer, 0, &staging_buffer, 0, c_buffer_size); queue.submit(std::iter::once(encoder.finish())); let buffer_slice = staging_buffer.slice(..); let (tx, rx) = std::sync::mpsc::channel(); buffer_slice.map_async(wgpu::MapMode::Read, move |result| { tx.send(result).unwrap(); }); let wait_start = Instant::now(); device.poll(wgpu::Maintain::Wait); rx.recv().unwrap().unwrap(); TOTAL_GPU_TIME_NS.fetch_add( wait_start.elapsed().as_nanos() as u64, Ordering::Relaxed, ); let data = buffer_slice.get_mapped_range(); let result_vec: Vec<f32> = bytemuck::cast_slice(&data).to_vec(); drop(data); staging_buffer.unmap(); Ok(ArrayD::from_shape_vec(vec![m, n], result_vec).unwrap()) } /// Performs batched 3D matrix multiplication on the GPU. /// Shapes: a[B, M, K], b[B, K, N] -> c[B, M, N] pub fn gpu_matmul_3d( gpu_context: &GpuContext, a: &ArrayD<f32>, b: &ArrayD<f32>, ) -> Result<ArrayD<f32>, &'static str> { let GpuContext { device, queue, .. } = gpu_context; let a_shape = a.shape(); let b_shape = b.shape(); if a.ndim() != 3 || b.ndim() != 3 { return Err("GPU matmul_3d requires 3D matrices."); } if a_shape[0] != b_shape[0] || a_shape[2] != b_shape[1] { return Err("Incompatible dimensions for batched matrix multiplication."); } // Note: This 3D matmul could also be tiled if B * M * N is too large. // For now, we assume it fits, as the most common bottleneck is 2D matmul with large vocab. let batch_size = a_shape[0]; let m = a_shape[1]; let k = a_shape[2]; let n = b_shape[2]; let shader_code = format!( r#" struct Matrix {{ data: array<f32>, }}; @group(0) @binding(0) var<storage, read> a: Matrix; @group(0) @binding(1) var<storage, read> b: Matrix; @group(0) @binding(2) var<storage, read_write> c: Matrix; @compute @workgroup_size(8, 8, 1) fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {{ let B: u32 = {}u; let M: u32 = {}u; let K: u32 = {}u; let N: u32 = {}u; let b_idx = global_id.z; let m_idx = global_id.x; let n_idx = global_id.y; if (m_idx >= M || n_idx >= N || b_idx >= B) {{ return; }} var sum = 0.0; for (var k_idx = 0u; k_idx < K; k_idx = k_idx + 1u) {{ let a_index = b_idx * M * K + m_idx * K + k_idx; let b_index = b_idx * K * N + k_idx * N + n_idx; sum = sum + a.data[a_index] * b.data[b_index]; }} let c_index = b_idx * M * N + m_idx * N + n_idx; c.data[c_index] = sum; }} "#, batch_size, m, k, n ); let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor { label: Some("Batched Matmul Shader"), source: wgpu::ShaderSource::Wgsl(shader_code.into()), }); let a_contig = a.as_standard_layout(); let b_contig = b.as_standard_layout(); let a_slice = a_contig .as_slice() .ok_or("Matrix A is not contiguous in memory for GPU transfer")?; let b_slice = b_contig .as_slice() .ok_or("Matrix B is not contiguous in memory for GPU transfer")?; let a_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { label: Some("Matrix A Buffer (3D)"), contents: bytemuck::cast_slice(a_slice), usage: wgpu::BufferUsages::STORAGE, }); let b_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { label: Some("Matrix B Buffer (3D)"), contents: bytemuck::cast_slice(b_slice), usage: wgpu::BufferUsages::STORAGE, }); let c_buffer_size = (batch_size * m * n * std::mem::size_of::<f32>()) as u64; let c_buffer = device.create_buffer(&wgpu::BufferDescriptor { label: Some("Matrix C Buffer (3D)"), size: c_buffer_size, usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, mapped_at_creation: false, }); let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { label: Some("Batched Matmul Bind Group Layout"), entries: &[ wgpu::BindGroupLayoutEntry { binding: 0, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: true }, has_dynamic_offset: false, min_binding_size: None }, count: None }, wgpu::BindGroupLayoutEntry { binding: 1, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: true }, has_dynamic_offset: false, min_binding_size: None }, count: None }, wgpu::BindGroupLayoutEntry { binding: 2, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: false }, has_dynamic_offset: false, min_binding_size: None }, count: None }, ], }); let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { label: Some("Batched Matmul Pipeline Layout"), bind_group_layouts: &[&bind_group_layout], push_constant_ranges: &[], }); let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { label: Some("Batched Matmul Pipeline"), layout: Some(&pipeline_layout), module: &shader, entry_point: "main", }); let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { label: Some("Batched Matmul Bind Group"), layout: &bind_group_layout, entries: &[ wgpu::BindGroupEntry { binding: 0, resource: a_buffer.as_entire_binding() }, wgpu::BindGroupEntry { binding: 1, resource: b_buffer.as_entire_binding() }, wgpu::BindGroupEntry { binding: 2, resource: c_buffer.as_entire_binding() }, ], }); let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("Batched Matmul Command Encoder"), }); { let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("Batched Matmul Compute Pass"), timestamp_writes: None, }); compute_pass.set_pipeline(&pipeline); compute_pass.set_bind_group(0, &bind_group, &[]); let workgroup_size_x = 8; let workgroup_size_y = 8; compute_pass.dispatch_workgroups( (m as u32 + workgroup_size_x - 1) / workgroup_size_x, (n as u32 + workgroup_size_y - 1) / workgroup_size_y, batch_size as u32, ); } let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { label: Some("Staging Buffer (3D)"), size: c_buffer_size, usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, mapped_at_creation: false, }); encoder.copy_buffer_to_buffer(&c_buffer, 0, &staging_buffer, 0, c_buffer_size); queue.submit(std::iter::once(encoder.finish())); let buffer_slice = staging_buffer.slice(..); let (tx, rx) = std::sync::mpsc::channel(); buffer_slice.map_async(wgpu::MapMode::Read, move |result| { tx.send(result).unwrap(); }); let wait_start = Instant::now(); device.poll(wgpu::Maintain::Wait); rx.recv().unwrap().unwrap(); TOTAL_GPU_TIME_NS.fetch_add( wait_start.elapsed().as_nanos() as u64, Ordering::Relaxed, ); let data = buffer_slice.get_mapped_range(); let result_vec: Vec<f32> = bytemuck::cast_slice(&data).to_vec(); drop(data); staging_buffer.unmap(); Ok(ArrayD::from_shape_vec(vec![batch_size, m, n], result_vec).unwrap()) }

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/8b-is/smart-tree'

If you have feedback or need assistance with the MCP directory API, please join our Discord server