1 /* SPDX-License-Identifier: GPL-2.0-or-later */
2 /*
3 * ChaCha and HChaCha functions (x86_64 optimized)
4 *
5 * Copyright (C) 2015 Martin Willi
6 */
7
8 #include <asm/simd.h>
9 #include <linux/jump_label.h>
10 #include <linux/kernel.h>
11 #include <linux/sizes.h>
12
13 asmlinkage void chacha_block_xor_ssse3(const struct chacha_state *state,
14 u8 *dst, const u8 *src,
15 unsigned int len, int nrounds);
16 asmlinkage void chacha_4block_xor_ssse3(const struct chacha_state *state,
17 u8 *dst, const u8 *src,
18 unsigned int len, int nrounds);
19 asmlinkage void hchacha_block_ssse3(const struct chacha_state *state,
20 u32 out[HCHACHA_OUT_WORDS], int nrounds);
21
22 asmlinkage void chacha_2block_xor_avx2(const struct chacha_state *state,
23 u8 *dst, const u8 *src,
24 unsigned int len, int nrounds);
25 asmlinkage void chacha_4block_xor_avx2(const struct chacha_state *state,
26 u8 *dst, const u8 *src,
27 unsigned int len, int nrounds);
28 asmlinkage void chacha_8block_xor_avx2(const struct chacha_state *state,
29 u8 *dst, const u8 *src,
30 unsigned int len, int nrounds);
31
32 asmlinkage void chacha_2block_xor_avx512vl(const struct chacha_state *state,
33 u8 *dst, const u8 *src,
34 unsigned int len, int nrounds);
35 asmlinkage void chacha_4block_xor_avx512vl(const struct chacha_state *state,
36 u8 *dst, const u8 *src,
37 unsigned int len, int nrounds);
38 asmlinkage void chacha_8block_xor_avx512vl(const struct chacha_state *state,
39 u8 *dst, const u8 *src,
40 unsigned int len, int nrounds);
41
42 static __ro_after_init DEFINE_STATIC_KEY_FALSE(chacha_use_simd);
43 static __ro_after_init DEFINE_STATIC_KEY_FALSE(chacha_use_avx2);
44 static __ro_after_init DEFINE_STATIC_KEY_FALSE(chacha_use_avx512vl);
45
chacha_advance(unsigned int len,unsigned int maxblocks)46 static unsigned int chacha_advance(unsigned int len, unsigned int maxblocks)
47 {
48 len = min(len, maxblocks * CHACHA_BLOCK_SIZE);
49 return round_up(len, CHACHA_BLOCK_SIZE) / CHACHA_BLOCK_SIZE;
50 }
51
chacha_dosimd(struct chacha_state * state,u8 * dst,const u8 * src,unsigned int bytes,int nrounds)52 static void chacha_dosimd(struct chacha_state *state, u8 *dst, const u8 *src,
53 unsigned int bytes, int nrounds)
54 {
55 if (static_branch_likely(&chacha_use_avx512vl)) {
56 while (bytes >= CHACHA_BLOCK_SIZE * 8) {
57 chacha_8block_xor_avx512vl(state, dst, src, bytes,
58 nrounds);
59 bytes -= CHACHA_BLOCK_SIZE * 8;
60 src += CHACHA_BLOCK_SIZE * 8;
61 dst += CHACHA_BLOCK_SIZE * 8;
62 state->x[12] += 8;
63 }
64 if (bytes > CHACHA_BLOCK_SIZE * 4) {
65 chacha_8block_xor_avx512vl(state, dst, src, bytes,
66 nrounds);
67 state->x[12] += chacha_advance(bytes, 8);
68 return;
69 }
70 if (bytes > CHACHA_BLOCK_SIZE * 2) {
71 chacha_4block_xor_avx512vl(state, dst, src, bytes,
72 nrounds);
73 state->x[12] += chacha_advance(bytes, 4);
74 return;
75 }
76 if (bytes) {
77 chacha_2block_xor_avx512vl(state, dst, src, bytes,
78 nrounds);
79 state->x[12] += chacha_advance(bytes, 2);
80 return;
81 }
82 }
83
84 if (static_branch_likely(&chacha_use_avx2)) {
85 while (bytes >= CHACHA_BLOCK_SIZE * 8) {
86 chacha_8block_xor_avx2(state, dst, src, bytes, nrounds);
87 bytes -= CHACHA_BLOCK_SIZE * 8;
88 src += CHACHA_BLOCK_SIZE * 8;
89 dst += CHACHA_BLOCK_SIZE * 8;
90 state->x[12] += 8;
91 }
92 if (bytes > CHACHA_BLOCK_SIZE * 4) {
93 chacha_8block_xor_avx2(state, dst, src, bytes, nrounds);
94 state->x[12] += chacha_advance(bytes, 8);
95 return;
96 }
97 if (bytes > CHACHA_BLOCK_SIZE * 2) {
98 chacha_4block_xor_avx2(state, dst, src, bytes, nrounds);
99 state->x[12] += chacha_advance(bytes, 4);
100 return;
101 }
102 if (bytes > CHACHA_BLOCK_SIZE) {
103 chacha_2block_xor_avx2(state, dst, src, bytes, nrounds);
104 state->x[12] += chacha_advance(bytes, 2);
105 return;
106 }
107 }
108
109 while (bytes >= CHACHA_BLOCK_SIZE * 4) {
110 chacha_4block_xor_ssse3(state, dst, src, bytes, nrounds);
111 bytes -= CHACHA_BLOCK_SIZE * 4;
112 src += CHACHA_BLOCK_SIZE * 4;
113 dst += CHACHA_BLOCK_SIZE * 4;
114 state->x[12] += 4;
115 }
116 if (bytes > CHACHA_BLOCK_SIZE) {
117 chacha_4block_xor_ssse3(state, dst, src, bytes, nrounds);
118 state->x[12] += chacha_advance(bytes, 4);
119 return;
120 }
121 if (bytes) {
122 chacha_block_xor_ssse3(state, dst, src, bytes, nrounds);
123 state->x[12]++;
124 }
125 }
126
hchacha_block_arch(const struct chacha_state * state,u32 out[HCHACHA_OUT_WORDS],int nrounds)127 static void hchacha_block_arch(const struct chacha_state *state,
128 u32 out[HCHACHA_OUT_WORDS], int nrounds)
129 {
130 if (!static_branch_likely(&chacha_use_simd)) {
131 hchacha_block_generic(state, out, nrounds);
132 } else {
133 kernel_fpu_begin();
134 hchacha_block_ssse3(state, out, nrounds);
135 kernel_fpu_end();
136 }
137 }
138
chacha_crypt_arch(struct chacha_state * state,u8 * dst,const u8 * src,unsigned int bytes,int nrounds)139 static void chacha_crypt_arch(struct chacha_state *state, u8 *dst,
140 const u8 *src, unsigned int bytes, int nrounds)
141 {
142 if (!static_branch_likely(&chacha_use_simd) ||
143 bytes <= CHACHA_BLOCK_SIZE)
144 return chacha_crypt_generic(state, dst, src, bytes, nrounds);
145
146 do {
147 unsigned int todo = min_t(unsigned int, bytes, SZ_4K);
148
149 kernel_fpu_begin();
150 chacha_dosimd(state, dst, src, todo, nrounds);
151 kernel_fpu_end();
152
153 bytes -= todo;
154 src += todo;
155 dst += todo;
156 } while (bytes);
157 }
158
159 #define chacha_mod_init_arch chacha_mod_init_arch
chacha_mod_init_arch(void)160 static void chacha_mod_init_arch(void)
161 {
162 if (!boot_cpu_has(X86_FEATURE_SSSE3))
163 return;
164
165 static_branch_enable(&chacha_use_simd);
166
167 if (boot_cpu_has(X86_FEATURE_AVX) &&
168 boot_cpu_has(X86_FEATURE_AVX2) &&
169 cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM, NULL)) {
170 static_branch_enable(&chacha_use_avx2);
171
172 if (boot_cpu_has(X86_FEATURE_AVX512VL) &&
173 boot_cpu_has(X86_FEATURE_AVX512BW)) /* kmovq */
174 static_branch_enable(&chacha_use_avx512vl);
175 }
176 }
177