xref: /freebsd/contrib/llvm-project/llvm/lib/Support/BLAKE3/blake3_avx512.c (revision 81ad626541db97eb356e2c1d4a20eb2a26a766ab)
1 #include "blake3_impl.h"
2 
3 #include <immintrin.h>
4 
5 #define _mm_shuffle_ps2(a, b, c)                                               \
6   (_mm_castps_si128(                                                           \
7       _mm_shuffle_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), (c))))
8 
loadu_128(const uint8_t src[16])9 INLINE __m128i loadu_128(const uint8_t src[16]) {
10   return _mm_loadu_si128((const __m128i *)src);
11 }
12 
loadu_256(const uint8_t src[32])13 INLINE __m256i loadu_256(const uint8_t src[32]) {
14   return _mm256_loadu_si256((const __m256i *)src);
15 }
16 
loadu_512(const uint8_t src[64])17 INLINE __m512i loadu_512(const uint8_t src[64]) {
18   return _mm512_loadu_si512((const __m512i *)src);
19 }
20 
storeu_128(__m128i src,uint8_t dest[16])21 INLINE void storeu_128(__m128i src, uint8_t dest[16]) {
22   _mm_storeu_si128((__m128i *)dest, src);
23 }
24 
storeu_256(__m256i src,uint8_t dest[16])25 INLINE void storeu_256(__m256i src, uint8_t dest[16]) {
26   _mm256_storeu_si256((__m256i *)dest, src);
27 }
28 
add_128(__m128i a,__m128i b)29 INLINE __m128i add_128(__m128i a, __m128i b) { return _mm_add_epi32(a, b); }
30 
add_256(__m256i a,__m256i b)31 INLINE __m256i add_256(__m256i a, __m256i b) { return _mm256_add_epi32(a, b); }
32 
add_512(__m512i a,__m512i b)33 INLINE __m512i add_512(__m512i a, __m512i b) { return _mm512_add_epi32(a, b); }
34 
xor_128(__m128i a,__m128i b)35 INLINE __m128i xor_128(__m128i a, __m128i b) { return _mm_xor_si128(a, b); }
36 
xor_256(__m256i a,__m256i b)37 INLINE __m256i xor_256(__m256i a, __m256i b) { return _mm256_xor_si256(a, b); }
38 
xor_512(__m512i a,__m512i b)39 INLINE __m512i xor_512(__m512i a, __m512i b) { return _mm512_xor_si512(a, b); }
40 
set1_128(uint32_t x)41 INLINE __m128i set1_128(uint32_t x) { return _mm_set1_epi32((int32_t)x); }
42 
set1_256(uint32_t x)43 INLINE __m256i set1_256(uint32_t x) { return _mm256_set1_epi32((int32_t)x); }
44 
set1_512(uint32_t x)45 INLINE __m512i set1_512(uint32_t x) { return _mm512_set1_epi32((int32_t)x); }
46 
set4(uint32_t a,uint32_t b,uint32_t c,uint32_t d)47 INLINE __m128i set4(uint32_t a, uint32_t b, uint32_t c, uint32_t d) {
48   return _mm_setr_epi32((int32_t)a, (int32_t)b, (int32_t)c, (int32_t)d);
49 }
50 
rot16_128(__m128i x)51 INLINE __m128i rot16_128(__m128i x) { return _mm_ror_epi32(x, 16); }
52 
rot16_256(__m256i x)53 INLINE __m256i rot16_256(__m256i x) { return _mm256_ror_epi32(x, 16); }
54 
rot16_512(__m512i x)55 INLINE __m512i rot16_512(__m512i x) { return _mm512_ror_epi32(x, 16); }
56 
rot12_128(__m128i x)57 INLINE __m128i rot12_128(__m128i x) { return _mm_ror_epi32(x, 12); }
58 
rot12_256(__m256i x)59 INLINE __m256i rot12_256(__m256i x) { return _mm256_ror_epi32(x, 12); }
60 
rot12_512(__m512i x)61 INLINE __m512i rot12_512(__m512i x) { return _mm512_ror_epi32(x, 12); }
62 
rot8_128(__m128i x)63 INLINE __m128i rot8_128(__m128i x) { return _mm_ror_epi32(x, 8); }
64 
rot8_256(__m256i x)65 INLINE __m256i rot8_256(__m256i x) { return _mm256_ror_epi32(x, 8); }
66 
rot8_512(__m512i x)67 INLINE __m512i rot8_512(__m512i x) { return _mm512_ror_epi32(x, 8); }
68 
rot7_128(__m128i x)69 INLINE __m128i rot7_128(__m128i x) { return _mm_ror_epi32(x, 7); }
70 
rot7_256(__m256i x)71 INLINE __m256i rot7_256(__m256i x) { return _mm256_ror_epi32(x, 7); }
72 
rot7_512(__m512i x)73 INLINE __m512i rot7_512(__m512i x) { return _mm512_ror_epi32(x, 7); }
74 
75 /*
76  * ----------------------------------------------------------------------------
77  * compress_avx512
78  * ----------------------------------------------------------------------------
79  */
80 
g1(__m128i * row0,__m128i * row1,__m128i * row2,__m128i * row3,__m128i m)81 INLINE void g1(__m128i *row0, __m128i *row1, __m128i *row2, __m128i *row3,
82                __m128i m) {
83   *row0 = add_128(add_128(*row0, m), *row1);
84   *row3 = xor_128(*row3, *row0);
85   *row3 = rot16_128(*row3);
86   *row2 = add_128(*row2, *row3);
87   *row1 = xor_128(*row1, *row2);
88   *row1 = rot12_128(*row1);
89 }
90 
g2(__m128i * row0,__m128i * row1,__m128i * row2,__m128i * row3,__m128i m)91 INLINE void g2(__m128i *row0, __m128i *row1, __m128i *row2, __m128i *row3,
92                __m128i m) {
93   *row0 = add_128(add_128(*row0, m), *row1);
94   *row3 = xor_128(*row3, *row0);
95   *row3 = rot8_128(*row3);
96   *row2 = add_128(*row2, *row3);
97   *row1 = xor_128(*row1, *row2);
98   *row1 = rot7_128(*row1);
99 }
100 
101 // Note the optimization here of leaving row1 as the unrotated row, rather than
102 // row0. All the message loads below are adjusted to compensate for this. See
103 // discussion at https://github.com/sneves/blake2-avx2/pull/4
diagonalize(__m128i * row0,__m128i * row2,__m128i * row3)104 INLINE void diagonalize(__m128i *row0, __m128i *row2, __m128i *row3) {
105   *row0 = _mm_shuffle_epi32(*row0, _MM_SHUFFLE(2, 1, 0, 3));
106   *row3 = _mm_shuffle_epi32(*row3, _MM_SHUFFLE(1, 0, 3, 2));
107   *row2 = _mm_shuffle_epi32(*row2, _MM_SHUFFLE(0, 3, 2, 1));
108 }
109 
undiagonalize(__m128i * row0,__m128i * row2,__m128i * row3)110 INLINE void undiagonalize(__m128i *row0, __m128i *row2, __m128i *row3) {
111   *row0 = _mm_shuffle_epi32(*row0, _MM_SHUFFLE(0, 3, 2, 1));
112   *row3 = _mm_shuffle_epi32(*row3, _MM_SHUFFLE(1, 0, 3, 2));
113   *row2 = _mm_shuffle_epi32(*row2, _MM_SHUFFLE(2, 1, 0, 3));
114 }
115 
compress_pre(__m128i rows[4],const uint32_t cv[8],const uint8_t block[BLAKE3_BLOCK_LEN],uint8_t block_len,uint64_t counter,uint8_t flags)116 INLINE void compress_pre(__m128i rows[4], const uint32_t cv[8],
117                          const uint8_t block[BLAKE3_BLOCK_LEN],
118                          uint8_t block_len, uint64_t counter, uint8_t flags) {
119   rows[0] = loadu_128((uint8_t *)&cv[0]);
120   rows[1] = loadu_128((uint8_t *)&cv[4]);
121   rows[2] = set4(IV[0], IV[1], IV[2], IV[3]);
122   rows[3] = set4(counter_low(counter), counter_high(counter),
123                  (uint32_t)block_len, (uint32_t)flags);
124 
125   __m128i m0 = loadu_128(&block[sizeof(__m128i) * 0]);
126   __m128i m1 = loadu_128(&block[sizeof(__m128i) * 1]);
127   __m128i m2 = loadu_128(&block[sizeof(__m128i) * 2]);
128   __m128i m3 = loadu_128(&block[sizeof(__m128i) * 3]);
129 
130   __m128i t0, t1, t2, t3, tt;
131 
132   // Round 1. The first round permutes the message words from the original
133   // input order, into the groups that get mixed in parallel.
134   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(2, 0, 2, 0)); //  6  4  2  0
135   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
136   t1 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 3, 1)); //  7  5  3  1
137   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
138   diagonalize(&rows[0], &rows[2], &rows[3]);
139   t2 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(2, 0, 2, 0)); // 14 12 10  8
140   t2 = _mm_shuffle_epi32(t2, _MM_SHUFFLE(2, 1, 0, 3));   // 12 10  8 14
141   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
142   t3 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 1, 3, 1)); // 15 13 11  9
143   t3 = _mm_shuffle_epi32(t3, _MM_SHUFFLE(2, 1, 0, 3));   // 13 11  9 15
144   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
145   undiagonalize(&rows[0], &rows[2], &rows[3]);
146   m0 = t0;
147   m1 = t1;
148   m2 = t2;
149   m3 = t3;
150 
151   // Round 2. This round and all following rounds apply a fixed permutation
152   // to the message words from the round before.
153   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
154   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
155   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
156   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
157   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
158   t1 = _mm_blend_epi16(tt, t1, 0xCC);
159   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
160   diagonalize(&rows[0], &rows[2], &rows[3]);
161   t2 = _mm_unpacklo_epi64(m3, m1);
162   tt = _mm_blend_epi16(t2, m2, 0xC0);
163   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
164   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
165   t3 = _mm_unpackhi_epi32(m1, m3);
166   tt = _mm_unpacklo_epi32(m2, t3);
167   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
168   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
169   undiagonalize(&rows[0], &rows[2], &rows[3]);
170   m0 = t0;
171   m1 = t1;
172   m2 = t2;
173   m3 = t3;
174 
175   // Round 3
176   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
177   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
178   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
179   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
180   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
181   t1 = _mm_blend_epi16(tt, t1, 0xCC);
182   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
183   diagonalize(&rows[0], &rows[2], &rows[3]);
184   t2 = _mm_unpacklo_epi64(m3, m1);
185   tt = _mm_blend_epi16(t2, m2, 0xC0);
186   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
187   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
188   t3 = _mm_unpackhi_epi32(m1, m3);
189   tt = _mm_unpacklo_epi32(m2, t3);
190   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
191   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
192   undiagonalize(&rows[0], &rows[2], &rows[3]);
193   m0 = t0;
194   m1 = t1;
195   m2 = t2;
196   m3 = t3;
197 
198   // Round 4
199   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
200   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
201   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
202   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
203   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
204   t1 = _mm_blend_epi16(tt, t1, 0xCC);
205   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
206   diagonalize(&rows[0], &rows[2], &rows[3]);
207   t2 = _mm_unpacklo_epi64(m3, m1);
208   tt = _mm_blend_epi16(t2, m2, 0xC0);
209   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
210   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
211   t3 = _mm_unpackhi_epi32(m1, m3);
212   tt = _mm_unpacklo_epi32(m2, t3);
213   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
214   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
215   undiagonalize(&rows[0], &rows[2], &rows[3]);
216   m0 = t0;
217   m1 = t1;
218   m2 = t2;
219   m3 = t3;
220 
221   // Round 5
222   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
223   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
224   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
225   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
226   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
227   t1 = _mm_blend_epi16(tt, t1, 0xCC);
228   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
229   diagonalize(&rows[0], &rows[2], &rows[3]);
230   t2 = _mm_unpacklo_epi64(m3, m1);
231   tt = _mm_blend_epi16(t2, m2, 0xC0);
232   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
233   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
234   t3 = _mm_unpackhi_epi32(m1, m3);
235   tt = _mm_unpacklo_epi32(m2, t3);
236   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
237   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
238   undiagonalize(&rows[0], &rows[2], &rows[3]);
239   m0 = t0;
240   m1 = t1;
241   m2 = t2;
242   m3 = t3;
243 
244   // Round 6
245   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
246   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
247   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
248   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
249   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
250   t1 = _mm_blend_epi16(tt, t1, 0xCC);
251   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
252   diagonalize(&rows[0], &rows[2], &rows[3]);
253   t2 = _mm_unpacklo_epi64(m3, m1);
254   tt = _mm_blend_epi16(t2, m2, 0xC0);
255   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
256   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
257   t3 = _mm_unpackhi_epi32(m1, m3);
258   tt = _mm_unpacklo_epi32(m2, t3);
259   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
260   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
261   undiagonalize(&rows[0], &rows[2], &rows[3]);
262   m0 = t0;
263   m1 = t1;
264   m2 = t2;
265   m3 = t3;
266 
267   // Round 7
268   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
269   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
270   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
271   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
272   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
273   t1 = _mm_blend_epi16(tt, t1, 0xCC);
274   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
275   diagonalize(&rows[0], &rows[2], &rows[3]);
276   t2 = _mm_unpacklo_epi64(m3, m1);
277   tt = _mm_blend_epi16(t2, m2, 0xC0);
278   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
279   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
280   t3 = _mm_unpackhi_epi32(m1, m3);
281   tt = _mm_unpacklo_epi32(m2, t3);
282   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
283   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
284   undiagonalize(&rows[0], &rows[2], &rows[3]);
285 }
286 
blake3_compress_xof_avx512(const uint32_t cv[8],const uint8_t block[BLAKE3_BLOCK_LEN],uint8_t block_len,uint64_t counter,uint8_t flags,uint8_t out[64])287 void blake3_compress_xof_avx512(const uint32_t cv[8],
288                                 const uint8_t block[BLAKE3_BLOCK_LEN],
289                                 uint8_t block_len, uint64_t counter,
290                                 uint8_t flags, uint8_t out[64]) {
291   __m128i rows[4];
292   compress_pre(rows, cv, block, block_len, counter, flags);
293   storeu_128(xor_128(rows[0], rows[2]), &out[0]);
294   storeu_128(xor_128(rows[1], rows[3]), &out[16]);
295   storeu_128(xor_128(rows[2], loadu_128((uint8_t *)&cv[0])), &out[32]);
296   storeu_128(xor_128(rows[3], loadu_128((uint8_t *)&cv[4])), &out[48]);
297 }
298 
blake3_compress_in_place_avx512(uint32_t cv[8],const uint8_t block[BLAKE3_BLOCK_LEN],uint8_t block_len,uint64_t counter,uint8_t flags)299 void blake3_compress_in_place_avx512(uint32_t cv[8],
300                                      const uint8_t block[BLAKE3_BLOCK_LEN],
301                                      uint8_t block_len, uint64_t counter,
302                                      uint8_t flags) {
303   __m128i rows[4];
304   compress_pre(rows, cv, block, block_len, counter, flags);
305   storeu_128(xor_128(rows[0], rows[2]), (uint8_t *)&cv[0]);
306   storeu_128(xor_128(rows[1], rows[3]), (uint8_t *)&cv[4]);
307 }
308 
309 /*
310  * ----------------------------------------------------------------------------
311  * hash4_avx512
312  * ----------------------------------------------------------------------------
313  */
314 
round_fn4(__m128i v[16],__m128i m[16],size_t r)315 INLINE void round_fn4(__m128i v[16], __m128i m[16], size_t r) {
316   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
317   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
318   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
319   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
320   v[0] = add_128(v[0], v[4]);
321   v[1] = add_128(v[1], v[5]);
322   v[2] = add_128(v[2], v[6]);
323   v[3] = add_128(v[3], v[7]);
324   v[12] = xor_128(v[12], v[0]);
325   v[13] = xor_128(v[13], v[1]);
326   v[14] = xor_128(v[14], v[2]);
327   v[15] = xor_128(v[15], v[3]);
328   v[12] = rot16_128(v[12]);
329   v[13] = rot16_128(v[13]);
330   v[14] = rot16_128(v[14]);
331   v[15] = rot16_128(v[15]);
332   v[8] = add_128(v[8], v[12]);
333   v[9] = add_128(v[9], v[13]);
334   v[10] = add_128(v[10], v[14]);
335   v[11] = add_128(v[11], v[15]);
336   v[4] = xor_128(v[4], v[8]);
337   v[5] = xor_128(v[5], v[9]);
338   v[6] = xor_128(v[6], v[10]);
339   v[7] = xor_128(v[7], v[11]);
340   v[4] = rot12_128(v[4]);
341   v[5] = rot12_128(v[5]);
342   v[6] = rot12_128(v[6]);
343   v[7] = rot12_128(v[7]);
344   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
345   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
346   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
347   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
348   v[0] = add_128(v[0], v[4]);
349   v[1] = add_128(v[1], v[5]);
350   v[2] = add_128(v[2], v[6]);
351   v[3] = add_128(v[3], v[7]);
352   v[12] = xor_128(v[12], v[0]);
353   v[13] = xor_128(v[13], v[1]);
354   v[14] = xor_128(v[14], v[2]);
355   v[15] = xor_128(v[15], v[3]);
356   v[12] = rot8_128(v[12]);
357   v[13] = rot8_128(v[13]);
358   v[14] = rot8_128(v[14]);
359   v[15] = rot8_128(v[15]);
360   v[8] = add_128(v[8], v[12]);
361   v[9] = add_128(v[9], v[13]);
362   v[10] = add_128(v[10], v[14]);
363   v[11] = add_128(v[11], v[15]);
364   v[4] = xor_128(v[4], v[8]);
365   v[5] = xor_128(v[5], v[9]);
366   v[6] = xor_128(v[6], v[10]);
367   v[7] = xor_128(v[7], v[11]);
368   v[4] = rot7_128(v[4]);
369   v[5] = rot7_128(v[5]);
370   v[6] = rot7_128(v[6]);
371   v[7] = rot7_128(v[7]);
372 
373   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
374   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
375   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
376   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
377   v[0] = add_128(v[0], v[5]);
378   v[1] = add_128(v[1], v[6]);
379   v[2] = add_128(v[2], v[7]);
380   v[3] = add_128(v[3], v[4]);
381   v[15] = xor_128(v[15], v[0]);
382   v[12] = xor_128(v[12], v[1]);
383   v[13] = xor_128(v[13], v[2]);
384   v[14] = xor_128(v[14], v[3]);
385   v[15] = rot16_128(v[15]);
386   v[12] = rot16_128(v[12]);
387   v[13] = rot16_128(v[13]);
388   v[14] = rot16_128(v[14]);
389   v[10] = add_128(v[10], v[15]);
390   v[11] = add_128(v[11], v[12]);
391   v[8] = add_128(v[8], v[13]);
392   v[9] = add_128(v[9], v[14]);
393   v[5] = xor_128(v[5], v[10]);
394   v[6] = xor_128(v[6], v[11]);
395   v[7] = xor_128(v[7], v[8]);
396   v[4] = xor_128(v[4], v[9]);
397   v[5] = rot12_128(v[5]);
398   v[6] = rot12_128(v[6]);
399   v[7] = rot12_128(v[7]);
400   v[4] = rot12_128(v[4]);
401   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
402   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
403   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
404   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
405   v[0] = add_128(v[0], v[5]);
406   v[1] = add_128(v[1], v[6]);
407   v[2] = add_128(v[2], v[7]);
408   v[3] = add_128(v[3], v[4]);
409   v[15] = xor_128(v[15], v[0]);
410   v[12] = xor_128(v[12], v[1]);
411   v[13] = xor_128(v[13], v[2]);
412   v[14] = xor_128(v[14], v[3]);
413   v[15] = rot8_128(v[15]);
414   v[12] = rot8_128(v[12]);
415   v[13] = rot8_128(v[13]);
416   v[14] = rot8_128(v[14]);
417   v[10] = add_128(v[10], v[15]);
418   v[11] = add_128(v[11], v[12]);
419   v[8] = add_128(v[8], v[13]);
420   v[9] = add_128(v[9], v[14]);
421   v[5] = xor_128(v[5], v[10]);
422   v[6] = xor_128(v[6], v[11]);
423   v[7] = xor_128(v[7], v[8]);
424   v[4] = xor_128(v[4], v[9]);
425   v[5] = rot7_128(v[5]);
426   v[6] = rot7_128(v[6]);
427   v[7] = rot7_128(v[7]);
428   v[4] = rot7_128(v[4]);
429 }
430 
transpose_vecs_128(__m128i vecs[4])431 INLINE void transpose_vecs_128(__m128i vecs[4]) {
432   // Interleave 32-bit lates. The low unpack is lanes 00/11 and the high is
433   // 22/33. Note that this doesn't split the vector into two lanes, as the
434   // AVX2 counterparts do.
435   __m128i ab_01 = _mm_unpacklo_epi32(vecs[0], vecs[1]);
436   __m128i ab_23 = _mm_unpackhi_epi32(vecs[0], vecs[1]);
437   __m128i cd_01 = _mm_unpacklo_epi32(vecs[2], vecs[3]);
438   __m128i cd_23 = _mm_unpackhi_epi32(vecs[2], vecs[3]);
439 
440   // Interleave 64-bit lanes.
441   __m128i abcd_0 = _mm_unpacklo_epi64(ab_01, cd_01);
442   __m128i abcd_1 = _mm_unpackhi_epi64(ab_01, cd_01);
443   __m128i abcd_2 = _mm_unpacklo_epi64(ab_23, cd_23);
444   __m128i abcd_3 = _mm_unpackhi_epi64(ab_23, cd_23);
445 
446   vecs[0] = abcd_0;
447   vecs[1] = abcd_1;
448   vecs[2] = abcd_2;
449   vecs[3] = abcd_3;
450 }
451 
transpose_msg_vecs4(const uint8_t * const * inputs,size_t block_offset,__m128i out[16])452 INLINE void transpose_msg_vecs4(const uint8_t *const *inputs,
453                                 size_t block_offset, __m128i out[16]) {
454   out[0] = loadu_128(&inputs[0][block_offset + 0 * sizeof(__m128i)]);
455   out[1] = loadu_128(&inputs[1][block_offset + 0 * sizeof(__m128i)]);
456   out[2] = loadu_128(&inputs[2][block_offset + 0 * sizeof(__m128i)]);
457   out[3] = loadu_128(&inputs[3][block_offset + 0 * sizeof(__m128i)]);
458   out[4] = loadu_128(&inputs[0][block_offset + 1 * sizeof(__m128i)]);
459   out[5] = loadu_128(&inputs[1][block_offset + 1 * sizeof(__m128i)]);
460   out[6] = loadu_128(&inputs[2][block_offset + 1 * sizeof(__m128i)]);
461   out[7] = loadu_128(&inputs[3][block_offset + 1 * sizeof(__m128i)]);
462   out[8] = loadu_128(&inputs[0][block_offset + 2 * sizeof(__m128i)]);
463   out[9] = loadu_128(&inputs[1][block_offset + 2 * sizeof(__m128i)]);
464   out[10] = loadu_128(&inputs[2][block_offset + 2 * sizeof(__m128i)]);
465   out[11] = loadu_128(&inputs[3][block_offset + 2 * sizeof(__m128i)]);
466   out[12] = loadu_128(&inputs[0][block_offset + 3 * sizeof(__m128i)]);
467   out[13] = loadu_128(&inputs[1][block_offset + 3 * sizeof(__m128i)]);
468   out[14] = loadu_128(&inputs[2][block_offset + 3 * sizeof(__m128i)]);
469   out[15] = loadu_128(&inputs[3][block_offset + 3 * sizeof(__m128i)]);
470   for (size_t i = 0; i < 4; ++i) {
471     _mm_prefetch((const void *)&inputs[i][block_offset + 256], _MM_HINT_T0);
472   }
473   transpose_vecs_128(&out[0]);
474   transpose_vecs_128(&out[4]);
475   transpose_vecs_128(&out[8]);
476   transpose_vecs_128(&out[12]);
477 }
478 
load_counters4(uint64_t counter,bool increment_counter,__m128i * out_lo,__m128i * out_hi)479 INLINE void load_counters4(uint64_t counter, bool increment_counter,
480                            __m128i *out_lo, __m128i *out_hi) {
481   uint64_t mask = (increment_counter ? ~0 : 0);
482   __m256i mask_vec = _mm256_set1_epi64x(mask);
483   __m256i deltas = _mm256_setr_epi64x(0, 1, 2, 3);
484   deltas = _mm256_and_si256(mask_vec, deltas);
485   __m256i counters =
486       _mm256_add_epi64(_mm256_set1_epi64x((int64_t)counter), deltas);
487   *out_lo = _mm256_cvtepi64_epi32(counters);
488   *out_hi = _mm256_cvtepi64_epi32(_mm256_srli_epi64(counters, 32));
489 }
490 
491 static
blake3_hash4_avx512(const uint8_t * const * inputs,size_t blocks,const uint32_t key[8],uint64_t counter,bool increment_counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t * out)492 void blake3_hash4_avx512(const uint8_t *const *inputs, size_t blocks,
493                          const uint32_t key[8], uint64_t counter,
494                          bool increment_counter, uint8_t flags,
495                          uint8_t flags_start, uint8_t flags_end, uint8_t *out) {
496   __m128i h_vecs[8] = {
497       set1_128(key[0]), set1_128(key[1]), set1_128(key[2]), set1_128(key[3]),
498       set1_128(key[4]), set1_128(key[5]), set1_128(key[6]), set1_128(key[7]),
499   };
500   __m128i counter_low_vec, counter_high_vec;
501   load_counters4(counter, increment_counter, &counter_low_vec,
502                  &counter_high_vec);
503   uint8_t block_flags = flags | flags_start;
504 
505   for (size_t block = 0; block < blocks; block++) {
506     if (block + 1 == blocks) {
507       block_flags |= flags_end;
508     }
509     __m128i block_len_vec = set1_128(BLAKE3_BLOCK_LEN);
510     __m128i block_flags_vec = set1_128(block_flags);
511     __m128i msg_vecs[16];
512     transpose_msg_vecs4(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
513 
514     __m128i v[16] = {
515         h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
516         h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
517         set1_128(IV[0]), set1_128(IV[1]),  set1_128(IV[2]), set1_128(IV[3]),
518         counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
519     };
520     round_fn4(v, msg_vecs, 0);
521     round_fn4(v, msg_vecs, 1);
522     round_fn4(v, msg_vecs, 2);
523     round_fn4(v, msg_vecs, 3);
524     round_fn4(v, msg_vecs, 4);
525     round_fn4(v, msg_vecs, 5);
526     round_fn4(v, msg_vecs, 6);
527     h_vecs[0] = xor_128(v[0], v[8]);
528     h_vecs[1] = xor_128(v[1], v[9]);
529     h_vecs[2] = xor_128(v[2], v[10]);
530     h_vecs[3] = xor_128(v[3], v[11]);
531     h_vecs[4] = xor_128(v[4], v[12]);
532     h_vecs[5] = xor_128(v[5], v[13]);
533     h_vecs[6] = xor_128(v[6], v[14]);
534     h_vecs[7] = xor_128(v[7], v[15]);
535 
536     block_flags = flags;
537   }
538 
539   transpose_vecs_128(&h_vecs[0]);
540   transpose_vecs_128(&h_vecs[4]);
541   // The first four vecs now contain the first half of each output, and the
542   // second four vecs contain the second half of each output.
543   storeu_128(h_vecs[0], &out[0 * sizeof(__m128i)]);
544   storeu_128(h_vecs[4], &out[1 * sizeof(__m128i)]);
545   storeu_128(h_vecs[1], &out[2 * sizeof(__m128i)]);
546   storeu_128(h_vecs[5], &out[3 * sizeof(__m128i)]);
547   storeu_128(h_vecs[2], &out[4 * sizeof(__m128i)]);
548   storeu_128(h_vecs[6], &out[5 * sizeof(__m128i)]);
549   storeu_128(h_vecs[3], &out[6 * sizeof(__m128i)]);
550   storeu_128(h_vecs[7], &out[7 * sizeof(__m128i)]);
551 }
552 
553 /*
554  * ----------------------------------------------------------------------------
555  * hash8_avx512
556  * ----------------------------------------------------------------------------
557  */
558 
round_fn8(__m256i v[16],__m256i m[16],size_t r)559 INLINE void round_fn8(__m256i v[16], __m256i m[16], size_t r) {
560   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
561   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
562   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
563   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
564   v[0] = add_256(v[0], v[4]);
565   v[1] = add_256(v[1], v[5]);
566   v[2] = add_256(v[2], v[6]);
567   v[3] = add_256(v[3], v[7]);
568   v[12] = xor_256(v[12], v[0]);
569   v[13] = xor_256(v[13], v[1]);
570   v[14] = xor_256(v[14], v[2]);
571   v[15] = xor_256(v[15], v[3]);
572   v[12] = rot16_256(v[12]);
573   v[13] = rot16_256(v[13]);
574   v[14] = rot16_256(v[14]);
575   v[15] = rot16_256(v[15]);
576   v[8] = add_256(v[8], v[12]);
577   v[9] = add_256(v[9], v[13]);
578   v[10] = add_256(v[10], v[14]);
579   v[11] = add_256(v[11], v[15]);
580   v[4] = xor_256(v[4], v[8]);
581   v[5] = xor_256(v[5], v[9]);
582   v[6] = xor_256(v[6], v[10]);
583   v[7] = xor_256(v[7], v[11]);
584   v[4] = rot12_256(v[4]);
585   v[5] = rot12_256(v[5]);
586   v[6] = rot12_256(v[6]);
587   v[7] = rot12_256(v[7]);
588   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
589   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
590   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
591   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
592   v[0] = add_256(v[0], v[4]);
593   v[1] = add_256(v[1], v[5]);
594   v[2] = add_256(v[2], v[6]);
595   v[3] = add_256(v[3], v[7]);
596   v[12] = xor_256(v[12], v[0]);
597   v[13] = xor_256(v[13], v[1]);
598   v[14] = xor_256(v[14], v[2]);
599   v[15] = xor_256(v[15], v[3]);
600   v[12] = rot8_256(v[12]);
601   v[13] = rot8_256(v[13]);
602   v[14] = rot8_256(v[14]);
603   v[15] = rot8_256(v[15]);
604   v[8] = add_256(v[8], v[12]);
605   v[9] = add_256(v[9], v[13]);
606   v[10] = add_256(v[10], v[14]);
607   v[11] = add_256(v[11], v[15]);
608   v[4] = xor_256(v[4], v[8]);
609   v[5] = xor_256(v[5], v[9]);
610   v[6] = xor_256(v[6], v[10]);
611   v[7] = xor_256(v[7], v[11]);
612   v[4] = rot7_256(v[4]);
613   v[5] = rot7_256(v[5]);
614   v[6] = rot7_256(v[6]);
615   v[7] = rot7_256(v[7]);
616 
617   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
618   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
619   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
620   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
621   v[0] = add_256(v[0], v[5]);
622   v[1] = add_256(v[1], v[6]);
623   v[2] = add_256(v[2], v[7]);
624   v[3] = add_256(v[3], v[4]);
625   v[15] = xor_256(v[15], v[0]);
626   v[12] = xor_256(v[12], v[1]);
627   v[13] = xor_256(v[13], v[2]);
628   v[14] = xor_256(v[14], v[3]);
629   v[15] = rot16_256(v[15]);
630   v[12] = rot16_256(v[12]);
631   v[13] = rot16_256(v[13]);
632   v[14] = rot16_256(v[14]);
633   v[10] = add_256(v[10], v[15]);
634   v[11] = add_256(v[11], v[12]);
635   v[8] = add_256(v[8], v[13]);
636   v[9] = add_256(v[9], v[14]);
637   v[5] = xor_256(v[5], v[10]);
638   v[6] = xor_256(v[6], v[11]);
639   v[7] = xor_256(v[7], v[8]);
640   v[4] = xor_256(v[4], v[9]);
641   v[5] = rot12_256(v[5]);
642   v[6] = rot12_256(v[6]);
643   v[7] = rot12_256(v[7]);
644   v[4] = rot12_256(v[4]);
645   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
646   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
647   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
648   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
649   v[0] = add_256(v[0], v[5]);
650   v[1] = add_256(v[1], v[6]);
651   v[2] = add_256(v[2], v[7]);
652   v[3] = add_256(v[3], v[4]);
653   v[15] = xor_256(v[15], v[0]);
654   v[12] = xor_256(v[12], v[1]);
655   v[13] = xor_256(v[13], v[2]);
656   v[14] = xor_256(v[14], v[3]);
657   v[15] = rot8_256(v[15]);
658   v[12] = rot8_256(v[12]);
659   v[13] = rot8_256(v[13]);
660   v[14] = rot8_256(v[14]);
661   v[10] = add_256(v[10], v[15]);
662   v[11] = add_256(v[11], v[12]);
663   v[8] = add_256(v[8], v[13]);
664   v[9] = add_256(v[9], v[14]);
665   v[5] = xor_256(v[5], v[10]);
666   v[6] = xor_256(v[6], v[11]);
667   v[7] = xor_256(v[7], v[8]);
668   v[4] = xor_256(v[4], v[9]);
669   v[5] = rot7_256(v[5]);
670   v[6] = rot7_256(v[6]);
671   v[7] = rot7_256(v[7]);
672   v[4] = rot7_256(v[4]);
673 }
674 
transpose_vecs_256(__m256i vecs[8])675 INLINE void transpose_vecs_256(__m256i vecs[8]) {
676   // Interleave 32-bit lanes. The low unpack is lanes 00/11/44/55, and the high
677   // is 22/33/66/77.
678   __m256i ab_0145 = _mm256_unpacklo_epi32(vecs[0], vecs[1]);
679   __m256i ab_2367 = _mm256_unpackhi_epi32(vecs[0], vecs[1]);
680   __m256i cd_0145 = _mm256_unpacklo_epi32(vecs[2], vecs[3]);
681   __m256i cd_2367 = _mm256_unpackhi_epi32(vecs[2], vecs[3]);
682   __m256i ef_0145 = _mm256_unpacklo_epi32(vecs[4], vecs[5]);
683   __m256i ef_2367 = _mm256_unpackhi_epi32(vecs[4], vecs[5]);
684   __m256i gh_0145 = _mm256_unpacklo_epi32(vecs[6], vecs[7]);
685   __m256i gh_2367 = _mm256_unpackhi_epi32(vecs[6], vecs[7]);
686 
687   // Interleave 64-bit lates. The low unpack is lanes 00/22 and the high is
688   // 11/33.
689   __m256i abcd_04 = _mm256_unpacklo_epi64(ab_0145, cd_0145);
690   __m256i abcd_15 = _mm256_unpackhi_epi64(ab_0145, cd_0145);
691   __m256i abcd_26 = _mm256_unpacklo_epi64(ab_2367, cd_2367);
692   __m256i abcd_37 = _mm256_unpackhi_epi64(ab_2367, cd_2367);
693   __m256i efgh_04 = _mm256_unpacklo_epi64(ef_0145, gh_0145);
694   __m256i efgh_15 = _mm256_unpackhi_epi64(ef_0145, gh_0145);
695   __m256i efgh_26 = _mm256_unpacklo_epi64(ef_2367, gh_2367);
696   __m256i efgh_37 = _mm256_unpackhi_epi64(ef_2367, gh_2367);
697 
698   // Interleave 128-bit lanes.
699   vecs[0] = _mm256_permute2x128_si256(abcd_04, efgh_04, 0x20);
700   vecs[1] = _mm256_permute2x128_si256(abcd_15, efgh_15, 0x20);
701   vecs[2] = _mm256_permute2x128_si256(abcd_26, efgh_26, 0x20);
702   vecs[3] = _mm256_permute2x128_si256(abcd_37, efgh_37, 0x20);
703   vecs[4] = _mm256_permute2x128_si256(abcd_04, efgh_04, 0x31);
704   vecs[5] = _mm256_permute2x128_si256(abcd_15, efgh_15, 0x31);
705   vecs[6] = _mm256_permute2x128_si256(abcd_26, efgh_26, 0x31);
706   vecs[7] = _mm256_permute2x128_si256(abcd_37, efgh_37, 0x31);
707 }
708 
transpose_msg_vecs8(const uint8_t * const * inputs,size_t block_offset,__m256i out[16])709 INLINE void transpose_msg_vecs8(const uint8_t *const *inputs,
710                                 size_t block_offset, __m256i out[16]) {
711   out[0] = loadu_256(&inputs[0][block_offset + 0 * sizeof(__m256i)]);
712   out[1] = loadu_256(&inputs[1][block_offset + 0 * sizeof(__m256i)]);
713   out[2] = loadu_256(&inputs[2][block_offset + 0 * sizeof(__m256i)]);
714   out[3] = loadu_256(&inputs[3][block_offset + 0 * sizeof(__m256i)]);
715   out[4] = loadu_256(&inputs[4][block_offset + 0 * sizeof(__m256i)]);
716   out[5] = loadu_256(&inputs[5][block_offset + 0 * sizeof(__m256i)]);
717   out[6] = loadu_256(&inputs[6][block_offset + 0 * sizeof(__m256i)]);
718   out[7] = loadu_256(&inputs[7][block_offset + 0 * sizeof(__m256i)]);
719   out[8] = loadu_256(&inputs[0][block_offset + 1 * sizeof(__m256i)]);
720   out[9] = loadu_256(&inputs[1][block_offset + 1 * sizeof(__m256i)]);
721   out[10] = loadu_256(&inputs[2][block_offset + 1 * sizeof(__m256i)]);
722   out[11] = loadu_256(&inputs[3][block_offset + 1 * sizeof(__m256i)]);
723   out[12] = loadu_256(&inputs[4][block_offset + 1 * sizeof(__m256i)]);
724   out[13] = loadu_256(&inputs[5][block_offset + 1 * sizeof(__m256i)]);
725   out[14] = loadu_256(&inputs[6][block_offset + 1 * sizeof(__m256i)]);
726   out[15] = loadu_256(&inputs[7][block_offset + 1 * sizeof(__m256i)]);
727   for (size_t i = 0; i < 8; ++i) {
728     _mm_prefetch((const void *)&inputs[i][block_offset + 256], _MM_HINT_T0);
729   }
730   transpose_vecs_256(&out[0]);
731   transpose_vecs_256(&out[8]);
732 }
733 
load_counters8(uint64_t counter,bool increment_counter,__m256i * out_lo,__m256i * out_hi)734 INLINE void load_counters8(uint64_t counter, bool increment_counter,
735                            __m256i *out_lo, __m256i *out_hi) {
736   uint64_t mask = (increment_counter ? ~0 : 0);
737   __m512i mask_vec = _mm512_set1_epi64(mask);
738   __m512i deltas = _mm512_setr_epi64(0, 1, 2, 3, 4, 5, 6, 7);
739   deltas = _mm512_and_si512(mask_vec, deltas);
740   __m512i counters =
741       _mm512_add_epi64(_mm512_set1_epi64((int64_t)counter), deltas);
742   *out_lo = _mm512_cvtepi64_epi32(counters);
743   *out_hi = _mm512_cvtepi64_epi32(_mm512_srli_epi64(counters, 32));
744 }
745 
746 static
blake3_hash8_avx512(const uint8_t * const * inputs,size_t blocks,const uint32_t key[8],uint64_t counter,bool increment_counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t * out)747 void blake3_hash8_avx512(const uint8_t *const *inputs, size_t blocks,
748                          const uint32_t key[8], uint64_t counter,
749                          bool increment_counter, uint8_t flags,
750                          uint8_t flags_start, uint8_t flags_end, uint8_t *out) {
751   __m256i h_vecs[8] = {
752       set1_256(key[0]), set1_256(key[1]), set1_256(key[2]), set1_256(key[3]),
753       set1_256(key[4]), set1_256(key[5]), set1_256(key[6]), set1_256(key[7]),
754   };
755   __m256i counter_low_vec, counter_high_vec;
756   load_counters8(counter, increment_counter, &counter_low_vec,
757                  &counter_high_vec);
758   uint8_t block_flags = flags | flags_start;
759 
760   for (size_t block = 0; block < blocks; block++) {
761     if (block + 1 == blocks) {
762       block_flags |= flags_end;
763     }
764     __m256i block_len_vec = set1_256(BLAKE3_BLOCK_LEN);
765     __m256i block_flags_vec = set1_256(block_flags);
766     __m256i msg_vecs[16];
767     transpose_msg_vecs8(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
768 
769     __m256i v[16] = {
770         h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
771         h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
772         set1_256(IV[0]), set1_256(IV[1]),  set1_256(IV[2]), set1_256(IV[3]),
773         counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
774     };
775     round_fn8(v, msg_vecs, 0);
776     round_fn8(v, msg_vecs, 1);
777     round_fn8(v, msg_vecs, 2);
778     round_fn8(v, msg_vecs, 3);
779     round_fn8(v, msg_vecs, 4);
780     round_fn8(v, msg_vecs, 5);
781     round_fn8(v, msg_vecs, 6);
782     h_vecs[0] = xor_256(v[0], v[8]);
783     h_vecs[1] = xor_256(v[1], v[9]);
784     h_vecs[2] = xor_256(v[2], v[10]);
785     h_vecs[3] = xor_256(v[3], v[11]);
786     h_vecs[4] = xor_256(v[4], v[12]);
787     h_vecs[5] = xor_256(v[5], v[13]);
788     h_vecs[6] = xor_256(v[6], v[14]);
789     h_vecs[7] = xor_256(v[7], v[15]);
790 
791     block_flags = flags;
792   }
793 
794   transpose_vecs_256(h_vecs);
795   storeu_256(h_vecs[0], &out[0 * sizeof(__m256i)]);
796   storeu_256(h_vecs[1], &out[1 * sizeof(__m256i)]);
797   storeu_256(h_vecs[2], &out[2 * sizeof(__m256i)]);
798   storeu_256(h_vecs[3], &out[3 * sizeof(__m256i)]);
799   storeu_256(h_vecs[4], &out[4 * sizeof(__m256i)]);
800   storeu_256(h_vecs[5], &out[5 * sizeof(__m256i)]);
801   storeu_256(h_vecs[6], &out[6 * sizeof(__m256i)]);
802   storeu_256(h_vecs[7], &out[7 * sizeof(__m256i)]);
803 }
804 
805 /*
806  * ----------------------------------------------------------------------------
807  * hash16_avx512
808  * ----------------------------------------------------------------------------
809  */
810 
round_fn16(__m512i v[16],__m512i m[16],size_t r)811 INLINE void round_fn16(__m512i v[16], __m512i m[16], size_t r) {
812   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
813   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
814   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
815   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
816   v[0] = add_512(v[0], v[4]);
817   v[1] = add_512(v[1], v[5]);
818   v[2] = add_512(v[2], v[6]);
819   v[3] = add_512(v[3], v[7]);
820   v[12] = xor_512(v[12], v[0]);
821   v[13] = xor_512(v[13], v[1]);
822   v[14] = xor_512(v[14], v[2]);
823   v[15] = xor_512(v[15], v[3]);
824   v[12] = rot16_512(v[12]);
825   v[13] = rot16_512(v[13]);
826   v[14] = rot16_512(v[14]);
827   v[15] = rot16_512(v[15]);
828   v[8] = add_512(v[8], v[12]);
829   v[9] = add_512(v[9], v[13]);
830   v[10] = add_512(v[10], v[14]);
831   v[11] = add_512(v[11], v[15]);
832   v[4] = xor_512(v[4], v[8]);
833   v[5] = xor_512(v[5], v[9]);
834   v[6] = xor_512(v[6], v[10]);
835   v[7] = xor_512(v[7], v[11]);
836   v[4] = rot12_512(v[4]);
837   v[5] = rot12_512(v[5]);
838   v[6] = rot12_512(v[6]);
839   v[7] = rot12_512(v[7]);
840   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
841   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
842   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
843   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
844   v[0] = add_512(v[0], v[4]);
845   v[1] = add_512(v[1], v[5]);
846   v[2] = add_512(v[2], v[6]);
847   v[3] = add_512(v[3], v[7]);
848   v[12] = xor_512(v[12], v[0]);
849   v[13] = xor_512(v[13], v[1]);
850   v[14] = xor_512(v[14], v[2]);
851   v[15] = xor_512(v[15], v[3]);
852   v[12] = rot8_512(v[12]);
853   v[13] = rot8_512(v[13]);
854   v[14] = rot8_512(v[14]);
855   v[15] = rot8_512(v[15]);
856   v[8] = add_512(v[8], v[12]);
857   v[9] = add_512(v[9], v[13]);
858   v[10] = add_512(v[10], v[14]);
859   v[11] = add_512(v[11], v[15]);
860   v[4] = xor_512(v[4], v[8]);
861   v[5] = xor_512(v[5], v[9]);
862   v[6] = xor_512(v[6], v[10]);
863   v[7] = xor_512(v[7], v[11]);
864   v[4] = rot7_512(v[4]);
865   v[5] = rot7_512(v[5]);
866   v[6] = rot7_512(v[6]);
867   v[7] = rot7_512(v[7]);
868 
869   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
870   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
871   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
872   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
873   v[0] = add_512(v[0], v[5]);
874   v[1] = add_512(v[1], v[6]);
875   v[2] = add_512(v[2], v[7]);
876   v[3] = add_512(v[3], v[4]);
877   v[15] = xor_512(v[15], v[0]);
878   v[12] = xor_512(v[12], v[1]);
879   v[13] = xor_512(v[13], v[2]);
880   v[14] = xor_512(v[14], v[3]);
881   v[15] = rot16_512(v[15]);
882   v[12] = rot16_512(v[12]);
883   v[13] = rot16_512(v[13]);
884   v[14] = rot16_512(v[14]);
885   v[10] = add_512(v[10], v[15]);
886   v[11] = add_512(v[11], v[12]);
887   v[8] = add_512(v[8], v[13]);
888   v[9] = add_512(v[9], v[14]);
889   v[5] = xor_512(v[5], v[10]);
890   v[6] = xor_512(v[6], v[11]);
891   v[7] = xor_512(v[7], v[8]);
892   v[4] = xor_512(v[4], v[9]);
893   v[5] = rot12_512(v[5]);
894   v[6] = rot12_512(v[6]);
895   v[7] = rot12_512(v[7]);
896   v[4] = rot12_512(v[4]);
897   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
898   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
899   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
900   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
901   v[0] = add_512(v[0], v[5]);
902   v[1] = add_512(v[1], v[6]);
903   v[2] = add_512(v[2], v[7]);
904   v[3] = add_512(v[3], v[4]);
905   v[15] = xor_512(v[15], v[0]);
906   v[12] = xor_512(v[12], v[1]);
907   v[13] = xor_512(v[13], v[2]);
908   v[14] = xor_512(v[14], v[3]);
909   v[15] = rot8_512(v[15]);
910   v[12] = rot8_512(v[12]);
911   v[13] = rot8_512(v[13]);
912   v[14] = rot8_512(v[14]);
913   v[10] = add_512(v[10], v[15]);
914   v[11] = add_512(v[11], v[12]);
915   v[8] = add_512(v[8], v[13]);
916   v[9] = add_512(v[9], v[14]);
917   v[5] = xor_512(v[5], v[10]);
918   v[6] = xor_512(v[6], v[11]);
919   v[7] = xor_512(v[7], v[8]);
920   v[4] = xor_512(v[4], v[9]);
921   v[5] = rot7_512(v[5]);
922   v[6] = rot7_512(v[6]);
923   v[7] = rot7_512(v[7]);
924   v[4] = rot7_512(v[4]);
925 }
926 
927 // 0b10001000, or lanes a0/a2/b0/b2 in little-endian order
928 #define LO_IMM8 0x88
929 
unpack_lo_128(__m512i a,__m512i b)930 INLINE __m512i unpack_lo_128(__m512i a, __m512i b) {
931   return _mm512_shuffle_i32x4(a, b, LO_IMM8);
932 }
933 
934 // 0b11011101, or lanes a1/a3/b1/b3 in little-endian order
935 #define HI_IMM8 0xdd
936 
unpack_hi_128(__m512i a,__m512i b)937 INLINE __m512i unpack_hi_128(__m512i a, __m512i b) {
938   return _mm512_shuffle_i32x4(a, b, HI_IMM8);
939 }
940 
transpose_vecs_512(__m512i vecs[16])941 INLINE void transpose_vecs_512(__m512i vecs[16]) {
942   // Interleave 32-bit lanes. The _0 unpack is lanes
943   // 0/0/1/1/4/4/5/5/8/8/9/9/12/12/13/13, and the _2 unpack is lanes
944   // 2/2/3/3/6/6/7/7/10/10/11/11/14/14/15/15.
945   __m512i ab_0 = _mm512_unpacklo_epi32(vecs[0], vecs[1]);
946   __m512i ab_2 = _mm512_unpackhi_epi32(vecs[0], vecs[1]);
947   __m512i cd_0 = _mm512_unpacklo_epi32(vecs[2], vecs[3]);
948   __m512i cd_2 = _mm512_unpackhi_epi32(vecs[2], vecs[3]);
949   __m512i ef_0 = _mm512_unpacklo_epi32(vecs[4], vecs[5]);
950   __m512i ef_2 = _mm512_unpackhi_epi32(vecs[4], vecs[5]);
951   __m512i gh_0 = _mm512_unpacklo_epi32(vecs[6], vecs[7]);
952   __m512i gh_2 = _mm512_unpackhi_epi32(vecs[6], vecs[7]);
953   __m512i ij_0 = _mm512_unpacklo_epi32(vecs[8], vecs[9]);
954   __m512i ij_2 = _mm512_unpackhi_epi32(vecs[8], vecs[9]);
955   __m512i kl_0 = _mm512_unpacklo_epi32(vecs[10], vecs[11]);
956   __m512i kl_2 = _mm512_unpackhi_epi32(vecs[10], vecs[11]);
957   __m512i mn_0 = _mm512_unpacklo_epi32(vecs[12], vecs[13]);
958   __m512i mn_2 = _mm512_unpackhi_epi32(vecs[12], vecs[13]);
959   __m512i op_0 = _mm512_unpacklo_epi32(vecs[14], vecs[15]);
960   __m512i op_2 = _mm512_unpackhi_epi32(vecs[14], vecs[15]);
961 
962   // Interleave 64-bit lates. The _0 unpack is lanes
963   // 0/0/0/0/4/4/4/4/8/8/8/8/12/12/12/12, the _1 unpack is lanes
964   // 1/1/1/1/5/5/5/5/9/9/9/9/13/13/13/13, the _2 unpack is lanes
965   // 2/2/2/2/6/6/6/6/10/10/10/10/14/14/14/14, and the _3 unpack is lanes
966   // 3/3/3/3/7/7/7/7/11/11/11/11/15/15/15/15.
967   __m512i abcd_0 = _mm512_unpacklo_epi64(ab_0, cd_0);
968   __m512i abcd_1 = _mm512_unpackhi_epi64(ab_0, cd_0);
969   __m512i abcd_2 = _mm512_unpacklo_epi64(ab_2, cd_2);
970   __m512i abcd_3 = _mm512_unpackhi_epi64(ab_2, cd_2);
971   __m512i efgh_0 = _mm512_unpacklo_epi64(ef_0, gh_0);
972   __m512i efgh_1 = _mm512_unpackhi_epi64(ef_0, gh_0);
973   __m512i efgh_2 = _mm512_unpacklo_epi64(ef_2, gh_2);
974   __m512i efgh_3 = _mm512_unpackhi_epi64(ef_2, gh_2);
975   __m512i ijkl_0 = _mm512_unpacklo_epi64(ij_0, kl_0);
976   __m512i ijkl_1 = _mm512_unpackhi_epi64(ij_0, kl_0);
977   __m512i ijkl_2 = _mm512_unpacklo_epi64(ij_2, kl_2);
978   __m512i ijkl_3 = _mm512_unpackhi_epi64(ij_2, kl_2);
979   __m512i mnop_0 = _mm512_unpacklo_epi64(mn_0, op_0);
980   __m512i mnop_1 = _mm512_unpackhi_epi64(mn_0, op_0);
981   __m512i mnop_2 = _mm512_unpacklo_epi64(mn_2, op_2);
982   __m512i mnop_3 = _mm512_unpackhi_epi64(mn_2, op_2);
983 
984   // Interleave 128-bit lanes. The _0 unpack is
985   // 0/0/0/0/8/8/8/8/0/0/0/0/8/8/8/8, the _1 unpack is
986   // 1/1/1/1/9/9/9/9/1/1/1/1/9/9/9/9, and so on.
987   __m512i abcdefgh_0 = unpack_lo_128(abcd_0, efgh_0);
988   __m512i abcdefgh_1 = unpack_lo_128(abcd_1, efgh_1);
989   __m512i abcdefgh_2 = unpack_lo_128(abcd_2, efgh_2);
990   __m512i abcdefgh_3 = unpack_lo_128(abcd_3, efgh_3);
991   __m512i abcdefgh_4 = unpack_hi_128(abcd_0, efgh_0);
992   __m512i abcdefgh_5 = unpack_hi_128(abcd_1, efgh_1);
993   __m512i abcdefgh_6 = unpack_hi_128(abcd_2, efgh_2);
994   __m512i abcdefgh_7 = unpack_hi_128(abcd_3, efgh_3);
995   __m512i ijklmnop_0 = unpack_lo_128(ijkl_0, mnop_0);
996   __m512i ijklmnop_1 = unpack_lo_128(ijkl_1, mnop_1);
997   __m512i ijklmnop_2 = unpack_lo_128(ijkl_2, mnop_2);
998   __m512i ijklmnop_3 = unpack_lo_128(ijkl_3, mnop_3);
999   __m512i ijklmnop_4 = unpack_hi_128(ijkl_0, mnop_0);
1000   __m512i ijklmnop_5 = unpack_hi_128(ijkl_1, mnop_1);
1001   __m512i ijklmnop_6 = unpack_hi_128(ijkl_2, mnop_2);
1002   __m512i ijklmnop_7 = unpack_hi_128(ijkl_3, mnop_3);
1003 
1004   // Interleave 128-bit lanes again for the final outputs.
1005   vecs[0] = unpack_lo_128(abcdefgh_0, ijklmnop_0);
1006   vecs[1] = unpack_lo_128(abcdefgh_1, ijklmnop_1);
1007   vecs[2] = unpack_lo_128(abcdefgh_2, ijklmnop_2);
1008   vecs[3] = unpack_lo_128(abcdefgh_3, ijklmnop_3);
1009   vecs[4] = unpack_lo_128(abcdefgh_4, ijklmnop_4);
1010   vecs[5] = unpack_lo_128(abcdefgh_5, ijklmnop_5);
1011   vecs[6] = unpack_lo_128(abcdefgh_6, ijklmnop_6);
1012   vecs[7] = unpack_lo_128(abcdefgh_7, ijklmnop_7);
1013   vecs[8] = unpack_hi_128(abcdefgh_0, ijklmnop_0);
1014   vecs[9] = unpack_hi_128(abcdefgh_1, ijklmnop_1);
1015   vecs[10] = unpack_hi_128(abcdefgh_2, ijklmnop_2);
1016   vecs[11] = unpack_hi_128(abcdefgh_3, ijklmnop_3);
1017   vecs[12] = unpack_hi_128(abcdefgh_4, ijklmnop_4);
1018   vecs[13] = unpack_hi_128(abcdefgh_5, ijklmnop_5);
1019   vecs[14] = unpack_hi_128(abcdefgh_6, ijklmnop_6);
1020   vecs[15] = unpack_hi_128(abcdefgh_7, ijklmnop_7);
1021 }
1022 
transpose_msg_vecs16(const uint8_t * const * inputs,size_t block_offset,__m512i out[16])1023 INLINE void transpose_msg_vecs16(const uint8_t *const *inputs,
1024                                  size_t block_offset, __m512i out[16]) {
1025   out[0] = loadu_512(&inputs[0][block_offset]);
1026   out[1] = loadu_512(&inputs[1][block_offset]);
1027   out[2] = loadu_512(&inputs[2][block_offset]);
1028   out[3] = loadu_512(&inputs[3][block_offset]);
1029   out[4] = loadu_512(&inputs[4][block_offset]);
1030   out[5] = loadu_512(&inputs[5][block_offset]);
1031   out[6] = loadu_512(&inputs[6][block_offset]);
1032   out[7] = loadu_512(&inputs[7][block_offset]);
1033   out[8] = loadu_512(&inputs[8][block_offset]);
1034   out[9] = loadu_512(&inputs[9][block_offset]);
1035   out[10] = loadu_512(&inputs[10][block_offset]);
1036   out[11] = loadu_512(&inputs[11][block_offset]);
1037   out[12] = loadu_512(&inputs[12][block_offset]);
1038   out[13] = loadu_512(&inputs[13][block_offset]);
1039   out[14] = loadu_512(&inputs[14][block_offset]);
1040   out[15] = loadu_512(&inputs[15][block_offset]);
1041   for (size_t i = 0; i < 16; ++i) {
1042     _mm_prefetch((const void *)&inputs[i][block_offset + 256], _MM_HINT_T0);
1043   }
1044   transpose_vecs_512(out);
1045 }
1046 
load_counters16(uint64_t counter,bool increment_counter,__m512i * out_lo,__m512i * out_hi)1047 INLINE void load_counters16(uint64_t counter, bool increment_counter,
1048                             __m512i *out_lo, __m512i *out_hi) {
1049   const __m512i mask = _mm512_set1_epi32(-(int32_t)increment_counter);
1050   const __m512i add0 = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
1051   const __m512i add1 = _mm512_and_si512(mask, add0);
1052   __m512i l = _mm512_add_epi32(_mm512_set1_epi32((int32_t)counter), add1);
1053   __mmask16 carry = _mm512_cmp_epu32_mask(l, add1, _MM_CMPINT_LT);
1054   __m512i h = _mm512_mask_add_epi32(_mm512_set1_epi32((int32_t)(counter >> 32)), carry, _mm512_set1_epi32((int32_t)(counter >> 32)), _mm512_set1_epi32(1));
1055   *out_lo = l;
1056   *out_hi = h;
1057 }
1058 
1059 static
blake3_hash16_avx512(const uint8_t * const * inputs,size_t blocks,const uint32_t key[8],uint64_t counter,bool increment_counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t * out)1060 void blake3_hash16_avx512(const uint8_t *const *inputs, size_t blocks,
1061                           const uint32_t key[8], uint64_t counter,
1062                           bool increment_counter, uint8_t flags,
1063                           uint8_t flags_start, uint8_t flags_end,
1064                           uint8_t *out) {
1065   __m512i h_vecs[8] = {
1066       set1_512(key[0]), set1_512(key[1]), set1_512(key[2]), set1_512(key[3]),
1067       set1_512(key[4]), set1_512(key[5]), set1_512(key[6]), set1_512(key[7]),
1068   };
1069   __m512i counter_low_vec, counter_high_vec;
1070   load_counters16(counter, increment_counter, &counter_low_vec,
1071                   &counter_high_vec);
1072   uint8_t block_flags = flags | flags_start;
1073 
1074   for (size_t block = 0; block < blocks; block++) {
1075     if (block + 1 == blocks) {
1076       block_flags |= flags_end;
1077     }
1078     __m512i block_len_vec = set1_512(BLAKE3_BLOCK_LEN);
1079     __m512i block_flags_vec = set1_512(block_flags);
1080     __m512i msg_vecs[16];
1081     transpose_msg_vecs16(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
1082 
1083     __m512i v[16] = {
1084         h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
1085         h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
1086         set1_512(IV[0]), set1_512(IV[1]),  set1_512(IV[2]), set1_512(IV[3]),
1087         counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
1088     };
1089     round_fn16(v, msg_vecs, 0);
1090     round_fn16(v, msg_vecs, 1);
1091     round_fn16(v, msg_vecs, 2);
1092     round_fn16(v, msg_vecs, 3);
1093     round_fn16(v, msg_vecs, 4);
1094     round_fn16(v, msg_vecs, 5);
1095     round_fn16(v, msg_vecs, 6);
1096     h_vecs[0] = xor_512(v[0], v[8]);
1097     h_vecs[1] = xor_512(v[1], v[9]);
1098     h_vecs[2] = xor_512(v[2], v[10]);
1099     h_vecs[3] = xor_512(v[3], v[11]);
1100     h_vecs[4] = xor_512(v[4], v[12]);
1101     h_vecs[5] = xor_512(v[5], v[13]);
1102     h_vecs[6] = xor_512(v[6], v[14]);
1103     h_vecs[7] = xor_512(v[7], v[15]);
1104 
1105     block_flags = flags;
1106   }
1107 
1108   // transpose_vecs_512 operates on a 16x16 matrix of words, but we only have 8
1109   // state vectors. Pad the matrix with zeros. After transposition, store the
1110   // lower half of each vector.
1111   __m512i padded[16] = {
1112       h_vecs[0],   h_vecs[1],   h_vecs[2],   h_vecs[3],
1113       h_vecs[4],   h_vecs[5],   h_vecs[6],   h_vecs[7],
1114       set1_512(0), set1_512(0), set1_512(0), set1_512(0),
1115       set1_512(0), set1_512(0), set1_512(0), set1_512(0),
1116   };
1117   transpose_vecs_512(padded);
1118   _mm256_mask_storeu_epi32(&out[0 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[0]));
1119   _mm256_mask_storeu_epi32(&out[1 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[1]));
1120   _mm256_mask_storeu_epi32(&out[2 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[2]));
1121   _mm256_mask_storeu_epi32(&out[3 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[3]));
1122   _mm256_mask_storeu_epi32(&out[4 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[4]));
1123   _mm256_mask_storeu_epi32(&out[5 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[5]));
1124   _mm256_mask_storeu_epi32(&out[6 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[6]));
1125   _mm256_mask_storeu_epi32(&out[7 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[7]));
1126   _mm256_mask_storeu_epi32(&out[8 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[8]));
1127   _mm256_mask_storeu_epi32(&out[9 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[9]));
1128   _mm256_mask_storeu_epi32(&out[10 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[10]));
1129   _mm256_mask_storeu_epi32(&out[11 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[11]));
1130   _mm256_mask_storeu_epi32(&out[12 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[12]));
1131   _mm256_mask_storeu_epi32(&out[13 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[13]));
1132   _mm256_mask_storeu_epi32(&out[14 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[14]));
1133   _mm256_mask_storeu_epi32(&out[15 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[15]));
1134 }
1135 
1136 /*
1137  * ----------------------------------------------------------------------------
1138  * hash_many_avx512
1139  * ----------------------------------------------------------------------------
1140  */
1141 
hash_one_avx512(const uint8_t * input,size_t blocks,const uint32_t key[8],uint64_t counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t out[BLAKE3_OUT_LEN])1142 INLINE void hash_one_avx512(const uint8_t *input, size_t blocks,
1143                             const uint32_t key[8], uint64_t counter,
1144                             uint8_t flags, uint8_t flags_start,
1145                             uint8_t flags_end, uint8_t out[BLAKE3_OUT_LEN]) {
1146   uint32_t cv[8];
1147   memcpy(cv, key, BLAKE3_KEY_LEN);
1148   uint8_t block_flags = flags | flags_start;
1149   while (blocks > 0) {
1150     if (blocks == 1) {
1151       block_flags |= flags_end;
1152     }
1153     blake3_compress_in_place_avx512(cv, input, BLAKE3_BLOCK_LEN, counter,
1154                                     block_flags);
1155     input = &input[BLAKE3_BLOCK_LEN];
1156     blocks -= 1;
1157     block_flags = flags;
1158   }
1159   memcpy(out, cv, BLAKE3_OUT_LEN);
1160 }
1161 
blake3_hash_many_avx512(const uint8_t * const * inputs,size_t num_inputs,size_t blocks,const uint32_t key[8],uint64_t counter,bool increment_counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t * out)1162 void blake3_hash_many_avx512(const uint8_t *const *inputs, size_t num_inputs,
1163                              size_t blocks, const uint32_t key[8],
1164                              uint64_t counter, bool increment_counter,
1165                              uint8_t flags, uint8_t flags_start,
1166                              uint8_t flags_end, uint8_t *out) {
1167   while (num_inputs >= 16) {
1168     blake3_hash16_avx512(inputs, blocks, key, counter, increment_counter, flags,
1169                          flags_start, flags_end, out);
1170     if (increment_counter) {
1171       counter += 16;
1172     }
1173     inputs += 16;
1174     num_inputs -= 16;
1175     out = &out[16 * BLAKE3_OUT_LEN];
1176   }
1177   while (num_inputs >= 8) {
1178     blake3_hash8_avx512(inputs, blocks, key, counter, increment_counter, flags,
1179                         flags_start, flags_end, out);
1180     if (increment_counter) {
1181       counter += 8;
1182     }
1183     inputs += 8;
1184     num_inputs -= 8;
1185     out = &out[8 * BLAKE3_OUT_LEN];
1186   }
1187   while (num_inputs >= 4) {
1188     blake3_hash4_avx512(inputs, blocks, key, counter, increment_counter, flags,
1189                         flags_start, flags_end, out);
1190     if (increment_counter) {
1191       counter += 4;
1192     }
1193     inputs += 4;
1194     num_inputs -= 4;
1195     out = &out[4 * BLAKE3_OUT_LEN];
1196   }
1197   while (num_inputs > 0) {
1198     hash_one_avx512(inputs[0], blocks, key, counter, flags, flags_start,
1199                     flags_end, out);
1200     if (increment_counter) {
1201       counter += 1;
1202     }
1203     inputs += 1;
1204     num_inputs -= 1;
1205     out = &out[BLAKE3_OUT_LEN];
1206   }
1207 }
1208