]> git.proxmox.com Git - rustc.git/blob - vendor/blake2b_simd/src/avx2.rs
New upstream version 1.49.0+dfsg1
[rustc.git] / vendor / blake2b_simd / src / avx2.rs
1 #[cfg(target_arch = "x86")]
2 use core::arch::x86::*;
3 #[cfg(target_arch = "x86_64")]
4 use core::arch::x86_64::*;
5
6 use crate::guts::{
7 assemble_count, count_high, count_low, final_block, flag_word, input_debug_asserts, Finalize,
8 Job, LastNode, Stride,
9 };
10 use crate::{Count, Word, BLOCKBYTES, IV, SIGMA};
11 use arrayref::{array_refs, mut_array_refs};
12 use core::cmp;
13 use core::mem;
14
15 pub const DEGREE: usize = 4;
16
17 #[inline(always)]
18 unsafe fn loadu(src: *const [Word; DEGREE]) -> __m256i {
19 // This is an unaligned load, so the pointer cast is allowed.
20 _mm256_loadu_si256(src as *const __m256i)
21 }
22
23 #[inline(always)]
24 unsafe fn storeu(src: __m256i, dest: *mut [Word; DEGREE]) {
25 // This is an unaligned store, so the pointer cast is allowed.
26 _mm256_storeu_si256(dest as *mut __m256i, src)
27 }
28
29 #[inline(always)]
30 unsafe fn loadu_128(mem_addr: &[u8; 16]) -> __m128i {
31 _mm_loadu_si128(mem_addr.as_ptr() as *const __m128i)
32 }
33
34 #[inline(always)]
35 unsafe fn add(a: __m256i, b: __m256i) -> __m256i {
36 _mm256_add_epi64(a, b)
37 }
38
39 #[inline(always)]
40 unsafe fn eq(a: __m256i, b: __m256i) -> __m256i {
41 _mm256_cmpeq_epi64(a, b)
42 }
43
44 #[inline(always)]
45 unsafe fn and(a: __m256i, b: __m256i) -> __m256i {
46 _mm256_and_si256(a, b)
47 }
48
49 #[inline(always)]
50 unsafe fn negate_and(a: __m256i, b: __m256i) -> __m256i {
51 // Note that "and not" implies the reverse of the actual arg order.
52 _mm256_andnot_si256(a, b)
53 }
54
55 #[inline(always)]
56 unsafe fn xor(a: __m256i, b: __m256i) -> __m256i {
57 _mm256_xor_si256(a, b)
58 }
59
60 #[inline(always)]
61 unsafe fn set1(x: u64) -> __m256i {
62 _mm256_set1_epi64x(x as i64)
63 }
64
65 #[inline(always)]
66 unsafe fn set4(a: u64, b: u64, c: u64, d: u64) -> __m256i {
67 _mm256_setr_epi64x(a as i64, b as i64, c as i64, d as i64)
68 }
69
70 // Adapted from https://github.com/rust-lang-nursery/stdsimd/pull/479.
71 macro_rules! _MM_SHUFFLE {
72 ($z:expr, $y:expr, $x:expr, $w:expr) => {
73 ($z << 6) | ($y << 4) | ($x << 2) | $w
74 };
75 }
76
77 // These rotations are the "simple version". For the "complicated version", see
78 // https://github.com/sneves/blake2-avx2/blob/b3723921f668df09ece52dcd225a36d4a4eea1d9/blake2b-common.h#L43-L46.
79 // For a discussion of the tradeoffs, see
80 // https://github.com/sneves/blake2-avx2/pull/5. In short:
81 // - This version performs better on modern x86 chips, Skylake and later.
82 // - LLVM is able to optimize this version to AVX-512 rotation instructions
83 // when those are enabled.
84
85 #[inline(always)]
86 unsafe fn rot32(x: __m256i) -> __m256i {
87 _mm256_or_si256(_mm256_srli_epi64(x, 32), _mm256_slli_epi64(x, 64 - 32))
88 }
89
90 #[inline(always)]
91 unsafe fn rot24(x: __m256i) -> __m256i {
92 _mm256_or_si256(_mm256_srli_epi64(x, 24), _mm256_slli_epi64(x, 64 - 24))
93 }
94
95 #[inline(always)]
96 unsafe fn rot16(x: __m256i) -> __m256i {
97 _mm256_or_si256(_mm256_srli_epi64(x, 16), _mm256_slli_epi64(x, 64 - 16))
98 }
99
100 #[inline(always)]
101 unsafe fn rot63(x: __m256i) -> __m256i {
102 _mm256_or_si256(_mm256_srli_epi64(x, 63), _mm256_slli_epi64(x, 64 - 63))
103 }
104
105 #[inline(always)]
106 unsafe fn g1(a: &mut __m256i, b: &mut __m256i, c: &mut __m256i, d: &mut __m256i, m: &mut __m256i) {
107 *a = add(*a, *m);
108 *a = add(*a, *b);
109 *d = xor(*d, *a);
110 *d = rot32(*d);
111 *c = add(*c, *d);
112 *b = xor(*b, *c);
113 *b = rot24(*b);
114 }
115
116 #[inline(always)]
117 unsafe fn g2(a: &mut __m256i, b: &mut __m256i, c: &mut __m256i, d: &mut __m256i, m: &mut __m256i) {
118 *a = add(*a, *m);
119 *a = add(*a, *b);
120 *d = xor(*d, *a);
121 *d = rot16(*d);
122 *c = add(*c, *d);
123 *b = xor(*b, *c);
124 *b = rot63(*b);
125 }
126
127 // Note the optimization here of leaving b as the unrotated row, rather than a.
128 // All the message loads below are adjusted to compensate for this. See
129 // discussion at https://github.com/sneves/blake2-avx2/pull/4
130 #[inline(always)]
131 unsafe fn diagonalize(a: &mut __m256i, _b: &mut __m256i, c: &mut __m256i, d: &mut __m256i) {
132 *a = _mm256_permute4x64_epi64(*a, _MM_SHUFFLE!(2, 1, 0, 3));
133 *d = _mm256_permute4x64_epi64(*d, _MM_SHUFFLE!(1, 0, 3, 2));
134 *c = _mm256_permute4x64_epi64(*c, _MM_SHUFFLE!(0, 3, 2, 1));
135 }
136
137 #[inline(always)]
138 unsafe fn undiagonalize(a: &mut __m256i, _b: &mut __m256i, c: &mut __m256i, d: &mut __m256i) {
139 *a = _mm256_permute4x64_epi64(*a, _MM_SHUFFLE!(0, 3, 2, 1));
140 *d = _mm256_permute4x64_epi64(*d, _MM_SHUFFLE!(1, 0, 3, 2));
141 *c = _mm256_permute4x64_epi64(*c, _MM_SHUFFLE!(2, 1, 0, 3));
142 }
143
144 #[inline(always)]
145 unsafe fn compress_block(
146 block: &[u8; BLOCKBYTES],
147 words: &mut [Word; 8],
148 count: Count,
149 last_block: Word,
150 last_node: Word,
151 ) {
152 let (words_low, words_high) = mut_array_refs!(words, DEGREE, DEGREE);
153 let (iv_low, iv_high) = array_refs!(&IV, DEGREE, DEGREE);
154 let mut a = loadu(words_low);
155 let mut b = loadu(words_high);
156 let mut c = loadu(iv_low);
157 let flags = set4(count_low(count), count_high(count), last_block, last_node);
158 let mut d = xor(loadu(iv_high), flags);
159
160 let msg_chunks = array_refs!(block, 16, 16, 16, 16, 16, 16, 16, 16);
161 let m0 = _mm256_broadcastsi128_si256(loadu_128(msg_chunks.0));
162 let m1 = _mm256_broadcastsi128_si256(loadu_128(msg_chunks.1));
163 let m2 = _mm256_broadcastsi128_si256(loadu_128(msg_chunks.2));
164 let m3 = _mm256_broadcastsi128_si256(loadu_128(msg_chunks.3));
165 let m4 = _mm256_broadcastsi128_si256(loadu_128(msg_chunks.4));
166 let m5 = _mm256_broadcastsi128_si256(loadu_128(msg_chunks.5));
167 let m6 = _mm256_broadcastsi128_si256(loadu_128(msg_chunks.6));
168 let m7 = _mm256_broadcastsi128_si256(loadu_128(msg_chunks.7));
169
170 let iv0 = a;
171 let iv1 = b;
172 let mut t0;
173 let mut t1;
174 let mut b0;
175
176 // round 1
177 t0 = _mm256_unpacklo_epi64(m0, m1);
178 t1 = _mm256_unpacklo_epi64(m2, m3);
179 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
180 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
181 t0 = _mm256_unpackhi_epi64(m0, m1);
182 t1 = _mm256_unpackhi_epi64(m2, m3);
183 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
184 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
185 diagonalize(&mut a, &mut b, &mut c, &mut d);
186 t0 = _mm256_unpacklo_epi64(m7, m4);
187 t1 = _mm256_unpacklo_epi64(m5, m6);
188 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
189 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
190 t0 = _mm256_unpackhi_epi64(m7, m4);
191 t1 = _mm256_unpackhi_epi64(m5, m6);
192 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
193 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
194 undiagonalize(&mut a, &mut b, &mut c, &mut d);
195
196 // round 2
197 t0 = _mm256_unpacklo_epi64(m7, m2);
198 t1 = _mm256_unpackhi_epi64(m4, m6);
199 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
200 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
201 t0 = _mm256_unpacklo_epi64(m5, m4);
202 t1 = _mm256_alignr_epi8(m3, m7, 8);
203 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
204 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
205 diagonalize(&mut a, &mut b, &mut c, &mut d);
206 t0 = _mm256_unpackhi_epi64(m2, m0);
207 t1 = _mm256_blend_epi32(m5, m0, 0x33);
208 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
209 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
210 t0 = _mm256_alignr_epi8(m6, m1, 8);
211 t1 = _mm256_blend_epi32(m3, m1, 0x33);
212 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
213 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
214 undiagonalize(&mut a, &mut b, &mut c, &mut d);
215
216 // round 3
217 t0 = _mm256_alignr_epi8(m6, m5, 8);
218 t1 = _mm256_unpackhi_epi64(m2, m7);
219 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
220 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
221 t0 = _mm256_unpacklo_epi64(m4, m0);
222 t1 = _mm256_blend_epi32(m6, m1, 0x33);
223 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
224 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
225 diagonalize(&mut a, &mut b, &mut c, &mut d);
226 t0 = _mm256_alignr_epi8(m5, m4, 8);
227 t1 = _mm256_unpackhi_epi64(m1, m3);
228 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
229 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
230 t0 = _mm256_unpacklo_epi64(m2, m7);
231 t1 = _mm256_blend_epi32(m0, m3, 0x33);
232 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
233 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
234 undiagonalize(&mut a, &mut b, &mut c, &mut d);
235
236 // round 4
237 t0 = _mm256_unpackhi_epi64(m3, m1);
238 t1 = _mm256_unpackhi_epi64(m6, m5);
239 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
240 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
241 t0 = _mm256_unpackhi_epi64(m4, m0);
242 t1 = _mm256_unpacklo_epi64(m6, m7);
243 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
244 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
245 diagonalize(&mut a, &mut b, &mut c, &mut d);
246 t0 = _mm256_alignr_epi8(m1, m7, 8);
247 t1 = _mm256_shuffle_epi32(m2, _MM_SHUFFLE!(1, 0, 3, 2));
248 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
249 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
250 t0 = _mm256_unpacklo_epi64(m4, m3);
251 t1 = _mm256_unpacklo_epi64(m5, m0);
252 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
253 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
254 undiagonalize(&mut a, &mut b, &mut c, &mut d);
255
256 // round 5
257 t0 = _mm256_unpackhi_epi64(m4, m2);
258 t1 = _mm256_unpacklo_epi64(m1, m5);
259 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
260 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
261 t0 = _mm256_blend_epi32(m3, m0, 0x33);
262 t1 = _mm256_blend_epi32(m7, m2, 0x33);
263 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
264 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
265 diagonalize(&mut a, &mut b, &mut c, &mut d);
266 t0 = _mm256_alignr_epi8(m7, m1, 8);
267 t1 = _mm256_alignr_epi8(m3, m5, 8);
268 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
269 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
270 t0 = _mm256_unpackhi_epi64(m6, m0);
271 t1 = _mm256_unpacklo_epi64(m6, m4);
272 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
273 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
274 undiagonalize(&mut a, &mut b, &mut c, &mut d);
275
276 // round 6
277 t0 = _mm256_unpacklo_epi64(m1, m3);
278 t1 = _mm256_unpacklo_epi64(m0, m4);
279 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
280 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
281 t0 = _mm256_unpacklo_epi64(m6, m5);
282 t1 = _mm256_unpackhi_epi64(m5, m1);
283 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
284 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
285 diagonalize(&mut a, &mut b, &mut c, &mut d);
286 t0 = _mm256_alignr_epi8(m2, m0, 8);
287 t1 = _mm256_unpackhi_epi64(m3, m7);
288 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
289 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
290 t0 = _mm256_unpackhi_epi64(m4, m6);
291 t1 = _mm256_alignr_epi8(m7, m2, 8);
292 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
293 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
294 undiagonalize(&mut a, &mut b, &mut c, &mut d);
295
296 // round 7
297 t0 = _mm256_blend_epi32(m0, m6, 0x33);
298 t1 = _mm256_unpacklo_epi64(m7, m2);
299 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
300 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
301 t0 = _mm256_unpackhi_epi64(m2, m7);
302 t1 = _mm256_alignr_epi8(m5, m6, 8);
303 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
304 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
305 diagonalize(&mut a, &mut b, &mut c, &mut d);
306 t0 = _mm256_unpacklo_epi64(m4, m0);
307 t1 = _mm256_blend_epi32(m4, m3, 0x33);
308 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
309 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
310 t0 = _mm256_unpackhi_epi64(m5, m3);
311 t1 = _mm256_shuffle_epi32(m1, _MM_SHUFFLE!(1, 0, 3, 2));
312 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
313 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
314 undiagonalize(&mut a, &mut b, &mut c, &mut d);
315
316 // round 8
317 t0 = _mm256_unpackhi_epi64(m6, m3);
318 t1 = _mm256_blend_epi32(m1, m6, 0x33);
319 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
320 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
321 t0 = _mm256_alignr_epi8(m7, m5, 8);
322 t1 = _mm256_unpackhi_epi64(m0, m4);
323 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
324 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
325 diagonalize(&mut a, &mut b, &mut c, &mut d);
326 t0 = _mm256_blend_epi32(m2, m1, 0x33);
327 t1 = _mm256_alignr_epi8(m4, m7, 8);
328 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
329 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
330 t0 = _mm256_unpacklo_epi64(m5, m0);
331 t1 = _mm256_unpacklo_epi64(m2, m3);
332 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
333 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
334 undiagonalize(&mut a, &mut b, &mut c, &mut d);
335
336 // round 9
337 t0 = _mm256_unpacklo_epi64(m3, m7);
338 t1 = _mm256_alignr_epi8(m0, m5, 8);
339 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
340 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
341 t0 = _mm256_unpackhi_epi64(m7, m4);
342 t1 = _mm256_alignr_epi8(m4, m1, 8);
343 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
344 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
345 diagonalize(&mut a, &mut b, &mut c, &mut d);
346 t0 = _mm256_unpacklo_epi64(m5, m6);
347 t1 = _mm256_unpackhi_epi64(m6, m0);
348 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
349 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
350 t0 = _mm256_alignr_epi8(m1, m2, 8);
351 t1 = _mm256_alignr_epi8(m2, m3, 8);
352 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
353 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
354 undiagonalize(&mut a, &mut b, &mut c, &mut d);
355
356 // round 10
357 t0 = _mm256_unpacklo_epi64(m5, m4);
358 t1 = _mm256_unpackhi_epi64(m3, m0);
359 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
360 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
361 t0 = _mm256_unpacklo_epi64(m1, m2);
362 t1 = _mm256_blend_epi32(m2, m3, 0x33);
363 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
364 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
365 diagonalize(&mut a, &mut b, &mut c, &mut d);
366 t0 = _mm256_unpackhi_epi64(m6, m7);
367 t1 = _mm256_unpackhi_epi64(m4, m1);
368 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
369 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
370 t0 = _mm256_blend_epi32(m5, m0, 0x33);
371 t1 = _mm256_unpacklo_epi64(m7, m6);
372 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
373 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
374 undiagonalize(&mut a, &mut b, &mut c, &mut d);
375
376 // round 11
377 t0 = _mm256_unpacklo_epi64(m0, m1);
378 t1 = _mm256_unpacklo_epi64(m2, m3);
379 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
380 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
381 t0 = _mm256_unpackhi_epi64(m0, m1);
382 t1 = _mm256_unpackhi_epi64(m2, m3);
383 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
384 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
385 diagonalize(&mut a, &mut b, &mut c, &mut d);
386 t0 = _mm256_unpacklo_epi64(m7, m4);
387 t1 = _mm256_unpacklo_epi64(m5, m6);
388 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
389 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
390 t0 = _mm256_unpackhi_epi64(m7, m4);
391 t1 = _mm256_unpackhi_epi64(m5, m6);
392 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
393 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
394 undiagonalize(&mut a, &mut b, &mut c, &mut d);
395
396 // round 12
397 t0 = _mm256_unpacklo_epi64(m7, m2);
398 t1 = _mm256_unpackhi_epi64(m4, m6);
399 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
400 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
401 t0 = _mm256_unpacklo_epi64(m5, m4);
402 t1 = _mm256_alignr_epi8(m3, m7, 8);
403 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
404 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
405 diagonalize(&mut a, &mut b, &mut c, &mut d);
406 t0 = _mm256_unpackhi_epi64(m2, m0);
407 t1 = _mm256_blend_epi32(m5, m0, 0x33);
408 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
409 g1(&mut a, &mut b, &mut c, &mut d, &mut b0);
410 t0 = _mm256_alignr_epi8(m6, m1, 8);
411 t1 = _mm256_blend_epi32(m3, m1, 0x33);
412 b0 = _mm256_blend_epi32(t0, t1, 0xF0);
413 g2(&mut a, &mut b, &mut c, &mut d, &mut b0);
414 undiagonalize(&mut a, &mut b, &mut c, &mut d);
415
416 a = xor(a, c);
417 b = xor(b, d);
418 a = xor(a, iv0);
419 b = xor(b, iv1);
420
421 storeu(a, words_low);
422 storeu(b, words_high);
423 }
424
425 #[target_feature(enable = "avx2")]
426 pub unsafe fn compress1_loop(
427 input: &[u8],
428 words: &mut [Word; 8],
429 mut count: Count,
430 last_node: LastNode,
431 finalize: Finalize,
432 stride: Stride,
433 ) {
434 input_debug_asserts(input, finalize);
435
436 let mut local_words = *words;
437
438 let mut fin_offset = input.len().saturating_sub(1);
439 fin_offset -= fin_offset % stride.padded_blockbytes();
440 let mut buf = [0; BLOCKBYTES];
441 let (fin_block, fin_len, _) = final_block(input, fin_offset, &mut buf, stride);
442 let fin_last_block = flag_word(finalize.yes());
443 let fin_last_node = flag_word(finalize.yes() && last_node.yes());
444
445 let mut offset = 0;
446 loop {
447 let block;
448 let count_delta;
449 let last_block;
450 let last_node;
451 if offset == fin_offset {
452 block = fin_block;
453 count_delta = fin_len;
454 last_block = fin_last_block;
455 last_node = fin_last_node;
456 } else {
457 // This unsafe cast avoids bounds checks. There's guaranteed to be
458 // enough input because `offset < fin_offset`.
459 block = &*(input.as_ptr().add(offset) as *const [u8; BLOCKBYTES]);
460 count_delta = BLOCKBYTES;
461 last_block = flag_word(false);
462 last_node = flag_word(false);
463 };
464
465 count = count.wrapping_add(count_delta as Count);
466 compress_block(block, &mut local_words, count, last_block, last_node);
467
468 // Check for termination before bumping the offset, to avoid overflow.
469 if offset == fin_offset {
470 break;
471 }
472
473 offset += stride.padded_blockbytes();
474 }
475
476 *words = local_words;
477 }
478
479 // Performance note: Factoring out a G function here doesn't hurt performance,
480 // unlike in the case of BLAKE2s where it hurts substantially. In fact, on my
481 // machine, it helps a tiny bit. But the difference it tiny, so I'm going to
482 // stick to the approach used by https://github.com/sneves/blake2-avx2
483 // until/unless I can be sure the (tiny) improvement is consistent across
484 // different Intel microarchitectures. Smaller code size is nice, but a
485 // divergence between the BLAKE2b and BLAKE2s implementations is less nice.
486 #[inline(always)]
487 unsafe fn round(v: &mut [__m256i; 16], m: &[__m256i; 16], r: usize) {
488 v[0] = add(v[0], m[SIGMA[r][0] as usize]);
489 v[1] = add(v[1], m[SIGMA[r][2] as usize]);
490 v[2] = add(v[2], m[SIGMA[r][4] as usize]);
491 v[3] = add(v[3], m[SIGMA[r][6] as usize]);
492 v[0] = add(v[0], v[4]);
493 v[1] = add(v[1], v[5]);
494 v[2] = add(v[2], v[6]);
495 v[3] = add(v[3], v[7]);
496 v[12] = xor(v[12], v[0]);
497 v[13] = xor(v[13], v[1]);
498 v[14] = xor(v[14], v[2]);
499 v[15] = xor(v[15], v[3]);
500 v[12] = rot32(v[12]);
501 v[13] = rot32(v[13]);
502 v[14] = rot32(v[14]);
503 v[15] = rot32(v[15]);
504 v[8] = add(v[8], v[12]);
505 v[9] = add(v[9], v[13]);
506 v[10] = add(v[10], v[14]);
507 v[11] = add(v[11], v[15]);
508 v[4] = xor(v[4], v[8]);
509 v[5] = xor(v[5], v[9]);
510 v[6] = xor(v[6], v[10]);
511 v[7] = xor(v[7], v[11]);
512 v[4] = rot24(v[4]);
513 v[5] = rot24(v[5]);
514 v[6] = rot24(v[6]);
515 v[7] = rot24(v[7]);
516 v[0] = add(v[0], m[SIGMA[r][1] as usize]);
517 v[1] = add(v[1], m[SIGMA[r][3] as usize]);
518 v[2] = add(v[2], m[SIGMA[r][5] as usize]);
519 v[3] = add(v[3], m[SIGMA[r][7] as usize]);
520 v[0] = add(v[0], v[4]);
521 v[1] = add(v[1], v[5]);
522 v[2] = add(v[2], v[6]);
523 v[3] = add(v[3], v[7]);
524 v[12] = xor(v[12], v[0]);
525 v[13] = xor(v[13], v[1]);
526 v[14] = xor(v[14], v[2]);
527 v[15] = xor(v[15], v[3]);
528 v[12] = rot16(v[12]);
529 v[13] = rot16(v[13]);
530 v[14] = rot16(v[14]);
531 v[15] = rot16(v[15]);
532 v[8] = add(v[8], v[12]);
533 v[9] = add(v[9], v[13]);
534 v[10] = add(v[10], v[14]);
535 v[11] = add(v[11], v[15]);
536 v[4] = xor(v[4], v[8]);
537 v[5] = xor(v[5], v[9]);
538 v[6] = xor(v[6], v[10]);
539 v[7] = xor(v[7], v[11]);
540 v[4] = rot63(v[4]);
541 v[5] = rot63(v[5]);
542 v[6] = rot63(v[6]);
543 v[7] = rot63(v[7]);
544
545 v[0] = add(v[0], m[SIGMA[r][8] as usize]);
546 v[1] = add(v[1], m[SIGMA[r][10] as usize]);
547 v[2] = add(v[2], m[SIGMA[r][12] as usize]);
548 v[3] = add(v[3], m[SIGMA[r][14] as usize]);
549 v[0] = add(v[0], v[5]);
550 v[1] = add(v[1], v[6]);
551 v[2] = add(v[2], v[7]);
552 v[3] = add(v[3], v[4]);
553 v[15] = xor(v[15], v[0]);
554 v[12] = xor(v[12], v[1]);
555 v[13] = xor(v[13], v[2]);
556 v[14] = xor(v[14], v[3]);
557 v[15] = rot32(v[15]);
558 v[12] = rot32(v[12]);
559 v[13] = rot32(v[13]);
560 v[14] = rot32(v[14]);
561 v[10] = add(v[10], v[15]);
562 v[11] = add(v[11], v[12]);
563 v[8] = add(v[8], v[13]);
564 v[9] = add(v[9], v[14]);
565 v[5] = xor(v[5], v[10]);
566 v[6] = xor(v[6], v[11]);
567 v[7] = xor(v[7], v[8]);
568 v[4] = xor(v[4], v[9]);
569 v[5] = rot24(v[5]);
570 v[6] = rot24(v[6]);
571 v[7] = rot24(v[7]);
572 v[4] = rot24(v[4]);
573 v[0] = add(v[0], m[SIGMA[r][9] as usize]);
574 v[1] = add(v[1], m[SIGMA[r][11] as usize]);
575 v[2] = add(v[2], m[SIGMA[r][13] as usize]);
576 v[3] = add(v[3], m[SIGMA[r][15] as usize]);
577 v[0] = add(v[0], v[5]);
578 v[1] = add(v[1], v[6]);
579 v[2] = add(v[2], v[7]);
580 v[3] = add(v[3], v[4]);
581 v[15] = xor(v[15], v[0]);
582 v[12] = xor(v[12], v[1]);
583 v[13] = xor(v[13], v[2]);
584 v[14] = xor(v[14], v[3]);
585 v[15] = rot16(v[15]);
586 v[12] = rot16(v[12]);
587 v[13] = rot16(v[13]);
588 v[14] = rot16(v[14]);
589 v[10] = add(v[10], v[15]);
590 v[11] = add(v[11], v[12]);
591 v[8] = add(v[8], v[13]);
592 v[9] = add(v[9], v[14]);
593 v[5] = xor(v[5], v[10]);
594 v[6] = xor(v[6], v[11]);
595 v[7] = xor(v[7], v[8]);
596 v[4] = xor(v[4], v[9]);
597 v[5] = rot63(v[5]);
598 v[6] = rot63(v[6]);
599 v[7] = rot63(v[7]);
600 v[4] = rot63(v[4]);
601 }
602
603 // We'd rather make this a regular function with #[inline(always)], but for
604 // some reason that blows up compile times by about 10 seconds, at least in
605 // some cases (BLAKE2b avx2.rs). This macro seems to get the same performance
606 // result, without the compile time issue.
607 macro_rules! compress4_transposed {
608 (
609 $h_vecs:expr,
610 $msg_vecs:expr,
611 $count_low:expr,
612 $count_high:expr,
613 $lastblock:expr,
614 $lastnode:expr,
615 ) => {
616 let h_vecs: &mut [__m256i; 8] = $h_vecs;
617 let msg_vecs: &[__m256i; 16] = $msg_vecs;
618 let count_low: __m256i = $count_low;
619 let count_high: __m256i = $count_high;
620 let lastblock: __m256i = $lastblock;
621 let lastnode: __m256i = $lastnode;
622
623 let mut v = [
624 h_vecs[0],
625 h_vecs[1],
626 h_vecs[2],
627 h_vecs[3],
628 h_vecs[4],
629 h_vecs[5],
630 h_vecs[6],
631 h_vecs[7],
632 set1(IV[0]),
633 set1(IV[1]),
634 set1(IV[2]),
635 set1(IV[3]),
636 xor(set1(IV[4]), count_low),
637 xor(set1(IV[5]), count_high),
638 xor(set1(IV[6]), lastblock),
639 xor(set1(IV[7]), lastnode),
640 ];
641
642 round(&mut v, &msg_vecs, 0);
643 round(&mut v, &msg_vecs, 1);
644 round(&mut v, &msg_vecs, 2);
645 round(&mut v, &msg_vecs, 3);
646 round(&mut v, &msg_vecs, 4);
647 round(&mut v, &msg_vecs, 5);
648 round(&mut v, &msg_vecs, 6);
649 round(&mut v, &msg_vecs, 7);
650 round(&mut v, &msg_vecs, 8);
651 round(&mut v, &msg_vecs, 9);
652 round(&mut v, &msg_vecs, 10);
653 round(&mut v, &msg_vecs, 11);
654
655 h_vecs[0] = xor(xor(h_vecs[0], v[0]), v[8]);
656 h_vecs[1] = xor(xor(h_vecs[1], v[1]), v[9]);
657 h_vecs[2] = xor(xor(h_vecs[2], v[2]), v[10]);
658 h_vecs[3] = xor(xor(h_vecs[3], v[3]), v[11]);
659 h_vecs[4] = xor(xor(h_vecs[4], v[4]), v[12]);
660 h_vecs[5] = xor(xor(h_vecs[5], v[5]), v[13]);
661 h_vecs[6] = xor(xor(h_vecs[6], v[6]), v[14]);
662 h_vecs[7] = xor(xor(h_vecs[7], v[7]), v[15]);
663 };
664 }
665
666 #[inline(always)]
667 unsafe fn interleave128(a: __m256i, b: __m256i) -> (__m256i, __m256i) {
668 (
669 _mm256_permute2x128_si256(a, b, 0x20),
670 _mm256_permute2x128_si256(a, b, 0x31),
671 )
672 }
673
674 // There are several ways to do a transposition. We could do it naively, with 8 separate
675 // _mm256_set_epi64x instructions, referencing each of the 64 words explicitly. Or we could copy
676 // the vecs into contiguous storage and then use gather instructions. This third approach is to use
677 // a series of unpack instructions to interleave the vectors. In my benchmarks, interleaving is the
678 // fastest approach. To test this, run `cargo +nightly bench --bench libtest load_4` in the
679 // https://github.com/oconnor663/bao_experiments repo.
680 #[inline(always)]
681 unsafe fn transpose_vecs(
682 vec_a: __m256i,
683 vec_b: __m256i,
684 vec_c: __m256i,
685 vec_d: __m256i,
686 ) -> [__m256i; DEGREE] {
687 // Interleave 64-bit lates. The low unpack is lanes 00/22 and the high is 11/33.
688 let ab_02 = _mm256_unpacklo_epi64(vec_a, vec_b);
689 let ab_13 = _mm256_unpackhi_epi64(vec_a, vec_b);
690 let cd_02 = _mm256_unpacklo_epi64(vec_c, vec_d);
691 let cd_13 = _mm256_unpackhi_epi64(vec_c, vec_d);
692
693 // Interleave 128-bit lanes.
694 let (abcd_0, abcd_2) = interleave128(ab_02, cd_02);
695 let (abcd_1, abcd_3) = interleave128(ab_13, cd_13);
696
697 [abcd_0, abcd_1, abcd_2, abcd_3]
698 }
699
700 #[inline(always)]
701 unsafe fn transpose_state_vecs(jobs: &[Job; DEGREE]) -> [__m256i; 8] {
702 // Load all the state words into transposed vectors, where the first vector
703 // has the first word of each state, etc. Transposing once at the beginning
704 // and once at the end is more efficient that repeating it for each block.
705 let words0 = array_refs!(&jobs[0].words, DEGREE, DEGREE);
706 let words1 = array_refs!(&jobs[1].words, DEGREE, DEGREE);
707 let words2 = array_refs!(&jobs[2].words, DEGREE, DEGREE);
708 let words3 = array_refs!(&jobs[3].words, DEGREE, DEGREE);
709 let [h0, h1, h2, h3] = transpose_vecs(
710 loadu(words0.0),
711 loadu(words1.0),
712 loadu(words2.0),
713 loadu(words3.0),
714 );
715 let [h4, h5, h6, h7] = transpose_vecs(
716 loadu(words0.1),
717 loadu(words1.1),
718 loadu(words2.1),
719 loadu(words3.1),
720 );
721 [h0, h1, h2, h3, h4, h5, h6, h7]
722 }
723
724 #[inline(always)]
725 unsafe fn untranspose_state_vecs(h_vecs: &[__m256i; 8], jobs: &mut [Job; DEGREE]) {
726 // Un-transpose the updated state vectors back into the caller's arrays.
727 let [job0, job1, job2, job3] = jobs;
728 let words0 = mut_array_refs!(&mut job0.words, DEGREE, DEGREE);
729 let words1 = mut_array_refs!(&mut job1.words, DEGREE, DEGREE);
730 let words2 = mut_array_refs!(&mut job2.words, DEGREE, DEGREE);
731 let words3 = mut_array_refs!(&mut job3.words, DEGREE, DEGREE);
732 let out = transpose_vecs(h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3]);
733 storeu(out[0], words0.0);
734 storeu(out[1], words1.0);
735 storeu(out[2], words2.0);
736 storeu(out[3], words3.0);
737 let out = transpose_vecs(h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7]);
738 storeu(out[0], words0.1);
739 storeu(out[1], words1.1);
740 storeu(out[2], words2.1);
741 storeu(out[3], words3.1);
742 }
743
744 #[inline(always)]
745 unsafe fn transpose_msg_vecs(blocks: [*const [u8; BLOCKBYTES]; DEGREE]) -> [__m256i; 16] {
746 // These input arrays have no particular alignment, so we use unaligned
747 // loads to read from them.
748 let block0 = blocks[0] as *const [Word; DEGREE];
749 let block1 = blocks[1] as *const [Word; DEGREE];
750 let block2 = blocks[2] as *const [Word; DEGREE];
751 let block3 = blocks[3] as *const [Word; DEGREE];
752 let [m0, m1, m2, m3] = transpose_vecs(
753 loadu(block0.add(0)),
754 loadu(block1.add(0)),
755 loadu(block2.add(0)),
756 loadu(block3.add(0)),
757 );
758 let [m4, m5, m6, m7] = transpose_vecs(
759 loadu(block0.add(1)),
760 loadu(block1.add(1)),
761 loadu(block2.add(1)),
762 loadu(block3.add(1)),
763 );
764 let [m8, m9, m10, m11] = transpose_vecs(
765 loadu(block0.add(2)),
766 loadu(block1.add(2)),
767 loadu(block2.add(2)),
768 loadu(block3.add(2)),
769 );
770 let [m12, m13, m14, m15] = transpose_vecs(
771 loadu(block0.add(3)),
772 loadu(block1.add(3)),
773 loadu(block2.add(3)),
774 loadu(block3.add(3)),
775 );
776 [
777 m0, m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14, m15,
778 ]
779 }
780
781 #[inline(always)]
782 unsafe fn load_counts(jobs: &[Job; DEGREE]) -> (__m256i, __m256i) {
783 (
784 set4(
785 count_low(jobs[0].count),
786 count_low(jobs[1].count),
787 count_low(jobs[2].count),
788 count_low(jobs[3].count),
789 ),
790 set4(
791 count_high(jobs[0].count),
792 count_high(jobs[1].count),
793 count_high(jobs[2].count),
794 count_high(jobs[3].count),
795 ),
796 )
797 }
798
799 #[inline(always)]
800 unsafe fn store_counts(jobs: &mut [Job; DEGREE], low: __m256i, high: __m256i) {
801 let low_ints: [Word; DEGREE] = mem::transmute(low);
802 let high_ints: [Word; DEGREE] = mem::transmute(high);
803 for i in 0..DEGREE {
804 jobs[i].count = assemble_count(low_ints[i], high_ints[i]);
805 }
806 }
807
808 #[inline(always)]
809 unsafe fn add_to_counts(lo: &mut __m256i, hi: &mut __m256i, delta: __m256i) {
810 // If the low counts reach zero, that means they wrapped, unless the delta
811 // was also zero.
812 *lo = add(*lo, delta);
813 let lo_reached_zero = eq(*lo, set1(0));
814 let delta_was_zero = eq(delta, set1(0));
815 let hi_inc = and(set1(1), negate_and(delta_was_zero, lo_reached_zero));
816 *hi = add(*hi, hi_inc);
817 }
818
819 #[inline(always)]
820 unsafe fn flags_vec(flags: [bool; DEGREE]) -> __m256i {
821 set4(
822 flag_word(flags[0]),
823 flag_word(flags[1]),
824 flag_word(flags[2]),
825 flag_word(flags[3]),
826 )
827 }
828
829 #[target_feature(enable = "avx2")]
830 pub unsafe fn compress4_loop(jobs: &mut [Job; DEGREE], finalize: Finalize, stride: Stride) {
831 // If we're not finalizing, there can't be a partial block at the end.
832 for job in jobs.iter() {
833 input_debug_asserts(job.input, finalize);
834 }
835
836 let msg_ptrs = [
837 jobs[0].input.as_ptr(),
838 jobs[1].input.as_ptr(),
839 jobs[2].input.as_ptr(),
840 jobs[3].input.as_ptr(),
841 ];
842 let mut h_vecs = transpose_state_vecs(&jobs);
843 let (mut counts_lo, mut counts_hi) = load_counts(&jobs);
844
845 // Prepare the final blocks (note, which could be empty if the input is
846 // empty). Do all this before entering the main loop.
847 let min_len = jobs.iter().map(|job| job.input.len()).min().unwrap();
848 let mut fin_offset = min_len.saturating_sub(1);
849 fin_offset -= fin_offset % stride.padded_blockbytes();
850 // Performance note, making these buffers mem::uninitialized() seems to
851 // cause problems in the optimizer.
852 let mut buf0: [u8; BLOCKBYTES] = [0; BLOCKBYTES];
853 let mut buf1: [u8; BLOCKBYTES] = [0; BLOCKBYTES];
854 let mut buf2: [u8; BLOCKBYTES] = [0; BLOCKBYTES];
855 let mut buf3: [u8; BLOCKBYTES] = [0; BLOCKBYTES];
856 let (block0, len0, finalize0) = final_block(jobs[0].input, fin_offset, &mut buf0, stride);
857 let (block1, len1, finalize1) = final_block(jobs[1].input, fin_offset, &mut buf1, stride);
858 let (block2, len2, finalize2) = final_block(jobs[2].input, fin_offset, &mut buf2, stride);
859 let (block3, len3, finalize3) = final_block(jobs[3].input, fin_offset, &mut buf3, stride);
860 let fin_blocks: [*const [u8; BLOCKBYTES]; DEGREE] = [block0, block1, block2, block3];
861 let fin_counts_delta = set4(len0 as Word, len1 as Word, len2 as Word, len3 as Word);
862 let fin_last_block;
863 let fin_last_node;
864 if finalize.yes() {
865 fin_last_block = flags_vec([finalize0, finalize1, finalize2, finalize3]);
866 fin_last_node = flags_vec([
867 finalize0 && jobs[0].last_node.yes(),
868 finalize1 && jobs[1].last_node.yes(),
869 finalize2 && jobs[2].last_node.yes(),
870 finalize3 && jobs[3].last_node.yes(),
871 ]);
872 } else {
873 fin_last_block = set1(0);
874 fin_last_node = set1(0);
875 }
876
877 // The main loop.
878 let mut offset = 0;
879 loop {
880 let blocks;
881 let counts_delta;
882 let last_block;
883 let last_node;
884 if offset == fin_offset {
885 blocks = fin_blocks;
886 counts_delta = fin_counts_delta;
887 last_block = fin_last_block;
888 last_node = fin_last_node;
889 } else {
890 blocks = [
891 msg_ptrs[0].add(offset) as *const [u8; BLOCKBYTES],
892 msg_ptrs[1].add(offset) as *const [u8; BLOCKBYTES],
893 msg_ptrs[2].add(offset) as *const [u8; BLOCKBYTES],
894 msg_ptrs[3].add(offset) as *const [u8; BLOCKBYTES],
895 ];
896 counts_delta = set1(BLOCKBYTES as Word);
897 last_block = set1(0);
898 last_node = set1(0);
899 };
900
901 let m_vecs = transpose_msg_vecs(blocks);
902 add_to_counts(&mut counts_lo, &mut counts_hi, counts_delta);
903 compress4_transposed!(
904 &mut h_vecs,
905 &m_vecs,
906 counts_lo,
907 counts_hi,
908 last_block,
909 last_node,
910 );
911
912 // Check for termination before bumping the offset, to avoid overflow.
913 if offset == fin_offset {
914 break;
915 }
916
917 offset += stride.padded_blockbytes();
918 }
919
920 // Write out the results.
921 untranspose_state_vecs(&h_vecs, &mut *jobs);
922 store_counts(&mut *jobs, counts_lo, counts_hi);
923 let max_consumed = offset.saturating_add(stride.padded_blockbytes());
924 for job in jobs.iter_mut() {
925 let consumed = cmp::min(max_consumed, job.input.len());
926 job.input = &job.input[consumed..];
927 }
928 }