]> git.proxmox.com Git - rustc.git/blame - vendor/elliptic-curve/src/hash2curve/hash2field/expand_msg.rs
New upstream version 1.71.1+dfsg1
[rustc.git] / vendor / elliptic-curve / src / hash2curve / hash2field / expand_msg.rs
CommitLineData
0a29b90c
FG
1//! `expand_message` interface `for hash_to_field`.
2
3pub(super) mod xmd;
4pub(super) mod xof;
5
6use crate::{Error, Result};
7use digest::{Digest, ExtendableOutput, Update, XofReader};
8use generic_array::typenum::{IsLess, U256};
9use generic_array::{ArrayLength, GenericArray};
10
11/// Salt when the DST is too long
12const OVERSIZE_DST_SALT: &[u8] = b"H2C-OVERSIZE-DST-";
13/// Maximum domain separation tag length
14const MAX_DST_LEN: usize = 255;
15
16/// Trait for types implementing expand_message interface for `hash_to_field`.
17///
18/// # Errors
19/// See implementors of [`ExpandMsg`] for errors.
20pub trait ExpandMsg<'a> {
21 /// Type holding data for the [`Expander`].
22 type Expander: Expander + Sized;
23
24 /// Expands `msg` to the required number of bytes.
25 ///
26 /// Returns an expander that can be used to call `read` until enough
27 /// bytes have been consumed
49aad941
FG
28 fn expand_message(
29 msgs: &[&[u8]],
30 dsts: &'a [&'a [u8]],
31 len_in_bytes: usize,
32 ) -> Result<Self::Expander>;
0a29b90c
FG
33}
34
35/// Expander that, call `read` until enough bytes have been consumed.
36pub trait Expander {
37 /// Fill the array with the expanded bytes
38 fn fill_bytes(&mut self, okm: &mut [u8]);
39}
40
41/// The domain separation tag
42///
43/// Implements [section 5.4.3 of `draft-irtf-cfrg-hash-to-curve-13`][dst].
44///
45/// [dst]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-hash-to-curve-13#section-5.4.3
46pub(crate) enum Domain<'a, L>
47where
48 L: ArrayLength<u8> + IsLess<U256>,
49{
50 /// > 255
51 Hashed(GenericArray<u8, L>),
52 /// <= 255
49aad941 53 Array(&'a [&'a [u8]]),
0a29b90c
FG
54}
55
56impl<'a, L> Domain<'a, L>
57where
58 L: ArrayLength<u8> + IsLess<U256>,
59{
49aad941 60 pub fn xof<X>(dsts: &'a [&'a [u8]]) -> Result<Self>
0a29b90c
FG
61 where
62 X: Default + ExtendableOutput + Update,
63 {
49aad941 64 if dsts.is_empty() {
0a29b90c 65 Err(Error)
49aad941 66 } else if dsts.iter().map(|dst| dst.len()).sum::<usize>() > MAX_DST_LEN {
0a29b90c 67 let mut data = GenericArray::<u8, L>::default();
49aad941
FG
68 let mut hash = X::default();
69 hash.update(OVERSIZE_DST_SALT);
70
71 for dst in dsts {
72 hash.update(dst);
73 }
74
75 hash.finalize_xof().read(&mut data);
76
0a29b90c
FG
77 Ok(Self::Hashed(data))
78 } else {
49aad941 79 Ok(Self::Array(dsts))
0a29b90c
FG
80 }
81 }
82
49aad941 83 pub fn xmd<X>(dsts: &'a [&'a [u8]]) -> Result<Self>
0a29b90c
FG
84 where
85 X: Digest<OutputSize = L>,
86 {
49aad941 87 if dsts.is_empty() {
0a29b90c 88 Err(Error)
49aad941 89 } else if dsts.iter().map(|dst| dst.len()).sum::<usize>() > MAX_DST_LEN {
0a29b90c
FG
90 Ok(Self::Hashed({
91 let mut hash = X::new();
92 hash.update(OVERSIZE_DST_SALT);
49aad941
FG
93
94 for dst in dsts {
95 hash.update(dst);
96 }
97
0a29b90c
FG
98 hash.finalize()
99 }))
100 } else {
49aad941 101 Ok(Self::Array(dsts))
0a29b90c
FG
102 }
103 }
104
49aad941 105 pub fn update_hash<HashT: Update>(&self, hash: &mut HashT) {
0a29b90c 106 match self {
49aad941
FG
107 Self::Hashed(d) => hash.update(d),
108 Self::Array(d) => {
109 for d in d.iter() {
110 hash.update(d)
111 }
112 }
0a29b90c
FG
113 }
114 }
115
116 pub fn len(&self) -> u8 {
117 match self {
118 // Can't overflow because it's enforced on a type level.
119 Self::Hashed(_) => L::to_u8(),
120 // Can't overflow because it's checked on creation.
49aad941
FG
121 Self::Array(d) => {
122 u8::try_from(d.iter().map(|d| d.len()).sum::<usize>()).expect("length overflow")
123 }
0a29b90c
FG
124 }
125 }
126
127 #[cfg(test)]
128 pub fn assert(&self, bytes: &[u8]) {
49aad941
FG
129 let data = match self {
130 Domain::Hashed(d) => d.to_vec(),
131 Domain::Array(d) => d.iter().copied().flatten().copied().collect(),
132 };
133 assert_eq!(data, bytes);
134 }
135
136 #[cfg(test)]
137 pub fn assert_dst(&self, bytes: &[u8]) {
138 let data = match self {
139 Domain::Hashed(d) => d.to_vec(),
140 Domain::Array(d) => d.iter().copied().flatten().copied().collect(),
141 };
142 assert_eq!(data, &bytes[..bytes.len() - 1]);
0a29b90c
FG
143 assert_eq!(self.len(), bytes[bytes.len() - 1]);
144 }
145}