xref: /linux/lib/crypto/x86/chacha.h (revision d8768fb12a14c30436bd0466b4fc28edeef45078)
1*13cecc52SEric Biggers /* SPDX-License-Identifier: GPL-2.0-or-later */
2*13cecc52SEric Biggers /*
3*13cecc52SEric Biggers  * ChaCha and HChaCha functions (x86_64 optimized)
4*13cecc52SEric Biggers  *
5*13cecc52SEric Biggers  * Copyright (C) 2015 Martin Willi
6*13cecc52SEric Biggers  */
7*13cecc52SEric Biggers 
8*13cecc52SEric Biggers #include <asm/simd.h>
9*13cecc52SEric Biggers #include <linux/jump_label.h>
10*13cecc52SEric Biggers #include <linux/kernel.h>
11*13cecc52SEric Biggers #include <linux/sizes.h>
12*13cecc52SEric Biggers 
13*13cecc52SEric Biggers asmlinkage void chacha_block_xor_ssse3(const struct chacha_state *state,
14*13cecc52SEric Biggers 				       u8 *dst, const u8 *src,
15*13cecc52SEric Biggers 				       unsigned int len, int nrounds);
16*13cecc52SEric Biggers asmlinkage void chacha_4block_xor_ssse3(const struct chacha_state *state,
17*13cecc52SEric Biggers 					u8 *dst, const u8 *src,
18*13cecc52SEric Biggers 					unsigned int len, int nrounds);
19*13cecc52SEric Biggers asmlinkage void hchacha_block_ssse3(const struct chacha_state *state,
20*13cecc52SEric Biggers 				    u32 out[HCHACHA_OUT_WORDS], int nrounds);
21*13cecc52SEric Biggers 
22*13cecc52SEric Biggers asmlinkage void chacha_2block_xor_avx2(const struct chacha_state *state,
23*13cecc52SEric Biggers 				       u8 *dst, const u8 *src,
24*13cecc52SEric Biggers 				       unsigned int len, int nrounds);
25*13cecc52SEric Biggers asmlinkage void chacha_4block_xor_avx2(const struct chacha_state *state,
26*13cecc52SEric Biggers 				       u8 *dst, const u8 *src,
27*13cecc52SEric Biggers 				       unsigned int len, int nrounds);
28*13cecc52SEric Biggers asmlinkage void chacha_8block_xor_avx2(const struct chacha_state *state,
29*13cecc52SEric Biggers 				       u8 *dst, const u8 *src,
30*13cecc52SEric Biggers 				       unsigned int len, int nrounds);
31*13cecc52SEric Biggers 
32*13cecc52SEric Biggers asmlinkage void chacha_2block_xor_avx512vl(const struct chacha_state *state,
33*13cecc52SEric Biggers 					   u8 *dst, const u8 *src,
34*13cecc52SEric Biggers 					   unsigned int len, int nrounds);
35*13cecc52SEric Biggers asmlinkage void chacha_4block_xor_avx512vl(const struct chacha_state *state,
36*13cecc52SEric Biggers 					   u8 *dst, const u8 *src,
37*13cecc52SEric Biggers 					   unsigned int len, int nrounds);
38*13cecc52SEric Biggers asmlinkage void chacha_8block_xor_avx512vl(const struct chacha_state *state,
39*13cecc52SEric Biggers 					   u8 *dst, const u8 *src,
40*13cecc52SEric Biggers 					   unsigned int len, int nrounds);
41*13cecc52SEric Biggers 
42*13cecc52SEric Biggers static __ro_after_init DEFINE_STATIC_KEY_FALSE(chacha_use_simd);
43*13cecc52SEric Biggers static __ro_after_init DEFINE_STATIC_KEY_FALSE(chacha_use_avx2);
44*13cecc52SEric Biggers static __ro_after_init DEFINE_STATIC_KEY_FALSE(chacha_use_avx512vl);
45*13cecc52SEric Biggers 
chacha_advance(unsigned int len,unsigned int maxblocks)46*13cecc52SEric Biggers static unsigned int chacha_advance(unsigned int len, unsigned int maxblocks)
47*13cecc52SEric Biggers {
48*13cecc52SEric Biggers 	len = min(len, maxblocks * CHACHA_BLOCK_SIZE);
49*13cecc52SEric Biggers 	return round_up(len, CHACHA_BLOCK_SIZE) / CHACHA_BLOCK_SIZE;
50*13cecc52SEric Biggers }
51*13cecc52SEric Biggers 
chacha_dosimd(struct chacha_state * state,u8 * dst,const u8 * src,unsigned int bytes,int nrounds)52*13cecc52SEric Biggers static void chacha_dosimd(struct chacha_state *state, u8 *dst, const u8 *src,
53*13cecc52SEric Biggers 			  unsigned int bytes, int nrounds)
54*13cecc52SEric Biggers {
55*13cecc52SEric Biggers 	if (static_branch_likely(&chacha_use_avx512vl)) {
56*13cecc52SEric Biggers 		while (bytes >= CHACHA_BLOCK_SIZE * 8) {
57*13cecc52SEric Biggers 			chacha_8block_xor_avx512vl(state, dst, src, bytes,
58*13cecc52SEric Biggers 						   nrounds);
59*13cecc52SEric Biggers 			bytes -= CHACHA_BLOCK_SIZE * 8;
60*13cecc52SEric Biggers 			src += CHACHA_BLOCK_SIZE * 8;
61*13cecc52SEric Biggers 			dst += CHACHA_BLOCK_SIZE * 8;
62*13cecc52SEric Biggers 			state->x[12] += 8;
63*13cecc52SEric Biggers 		}
64*13cecc52SEric Biggers 		if (bytes > CHACHA_BLOCK_SIZE * 4) {
65*13cecc52SEric Biggers 			chacha_8block_xor_avx512vl(state, dst, src, bytes,
66*13cecc52SEric Biggers 						   nrounds);
67*13cecc52SEric Biggers 			state->x[12] += chacha_advance(bytes, 8);
68*13cecc52SEric Biggers 			return;
69*13cecc52SEric Biggers 		}
70*13cecc52SEric Biggers 		if (bytes > CHACHA_BLOCK_SIZE * 2) {
71*13cecc52SEric Biggers 			chacha_4block_xor_avx512vl(state, dst, src, bytes,
72*13cecc52SEric Biggers 						   nrounds);
73*13cecc52SEric Biggers 			state->x[12] += chacha_advance(bytes, 4);
74*13cecc52SEric Biggers 			return;
75*13cecc52SEric Biggers 		}
76*13cecc52SEric Biggers 		if (bytes) {
77*13cecc52SEric Biggers 			chacha_2block_xor_avx512vl(state, dst, src, bytes,
78*13cecc52SEric Biggers 						   nrounds);
79*13cecc52SEric Biggers 			state->x[12] += chacha_advance(bytes, 2);
80*13cecc52SEric Biggers 			return;
81*13cecc52SEric Biggers 		}
82*13cecc52SEric Biggers 	}
83*13cecc52SEric Biggers 
84*13cecc52SEric Biggers 	if (static_branch_likely(&chacha_use_avx2)) {
85*13cecc52SEric Biggers 		while (bytes >= CHACHA_BLOCK_SIZE * 8) {
86*13cecc52SEric Biggers 			chacha_8block_xor_avx2(state, dst, src, bytes, nrounds);
87*13cecc52SEric Biggers 			bytes -= CHACHA_BLOCK_SIZE * 8;
88*13cecc52SEric Biggers 			src += CHACHA_BLOCK_SIZE * 8;
89*13cecc52SEric Biggers 			dst += CHACHA_BLOCK_SIZE * 8;
90*13cecc52SEric Biggers 			state->x[12] += 8;
91*13cecc52SEric Biggers 		}
92*13cecc52SEric Biggers 		if (bytes > CHACHA_BLOCK_SIZE * 4) {
93*13cecc52SEric Biggers 			chacha_8block_xor_avx2(state, dst, src, bytes, nrounds);
94*13cecc52SEric Biggers 			state->x[12] += chacha_advance(bytes, 8);
95*13cecc52SEric Biggers 			return;
96*13cecc52SEric Biggers 		}
97*13cecc52SEric Biggers 		if (bytes > CHACHA_BLOCK_SIZE * 2) {
98*13cecc52SEric Biggers 			chacha_4block_xor_avx2(state, dst, src, bytes, nrounds);
99*13cecc52SEric Biggers 			state->x[12] += chacha_advance(bytes, 4);
100*13cecc52SEric Biggers 			return;
101*13cecc52SEric Biggers 		}
102*13cecc52SEric Biggers 		if (bytes > CHACHA_BLOCK_SIZE) {
103*13cecc52SEric Biggers 			chacha_2block_xor_avx2(state, dst, src, bytes, nrounds);
104*13cecc52SEric Biggers 			state->x[12] += chacha_advance(bytes, 2);
105*13cecc52SEric Biggers 			return;
106*13cecc52SEric Biggers 		}
107*13cecc52SEric Biggers 	}
108*13cecc52SEric Biggers 
109*13cecc52SEric Biggers 	while (bytes >= CHACHA_BLOCK_SIZE * 4) {
110*13cecc52SEric Biggers 		chacha_4block_xor_ssse3(state, dst, src, bytes, nrounds);
111*13cecc52SEric Biggers 		bytes -= CHACHA_BLOCK_SIZE * 4;
112*13cecc52SEric Biggers 		src += CHACHA_BLOCK_SIZE * 4;
113*13cecc52SEric Biggers 		dst += CHACHA_BLOCK_SIZE * 4;
114*13cecc52SEric Biggers 		state->x[12] += 4;
115*13cecc52SEric Biggers 	}
116*13cecc52SEric Biggers 	if (bytes > CHACHA_BLOCK_SIZE) {
117*13cecc52SEric Biggers 		chacha_4block_xor_ssse3(state, dst, src, bytes, nrounds);
118*13cecc52SEric Biggers 		state->x[12] += chacha_advance(bytes, 4);
119*13cecc52SEric Biggers 		return;
120*13cecc52SEric Biggers 	}
121*13cecc52SEric Biggers 	if (bytes) {
122*13cecc52SEric Biggers 		chacha_block_xor_ssse3(state, dst, src, bytes, nrounds);
123*13cecc52SEric Biggers 		state->x[12]++;
124*13cecc52SEric Biggers 	}
125*13cecc52SEric Biggers }
126*13cecc52SEric Biggers 
hchacha_block_arch(const struct chacha_state * state,u32 out[HCHACHA_OUT_WORDS],int nrounds)127*13cecc52SEric Biggers static void hchacha_block_arch(const struct chacha_state *state,
128*13cecc52SEric Biggers 			       u32 out[HCHACHA_OUT_WORDS], int nrounds)
129*13cecc52SEric Biggers {
130*13cecc52SEric Biggers 	if (!static_branch_likely(&chacha_use_simd)) {
131*13cecc52SEric Biggers 		hchacha_block_generic(state, out, nrounds);
132*13cecc52SEric Biggers 	} else {
133*13cecc52SEric Biggers 		kernel_fpu_begin();
134*13cecc52SEric Biggers 		hchacha_block_ssse3(state, out, nrounds);
135*13cecc52SEric Biggers 		kernel_fpu_end();
136*13cecc52SEric Biggers 	}
137*13cecc52SEric Biggers }
138*13cecc52SEric Biggers 
chacha_crypt_arch(struct chacha_state * state,u8 * dst,const u8 * src,unsigned int bytes,int nrounds)139*13cecc52SEric Biggers static void chacha_crypt_arch(struct chacha_state *state, u8 *dst,
140*13cecc52SEric Biggers 			      const u8 *src, unsigned int bytes, int nrounds)
141*13cecc52SEric Biggers {
142*13cecc52SEric Biggers 	if (!static_branch_likely(&chacha_use_simd) ||
143*13cecc52SEric Biggers 	    bytes <= CHACHA_BLOCK_SIZE)
144*13cecc52SEric Biggers 		return chacha_crypt_generic(state, dst, src, bytes, nrounds);
145*13cecc52SEric Biggers 
146*13cecc52SEric Biggers 	do {
147*13cecc52SEric Biggers 		unsigned int todo = min_t(unsigned int, bytes, SZ_4K);
148*13cecc52SEric Biggers 
149*13cecc52SEric Biggers 		kernel_fpu_begin();
150*13cecc52SEric Biggers 		chacha_dosimd(state, dst, src, todo, nrounds);
151*13cecc52SEric Biggers 		kernel_fpu_end();
152*13cecc52SEric Biggers 
153*13cecc52SEric Biggers 		bytes -= todo;
154*13cecc52SEric Biggers 		src += todo;
155*13cecc52SEric Biggers 		dst += todo;
156*13cecc52SEric Biggers 	} while (bytes);
157*13cecc52SEric Biggers }
158*13cecc52SEric Biggers 
159*13cecc52SEric Biggers #define chacha_mod_init_arch chacha_mod_init_arch
chacha_mod_init_arch(void)160*13cecc52SEric Biggers static void chacha_mod_init_arch(void)
161*13cecc52SEric Biggers {
162*13cecc52SEric Biggers 	if (!boot_cpu_has(X86_FEATURE_SSSE3))
163*13cecc52SEric Biggers 		return;
164*13cecc52SEric Biggers 
165*13cecc52SEric Biggers 	static_branch_enable(&chacha_use_simd);
166*13cecc52SEric Biggers 
167*13cecc52SEric Biggers 	if (boot_cpu_has(X86_FEATURE_AVX) &&
168*13cecc52SEric Biggers 	    boot_cpu_has(X86_FEATURE_AVX2) &&
169*13cecc52SEric Biggers 	    cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM, NULL)) {
170*13cecc52SEric Biggers 		static_branch_enable(&chacha_use_avx2);
171*13cecc52SEric Biggers 
172*13cecc52SEric Biggers 		if (boot_cpu_has(X86_FEATURE_AVX512VL) &&
173*13cecc52SEric Biggers 		    boot_cpu_has(X86_FEATURE_AVX512BW)) /* kmovq */
174*13cecc52SEric Biggers 			static_branch_enable(&chacha_use_avx512vl);
175*13cecc52SEric Biggers 	}
176*13cecc52SEric Biggers }
177