1*13cecc52SEric Biggers /* SPDX-License-Identifier: GPL-2.0-or-later */
2*13cecc52SEric Biggers /*
3*13cecc52SEric Biggers * ChaCha and HChaCha functions (x86_64 optimized)
4*13cecc52SEric Biggers *
5*13cecc52SEric Biggers * Copyright (C) 2015 Martin Willi
6*13cecc52SEric Biggers */
7*13cecc52SEric Biggers
8*13cecc52SEric Biggers #include <asm/simd.h>
9*13cecc52SEric Biggers #include <linux/jump_label.h>
10*13cecc52SEric Biggers #include <linux/kernel.h>
11*13cecc52SEric Biggers #include <linux/sizes.h>
12*13cecc52SEric Biggers
13*13cecc52SEric Biggers asmlinkage void chacha_block_xor_ssse3(const struct chacha_state *state,
14*13cecc52SEric Biggers u8 *dst, const u8 *src,
15*13cecc52SEric Biggers unsigned int len, int nrounds);
16*13cecc52SEric Biggers asmlinkage void chacha_4block_xor_ssse3(const struct chacha_state *state,
17*13cecc52SEric Biggers u8 *dst, const u8 *src,
18*13cecc52SEric Biggers unsigned int len, int nrounds);
19*13cecc52SEric Biggers asmlinkage void hchacha_block_ssse3(const struct chacha_state *state,
20*13cecc52SEric Biggers u32 out[HCHACHA_OUT_WORDS], int nrounds);
21*13cecc52SEric Biggers
22*13cecc52SEric Biggers asmlinkage void chacha_2block_xor_avx2(const struct chacha_state *state,
23*13cecc52SEric Biggers u8 *dst, const u8 *src,
24*13cecc52SEric Biggers unsigned int len, int nrounds);
25*13cecc52SEric Biggers asmlinkage void chacha_4block_xor_avx2(const struct chacha_state *state,
26*13cecc52SEric Biggers u8 *dst, const u8 *src,
27*13cecc52SEric Biggers unsigned int len, int nrounds);
28*13cecc52SEric Biggers asmlinkage void chacha_8block_xor_avx2(const struct chacha_state *state,
29*13cecc52SEric Biggers u8 *dst, const u8 *src,
30*13cecc52SEric Biggers unsigned int len, int nrounds);
31*13cecc52SEric Biggers
32*13cecc52SEric Biggers asmlinkage void chacha_2block_xor_avx512vl(const struct chacha_state *state,
33*13cecc52SEric Biggers u8 *dst, const u8 *src,
34*13cecc52SEric Biggers unsigned int len, int nrounds);
35*13cecc52SEric Biggers asmlinkage void chacha_4block_xor_avx512vl(const struct chacha_state *state,
36*13cecc52SEric Biggers u8 *dst, const u8 *src,
37*13cecc52SEric Biggers unsigned int len, int nrounds);
38*13cecc52SEric Biggers asmlinkage void chacha_8block_xor_avx512vl(const struct chacha_state *state,
39*13cecc52SEric Biggers u8 *dst, const u8 *src,
40*13cecc52SEric Biggers unsigned int len, int nrounds);
41*13cecc52SEric Biggers
42*13cecc52SEric Biggers static __ro_after_init DEFINE_STATIC_KEY_FALSE(chacha_use_simd);
43*13cecc52SEric Biggers static __ro_after_init DEFINE_STATIC_KEY_FALSE(chacha_use_avx2);
44*13cecc52SEric Biggers static __ro_after_init DEFINE_STATIC_KEY_FALSE(chacha_use_avx512vl);
45*13cecc52SEric Biggers
chacha_advance(unsigned int len,unsigned int maxblocks)46*13cecc52SEric Biggers static unsigned int chacha_advance(unsigned int len, unsigned int maxblocks)
47*13cecc52SEric Biggers {
48*13cecc52SEric Biggers len = min(len, maxblocks * CHACHA_BLOCK_SIZE);
49*13cecc52SEric Biggers return round_up(len, CHACHA_BLOCK_SIZE) / CHACHA_BLOCK_SIZE;
50*13cecc52SEric Biggers }
51*13cecc52SEric Biggers
chacha_dosimd(struct chacha_state * state,u8 * dst,const u8 * src,unsigned int bytes,int nrounds)52*13cecc52SEric Biggers static void chacha_dosimd(struct chacha_state *state, u8 *dst, const u8 *src,
53*13cecc52SEric Biggers unsigned int bytes, int nrounds)
54*13cecc52SEric Biggers {
55*13cecc52SEric Biggers if (static_branch_likely(&chacha_use_avx512vl)) {
56*13cecc52SEric Biggers while (bytes >= CHACHA_BLOCK_SIZE * 8) {
57*13cecc52SEric Biggers chacha_8block_xor_avx512vl(state, dst, src, bytes,
58*13cecc52SEric Biggers nrounds);
59*13cecc52SEric Biggers bytes -= CHACHA_BLOCK_SIZE * 8;
60*13cecc52SEric Biggers src += CHACHA_BLOCK_SIZE * 8;
61*13cecc52SEric Biggers dst += CHACHA_BLOCK_SIZE * 8;
62*13cecc52SEric Biggers state->x[12] += 8;
63*13cecc52SEric Biggers }
64*13cecc52SEric Biggers if (bytes > CHACHA_BLOCK_SIZE * 4) {
65*13cecc52SEric Biggers chacha_8block_xor_avx512vl(state, dst, src, bytes,
66*13cecc52SEric Biggers nrounds);
67*13cecc52SEric Biggers state->x[12] += chacha_advance(bytes, 8);
68*13cecc52SEric Biggers return;
69*13cecc52SEric Biggers }
70*13cecc52SEric Biggers if (bytes > CHACHA_BLOCK_SIZE * 2) {
71*13cecc52SEric Biggers chacha_4block_xor_avx512vl(state, dst, src, bytes,
72*13cecc52SEric Biggers nrounds);
73*13cecc52SEric Biggers state->x[12] += chacha_advance(bytes, 4);
74*13cecc52SEric Biggers return;
75*13cecc52SEric Biggers }
76*13cecc52SEric Biggers if (bytes) {
77*13cecc52SEric Biggers chacha_2block_xor_avx512vl(state, dst, src, bytes,
78*13cecc52SEric Biggers nrounds);
79*13cecc52SEric Biggers state->x[12] += chacha_advance(bytes, 2);
80*13cecc52SEric Biggers return;
81*13cecc52SEric Biggers }
82*13cecc52SEric Biggers }
83*13cecc52SEric Biggers
84*13cecc52SEric Biggers if (static_branch_likely(&chacha_use_avx2)) {
85*13cecc52SEric Biggers while (bytes >= CHACHA_BLOCK_SIZE * 8) {
86*13cecc52SEric Biggers chacha_8block_xor_avx2(state, dst, src, bytes, nrounds);
87*13cecc52SEric Biggers bytes -= CHACHA_BLOCK_SIZE * 8;
88*13cecc52SEric Biggers src += CHACHA_BLOCK_SIZE * 8;
89*13cecc52SEric Biggers dst += CHACHA_BLOCK_SIZE * 8;
90*13cecc52SEric Biggers state->x[12] += 8;
91*13cecc52SEric Biggers }
92*13cecc52SEric Biggers if (bytes > CHACHA_BLOCK_SIZE * 4) {
93*13cecc52SEric Biggers chacha_8block_xor_avx2(state, dst, src, bytes, nrounds);
94*13cecc52SEric Biggers state->x[12] += chacha_advance(bytes, 8);
95*13cecc52SEric Biggers return;
96*13cecc52SEric Biggers }
97*13cecc52SEric Biggers if (bytes > CHACHA_BLOCK_SIZE * 2) {
98*13cecc52SEric Biggers chacha_4block_xor_avx2(state, dst, src, bytes, nrounds);
99*13cecc52SEric Biggers state->x[12] += chacha_advance(bytes, 4);
100*13cecc52SEric Biggers return;
101*13cecc52SEric Biggers }
102*13cecc52SEric Biggers if (bytes > CHACHA_BLOCK_SIZE) {
103*13cecc52SEric Biggers chacha_2block_xor_avx2(state, dst, src, bytes, nrounds);
104*13cecc52SEric Biggers state->x[12] += chacha_advance(bytes, 2);
105*13cecc52SEric Biggers return;
106*13cecc52SEric Biggers }
107*13cecc52SEric Biggers }
108*13cecc52SEric Biggers
109*13cecc52SEric Biggers while (bytes >= CHACHA_BLOCK_SIZE * 4) {
110*13cecc52SEric Biggers chacha_4block_xor_ssse3(state, dst, src, bytes, nrounds);
111*13cecc52SEric Biggers bytes -= CHACHA_BLOCK_SIZE * 4;
112*13cecc52SEric Biggers src += CHACHA_BLOCK_SIZE * 4;
113*13cecc52SEric Biggers dst += CHACHA_BLOCK_SIZE * 4;
114*13cecc52SEric Biggers state->x[12] += 4;
115*13cecc52SEric Biggers }
116*13cecc52SEric Biggers if (bytes > CHACHA_BLOCK_SIZE) {
117*13cecc52SEric Biggers chacha_4block_xor_ssse3(state, dst, src, bytes, nrounds);
118*13cecc52SEric Biggers state->x[12] += chacha_advance(bytes, 4);
119*13cecc52SEric Biggers return;
120*13cecc52SEric Biggers }
121*13cecc52SEric Biggers if (bytes) {
122*13cecc52SEric Biggers chacha_block_xor_ssse3(state, dst, src, bytes, nrounds);
123*13cecc52SEric Biggers state->x[12]++;
124*13cecc52SEric Biggers }
125*13cecc52SEric Biggers }
126*13cecc52SEric Biggers
hchacha_block_arch(const struct chacha_state * state,u32 out[HCHACHA_OUT_WORDS],int nrounds)127*13cecc52SEric Biggers static void hchacha_block_arch(const struct chacha_state *state,
128*13cecc52SEric Biggers u32 out[HCHACHA_OUT_WORDS], int nrounds)
129*13cecc52SEric Biggers {
130*13cecc52SEric Biggers if (!static_branch_likely(&chacha_use_simd)) {
131*13cecc52SEric Biggers hchacha_block_generic(state, out, nrounds);
132*13cecc52SEric Biggers } else {
133*13cecc52SEric Biggers kernel_fpu_begin();
134*13cecc52SEric Biggers hchacha_block_ssse3(state, out, nrounds);
135*13cecc52SEric Biggers kernel_fpu_end();
136*13cecc52SEric Biggers }
137*13cecc52SEric Biggers }
138*13cecc52SEric Biggers
chacha_crypt_arch(struct chacha_state * state,u8 * dst,const u8 * src,unsigned int bytes,int nrounds)139*13cecc52SEric Biggers static void chacha_crypt_arch(struct chacha_state *state, u8 *dst,
140*13cecc52SEric Biggers const u8 *src, unsigned int bytes, int nrounds)
141*13cecc52SEric Biggers {
142*13cecc52SEric Biggers if (!static_branch_likely(&chacha_use_simd) ||
143*13cecc52SEric Biggers bytes <= CHACHA_BLOCK_SIZE)
144*13cecc52SEric Biggers return chacha_crypt_generic(state, dst, src, bytes, nrounds);
145*13cecc52SEric Biggers
146*13cecc52SEric Biggers do {
147*13cecc52SEric Biggers unsigned int todo = min_t(unsigned int, bytes, SZ_4K);
148*13cecc52SEric Biggers
149*13cecc52SEric Biggers kernel_fpu_begin();
150*13cecc52SEric Biggers chacha_dosimd(state, dst, src, todo, nrounds);
151*13cecc52SEric Biggers kernel_fpu_end();
152*13cecc52SEric Biggers
153*13cecc52SEric Biggers bytes -= todo;
154*13cecc52SEric Biggers src += todo;
155*13cecc52SEric Biggers dst += todo;
156*13cecc52SEric Biggers } while (bytes);
157*13cecc52SEric Biggers }
158*13cecc52SEric Biggers
159*13cecc52SEric Biggers #define chacha_mod_init_arch chacha_mod_init_arch
chacha_mod_init_arch(void)160*13cecc52SEric Biggers static void chacha_mod_init_arch(void)
161*13cecc52SEric Biggers {
162*13cecc52SEric Biggers if (!boot_cpu_has(X86_FEATURE_SSSE3))
163*13cecc52SEric Biggers return;
164*13cecc52SEric Biggers
165*13cecc52SEric Biggers static_branch_enable(&chacha_use_simd);
166*13cecc52SEric Biggers
167*13cecc52SEric Biggers if (boot_cpu_has(X86_FEATURE_AVX) &&
168*13cecc52SEric Biggers boot_cpu_has(X86_FEATURE_AVX2) &&
169*13cecc52SEric Biggers cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM, NULL)) {
170*13cecc52SEric Biggers static_branch_enable(&chacha_use_avx2);
171*13cecc52SEric Biggers
172*13cecc52SEric Biggers if (boot_cpu_has(X86_FEATURE_AVX512VL) &&
173*13cecc52SEric Biggers boot_cpu_has(X86_FEATURE_AVX512BW)) /* kmovq */
174*13cecc52SEric Biggers static_branch_enable(&chacha_use_avx512vl);
175*13cecc52SEric Biggers }
176*13cecc52SEric Biggers }
177