use core::{mem::MaybeUninit, ptr};
use zerocopy::transmute;
#[cfg(target_arch = "x86")]
use core::arch::x86::{
__m128, __m128i, __m256, _mm256_cvtph_ps, _mm256_cvtps_ph, _mm_cvtph_ps,
_MM_FROUND_TO_NEAREST_INT,
};
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::{
__m128, __m128i, __m256, _mm256_cvtph_ps, _mm256_cvtps_ph, _mm_cvtph_ps, _mm_cvtps_ph,
_MM_FROUND_TO_NEAREST_INT,
};
#[cfg(target_arch = "x86")]
use core::arch::x86::_mm_cvtps_ph;
use super::convert_chunked_slice_8;
/////////////// x86/x86_64 f16c ////////////////
#[target_feature(enable = "f16c")]
#[inline]
pub(super) unsafe fn f16_to_f32_x86_f16c(i: u16) -> f32 {
let vec: __m128i = transmute!([i, 0, 0, 0, 0, 0, 0, 0]);
let retval: [f32; 4] = transmute!(_mm_cvtph_ps(vec));
retval[0]
}
#[target_feature(enable = "f16c")]
#[inline]
pub(super) unsafe fn f32_to_f16_x86_f16c(f: f32) -> u16 {
let vec: __m128 = transmute!([f, 0.0, 0.0, 0.0]);
let retval = _mm_cvtps_ph(vec, _MM_FROUND_TO_NEAREST_INT);
let retval: [u16; 8] = transmute!(retval);
retval[0]
}
#[target_feature(enable = "f16c")]
#[inline]
pub(super) unsafe fn f16x4_to_f32x4_x86_f16c(v: &[u16; 4]) -> [f32; 4] {
let vec: __m128i = transmute!([*v, [0, 0, 0, 0]]);
transmute!(_mm_cvtph_ps(vec))
}
#[target_feature(enable = "f16c")]
#[inline]
pub(super) unsafe fn f32x4_to_f16x4_x86_f16c(v: &[f32; 4]) -> [u16; 4] {
let vec: __m128 = zerocopy::transmute!(*v);
let retval = _mm_cvtps_ph(vec, _MM_FROUND_TO_NEAREST_INT);
let retval: [[u16; 4]; 2] = transmute!(retval);
retval[0]
}
#[target_feature(enable = "f16c")]
#[inline]
pub(super) unsafe fn f16x4_to_f64x4_x86_f16c(v: &[u16; 4]) -> [f64; 4] {
let array = f16x4_to_f32x4_x86_f16c(v);
// Let compiler vectorize this regular cast for now.
// TODO: investigate auto-detecting sse2/avx convert features
[
array[0] as f64,
array[1] as f64,
array[2] as f64,
array[3] as f64,
]
}
#[target_feature(enable = "f16c")]
#[inline]
pub(super) unsafe fn f64x4_to_f16x4_x86_f16c(v: &[f64; 4]) -> [u16; 4] {
// Let compiler vectorize this regular cast for now.
// TODO: investigate auto-detecting sse2/avx convert features
let v = [v[0] as f32, v[1] as f32, v[2] as f32, v[3] as f32];
f32x4_to_f16x4_x86_f16c(&v)
}
#[target_feature(enable = "f16c")]
#[inline]
pub(super) unsafe fn f16x8_to_f32x8_x86_f16c(v: &[u16; 8]) -> [f32; 8] {
let vec: __m128i = transmute!(*v);
transmute!(_mm256_cvtph_ps(vec))
}
#[target_feature(enable = "f16c")]
#[inline]
pub(super) unsafe fn f32x8_to_f16x8_x86_f16c(v: &[f32; 8]) -> [u16; 8] {
let vec: __m256 = transmute!(*v);
let retval = _mm256_cvtps_ph(vec, _MM_FROUND_TO_NEAREST_INT);
transmute!(retval)
}
#[target_feature(enable = "f16c")]
#[inline]
pub(super) unsafe fn f16x8_to_f64x8_x86_f16c(v: &[u16; 8]) -> [f64; 8] {
let array = f16x8_to_f32x8_x86_f16c(v);
// Let compiler vectorize this regular cast for now.
// TODO: investigate auto-detecting sse2/avx convert features
[
array[0] as f64,
array[1] as f64,
array[2] as f64,
array[3] as f64,
array[4] as f64,
array[5] as f64,
array[6] as f64,
array[7] as f64,
]
}
#[target_feature(enable = "f16c")]
#[inline]
pub(super) unsafe fn f64x8_to_f16x8_x86_f16c(v: &[f64; 8]) -> [u16; 8] {
// Let compiler vectorize this regular cast for now.
// TODO: investigate auto-detecting sse2/avx convert features
let v = [
v[0] as f32,
v[1] as f32,
v[2] as f32,
v[3] as f32,
v[4] as f32,
v[5] as f32,
v[6] as f32,
v[7] as f32,
];
f32x8_to_f16x8_x86_f16c(&v)
}