xref: /freebsd/contrib/llvm-project/llvm/lib/Support/BLAKE3/blake3_avx512.c (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 #include "blake3_impl.h"
2 
3 #include <immintrin.h>
4 
5 #define _mm_shuffle_ps2(a, b, c)                                               \
6   (_mm_castps_si128(                                                           \
7       _mm_shuffle_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), (c))))
8 
loadu_128(const uint8_t src[16])9 INLINE __m128i loadu_128(const uint8_t src[16]) {
10   return _mm_loadu_si128((const __m128i *)src);
11 }
12 
loadu_256(const uint8_t src[32])13 INLINE __m256i loadu_256(const uint8_t src[32]) {
14   return _mm256_loadu_si256((const __m256i *)src);
15 }
16 
loadu_512(const uint8_t src[64])17 INLINE __m512i loadu_512(const uint8_t src[64]) {
18   return _mm512_loadu_si512((const __m512i *)src);
19 }
20 
storeu_128(__m128i src,uint8_t dest[16])21 INLINE void storeu_128(__m128i src, uint8_t dest[16]) {
22   _mm_storeu_si128((__m128i *)dest, src);
23 }
24 
storeu_256(__m256i src,uint8_t dest[32])25 INLINE void storeu_256(__m256i src, uint8_t dest[32]) {
26   _mm256_storeu_si256((__m256i *)dest, src);
27 }
28 
storeu_512(__m512i src,uint8_t dest[64])29 INLINE void storeu_512(__m512i src, uint8_t dest[64]) {
30   _mm512_storeu_si512((__m512i *)dest, src);
31 }
32 
add_128(__m128i a,__m128i b)33 INLINE __m128i add_128(__m128i a, __m128i b) { return _mm_add_epi32(a, b); }
34 
add_256(__m256i a,__m256i b)35 INLINE __m256i add_256(__m256i a, __m256i b) { return _mm256_add_epi32(a, b); }
36 
add_512(__m512i a,__m512i b)37 INLINE __m512i add_512(__m512i a, __m512i b) { return _mm512_add_epi32(a, b); }
38 
xor_128(__m128i a,__m128i b)39 INLINE __m128i xor_128(__m128i a, __m128i b) { return _mm_xor_si128(a, b); }
40 
xor_256(__m256i a,__m256i b)41 INLINE __m256i xor_256(__m256i a, __m256i b) { return _mm256_xor_si256(a, b); }
42 
xor_512(__m512i a,__m512i b)43 INLINE __m512i xor_512(__m512i a, __m512i b) { return _mm512_xor_si512(a, b); }
44 
set1_128(uint32_t x)45 INLINE __m128i set1_128(uint32_t x) { return _mm_set1_epi32((int32_t)x); }
46 
set1_256(uint32_t x)47 INLINE __m256i set1_256(uint32_t x) { return _mm256_set1_epi32((int32_t)x); }
48 
set1_512(uint32_t x)49 INLINE __m512i set1_512(uint32_t x) { return _mm512_set1_epi32((int32_t)x); }
50 
set4(uint32_t a,uint32_t b,uint32_t c,uint32_t d)51 INLINE __m128i set4(uint32_t a, uint32_t b, uint32_t c, uint32_t d) {
52   return _mm_setr_epi32((int32_t)a, (int32_t)b, (int32_t)c, (int32_t)d);
53 }
54 
rot16_128(__m128i x)55 INLINE __m128i rot16_128(__m128i x) { return _mm_ror_epi32(x, 16); }
56 
rot16_256(__m256i x)57 INLINE __m256i rot16_256(__m256i x) { return _mm256_ror_epi32(x, 16); }
58 
rot16_512(__m512i x)59 INLINE __m512i rot16_512(__m512i x) { return _mm512_ror_epi32(x, 16); }
60 
rot12_128(__m128i x)61 INLINE __m128i rot12_128(__m128i x) { return _mm_ror_epi32(x, 12); }
62 
rot12_256(__m256i x)63 INLINE __m256i rot12_256(__m256i x) { return _mm256_ror_epi32(x, 12); }
64 
rot12_512(__m512i x)65 INLINE __m512i rot12_512(__m512i x) { return _mm512_ror_epi32(x, 12); }
66 
rot8_128(__m128i x)67 INLINE __m128i rot8_128(__m128i x) { return _mm_ror_epi32(x, 8); }
68 
rot8_256(__m256i x)69 INLINE __m256i rot8_256(__m256i x) { return _mm256_ror_epi32(x, 8); }
70 
rot8_512(__m512i x)71 INLINE __m512i rot8_512(__m512i x) { return _mm512_ror_epi32(x, 8); }
72 
rot7_128(__m128i x)73 INLINE __m128i rot7_128(__m128i x) { return _mm_ror_epi32(x, 7); }
74 
rot7_256(__m256i x)75 INLINE __m256i rot7_256(__m256i x) { return _mm256_ror_epi32(x, 7); }
76 
rot7_512(__m512i x)77 INLINE __m512i rot7_512(__m512i x) { return _mm512_ror_epi32(x, 7); }
78 
79 /*
80  * ----------------------------------------------------------------------------
81  * compress_avx512
82  * ----------------------------------------------------------------------------
83  */
84 
g1(__m128i * row0,__m128i * row1,__m128i * row2,__m128i * row3,__m128i m)85 INLINE void g1(__m128i *row0, __m128i *row1, __m128i *row2, __m128i *row3,
86                __m128i m) {
87   *row0 = add_128(add_128(*row0, m), *row1);
88   *row3 = xor_128(*row3, *row0);
89   *row3 = rot16_128(*row3);
90   *row2 = add_128(*row2, *row3);
91   *row1 = xor_128(*row1, *row2);
92   *row1 = rot12_128(*row1);
93 }
94 
g2(__m128i * row0,__m128i * row1,__m128i * row2,__m128i * row3,__m128i m)95 INLINE void g2(__m128i *row0, __m128i *row1, __m128i *row2, __m128i *row3,
96                __m128i m) {
97   *row0 = add_128(add_128(*row0, m), *row1);
98   *row3 = xor_128(*row3, *row0);
99   *row3 = rot8_128(*row3);
100   *row2 = add_128(*row2, *row3);
101   *row1 = xor_128(*row1, *row2);
102   *row1 = rot7_128(*row1);
103 }
104 
105 // Note the optimization here of leaving row1 as the unrotated row, rather than
106 // row0. All the message loads below are adjusted to compensate for this. See
107 // discussion at https://github.com/sneves/blake2-avx2/pull/4
diagonalize(__m128i * row0,__m128i * row2,__m128i * row3)108 INLINE void diagonalize(__m128i *row0, __m128i *row2, __m128i *row3) {
109   *row0 = _mm_shuffle_epi32(*row0, _MM_SHUFFLE(2, 1, 0, 3));
110   *row3 = _mm_shuffle_epi32(*row3, _MM_SHUFFLE(1, 0, 3, 2));
111   *row2 = _mm_shuffle_epi32(*row2, _MM_SHUFFLE(0, 3, 2, 1));
112 }
113 
undiagonalize(__m128i * row0,__m128i * row2,__m128i * row3)114 INLINE void undiagonalize(__m128i *row0, __m128i *row2, __m128i *row3) {
115   *row0 = _mm_shuffle_epi32(*row0, _MM_SHUFFLE(0, 3, 2, 1));
116   *row3 = _mm_shuffle_epi32(*row3, _MM_SHUFFLE(1, 0, 3, 2));
117   *row2 = _mm_shuffle_epi32(*row2, _MM_SHUFFLE(2, 1, 0, 3));
118 }
119 
compress_pre(__m128i rows[4],const uint32_t cv[8],const uint8_t block[BLAKE3_BLOCK_LEN],uint8_t block_len,uint64_t counter,uint8_t flags)120 INLINE void compress_pre(__m128i rows[4], const uint32_t cv[8],
121                          const uint8_t block[BLAKE3_BLOCK_LEN],
122                          uint8_t block_len, uint64_t counter, uint8_t flags) {
123   rows[0] = loadu_128((uint8_t *)&cv[0]);
124   rows[1] = loadu_128((uint8_t *)&cv[4]);
125   rows[2] = set4(IV[0], IV[1], IV[2], IV[3]);
126   rows[3] = set4(counter_low(counter), counter_high(counter),
127                  (uint32_t)block_len, (uint32_t)flags);
128 
129   __m128i m0 = loadu_128(&block[sizeof(__m128i) * 0]);
130   __m128i m1 = loadu_128(&block[sizeof(__m128i) * 1]);
131   __m128i m2 = loadu_128(&block[sizeof(__m128i) * 2]);
132   __m128i m3 = loadu_128(&block[sizeof(__m128i) * 3]);
133 
134   __m128i t0, t1, t2, t3, tt;
135 
136   // Round 1. The first round permutes the message words from the original
137   // input order, into the groups that get mixed in parallel.
138   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(2, 0, 2, 0)); //  6  4  2  0
139   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
140   t1 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 3, 1)); //  7  5  3  1
141   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
142   diagonalize(&rows[0], &rows[2], &rows[3]);
143   t2 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(2, 0, 2, 0)); // 14 12 10  8
144   t2 = _mm_shuffle_epi32(t2, _MM_SHUFFLE(2, 1, 0, 3));   // 12 10  8 14
145   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
146   t3 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 1, 3, 1)); // 15 13 11  9
147   t3 = _mm_shuffle_epi32(t3, _MM_SHUFFLE(2, 1, 0, 3));   // 13 11  9 15
148   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
149   undiagonalize(&rows[0], &rows[2], &rows[3]);
150   m0 = t0;
151   m1 = t1;
152   m2 = t2;
153   m3 = t3;
154 
155   // Round 2. This round and all following rounds apply a fixed permutation
156   // to the message words from the round before.
157   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
158   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
159   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
160   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
161   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
162   t1 = _mm_blend_epi16(tt, t1, 0xCC);
163   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
164   diagonalize(&rows[0], &rows[2], &rows[3]);
165   t2 = _mm_unpacklo_epi64(m3, m1);
166   tt = _mm_blend_epi16(t2, m2, 0xC0);
167   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
168   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
169   t3 = _mm_unpackhi_epi32(m1, m3);
170   tt = _mm_unpacklo_epi32(m2, t3);
171   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
172   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
173   undiagonalize(&rows[0], &rows[2], &rows[3]);
174   m0 = t0;
175   m1 = t1;
176   m2 = t2;
177   m3 = t3;
178 
179   // Round 3
180   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
181   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
182   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
183   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
184   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
185   t1 = _mm_blend_epi16(tt, t1, 0xCC);
186   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
187   diagonalize(&rows[0], &rows[2], &rows[3]);
188   t2 = _mm_unpacklo_epi64(m3, m1);
189   tt = _mm_blend_epi16(t2, m2, 0xC0);
190   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
191   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
192   t3 = _mm_unpackhi_epi32(m1, m3);
193   tt = _mm_unpacklo_epi32(m2, t3);
194   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
195   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
196   undiagonalize(&rows[0], &rows[2], &rows[3]);
197   m0 = t0;
198   m1 = t1;
199   m2 = t2;
200   m3 = t3;
201 
202   // Round 4
203   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
204   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
205   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
206   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
207   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
208   t1 = _mm_blend_epi16(tt, t1, 0xCC);
209   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
210   diagonalize(&rows[0], &rows[2], &rows[3]);
211   t2 = _mm_unpacklo_epi64(m3, m1);
212   tt = _mm_blend_epi16(t2, m2, 0xC0);
213   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
214   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
215   t3 = _mm_unpackhi_epi32(m1, m3);
216   tt = _mm_unpacklo_epi32(m2, t3);
217   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
218   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
219   undiagonalize(&rows[0], &rows[2], &rows[3]);
220   m0 = t0;
221   m1 = t1;
222   m2 = t2;
223   m3 = t3;
224 
225   // Round 5
226   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
227   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
228   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
229   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
230   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
231   t1 = _mm_blend_epi16(tt, t1, 0xCC);
232   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
233   diagonalize(&rows[0], &rows[2], &rows[3]);
234   t2 = _mm_unpacklo_epi64(m3, m1);
235   tt = _mm_blend_epi16(t2, m2, 0xC0);
236   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
237   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
238   t3 = _mm_unpackhi_epi32(m1, m3);
239   tt = _mm_unpacklo_epi32(m2, t3);
240   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
241   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
242   undiagonalize(&rows[0], &rows[2], &rows[3]);
243   m0 = t0;
244   m1 = t1;
245   m2 = t2;
246   m3 = t3;
247 
248   // Round 6
249   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
250   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
251   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
252   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
253   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
254   t1 = _mm_blend_epi16(tt, t1, 0xCC);
255   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
256   diagonalize(&rows[0], &rows[2], &rows[3]);
257   t2 = _mm_unpacklo_epi64(m3, m1);
258   tt = _mm_blend_epi16(t2, m2, 0xC0);
259   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
260   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
261   t3 = _mm_unpackhi_epi32(m1, m3);
262   tt = _mm_unpacklo_epi32(m2, t3);
263   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
264   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
265   undiagonalize(&rows[0], &rows[2], &rows[3]);
266   m0 = t0;
267   m1 = t1;
268   m2 = t2;
269   m3 = t3;
270 
271   // Round 7
272   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
273   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
274   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
275   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
276   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
277   t1 = _mm_blend_epi16(tt, t1, 0xCC);
278   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
279   diagonalize(&rows[0], &rows[2], &rows[3]);
280   t2 = _mm_unpacklo_epi64(m3, m1);
281   tt = _mm_blend_epi16(t2, m2, 0xC0);
282   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
283   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
284   t3 = _mm_unpackhi_epi32(m1, m3);
285   tt = _mm_unpacklo_epi32(m2, t3);
286   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
287   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
288   undiagonalize(&rows[0], &rows[2], &rows[3]);
289 }
290 
blake3_compress_xof_avx512(const uint32_t cv[8],const uint8_t block[BLAKE3_BLOCK_LEN],uint8_t block_len,uint64_t counter,uint8_t flags,uint8_t out[64])291 void blake3_compress_xof_avx512(const uint32_t cv[8],
292                                 const uint8_t block[BLAKE3_BLOCK_LEN],
293                                 uint8_t block_len, uint64_t counter,
294                                 uint8_t flags, uint8_t out[64]) {
295   __m128i rows[4];
296   compress_pre(rows, cv, block, block_len, counter, flags);
297   storeu_128(xor_128(rows[0], rows[2]), &out[0]);
298   storeu_128(xor_128(rows[1], rows[3]), &out[16]);
299   storeu_128(xor_128(rows[2], loadu_128((uint8_t *)&cv[0])), &out[32]);
300   storeu_128(xor_128(rows[3], loadu_128((uint8_t *)&cv[4])), &out[48]);
301 }
302 
blake3_compress_in_place_avx512(uint32_t cv[8],const uint8_t block[BLAKE3_BLOCK_LEN],uint8_t block_len,uint64_t counter,uint8_t flags)303 void blake3_compress_in_place_avx512(uint32_t cv[8],
304                                      const uint8_t block[BLAKE3_BLOCK_LEN],
305                                      uint8_t block_len, uint64_t counter,
306                                      uint8_t flags) {
307   __m128i rows[4];
308   compress_pre(rows, cv, block, block_len, counter, flags);
309   storeu_128(xor_128(rows[0], rows[2]), (uint8_t *)&cv[0]);
310   storeu_128(xor_128(rows[1], rows[3]), (uint8_t *)&cv[4]);
311 }
312 
313 /*
314  * ----------------------------------------------------------------------------
315  * hash4_avx512
316  * ----------------------------------------------------------------------------
317  */
318 
round_fn4(__m128i v[16],__m128i m[16],size_t r)319 INLINE void round_fn4(__m128i v[16], __m128i m[16], size_t r) {
320   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
321   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
322   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
323   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
324   v[0] = add_128(v[0], v[4]);
325   v[1] = add_128(v[1], v[5]);
326   v[2] = add_128(v[2], v[6]);
327   v[3] = add_128(v[3], v[7]);
328   v[12] = xor_128(v[12], v[0]);
329   v[13] = xor_128(v[13], v[1]);
330   v[14] = xor_128(v[14], v[2]);
331   v[15] = xor_128(v[15], v[3]);
332   v[12] = rot16_128(v[12]);
333   v[13] = rot16_128(v[13]);
334   v[14] = rot16_128(v[14]);
335   v[15] = rot16_128(v[15]);
336   v[8] = add_128(v[8], v[12]);
337   v[9] = add_128(v[9], v[13]);
338   v[10] = add_128(v[10], v[14]);
339   v[11] = add_128(v[11], v[15]);
340   v[4] = xor_128(v[4], v[8]);
341   v[5] = xor_128(v[5], v[9]);
342   v[6] = xor_128(v[6], v[10]);
343   v[7] = xor_128(v[7], v[11]);
344   v[4] = rot12_128(v[4]);
345   v[5] = rot12_128(v[5]);
346   v[6] = rot12_128(v[6]);
347   v[7] = rot12_128(v[7]);
348   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
349   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
350   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
351   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
352   v[0] = add_128(v[0], v[4]);
353   v[1] = add_128(v[1], v[5]);
354   v[2] = add_128(v[2], v[6]);
355   v[3] = add_128(v[3], v[7]);
356   v[12] = xor_128(v[12], v[0]);
357   v[13] = xor_128(v[13], v[1]);
358   v[14] = xor_128(v[14], v[2]);
359   v[15] = xor_128(v[15], v[3]);
360   v[12] = rot8_128(v[12]);
361   v[13] = rot8_128(v[13]);
362   v[14] = rot8_128(v[14]);
363   v[15] = rot8_128(v[15]);
364   v[8] = add_128(v[8], v[12]);
365   v[9] = add_128(v[9], v[13]);
366   v[10] = add_128(v[10], v[14]);
367   v[11] = add_128(v[11], v[15]);
368   v[4] = xor_128(v[4], v[8]);
369   v[5] = xor_128(v[5], v[9]);
370   v[6] = xor_128(v[6], v[10]);
371   v[7] = xor_128(v[7], v[11]);
372   v[4] = rot7_128(v[4]);
373   v[5] = rot7_128(v[5]);
374   v[6] = rot7_128(v[6]);
375   v[7] = rot7_128(v[7]);
376 
377   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
378   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
379   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
380   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
381   v[0] = add_128(v[0], v[5]);
382   v[1] = add_128(v[1], v[6]);
383   v[2] = add_128(v[2], v[7]);
384   v[3] = add_128(v[3], v[4]);
385   v[15] = xor_128(v[15], v[0]);
386   v[12] = xor_128(v[12], v[1]);
387   v[13] = xor_128(v[13], v[2]);
388   v[14] = xor_128(v[14], v[3]);
389   v[15] = rot16_128(v[15]);
390   v[12] = rot16_128(v[12]);
391   v[13] = rot16_128(v[13]);
392   v[14] = rot16_128(v[14]);
393   v[10] = add_128(v[10], v[15]);
394   v[11] = add_128(v[11], v[12]);
395   v[8] = add_128(v[8], v[13]);
396   v[9] = add_128(v[9], v[14]);
397   v[5] = xor_128(v[5], v[10]);
398   v[6] = xor_128(v[6], v[11]);
399   v[7] = xor_128(v[7], v[8]);
400   v[4] = xor_128(v[4], v[9]);
401   v[5] = rot12_128(v[5]);
402   v[6] = rot12_128(v[6]);
403   v[7] = rot12_128(v[7]);
404   v[4] = rot12_128(v[4]);
405   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
406   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
407   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
408   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
409   v[0] = add_128(v[0], v[5]);
410   v[1] = add_128(v[1], v[6]);
411   v[2] = add_128(v[2], v[7]);
412   v[3] = add_128(v[3], v[4]);
413   v[15] = xor_128(v[15], v[0]);
414   v[12] = xor_128(v[12], v[1]);
415   v[13] = xor_128(v[13], v[2]);
416   v[14] = xor_128(v[14], v[3]);
417   v[15] = rot8_128(v[15]);
418   v[12] = rot8_128(v[12]);
419   v[13] = rot8_128(v[13]);
420   v[14] = rot8_128(v[14]);
421   v[10] = add_128(v[10], v[15]);
422   v[11] = add_128(v[11], v[12]);
423   v[8] = add_128(v[8], v[13]);
424   v[9] = add_128(v[9], v[14]);
425   v[5] = xor_128(v[5], v[10]);
426   v[6] = xor_128(v[6], v[11]);
427   v[7] = xor_128(v[7], v[8]);
428   v[4] = xor_128(v[4], v[9]);
429   v[5] = rot7_128(v[5]);
430   v[6] = rot7_128(v[6]);
431   v[7] = rot7_128(v[7]);
432   v[4] = rot7_128(v[4]);
433 }
434 
transpose_vecs_128(__m128i vecs[4])435 INLINE void transpose_vecs_128(__m128i vecs[4]) {
436   // Interleave 32-bit lanes. The low unpack is lanes 00/11 and the high is
437   // 22/33. Note that this doesn't split the vector into two lanes, as the
438   // AVX2 counterparts do.
439   __m128i ab_01 = _mm_unpacklo_epi32(vecs[0], vecs[1]);
440   __m128i ab_23 = _mm_unpackhi_epi32(vecs[0], vecs[1]);
441   __m128i cd_01 = _mm_unpacklo_epi32(vecs[2], vecs[3]);
442   __m128i cd_23 = _mm_unpackhi_epi32(vecs[2], vecs[3]);
443 
444   // Interleave 64-bit lanes.
445   __m128i abcd_0 = _mm_unpacklo_epi64(ab_01, cd_01);
446   __m128i abcd_1 = _mm_unpackhi_epi64(ab_01, cd_01);
447   __m128i abcd_2 = _mm_unpacklo_epi64(ab_23, cd_23);
448   __m128i abcd_3 = _mm_unpackhi_epi64(ab_23, cd_23);
449 
450   vecs[0] = abcd_0;
451   vecs[1] = abcd_1;
452   vecs[2] = abcd_2;
453   vecs[3] = abcd_3;
454 }
455 
transpose_msg_vecs4(const uint8_t * const * inputs,size_t block_offset,__m128i out[16])456 INLINE void transpose_msg_vecs4(const uint8_t *const *inputs,
457                                 size_t block_offset, __m128i out[16]) {
458   out[0] = loadu_128(&inputs[0][block_offset + 0 * sizeof(__m128i)]);
459   out[1] = loadu_128(&inputs[1][block_offset + 0 * sizeof(__m128i)]);
460   out[2] = loadu_128(&inputs[2][block_offset + 0 * sizeof(__m128i)]);
461   out[3] = loadu_128(&inputs[3][block_offset + 0 * sizeof(__m128i)]);
462   out[4] = loadu_128(&inputs[0][block_offset + 1 * sizeof(__m128i)]);
463   out[5] = loadu_128(&inputs[1][block_offset + 1 * sizeof(__m128i)]);
464   out[6] = loadu_128(&inputs[2][block_offset + 1 * sizeof(__m128i)]);
465   out[7] = loadu_128(&inputs[3][block_offset + 1 * sizeof(__m128i)]);
466   out[8] = loadu_128(&inputs[0][block_offset + 2 * sizeof(__m128i)]);
467   out[9] = loadu_128(&inputs[1][block_offset + 2 * sizeof(__m128i)]);
468   out[10] = loadu_128(&inputs[2][block_offset + 2 * sizeof(__m128i)]);
469   out[11] = loadu_128(&inputs[3][block_offset + 2 * sizeof(__m128i)]);
470   out[12] = loadu_128(&inputs[0][block_offset + 3 * sizeof(__m128i)]);
471   out[13] = loadu_128(&inputs[1][block_offset + 3 * sizeof(__m128i)]);
472   out[14] = loadu_128(&inputs[2][block_offset + 3 * sizeof(__m128i)]);
473   out[15] = loadu_128(&inputs[3][block_offset + 3 * sizeof(__m128i)]);
474   for (size_t i = 0; i < 4; ++i) {
475     _mm_prefetch((const void *)&inputs[i][block_offset + 256], _MM_HINT_T0);
476   }
477   transpose_vecs_128(&out[0]);
478   transpose_vecs_128(&out[4]);
479   transpose_vecs_128(&out[8]);
480   transpose_vecs_128(&out[12]);
481 }
482 
load_counters4(uint64_t counter,bool increment_counter,__m128i * out_lo,__m128i * out_hi)483 INLINE void load_counters4(uint64_t counter, bool increment_counter,
484                            __m128i *out_lo, __m128i *out_hi) {
485   uint64_t mask = (increment_counter ? ~0 : 0);
486   __m256i mask_vec = _mm256_set1_epi64x(mask);
487   __m256i deltas = _mm256_setr_epi64x(0, 1, 2, 3);
488   deltas = _mm256_and_si256(mask_vec, deltas);
489   __m256i counters =
490       _mm256_add_epi64(_mm256_set1_epi64x((int64_t)counter), deltas);
491   *out_lo = _mm256_cvtepi64_epi32(counters);
492   *out_hi = _mm256_cvtepi64_epi32(_mm256_srli_epi64(counters, 32));
493 }
494 
495 static
blake3_hash4_avx512(const uint8_t * const * inputs,size_t blocks,const uint32_t key[8],uint64_t counter,bool increment_counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t * out)496 void blake3_hash4_avx512(const uint8_t *const *inputs, size_t blocks,
497                          const uint32_t key[8], uint64_t counter,
498                          bool increment_counter, uint8_t flags,
499                          uint8_t flags_start, uint8_t flags_end, uint8_t *out) {
500   __m128i h_vecs[8] = {
501       set1_128(key[0]), set1_128(key[1]), set1_128(key[2]), set1_128(key[3]),
502       set1_128(key[4]), set1_128(key[5]), set1_128(key[6]), set1_128(key[7]),
503   };
504   __m128i counter_low_vec, counter_high_vec;
505   load_counters4(counter, increment_counter, &counter_low_vec,
506                  &counter_high_vec);
507   uint8_t block_flags = flags | flags_start;
508 
509   for (size_t block = 0; block < blocks; block++) {
510     if (block + 1 == blocks) {
511       block_flags |= flags_end;
512     }
513     __m128i block_len_vec = set1_128(BLAKE3_BLOCK_LEN);
514     __m128i block_flags_vec = set1_128(block_flags);
515     __m128i msg_vecs[16];
516     transpose_msg_vecs4(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
517 
518     __m128i v[16] = {
519         h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
520         h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
521         set1_128(IV[0]), set1_128(IV[1]),  set1_128(IV[2]), set1_128(IV[3]),
522         counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
523     };
524     round_fn4(v, msg_vecs, 0);
525     round_fn4(v, msg_vecs, 1);
526     round_fn4(v, msg_vecs, 2);
527     round_fn4(v, msg_vecs, 3);
528     round_fn4(v, msg_vecs, 4);
529     round_fn4(v, msg_vecs, 5);
530     round_fn4(v, msg_vecs, 6);
531     h_vecs[0] = xor_128(v[0], v[8]);
532     h_vecs[1] = xor_128(v[1], v[9]);
533     h_vecs[2] = xor_128(v[2], v[10]);
534     h_vecs[3] = xor_128(v[3], v[11]);
535     h_vecs[4] = xor_128(v[4], v[12]);
536     h_vecs[5] = xor_128(v[5], v[13]);
537     h_vecs[6] = xor_128(v[6], v[14]);
538     h_vecs[7] = xor_128(v[7], v[15]);
539 
540     block_flags = flags;
541   }
542 
543   transpose_vecs_128(&h_vecs[0]);
544   transpose_vecs_128(&h_vecs[4]);
545   // The first four vecs now contain the first half of each output, and the
546   // second four vecs contain the second half of each output.
547   storeu_128(h_vecs[0], &out[0 * sizeof(__m128i)]);
548   storeu_128(h_vecs[4], &out[1 * sizeof(__m128i)]);
549   storeu_128(h_vecs[1], &out[2 * sizeof(__m128i)]);
550   storeu_128(h_vecs[5], &out[3 * sizeof(__m128i)]);
551   storeu_128(h_vecs[2], &out[4 * sizeof(__m128i)]);
552   storeu_128(h_vecs[6], &out[5 * sizeof(__m128i)]);
553   storeu_128(h_vecs[3], &out[6 * sizeof(__m128i)]);
554   storeu_128(h_vecs[7], &out[7 * sizeof(__m128i)]);
555 }
556 
557 static
blake3_xof4_avx512(const uint32_t cv[8],const uint8_t block[BLAKE3_BLOCK_LEN],uint8_t block_len,uint64_t counter,uint8_t flags,uint8_t out[4* 64])558 void blake3_xof4_avx512(const uint32_t cv[8],
559                         const uint8_t block[BLAKE3_BLOCK_LEN],
560                         uint8_t block_len, uint64_t counter, uint8_t flags,
561                         uint8_t out[4 * 64]) {
562   __m128i h_vecs[8] = {
563       set1_128(cv[0]), set1_128(cv[1]), set1_128(cv[2]), set1_128(cv[3]),
564       set1_128(cv[4]), set1_128(cv[5]), set1_128(cv[6]), set1_128(cv[7]),
565   };
566   uint32_t block_words[16];
567   load_block_words(block, block_words);
568   __m128i msg_vecs[16];
569   for (size_t i = 0; i < 16; i++) {
570       msg_vecs[i] = set1_128(block_words[i]);
571   }
572   __m128i counter_low_vec, counter_high_vec;
573   load_counters4(counter, true, &counter_low_vec, &counter_high_vec);
574   __m128i block_len_vec = set1_128(block_len);
575   __m128i block_flags_vec = set1_128(flags);
576   __m128i v[16] = {
577       h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
578       h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
579       set1_128(IV[0]), set1_128(IV[1]),  set1_128(IV[2]), set1_128(IV[3]),
580       counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
581   };
582   round_fn4(v, msg_vecs, 0);
583   round_fn4(v, msg_vecs, 1);
584   round_fn4(v, msg_vecs, 2);
585   round_fn4(v, msg_vecs, 3);
586   round_fn4(v, msg_vecs, 4);
587   round_fn4(v, msg_vecs, 5);
588   round_fn4(v, msg_vecs, 6);
589   for (size_t i = 0; i < 8; i++) {
590       v[i] = xor_128(v[i], v[i+8]);
591       v[i+8] = xor_128(v[i+8], h_vecs[i]);
592   }
593   transpose_vecs_128(&v[0]);
594   transpose_vecs_128(&v[4]);
595   transpose_vecs_128(&v[8]);
596   transpose_vecs_128(&v[12]);
597   for (size_t i = 0; i < 4; i++) {
598       storeu_128(v[i+ 0], &out[(4*i+0) * sizeof(__m128i)]);
599       storeu_128(v[i+ 4], &out[(4*i+1) * sizeof(__m128i)]);
600       storeu_128(v[i+ 8], &out[(4*i+2) * sizeof(__m128i)]);
601       storeu_128(v[i+12], &out[(4*i+3) * sizeof(__m128i)]);
602   }
603 }
604 
605 /*
606  * ----------------------------------------------------------------------------
607  * hash8_avx512
608  * ----------------------------------------------------------------------------
609  */
610 
round_fn8(__m256i v[16],__m256i m[16],size_t r)611 INLINE void round_fn8(__m256i v[16], __m256i m[16], size_t r) {
612   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
613   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
614   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
615   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
616   v[0] = add_256(v[0], v[4]);
617   v[1] = add_256(v[1], v[5]);
618   v[2] = add_256(v[2], v[6]);
619   v[3] = add_256(v[3], v[7]);
620   v[12] = xor_256(v[12], v[0]);
621   v[13] = xor_256(v[13], v[1]);
622   v[14] = xor_256(v[14], v[2]);
623   v[15] = xor_256(v[15], v[3]);
624   v[12] = rot16_256(v[12]);
625   v[13] = rot16_256(v[13]);
626   v[14] = rot16_256(v[14]);
627   v[15] = rot16_256(v[15]);
628   v[8] = add_256(v[8], v[12]);
629   v[9] = add_256(v[9], v[13]);
630   v[10] = add_256(v[10], v[14]);
631   v[11] = add_256(v[11], v[15]);
632   v[4] = xor_256(v[4], v[8]);
633   v[5] = xor_256(v[5], v[9]);
634   v[6] = xor_256(v[6], v[10]);
635   v[7] = xor_256(v[7], v[11]);
636   v[4] = rot12_256(v[4]);
637   v[5] = rot12_256(v[5]);
638   v[6] = rot12_256(v[6]);
639   v[7] = rot12_256(v[7]);
640   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
641   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
642   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
643   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
644   v[0] = add_256(v[0], v[4]);
645   v[1] = add_256(v[1], v[5]);
646   v[2] = add_256(v[2], v[6]);
647   v[3] = add_256(v[3], v[7]);
648   v[12] = xor_256(v[12], v[0]);
649   v[13] = xor_256(v[13], v[1]);
650   v[14] = xor_256(v[14], v[2]);
651   v[15] = xor_256(v[15], v[3]);
652   v[12] = rot8_256(v[12]);
653   v[13] = rot8_256(v[13]);
654   v[14] = rot8_256(v[14]);
655   v[15] = rot8_256(v[15]);
656   v[8] = add_256(v[8], v[12]);
657   v[9] = add_256(v[9], v[13]);
658   v[10] = add_256(v[10], v[14]);
659   v[11] = add_256(v[11], v[15]);
660   v[4] = xor_256(v[4], v[8]);
661   v[5] = xor_256(v[5], v[9]);
662   v[6] = xor_256(v[6], v[10]);
663   v[7] = xor_256(v[7], v[11]);
664   v[4] = rot7_256(v[4]);
665   v[5] = rot7_256(v[5]);
666   v[6] = rot7_256(v[6]);
667   v[7] = rot7_256(v[7]);
668 
669   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
670   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
671   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
672   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
673   v[0] = add_256(v[0], v[5]);
674   v[1] = add_256(v[1], v[6]);
675   v[2] = add_256(v[2], v[7]);
676   v[3] = add_256(v[3], v[4]);
677   v[15] = xor_256(v[15], v[0]);
678   v[12] = xor_256(v[12], v[1]);
679   v[13] = xor_256(v[13], v[2]);
680   v[14] = xor_256(v[14], v[3]);
681   v[15] = rot16_256(v[15]);
682   v[12] = rot16_256(v[12]);
683   v[13] = rot16_256(v[13]);
684   v[14] = rot16_256(v[14]);
685   v[10] = add_256(v[10], v[15]);
686   v[11] = add_256(v[11], v[12]);
687   v[8] = add_256(v[8], v[13]);
688   v[9] = add_256(v[9], v[14]);
689   v[5] = xor_256(v[5], v[10]);
690   v[6] = xor_256(v[6], v[11]);
691   v[7] = xor_256(v[7], v[8]);
692   v[4] = xor_256(v[4], v[9]);
693   v[5] = rot12_256(v[5]);
694   v[6] = rot12_256(v[6]);
695   v[7] = rot12_256(v[7]);
696   v[4] = rot12_256(v[4]);
697   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
698   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
699   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
700   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
701   v[0] = add_256(v[0], v[5]);
702   v[1] = add_256(v[1], v[6]);
703   v[2] = add_256(v[2], v[7]);
704   v[3] = add_256(v[3], v[4]);
705   v[15] = xor_256(v[15], v[0]);
706   v[12] = xor_256(v[12], v[1]);
707   v[13] = xor_256(v[13], v[2]);
708   v[14] = xor_256(v[14], v[3]);
709   v[15] = rot8_256(v[15]);
710   v[12] = rot8_256(v[12]);
711   v[13] = rot8_256(v[13]);
712   v[14] = rot8_256(v[14]);
713   v[10] = add_256(v[10], v[15]);
714   v[11] = add_256(v[11], v[12]);
715   v[8] = add_256(v[8], v[13]);
716   v[9] = add_256(v[9], v[14]);
717   v[5] = xor_256(v[5], v[10]);
718   v[6] = xor_256(v[6], v[11]);
719   v[7] = xor_256(v[7], v[8]);
720   v[4] = xor_256(v[4], v[9]);
721   v[5] = rot7_256(v[5]);
722   v[6] = rot7_256(v[6]);
723   v[7] = rot7_256(v[7]);
724   v[4] = rot7_256(v[4]);
725 }
726 
transpose_vecs_256(__m256i vecs[8])727 INLINE void transpose_vecs_256(__m256i vecs[8]) {
728   // Interleave 32-bit lanes. The low unpack is lanes 00/11/44/55, and the high
729   // is 22/33/66/77.
730   __m256i ab_0145 = _mm256_unpacklo_epi32(vecs[0], vecs[1]);
731   __m256i ab_2367 = _mm256_unpackhi_epi32(vecs[0], vecs[1]);
732   __m256i cd_0145 = _mm256_unpacklo_epi32(vecs[2], vecs[3]);
733   __m256i cd_2367 = _mm256_unpackhi_epi32(vecs[2], vecs[3]);
734   __m256i ef_0145 = _mm256_unpacklo_epi32(vecs[4], vecs[5]);
735   __m256i ef_2367 = _mm256_unpackhi_epi32(vecs[4], vecs[5]);
736   __m256i gh_0145 = _mm256_unpacklo_epi32(vecs[6], vecs[7]);
737   __m256i gh_2367 = _mm256_unpackhi_epi32(vecs[6], vecs[7]);
738 
739   // Interleave 64-bit lanes. The low unpack is lanes 00/22 and the high is
740   // 11/33.
741   __m256i abcd_04 = _mm256_unpacklo_epi64(ab_0145, cd_0145);
742   __m256i abcd_15 = _mm256_unpackhi_epi64(ab_0145, cd_0145);
743   __m256i abcd_26 = _mm256_unpacklo_epi64(ab_2367, cd_2367);
744   __m256i abcd_37 = _mm256_unpackhi_epi64(ab_2367, cd_2367);
745   __m256i efgh_04 = _mm256_unpacklo_epi64(ef_0145, gh_0145);
746   __m256i efgh_15 = _mm256_unpackhi_epi64(ef_0145, gh_0145);
747   __m256i efgh_26 = _mm256_unpacklo_epi64(ef_2367, gh_2367);
748   __m256i efgh_37 = _mm256_unpackhi_epi64(ef_2367, gh_2367);
749 
750   // Interleave 128-bit lanes.
751   vecs[0] = _mm256_permute2x128_si256(abcd_04, efgh_04, 0x20);
752   vecs[1] = _mm256_permute2x128_si256(abcd_15, efgh_15, 0x20);
753   vecs[2] = _mm256_permute2x128_si256(abcd_26, efgh_26, 0x20);
754   vecs[3] = _mm256_permute2x128_si256(abcd_37, efgh_37, 0x20);
755   vecs[4] = _mm256_permute2x128_si256(abcd_04, efgh_04, 0x31);
756   vecs[5] = _mm256_permute2x128_si256(abcd_15, efgh_15, 0x31);
757   vecs[6] = _mm256_permute2x128_si256(abcd_26, efgh_26, 0x31);
758   vecs[7] = _mm256_permute2x128_si256(abcd_37, efgh_37, 0x31);
759 }
760 
transpose_msg_vecs8(const uint8_t * const * inputs,size_t block_offset,__m256i out[16])761 INLINE void transpose_msg_vecs8(const uint8_t *const *inputs,
762                                 size_t block_offset, __m256i out[16]) {
763   out[0] = loadu_256(&inputs[0][block_offset + 0 * sizeof(__m256i)]);
764   out[1] = loadu_256(&inputs[1][block_offset + 0 * sizeof(__m256i)]);
765   out[2] = loadu_256(&inputs[2][block_offset + 0 * sizeof(__m256i)]);
766   out[3] = loadu_256(&inputs[3][block_offset + 0 * sizeof(__m256i)]);
767   out[4] = loadu_256(&inputs[4][block_offset + 0 * sizeof(__m256i)]);
768   out[5] = loadu_256(&inputs[5][block_offset + 0 * sizeof(__m256i)]);
769   out[6] = loadu_256(&inputs[6][block_offset + 0 * sizeof(__m256i)]);
770   out[7] = loadu_256(&inputs[7][block_offset + 0 * sizeof(__m256i)]);
771   out[8] = loadu_256(&inputs[0][block_offset + 1 * sizeof(__m256i)]);
772   out[9] = loadu_256(&inputs[1][block_offset + 1 * sizeof(__m256i)]);
773   out[10] = loadu_256(&inputs[2][block_offset + 1 * sizeof(__m256i)]);
774   out[11] = loadu_256(&inputs[3][block_offset + 1 * sizeof(__m256i)]);
775   out[12] = loadu_256(&inputs[4][block_offset + 1 * sizeof(__m256i)]);
776   out[13] = loadu_256(&inputs[5][block_offset + 1 * sizeof(__m256i)]);
777   out[14] = loadu_256(&inputs[6][block_offset + 1 * sizeof(__m256i)]);
778   out[15] = loadu_256(&inputs[7][block_offset + 1 * sizeof(__m256i)]);
779   for (size_t i = 0; i < 8; ++i) {
780     _mm_prefetch((const void *)&inputs[i][block_offset + 256], _MM_HINT_T0);
781   }
782   transpose_vecs_256(&out[0]);
783   transpose_vecs_256(&out[8]);
784 }
785 
load_counters8(uint64_t counter,bool increment_counter,__m256i * out_lo,__m256i * out_hi)786 INLINE void load_counters8(uint64_t counter, bool increment_counter,
787                            __m256i *out_lo, __m256i *out_hi) {
788   uint64_t mask = (increment_counter ? ~0 : 0);
789   __m512i mask_vec = _mm512_set1_epi64(mask);
790   __m512i deltas = _mm512_setr_epi64(0, 1, 2, 3, 4, 5, 6, 7);
791   deltas = _mm512_and_si512(mask_vec, deltas);
792   __m512i counters =
793       _mm512_add_epi64(_mm512_set1_epi64((int64_t)counter), deltas);
794   *out_lo = _mm512_cvtepi64_epi32(counters);
795   *out_hi = _mm512_cvtepi64_epi32(_mm512_srli_epi64(counters, 32));
796 }
797 
798 static
blake3_hash8_avx512(const uint8_t * const * inputs,size_t blocks,const uint32_t key[8],uint64_t counter,bool increment_counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t * out)799 void blake3_hash8_avx512(const uint8_t *const *inputs, size_t blocks,
800                          const uint32_t key[8], uint64_t counter,
801                          bool increment_counter, uint8_t flags,
802                          uint8_t flags_start, uint8_t flags_end, uint8_t *out) {
803   __m256i h_vecs[8] = {
804       set1_256(key[0]), set1_256(key[1]), set1_256(key[2]), set1_256(key[3]),
805       set1_256(key[4]), set1_256(key[5]), set1_256(key[6]), set1_256(key[7]),
806   };
807   __m256i counter_low_vec, counter_high_vec;
808   load_counters8(counter, increment_counter, &counter_low_vec,
809                  &counter_high_vec);
810   uint8_t block_flags = flags | flags_start;
811 
812   for (size_t block = 0; block < blocks; block++) {
813     if (block + 1 == blocks) {
814       block_flags |= flags_end;
815     }
816     __m256i block_len_vec = set1_256(BLAKE3_BLOCK_LEN);
817     __m256i block_flags_vec = set1_256(block_flags);
818     __m256i msg_vecs[16];
819     transpose_msg_vecs8(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
820 
821     __m256i v[16] = {
822         h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
823         h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
824         set1_256(IV[0]), set1_256(IV[1]),  set1_256(IV[2]), set1_256(IV[3]),
825         counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
826     };
827     round_fn8(v, msg_vecs, 0);
828     round_fn8(v, msg_vecs, 1);
829     round_fn8(v, msg_vecs, 2);
830     round_fn8(v, msg_vecs, 3);
831     round_fn8(v, msg_vecs, 4);
832     round_fn8(v, msg_vecs, 5);
833     round_fn8(v, msg_vecs, 6);
834     h_vecs[0] = xor_256(v[0], v[8]);
835     h_vecs[1] = xor_256(v[1], v[9]);
836     h_vecs[2] = xor_256(v[2], v[10]);
837     h_vecs[3] = xor_256(v[3], v[11]);
838     h_vecs[4] = xor_256(v[4], v[12]);
839     h_vecs[5] = xor_256(v[5], v[13]);
840     h_vecs[6] = xor_256(v[6], v[14]);
841     h_vecs[7] = xor_256(v[7], v[15]);
842 
843     block_flags = flags;
844   }
845 
846   transpose_vecs_256(h_vecs);
847   storeu_256(h_vecs[0], &out[0 * sizeof(__m256i)]);
848   storeu_256(h_vecs[1], &out[1 * sizeof(__m256i)]);
849   storeu_256(h_vecs[2], &out[2 * sizeof(__m256i)]);
850   storeu_256(h_vecs[3], &out[3 * sizeof(__m256i)]);
851   storeu_256(h_vecs[4], &out[4 * sizeof(__m256i)]);
852   storeu_256(h_vecs[5], &out[5 * sizeof(__m256i)]);
853   storeu_256(h_vecs[6], &out[6 * sizeof(__m256i)]);
854   storeu_256(h_vecs[7], &out[7 * sizeof(__m256i)]);
855 }
856 
857 static
blake3_xof8_avx512(const uint32_t cv[8],const uint8_t block[BLAKE3_BLOCK_LEN],uint8_t block_len,uint64_t counter,uint8_t flags,uint8_t out[8* 64])858 void blake3_xof8_avx512(const uint32_t cv[8],
859                         const uint8_t block[BLAKE3_BLOCK_LEN],
860                         uint8_t block_len, uint64_t counter, uint8_t flags,
861                         uint8_t out[8 * 64]) {
862   __m256i h_vecs[8] = {
863       set1_256(cv[0]), set1_256(cv[1]), set1_256(cv[2]), set1_256(cv[3]),
864       set1_256(cv[4]), set1_256(cv[5]), set1_256(cv[6]), set1_256(cv[7]),
865   };
866   uint32_t block_words[16];
867   load_block_words(block, block_words);
868   __m256i msg_vecs[16];
869   for (size_t i = 0; i < 16; i++) {
870       msg_vecs[i] = set1_256(block_words[i]);
871   }
872   __m256i counter_low_vec, counter_high_vec;
873   load_counters8(counter, true, &counter_low_vec, &counter_high_vec);
874   __m256i block_len_vec = set1_256(block_len);
875   __m256i block_flags_vec = set1_256(flags);
876   __m256i v[16] = {
877       h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
878       h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
879       set1_256(IV[0]), set1_256(IV[1]),  set1_256(IV[2]), set1_256(IV[3]),
880       counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
881   };
882   round_fn8(v, msg_vecs, 0);
883   round_fn8(v, msg_vecs, 1);
884   round_fn8(v, msg_vecs, 2);
885   round_fn8(v, msg_vecs, 3);
886   round_fn8(v, msg_vecs, 4);
887   round_fn8(v, msg_vecs, 5);
888   round_fn8(v, msg_vecs, 6);
889   for (size_t i = 0; i < 8; i++) {
890       v[i] = xor_256(v[i], v[i+8]);
891       v[i+8] = xor_256(v[i+8], h_vecs[i]);
892   }
893   transpose_vecs_256(&v[0]);
894   transpose_vecs_256(&v[8]);
895   for (size_t i = 0; i < 8; i++) {
896       storeu_256(v[i+0], &out[(2*i+0) * sizeof(__m256i)]);
897       storeu_256(v[i+8], &out[(2*i+1) * sizeof(__m256i)]);
898   }
899 }
900 
901 /*
902  * ----------------------------------------------------------------------------
903  * hash16_avx512
904  * ----------------------------------------------------------------------------
905  */
906 
round_fn16(__m512i v[16],__m512i m[16],size_t r)907 INLINE void round_fn16(__m512i v[16], __m512i m[16], size_t r) {
908   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
909   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
910   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
911   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
912   v[0] = add_512(v[0], v[4]);
913   v[1] = add_512(v[1], v[5]);
914   v[2] = add_512(v[2], v[6]);
915   v[3] = add_512(v[3], v[7]);
916   v[12] = xor_512(v[12], v[0]);
917   v[13] = xor_512(v[13], v[1]);
918   v[14] = xor_512(v[14], v[2]);
919   v[15] = xor_512(v[15], v[3]);
920   v[12] = rot16_512(v[12]);
921   v[13] = rot16_512(v[13]);
922   v[14] = rot16_512(v[14]);
923   v[15] = rot16_512(v[15]);
924   v[8] = add_512(v[8], v[12]);
925   v[9] = add_512(v[9], v[13]);
926   v[10] = add_512(v[10], v[14]);
927   v[11] = add_512(v[11], v[15]);
928   v[4] = xor_512(v[4], v[8]);
929   v[5] = xor_512(v[5], v[9]);
930   v[6] = xor_512(v[6], v[10]);
931   v[7] = xor_512(v[7], v[11]);
932   v[4] = rot12_512(v[4]);
933   v[5] = rot12_512(v[5]);
934   v[6] = rot12_512(v[6]);
935   v[7] = rot12_512(v[7]);
936   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
937   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
938   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
939   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
940   v[0] = add_512(v[0], v[4]);
941   v[1] = add_512(v[1], v[5]);
942   v[2] = add_512(v[2], v[6]);
943   v[3] = add_512(v[3], v[7]);
944   v[12] = xor_512(v[12], v[0]);
945   v[13] = xor_512(v[13], v[1]);
946   v[14] = xor_512(v[14], v[2]);
947   v[15] = xor_512(v[15], v[3]);
948   v[12] = rot8_512(v[12]);
949   v[13] = rot8_512(v[13]);
950   v[14] = rot8_512(v[14]);
951   v[15] = rot8_512(v[15]);
952   v[8] = add_512(v[8], v[12]);
953   v[9] = add_512(v[9], v[13]);
954   v[10] = add_512(v[10], v[14]);
955   v[11] = add_512(v[11], v[15]);
956   v[4] = xor_512(v[4], v[8]);
957   v[5] = xor_512(v[5], v[9]);
958   v[6] = xor_512(v[6], v[10]);
959   v[7] = xor_512(v[7], v[11]);
960   v[4] = rot7_512(v[4]);
961   v[5] = rot7_512(v[5]);
962   v[6] = rot7_512(v[6]);
963   v[7] = rot7_512(v[7]);
964 
965   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
966   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
967   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
968   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
969   v[0] = add_512(v[0], v[5]);
970   v[1] = add_512(v[1], v[6]);
971   v[2] = add_512(v[2], v[7]);
972   v[3] = add_512(v[3], v[4]);
973   v[15] = xor_512(v[15], v[0]);
974   v[12] = xor_512(v[12], v[1]);
975   v[13] = xor_512(v[13], v[2]);
976   v[14] = xor_512(v[14], v[3]);
977   v[15] = rot16_512(v[15]);
978   v[12] = rot16_512(v[12]);
979   v[13] = rot16_512(v[13]);
980   v[14] = rot16_512(v[14]);
981   v[10] = add_512(v[10], v[15]);
982   v[11] = add_512(v[11], v[12]);
983   v[8] = add_512(v[8], v[13]);
984   v[9] = add_512(v[9], v[14]);
985   v[5] = xor_512(v[5], v[10]);
986   v[6] = xor_512(v[6], v[11]);
987   v[7] = xor_512(v[7], v[8]);
988   v[4] = xor_512(v[4], v[9]);
989   v[5] = rot12_512(v[5]);
990   v[6] = rot12_512(v[6]);
991   v[7] = rot12_512(v[7]);
992   v[4] = rot12_512(v[4]);
993   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
994   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
995   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
996   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
997   v[0] = add_512(v[0], v[5]);
998   v[1] = add_512(v[1], v[6]);
999   v[2] = add_512(v[2], v[7]);
1000   v[3] = add_512(v[3], v[4]);
1001   v[15] = xor_512(v[15], v[0]);
1002   v[12] = xor_512(v[12], v[1]);
1003   v[13] = xor_512(v[13], v[2]);
1004   v[14] = xor_512(v[14], v[3]);
1005   v[15] = rot8_512(v[15]);
1006   v[12] = rot8_512(v[12]);
1007   v[13] = rot8_512(v[13]);
1008   v[14] = rot8_512(v[14]);
1009   v[10] = add_512(v[10], v[15]);
1010   v[11] = add_512(v[11], v[12]);
1011   v[8] = add_512(v[8], v[13]);
1012   v[9] = add_512(v[9], v[14]);
1013   v[5] = xor_512(v[5], v[10]);
1014   v[6] = xor_512(v[6], v[11]);
1015   v[7] = xor_512(v[7], v[8]);
1016   v[4] = xor_512(v[4], v[9]);
1017   v[5] = rot7_512(v[5]);
1018   v[6] = rot7_512(v[6]);
1019   v[7] = rot7_512(v[7]);
1020   v[4] = rot7_512(v[4]);
1021 }
1022 
1023 // 0b10001000, or lanes a0/a2/b0/b2 in little-endian order
1024 #define LO_IMM8 0x88
1025 
unpack_lo_128(__m512i a,__m512i b)1026 INLINE __m512i unpack_lo_128(__m512i a, __m512i b) {
1027   return _mm512_shuffle_i32x4(a, b, LO_IMM8);
1028 }
1029 
1030 // 0b11011101, or lanes a1/a3/b1/b3 in little-endian order
1031 #define HI_IMM8 0xdd
1032 
unpack_hi_128(__m512i a,__m512i b)1033 INLINE __m512i unpack_hi_128(__m512i a, __m512i b) {
1034   return _mm512_shuffle_i32x4(a, b, HI_IMM8);
1035 }
1036 
transpose_vecs_512(__m512i vecs[16])1037 INLINE void transpose_vecs_512(__m512i vecs[16]) {
1038   // Interleave 32-bit lanes. The _0 unpack is lanes
1039   // 0/0/1/1/4/4/5/5/8/8/9/9/12/12/13/13, and the _2 unpack is lanes
1040   // 2/2/3/3/6/6/7/7/10/10/11/11/14/14/15/15.
1041   __m512i ab_0 = _mm512_unpacklo_epi32(vecs[0], vecs[1]);
1042   __m512i ab_2 = _mm512_unpackhi_epi32(vecs[0], vecs[1]);
1043   __m512i cd_0 = _mm512_unpacklo_epi32(vecs[2], vecs[3]);
1044   __m512i cd_2 = _mm512_unpackhi_epi32(vecs[2], vecs[3]);
1045   __m512i ef_0 = _mm512_unpacklo_epi32(vecs[4], vecs[5]);
1046   __m512i ef_2 = _mm512_unpackhi_epi32(vecs[4], vecs[5]);
1047   __m512i gh_0 = _mm512_unpacklo_epi32(vecs[6], vecs[7]);
1048   __m512i gh_2 = _mm512_unpackhi_epi32(vecs[6], vecs[7]);
1049   __m512i ij_0 = _mm512_unpacklo_epi32(vecs[8], vecs[9]);
1050   __m512i ij_2 = _mm512_unpackhi_epi32(vecs[8], vecs[9]);
1051   __m512i kl_0 = _mm512_unpacklo_epi32(vecs[10], vecs[11]);
1052   __m512i kl_2 = _mm512_unpackhi_epi32(vecs[10], vecs[11]);
1053   __m512i mn_0 = _mm512_unpacklo_epi32(vecs[12], vecs[13]);
1054   __m512i mn_2 = _mm512_unpackhi_epi32(vecs[12], vecs[13]);
1055   __m512i op_0 = _mm512_unpacklo_epi32(vecs[14], vecs[15]);
1056   __m512i op_2 = _mm512_unpackhi_epi32(vecs[14], vecs[15]);
1057 
1058   // Interleave 64-bit lanes. The _0 unpack is lanes
1059   // 0/0/0/0/4/4/4/4/8/8/8/8/12/12/12/12, the _1 unpack is lanes
1060   // 1/1/1/1/5/5/5/5/9/9/9/9/13/13/13/13, the _2 unpack is lanes
1061   // 2/2/2/2/6/6/6/6/10/10/10/10/14/14/14/14, and the _3 unpack is lanes
1062   // 3/3/3/3/7/7/7/7/11/11/11/11/15/15/15/15.
1063   __m512i abcd_0 = _mm512_unpacklo_epi64(ab_0, cd_0);
1064   __m512i abcd_1 = _mm512_unpackhi_epi64(ab_0, cd_0);
1065   __m512i abcd_2 = _mm512_unpacklo_epi64(ab_2, cd_2);
1066   __m512i abcd_3 = _mm512_unpackhi_epi64(ab_2, cd_2);
1067   __m512i efgh_0 = _mm512_unpacklo_epi64(ef_0, gh_0);
1068   __m512i efgh_1 = _mm512_unpackhi_epi64(ef_0, gh_0);
1069   __m512i efgh_2 = _mm512_unpacklo_epi64(ef_2, gh_2);
1070   __m512i efgh_3 = _mm512_unpackhi_epi64(ef_2, gh_2);
1071   __m512i ijkl_0 = _mm512_unpacklo_epi64(ij_0, kl_0);
1072   __m512i ijkl_1 = _mm512_unpackhi_epi64(ij_0, kl_0);
1073   __m512i ijkl_2 = _mm512_unpacklo_epi64(ij_2, kl_2);
1074   __m512i ijkl_3 = _mm512_unpackhi_epi64(ij_2, kl_2);
1075   __m512i mnop_0 = _mm512_unpacklo_epi64(mn_0, op_0);
1076   __m512i mnop_1 = _mm512_unpackhi_epi64(mn_0, op_0);
1077   __m512i mnop_2 = _mm512_unpacklo_epi64(mn_2, op_2);
1078   __m512i mnop_3 = _mm512_unpackhi_epi64(mn_2, op_2);
1079 
1080   // Interleave 128-bit lanes. The _0 unpack is
1081   // 0/0/0/0/8/8/8/8/0/0/0/0/8/8/8/8, the _1 unpack is
1082   // 1/1/1/1/9/9/9/9/1/1/1/1/9/9/9/9, and so on.
1083   __m512i abcdefgh_0 = unpack_lo_128(abcd_0, efgh_0);
1084   __m512i abcdefgh_1 = unpack_lo_128(abcd_1, efgh_1);
1085   __m512i abcdefgh_2 = unpack_lo_128(abcd_2, efgh_2);
1086   __m512i abcdefgh_3 = unpack_lo_128(abcd_3, efgh_3);
1087   __m512i abcdefgh_4 = unpack_hi_128(abcd_0, efgh_0);
1088   __m512i abcdefgh_5 = unpack_hi_128(abcd_1, efgh_1);
1089   __m512i abcdefgh_6 = unpack_hi_128(abcd_2, efgh_2);
1090   __m512i abcdefgh_7 = unpack_hi_128(abcd_3, efgh_3);
1091   __m512i ijklmnop_0 = unpack_lo_128(ijkl_0, mnop_0);
1092   __m512i ijklmnop_1 = unpack_lo_128(ijkl_1, mnop_1);
1093   __m512i ijklmnop_2 = unpack_lo_128(ijkl_2, mnop_2);
1094   __m512i ijklmnop_3 = unpack_lo_128(ijkl_3, mnop_3);
1095   __m512i ijklmnop_4 = unpack_hi_128(ijkl_0, mnop_0);
1096   __m512i ijklmnop_5 = unpack_hi_128(ijkl_1, mnop_1);
1097   __m512i ijklmnop_6 = unpack_hi_128(ijkl_2, mnop_2);
1098   __m512i ijklmnop_7 = unpack_hi_128(ijkl_3, mnop_3);
1099 
1100   // Interleave 128-bit lanes again for the final outputs.
1101   vecs[0] = unpack_lo_128(abcdefgh_0, ijklmnop_0);
1102   vecs[1] = unpack_lo_128(abcdefgh_1, ijklmnop_1);
1103   vecs[2] = unpack_lo_128(abcdefgh_2, ijklmnop_2);
1104   vecs[3] = unpack_lo_128(abcdefgh_3, ijklmnop_3);
1105   vecs[4] = unpack_lo_128(abcdefgh_4, ijklmnop_4);
1106   vecs[5] = unpack_lo_128(abcdefgh_5, ijklmnop_5);
1107   vecs[6] = unpack_lo_128(abcdefgh_6, ijklmnop_6);
1108   vecs[7] = unpack_lo_128(abcdefgh_7, ijklmnop_7);
1109   vecs[8] = unpack_hi_128(abcdefgh_0, ijklmnop_0);
1110   vecs[9] = unpack_hi_128(abcdefgh_1, ijklmnop_1);
1111   vecs[10] = unpack_hi_128(abcdefgh_2, ijklmnop_2);
1112   vecs[11] = unpack_hi_128(abcdefgh_3, ijklmnop_3);
1113   vecs[12] = unpack_hi_128(abcdefgh_4, ijklmnop_4);
1114   vecs[13] = unpack_hi_128(abcdefgh_5, ijklmnop_5);
1115   vecs[14] = unpack_hi_128(abcdefgh_6, ijklmnop_6);
1116   vecs[15] = unpack_hi_128(abcdefgh_7, ijklmnop_7);
1117 }
1118 
transpose_msg_vecs16(const uint8_t * const * inputs,size_t block_offset,__m512i out[16])1119 INLINE void transpose_msg_vecs16(const uint8_t *const *inputs,
1120                                  size_t block_offset, __m512i out[16]) {
1121   out[0] = loadu_512(&inputs[0][block_offset]);
1122   out[1] = loadu_512(&inputs[1][block_offset]);
1123   out[2] = loadu_512(&inputs[2][block_offset]);
1124   out[3] = loadu_512(&inputs[3][block_offset]);
1125   out[4] = loadu_512(&inputs[4][block_offset]);
1126   out[5] = loadu_512(&inputs[5][block_offset]);
1127   out[6] = loadu_512(&inputs[6][block_offset]);
1128   out[7] = loadu_512(&inputs[7][block_offset]);
1129   out[8] = loadu_512(&inputs[8][block_offset]);
1130   out[9] = loadu_512(&inputs[9][block_offset]);
1131   out[10] = loadu_512(&inputs[10][block_offset]);
1132   out[11] = loadu_512(&inputs[11][block_offset]);
1133   out[12] = loadu_512(&inputs[12][block_offset]);
1134   out[13] = loadu_512(&inputs[13][block_offset]);
1135   out[14] = loadu_512(&inputs[14][block_offset]);
1136   out[15] = loadu_512(&inputs[15][block_offset]);
1137   for (size_t i = 0; i < 16; ++i) {
1138     _mm_prefetch((const void *)&inputs[i][block_offset + 256], _MM_HINT_T0);
1139   }
1140   transpose_vecs_512(out);
1141 }
1142 
load_counters16(uint64_t counter,bool increment_counter,__m512i * out_lo,__m512i * out_hi)1143 INLINE void load_counters16(uint64_t counter, bool increment_counter,
1144                             __m512i *out_lo, __m512i *out_hi) {
1145   const __m512i mask = _mm512_set1_epi32(-(int32_t)increment_counter);
1146   const __m512i deltas = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
1147   const __m512i masked_deltas = _mm512_and_si512(deltas, mask);
1148   const __m512i low_words = _mm512_add_epi32(
1149     _mm512_set1_epi32((int32_t)counter),
1150     masked_deltas);
1151   // The carry bit is 1 if the high bit of the word was 1 before addition and is
1152   // 0 after.
1153   // NOTE: It would be a bit more natural to use _mm512_cmp_epu32_mask to
1154   // compute the carry bits here, and originally we did, but that intrinsic is
1155   // broken under GCC 5.4. See https://github.com/BLAKE3-team/BLAKE3/issues/271.
1156   const __m512i carries = _mm512_srli_epi32(
1157     _mm512_andnot_si512(
1158         low_words, // 0 after (gets inverted by andnot)
1159         _mm512_set1_epi32((int32_t)counter)), // and 1 before
1160     31);
1161   const __m512i high_words = _mm512_add_epi32(
1162     _mm512_set1_epi32((int32_t)(counter >> 32)),
1163     carries);
1164   *out_lo = low_words;
1165   *out_hi = high_words;
1166 }
1167 
1168 static
blake3_hash16_avx512(const uint8_t * const * inputs,size_t blocks,const uint32_t key[8],uint64_t counter,bool increment_counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t * out)1169 void blake3_hash16_avx512(const uint8_t *const *inputs, size_t blocks,
1170                           const uint32_t key[8], uint64_t counter,
1171                           bool increment_counter, uint8_t flags,
1172                           uint8_t flags_start, uint8_t flags_end,
1173                           uint8_t *out) {
1174   __m512i h_vecs[8] = {
1175       set1_512(key[0]), set1_512(key[1]), set1_512(key[2]), set1_512(key[3]),
1176       set1_512(key[4]), set1_512(key[5]), set1_512(key[6]), set1_512(key[7]),
1177   };
1178   __m512i counter_low_vec, counter_high_vec;
1179   load_counters16(counter, increment_counter, &counter_low_vec,
1180                   &counter_high_vec);
1181   uint8_t block_flags = flags | flags_start;
1182 
1183   for (size_t block = 0; block < blocks; block++) {
1184     if (block + 1 == blocks) {
1185       block_flags |= flags_end;
1186     }
1187     __m512i block_len_vec = set1_512(BLAKE3_BLOCK_LEN);
1188     __m512i block_flags_vec = set1_512(block_flags);
1189     __m512i msg_vecs[16];
1190     transpose_msg_vecs16(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
1191 
1192     __m512i v[16] = {
1193         h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
1194         h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
1195         set1_512(IV[0]), set1_512(IV[1]),  set1_512(IV[2]), set1_512(IV[3]),
1196         counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
1197     };
1198     round_fn16(v, msg_vecs, 0);
1199     round_fn16(v, msg_vecs, 1);
1200     round_fn16(v, msg_vecs, 2);
1201     round_fn16(v, msg_vecs, 3);
1202     round_fn16(v, msg_vecs, 4);
1203     round_fn16(v, msg_vecs, 5);
1204     round_fn16(v, msg_vecs, 6);
1205     h_vecs[0] = xor_512(v[0], v[8]);
1206     h_vecs[1] = xor_512(v[1], v[9]);
1207     h_vecs[2] = xor_512(v[2], v[10]);
1208     h_vecs[3] = xor_512(v[3], v[11]);
1209     h_vecs[4] = xor_512(v[4], v[12]);
1210     h_vecs[5] = xor_512(v[5], v[13]);
1211     h_vecs[6] = xor_512(v[6], v[14]);
1212     h_vecs[7] = xor_512(v[7], v[15]);
1213 
1214     block_flags = flags;
1215   }
1216 
1217   // transpose_vecs_512 operates on a 16x16 matrix of words, but we only have 8
1218   // state vectors. Pad the matrix with zeros. After transposition, store the
1219   // lower half of each vector.
1220   __m512i padded[16] = {
1221       h_vecs[0],   h_vecs[1],   h_vecs[2],   h_vecs[3],
1222       h_vecs[4],   h_vecs[5],   h_vecs[6],   h_vecs[7],
1223       set1_512(0), set1_512(0), set1_512(0), set1_512(0),
1224       set1_512(0), set1_512(0), set1_512(0), set1_512(0),
1225   };
1226   transpose_vecs_512(padded);
1227   _mm256_mask_storeu_epi32(&out[0 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[0]));
1228   _mm256_mask_storeu_epi32(&out[1 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[1]));
1229   _mm256_mask_storeu_epi32(&out[2 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[2]));
1230   _mm256_mask_storeu_epi32(&out[3 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[3]));
1231   _mm256_mask_storeu_epi32(&out[4 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[4]));
1232   _mm256_mask_storeu_epi32(&out[5 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[5]));
1233   _mm256_mask_storeu_epi32(&out[6 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[6]));
1234   _mm256_mask_storeu_epi32(&out[7 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[7]));
1235   _mm256_mask_storeu_epi32(&out[8 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[8]));
1236   _mm256_mask_storeu_epi32(&out[9 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[9]));
1237   _mm256_mask_storeu_epi32(&out[10 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[10]));
1238   _mm256_mask_storeu_epi32(&out[11 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[11]));
1239   _mm256_mask_storeu_epi32(&out[12 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[12]));
1240   _mm256_mask_storeu_epi32(&out[13 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[13]));
1241   _mm256_mask_storeu_epi32(&out[14 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[14]));
1242   _mm256_mask_storeu_epi32(&out[15 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[15]));
1243 }
1244 
1245 static
blake3_xof16_avx512(const uint32_t cv[8],const uint8_t block[BLAKE3_BLOCK_LEN],uint8_t block_len,uint64_t counter,uint8_t flags,uint8_t out[16* 64])1246 void blake3_xof16_avx512(const uint32_t cv[8],
1247                         const uint8_t block[BLAKE3_BLOCK_LEN],
1248                         uint8_t block_len, uint64_t counter, uint8_t flags,
1249                         uint8_t out[16 * 64]) {
1250   __m512i h_vecs[8] = {
1251       set1_512(cv[0]), set1_512(cv[1]), set1_512(cv[2]), set1_512(cv[3]),
1252       set1_512(cv[4]), set1_512(cv[5]), set1_512(cv[6]), set1_512(cv[7]),
1253   };
1254   uint32_t block_words[16];
1255   load_block_words(block, block_words);
1256   __m512i msg_vecs[16];
1257   for (size_t i = 0; i < 16; i++) {
1258       msg_vecs[i] = set1_512(block_words[i]);
1259   }
1260   __m512i counter_low_vec, counter_high_vec;
1261   load_counters16(counter, true, &counter_low_vec, &counter_high_vec);
1262   __m512i block_len_vec = set1_512(block_len);
1263   __m512i block_flags_vec = set1_512(flags);
1264   __m512i v[16] = {
1265       h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
1266       h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
1267       set1_512(IV[0]), set1_512(IV[1]),  set1_512(IV[2]), set1_512(IV[3]),
1268       counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
1269   };
1270   round_fn16(v, msg_vecs, 0);
1271   round_fn16(v, msg_vecs, 1);
1272   round_fn16(v, msg_vecs, 2);
1273   round_fn16(v, msg_vecs, 3);
1274   round_fn16(v, msg_vecs, 4);
1275   round_fn16(v, msg_vecs, 5);
1276   round_fn16(v, msg_vecs, 6);
1277   for (size_t i = 0; i < 8; i++) {
1278       v[i] = xor_512(v[i], v[i+8]);
1279       v[i+8] = xor_512(v[i+8], h_vecs[i]);
1280   }
1281   transpose_vecs_512(&v[0]);
1282   for (size_t i = 0; i < 16; i++) {
1283       storeu_512(v[i], &out[i * sizeof(__m512i)]);
1284   }
1285 }
1286 
1287 /*
1288  * ----------------------------------------------------------------------------
1289  * hash_many_avx512
1290  * ----------------------------------------------------------------------------
1291  */
1292 
hash_one_avx512(const uint8_t * input,size_t blocks,const uint32_t key[8],uint64_t counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t out[BLAKE3_OUT_LEN])1293 INLINE void hash_one_avx512(const uint8_t *input, size_t blocks,
1294                             const uint32_t key[8], uint64_t counter,
1295                             uint8_t flags, uint8_t flags_start,
1296                             uint8_t flags_end, uint8_t out[BLAKE3_OUT_LEN]) {
1297   uint32_t cv[8];
1298   memcpy(cv, key, BLAKE3_KEY_LEN);
1299   uint8_t block_flags = flags | flags_start;
1300   while (blocks > 0) {
1301     if (blocks == 1) {
1302       block_flags |= flags_end;
1303     }
1304     blake3_compress_in_place_avx512(cv, input, BLAKE3_BLOCK_LEN, counter,
1305                                     block_flags);
1306     input = &input[BLAKE3_BLOCK_LEN];
1307     blocks -= 1;
1308     block_flags = flags;
1309   }
1310   memcpy(out, cv, BLAKE3_OUT_LEN);
1311 }
1312 
blake3_hash_many_avx512(const uint8_t * const * inputs,size_t num_inputs,size_t blocks,const uint32_t key[8],uint64_t counter,bool increment_counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t * out)1313 void blake3_hash_many_avx512(const uint8_t *const *inputs, size_t num_inputs,
1314                              size_t blocks, const uint32_t key[8],
1315                              uint64_t counter, bool increment_counter,
1316                              uint8_t flags, uint8_t flags_start,
1317                              uint8_t flags_end, uint8_t *out) {
1318   while (num_inputs >= 16) {
1319     blake3_hash16_avx512(inputs, blocks, key, counter, increment_counter, flags,
1320                          flags_start, flags_end, out);
1321     if (increment_counter) {
1322       counter += 16;
1323     }
1324     inputs += 16;
1325     num_inputs -= 16;
1326     out = &out[16 * BLAKE3_OUT_LEN];
1327   }
1328   while (num_inputs >= 8) {
1329     blake3_hash8_avx512(inputs, blocks, key, counter, increment_counter, flags,
1330                         flags_start, flags_end, out);
1331     if (increment_counter) {
1332       counter += 8;
1333     }
1334     inputs += 8;
1335     num_inputs -= 8;
1336     out = &out[8 * BLAKE3_OUT_LEN];
1337   }
1338   while (num_inputs >= 4) {
1339     blake3_hash4_avx512(inputs, blocks, key, counter, increment_counter, flags,
1340                         flags_start, flags_end, out);
1341     if (increment_counter) {
1342       counter += 4;
1343     }
1344     inputs += 4;
1345     num_inputs -= 4;
1346     out = &out[4 * BLAKE3_OUT_LEN];
1347   }
1348   while (num_inputs > 0) {
1349     hash_one_avx512(inputs[0], blocks, key, counter, flags, flags_start,
1350                     flags_end, out);
1351     if (increment_counter) {
1352       counter += 1;
1353     }
1354     inputs += 1;
1355     num_inputs -= 1;
1356     out = &out[BLAKE3_OUT_LEN];
1357   }
1358 }
1359 
blake3_xof_many_avx512(const uint32_t cv[8],const uint8_t block[BLAKE3_BLOCK_LEN],uint8_t block_len,uint64_t counter,uint8_t flags,uint8_t * out,size_t outblocks)1360 void blake3_xof_many_avx512(const uint32_t cv[8],
1361                             const uint8_t block[BLAKE3_BLOCK_LEN],
1362                             uint8_t block_len, uint64_t counter, uint8_t flags,
1363                             uint8_t* out, size_t outblocks) {
1364   while (outblocks >= 16) {
1365     blake3_xof16_avx512(cv, block, block_len, counter, flags, out);
1366     counter += 16;
1367     outblocks -= 16;
1368     out += 16 * BLAKE3_BLOCK_LEN;
1369   }
1370   while (outblocks >= 8) {
1371     blake3_xof8_avx512(cv, block, block_len, counter, flags, out);
1372     counter += 8;
1373     outblocks -= 8;
1374     out += 8 * BLAKE3_BLOCK_LEN;
1375   }
1376   while (outblocks >= 4) {
1377     blake3_xof4_avx512(cv, block, block_len, counter, flags, out);
1378     counter += 4;
1379     outblocks -= 4;
1380     out += 4 * BLAKE3_BLOCK_LEN;
1381   }
1382   while (outblocks > 0) {
1383     blake3_compress_xof_avx512(cv, block, block_len, counter, flags, out);
1384     counter += 1;
1385     outblocks -= 1;
1386     out += BLAKE3_BLOCK_LEN;
1387   }
1388 }
1389