xref: /linux/arch/arm/crypto/aes-neonbs-glue.c (revision 5afca7e996c42aed1b4a42d4712817601ba42aff)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Bit sliced AES using NEON instructions
4  *
5  * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7 
8 #include <asm/neon.h>
9 #include <asm/simd.h>
10 #include <crypto/aes.h>
11 #include <crypto/ctr.h>
12 #include <crypto/internal/simd.h>
13 #include <crypto/internal/skcipher.h>
14 #include <crypto/scatterwalk.h>
15 #include <crypto/xts.h>
16 #include <linux/module.h>
17 #include "aes-cipher.h"
18 
19 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
20 MODULE_DESCRIPTION("Bit sliced AES using NEON instructions");
21 MODULE_LICENSE("GPL v2");
22 
23 MODULE_ALIAS_CRYPTO("ecb(aes)");
24 MODULE_ALIAS_CRYPTO("cbc(aes)");
25 MODULE_ALIAS_CRYPTO("ctr(aes)");
26 MODULE_ALIAS_CRYPTO("xts(aes)");
27 
28 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
29 
30 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
31 				  int rounds, int blocks);
32 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
33 				  int rounds, int blocks);
34 
35 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
36 				  int rounds, int blocks, u8 iv[]);
37 
38 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
39 				  int rounds, int blocks, u8 ctr[]);
40 
41 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
42 				  int rounds, int blocks, u8 iv[], int);
43 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
44 				  int rounds, int blocks, u8 iv[], int);
45 
46 struct aesbs_ctx {
47 	int	rounds;
48 	u8	rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
49 };
50 
51 struct aesbs_cbc_ctx {
52 	struct aesbs_ctx	key;
53 	struct crypto_aes_ctx	fallback;
54 };
55 
56 struct aesbs_xts_ctx {
57 	struct aesbs_ctx	key;
58 	struct crypto_aes_ctx	fallback;
59 	struct crypto_aes_ctx	tweak_key;
60 };
61 
62 struct aesbs_ctr_ctx {
63 	struct aesbs_ctx	key;		/* must be first member */
64 	struct crypto_aes_ctx	fallback;
65 };
66 
67 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
68 			unsigned int key_len)
69 {
70 	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
71 	struct crypto_aes_ctx rk;
72 	int err;
73 
74 	err = aes_expandkey(&rk, in_key, key_len);
75 	if (err)
76 		return err;
77 
78 	ctx->rounds = 6 + key_len / 4;
79 
80 	kernel_neon_begin();
81 	aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
82 	kernel_neon_end();
83 
84 	return 0;
85 }
86 
87 static int __ecb_crypt(struct skcipher_request *req,
88 		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
89 				  int rounds, int blocks))
90 {
91 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
92 	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
93 	struct skcipher_walk walk;
94 	int err;
95 
96 	err = skcipher_walk_virt(&walk, req, false);
97 
98 	while (walk.nbytes >= AES_BLOCK_SIZE) {
99 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
100 
101 		if (walk.nbytes < walk.total)
102 			blocks = round_down(blocks,
103 					    walk.stride / AES_BLOCK_SIZE);
104 
105 		kernel_neon_begin();
106 		fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
107 		   ctx->rounds, blocks);
108 		kernel_neon_end();
109 		err = skcipher_walk_done(&walk,
110 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
111 	}
112 
113 	return err;
114 }
115 
116 static int ecb_encrypt(struct skcipher_request *req)
117 {
118 	return __ecb_crypt(req, aesbs_ecb_encrypt);
119 }
120 
121 static int ecb_decrypt(struct skcipher_request *req)
122 {
123 	return __ecb_crypt(req, aesbs_ecb_decrypt);
124 }
125 
126 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
127 			    unsigned int key_len)
128 {
129 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
130 	int err;
131 
132 	err = aes_expandkey(&ctx->fallback, in_key, key_len);
133 	if (err)
134 		return err;
135 
136 	ctx->key.rounds = 6 + key_len / 4;
137 
138 	kernel_neon_begin();
139 	aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
140 	kernel_neon_end();
141 
142 	return 0;
143 }
144 
145 static int cbc_encrypt(struct skcipher_request *req)
146 {
147 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
148 	const struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
149 	struct skcipher_walk walk;
150 	unsigned int nbytes;
151 	int err;
152 
153 	err = skcipher_walk_virt(&walk, req, false);
154 
155 	while ((nbytes = walk.nbytes) >= AES_BLOCK_SIZE) {
156 		const u8 *src = walk.src.virt.addr;
157 		u8 *dst = walk.dst.virt.addr;
158 		u8 *prev = walk.iv;
159 
160 		do {
161 			crypto_xor_cpy(dst, src, prev, AES_BLOCK_SIZE);
162 			__aes_arm_encrypt(ctx->fallback.key_enc,
163 					  ctx->key.rounds, dst, dst);
164 			prev = dst;
165 			src += AES_BLOCK_SIZE;
166 			dst += AES_BLOCK_SIZE;
167 			nbytes -= AES_BLOCK_SIZE;
168 		} while (nbytes >= AES_BLOCK_SIZE);
169 		memcpy(walk.iv, prev, AES_BLOCK_SIZE);
170 		err = skcipher_walk_done(&walk, nbytes);
171 	}
172 	return err;
173 }
174 
175 static int cbc_decrypt(struct skcipher_request *req)
176 {
177 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
178 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
179 	struct skcipher_walk walk;
180 	int err;
181 
182 	err = skcipher_walk_virt(&walk, req, false);
183 
184 	while (walk.nbytes >= AES_BLOCK_SIZE) {
185 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
186 
187 		if (walk.nbytes < walk.total)
188 			blocks = round_down(blocks,
189 					    walk.stride / AES_BLOCK_SIZE);
190 
191 		kernel_neon_begin();
192 		aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
193 				  ctx->key.rk, ctx->key.rounds, blocks,
194 				  walk.iv);
195 		kernel_neon_end();
196 		err = skcipher_walk_done(&walk,
197 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
198 	}
199 
200 	return err;
201 }
202 
203 static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key,
204 				 unsigned int key_len)
205 {
206 	struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
207 	int err;
208 
209 	err = aes_expandkey(&ctx->fallback, in_key, key_len);
210 	if (err)
211 		return err;
212 
213 	ctx->key.rounds = 6 + key_len / 4;
214 
215 	kernel_neon_begin();
216 	aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
217 	kernel_neon_end();
218 
219 	return 0;
220 }
221 
222 static int ctr_encrypt(struct skcipher_request *req)
223 {
224 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
225 	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
226 	struct skcipher_walk walk;
227 	u8 buf[AES_BLOCK_SIZE];
228 	int err;
229 
230 	err = skcipher_walk_virt(&walk, req, false);
231 
232 	while (walk.nbytes > 0) {
233 		const u8 *src = walk.src.virt.addr;
234 		u8 *dst = walk.dst.virt.addr;
235 		int bytes = walk.nbytes;
236 
237 		if (unlikely(bytes < AES_BLOCK_SIZE))
238 			src = dst = memcpy(buf + sizeof(buf) - bytes,
239 					   src, bytes);
240 		else if (walk.nbytes < walk.total)
241 			bytes &= ~(8 * AES_BLOCK_SIZE - 1);
242 
243 		kernel_neon_begin();
244 		aesbs_ctr_encrypt(dst, src, ctx->rk, ctx->rounds, bytes, walk.iv);
245 		kernel_neon_end();
246 
247 		if (unlikely(bytes < AES_BLOCK_SIZE))
248 			memcpy(walk.dst.virt.addr,
249 			       buf + sizeof(buf) - bytes, bytes);
250 
251 		err = skcipher_walk_done(&walk, walk.nbytes - bytes);
252 	}
253 
254 	return err;
255 }
256 
257 static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
258 {
259 	struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
260 
261 	__aes_arm_encrypt(ctx->fallback.key_enc, ctx->key.rounds, src, dst);
262 }
263 
264 static int ctr_encrypt_sync(struct skcipher_request *req)
265 {
266 	if (!crypto_simd_usable())
267 		return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
268 
269 	return ctr_encrypt(req);
270 }
271 
272 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
273 			    unsigned int key_len)
274 {
275 	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
276 	int err;
277 
278 	err = xts_verify_key(tfm, in_key, key_len);
279 	if (err)
280 		return err;
281 
282 	key_len /= 2;
283 	err = aes_expandkey(&ctx->fallback, in_key, key_len);
284 	if (err)
285 		return err;
286 	err = aes_expandkey(&ctx->tweak_key, in_key + key_len, key_len);
287 	if (err)
288 		return err;
289 
290 	return aesbs_setkey(tfm, in_key, key_len);
291 }
292 
293 static int __xts_crypt(struct skcipher_request *req, bool encrypt,
294 		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
295 				  int rounds, int blocks, u8 iv[], int))
296 {
297 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
298 	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
299 	const int rounds = ctx->key.rounds;
300 	int tail = req->cryptlen % AES_BLOCK_SIZE;
301 	struct skcipher_request subreq;
302 	u8 buf[2 * AES_BLOCK_SIZE];
303 	struct skcipher_walk walk;
304 	int err;
305 
306 	if (req->cryptlen < AES_BLOCK_SIZE)
307 		return -EINVAL;
308 
309 	if (unlikely(tail)) {
310 		skcipher_request_set_tfm(&subreq, tfm);
311 		skcipher_request_set_callback(&subreq,
312 					      skcipher_request_flags(req),
313 					      NULL, NULL);
314 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
315 					   req->cryptlen - tail, req->iv);
316 		req = &subreq;
317 	}
318 
319 	err = skcipher_walk_virt(&walk, req, true);
320 	if (err)
321 		return err;
322 
323 	__aes_arm_encrypt(ctx->tweak_key.key_enc, rounds, walk.iv, walk.iv);
324 
325 	while (walk.nbytes >= AES_BLOCK_SIZE) {
326 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
327 		int reorder_last_tweak = !encrypt && tail > 0;
328 
329 		if (walk.nbytes < walk.total) {
330 			blocks = round_down(blocks,
331 					    walk.stride / AES_BLOCK_SIZE);
332 			reorder_last_tweak = 0;
333 		}
334 
335 		kernel_neon_begin();
336 		fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
337 		   rounds, blocks, walk.iv, reorder_last_tweak);
338 		kernel_neon_end();
339 		err = skcipher_walk_done(&walk,
340 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
341 	}
342 
343 	if (err || likely(!tail))
344 		return err;
345 
346 	/* handle ciphertext stealing */
347 	scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
348 				 AES_BLOCK_SIZE, 0);
349 	memcpy(buf + AES_BLOCK_SIZE, buf, tail);
350 	scatterwalk_map_and_copy(buf, req->src, req->cryptlen, tail, 0);
351 
352 	crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
353 
354 	if (encrypt)
355 		__aes_arm_encrypt(ctx->fallback.key_enc, rounds, buf, buf);
356 	else
357 		__aes_arm_decrypt(ctx->fallback.key_dec, rounds, buf, buf);
358 
359 	crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
360 
361 	scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
362 				 AES_BLOCK_SIZE + tail, 1);
363 	return 0;
364 }
365 
366 static int xts_encrypt(struct skcipher_request *req)
367 {
368 	return __xts_crypt(req, true, aesbs_xts_encrypt);
369 }
370 
371 static int xts_decrypt(struct skcipher_request *req)
372 {
373 	return __xts_crypt(req, false, aesbs_xts_decrypt);
374 }
375 
376 static struct skcipher_alg aes_algs[] = { {
377 	.base.cra_name		= "__ecb(aes)",
378 	.base.cra_driver_name	= "__ecb-aes-neonbs",
379 	.base.cra_priority	= 250,
380 	.base.cra_blocksize	= AES_BLOCK_SIZE,
381 	.base.cra_ctxsize	= sizeof(struct aesbs_ctx),
382 	.base.cra_module	= THIS_MODULE,
383 	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
384 
385 	.min_keysize		= AES_MIN_KEY_SIZE,
386 	.max_keysize		= AES_MAX_KEY_SIZE,
387 	.walksize		= 8 * AES_BLOCK_SIZE,
388 	.setkey			= aesbs_setkey,
389 	.encrypt		= ecb_encrypt,
390 	.decrypt		= ecb_decrypt,
391 }, {
392 	.base.cra_name		= "__cbc(aes)",
393 	.base.cra_driver_name	= "__cbc-aes-neonbs",
394 	.base.cra_priority	= 250,
395 	.base.cra_blocksize	= AES_BLOCK_SIZE,
396 	.base.cra_ctxsize	= sizeof(struct aesbs_cbc_ctx),
397 	.base.cra_module	= THIS_MODULE,
398 	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
399 
400 	.min_keysize		= AES_MIN_KEY_SIZE,
401 	.max_keysize		= AES_MAX_KEY_SIZE,
402 	.walksize		= 8 * AES_BLOCK_SIZE,
403 	.ivsize			= AES_BLOCK_SIZE,
404 	.setkey			= aesbs_cbc_setkey,
405 	.encrypt		= cbc_encrypt,
406 	.decrypt		= cbc_decrypt,
407 }, {
408 	.base.cra_name		= "__ctr(aes)",
409 	.base.cra_driver_name	= "__ctr-aes-neonbs",
410 	.base.cra_priority	= 250,
411 	.base.cra_blocksize	= 1,
412 	.base.cra_ctxsize	= sizeof(struct aesbs_ctx),
413 	.base.cra_module	= THIS_MODULE,
414 	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
415 
416 	.min_keysize		= AES_MIN_KEY_SIZE,
417 	.max_keysize		= AES_MAX_KEY_SIZE,
418 	.chunksize		= AES_BLOCK_SIZE,
419 	.walksize		= 8 * AES_BLOCK_SIZE,
420 	.ivsize			= AES_BLOCK_SIZE,
421 	.setkey			= aesbs_setkey,
422 	.encrypt		= ctr_encrypt,
423 	.decrypt		= ctr_encrypt,
424 }, {
425 	.base.cra_name		= "ctr(aes)",
426 	.base.cra_driver_name	= "ctr-aes-neonbs-sync",
427 	.base.cra_priority	= 250 - 1,
428 	.base.cra_blocksize	= 1,
429 	.base.cra_ctxsize	= sizeof(struct aesbs_ctr_ctx),
430 	.base.cra_module	= THIS_MODULE,
431 
432 	.min_keysize		= AES_MIN_KEY_SIZE,
433 	.max_keysize		= AES_MAX_KEY_SIZE,
434 	.chunksize		= AES_BLOCK_SIZE,
435 	.walksize		= 8 * AES_BLOCK_SIZE,
436 	.ivsize			= AES_BLOCK_SIZE,
437 	.setkey			= aesbs_ctr_setkey_sync,
438 	.encrypt		= ctr_encrypt_sync,
439 	.decrypt		= ctr_encrypt_sync,
440 }, {
441 	.base.cra_name		= "__xts(aes)",
442 	.base.cra_driver_name	= "__xts-aes-neonbs",
443 	.base.cra_priority	= 250,
444 	.base.cra_blocksize	= AES_BLOCK_SIZE,
445 	.base.cra_ctxsize	= sizeof(struct aesbs_xts_ctx),
446 	.base.cra_module	= THIS_MODULE,
447 	.base.cra_flags		= CRYPTO_ALG_INTERNAL,
448 
449 	.min_keysize		= 2 * AES_MIN_KEY_SIZE,
450 	.max_keysize		= 2 * AES_MAX_KEY_SIZE,
451 	.walksize		= 8 * AES_BLOCK_SIZE,
452 	.ivsize			= AES_BLOCK_SIZE,
453 	.setkey			= aesbs_xts_setkey,
454 	.encrypt		= xts_encrypt,
455 	.decrypt		= xts_decrypt,
456 } };
457 
458 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
459 
460 static void aes_exit(void)
461 {
462 	int i;
463 
464 	for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
465 		if (aes_simd_algs[i])
466 			simd_skcipher_free(aes_simd_algs[i]);
467 
468 	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
469 }
470 
471 static int __init aes_init(void)
472 {
473 	struct simd_skcipher_alg *simd;
474 	const char *basename;
475 	const char *algname;
476 	const char *drvname;
477 	int err;
478 	int i;
479 
480 	if (!(elf_hwcap & HWCAP_NEON))
481 		return -ENODEV;
482 
483 	err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
484 	if (err)
485 		return err;
486 
487 	for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
488 		if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
489 			continue;
490 
491 		algname = aes_algs[i].base.cra_name + 2;
492 		drvname = aes_algs[i].base.cra_driver_name + 2;
493 		basename = aes_algs[i].base.cra_driver_name;
494 		simd = simd_skcipher_create_compat(aes_algs + i, algname, drvname, basename);
495 		err = PTR_ERR(simd);
496 		if (IS_ERR(simd))
497 			goto unregister_simds;
498 
499 		aes_simd_algs[i] = simd;
500 	}
501 	return 0;
502 
503 unregister_simds:
504 	aes_exit();
505 	return err;
506 }
507 
508 late_initcall(aes_init);
509 module_exit(aes_exit);
510