xref: /freebsd/contrib/llvm-project/llvm/lib/Support/BLAKE3/blake3_neon.c (revision e64bea71c21eb42e97aa615188ba91f6cce0d36d)
1 #include "blake3_impl.h"
2 
3 #if BLAKE3_USE_NEON
4 
5 #include <arm_neon.h>
6 
7 #ifdef __ARM_BIG_ENDIAN
8 #error "This implementation only supports little-endian ARM."
9 // It might be that all we need for big-endian support here is to get the loads
10 // and stores right, but step zero would be finding a way to test it in CI.
11 #endif
12 
loadu_128(const uint8_t src[16])13 INLINE uint32x4_t loadu_128(const uint8_t src[16]) {
14   // vld1q_u32 has alignment requirements. Don't use it.
15   return vreinterpretq_u32_u8(vld1q_u8(src));
16 }
17 
storeu_128(uint32x4_t src,uint8_t dest[16])18 INLINE void storeu_128(uint32x4_t src, uint8_t dest[16]) {
19   // vst1q_u32 has alignment requirements. Don't use it.
20   vst1q_u8(dest, vreinterpretq_u8_u32(src));
21 }
22 
add_128(uint32x4_t a,uint32x4_t b)23 INLINE uint32x4_t add_128(uint32x4_t a, uint32x4_t b) {
24   return vaddq_u32(a, b);
25 }
26 
xor_128(uint32x4_t a,uint32x4_t b)27 INLINE uint32x4_t xor_128(uint32x4_t a, uint32x4_t b) {
28   return veorq_u32(a, b);
29 }
30 
set1_128(uint32_t x)31 INLINE uint32x4_t set1_128(uint32_t x) { return vld1q_dup_u32(&x); }
32 
set4(uint32_t a,uint32_t b,uint32_t c,uint32_t d)33 INLINE uint32x4_t set4(uint32_t a, uint32_t b, uint32_t c, uint32_t d) {
34   uint32_t array[4] = {a, b, c, d};
35   return vld1q_u32(array);
36 }
37 
rot16_128(uint32x4_t x)38 INLINE uint32x4_t rot16_128(uint32x4_t x) {
39   // The straightforward implementation would be two shifts and an or, but that's
40   // slower on microarchitectures we've tested. See
41   // https://github.com/BLAKE3-team/BLAKE3/pull/319.
42   // return vorrq_u32(vshrq_n_u32(x, 16), vshlq_n_u32(x, 32 - 16));
43   return vreinterpretq_u32_u16(vrev32q_u16(vreinterpretq_u16_u32(x)));
44 }
45 
rot12_128(uint32x4_t x)46 INLINE uint32x4_t rot12_128(uint32x4_t x) {
47   // See comment in rot16_128.
48   // return vorrq_u32(vshrq_n_u32(x, 12), vshlq_n_u32(x, 32 - 12));
49   return vsriq_n_u32(vshlq_n_u32(x, 32-12), x, 12);
50 }
51 
rot8_128(uint32x4_t x)52 INLINE uint32x4_t rot8_128(uint32x4_t x) {
53   // See comment in rot16_128.
54   // return vorrq_u32(vshrq_n_u32(x, 8), vshlq_n_u32(x, 32 - 8));
55 #if defined(__clang__)
56   return vreinterpretq_u32_u8(__builtin_shufflevector(vreinterpretq_u8_u32(x), vreinterpretq_u8_u32(x), 1,2,3,0,5,6,7,4,9,10,11,8,13,14,15,12));
57 #elif __GNUC__ * 10000 + __GNUC_MINOR__ * 100 >=40700
58   static const uint8x16_t r8 = {1,2,3,0,5,6,7,4,9,10,11,8,13,14,15,12};
59   return vreinterpretq_u32_u8(__builtin_shuffle(vreinterpretq_u8_u32(x), vreinterpretq_u8_u32(x), r8));
60 #else
61   return vsriq_n_u32(vshlq_n_u32(x, 32-8), x, 8);
62 #endif
63 }
64 
rot7_128(uint32x4_t x)65 INLINE uint32x4_t rot7_128(uint32x4_t x) {
66   // See comment in rot16_128.
67   // return vorrq_u32(vshrq_n_u32(x, 7), vshlq_n_u32(x, 32 - 7));
68   return vsriq_n_u32(vshlq_n_u32(x, 32-7), x, 7);
69 }
70 
71 // TODO: compress_neon
72 
73 // TODO: hash2_neon
74 
75 /*
76  * ----------------------------------------------------------------------------
77  * hash4_neon
78  * ----------------------------------------------------------------------------
79  */
80 
round_fn4(uint32x4_t v[16],uint32x4_t m[16],size_t r)81 INLINE void round_fn4(uint32x4_t v[16], uint32x4_t m[16], size_t r) {
82   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
83   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
84   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
85   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
86   v[0] = add_128(v[0], v[4]);
87   v[1] = add_128(v[1], v[5]);
88   v[2] = add_128(v[2], v[6]);
89   v[3] = add_128(v[3], v[7]);
90   v[12] = xor_128(v[12], v[0]);
91   v[13] = xor_128(v[13], v[1]);
92   v[14] = xor_128(v[14], v[2]);
93   v[15] = xor_128(v[15], v[3]);
94   v[12] = rot16_128(v[12]);
95   v[13] = rot16_128(v[13]);
96   v[14] = rot16_128(v[14]);
97   v[15] = rot16_128(v[15]);
98   v[8] = add_128(v[8], v[12]);
99   v[9] = add_128(v[9], v[13]);
100   v[10] = add_128(v[10], v[14]);
101   v[11] = add_128(v[11], v[15]);
102   v[4] = xor_128(v[4], v[8]);
103   v[5] = xor_128(v[5], v[9]);
104   v[6] = xor_128(v[6], v[10]);
105   v[7] = xor_128(v[7], v[11]);
106   v[4] = rot12_128(v[4]);
107   v[5] = rot12_128(v[5]);
108   v[6] = rot12_128(v[6]);
109   v[7] = rot12_128(v[7]);
110   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
111   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
112   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
113   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
114   v[0] = add_128(v[0], v[4]);
115   v[1] = add_128(v[1], v[5]);
116   v[2] = add_128(v[2], v[6]);
117   v[3] = add_128(v[3], v[7]);
118   v[12] = xor_128(v[12], v[0]);
119   v[13] = xor_128(v[13], v[1]);
120   v[14] = xor_128(v[14], v[2]);
121   v[15] = xor_128(v[15], v[3]);
122   v[12] = rot8_128(v[12]);
123   v[13] = rot8_128(v[13]);
124   v[14] = rot8_128(v[14]);
125   v[15] = rot8_128(v[15]);
126   v[8] = add_128(v[8], v[12]);
127   v[9] = add_128(v[9], v[13]);
128   v[10] = add_128(v[10], v[14]);
129   v[11] = add_128(v[11], v[15]);
130   v[4] = xor_128(v[4], v[8]);
131   v[5] = xor_128(v[5], v[9]);
132   v[6] = xor_128(v[6], v[10]);
133   v[7] = xor_128(v[7], v[11]);
134   v[4] = rot7_128(v[4]);
135   v[5] = rot7_128(v[5]);
136   v[6] = rot7_128(v[6]);
137   v[7] = rot7_128(v[7]);
138 
139   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
140   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
141   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
142   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
143   v[0] = add_128(v[0], v[5]);
144   v[1] = add_128(v[1], v[6]);
145   v[2] = add_128(v[2], v[7]);
146   v[3] = add_128(v[3], v[4]);
147   v[15] = xor_128(v[15], v[0]);
148   v[12] = xor_128(v[12], v[1]);
149   v[13] = xor_128(v[13], v[2]);
150   v[14] = xor_128(v[14], v[3]);
151   v[15] = rot16_128(v[15]);
152   v[12] = rot16_128(v[12]);
153   v[13] = rot16_128(v[13]);
154   v[14] = rot16_128(v[14]);
155   v[10] = add_128(v[10], v[15]);
156   v[11] = add_128(v[11], v[12]);
157   v[8] = add_128(v[8], v[13]);
158   v[9] = add_128(v[9], v[14]);
159   v[5] = xor_128(v[5], v[10]);
160   v[6] = xor_128(v[6], v[11]);
161   v[7] = xor_128(v[7], v[8]);
162   v[4] = xor_128(v[4], v[9]);
163   v[5] = rot12_128(v[5]);
164   v[6] = rot12_128(v[6]);
165   v[7] = rot12_128(v[7]);
166   v[4] = rot12_128(v[4]);
167   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
168   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
169   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
170   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
171   v[0] = add_128(v[0], v[5]);
172   v[1] = add_128(v[1], v[6]);
173   v[2] = add_128(v[2], v[7]);
174   v[3] = add_128(v[3], v[4]);
175   v[15] = xor_128(v[15], v[0]);
176   v[12] = xor_128(v[12], v[1]);
177   v[13] = xor_128(v[13], v[2]);
178   v[14] = xor_128(v[14], v[3]);
179   v[15] = rot8_128(v[15]);
180   v[12] = rot8_128(v[12]);
181   v[13] = rot8_128(v[13]);
182   v[14] = rot8_128(v[14]);
183   v[10] = add_128(v[10], v[15]);
184   v[11] = add_128(v[11], v[12]);
185   v[8] = add_128(v[8], v[13]);
186   v[9] = add_128(v[9], v[14]);
187   v[5] = xor_128(v[5], v[10]);
188   v[6] = xor_128(v[6], v[11]);
189   v[7] = xor_128(v[7], v[8]);
190   v[4] = xor_128(v[4], v[9]);
191   v[5] = rot7_128(v[5]);
192   v[6] = rot7_128(v[6]);
193   v[7] = rot7_128(v[7]);
194   v[4] = rot7_128(v[4]);
195 }
196 
transpose_vecs_128(uint32x4_t vecs[4])197 INLINE void transpose_vecs_128(uint32x4_t vecs[4]) {
198   // Individually transpose the four 2x2 sub-matrices in each corner.
199   uint32x4x2_t rows01 = vtrnq_u32(vecs[0], vecs[1]);
200   uint32x4x2_t rows23 = vtrnq_u32(vecs[2], vecs[3]);
201 
202   // Swap the top-right and bottom-left 2x2s (which just got transposed).
203   vecs[0] =
204       vcombine_u32(vget_low_u32(rows01.val[0]), vget_low_u32(rows23.val[0]));
205   vecs[1] =
206       vcombine_u32(vget_low_u32(rows01.val[1]), vget_low_u32(rows23.val[1]));
207   vecs[2] =
208       vcombine_u32(vget_high_u32(rows01.val[0]), vget_high_u32(rows23.val[0]));
209   vecs[3] =
210       vcombine_u32(vget_high_u32(rows01.val[1]), vget_high_u32(rows23.val[1]));
211 }
212 
transpose_msg_vecs4(const uint8_t * const * inputs,size_t block_offset,uint32x4_t out[16])213 INLINE void transpose_msg_vecs4(const uint8_t *const *inputs,
214                                 size_t block_offset, uint32x4_t out[16]) {
215   out[0] = loadu_128(&inputs[0][block_offset + 0 * sizeof(uint32x4_t)]);
216   out[1] = loadu_128(&inputs[1][block_offset + 0 * sizeof(uint32x4_t)]);
217   out[2] = loadu_128(&inputs[2][block_offset + 0 * sizeof(uint32x4_t)]);
218   out[3] = loadu_128(&inputs[3][block_offset + 0 * sizeof(uint32x4_t)]);
219   out[4] = loadu_128(&inputs[0][block_offset + 1 * sizeof(uint32x4_t)]);
220   out[5] = loadu_128(&inputs[1][block_offset + 1 * sizeof(uint32x4_t)]);
221   out[6] = loadu_128(&inputs[2][block_offset + 1 * sizeof(uint32x4_t)]);
222   out[7] = loadu_128(&inputs[3][block_offset + 1 * sizeof(uint32x4_t)]);
223   out[8] = loadu_128(&inputs[0][block_offset + 2 * sizeof(uint32x4_t)]);
224   out[9] = loadu_128(&inputs[1][block_offset + 2 * sizeof(uint32x4_t)]);
225   out[10] = loadu_128(&inputs[2][block_offset + 2 * sizeof(uint32x4_t)]);
226   out[11] = loadu_128(&inputs[3][block_offset + 2 * sizeof(uint32x4_t)]);
227   out[12] = loadu_128(&inputs[0][block_offset + 3 * sizeof(uint32x4_t)]);
228   out[13] = loadu_128(&inputs[1][block_offset + 3 * sizeof(uint32x4_t)]);
229   out[14] = loadu_128(&inputs[2][block_offset + 3 * sizeof(uint32x4_t)]);
230   out[15] = loadu_128(&inputs[3][block_offset + 3 * sizeof(uint32x4_t)]);
231   transpose_vecs_128(&out[0]);
232   transpose_vecs_128(&out[4]);
233   transpose_vecs_128(&out[8]);
234   transpose_vecs_128(&out[12]);
235 }
236 
load_counters4(uint64_t counter,bool increment_counter,uint32x4_t * out_low,uint32x4_t * out_high)237 INLINE void load_counters4(uint64_t counter, bool increment_counter,
238                            uint32x4_t *out_low, uint32x4_t *out_high) {
239   uint64_t mask = (increment_counter ? ~0 : 0);
240   *out_low = set4(
241       counter_low(counter + (mask & 0)), counter_low(counter + (mask & 1)),
242       counter_low(counter + (mask & 2)), counter_low(counter + (mask & 3)));
243   *out_high = set4(
244       counter_high(counter + (mask & 0)), counter_high(counter + (mask & 1)),
245       counter_high(counter + (mask & 2)), counter_high(counter + (mask & 3)));
246 }
247 
blake3_hash4_neon(const uint8_t * const * inputs,size_t blocks,const uint32_t key[8],uint64_t counter,bool increment_counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t * out)248 static void blake3_hash4_neon(const uint8_t *const *inputs, size_t blocks,
249                               const uint32_t key[8], uint64_t counter,
250                               bool increment_counter, uint8_t flags,
251                               uint8_t flags_start, uint8_t flags_end,
252                               uint8_t *out) {
253   uint32x4_t h_vecs[8] = {
254       set1_128(key[0]), set1_128(key[1]), set1_128(key[2]), set1_128(key[3]),
255       set1_128(key[4]), set1_128(key[5]), set1_128(key[6]), set1_128(key[7]),
256   };
257   uint32x4_t counter_low_vec, counter_high_vec;
258   load_counters4(counter, increment_counter, &counter_low_vec,
259                  &counter_high_vec);
260   uint8_t block_flags = flags | flags_start;
261 
262   for (size_t block = 0; block < blocks; block++) {
263     if (block + 1 == blocks) {
264       block_flags |= flags_end;
265     }
266     uint32x4_t block_len_vec = set1_128(BLAKE3_BLOCK_LEN);
267     uint32x4_t block_flags_vec = set1_128(block_flags);
268     uint32x4_t msg_vecs[16];
269     transpose_msg_vecs4(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
270 
271     uint32x4_t v[16] = {
272         h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
273         h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
274         set1_128(IV[0]), set1_128(IV[1]),  set1_128(IV[2]), set1_128(IV[3]),
275         counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
276     };
277     round_fn4(v, msg_vecs, 0);
278     round_fn4(v, msg_vecs, 1);
279     round_fn4(v, msg_vecs, 2);
280     round_fn4(v, msg_vecs, 3);
281     round_fn4(v, msg_vecs, 4);
282     round_fn4(v, msg_vecs, 5);
283     round_fn4(v, msg_vecs, 6);
284     h_vecs[0] = xor_128(v[0], v[8]);
285     h_vecs[1] = xor_128(v[1], v[9]);
286     h_vecs[2] = xor_128(v[2], v[10]);
287     h_vecs[3] = xor_128(v[3], v[11]);
288     h_vecs[4] = xor_128(v[4], v[12]);
289     h_vecs[5] = xor_128(v[5], v[13]);
290     h_vecs[6] = xor_128(v[6], v[14]);
291     h_vecs[7] = xor_128(v[7], v[15]);
292 
293     block_flags = flags;
294   }
295 
296   transpose_vecs_128(&h_vecs[0]);
297   transpose_vecs_128(&h_vecs[4]);
298   // The first four vecs now contain the first half of each output, and the
299   // second four vecs contain the second half of each output.
300   storeu_128(h_vecs[0], &out[0 * sizeof(uint32x4_t)]);
301   storeu_128(h_vecs[4], &out[1 * sizeof(uint32x4_t)]);
302   storeu_128(h_vecs[1], &out[2 * sizeof(uint32x4_t)]);
303   storeu_128(h_vecs[5], &out[3 * sizeof(uint32x4_t)]);
304   storeu_128(h_vecs[2], &out[4 * sizeof(uint32x4_t)]);
305   storeu_128(h_vecs[6], &out[5 * sizeof(uint32x4_t)]);
306   storeu_128(h_vecs[3], &out[6 * sizeof(uint32x4_t)]);
307   storeu_128(h_vecs[7], &out[7 * sizeof(uint32x4_t)]);
308 }
309 
310 /*
311  * ----------------------------------------------------------------------------
312  * hash_many_neon
313  * ----------------------------------------------------------------------------
314  */
315 
316 void blake3_compress_in_place_portable(uint32_t cv[8],
317                                        const uint8_t block[BLAKE3_BLOCK_LEN],
318                                        uint8_t block_len, uint64_t counter,
319                                        uint8_t flags);
320 
hash_one_neon(const uint8_t * input,size_t blocks,const uint32_t key[8],uint64_t counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t out[BLAKE3_OUT_LEN])321 INLINE void hash_one_neon(const uint8_t *input, size_t blocks,
322                           const uint32_t key[8], uint64_t counter,
323                           uint8_t flags, uint8_t flags_start, uint8_t flags_end,
324                           uint8_t out[BLAKE3_OUT_LEN]) {
325   uint32_t cv[8];
326   memcpy(cv, key, BLAKE3_KEY_LEN);
327   uint8_t block_flags = flags | flags_start;
328   while (blocks > 0) {
329     if (blocks == 1) {
330       block_flags |= flags_end;
331     }
332     // TODO: Implement compress_neon. However note that according to
333     // https://github.com/BLAKE2/BLAKE2/commit/7965d3e6e1b4193438b8d3a656787587d2579227,
334     // compress_neon might not be any faster than compress_portable.
335     blake3_compress_in_place_portable(cv, input, BLAKE3_BLOCK_LEN, counter,
336                                       block_flags);
337     input = &input[BLAKE3_BLOCK_LEN];
338     blocks -= 1;
339     block_flags = flags;
340   }
341   memcpy(out, cv, BLAKE3_OUT_LEN);
342 }
343 
blake3_hash_many_neon(const uint8_t * const * inputs,size_t num_inputs,size_t blocks,const uint32_t key[8],uint64_t counter,bool increment_counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t * out)344 void blake3_hash_many_neon(const uint8_t *const *inputs, size_t num_inputs,
345                            size_t blocks, const uint32_t key[8],
346                            uint64_t counter, bool increment_counter,
347                            uint8_t flags, uint8_t flags_start,
348                            uint8_t flags_end, uint8_t *out) {
349   while (num_inputs >= 4) {
350     blake3_hash4_neon(inputs, blocks, key, counter, increment_counter, flags,
351                       flags_start, flags_end, out);
352     if (increment_counter) {
353       counter += 4;
354     }
355     inputs += 4;
356     num_inputs -= 4;
357     out = &out[4 * BLAKE3_OUT_LEN];
358   }
359   while (num_inputs > 0) {
360     hash_one_neon(inputs[0], blocks, key, counter, flags, flags_start,
361                   flags_end, out);
362     if (increment_counter) {
363       counter += 1;
364     }
365     inputs += 1;
366     num_inputs -= 1;
367     out = &out[BLAKE3_OUT_LEN];
368   }
369 }
370 
371 #endif // BLAKE3_USE_NEON
372