utils/
number.rs

1//! Traits for using numbers as generic data types.
2
3use std::fmt::Debug;
4use std::iter::{Product, Sum};
5use std::ops::{
6    Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign,
7    Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
8};
9
10/// Trait implemented by the primitive number types, combining common supertraits.
11pub trait Number:
12    Copy
13    + Debug
14    + Default
15    + PartialEq
16    + PartialOrd
17    + Add<Output = Self>
18    + AddAssign
19    + Div<Output = Self>
20    + DivAssign
21    + Mul<Output = Self>
22    + MulAssign
23    + Rem<Output = Self>
24    + RemAssign
25    + Sub<Output = Self>
26    + SubAssign
27    + Sum<Self>
28    + for<'a> Sum<&'a Self>
29    + Product<Self>
30    + for<'a> Product<&'a Self>
31{
32    const ZERO: Self;
33    const ONE: Self;
34    const MIN: Self;
35    const MAX: Self;
36
37    #[must_use]
38    fn abs(self) -> Self;
39    #[must_use]
40    fn rem_euclid(self, rhs: Self) -> Self;
41    #[must_use]
42    fn squared_diff(self, rhs: Self) -> Self;
43}
44
45/// Trait implemented by the primitive signed integer and floating point types.
46pub trait Signed: Number + Neg<Output = Self> + From<i8> {
47    const MINUS_ONE: Self;
48}
49
50/// Trait implemented by the primitive integer types.
51pub trait Integer:
52    Number
53    + Not<Output = Self>
54    + BitAnd<Output = Self>
55    + BitAndAssign
56    + BitOr<Output = Self>
57    + BitOrAssign
58    + BitXor<Output = Self>
59    + BitXorAssign
60    + Shl<Output = Self>
61    + Shl<u32, Output = Self>
62    + ShlAssign
63    + ShlAssign<u32>
64    + Shr<Output = Self>
65    + Shr<u32, Output = Self>
66    + ShrAssign
67    + ShrAssign<u32>
68    + TryInto<i128>
69{
70    type Unsigned: UnsignedInteger;
71    type Signed: SignedInteger;
72
73    #[must_use]
74    fn abs_diff(self, rhs: Self) -> Self::Unsigned;
75    #[must_use]
76    fn checked_add(self, rhs: Self) -> Option<Self>;
77    #[must_use]
78    fn checked_sub(self, rhs: Self) -> Option<Self>;
79    #[must_use]
80    fn checked_mul(self, rhs: Self) -> Option<Self>;
81    #[must_use]
82    fn trailing_ones(self) -> u32;
83    #[must_use]
84    fn trailing_zeros(self) -> u32;
85    #[must_use]
86    fn unsigned_abs(self) -> Self::Unsigned;
87    #[must_use]
88    fn saturating_sub_0(self, rhs: Self) -> Self::Unsigned;
89}
90
91/// Trait implemented by the primitive unsigned integer types.
92pub trait UnsignedInteger: Integer<Unsigned = Self> + From<u8> {
93    #[must_use]
94    fn wrapping_add_signed(self, rhs: Self::Signed) -> Self;
95}
96
97/// Trait implemented by the primitive signed integer types.
98pub trait SignedInteger: Integer<Signed = Self> + Signed {}
99
100macro_rules! number_impl {
101    (int => $($u:ident: $s:ident ),+) => {
102        $(impl Number for $u {
103            const ZERO: Self = 0;
104            const ONE: Self = 1;
105            const MIN: Self = Self::MIN;
106            const MAX: Self = Self::MAX;
107
108            #[inline]
109            fn abs(self) -> Self {
110                self // no-op for unsigned integers
111            }
112
113            #[inline]
114            fn rem_euclid(self, rhs: Self) -> Self {
115                self.rem_euclid(rhs)
116            }
117
118            #[inline]
119            fn squared_diff(self, rhs: Self) -> Self {
120                let diff = self.abs_diff(rhs);
121                diff * diff
122            }
123        })+
124
125        $(impl Integer for $u {
126            type Unsigned = $u;
127            type Signed = $s;
128
129            #[inline]
130            fn abs_diff(self, rhs: Self) -> Self::Unsigned {
131                self.abs_diff(rhs)
132            }
133            #[inline]
134            fn checked_add(self, rhs: Self) -> Option<Self> {
135                self.checked_add(rhs)
136            }
137            #[inline]
138            fn checked_sub(self, rhs: Self) -> Option<Self> {
139                self.checked_sub(rhs)
140            }
141            #[inline]
142            fn checked_mul(self, rhs: Self) -> Option<Self> {
143                self.checked_mul(rhs)
144            }
145            #[inline]
146            fn trailing_ones(self) -> u32 {
147                self.trailing_ones()
148            }
149            #[inline]
150            fn trailing_zeros(self) -> u32 {
151                self.trailing_zeros()
152            }
153            #[inline]
154            fn unsigned_abs(self) -> Self::Unsigned {
155                self // no-op for unsigned integers
156            }
157            #[inline]
158            fn saturating_sub_0(self, rhs: Self) -> Self::Unsigned {
159                self.saturating_sub(rhs)
160            }
161        })+
162
163        $(impl UnsignedInteger for $u {
164            #[inline]
165            fn wrapping_add_signed(self, rhs: Self::Signed) -> Self {
166                self.wrapping_add_signed(rhs)
167            }
168        })+
169
170        $(impl Number for $s {
171            const ZERO: Self = 0;
172            const ONE: Self = 1;
173            const MIN: Self = Self::MIN;
174            const MAX: Self = Self::MAX;
175
176            #[inline]
177            fn abs(self) -> Self {
178                self.abs()
179            }
180
181            #[inline]
182            fn rem_euclid(self, rhs: Self) -> Self {
183                self.rem_euclid(rhs)
184            }
185
186            #[inline]
187            fn squared_diff(self, rhs: Self) -> Self {
188                let diff = self - rhs;
189                diff * diff
190            }
191        })+
192
193        $(impl Signed for $s {
194            const MINUS_ONE: Self = -Self::ONE;
195        })+
196
197        $(impl Integer for $s {
198            type Unsigned = $u;
199            type Signed = $s;
200
201            #[inline]
202            fn abs_diff(self, rhs: Self) -> Self::Unsigned {
203                self.abs_diff(rhs)
204            }
205            #[inline]
206            fn checked_add(self, rhs: Self) -> Option<Self> {
207                self.checked_add(rhs)
208            }
209            #[inline]
210            fn checked_sub(self, rhs: Self) -> Option<Self> {
211                self.checked_sub(rhs)
212            }
213            #[inline]
214            fn checked_mul(self, rhs: Self) -> Option<Self> {
215                self.checked_mul(rhs)
216            }
217            #[inline]
218            fn trailing_ones(self) -> u32 {
219                self.trailing_ones()
220            }
221            #[inline]
222            fn trailing_zeros(self) -> u32 {
223                self.trailing_zeros()
224            }
225            #[inline]
226            fn unsigned_abs(self) -> Self::Unsigned {
227                self.unsigned_abs()
228            }
229            #[inline]
230            #[expect(clippy::cast_sign_loss)]
231            fn saturating_sub_0(self, rhs: Self) -> Self::Unsigned {
232                // Equivalent to `self.saturating_sub(rhs).max(0) as $u`, but avoids overflow for
233                // e.g. i32::MAX - i32::MIN
234                let diff = (self as $u).wrapping_sub(rhs as $u);
235                let mask = (0 as $u).wrapping_sub($u::from(self >= rhs));
236                diff & mask
237            }
238        })+
239
240        $(impl SignedInteger for $s {})+
241    };
242    (float => $($t:ident),+) => {$(
243        impl Number for $t {
244            const ZERO: Self = 0.0;
245            const ONE: Self = 1.0;
246            const MIN: Self = Self::NEG_INFINITY;
247            const MAX: Self = Self::INFINITY;
248
249            #[inline]
250            fn abs(self) -> Self {
251                self.abs()
252            }
253
254            #[inline]
255            fn rem_euclid(self, rhs: Self) -> Self {
256                self.rem_euclid(rhs)
257            }
258
259            #[inline]
260            fn squared_diff(self, rhs: Self) -> Self {
261                let diff = self - rhs;
262                diff * diff
263            }
264        }
265
266        impl Signed for $t {
267            const MINUS_ONE: Self = -Self::ONE;
268        }
269    )+};
270}
271number_impl! {int => u8: i8, u16: i16, u32: i32, u64: i64, u128: i128, usize: isize}
272number_impl! {float => f32, f64}
273
274/// Checks if the provided unsigned integer `n` is a prime number.
275///
276/// # Examples
277/// ```
278/// # use utils::number::is_prime;
279/// assert_eq!(is_prime(7901u32), true);
280/// assert_eq!(is_prime(2147483647u32), true);
281/// assert_eq!(is_prime(4294967291u32), true);
282/// assert_eq!(is_prime(6u32), false);
283/// assert_eq!(is_prime(123u32), false);
284/// ```
285#[inline]
286#[must_use]
287pub fn is_prime<T: UnsignedInteger>(n: T) -> bool {
288    if n <= T::ONE {
289        return false;
290    }
291    if n == T::from(2) || n == T::from(3) {
292        return true;
293    }
294    if n % T::from(2) == T::ZERO || n % T::from(3) == T::ZERO {
295        return false;
296    }
297
298    let mut i = T::from(5);
299    while let Some(square) = i.checked_mul(i)
300        && square <= n
301    {
302        if n % i == T::ZERO || n % (i + T::from(2)) == T::ZERO {
303            return false;
304        }
305
306        if let Some(next) = i.checked_add(T::from(6)) {
307            i = next;
308        } else {
309            break;
310        }
311    }
312
313    true
314}
315
316/// Computes the sum of the divisors for unsigned integer `n`.
317///
318/// Returns `None` if the sum overflows.
319///
320/// # Examples
321/// ```
322/// # use utils::number::sum_of_divisors;
323/// assert_eq!(sum_of_divisors(5u32), Some(6));
324/// assert_eq!(sum_of_divisors(32u32), Some(63));
325/// assert_eq!(sum_of_divisors(50u32), Some(93));
326/// assert_eq!(sum_of_divisors(857_656_800u32), None);
327/// assert_eq!(sum_of_divisors(857_656_800u64), Some(4_376_251_152));
328/// ```
329#[inline]
330#[must_use]
331pub fn sum_of_divisors<T: UnsignedInteger>(n: T) -> Option<T> {
332    if n <= T::ONE {
333        return Some(n);
334    }
335
336    let mut sum = T::ZERO;
337    let mut d = T::ONE;
338    while let Some(square) = d.checked_mul(d)
339        && square <= n
340    {
341        if n % d == T::ZERO {
342            if let Some(s) = sum.checked_add(d) {
343                sum = s;
344            } else {
345                return None;
346            }
347
348            let q = n / d;
349            if q != d {
350                if let Some(s) = sum.checked_add(q) {
351                    sum = s;
352                } else {
353                    return None;
354                }
355            }
356        }
357
358        d += T::ONE;
359    }
360
361    Some(sum)
362}
363
364/// Computes the greatest common divisor (GCD) using the extended Euclidean algorithm.
365///
366/// Returns a tuple `(gcd, x, y)` where `x`, `y` are the coefficients of Bézout's identity:
367/// ```text
368/// ax + by = gcd(a, b)
369/// ```
370///
371/// # Examples
372/// ```
373/// # use utils::number::egcd;
374/// assert_eq!(egcd(252, 105), (21, -2, 5));
375/// assert_eq!((252 * -2) + (105 * 5), 21);
376/// ```
377#[inline]
378#[must_use]
379pub fn egcd<T: SignedInteger>(mut a: T, mut b: T) -> (T, T, T) {
380    let (mut x0, mut x1, mut y0, mut y1) = (T::ONE, T::ZERO, T::ZERO, T::ONE);
381
382    while b != T::ZERO {
383        let q = a / b;
384        (a, b) = (b, a % b);
385        (x0, x1) = (x1, x0 - q * x1);
386        (y0, y1) = (y1, y0 - q * y1);
387    }
388
389    (a, x0, y0)
390}
391
392/// Computes the lowest common multiple (LCM).
393///
394/// # Examples
395/// ```
396/// # use utils::number::lcm;
397/// assert_eq!(lcm(6, 4), 12);
398/// assert_eq!(lcm(21, 6), 42);
399/// ```
400#[inline]
401#[must_use]
402pub fn lcm<T: SignedInteger>(a: T, b: T) -> T {
403    if a == T::ZERO || b == T::ZERO {
404        return T::ZERO;
405    }
406
407    let (gcd, ..) = egcd(a, b);
408    (a / gcd).abs() * b.abs()
409}
410
411/// Computes the modular inverse of `a` modulo `b` if it exists.
412///
413/// # Examples
414/// ```
415/// # use utils::number::mod_inverse;
416/// assert_eq!(mod_inverse(3, 5), Some(2));
417/// assert_eq!((3 * 2) % 5, 1);
418///
419/// assert_eq!(mod_inverse(10, 23), Some(7));
420/// assert_eq!((10 * 7) % 23, 1);
421///
422/// assert_eq!(mod_inverse(2, 8), None);
423/// ```
424#[inline]
425#[must_use]
426pub fn mod_inverse<T: SignedInteger>(a: T, b: T) -> Option<T> {
427    let (gcd, x, _) = egcd(a, b);
428    if gcd == T::ONE {
429        Some(x.rem_euclid(b))
430    } else {
431        None
432    }
433}
434
435/// Solves a system of simultaneous congruences using the Chinese Remainder Theorem.
436///
437/// This function finds the smallest non-negative integer `x` where `x % modulus = residue` for each
438/// provided (residue, modulus) pair.
439///
440/// # Examples
441/// ```
442/// # use utils::number::chinese_remainder;
443/// assert_eq!(chinese_remainder([1, 2, 3], [5, 7, 11]), Some(366));
444/// assert_eq!(366 % 5, 1);
445/// assert_eq!(366 % 7, 2);
446/// assert_eq!(366 % 11, 3);
447/// ```
448#[inline]
449#[must_use]
450pub fn chinese_remainder<T: SignedInteger>(
451    residues: impl IntoIterator<Item = T>,
452    moduli: impl IntoIterator<Item = T, IntoIter: Clone>,
453) -> Option<T> {
454    let moduli = moduli.into_iter();
455    let product = moduli.clone().product();
456
457    let mut sum = T::ZERO;
458    for (residue, modulus) in residues.into_iter().zip(moduli) {
459        let p = product / modulus;
460        sum += residue * mod_inverse(p, modulus)? * p;
461    }
462
463    Some(sum.rem_euclid(product))
464}
465
466/// Calculates `base.pow(exponent) % modulus`.
467///
468/// # Examples
469/// ```
470/// # use utils::number::mod_pow;
471/// assert_eq!(mod_pow::<u64>(2, 10, 1000), 24);
472/// assert_eq!(mod_pow::<u64>(65, 100000, 2147483647), 1085966926);
473/// ```
474#[inline]
475#[must_use]
476pub fn mod_pow<T: UnsignedInteger>(base: T, exponent: T, modulus: T) -> T {
477    let mut result = T::ONE;
478    let mut base = base % modulus;
479    let mut exponent = exponent;
480
481    while exponent > T::ZERO {
482        if exponent % T::from(2) == T::ONE {
483            result = (result * base) % modulus;
484        }
485        exponent >>= 1;
486        base = (base * base) % modulus;
487    }
488
489    result
490}