utils/simd/
avx512.rs

1//! AVX512 vector implementations.
2//!
3//! Currently only requires AVX-512F.
4
5use std::array::from_fn;
6use std::ops::{Add, BitAnd, BitOr, BitXor, Not};
7
8#[cfg(target_arch = "x86_64")]
9#[allow(clippy::allow_attributes, clippy::wildcard_imports)]
10use std::arch::x86_64::*;
11
12#[cfg(target_arch = "x86")]
13#[allow(clippy::allow_attributes, clippy::wildcard_imports)]
14use std::arch::x86::*;
15
16/// AVX512 [u32] vector implementation.
17#[derive(Clone, Copy)]
18#[repr(transparent)]
19pub struct U32Vector<const V: usize, const L: usize>([__m512i; V]);
20
21impl<const V: usize, const L: usize> From<[u32; L]> for U32Vector<V, L> {
22    #[inline]
23    fn from(value: [u32; L]) -> Self {
24        Self(from_fn(|i| unsafe {
25            #[expect(
26                clippy::cast_ptr_alignment,
27                reason = "_mm512_loadu_si512 is an unaligned load which requires no alignment"
28            )]
29            _mm512_loadu_si512(value[i * 16..].as_ptr().cast::<__m512i>())
30        }))
31    }
32}
33
34impl<const V: usize, const L: usize> From<U32Vector<V, L>> for [u32; L] {
35    #[inline]
36    fn from(value: U32Vector<V, L>) -> Self {
37        let mut result = [0; L];
38        for (&v, r) in value.0.iter().zip(result.chunks_exact_mut(16)) {
39            unsafe {
40                #[expect(
41                    clippy::cast_ptr_alignment,
42                    reason = "_mm512_storeu_si512 is an unaligned store which requires no alignment"
43                )]
44                _mm512_storeu_si512(r.as_mut_ptr().cast::<__m512i>(), v);
45            }
46        }
47        result
48    }
49}
50
51impl<const V: usize, const L: usize> Add for U32Vector<V, L> {
52    type Output = Self;
53
54    #[inline]
55    fn add(self, rhs: Self) -> Self::Output {
56        Self(from_fn(|i| unsafe {
57            _mm512_add_epi32(self.0[i], rhs.0[i])
58        }))
59    }
60}
61
62impl<const V: usize, const L: usize> BitAnd for U32Vector<V, L> {
63    type Output = Self;
64
65    #[inline]
66    fn bitand(self, rhs: Self) -> Self::Output {
67        Self(from_fn(|i| unsafe {
68            _mm512_and_si512(self.0[i], rhs.0[i])
69        }))
70    }
71}
72
73impl<const V: usize, const L: usize> BitOr for U32Vector<V, L> {
74    type Output = Self;
75
76    #[inline]
77    fn bitor(self, rhs: Self) -> Self::Output {
78        Self(from_fn(|i| unsafe { _mm512_or_si512(self.0[i], rhs.0[i]) }))
79    }
80}
81
82impl<const V: usize, const L: usize> BitXor for U32Vector<V, L> {
83    type Output = Self;
84
85    #[inline]
86    fn bitxor(self, rhs: Self) -> Self::Output {
87        Self(from_fn(|i| unsafe {
88            _mm512_xor_si512(self.0[i], rhs.0[i])
89        }))
90    }
91}
92
93impl<const V: usize, const L: usize> Not for U32Vector<V, L> {
94    type Output = Self;
95
96    #[inline]
97    fn not(self) -> Self::Output {
98        Self(from_fn(|i| unsafe {
99            _mm512_xor_si512(self.0[i], _mm512_set1_epi8(!0))
100        }))
101    }
102}
103
104impl<const V: usize, const L: usize> U32Vector<V, L> {
105    pub const LANES: usize = {
106        assert!(V * 16 == L);
107        L
108    };
109
110    #[inline]
111    #[must_use]
112    #[target_feature(enable = "avx512f")]
113    pub fn andnot(self, rhs: Self) -> Self {
114        Self(from_fn(|i| _mm512_andnot_si512(rhs.0[i], self.0[i])))
115    }
116
117    #[inline]
118    #[must_use]
119    #[target_feature(enable = "avx512f")]
120    pub fn splat(v: u32) -> Self {
121        Self(
122            #[expect(clippy::cast_possible_wrap)]
123            [_mm512_set1_epi32(v as i32); V],
124        )
125    }
126
127    #[inline]
128    #[must_use]
129    #[target_feature(enable = "avx512f")]
130    pub fn rotate_left(self, n: u32) -> Self {
131        Self(from_fn(|i| {
132            #[expect(clippy::cast_possible_wrap)]
133            _mm512_or_si512(
134                _mm512_sll_epi32(self.0[i], _mm_cvtsi32_si128(n as i32)),
135                _mm512_srl_epi32(self.0[i], _mm_cvtsi32_si128(32 - n as i32)),
136            )
137        }))
138    }
139}
140
141/// Vector implementations using a single AVX512 vector.
142pub mod avx512 {
143    /// The name of this backend.
144    pub const SIMD_BACKEND: &str = "avx512";
145
146    /// AVX512 vector with sixteen [u32] lanes.
147    pub type U32Vector = super::U32Vector<1, 16>;
148}
149
150/// Vector implementations using two AVX512 vectors.
151#[cfg(feature = "all-simd")]
152pub mod avx512x2 {
153    /// The name of this backend.
154    pub const SIMD_BACKEND: &str = "avx512x2";
155
156    /// Two AVX512 vectors with thirty-two total [u32] lanes.
157    pub type U32Vector = super::U32Vector<2, 32>;
158}
159
160/// Vector implementations using four AVX512 vectors.
161#[cfg(feature = "all-simd")]
162pub mod avx512x4 {
163    /// The name of this backend.
164    pub const SIMD_BACKEND: &str = "avx512x4";
165
166    /// Four AVX512 vectors with sixty-four total [u32] lanes.
167    pub type U32Vector = super::U32Vector<4, 64>;
168}
169
170/// Vector implementations using eight AVX512 vectors.
171#[cfg(feature = "all-simd")]
172pub mod avx512x8 {
173    /// The name of this backend.
174    pub const SIMD_BACKEND: &str = "avx512x8";
175
176    /// Eight AVX512 vectors with 128 total [u32] lanes.
177    pub type U32Vector = super::U32Vector<8, 128>;
178}