utils/simd/
avx2.rs

1//! AVX2 vector implementations.
2
3use std::array::from_fn;
4use std::ops::{Add, BitAnd, BitOr, BitXor, Not};
5
6#[cfg(target_arch = "x86_64")]
7#[allow(clippy::allow_attributes, clippy::wildcard_imports)]
8use std::arch::x86_64::*;
9
10#[cfg(target_arch = "x86")]
11#[allow(clippy::allow_attributes, clippy::wildcard_imports)]
12use std::arch::x86::*;
13
14/// AVX2 [u32] vector implementation.
15#[derive(Clone, Copy)]
16#[repr(transparent)]
17pub struct U32Vector<const V: usize, const L: usize>([__m256i; V]);
18
19impl<const V: usize, const L: usize> From<[u32; L]> for U32Vector<V, L> {
20    #[inline]
21    fn from(value: [u32; L]) -> Self {
22        Self(from_fn(|i| unsafe {
23            #[expect(
24                clippy::cast_ptr_alignment,
25                reason = "_mm256_loadu_si256 is an unaligned load which requires no alignment"
26            )]
27            _mm256_loadu_si256(value[i * 8..].as_ptr().cast::<__m256i>())
28        }))
29    }
30}
31
32impl<const V: usize, const L: usize> From<U32Vector<V, L>> for [u32; L] {
33    #[inline]
34    fn from(value: U32Vector<V, L>) -> Self {
35        let mut result = [0; L];
36        for (&v, r) in value.0.iter().zip(result.chunks_exact_mut(8)) {
37            unsafe {
38                #[expect(
39                    clippy::cast_ptr_alignment,
40                    reason = "_mm256_storeu_si256 is an unaligned store which requires no alignment"
41                )]
42                _mm256_storeu_si256(r.as_mut_ptr().cast::<__m256i>(), v);
43            }
44        }
45        result
46    }
47}
48
49impl<const V: usize, const L: usize> Add for U32Vector<V, L> {
50    type Output = Self;
51
52    #[inline]
53    fn add(self, rhs: Self) -> Self::Output {
54        Self(from_fn(|i| unsafe {
55            _mm256_add_epi32(self.0[i], rhs.0[i])
56        }))
57    }
58}
59
60impl<const V: usize, const L: usize> BitAnd for U32Vector<V, L> {
61    type Output = Self;
62
63    #[inline]
64    fn bitand(self, rhs: Self) -> Self::Output {
65        Self(from_fn(|i| unsafe {
66            _mm256_and_si256(self.0[i], rhs.0[i])
67        }))
68    }
69}
70
71impl<const V: usize, const L: usize> BitOr for U32Vector<V, L> {
72    type Output = Self;
73
74    #[inline]
75    fn bitor(self, rhs: Self) -> Self::Output {
76        Self(from_fn(|i| unsafe { _mm256_or_si256(self.0[i], rhs.0[i]) }))
77    }
78}
79
80impl<const V: usize, const L: usize> BitXor for U32Vector<V, L> {
81    type Output = Self;
82
83    #[inline]
84    fn bitxor(self, rhs: Self) -> Self::Output {
85        Self(from_fn(|i| unsafe {
86            _mm256_xor_si256(self.0[i], rhs.0[i])
87        }))
88    }
89}
90
91impl<const V: usize, const L: usize> Not for U32Vector<V, L> {
92    type Output = Self;
93
94    #[inline]
95    fn not(self) -> Self::Output {
96        Self(from_fn(|i| unsafe {
97            _mm256_xor_si256(self.0[i], _mm256_set1_epi8(!0))
98        }))
99    }
100}
101
102impl<const V: usize, const L: usize> U32Vector<V, L> {
103    pub const LANES: usize = {
104        assert!(V * 8 == L);
105        L
106    };
107
108    #[inline]
109    #[must_use]
110    #[target_feature(enable = "avx2")]
111    pub fn andnot(self, rhs: Self) -> Self {
112        Self(from_fn(|i| _mm256_andnot_si256(rhs.0[i], self.0[i])))
113    }
114
115    #[inline]
116    #[must_use]
117    #[target_feature(enable = "avx2")]
118    pub fn splat(v: u32) -> Self {
119        Self(
120            #[expect(clippy::cast_possible_wrap)]
121            [_mm256_set1_epi32(v as i32); V],
122        )
123    }
124
125    #[inline]
126    #[must_use]
127    #[target_feature(enable = "avx2")]
128    pub fn rotate_left(self, n: u32) -> Self {
129        Self(from_fn(|i| {
130            #[expect(clippy::cast_possible_wrap)]
131            _mm256_or_si256(
132                _mm256_sll_epi32(self.0[i], _mm_cvtsi32_si128(n as i32)),
133                _mm256_srl_epi32(self.0[i], _mm_cvtsi32_si128(32 - n as i32)),
134            )
135        }))
136    }
137}
138
139/// Vector implementations using a single AVX2 vector.
140pub mod avx2 {
141    /// The name of this backend.
142    pub const SIMD_BACKEND: &str = "avx2";
143
144    /// AVX2 vector with eight [u32] lanes.
145    pub type U32Vector = super::U32Vector<1, 8>;
146}
147
148/// Vector implementations using two AVX2 vectors.
149#[cfg(feature = "all-simd")]
150pub mod avx2x2 {
151    /// The name of this backend.
152    pub const SIMD_BACKEND: &str = "avx2x2";
153
154    /// Two AVX2 vectors with sixteen total [u32] lanes.
155    pub type U32Vector = super::U32Vector<2, 16>;
156}
157
158/// Vector implementations using four AVX2 vectors.
159#[cfg(feature = "all-simd")]
160pub mod avx2x4 {
161    /// The name of this backend.
162    pub const SIMD_BACKEND: &str = "avx2x4";
163
164    /// Four AVX2 vectors with thirty-two total [u32] lanes.
165    pub type U32Vector = super::U32Vector<4, 32>;
166}
167
168/// Vector implementations using eight AVX2 vectors.
169#[cfg(feature = "all-simd")]
170pub mod avx2x8 {
171    /// The name of this backend.
172    pub const SIMD_BACKEND: &str = "avx2x8";
173
174    /// Eight AVX2 vectors with sixty-four total [u32] lanes.
175    pub type U32Vector = super::U32Vector<8, 64>;
176}