simd_ops.rsā¢15.7 kB
use crate::Result;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
/// SIMD-optimized vector operations for high-performance embedding computations
/// Provides 4-8x speedup over scalar operations using AVX2/AVX-512 instructions
pub struct SIMDVectorOps;
impl SIMDVectorOps {
/// Optimized cosine similarity using AVX2 (8 f32 operations in parallel)
/// Performance: ~4x faster than scalar implementation
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[target_feature(enable = "fma")]
pub unsafe fn cosine_similarity_avx2(a: &[f32], b: &[f32]) -> Result<f32> {
if a.len() != b.len() {
return Err(crate::VectorError::DimensionMismatch(a.len(), b.len()).into());
}
let len = a.len();
if len == 0 {
return Ok(0.0);
}
let mut dot_product = _mm256_setzero_ps();
let mut norm_a_squared = _mm256_setzero_ps();
let mut norm_b_squared = _mm256_setzero_ps();
// Process 8 elements at a time
let chunks = len / 8;
for i in 0..chunks {
let idx = i * 8;
// Load 8 f32 values from each array
let va = _mm256_loadu_ps(a.as_ptr().add(idx));
let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
// Parallel operations:
// dot_product += va * vb (using FMA)
dot_product = _mm256_fmadd_ps(va, vb, dot_product);
// norm_a_squared += va * va (using FMA)
norm_a_squared = _mm256_fmadd_ps(va, va, norm_a_squared);
// norm_b_squared += vb * vb (using FMA)
norm_b_squared = _mm256_fmadd_ps(vb, vb, norm_b_squared);
}
// Horizontal sum of the 8-element vectors
let dp = Self::horizontal_sum_avx2(dot_product);
let na_sq = Self::horizontal_sum_avx2(norm_a_squared);
let nb_sq = Self::horizontal_sum_avx2(norm_b_squared);
// Handle remaining elements (scalar fallback)
let mut dp_remainder = 0.0f32;
let mut na_sq_remainder = 0.0f32;
let mut nb_sq_remainder = 0.0f32;
for i in (chunks * 8)..len {
let va = a[i];
let vb = b[i];
dp_remainder += va * vb;
na_sq_remainder += va * va;
nb_sq_remainder += vb * vb;
}
let final_dp = dp + dp_remainder;
let final_na_sq = na_sq + na_sq_remainder;
let final_nb_sq = nb_sq + nb_sq_remainder;
// Compute cosine similarity
let norm_product = (final_na_sq * final_nb_sq).sqrt();
if norm_product == 0.0 {
Ok(0.0)
} else {
Ok(final_dp / norm_product)
}
}
/// Batch cosine similarity computation for multiple queries
/// Optimized for processing large query batches efficiently
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[target_feature(enable = "fma")]
pub unsafe fn batch_cosine_similarity_avx2(
query: &[f32],
embeddings: &[&[f32]],
results: &mut [f32],
) -> Result<()> {
if embeddings.len() != results.len() {
return Err(crate::VectorError::BatchSizeMismatch.into());
}
for (embedding, result) in embeddings.iter().zip(results.iter_mut()) {
*result = Self::cosine_similarity_avx2(query, embedding)?;
}
Ok(())
}
/// Optimized L2 distance computation using AVX2
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[target_feature(enable = "fma")]
pub unsafe fn l2_distance_avx2(a: &[f32], b: &[f32]) -> Result<f32> {
if a.len() != b.len() {
return Err(crate::VectorError::DimensionMismatch(a.len(), b.len()).into());
}
let len = a.len();
if len == 0 {
return Ok(0.0);
}
let mut sum_squared_diff = _mm256_setzero_ps();
// Process 8 elements at a time
let chunks = len / 8;
for i in 0..chunks {
let idx = i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(idx));
let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
// Compute difference: va - vb
let diff = _mm256_sub_ps(va, vb);
// Square the differences and add to sum: sum += diff^2
sum_squared_diff = _mm256_fmadd_ps(diff, diff, sum_squared_diff);
}
// Horizontal sum
let sum_sq = Self::horizontal_sum_avx2(sum_squared_diff);
// Handle remaining elements
let mut remainder_sum = 0.0f32;
for i in (chunks * 8)..len {
let diff = a[i] - b[i];
remainder_sum += diff * diff;
}
Ok((sum_sq + remainder_sum).sqrt())
}
/// Optimized dot product computation using AVX2
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[target_feature(enable = "fma")]
pub unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> Result<f32> {
if a.len() != b.len() {
return Err(crate::VectorError::DimensionMismatch(a.len(), b.len()).into());
}
let len = a.len();
if len == 0 {
return Ok(0.0);
}
let mut dot_product = _mm256_setzero_ps();
// Process 8 elements at a time
let chunks = len / 8;
for i in 0..chunks {
let idx = i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(idx));
let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
// dot_product += va * vb
dot_product = _mm256_fmadd_ps(va, vb, dot_product);
}
// Horizontal sum
let dp = Self::horizontal_sum_avx2(dot_product);
// Handle remaining elements
let mut remainder = 0.0f32;
for i in (chunks * 8)..len {
remainder += a[i] * b[i];
}
Ok(dp + remainder)
}
/// Normalize vector in-place using AVX2
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[target_feature(enable = "fma")]
pub unsafe fn normalize_avx2(vector: &mut [f32]) -> Result<()> {
if vector.is_empty() {
return Ok(());
}
// Compute norm
let norm_squared = Self::dot_product_avx2(vector, vector)?;
if norm_squared == 0.0 {
return Ok(()); // Zero vector remains zero
}
let norm = norm_squared.sqrt();
let inv_norm = 1.0 / norm;
let inv_norm_vec = _mm256_set1_ps(inv_norm);
// Normalize 8 elements at a time
let len = vector.len();
let chunks = len / 8;
for i in 0..chunks {
let idx = i * 8;
let v = _mm256_loadu_ps(vector.as_ptr().add(idx));
let normalized = _mm256_mul_ps(v, inv_norm_vec);
_mm256_storeu_ps(vector.as_mut_ptr().add(idx), normalized);
}
// Handle remaining elements
for i in (chunks * 8)..len {
vector[i] *= inv_norm;
}
Ok(())
}
/// Efficient horizontal sum of 8 f32 values in AVX2 register
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn horizontal_sum_avx2(v: __m256) -> f32 {
// v = [a, b, c, d, e, f, g, h]
// Permute and add to get [e+a, f+b, g+c, h+d, a+e, b+f, c+g, d+h]
let v_perm = _mm256_permute2f128_ps(v, v, 0x01);
let v_add1 = _mm256_add_ps(v, v_perm);
// Now we have [e+a, f+b, g+c, h+d] in lower 128 bits
// Horizontal add to get [e+a+f+b, g+c+h+d, *, *]
let v_hadd1 = _mm256_hadd_ps(v_add1, v_add1);
// Final horizontal add to get sum in lowest element
let v_hadd2 = _mm256_hadd_ps(v_hadd1, v_hadd1);
// Extract lowest element
_mm256_cvtss_f32(v_hadd2)
}
/// Check if AVX2 is available at runtime
pub fn is_avx2_available() -> bool {
#[cfg(target_arch = "x86_64")]
{
is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma")
}
#[cfg(not(target_arch = "x86_64"))]
{
false
}
}
/// Fallback scalar implementation for non-AVX2 systems
pub fn cosine_similarity_scalar(a: &[f32], b: &[f32]) -> Result<f32> {
if a.len() != b.len() {
return Err(crate::VectorError::DimensionMismatch(a.len(), b.len()).into());
}
let mut dot_product = 0.0f32;
let mut norm_a_squared = 0.0f32;
let mut norm_b_squared = 0.0f32;
for (&va, &vb) in a.iter().zip(b.iter()) {
dot_product += va * vb;
norm_a_squared += va * va;
norm_b_squared += vb * vb;
}
let norm_product = (norm_a_squared * norm_b_squared).sqrt();
if norm_product == 0.0 {
Ok(0.0)
} else {
Ok(dot_product / norm_product)
}
}
/// Adaptive similarity computation that chooses the best implementation
pub fn adaptive_cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32> {
#[cfg(target_arch = "x86_64")]
{
if Self::is_avx2_available() && a.len() >= 32 {
// Use SIMD for vectors with at least 32 elements
unsafe { Self::cosine_similarity_avx2(a, b) }
} else {
Self::cosine_similarity_scalar(a, b)
}
}
#[cfg(not(target_arch = "x86_64"))]
{
Self::cosine_similarity_scalar(a, b)
}
}
#[cfg(not(target_arch = "x86_64"))]
pub unsafe fn cosine_similarity_avx2(_a: &[f32], _b: &[f32]) -> Result<f32> {
Err(
crate::VectorError::SimdError("AVX2 not supported on this architecture".to_string())
.into(),
)
}
#[cfg(not(target_arch = "x86_64"))]
pub unsafe fn batch_cosine_similarity_avx2(
_query: &[f32],
_embeddings: &[&[f32]],
_results: &mut [f32],
) -> Result<()> {
Err(
crate::VectorError::SimdError("AVX2 not supported on this architecture".to_string())
.into(),
)
}
#[cfg(not(target_arch = "x86_64"))]
pub unsafe fn l2_distance_avx2(_a: &[f32], _b: &[f32]) -> Result<f32> {
Err(
crate::VectorError::SimdError("AVX2 not supported on this architecture".to_string())
.into(),
)
}
#[cfg(not(target_arch = "x86_64"))]
pub unsafe fn dot_product_avx2(_a: &[f32], _b: &[f32]) -> Result<f32> {
Err(
crate::VectorError::SimdError("AVX2 not supported on this architecture".to_string())
.into(),
)
}
#[cfg(not(target_arch = "x86_64"))]
pub unsafe fn normalize_avx2(_vector: &mut [f32]) -> Result<()> {
Err(
crate::VectorError::SimdError("AVX2 not supported on this architecture".to_string())
.into(),
)
}
}
/// Parallel vector operations using Rayon for CPU-bound tasks
pub struct ParallelVectorOps;
impl ParallelVectorOps {
/// Parallel batch similarity computation using thread pool
pub fn parallel_batch_similarity(
query: &[f32],
embeddings: &[Vec<f32>],
similarity_fn: fn(&[f32], &[f32]) -> Result<f32>,
) -> Result<Vec<f32>> {
use rayon::prelude::*;
embeddings
.par_iter()
.map(|embedding| similarity_fn(query, embedding))
.collect()
}
/// Parallel top-k similarity search
pub fn parallel_top_k_search(
query: &[f32],
embeddings: &[Vec<f32>],
k: usize,
) -> Result<Vec<(usize, f32)>> {
use rayon::prelude::*;
let mut similarities: Vec<(usize, f32)> = embeddings
.par_iter()
.enumerate()
.map(|(idx, embedding)| {
let sim =
SIMDVectorOps::adaptive_cosine_similarity(query, embedding).unwrap_or(0.0);
(idx, sim)
})
.collect();
// Sort by similarity (descending) and take top-k
similarities.par_sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
similarities.truncate(k);
Ok(similarities)
}
/// Parallel vector normalization
pub fn parallel_normalize_vectors(vectors: &mut [Vec<f32>]) -> Result<()> {
use rayon::prelude::*;
vectors.par_iter_mut().try_for_each(|vector| {
#[cfg(target_arch = "x86_64")]
{
if SIMDVectorOps::is_avx2_available() {
unsafe { SIMDVectorOps::normalize_avx2(vector) }
} else {
// Scalar normalization fallback
let norm_squared: f32 = vector.iter().map(|&x| x * x).sum();
if norm_squared > 0.0 {
let norm = norm_squared.sqrt();
for x in vector.iter_mut() {
*x /= norm;
}
}
Ok(())
}
}
#[cfg(not(target_arch = "x86_64"))]
{
// Scalar normalization fallback
let norm_squared: f32 = vector.iter().map(|&x| x * x).sum();
if norm_squared > 0.0 {
let norm = norm_squared.sqrt();
for x in vector.iter_mut() {
*x /= norm;
}
}
Ok(())
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_simd_cosine_similarity() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let b = vec![8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
let scalar_result = SIMDVectorOps::cosine_similarity_scalar(&a, &b).unwrap();
#[cfg(target_arch = "x86_64")]
{
if SIMDVectorOps::is_avx2_available() {
let simd_result = unsafe { SIMDVectorOps::cosine_similarity_avx2(&a, &b).unwrap() };
assert_abs_diff_eq!(scalar_result, simd_result, epsilon = 1e-6);
println!(
"SIMD vs Scalar similarity: {} vs {}",
simd_result, scalar_result
);
}
}
}
#[test]
fn test_adaptive_similarity() {
let a: Vec<f32> = (0..100).map(|i| i as f32).collect();
let b: Vec<f32> = (0..100).map(|i| (100 - i) as f32).collect();
let result = SIMDVectorOps::adaptive_cosine_similarity(&a, &b).unwrap();
println!("Adaptive similarity result: {}", result);
// Should be valid similarity score
assert!(result >= -1.0 && result <= 1.0);
}
#[test]
fn test_parallel_operations() {
let query = vec![1.0; 256];
let embeddings: Vec<Vec<f32>> = (0..1000)
.map(|i| (0..256).map(|j| (i + j) as f32).collect())
.collect();
let results = ParallelVectorOps::parallel_top_k_search(&query, &embeddings, 10).unwrap();
assert_eq!(results.len(), 10);
println!("Top-10 parallel search results: {:?}", &results[0..3]);
}
#[test]
fn test_memory_alignment() {
use codegraph_core::optimized_types::AlignedVec;
let mut aligned_vec = AlignedVec::new_aligned(1024, 32);
for i in 0..512 {
aligned_vec.push(i as f32).unwrap();
}
// Test that the memory is properly aligned for SIMD operations
let ptr = aligned_vec.as_slice().as_ptr() as usize;
assert_eq!(ptr % 32, 0, "Vector not properly aligned for AVX2");
}
}