]> git.proxmox.com Git - rustc.git/blob - vendor/crypto-bigint/src/non_zero.rs
New upstream version 1.76.0+dfsg1
[rustc.git] / vendor / crypto-bigint / src / non_zero.rs
1 //! Wrapper type for non-zero integers.
2
3 use crate::{CtChoice, Encoding, Integer, Limb, Uint, Zero};
4 use core::{
5 fmt,
6 num::{NonZeroU128, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8},
7 ops::Deref,
8 };
9 use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
10
11 #[cfg(feature = "generic-array")]
12 use crate::{ArrayEncoding, ByteArray};
13
14 #[cfg(feature = "rand_core")]
15 use {crate::Random, rand_core::CryptoRngCore};
16
17 #[cfg(feature = "serde")]
18 use serdect::serde::{
19 de::{Error, Unexpected},
20 Deserialize, Deserializer, Serialize, Serializer,
21 };
22
23 /// Wrapper type for non-zero integers.
24 #[derive(Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord)]
25 pub struct NonZero<T: Zero>(T);
26
27 impl NonZero<Limb> {
28 /// Creates a new non-zero limb in a const context.
29 /// The second return value is `FALSE` if `n` is zero, `TRUE` otherwise.
30 pub const fn const_new(n: Limb) -> (Self, CtChoice) {
31 (Self(n), n.ct_is_nonzero())
32 }
33 }
34
35 impl<const LIMBS: usize> NonZero<Uint<LIMBS>> {
36 /// Creates a new non-zero integer in a const context.
37 /// The second return value is `FALSE` if `n` is zero, `TRUE` otherwise.
38 pub const fn const_new(n: Uint<LIMBS>) -> (Self, CtChoice) {
39 (Self(n), n.ct_is_nonzero())
40 }
41 }
42
43 impl<T> NonZero<T>
44 where
45 T: Zero,
46 {
47 /// Create a new non-zero integer.
48 pub fn new(n: T) -> CtOption<Self> {
49 let is_zero = n.is_zero();
50 CtOption::new(Self(n), !is_zero)
51 }
52 }
53
54 impl<T> NonZero<T>
55 where
56 T: Integer,
57 {
58 /// The value `1`.
59 pub const ONE: Self = Self(T::ONE);
60
61 /// Maximum value this integer can express.
62 pub const MAX: Self = Self(T::MAX);
63 }
64
65 impl<T> NonZero<T>
66 where
67 T: Encoding + Zero,
68 {
69 /// Decode from big endian bytes.
70 pub fn from_be_bytes(bytes: T::Repr) -> CtOption<Self> {
71 Self::new(T::from_be_bytes(bytes))
72 }
73
74 /// Decode from little endian bytes.
75 pub fn from_le_bytes(bytes: T::Repr) -> CtOption<Self> {
76 Self::new(T::from_le_bytes(bytes))
77 }
78 }
79
80 #[cfg(feature = "generic-array")]
81 impl<T> NonZero<T>
82 where
83 T: ArrayEncoding + Zero,
84 {
85 /// Decode a non-zero integer from big endian bytes.
86 pub fn from_be_byte_array(bytes: ByteArray<T>) -> CtOption<Self> {
87 Self::new(T::from_be_byte_array(bytes))
88 }
89
90 /// Decode a non-zero integer from big endian bytes.
91 pub fn from_le_byte_array(bytes: ByteArray<T>) -> CtOption<Self> {
92 Self::new(T::from_be_byte_array(bytes))
93 }
94 }
95
96 impl<T> AsRef<T> for NonZero<T>
97 where
98 T: Zero,
99 {
100 fn as_ref(&self) -> &T {
101 &self.0
102 }
103 }
104
105 impl<T> ConditionallySelectable for NonZero<T>
106 where
107 T: ConditionallySelectable + Zero,
108 {
109 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
110 Self(T::conditional_select(&a.0, &b.0, choice))
111 }
112 }
113
114 impl<T> ConstantTimeEq for NonZero<T>
115 where
116 T: Zero,
117 {
118 fn ct_eq(&self, other: &Self) -> Choice {
119 self.0.ct_eq(&other.0)
120 }
121 }
122
123 impl<T> Deref for NonZero<T>
124 where
125 T: Zero,
126 {
127 type Target = T;
128
129 fn deref(&self) -> &T {
130 &self.0
131 }
132 }
133
134 #[cfg(feature = "rand_core")]
135 impl<T> Random for NonZero<T>
136 where
137 T: Random + Zero,
138 {
139 /// Generate a random `NonZero<T>`.
140 fn random(mut rng: &mut impl CryptoRngCore) -> Self {
141 // Use rejection sampling to eliminate zero values.
142 // While this method isn't constant-time, the attacker shouldn't learn
143 // anything about unrelated outputs so long as `rng` is a CSRNG.
144 loop {
145 if let Some(result) = Self::new(T::random(&mut rng)).into() {
146 break result;
147 }
148 }
149 }
150 }
151
152 impl<T> fmt::Display for NonZero<T>
153 where
154 T: fmt::Display + Zero,
155 {
156 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157 fmt::Display::fmt(&self.0, f)
158 }
159 }
160
161 impl<T> fmt::Binary for NonZero<T>
162 where
163 T: fmt::Binary + Zero,
164 {
165 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166 fmt::Binary::fmt(&self.0, f)
167 }
168 }
169
170 impl<T> fmt::Octal for NonZero<T>
171 where
172 T: fmt::Octal + Zero,
173 {
174 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175 fmt::Octal::fmt(&self.0, f)
176 }
177 }
178
179 impl<T> fmt::LowerHex for NonZero<T>
180 where
181 T: fmt::LowerHex + Zero,
182 {
183 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184 fmt::LowerHex::fmt(&self.0, f)
185 }
186 }
187
188 impl<T> fmt::UpperHex for NonZero<T>
189 where
190 T: fmt::UpperHex + Zero,
191 {
192 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193 fmt::UpperHex::fmt(&self.0, f)
194 }
195 }
196
197 impl NonZero<Limb> {
198 /// Create a [`NonZero<Limb>`] from a [`NonZeroU8`] (const-friendly)
199 // TODO(tarcieri): replace with `const impl From<NonZeroU8>` when stable
200 pub const fn from_u8(n: NonZeroU8) -> Self {
201 Self(Limb::from_u8(n.get()))
202 }
203
204 /// Create a [`NonZero<Limb>`] from a [`NonZeroU16`] (const-friendly)
205 // TODO(tarcieri): replace with `const impl From<NonZeroU16>` when stable
206 pub const fn from_u16(n: NonZeroU16) -> Self {
207 Self(Limb::from_u16(n.get()))
208 }
209
210 /// Create a [`NonZero<Limb>`] from a [`NonZeroU32`] (const-friendly)
211 // TODO(tarcieri): replace with `const impl From<NonZeroU32>` when stable
212 pub const fn from_u32(n: NonZeroU32) -> Self {
213 Self(Limb::from_u32(n.get()))
214 }
215
216 /// Create a [`NonZero<Limb>`] from a [`NonZeroU64`] (const-friendly)
217 // TODO(tarcieri): replace with `const impl From<NonZeroU64>` when stable
218 #[cfg(target_pointer_width = "64")]
219 pub const fn from_u64(n: NonZeroU64) -> Self {
220 Self(Limb::from_u64(n.get()))
221 }
222 }
223
224 impl From<NonZeroU8> for NonZero<Limb> {
225 fn from(integer: NonZeroU8) -> Self {
226 Self::from_u8(integer)
227 }
228 }
229
230 impl From<NonZeroU16> for NonZero<Limb> {
231 fn from(integer: NonZeroU16) -> Self {
232 Self::from_u16(integer)
233 }
234 }
235
236 impl From<NonZeroU32> for NonZero<Limb> {
237 fn from(integer: NonZeroU32) -> Self {
238 Self::from_u32(integer)
239 }
240 }
241
242 #[cfg(target_pointer_width = "64")]
243 impl From<NonZeroU64> for NonZero<Limb> {
244 fn from(integer: NonZeroU64) -> Self {
245 Self::from_u64(integer)
246 }
247 }
248
249 impl<const LIMBS: usize> NonZero<Uint<LIMBS>> {
250 /// Create a [`NonZero<Uint>`] from a [`Uint`] (const-friendly)
251 pub const fn from_uint(n: Uint<LIMBS>) -> Self {
252 let mut i = 0;
253 let mut found_non_zero = false;
254 while i < LIMBS {
255 if n.as_limbs()[i].0 != 0 {
256 found_non_zero = true;
257 }
258 i += 1;
259 }
260 assert!(found_non_zero, "found zero");
261 Self(n)
262 }
263
264 /// Create a [`NonZero<Uint>`] from a [`NonZeroU8`] (const-friendly)
265 // TODO(tarcieri): replace with `const impl From<NonZeroU8>` when stable
266 pub const fn from_u8(n: NonZeroU8) -> Self {
267 Self(Uint::from_u8(n.get()))
268 }
269
270 /// Create a [`NonZero<Uint>`] from a [`NonZeroU16`] (const-friendly)
271 // TODO(tarcieri): replace with `const impl From<NonZeroU16>` when stable
272 pub const fn from_u16(n: NonZeroU16) -> Self {
273 Self(Uint::from_u16(n.get()))
274 }
275
276 /// Create a [`NonZero<Uint>`] from a [`NonZeroU32`] (const-friendly)
277 // TODO(tarcieri): replace with `const impl From<NonZeroU32>` when stable
278 pub const fn from_u32(n: NonZeroU32) -> Self {
279 Self(Uint::from_u32(n.get()))
280 }
281
282 /// Create a [`NonZero<Uint>`] from a [`NonZeroU64`] (const-friendly)
283 // TODO(tarcieri): replace with `const impl From<NonZeroU64>` when stable
284 pub const fn from_u64(n: NonZeroU64) -> Self {
285 Self(Uint::from_u64(n.get()))
286 }
287
288 /// Create a [`NonZero<Uint>`] from a [`NonZeroU128`] (const-friendly)
289 // TODO(tarcieri): replace with `const impl From<NonZeroU128>` when stable
290 pub const fn from_u128(n: NonZeroU128) -> Self {
291 Self(Uint::from_u128(n.get()))
292 }
293 }
294
295 impl<const LIMBS: usize> From<NonZeroU8> for NonZero<Uint<LIMBS>> {
296 fn from(integer: NonZeroU8) -> Self {
297 Self::from_u8(integer)
298 }
299 }
300
301 impl<const LIMBS: usize> From<NonZeroU16> for NonZero<Uint<LIMBS>> {
302 fn from(integer: NonZeroU16) -> Self {
303 Self::from_u16(integer)
304 }
305 }
306
307 impl<const LIMBS: usize> From<NonZeroU32> for NonZero<Uint<LIMBS>> {
308 fn from(integer: NonZeroU32) -> Self {
309 Self::from_u32(integer)
310 }
311 }
312
313 impl<const LIMBS: usize> From<NonZeroU64> for NonZero<Uint<LIMBS>> {
314 fn from(integer: NonZeroU64) -> Self {
315 Self::from_u64(integer)
316 }
317 }
318
319 impl<const LIMBS: usize> From<NonZeroU128> for NonZero<Uint<LIMBS>> {
320 fn from(integer: NonZeroU128) -> Self {
321 Self::from_u128(integer)
322 }
323 }
324
325 #[cfg(feature = "serde")]
326 impl<'de, T: Deserialize<'de> + Zero> Deserialize<'de> for NonZero<T> {
327 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
328 where
329 D: Deserializer<'de>,
330 {
331 let value: T = T::deserialize(deserializer)?;
332
333 if bool::from(value.is_zero()) {
334 Err(D::Error::invalid_value(
335 Unexpected::Other("zero"),
336 &"a non-zero value",
337 ))
338 } else {
339 Ok(Self(value))
340 }
341 }
342 }
343
344 #[cfg(feature = "serde")]
345 impl<T: Serialize + Zero> Serialize for NonZero<T> {
346 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
347 where
348 S: Serializer,
349 {
350 self.0.serialize(serializer)
351 }
352 }
353
354 #[cfg(all(test, feature = "serde"))]
355 #[allow(clippy::unwrap_used)]
356 mod tests {
357 use crate::{NonZero, U64};
358 use bincode::ErrorKind;
359
360 #[test]
361 fn serde() {
362 let test =
363 Option::<NonZero<U64>>::from(NonZero::new(U64::from_u64(0x0011223344556677))).unwrap();
364
365 let serialized = bincode::serialize(&test).unwrap();
366 let deserialized: NonZero<U64> = bincode::deserialize(&serialized).unwrap();
367
368 assert_eq!(test, deserialized);
369
370 let serialized = bincode::serialize(&U64::ZERO).unwrap();
371 assert!(matches!(
372 *bincode::deserialize::<NonZero<U64>>(&serialized).unwrap_err(),
373 ErrorKind::Custom(message) if message == "invalid value: zero, expected a non-zero value"
374 ));
375 }
376
377 #[test]
378 fn serde_owned() {
379 let test =
380 Option::<NonZero<U64>>::from(NonZero::new(U64::from_u64(0x0011223344556677))).unwrap();
381
382 let serialized = bincode::serialize(&test).unwrap();
383 let deserialized: NonZero<U64> = bincode::deserialize_from(serialized.as_slice()).unwrap();
384
385 assert_eq!(test, deserialized);
386
387 let serialized = bincode::serialize(&U64::ZERO).unwrap();
388 assert!(matches!(
389 *bincode::deserialize_from::<_, NonZero<U64>>(serialized.as_slice()).unwrap_err(),
390 ErrorKind::Custom(message) if message == "invalid value: zero, expected a non-zero value"
391 ));
392 }
393 }