xref: /linux/arch/x86/crypto/sm4_aesni_avx_glue.c (revision 0e9ab8e4d44ae9d9aaf213bfd2c90bbe7289337b)
1 /* SPDX-License-Identifier: GPL-2.0-or-later */
2 /*
3  * SM4 Cipher Algorithm, AES-NI/AVX optimized.
4  * as specified in
5  * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
6  *
7  * Copyright (c) 2021, Alibaba Group.
8  * Copyright (c) 2021 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
9  */
10 
11 #include <linux/module.h>
12 #include <linux/crypto.h>
13 #include <linux/kernel.h>
14 #include <asm/simd.h>
15 #include <crypto/internal/simd.h>
16 #include <crypto/internal/skcipher.h>
17 #include <crypto/sm4.h>
18 #include "sm4-avx.h"
19 
20 #define SM4_CRYPT8_BLOCK_SIZE	(SM4_BLOCK_SIZE * 8)
21 
22 asmlinkage void sm4_aesni_avx_crypt4(const u32 *rk, u8 *dst,
23 				const u8 *src, int nblocks);
24 asmlinkage void sm4_aesni_avx_crypt8(const u32 *rk, u8 *dst,
25 				const u8 *src, int nblocks);
26 asmlinkage void sm4_aesni_avx_ctr_enc_blk8(const u32 *rk, u8 *dst,
27 				const u8 *src, u8 *iv);
28 asmlinkage void sm4_aesni_avx_cbc_dec_blk8(const u32 *rk, u8 *dst,
29 				const u8 *src, u8 *iv);
30 asmlinkage void sm4_aesni_avx_cfb_dec_blk8(const u32 *rk, u8 *dst,
31 				const u8 *src, u8 *iv);
32 
33 static int sm4_skcipher_setkey(struct crypto_skcipher *tfm, const u8 *key,
34 			unsigned int key_len)
35 {
36 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
37 
38 	return sm4_expandkey(ctx, key, key_len);
39 }
40 
41 static int ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
42 {
43 	struct skcipher_walk walk;
44 	unsigned int nbytes;
45 	int err;
46 
47 	err = skcipher_walk_virt(&walk, req, false);
48 
49 	while ((nbytes = walk.nbytes) > 0) {
50 		const u8 *src = walk.src.virt.addr;
51 		u8 *dst = walk.dst.virt.addr;
52 
53 		kernel_fpu_begin();
54 		while (nbytes >= SM4_CRYPT8_BLOCK_SIZE) {
55 			sm4_aesni_avx_crypt8(rkey, dst, src, 8);
56 			dst += SM4_CRYPT8_BLOCK_SIZE;
57 			src += SM4_CRYPT8_BLOCK_SIZE;
58 			nbytes -= SM4_CRYPT8_BLOCK_SIZE;
59 		}
60 		while (nbytes >= SM4_BLOCK_SIZE) {
61 			unsigned int nblocks = min(nbytes >> 4, 4u);
62 			sm4_aesni_avx_crypt4(rkey, dst, src, nblocks);
63 			dst += nblocks * SM4_BLOCK_SIZE;
64 			src += nblocks * SM4_BLOCK_SIZE;
65 			nbytes -= nblocks * SM4_BLOCK_SIZE;
66 		}
67 		kernel_fpu_end();
68 
69 		err = skcipher_walk_done(&walk, nbytes);
70 	}
71 
72 	return err;
73 }
74 
75 int sm4_avx_ecb_encrypt(struct skcipher_request *req)
76 {
77 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
78 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
79 
80 	return ecb_do_crypt(req, ctx->rkey_enc);
81 }
82 EXPORT_SYMBOL_GPL(sm4_avx_ecb_encrypt);
83 
84 int sm4_avx_ecb_decrypt(struct skcipher_request *req)
85 {
86 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
87 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
88 
89 	return ecb_do_crypt(req, ctx->rkey_dec);
90 }
91 EXPORT_SYMBOL_GPL(sm4_avx_ecb_decrypt);
92 
93 int sm4_cbc_encrypt(struct skcipher_request *req)
94 {
95 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
96 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
97 	struct skcipher_walk walk;
98 	unsigned int nbytes;
99 	int err;
100 
101 	err = skcipher_walk_virt(&walk, req, false);
102 
103 	while ((nbytes = walk.nbytes) > 0) {
104 		const u8 *iv = walk.iv;
105 		const u8 *src = walk.src.virt.addr;
106 		u8 *dst = walk.dst.virt.addr;
107 
108 		while (nbytes >= SM4_BLOCK_SIZE) {
109 			crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE);
110 			sm4_crypt_block(ctx->rkey_enc, dst, dst);
111 			iv = dst;
112 			src += SM4_BLOCK_SIZE;
113 			dst += SM4_BLOCK_SIZE;
114 			nbytes -= SM4_BLOCK_SIZE;
115 		}
116 		if (iv != walk.iv)
117 			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
118 
119 		err = skcipher_walk_done(&walk, nbytes);
120 	}
121 
122 	return err;
123 }
124 EXPORT_SYMBOL_GPL(sm4_cbc_encrypt);
125 
126 int sm4_avx_cbc_decrypt(struct skcipher_request *req,
127 			unsigned int bsize, sm4_crypt_func func)
128 {
129 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
130 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
131 	struct skcipher_walk walk;
132 	unsigned int nbytes;
133 	int err;
134 
135 	err = skcipher_walk_virt(&walk, req, false);
136 
137 	while ((nbytes = walk.nbytes) > 0) {
138 		const u8 *src = walk.src.virt.addr;
139 		u8 *dst = walk.dst.virt.addr;
140 
141 		kernel_fpu_begin();
142 
143 		while (nbytes >= bsize) {
144 			func(ctx->rkey_dec, dst, src, walk.iv);
145 			dst += bsize;
146 			src += bsize;
147 			nbytes -= bsize;
148 		}
149 
150 		while (nbytes >= SM4_BLOCK_SIZE) {
151 			u8 keystream[SM4_BLOCK_SIZE * 8];
152 			u8 iv[SM4_BLOCK_SIZE];
153 			unsigned int nblocks = min(nbytes >> 4, 8u);
154 			int i;
155 
156 			sm4_aesni_avx_crypt8(ctx->rkey_dec, keystream,
157 						src, nblocks);
158 
159 			src += ((int)nblocks - 2) * SM4_BLOCK_SIZE;
160 			dst += (nblocks - 1) * SM4_BLOCK_SIZE;
161 			memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
162 
163 			for (i = nblocks - 1; i > 0; i--) {
164 				crypto_xor_cpy(dst, src,
165 					&keystream[i * SM4_BLOCK_SIZE],
166 					SM4_BLOCK_SIZE);
167 				src -= SM4_BLOCK_SIZE;
168 				dst -= SM4_BLOCK_SIZE;
169 			}
170 			crypto_xor_cpy(dst, walk.iv, keystream, SM4_BLOCK_SIZE);
171 			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
172 			dst += nblocks * SM4_BLOCK_SIZE;
173 			src += (nblocks + 1) * SM4_BLOCK_SIZE;
174 			nbytes -= nblocks * SM4_BLOCK_SIZE;
175 		}
176 
177 		kernel_fpu_end();
178 		err = skcipher_walk_done(&walk, nbytes);
179 	}
180 
181 	return err;
182 }
183 EXPORT_SYMBOL_GPL(sm4_avx_cbc_decrypt);
184 
185 static int cbc_decrypt(struct skcipher_request *req)
186 {
187 	return sm4_avx_cbc_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
188 				sm4_aesni_avx_cbc_dec_blk8);
189 }
190 
191 int sm4_cfb_encrypt(struct skcipher_request *req)
192 {
193 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
194 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
195 	struct skcipher_walk walk;
196 	unsigned int nbytes;
197 	int err;
198 
199 	err = skcipher_walk_virt(&walk, req, false);
200 
201 	while ((nbytes = walk.nbytes) > 0) {
202 		u8 keystream[SM4_BLOCK_SIZE];
203 		const u8 *iv = walk.iv;
204 		const u8 *src = walk.src.virt.addr;
205 		u8 *dst = walk.dst.virt.addr;
206 
207 		while (nbytes >= SM4_BLOCK_SIZE) {
208 			sm4_crypt_block(ctx->rkey_enc, keystream, iv);
209 			crypto_xor_cpy(dst, src, keystream, SM4_BLOCK_SIZE);
210 			iv = dst;
211 			src += SM4_BLOCK_SIZE;
212 			dst += SM4_BLOCK_SIZE;
213 			nbytes -= SM4_BLOCK_SIZE;
214 		}
215 		if (iv != walk.iv)
216 			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
217 
218 		/* tail */
219 		if (walk.nbytes == walk.total && nbytes > 0) {
220 			sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
221 			crypto_xor_cpy(dst, src, keystream, nbytes);
222 			nbytes = 0;
223 		}
224 
225 		err = skcipher_walk_done(&walk, nbytes);
226 	}
227 
228 	return err;
229 }
230 EXPORT_SYMBOL_GPL(sm4_cfb_encrypt);
231 
232 int sm4_avx_cfb_decrypt(struct skcipher_request *req,
233 			unsigned int bsize, sm4_crypt_func func)
234 {
235 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
236 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
237 	struct skcipher_walk walk;
238 	unsigned int nbytes;
239 	int err;
240 
241 	err = skcipher_walk_virt(&walk, req, false);
242 
243 	while ((nbytes = walk.nbytes) > 0) {
244 		const u8 *src = walk.src.virt.addr;
245 		u8 *dst = walk.dst.virt.addr;
246 
247 		kernel_fpu_begin();
248 
249 		while (nbytes >= bsize) {
250 			func(ctx->rkey_enc, dst, src, walk.iv);
251 			dst += bsize;
252 			src += bsize;
253 			nbytes -= bsize;
254 		}
255 
256 		while (nbytes >= SM4_BLOCK_SIZE) {
257 			u8 keystream[SM4_BLOCK_SIZE * 8];
258 			unsigned int nblocks = min(nbytes >> 4, 8u);
259 
260 			memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
261 			if (nblocks > 1)
262 				memcpy(&keystream[SM4_BLOCK_SIZE], src,
263 					(nblocks - 1) * SM4_BLOCK_SIZE);
264 			memcpy(walk.iv, src + (nblocks - 1) * SM4_BLOCK_SIZE,
265 				SM4_BLOCK_SIZE);
266 
267 			sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
268 						keystream, nblocks);
269 
270 			crypto_xor_cpy(dst, src, keystream,
271 					nblocks * SM4_BLOCK_SIZE);
272 			dst += nblocks * SM4_BLOCK_SIZE;
273 			src += nblocks * SM4_BLOCK_SIZE;
274 			nbytes -= nblocks * SM4_BLOCK_SIZE;
275 		}
276 
277 		kernel_fpu_end();
278 
279 		/* tail */
280 		if (walk.nbytes == walk.total && nbytes > 0) {
281 			u8 keystream[SM4_BLOCK_SIZE];
282 
283 			sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
284 			crypto_xor_cpy(dst, src, keystream, nbytes);
285 			nbytes = 0;
286 		}
287 
288 		err = skcipher_walk_done(&walk, nbytes);
289 	}
290 
291 	return err;
292 }
293 EXPORT_SYMBOL_GPL(sm4_avx_cfb_decrypt);
294 
295 static int cfb_decrypt(struct skcipher_request *req)
296 {
297 	return sm4_avx_cfb_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
298 				sm4_aesni_avx_cfb_dec_blk8);
299 }
300 
301 int sm4_avx_ctr_crypt(struct skcipher_request *req,
302 			unsigned int bsize, sm4_crypt_func func)
303 {
304 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
305 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
306 	struct skcipher_walk walk;
307 	unsigned int nbytes;
308 	int err;
309 
310 	err = skcipher_walk_virt(&walk, req, false);
311 
312 	while ((nbytes = walk.nbytes) > 0) {
313 		const u8 *src = walk.src.virt.addr;
314 		u8 *dst = walk.dst.virt.addr;
315 
316 		kernel_fpu_begin();
317 
318 		while (nbytes >= bsize) {
319 			func(ctx->rkey_enc, dst, src, walk.iv);
320 			dst += bsize;
321 			src += bsize;
322 			nbytes -= bsize;
323 		}
324 
325 		while (nbytes >= SM4_BLOCK_SIZE) {
326 			u8 keystream[SM4_BLOCK_SIZE * 8];
327 			unsigned int nblocks = min(nbytes >> 4, 8u);
328 			int i;
329 
330 			for (i = 0; i < nblocks; i++) {
331 				memcpy(&keystream[i * SM4_BLOCK_SIZE],
332 					walk.iv, SM4_BLOCK_SIZE);
333 				crypto_inc(walk.iv, SM4_BLOCK_SIZE);
334 			}
335 			sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
336 					keystream, nblocks);
337 
338 			crypto_xor_cpy(dst, src, keystream,
339 					nblocks * SM4_BLOCK_SIZE);
340 			dst += nblocks * SM4_BLOCK_SIZE;
341 			src += nblocks * SM4_BLOCK_SIZE;
342 			nbytes -= nblocks * SM4_BLOCK_SIZE;
343 		}
344 
345 		kernel_fpu_end();
346 
347 		/* tail */
348 		if (walk.nbytes == walk.total && nbytes > 0) {
349 			u8 keystream[SM4_BLOCK_SIZE];
350 
351 			memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
352 			crypto_inc(walk.iv, SM4_BLOCK_SIZE);
353 
354 			sm4_crypt_block(ctx->rkey_enc, keystream, keystream);
355 
356 			crypto_xor_cpy(dst, src, keystream, nbytes);
357 			dst += nbytes;
358 			src += nbytes;
359 			nbytes = 0;
360 		}
361 
362 		err = skcipher_walk_done(&walk, nbytes);
363 	}
364 
365 	return err;
366 }
367 EXPORT_SYMBOL_GPL(sm4_avx_ctr_crypt);
368 
369 static int ctr_crypt(struct skcipher_request *req)
370 {
371 	return sm4_avx_ctr_crypt(req, SM4_CRYPT8_BLOCK_SIZE,
372 				sm4_aesni_avx_ctr_enc_blk8);
373 }
374 
375 static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
376 	{
377 		.base = {
378 			.cra_name		= "__ecb(sm4)",
379 			.cra_driver_name	= "__ecb-sm4-aesni-avx",
380 			.cra_priority		= 400,
381 			.cra_flags		= CRYPTO_ALG_INTERNAL,
382 			.cra_blocksize		= SM4_BLOCK_SIZE,
383 			.cra_ctxsize		= sizeof(struct sm4_ctx),
384 			.cra_module		= THIS_MODULE,
385 		},
386 		.min_keysize	= SM4_KEY_SIZE,
387 		.max_keysize	= SM4_KEY_SIZE,
388 		.walksize	= 8 * SM4_BLOCK_SIZE,
389 		.setkey		= sm4_skcipher_setkey,
390 		.encrypt	= sm4_avx_ecb_encrypt,
391 		.decrypt	= sm4_avx_ecb_decrypt,
392 	}, {
393 		.base = {
394 			.cra_name		= "__cbc(sm4)",
395 			.cra_driver_name	= "__cbc-sm4-aesni-avx",
396 			.cra_priority		= 400,
397 			.cra_flags		= CRYPTO_ALG_INTERNAL,
398 			.cra_blocksize		= SM4_BLOCK_SIZE,
399 			.cra_ctxsize		= sizeof(struct sm4_ctx),
400 			.cra_module		= THIS_MODULE,
401 		},
402 		.min_keysize	= SM4_KEY_SIZE,
403 		.max_keysize	= SM4_KEY_SIZE,
404 		.ivsize		= SM4_BLOCK_SIZE,
405 		.walksize	= 8 * SM4_BLOCK_SIZE,
406 		.setkey		= sm4_skcipher_setkey,
407 		.encrypt	= sm4_cbc_encrypt,
408 		.decrypt	= cbc_decrypt,
409 	}, {
410 		.base = {
411 			.cra_name		= "__cfb(sm4)",
412 			.cra_driver_name	= "__cfb-sm4-aesni-avx",
413 			.cra_priority		= 400,
414 			.cra_flags		= CRYPTO_ALG_INTERNAL,
415 			.cra_blocksize		= 1,
416 			.cra_ctxsize		= sizeof(struct sm4_ctx),
417 			.cra_module		= THIS_MODULE,
418 		},
419 		.min_keysize	= SM4_KEY_SIZE,
420 		.max_keysize	= SM4_KEY_SIZE,
421 		.ivsize		= SM4_BLOCK_SIZE,
422 		.chunksize	= SM4_BLOCK_SIZE,
423 		.walksize	= 8 * SM4_BLOCK_SIZE,
424 		.setkey		= sm4_skcipher_setkey,
425 		.encrypt	= sm4_cfb_encrypt,
426 		.decrypt	= cfb_decrypt,
427 	}, {
428 		.base = {
429 			.cra_name		= "__ctr(sm4)",
430 			.cra_driver_name	= "__ctr-sm4-aesni-avx",
431 			.cra_priority		= 400,
432 			.cra_flags		= CRYPTO_ALG_INTERNAL,
433 			.cra_blocksize		= 1,
434 			.cra_ctxsize		= sizeof(struct sm4_ctx),
435 			.cra_module		= THIS_MODULE,
436 		},
437 		.min_keysize	= SM4_KEY_SIZE,
438 		.max_keysize	= SM4_KEY_SIZE,
439 		.ivsize		= SM4_BLOCK_SIZE,
440 		.chunksize	= SM4_BLOCK_SIZE,
441 		.walksize	= 8 * SM4_BLOCK_SIZE,
442 		.setkey		= sm4_skcipher_setkey,
443 		.encrypt	= ctr_crypt,
444 		.decrypt	= ctr_crypt,
445 	}
446 };
447 
448 static struct simd_skcipher_alg *
449 simd_sm4_aesni_avx_skciphers[ARRAY_SIZE(sm4_aesni_avx_skciphers)];
450 
451 static int __init sm4_init(void)
452 {
453 	const char *feature_name;
454 
455 	if (!boot_cpu_has(X86_FEATURE_AVX) ||
456 	    !boot_cpu_has(X86_FEATURE_AES) ||
457 	    !boot_cpu_has(X86_FEATURE_OSXSAVE)) {
458 		pr_info("AVX or AES-NI instructions are not detected.\n");
459 		return -ENODEV;
460 	}
461 
462 	if (!cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM,
463 				&feature_name)) {
464 		pr_info("CPU feature '%s' is not supported.\n", feature_name);
465 		return -ENODEV;
466 	}
467 
468 	return simd_register_skciphers_compat(sm4_aesni_avx_skciphers,
469 					ARRAY_SIZE(sm4_aesni_avx_skciphers),
470 					simd_sm4_aesni_avx_skciphers);
471 }
472 
473 static void __exit sm4_exit(void)
474 {
475 	simd_unregister_skciphers(sm4_aesni_avx_skciphers,
476 					ARRAY_SIZE(sm4_aesni_avx_skciphers),
477 					simd_sm4_aesni_avx_skciphers);
478 }
479 
480 module_init(sm4_init);
481 module_exit(sm4_exit);
482 
483 MODULE_LICENSE("GPL v2");
484 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
485 MODULE_DESCRIPTION("SM4 Cipher Algorithm, AES-NI/AVX optimized");
486 MODULE_ALIAS_CRYPTO("sm4");
487 MODULE_ALIAS_CRYPTO("sm4-aesni-avx");
488