xref: /linux/arch/arm/crypto/aes-neonbs-glue.c (revision 95298d63c67673c654c08952672d016212b26054)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Bit sliced AES using NEON instructions
4  *
5  * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7 
8 #include <asm/neon.h>
9 #include <asm/simd.h>
10 #include <crypto/aes.h>
11 #include <crypto/cbc.h>
12 #include <crypto/ctr.h>
13 #include <crypto/internal/simd.h>
14 #include <crypto/internal/skcipher.h>
15 #include <crypto/scatterwalk.h>
16 #include <crypto/xts.h>
17 #include <linux/module.h>
18 
19 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
20 MODULE_LICENSE("GPL v2");
21 
22 MODULE_ALIAS_CRYPTO("ecb(aes)");
23 MODULE_ALIAS_CRYPTO("cbc(aes)");
24 MODULE_ALIAS_CRYPTO("ctr(aes)");
25 MODULE_ALIAS_CRYPTO("xts(aes)");
26 
27 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
28 
29 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
30 				  int rounds, int blocks);
31 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
32 				  int rounds, int blocks);
33 
34 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
35 				  int rounds, int blocks, u8 iv[]);
36 
37 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
38 				  int rounds, int blocks, u8 ctr[], u8 final[]);
39 
40 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
41 				  int rounds, int blocks, u8 iv[], int);
42 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
43 				  int rounds, int blocks, u8 iv[], int);
44 
45 struct aesbs_ctx {
46 	int	rounds;
47 	u8	rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
48 };
49 
50 struct aesbs_cbc_ctx {
51 	struct aesbs_ctx	key;
52 	struct crypto_cipher	*enc_tfm;
53 };
54 
55 struct aesbs_xts_ctx {
56 	struct aesbs_ctx	key;
57 	struct crypto_cipher	*cts_tfm;
58 	struct crypto_cipher	*tweak_tfm;
59 };
60 
61 struct aesbs_ctr_ctx {
62 	struct aesbs_ctx	key;		/* must be first member */
63 	struct crypto_aes_ctx	fallback;
64 };
65 
66 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
67 			unsigned int key_len)
68 {
69 	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
70 	struct crypto_aes_ctx rk;
71 	int err;
72 
73 	err = aes_expandkey(&rk, in_key, key_len);
74 	if (err)
75 		return err;
76 
77 	ctx->rounds = 6 + key_len / 4;
78 
79 	kernel_neon_begin();
80 	aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
81 	kernel_neon_end();
82 
83 	return 0;
84 }
85 
86 static int __ecb_crypt(struct skcipher_request *req,
87 		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
88 				  int rounds, int blocks))
89 {
90 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
91 	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
92 	struct skcipher_walk walk;
93 	int err;
94 
95 	err = skcipher_walk_virt(&walk, req, false);
96 
97 	while (walk.nbytes >= AES_BLOCK_SIZE) {
98 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
99 
100 		if (walk.nbytes < walk.total)
101 			blocks = round_down(blocks,
102 					    walk.stride / AES_BLOCK_SIZE);
103 
104 		kernel_neon_begin();
105 		fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
106 		   ctx->rounds, blocks);
107 		kernel_neon_end();
108 		err = skcipher_walk_done(&walk,
109 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
110 	}
111 
112 	return err;
113 }
114 
115 static int ecb_encrypt(struct skcipher_request *req)
116 {
117 	return __ecb_crypt(req, aesbs_ecb_encrypt);
118 }
119 
120 static int ecb_decrypt(struct skcipher_request *req)
121 {
122 	return __ecb_crypt(req, aesbs_ecb_decrypt);
123 }
124 
125 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
126 			    unsigned int key_len)
127 {
128 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
129 	struct crypto_aes_ctx rk;
130 	int err;
131 
132 	err = aes_expandkey(&rk, in_key, key_len);
133 	if (err)
134 		return err;
135 
136 	ctx->key.rounds = 6 + key_len / 4;
137 
138 	kernel_neon_begin();
139 	aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
140 	kernel_neon_end();
141 	memzero_explicit(&rk, sizeof(rk));
142 
143 	return crypto_cipher_setkey(ctx->enc_tfm, in_key, key_len);
144 }
145 
146 static void cbc_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
147 {
148 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
149 
150 	crypto_cipher_encrypt_one(ctx->enc_tfm, dst, src);
151 }
152 
153 static int cbc_encrypt(struct skcipher_request *req)
154 {
155 	return crypto_cbc_encrypt_walk(req, cbc_encrypt_one);
156 }
157 
158 static int cbc_decrypt(struct skcipher_request *req)
159 {
160 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
161 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
162 	struct skcipher_walk walk;
163 	int err;
164 
165 	err = skcipher_walk_virt(&walk, req, false);
166 
167 	while (walk.nbytes >= AES_BLOCK_SIZE) {
168 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
169 
170 		if (walk.nbytes < walk.total)
171 			blocks = round_down(blocks,
172 					    walk.stride / AES_BLOCK_SIZE);
173 
174 		kernel_neon_begin();
175 		aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
176 				  ctx->key.rk, ctx->key.rounds, blocks,
177 				  walk.iv);
178 		kernel_neon_end();
179 		err = skcipher_walk_done(&walk,
180 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
181 	}
182 
183 	return err;
184 }
185 
186 static int cbc_init(struct crypto_tfm *tfm)
187 {
188 	struct aesbs_cbc_ctx *ctx = crypto_tfm_ctx(tfm);
189 
190 	ctx->enc_tfm = crypto_alloc_cipher("aes", 0, 0);
191 
192 	return PTR_ERR_OR_ZERO(ctx->enc_tfm);
193 }
194 
195 static void cbc_exit(struct crypto_tfm *tfm)
196 {
197 	struct aesbs_cbc_ctx *ctx = crypto_tfm_ctx(tfm);
198 
199 	crypto_free_cipher(ctx->enc_tfm);
200 }
201 
202 static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key,
203 				 unsigned int key_len)
204 {
205 	struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
206 	int err;
207 
208 	err = aes_expandkey(&ctx->fallback, in_key, key_len);
209 	if (err)
210 		return err;
211 
212 	ctx->key.rounds = 6 + key_len / 4;
213 
214 	kernel_neon_begin();
215 	aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
216 	kernel_neon_end();
217 
218 	return 0;
219 }
220 
221 static int ctr_encrypt(struct skcipher_request *req)
222 {
223 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
224 	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
225 	struct skcipher_walk walk;
226 	u8 buf[AES_BLOCK_SIZE];
227 	int err;
228 
229 	err = skcipher_walk_virt(&walk, req, false);
230 
231 	while (walk.nbytes > 0) {
232 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
233 		u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL;
234 
235 		if (walk.nbytes < walk.total) {
236 			blocks = round_down(blocks,
237 					    walk.stride / AES_BLOCK_SIZE);
238 			final = NULL;
239 		}
240 
241 		kernel_neon_begin();
242 		aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
243 				  ctx->rk, ctx->rounds, blocks, walk.iv, final);
244 		kernel_neon_end();
245 
246 		if (final) {
247 			u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
248 			u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE;
249 
250 			crypto_xor_cpy(dst, src, final,
251 				       walk.total % AES_BLOCK_SIZE);
252 
253 			err = skcipher_walk_done(&walk, 0);
254 			break;
255 		}
256 		err = skcipher_walk_done(&walk,
257 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
258 	}
259 
260 	return err;
261 }
262 
263 static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
264 {
265 	struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
266 	unsigned long flags;
267 
268 	/*
269 	 * Temporarily disable interrupts to avoid races where
270 	 * cachelines are evicted when the CPU is interrupted
271 	 * to do something else.
272 	 */
273 	local_irq_save(flags);
274 	aes_encrypt(&ctx->fallback, dst, src);
275 	local_irq_restore(flags);
276 }
277 
278 static int ctr_encrypt_sync(struct skcipher_request *req)
279 {
280 	if (!crypto_simd_usable())
281 		return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
282 
283 	return ctr_encrypt(req);
284 }
285 
286 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
287 			    unsigned int key_len)
288 {
289 	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
290 	int err;
291 
292 	err = xts_verify_key(tfm, in_key, key_len);
293 	if (err)
294 		return err;
295 
296 	key_len /= 2;
297 	err = crypto_cipher_setkey(ctx->cts_tfm, in_key, key_len);
298 	if (err)
299 		return err;
300 	err = crypto_cipher_setkey(ctx->tweak_tfm, in_key + key_len, key_len);
301 	if (err)
302 		return err;
303 
304 	return aesbs_setkey(tfm, in_key, key_len);
305 }
306 
307 static int xts_init(struct crypto_tfm *tfm)
308 {
309 	struct aesbs_xts_ctx *ctx = crypto_tfm_ctx(tfm);
310 
311 	ctx->cts_tfm = crypto_alloc_cipher("aes", 0, 0);
312 	if (IS_ERR(ctx->cts_tfm))
313 		return PTR_ERR(ctx->cts_tfm);
314 
315 	ctx->tweak_tfm = crypto_alloc_cipher("aes", 0, 0);
316 	if (IS_ERR(ctx->tweak_tfm))
317 		crypto_free_cipher(ctx->cts_tfm);
318 
319 	return PTR_ERR_OR_ZERO(ctx->tweak_tfm);
320 }
321 
322 static void xts_exit(struct crypto_tfm *tfm)
323 {
324 	struct aesbs_xts_ctx *ctx = crypto_tfm_ctx(tfm);
325 
326 	crypto_free_cipher(ctx->tweak_tfm);
327 	crypto_free_cipher(ctx->cts_tfm);
328 }
329 
330 static int __xts_crypt(struct skcipher_request *req, bool encrypt,
331 		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
332 				  int rounds, int blocks, u8 iv[], int))
333 {
334 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
335 	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
336 	int tail = req->cryptlen % AES_BLOCK_SIZE;
337 	struct skcipher_request subreq;
338 	u8 buf[2 * AES_BLOCK_SIZE];
339 	struct skcipher_walk walk;
340 	int err;
341 
342 	if (req->cryptlen < AES_BLOCK_SIZE)
343 		return -EINVAL;
344 
345 	if (unlikely(tail)) {
346 		skcipher_request_set_tfm(&subreq, tfm);
347 		skcipher_request_set_callback(&subreq,
348 					      skcipher_request_flags(req),
349 					      NULL, NULL);
350 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
351 					   req->cryptlen - tail, req->iv);
352 		req = &subreq;
353 	}
354 
355 	err = skcipher_walk_virt(&walk, req, true);
356 	if (err)
357 		return err;
358 
359 	crypto_cipher_encrypt_one(ctx->tweak_tfm, walk.iv, walk.iv);
360 
361 	while (walk.nbytes >= AES_BLOCK_SIZE) {
362 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
363 		int reorder_last_tweak = !encrypt && tail > 0;
364 
365 		if (walk.nbytes < walk.total) {
366 			blocks = round_down(blocks,
367 					    walk.stride / AES_BLOCK_SIZE);
368 			reorder_last_tweak = 0;
369 		}
370 
371 		kernel_neon_begin();
372 		fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
373 		   ctx->key.rounds, blocks, walk.iv, reorder_last_tweak);
374 		kernel_neon_end();
375 		err = skcipher_walk_done(&walk,
376 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
377 	}
378 
379 	if (err || likely(!tail))
380 		return err;
381 
382 	/* handle ciphertext stealing */
383 	scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
384 				 AES_BLOCK_SIZE, 0);
385 	memcpy(buf + AES_BLOCK_SIZE, buf, tail);
386 	scatterwalk_map_and_copy(buf, req->src, req->cryptlen, tail, 0);
387 
388 	crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
389 
390 	if (encrypt)
391 		crypto_cipher_encrypt_one(ctx->cts_tfm, buf, buf);
392 	else
393 		crypto_cipher_decrypt_one(ctx->cts_tfm, buf, buf);
394 
395 	crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
396 
397 	scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
398 				 AES_BLOCK_SIZE + tail, 1);
399 	return 0;
400 }
401 
402 static int xts_encrypt(struct skcipher_request *req)
403 {
404 	return __xts_crypt(req, true, aesbs_xts_encrypt);
405 }
406 
407 static int xts_decrypt(struct skcipher_request *req)
408 {
409 	return __xts_crypt(req, false, aesbs_xts_decrypt);
410 }
411 
412 static struct skcipher_alg aes_algs[] = { {
413 	.base.cra_name		= "__ecb(aes)",
414 	.base.cra_driver_name	= "__ecb-aes-neonbs",
415 	.base.cra_priority	= 250,
416 	.base.cra_blocksize	= AES_BLOCK_SIZE,
417 	.base.cra_ctxsize	= sizeof(struct aesbs_ctx),
418 	.base.cra_module	= THIS_MODULE,
419 	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
420 
421 	.min_keysize		= AES_MIN_KEY_SIZE,
422 	.max_keysize		= AES_MAX_KEY_SIZE,
423 	.walksize		= 8 * AES_BLOCK_SIZE,
424 	.setkey			= aesbs_setkey,
425 	.encrypt		= ecb_encrypt,
426 	.decrypt		= ecb_decrypt,
427 }, {
428 	.base.cra_name		= "__cbc(aes)",
429 	.base.cra_driver_name	= "__cbc-aes-neonbs",
430 	.base.cra_priority	= 250,
431 	.base.cra_blocksize	= AES_BLOCK_SIZE,
432 	.base.cra_ctxsize	= sizeof(struct aesbs_cbc_ctx),
433 	.base.cra_module	= THIS_MODULE,
434 	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
435 	.base.cra_init		= cbc_init,
436 	.base.cra_exit		= cbc_exit,
437 
438 	.min_keysize		= AES_MIN_KEY_SIZE,
439 	.max_keysize		= AES_MAX_KEY_SIZE,
440 	.walksize		= 8 * AES_BLOCK_SIZE,
441 	.ivsize			= AES_BLOCK_SIZE,
442 	.setkey			= aesbs_cbc_setkey,
443 	.encrypt		= cbc_encrypt,
444 	.decrypt		= cbc_decrypt,
445 }, {
446 	.base.cra_name		= "__ctr(aes)",
447 	.base.cra_driver_name	= "__ctr-aes-neonbs",
448 	.base.cra_priority	= 250,
449 	.base.cra_blocksize	= 1,
450 	.base.cra_ctxsize	= sizeof(struct aesbs_ctx),
451 	.base.cra_module	= THIS_MODULE,
452 	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
453 
454 	.min_keysize		= AES_MIN_KEY_SIZE,
455 	.max_keysize		= AES_MAX_KEY_SIZE,
456 	.chunksize		= AES_BLOCK_SIZE,
457 	.walksize		= 8 * AES_BLOCK_SIZE,
458 	.ivsize			= AES_BLOCK_SIZE,
459 	.setkey			= aesbs_setkey,
460 	.encrypt		= ctr_encrypt,
461 	.decrypt		= ctr_encrypt,
462 }, {
463 	.base.cra_name		= "ctr(aes)",
464 	.base.cra_driver_name	= "ctr-aes-neonbs-sync",
465 	.base.cra_priority	= 250 - 1,
466 	.base.cra_blocksize	= 1,
467 	.base.cra_ctxsize	= sizeof(struct aesbs_ctr_ctx),
468 	.base.cra_module	= THIS_MODULE,
469 
470 	.min_keysize		= AES_MIN_KEY_SIZE,
471 	.max_keysize		= AES_MAX_KEY_SIZE,
472 	.chunksize		= AES_BLOCK_SIZE,
473 	.walksize		= 8 * AES_BLOCK_SIZE,
474 	.ivsize			= AES_BLOCK_SIZE,
475 	.setkey			= aesbs_ctr_setkey_sync,
476 	.encrypt		= ctr_encrypt_sync,
477 	.decrypt		= ctr_encrypt_sync,
478 }, {
479 	.base.cra_name		= "__xts(aes)",
480 	.base.cra_driver_name	= "__xts-aes-neonbs",
481 	.base.cra_priority	= 250,
482 	.base.cra_blocksize	= AES_BLOCK_SIZE,
483 	.base.cra_ctxsize	= sizeof(struct aesbs_xts_ctx),
484 	.base.cra_module	= THIS_MODULE,
485 	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
486 	.base.cra_init		= xts_init,
487 	.base.cra_exit		= xts_exit,
488 
489 	.min_keysize		= 2 * AES_MIN_KEY_SIZE,
490 	.max_keysize		= 2 * AES_MAX_KEY_SIZE,
491 	.walksize		= 8 * AES_BLOCK_SIZE,
492 	.ivsize			= AES_BLOCK_SIZE,
493 	.setkey			= aesbs_xts_setkey,
494 	.encrypt		= xts_encrypt,
495 	.decrypt		= xts_decrypt,
496 } };
497 
498 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
499 
500 static void aes_exit(void)
501 {
502 	int i;
503 
504 	for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
505 		if (aes_simd_algs[i])
506 			simd_skcipher_free(aes_simd_algs[i]);
507 
508 	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
509 }
510 
511 static int __init aes_init(void)
512 {
513 	struct simd_skcipher_alg *simd;
514 	const char *basename;
515 	const char *algname;
516 	const char *drvname;
517 	int err;
518 	int i;
519 
520 	if (!(elf_hwcap & HWCAP_NEON))
521 		return -ENODEV;
522 
523 	err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
524 	if (err)
525 		return err;
526 
527 	for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
528 		if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
529 			continue;
530 
531 		algname = aes_algs[i].base.cra_name + 2;
532 		drvname = aes_algs[i].base.cra_driver_name + 2;
533 		basename = aes_algs[i].base.cra_driver_name;
534 		simd = simd_skcipher_create_compat(algname, drvname, basename);
535 		err = PTR_ERR(simd);
536 		if (IS_ERR(simd))
537 			goto unregister_simds;
538 
539 		aes_simd_algs[i] = simd;
540 	}
541 	return 0;
542 
543 unregister_simds:
544 	aes_exit();
545 	return err;
546 }
547 
548 late_initcall(aes_init);
549 module_exit(aes_exit);
550