xref: /linux/arch/arm/crypto/aes-neonbs-glue.c (revision a3a02a52bcfcbcc4a637d4b68bf1bc391c9fad02)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Bit sliced AES using NEON instructions
4  *
5  * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7 
8 #include <asm/neon.h>
9 #include <asm/simd.h>
10 #include <crypto/aes.h>
11 #include <crypto/ctr.h>
12 #include <crypto/internal/cipher.h>
13 #include <crypto/internal/simd.h>
14 #include <crypto/internal/skcipher.h>
15 #include <crypto/scatterwalk.h>
16 #include <crypto/xts.h>
17 #include <linux/module.h>
18 
19 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
20 MODULE_DESCRIPTION("Bit sliced AES using NEON instructions");
21 MODULE_LICENSE("GPL v2");
22 
23 MODULE_ALIAS_CRYPTO("ecb(aes)");
24 MODULE_ALIAS_CRYPTO("cbc(aes)-all");
25 MODULE_ALIAS_CRYPTO("ctr(aes)");
26 MODULE_ALIAS_CRYPTO("xts(aes)");
27 
28 MODULE_IMPORT_NS(CRYPTO_INTERNAL);
29 
30 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
31 
32 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
33 				  int rounds, int blocks);
34 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
35 				  int rounds, int blocks);
36 
37 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
38 				  int rounds, int blocks, u8 iv[]);
39 
40 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
41 				  int rounds, int blocks, u8 ctr[]);
42 
43 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
44 				  int rounds, int blocks, u8 iv[], int);
45 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
46 				  int rounds, int blocks, u8 iv[], int);
47 
48 struct aesbs_ctx {
49 	int	rounds;
50 	u8	rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
51 };
52 
53 struct aesbs_cbc_ctx {
54 	struct aesbs_ctx	key;
55 	struct crypto_skcipher	*enc_tfm;
56 };
57 
58 struct aesbs_xts_ctx {
59 	struct aesbs_ctx	key;
60 	struct crypto_cipher	*cts_tfm;
61 	struct crypto_cipher	*tweak_tfm;
62 };
63 
64 struct aesbs_ctr_ctx {
65 	struct aesbs_ctx	key;		/* must be first member */
66 	struct crypto_aes_ctx	fallback;
67 };
68 
69 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
70 			unsigned int key_len)
71 {
72 	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
73 	struct crypto_aes_ctx rk;
74 	int err;
75 
76 	err = aes_expandkey(&rk, in_key, key_len);
77 	if (err)
78 		return err;
79 
80 	ctx->rounds = 6 + key_len / 4;
81 
82 	kernel_neon_begin();
83 	aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
84 	kernel_neon_end();
85 
86 	return 0;
87 }
88 
89 static int __ecb_crypt(struct skcipher_request *req,
90 		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
91 				  int rounds, int blocks))
92 {
93 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
94 	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
95 	struct skcipher_walk walk;
96 	int err;
97 
98 	err = skcipher_walk_virt(&walk, req, false);
99 
100 	while (walk.nbytes >= AES_BLOCK_SIZE) {
101 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
102 
103 		if (walk.nbytes < walk.total)
104 			blocks = round_down(blocks,
105 					    walk.stride / AES_BLOCK_SIZE);
106 
107 		kernel_neon_begin();
108 		fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
109 		   ctx->rounds, blocks);
110 		kernel_neon_end();
111 		err = skcipher_walk_done(&walk,
112 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
113 	}
114 
115 	return err;
116 }
117 
118 static int ecb_encrypt(struct skcipher_request *req)
119 {
120 	return __ecb_crypt(req, aesbs_ecb_encrypt);
121 }
122 
123 static int ecb_decrypt(struct skcipher_request *req)
124 {
125 	return __ecb_crypt(req, aesbs_ecb_decrypt);
126 }
127 
128 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
129 			    unsigned int key_len)
130 {
131 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
132 	struct crypto_aes_ctx rk;
133 	int err;
134 
135 	err = aes_expandkey(&rk, in_key, key_len);
136 	if (err)
137 		return err;
138 
139 	ctx->key.rounds = 6 + key_len / 4;
140 
141 	kernel_neon_begin();
142 	aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
143 	kernel_neon_end();
144 	memzero_explicit(&rk, sizeof(rk));
145 
146 	return crypto_skcipher_setkey(ctx->enc_tfm, in_key, key_len);
147 }
148 
149 static int cbc_encrypt(struct skcipher_request *req)
150 {
151 	struct skcipher_request *subreq = skcipher_request_ctx(req);
152 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
153 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
154 
155 	skcipher_request_set_tfm(subreq, ctx->enc_tfm);
156 	skcipher_request_set_callback(subreq,
157 				      skcipher_request_flags(req),
158 				      NULL, NULL);
159 	skcipher_request_set_crypt(subreq, req->src, req->dst,
160 				   req->cryptlen, req->iv);
161 
162 	return crypto_skcipher_encrypt(subreq);
163 }
164 
165 static int cbc_decrypt(struct skcipher_request *req)
166 {
167 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
168 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
169 	struct skcipher_walk walk;
170 	int err;
171 
172 	err = skcipher_walk_virt(&walk, req, false);
173 
174 	while (walk.nbytes >= AES_BLOCK_SIZE) {
175 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
176 
177 		if (walk.nbytes < walk.total)
178 			blocks = round_down(blocks,
179 					    walk.stride / AES_BLOCK_SIZE);
180 
181 		kernel_neon_begin();
182 		aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
183 				  ctx->key.rk, ctx->key.rounds, blocks,
184 				  walk.iv);
185 		kernel_neon_end();
186 		err = skcipher_walk_done(&walk,
187 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
188 	}
189 
190 	return err;
191 }
192 
193 static int cbc_init(struct crypto_skcipher *tfm)
194 {
195 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
196 	unsigned int reqsize;
197 
198 	ctx->enc_tfm = crypto_alloc_skcipher("cbc(aes)", 0, CRYPTO_ALG_ASYNC |
199 					     CRYPTO_ALG_NEED_FALLBACK);
200 	if (IS_ERR(ctx->enc_tfm))
201 		return PTR_ERR(ctx->enc_tfm);
202 
203 	reqsize = sizeof(struct skcipher_request);
204 	reqsize += crypto_skcipher_reqsize(ctx->enc_tfm);
205 	crypto_skcipher_set_reqsize(tfm, reqsize);
206 
207 	return 0;
208 }
209 
210 static void cbc_exit(struct crypto_skcipher *tfm)
211 {
212 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
213 
214 	crypto_free_skcipher(ctx->enc_tfm);
215 }
216 
217 static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key,
218 				 unsigned int key_len)
219 {
220 	struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
221 	int err;
222 
223 	err = aes_expandkey(&ctx->fallback, in_key, key_len);
224 	if (err)
225 		return err;
226 
227 	ctx->key.rounds = 6 + key_len / 4;
228 
229 	kernel_neon_begin();
230 	aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
231 	kernel_neon_end();
232 
233 	return 0;
234 }
235 
236 static int ctr_encrypt(struct skcipher_request *req)
237 {
238 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
239 	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
240 	struct skcipher_walk walk;
241 	u8 buf[AES_BLOCK_SIZE];
242 	int err;
243 
244 	err = skcipher_walk_virt(&walk, req, false);
245 
246 	while (walk.nbytes > 0) {
247 		const u8 *src = walk.src.virt.addr;
248 		u8 *dst = walk.dst.virt.addr;
249 		int bytes = walk.nbytes;
250 
251 		if (unlikely(bytes < AES_BLOCK_SIZE))
252 			src = dst = memcpy(buf + sizeof(buf) - bytes,
253 					   src, bytes);
254 		else if (walk.nbytes < walk.total)
255 			bytes &= ~(8 * AES_BLOCK_SIZE - 1);
256 
257 		kernel_neon_begin();
258 		aesbs_ctr_encrypt(dst, src, ctx->rk, ctx->rounds, bytes, walk.iv);
259 		kernel_neon_end();
260 
261 		if (unlikely(bytes < AES_BLOCK_SIZE))
262 			memcpy(walk.dst.virt.addr,
263 			       buf + sizeof(buf) - bytes, bytes);
264 
265 		err = skcipher_walk_done(&walk, walk.nbytes - bytes);
266 	}
267 
268 	return err;
269 }
270 
271 static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
272 {
273 	struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
274 	unsigned long flags;
275 
276 	/*
277 	 * Temporarily disable interrupts to avoid races where
278 	 * cachelines are evicted when the CPU is interrupted
279 	 * to do something else.
280 	 */
281 	local_irq_save(flags);
282 	aes_encrypt(&ctx->fallback, dst, src);
283 	local_irq_restore(flags);
284 }
285 
286 static int ctr_encrypt_sync(struct skcipher_request *req)
287 {
288 	if (!crypto_simd_usable())
289 		return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
290 
291 	return ctr_encrypt(req);
292 }
293 
294 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
295 			    unsigned int key_len)
296 {
297 	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
298 	int err;
299 
300 	err = xts_verify_key(tfm, in_key, key_len);
301 	if (err)
302 		return err;
303 
304 	key_len /= 2;
305 	err = crypto_cipher_setkey(ctx->cts_tfm, in_key, key_len);
306 	if (err)
307 		return err;
308 	err = crypto_cipher_setkey(ctx->tweak_tfm, in_key + key_len, key_len);
309 	if (err)
310 		return err;
311 
312 	return aesbs_setkey(tfm, in_key, key_len);
313 }
314 
315 static int xts_init(struct crypto_skcipher *tfm)
316 {
317 	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
318 
319 	ctx->cts_tfm = crypto_alloc_cipher("aes", 0, 0);
320 	if (IS_ERR(ctx->cts_tfm))
321 		return PTR_ERR(ctx->cts_tfm);
322 
323 	ctx->tweak_tfm = crypto_alloc_cipher("aes", 0, 0);
324 	if (IS_ERR(ctx->tweak_tfm))
325 		crypto_free_cipher(ctx->cts_tfm);
326 
327 	return PTR_ERR_OR_ZERO(ctx->tweak_tfm);
328 }
329 
330 static void xts_exit(struct crypto_skcipher *tfm)
331 {
332 	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
333 
334 	crypto_free_cipher(ctx->tweak_tfm);
335 	crypto_free_cipher(ctx->cts_tfm);
336 }
337 
338 static int __xts_crypt(struct skcipher_request *req, bool encrypt,
339 		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
340 				  int rounds, int blocks, u8 iv[], int))
341 {
342 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
343 	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
344 	int tail = req->cryptlen % AES_BLOCK_SIZE;
345 	struct skcipher_request subreq;
346 	u8 buf[2 * AES_BLOCK_SIZE];
347 	struct skcipher_walk walk;
348 	int err;
349 
350 	if (req->cryptlen < AES_BLOCK_SIZE)
351 		return -EINVAL;
352 
353 	if (unlikely(tail)) {
354 		skcipher_request_set_tfm(&subreq, tfm);
355 		skcipher_request_set_callback(&subreq,
356 					      skcipher_request_flags(req),
357 					      NULL, NULL);
358 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
359 					   req->cryptlen - tail, req->iv);
360 		req = &subreq;
361 	}
362 
363 	err = skcipher_walk_virt(&walk, req, true);
364 	if (err)
365 		return err;
366 
367 	crypto_cipher_encrypt_one(ctx->tweak_tfm, walk.iv, walk.iv);
368 
369 	while (walk.nbytes >= AES_BLOCK_SIZE) {
370 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
371 		int reorder_last_tweak = !encrypt && tail > 0;
372 
373 		if (walk.nbytes < walk.total) {
374 			blocks = round_down(blocks,
375 					    walk.stride / AES_BLOCK_SIZE);
376 			reorder_last_tweak = 0;
377 		}
378 
379 		kernel_neon_begin();
380 		fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
381 		   ctx->key.rounds, blocks, walk.iv, reorder_last_tweak);
382 		kernel_neon_end();
383 		err = skcipher_walk_done(&walk,
384 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
385 	}
386 
387 	if (err || likely(!tail))
388 		return err;
389 
390 	/* handle ciphertext stealing */
391 	scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
392 				 AES_BLOCK_SIZE, 0);
393 	memcpy(buf + AES_BLOCK_SIZE, buf, tail);
394 	scatterwalk_map_and_copy(buf, req->src, req->cryptlen, tail, 0);
395 
396 	crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
397 
398 	if (encrypt)
399 		crypto_cipher_encrypt_one(ctx->cts_tfm, buf, buf);
400 	else
401 		crypto_cipher_decrypt_one(ctx->cts_tfm, buf, buf);
402 
403 	crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
404 
405 	scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
406 				 AES_BLOCK_SIZE + tail, 1);
407 	return 0;
408 }
409 
410 static int xts_encrypt(struct skcipher_request *req)
411 {
412 	return __xts_crypt(req, true, aesbs_xts_encrypt);
413 }
414 
415 static int xts_decrypt(struct skcipher_request *req)
416 {
417 	return __xts_crypt(req, false, aesbs_xts_decrypt);
418 }
419 
420 static struct skcipher_alg aes_algs[] = { {
421 	.base.cra_name		= "__ecb(aes)",
422 	.base.cra_driver_name	= "__ecb-aes-neonbs",
423 	.base.cra_priority	= 250,
424 	.base.cra_blocksize	= AES_BLOCK_SIZE,
425 	.base.cra_ctxsize	= sizeof(struct aesbs_ctx),
426 	.base.cra_module	= THIS_MODULE,
427 	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
428 
429 	.min_keysize		= AES_MIN_KEY_SIZE,
430 	.max_keysize		= AES_MAX_KEY_SIZE,
431 	.walksize		= 8 * AES_BLOCK_SIZE,
432 	.setkey			= aesbs_setkey,
433 	.encrypt		= ecb_encrypt,
434 	.decrypt		= ecb_decrypt,
435 }, {
436 	.base.cra_name		= "__cbc(aes)",
437 	.base.cra_driver_name	= "__cbc-aes-neonbs",
438 	.base.cra_priority	= 250,
439 	.base.cra_blocksize	= AES_BLOCK_SIZE,
440 	.base.cra_ctxsize	= sizeof(struct aesbs_cbc_ctx),
441 	.base.cra_module	= THIS_MODULE,
442 	.base.cra_flags		= CRYPTO_ALG_INTERNAL |
443 				  CRYPTO_ALG_NEED_FALLBACK,
444 
445 	.min_keysize		= AES_MIN_KEY_SIZE,
446 	.max_keysize		= AES_MAX_KEY_SIZE,
447 	.walksize		= 8 * AES_BLOCK_SIZE,
448 	.ivsize			= AES_BLOCK_SIZE,
449 	.setkey			= aesbs_cbc_setkey,
450 	.encrypt		= cbc_encrypt,
451 	.decrypt		= cbc_decrypt,
452 	.init			= cbc_init,
453 	.exit			= cbc_exit,
454 }, {
455 	.base.cra_name		= "__ctr(aes)",
456 	.base.cra_driver_name	= "__ctr-aes-neonbs",
457 	.base.cra_priority	= 250,
458 	.base.cra_blocksize	= 1,
459 	.base.cra_ctxsize	= sizeof(struct aesbs_ctx),
460 	.base.cra_module	= THIS_MODULE,
461 	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
462 
463 	.min_keysize		= AES_MIN_KEY_SIZE,
464 	.max_keysize		= AES_MAX_KEY_SIZE,
465 	.chunksize		= AES_BLOCK_SIZE,
466 	.walksize		= 8 * AES_BLOCK_SIZE,
467 	.ivsize			= AES_BLOCK_SIZE,
468 	.setkey			= aesbs_setkey,
469 	.encrypt		= ctr_encrypt,
470 	.decrypt		= ctr_encrypt,
471 }, {
472 	.base.cra_name		= "ctr(aes)",
473 	.base.cra_driver_name	= "ctr-aes-neonbs-sync",
474 	.base.cra_priority	= 250 - 1,
475 	.base.cra_blocksize	= 1,
476 	.base.cra_ctxsize	= sizeof(struct aesbs_ctr_ctx),
477 	.base.cra_module	= THIS_MODULE,
478 
479 	.min_keysize		= AES_MIN_KEY_SIZE,
480 	.max_keysize		= AES_MAX_KEY_SIZE,
481 	.chunksize		= AES_BLOCK_SIZE,
482 	.walksize		= 8 * AES_BLOCK_SIZE,
483 	.ivsize			= AES_BLOCK_SIZE,
484 	.setkey			= aesbs_ctr_setkey_sync,
485 	.encrypt		= ctr_encrypt_sync,
486 	.decrypt		= ctr_encrypt_sync,
487 }, {
488 	.base.cra_name		= "__xts(aes)",
489 	.base.cra_driver_name	= "__xts-aes-neonbs",
490 	.base.cra_priority	= 250,
491 	.base.cra_blocksize	= AES_BLOCK_SIZE,
492 	.base.cra_ctxsize	= sizeof(struct aesbs_xts_ctx),
493 	.base.cra_module	= THIS_MODULE,
494 	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
495 
496 	.min_keysize		= 2 * AES_MIN_KEY_SIZE,
497 	.max_keysize		= 2 * AES_MAX_KEY_SIZE,
498 	.walksize		= 8 * AES_BLOCK_SIZE,
499 	.ivsize			= AES_BLOCK_SIZE,
500 	.setkey			= aesbs_xts_setkey,
501 	.encrypt		= xts_encrypt,
502 	.decrypt		= xts_decrypt,
503 	.init			= xts_init,
504 	.exit			= xts_exit,
505 } };
506 
507 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
508 
509 static void aes_exit(void)
510 {
511 	int i;
512 
513 	for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
514 		if (aes_simd_algs[i])
515 			simd_skcipher_free(aes_simd_algs[i]);
516 
517 	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
518 }
519 
520 static int __init aes_init(void)
521 {
522 	struct simd_skcipher_alg *simd;
523 	const char *basename;
524 	const char *algname;
525 	const char *drvname;
526 	int err;
527 	int i;
528 
529 	if (!(elf_hwcap & HWCAP_NEON))
530 		return -ENODEV;
531 
532 	err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
533 	if (err)
534 		return err;
535 
536 	for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
537 		if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
538 			continue;
539 
540 		algname = aes_algs[i].base.cra_name + 2;
541 		drvname = aes_algs[i].base.cra_driver_name + 2;
542 		basename = aes_algs[i].base.cra_driver_name;
543 		simd = simd_skcipher_create_compat(algname, drvname, basename);
544 		err = PTR_ERR(simd);
545 		if (IS_ERR(simd))
546 			goto unregister_simds;
547 
548 		aes_simd_algs[i] = simd;
549 	}
550 	return 0;
551 
552 unregister_simds:
553 	aes_exit();
554 	return err;
555 }
556 
557 late_initcall(aes_init);
558 module_exit(aes_exit);
559