xref: /linux/arch/arm64/crypto/aes-glue.c (revision 78c3925c048c752334873f56c3a3d1c9d53e0416)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
4  *
5  * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7 
8 #include <asm/neon.h>
9 #include <asm/hwcap.h>
10 #include <asm/simd.h>
11 #include <crypto/aes.h>
12 #include <crypto/ctr.h>
13 #include <crypto/sha2.h>
14 #include <crypto/internal/hash.h>
15 #include <crypto/internal/simd.h>
16 #include <crypto/internal/skcipher.h>
17 #include <crypto/scatterwalk.h>
18 #include <linux/module.h>
19 #include <linux/cpufeature.h>
20 #include <crypto/xts.h>
21 
22 #include "aes-ce-setkey.h"
23 
24 #ifdef USE_V8_CRYPTO_EXTENSIONS
25 #define MODE			"ce"
26 #define PRIO			300
27 #define aes_expandkey		ce_aes_expandkey
28 #define aes_ecb_encrypt		ce_aes_ecb_encrypt
29 #define aes_ecb_decrypt		ce_aes_ecb_decrypt
30 #define aes_cbc_encrypt		ce_aes_cbc_encrypt
31 #define aes_cbc_decrypt		ce_aes_cbc_decrypt
32 #define aes_cbc_cts_encrypt	ce_aes_cbc_cts_encrypt
33 #define aes_cbc_cts_decrypt	ce_aes_cbc_cts_decrypt
34 #define aes_essiv_cbc_encrypt	ce_aes_essiv_cbc_encrypt
35 #define aes_essiv_cbc_decrypt	ce_aes_essiv_cbc_decrypt
36 #define aes_ctr_encrypt		ce_aes_ctr_encrypt
37 #define aes_xctr_encrypt	ce_aes_xctr_encrypt
38 #define aes_xts_encrypt		ce_aes_xts_encrypt
39 #define aes_xts_decrypt		ce_aes_xts_decrypt
40 #define aes_mac_update		ce_aes_mac_update
41 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 Crypto Extensions");
42 #else
43 #define MODE			"neon"
44 #define PRIO			200
45 #define aes_ecb_encrypt		neon_aes_ecb_encrypt
46 #define aes_ecb_decrypt		neon_aes_ecb_decrypt
47 #define aes_cbc_encrypt		neon_aes_cbc_encrypt
48 #define aes_cbc_decrypt		neon_aes_cbc_decrypt
49 #define aes_cbc_cts_encrypt	neon_aes_cbc_cts_encrypt
50 #define aes_cbc_cts_decrypt	neon_aes_cbc_cts_decrypt
51 #define aes_essiv_cbc_encrypt	neon_aes_essiv_cbc_encrypt
52 #define aes_essiv_cbc_decrypt	neon_aes_essiv_cbc_decrypt
53 #define aes_ctr_encrypt		neon_aes_ctr_encrypt
54 #define aes_xctr_encrypt	neon_aes_xctr_encrypt
55 #define aes_xts_encrypt		neon_aes_xts_encrypt
56 #define aes_xts_decrypt		neon_aes_xts_decrypt
57 #define aes_mac_update		neon_aes_mac_update
58 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 NEON");
59 #endif
60 #if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
61 MODULE_ALIAS_CRYPTO("ecb(aes)");
62 MODULE_ALIAS_CRYPTO("cbc(aes)");
63 MODULE_ALIAS_CRYPTO("ctr(aes)");
64 MODULE_ALIAS_CRYPTO("xts(aes)");
65 MODULE_ALIAS_CRYPTO("xctr(aes)");
66 #endif
67 MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
68 MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)");
69 MODULE_ALIAS_CRYPTO("cmac(aes)");
70 MODULE_ALIAS_CRYPTO("xcbc(aes)");
71 MODULE_ALIAS_CRYPTO("cbcmac(aes)");
72 
73 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
74 MODULE_LICENSE("GPL v2");
75 
76 /* defined in aes-modes.S */
77 asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
78 				int rounds, int blocks);
79 asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
80 				int rounds, int blocks);
81 
82 asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
83 				int rounds, int blocks, u8 iv[]);
84 asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
85 				int rounds, int blocks, u8 iv[]);
86 
87 asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
88 				int rounds, int bytes, u8 const iv[]);
89 asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
90 				int rounds, int bytes, u8 const iv[]);
91 
92 asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
93 				int rounds, int bytes, u8 ctr[]);
94 
95 asmlinkage void aes_xctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
96 				 int rounds, int bytes, u8 ctr[], int byte_ctr);
97 
98 asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
99 				int rounds, int bytes, u32 const rk2[], u8 iv[],
100 				int first);
101 asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
102 				int rounds, int bytes, u32 const rk2[], u8 iv[],
103 				int first);
104 
105 asmlinkage void aes_essiv_cbc_encrypt(u8 out[], u8 const in[], u32 const rk1[],
106 				      int rounds, int blocks, u8 iv[],
107 				      u32 const rk2[]);
108 asmlinkage void aes_essiv_cbc_decrypt(u8 out[], u8 const in[], u32 const rk1[],
109 				      int rounds, int blocks, u8 iv[],
110 				      u32 const rk2[]);
111 
112 asmlinkage int aes_mac_update(u8 const in[], u32 const rk[], int rounds,
113 			      int blocks, u8 dg[], int enc_before,
114 			      int enc_after);
115 
116 struct crypto_aes_xts_ctx {
117 	struct crypto_aes_ctx key1;
118 	struct crypto_aes_ctx __aligned(8) key2;
119 };
120 
121 struct crypto_aes_essiv_cbc_ctx {
122 	struct crypto_aes_ctx key1;
123 	struct crypto_aes_ctx __aligned(8) key2;
124 	struct crypto_shash *hash;
125 };
126 
127 struct mac_tfm_ctx {
128 	struct crypto_aes_ctx key;
129 	u8 __aligned(8) consts[];
130 };
131 
132 struct mac_desc_ctx {
133 	unsigned int len;
134 	u8 dg[AES_BLOCK_SIZE];
135 };
136 
137 static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
138 			       unsigned int key_len)
139 {
140 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
141 
142 	return aes_expandkey(ctx, in_key, key_len);
143 }
144 
145 static int __maybe_unused xts_set_key(struct crypto_skcipher *tfm,
146 				      const u8 *in_key, unsigned int key_len)
147 {
148 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
149 	int ret;
150 
151 	ret = xts_verify_key(tfm, in_key, key_len);
152 	if (ret)
153 		return ret;
154 
155 	ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
156 	if (!ret)
157 		ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
158 				    key_len / 2);
159 	return ret;
160 }
161 
162 static int __maybe_unused essiv_cbc_set_key(struct crypto_skcipher *tfm,
163 					    const u8 *in_key,
164 					    unsigned int key_len)
165 {
166 	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
167 	u8 digest[SHA256_DIGEST_SIZE];
168 	int ret;
169 
170 	ret = aes_expandkey(&ctx->key1, in_key, key_len);
171 	if (ret)
172 		return ret;
173 
174 	crypto_shash_tfm_digest(ctx->hash, in_key, key_len, digest);
175 
176 	return aes_expandkey(&ctx->key2, digest, sizeof(digest));
177 }
178 
179 static int __maybe_unused ecb_encrypt(struct skcipher_request *req)
180 {
181 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
182 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
183 	int err, rounds = 6 + ctx->key_length / 4;
184 	struct skcipher_walk walk;
185 	unsigned int blocks;
186 
187 	err = skcipher_walk_virt(&walk, req, false);
188 
189 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
190 		kernel_neon_begin();
191 		aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
192 				ctx->key_enc, rounds, blocks);
193 		kernel_neon_end();
194 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
195 	}
196 	return err;
197 }
198 
199 static int __maybe_unused ecb_decrypt(struct skcipher_request *req)
200 {
201 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
202 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
203 	int err, rounds = 6 + ctx->key_length / 4;
204 	struct skcipher_walk walk;
205 	unsigned int blocks;
206 
207 	err = skcipher_walk_virt(&walk, req, false);
208 
209 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
210 		kernel_neon_begin();
211 		aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
212 				ctx->key_dec, rounds, blocks);
213 		kernel_neon_end();
214 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
215 	}
216 	return err;
217 }
218 
219 static int cbc_encrypt_walk(struct skcipher_request *req,
220 			    struct skcipher_walk *walk)
221 {
222 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
223 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
224 	int err = 0, rounds = 6 + ctx->key_length / 4;
225 	unsigned int blocks;
226 
227 	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
228 		kernel_neon_begin();
229 		aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
230 				ctx->key_enc, rounds, blocks, walk->iv);
231 		kernel_neon_end();
232 		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
233 	}
234 	return err;
235 }
236 
237 static int __maybe_unused cbc_encrypt(struct skcipher_request *req)
238 {
239 	struct skcipher_walk walk;
240 	int err;
241 
242 	err = skcipher_walk_virt(&walk, req, false);
243 	if (err)
244 		return err;
245 	return cbc_encrypt_walk(req, &walk);
246 }
247 
248 static int cbc_decrypt_walk(struct skcipher_request *req,
249 			    struct skcipher_walk *walk)
250 {
251 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
252 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
253 	int err = 0, rounds = 6 + ctx->key_length / 4;
254 	unsigned int blocks;
255 
256 	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
257 		kernel_neon_begin();
258 		aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
259 				ctx->key_dec, rounds, blocks, walk->iv);
260 		kernel_neon_end();
261 		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
262 	}
263 	return err;
264 }
265 
266 static int __maybe_unused cbc_decrypt(struct skcipher_request *req)
267 {
268 	struct skcipher_walk walk;
269 	int err;
270 
271 	err = skcipher_walk_virt(&walk, req, false);
272 	if (err)
273 		return err;
274 	return cbc_decrypt_walk(req, &walk);
275 }
276 
277 static int cts_cbc_encrypt(struct skcipher_request *req)
278 {
279 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
280 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
281 	int err, rounds = 6 + ctx->key_length / 4;
282 	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
283 	struct scatterlist *src = req->src, *dst = req->dst;
284 	struct scatterlist sg_src[2], sg_dst[2];
285 	struct skcipher_request subreq;
286 	struct skcipher_walk walk;
287 
288 	skcipher_request_set_tfm(&subreq, tfm);
289 	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
290 				      NULL, NULL);
291 
292 	if (req->cryptlen <= AES_BLOCK_SIZE) {
293 		if (req->cryptlen < AES_BLOCK_SIZE)
294 			return -EINVAL;
295 		cbc_blocks = 1;
296 	}
297 
298 	if (cbc_blocks > 0) {
299 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
300 					   cbc_blocks * AES_BLOCK_SIZE,
301 					   req->iv);
302 
303 		err = skcipher_walk_virt(&walk, &subreq, false) ?:
304 		      cbc_encrypt_walk(&subreq, &walk);
305 		if (err)
306 			return err;
307 
308 		if (req->cryptlen == AES_BLOCK_SIZE)
309 			return 0;
310 
311 		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
312 		if (req->dst != req->src)
313 			dst = scatterwalk_ffwd(sg_dst, req->dst,
314 					       subreq.cryptlen);
315 	}
316 
317 	/* handle ciphertext stealing */
318 	skcipher_request_set_crypt(&subreq, src, dst,
319 				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
320 				   req->iv);
321 
322 	err = skcipher_walk_virt(&walk, &subreq, false);
323 	if (err)
324 		return err;
325 
326 	kernel_neon_begin();
327 	aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
328 			    ctx->key_enc, rounds, walk.nbytes, walk.iv);
329 	kernel_neon_end();
330 
331 	return skcipher_walk_done(&walk, 0);
332 }
333 
334 static int cts_cbc_decrypt(struct skcipher_request *req)
335 {
336 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
337 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
338 	int err, rounds = 6 + ctx->key_length / 4;
339 	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
340 	struct scatterlist *src = req->src, *dst = req->dst;
341 	struct scatterlist sg_src[2], sg_dst[2];
342 	struct skcipher_request subreq;
343 	struct skcipher_walk walk;
344 
345 	skcipher_request_set_tfm(&subreq, tfm);
346 	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
347 				      NULL, NULL);
348 
349 	if (req->cryptlen <= AES_BLOCK_SIZE) {
350 		if (req->cryptlen < AES_BLOCK_SIZE)
351 			return -EINVAL;
352 		cbc_blocks = 1;
353 	}
354 
355 	if (cbc_blocks > 0) {
356 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
357 					   cbc_blocks * AES_BLOCK_SIZE,
358 					   req->iv);
359 
360 		err = skcipher_walk_virt(&walk, &subreq, false) ?:
361 		      cbc_decrypt_walk(&subreq, &walk);
362 		if (err)
363 			return err;
364 
365 		if (req->cryptlen == AES_BLOCK_SIZE)
366 			return 0;
367 
368 		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
369 		if (req->dst != req->src)
370 			dst = scatterwalk_ffwd(sg_dst, req->dst,
371 					       subreq.cryptlen);
372 	}
373 
374 	/* handle ciphertext stealing */
375 	skcipher_request_set_crypt(&subreq, src, dst,
376 				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
377 				   req->iv);
378 
379 	err = skcipher_walk_virt(&walk, &subreq, false);
380 	if (err)
381 		return err;
382 
383 	kernel_neon_begin();
384 	aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
385 			    ctx->key_dec, rounds, walk.nbytes, walk.iv);
386 	kernel_neon_end();
387 
388 	return skcipher_walk_done(&walk, 0);
389 }
390 
391 static int __maybe_unused essiv_cbc_init_tfm(struct crypto_skcipher *tfm)
392 {
393 	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
394 
395 	ctx->hash = crypto_alloc_shash("sha256", 0, 0);
396 
397 	return PTR_ERR_OR_ZERO(ctx->hash);
398 }
399 
400 static void __maybe_unused essiv_cbc_exit_tfm(struct crypto_skcipher *tfm)
401 {
402 	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
403 
404 	crypto_free_shash(ctx->hash);
405 }
406 
407 static int __maybe_unused essiv_cbc_encrypt(struct skcipher_request *req)
408 {
409 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
410 	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
411 	int err, rounds = 6 + ctx->key1.key_length / 4;
412 	struct skcipher_walk walk;
413 	unsigned int blocks;
414 
415 	err = skcipher_walk_virt(&walk, req, false);
416 
417 	blocks = walk.nbytes / AES_BLOCK_SIZE;
418 	if (blocks) {
419 		kernel_neon_begin();
420 		aes_essiv_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
421 				      ctx->key1.key_enc, rounds, blocks,
422 				      req->iv, ctx->key2.key_enc);
423 		kernel_neon_end();
424 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
425 	}
426 	return err ?: cbc_encrypt_walk(req, &walk);
427 }
428 
429 static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)
430 {
431 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
432 	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
433 	int err, rounds = 6 + ctx->key1.key_length / 4;
434 	struct skcipher_walk walk;
435 	unsigned int blocks;
436 
437 	err = skcipher_walk_virt(&walk, req, false);
438 
439 	blocks = walk.nbytes / AES_BLOCK_SIZE;
440 	if (blocks) {
441 		kernel_neon_begin();
442 		aes_essiv_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
443 				      ctx->key1.key_dec, rounds, blocks,
444 				      req->iv, ctx->key2.key_enc);
445 		kernel_neon_end();
446 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
447 	}
448 	return err ?: cbc_decrypt_walk(req, &walk);
449 }
450 
451 static int __maybe_unused xctr_encrypt(struct skcipher_request *req)
452 {
453 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
454 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
455 	int err, rounds = 6 + ctx->key_length / 4;
456 	struct skcipher_walk walk;
457 	unsigned int byte_ctr = 0;
458 
459 	err = skcipher_walk_virt(&walk, req, false);
460 
461 	while (walk.nbytes > 0) {
462 		const u8 *src = walk.src.virt.addr;
463 		unsigned int nbytes = walk.nbytes;
464 		u8 *dst = walk.dst.virt.addr;
465 		u8 buf[AES_BLOCK_SIZE];
466 
467 		/*
468 		 * If given less than 16 bytes, we must copy the partial block
469 		 * into a temporary buffer of 16 bytes to avoid out of bounds
470 		 * reads and writes.  Furthermore, this code is somewhat unusual
471 		 * in that it expects the end of the data to be at the end of
472 		 * the temporary buffer, rather than the start of the data at
473 		 * the start of the temporary buffer.
474 		 */
475 		if (unlikely(nbytes < AES_BLOCK_SIZE))
476 			src = dst = memcpy(buf + sizeof(buf) - nbytes,
477 					   src, nbytes);
478 		else if (nbytes < walk.total)
479 			nbytes &= ~(AES_BLOCK_SIZE - 1);
480 
481 		kernel_neon_begin();
482 		aes_xctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
483 						 walk.iv, byte_ctr);
484 		kernel_neon_end();
485 
486 		if (unlikely(nbytes < AES_BLOCK_SIZE))
487 			memcpy(walk.dst.virt.addr,
488 			       buf + sizeof(buf) - nbytes, nbytes);
489 		byte_ctr += nbytes;
490 
491 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
492 	}
493 
494 	return err;
495 }
496 
497 static int __maybe_unused ctr_encrypt(struct skcipher_request *req)
498 {
499 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
500 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
501 	int err, rounds = 6 + ctx->key_length / 4;
502 	struct skcipher_walk walk;
503 
504 	err = skcipher_walk_virt(&walk, req, false);
505 
506 	while (walk.nbytes > 0) {
507 		const u8 *src = walk.src.virt.addr;
508 		unsigned int nbytes = walk.nbytes;
509 		u8 *dst = walk.dst.virt.addr;
510 		u8 buf[AES_BLOCK_SIZE];
511 
512 		/*
513 		 * If given less than 16 bytes, we must copy the partial block
514 		 * into a temporary buffer of 16 bytes to avoid out of bounds
515 		 * reads and writes.  Furthermore, this code is somewhat unusual
516 		 * in that it expects the end of the data to be at the end of
517 		 * the temporary buffer, rather than the start of the data at
518 		 * the start of the temporary buffer.
519 		 */
520 		if (unlikely(nbytes < AES_BLOCK_SIZE))
521 			src = dst = memcpy(buf + sizeof(buf) - nbytes,
522 					   src, nbytes);
523 		else if (nbytes < walk.total)
524 			nbytes &= ~(AES_BLOCK_SIZE - 1);
525 
526 		kernel_neon_begin();
527 		aes_ctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
528 				walk.iv);
529 		kernel_neon_end();
530 
531 		if (unlikely(nbytes < AES_BLOCK_SIZE))
532 			memcpy(walk.dst.virt.addr,
533 			       buf + sizeof(buf) - nbytes, nbytes);
534 
535 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
536 	}
537 
538 	return err;
539 }
540 
541 static int __maybe_unused xts_encrypt(struct skcipher_request *req)
542 {
543 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
544 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
545 	int err, first, rounds = 6 + ctx->key1.key_length / 4;
546 	int tail = req->cryptlen % AES_BLOCK_SIZE;
547 	struct scatterlist sg_src[2], sg_dst[2];
548 	struct skcipher_request subreq;
549 	struct scatterlist *src, *dst;
550 	struct skcipher_walk walk;
551 
552 	if (req->cryptlen < AES_BLOCK_SIZE)
553 		return -EINVAL;
554 
555 	err = skcipher_walk_virt(&walk, req, false);
556 
557 	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
558 		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
559 					      AES_BLOCK_SIZE) - 2;
560 
561 		skcipher_walk_abort(&walk);
562 
563 		skcipher_request_set_tfm(&subreq, tfm);
564 		skcipher_request_set_callback(&subreq,
565 					      skcipher_request_flags(req),
566 					      NULL, NULL);
567 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
568 					   xts_blocks * AES_BLOCK_SIZE,
569 					   req->iv);
570 		req = &subreq;
571 		err = skcipher_walk_virt(&walk, req, false);
572 	} else {
573 		tail = 0;
574 	}
575 
576 	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
577 		int nbytes = walk.nbytes;
578 
579 		if (walk.nbytes < walk.total)
580 			nbytes &= ~(AES_BLOCK_SIZE - 1);
581 
582 		kernel_neon_begin();
583 		aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
584 				ctx->key1.key_enc, rounds, nbytes,
585 				ctx->key2.key_enc, walk.iv, first);
586 		kernel_neon_end();
587 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
588 	}
589 
590 	if (err || likely(!tail))
591 		return err;
592 
593 	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
594 	if (req->dst != req->src)
595 		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
596 
597 	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
598 				   req->iv);
599 
600 	err = skcipher_walk_virt(&walk, &subreq, false);
601 	if (err)
602 		return err;
603 
604 	kernel_neon_begin();
605 	aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
606 			ctx->key1.key_enc, rounds, walk.nbytes,
607 			ctx->key2.key_enc, walk.iv, first);
608 	kernel_neon_end();
609 
610 	return skcipher_walk_done(&walk, 0);
611 }
612 
613 static int __maybe_unused xts_decrypt(struct skcipher_request *req)
614 {
615 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
616 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
617 	int err, first, rounds = 6 + ctx->key1.key_length / 4;
618 	int tail = req->cryptlen % AES_BLOCK_SIZE;
619 	struct scatterlist sg_src[2], sg_dst[2];
620 	struct skcipher_request subreq;
621 	struct scatterlist *src, *dst;
622 	struct skcipher_walk walk;
623 
624 	if (req->cryptlen < AES_BLOCK_SIZE)
625 		return -EINVAL;
626 
627 	err = skcipher_walk_virt(&walk, req, false);
628 
629 	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
630 		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
631 					      AES_BLOCK_SIZE) - 2;
632 
633 		skcipher_walk_abort(&walk);
634 
635 		skcipher_request_set_tfm(&subreq, tfm);
636 		skcipher_request_set_callback(&subreq,
637 					      skcipher_request_flags(req),
638 					      NULL, NULL);
639 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
640 					   xts_blocks * AES_BLOCK_SIZE,
641 					   req->iv);
642 		req = &subreq;
643 		err = skcipher_walk_virt(&walk, req, false);
644 	} else {
645 		tail = 0;
646 	}
647 
648 	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
649 		int nbytes = walk.nbytes;
650 
651 		if (walk.nbytes < walk.total)
652 			nbytes &= ~(AES_BLOCK_SIZE - 1);
653 
654 		kernel_neon_begin();
655 		aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
656 				ctx->key1.key_dec, rounds, nbytes,
657 				ctx->key2.key_enc, walk.iv, first);
658 		kernel_neon_end();
659 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
660 	}
661 
662 	if (err || likely(!tail))
663 		return err;
664 
665 	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
666 	if (req->dst != req->src)
667 		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
668 
669 	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
670 				   req->iv);
671 
672 	err = skcipher_walk_virt(&walk, &subreq, false);
673 	if (err)
674 		return err;
675 
676 
677 	kernel_neon_begin();
678 	aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
679 			ctx->key1.key_dec, rounds, walk.nbytes,
680 			ctx->key2.key_enc, walk.iv, first);
681 	kernel_neon_end();
682 
683 	return skcipher_walk_done(&walk, 0);
684 }
685 
686 static struct skcipher_alg aes_algs[] = { {
687 #if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
688 	.base = {
689 		.cra_name		= "ecb(aes)",
690 		.cra_driver_name	= "ecb-aes-" MODE,
691 		.cra_priority		= PRIO,
692 		.cra_blocksize		= AES_BLOCK_SIZE,
693 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
694 		.cra_module		= THIS_MODULE,
695 	},
696 	.min_keysize	= AES_MIN_KEY_SIZE,
697 	.max_keysize	= AES_MAX_KEY_SIZE,
698 	.setkey		= skcipher_aes_setkey,
699 	.encrypt	= ecb_encrypt,
700 	.decrypt	= ecb_decrypt,
701 }, {
702 	.base = {
703 		.cra_name		= "cbc(aes)",
704 		.cra_driver_name	= "cbc-aes-" MODE,
705 		.cra_priority		= PRIO,
706 		.cra_blocksize		= AES_BLOCK_SIZE,
707 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
708 		.cra_module		= THIS_MODULE,
709 	},
710 	.min_keysize	= AES_MIN_KEY_SIZE,
711 	.max_keysize	= AES_MAX_KEY_SIZE,
712 	.ivsize		= AES_BLOCK_SIZE,
713 	.setkey		= skcipher_aes_setkey,
714 	.encrypt	= cbc_encrypt,
715 	.decrypt	= cbc_decrypt,
716 }, {
717 	.base = {
718 		.cra_name		= "ctr(aes)",
719 		.cra_driver_name	= "ctr-aes-" MODE,
720 		.cra_priority		= PRIO,
721 		.cra_blocksize		= 1,
722 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
723 		.cra_module		= THIS_MODULE,
724 	},
725 	.min_keysize	= AES_MIN_KEY_SIZE,
726 	.max_keysize	= AES_MAX_KEY_SIZE,
727 	.ivsize		= AES_BLOCK_SIZE,
728 	.chunksize	= AES_BLOCK_SIZE,
729 	.setkey		= skcipher_aes_setkey,
730 	.encrypt	= ctr_encrypt,
731 	.decrypt	= ctr_encrypt,
732 }, {
733 	.base = {
734 		.cra_name		= "xctr(aes)",
735 		.cra_driver_name	= "xctr-aes-" MODE,
736 		.cra_priority		= PRIO,
737 		.cra_blocksize		= 1,
738 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
739 		.cra_module		= THIS_MODULE,
740 	},
741 	.min_keysize	= AES_MIN_KEY_SIZE,
742 	.max_keysize	= AES_MAX_KEY_SIZE,
743 	.ivsize		= AES_BLOCK_SIZE,
744 	.chunksize	= AES_BLOCK_SIZE,
745 	.setkey		= skcipher_aes_setkey,
746 	.encrypt	= xctr_encrypt,
747 	.decrypt	= xctr_encrypt,
748 }, {
749 	.base = {
750 		.cra_name		= "xts(aes)",
751 		.cra_driver_name	= "xts-aes-" MODE,
752 		.cra_priority		= PRIO,
753 		.cra_blocksize		= AES_BLOCK_SIZE,
754 		.cra_ctxsize		= sizeof(struct crypto_aes_xts_ctx),
755 		.cra_module		= THIS_MODULE,
756 	},
757 	.min_keysize	= 2 * AES_MIN_KEY_SIZE,
758 	.max_keysize	= 2 * AES_MAX_KEY_SIZE,
759 	.ivsize		= AES_BLOCK_SIZE,
760 	.walksize	= 2 * AES_BLOCK_SIZE,
761 	.setkey		= xts_set_key,
762 	.encrypt	= xts_encrypt,
763 	.decrypt	= xts_decrypt,
764 }, {
765 #endif
766 	.base = {
767 		.cra_name		= "cts(cbc(aes))",
768 		.cra_driver_name	= "cts-cbc-aes-" MODE,
769 		.cra_priority		= PRIO,
770 		.cra_blocksize		= AES_BLOCK_SIZE,
771 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
772 		.cra_module		= THIS_MODULE,
773 	},
774 	.min_keysize	= AES_MIN_KEY_SIZE,
775 	.max_keysize	= AES_MAX_KEY_SIZE,
776 	.ivsize		= AES_BLOCK_SIZE,
777 	.walksize	= 2 * AES_BLOCK_SIZE,
778 	.setkey		= skcipher_aes_setkey,
779 	.encrypt	= cts_cbc_encrypt,
780 	.decrypt	= cts_cbc_decrypt,
781 }, {
782 	.base = {
783 		.cra_name		= "essiv(cbc(aes),sha256)",
784 		.cra_driver_name	= "essiv-cbc-aes-sha256-" MODE,
785 		.cra_priority		= PRIO + 1,
786 		.cra_blocksize		= AES_BLOCK_SIZE,
787 		.cra_ctxsize		= sizeof(struct crypto_aes_essiv_cbc_ctx),
788 		.cra_module		= THIS_MODULE,
789 	},
790 	.min_keysize	= AES_MIN_KEY_SIZE,
791 	.max_keysize	= AES_MAX_KEY_SIZE,
792 	.ivsize		= AES_BLOCK_SIZE,
793 	.setkey		= essiv_cbc_set_key,
794 	.encrypt	= essiv_cbc_encrypt,
795 	.decrypt	= essiv_cbc_decrypt,
796 	.init		= essiv_cbc_init_tfm,
797 	.exit		= essiv_cbc_exit_tfm,
798 } };
799 
800 static int cbcmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
801 			 unsigned int key_len)
802 {
803 	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
804 
805 	return aes_expandkey(&ctx->key, in_key, key_len);
806 }
807 
808 static void cmac_gf128_mul_by_x(be128 *y, const be128 *x)
809 {
810 	u64 a = be64_to_cpu(x->a);
811 	u64 b = be64_to_cpu(x->b);
812 
813 	y->a = cpu_to_be64((a << 1) | (b >> 63));
814 	y->b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
815 }
816 
817 static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
818 		       unsigned int key_len)
819 {
820 	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
821 	be128 *consts = (be128 *)ctx->consts;
822 	int rounds = 6 + key_len / 4;
823 	int err;
824 
825 	err = cbcmac_setkey(tfm, in_key, key_len);
826 	if (err)
827 		return err;
828 
829 	/* encrypt the zero vector */
830 	kernel_neon_begin();
831 	aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
832 			rounds, 1);
833 	kernel_neon_end();
834 
835 	cmac_gf128_mul_by_x(consts, consts);
836 	cmac_gf128_mul_by_x(consts + 1, consts);
837 
838 	return 0;
839 }
840 
841 static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
842 		       unsigned int key_len)
843 {
844 	static u8 const ks[3][AES_BLOCK_SIZE] = {
845 		{ [0 ... AES_BLOCK_SIZE - 1] = 0x1 },
846 		{ [0 ... AES_BLOCK_SIZE - 1] = 0x2 },
847 		{ [0 ... AES_BLOCK_SIZE - 1] = 0x3 },
848 	};
849 
850 	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
851 	int rounds = 6 + key_len / 4;
852 	u8 key[AES_BLOCK_SIZE];
853 	int err;
854 
855 	err = cbcmac_setkey(tfm, in_key, key_len);
856 	if (err)
857 		return err;
858 
859 	kernel_neon_begin();
860 	aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
861 	aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
862 	kernel_neon_end();
863 
864 	return cbcmac_setkey(tfm, key, sizeof(key));
865 }
866 
867 static int mac_init(struct shash_desc *desc)
868 {
869 	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
870 
871 	memset(ctx->dg, 0, AES_BLOCK_SIZE);
872 	ctx->len = 0;
873 
874 	return 0;
875 }
876 
877 static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
878 			  u8 dg[], int enc_before, int enc_after)
879 {
880 	int rounds = 6 + ctx->key_length / 4;
881 
882 	if (crypto_simd_usable()) {
883 		int rem;
884 
885 		do {
886 			kernel_neon_begin();
887 			rem = aes_mac_update(in, ctx->key_enc, rounds, blocks,
888 					     dg, enc_before, enc_after);
889 			kernel_neon_end();
890 			in += (blocks - rem) * AES_BLOCK_SIZE;
891 			blocks = rem;
892 			enc_before = 0;
893 		} while (blocks);
894 	} else {
895 		if (enc_before)
896 			aes_encrypt(ctx, dg, dg);
897 
898 		while (blocks--) {
899 			crypto_xor(dg, in, AES_BLOCK_SIZE);
900 			in += AES_BLOCK_SIZE;
901 
902 			if (blocks || enc_after)
903 				aes_encrypt(ctx, dg, dg);
904 		}
905 	}
906 }
907 
908 static int mac_update(struct shash_desc *desc, const u8 *p, unsigned int len)
909 {
910 	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
911 	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
912 
913 	while (len > 0) {
914 		unsigned int l;
915 
916 		if ((ctx->len % AES_BLOCK_SIZE) == 0 &&
917 		    (ctx->len + len) > AES_BLOCK_SIZE) {
918 
919 			int blocks = len / AES_BLOCK_SIZE;
920 
921 			len %= AES_BLOCK_SIZE;
922 
923 			mac_do_update(&tctx->key, p, blocks, ctx->dg,
924 				      (ctx->len != 0), (len != 0));
925 
926 			p += blocks * AES_BLOCK_SIZE;
927 
928 			if (!len) {
929 				ctx->len = AES_BLOCK_SIZE;
930 				break;
931 			}
932 			ctx->len = 0;
933 		}
934 
935 		l = min(len, AES_BLOCK_SIZE - ctx->len);
936 
937 		if (l <= AES_BLOCK_SIZE) {
938 			crypto_xor(ctx->dg + ctx->len, p, l);
939 			ctx->len += l;
940 			len -= l;
941 			p += l;
942 		}
943 	}
944 
945 	return 0;
946 }
947 
948 static int cbcmac_final(struct shash_desc *desc, u8 *out)
949 {
950 	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
951 	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
952 
953 	mac_do_update(&tctx->key, NULL, 0, ctx->dg, (ctx->len != 0), 0);
954 
955 	memcpy(out, ctx->dg, AES_BLOCK_SIZE);
956 
957 	return 0;
958 }
959 
960 static int cmac_final(struct shash_desc *desc, u8 *out)
961 {
962 	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
963 	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
964 	u8 *consts = tctx->consts;
965 
966 	if (ctx->len != AES_BLOCK_SIZE) {
967 		ctx->dg[ctx->len] ^= 0x80;
968 		consts += AES_BLOCK_SIZE;
969 	}
970 
971 	mac_do_update(&tctx->key, consts, 1, ctx->dg, 0, 1);
972 
973 	memcpy(out, ctx->dg, AES_BLOCK_SIZE);
974 
975 	return 0;
976 }
977 
978 static struct shash_alg mac_algs[] = { {
979 	.base.cra_name		= "cmac(aes)",
980 	.base.cra_driver_name	= "cmac-aes-" MODE,
981 	.base.cra_priority	= PRIO,
982 	.base.cra_blocksize	= AES_BLOCK_SIZE,
983 	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx) +
984 				  2 * AES_BLOCK_SIZE,
985 	.base.cra_module	= THIS_MODULE,
986 
987 	.digestsize		= AES_BLOCK_SIZE,
988 	.init			= mac_init,
989 	.update			= mac_update,
990 	.final			= cmac_final,
991 	.setkey			= cmac_setkey,
992 	.descsize		= sizeof(struct mac_desc_ctx),
993 }, {
994 	.base.cra_name		= "xcbc(aes)",
995 	.base.cra_driver_name	= "xcbc-aes-" MODE,
996 	.base.cra_priority	= PRIO,
997 	.base.cra_blocksize	= AES_BLOCK_SIZE,
998 	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx) +
999 				  2 * AES_BLOCK_SIZE,
1000 	.base.cra_module	= THIS_MODULE,
1001 
1002 	.digestsize		= AES_BLOCK_SIZE,
1003 	.init			= mac_init,
1004 	.update			= mac_update,
1005 	.final			= cmac_final,
1006 	.setkey			= xcbc_setkey,
1007 	.descsize		= sizeof(struct mac_desc_ctx),
1008 }, {
1009 	.base.cra_name		= "cbcmac(aes)",
1010 	.base.cra_driver_name	= "cbcmac-aes-" MODE,
1011 	.base.cra_priority	= PRIO,
1012 	.base.cra_blocksize	= 1,
1013 	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx),
1014 	.base.cra_module	= THIS_MODULE,
1015 
1016 	.digestsize		= AES_BLOCK_SIZE,
1017 	.init			= mac_init,
1018 	.update			= mac_update,
1019 	.final			= cbcmac_final,
1020 	.setkey			= cbcmac_setkey,
1021 	.descsize		= sizeof(struct mac_desc_ctx),
1022 } };
1023 
1024 static void aes_exit(void)
1025 {
1026 	crypto_unregister_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1027 	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1028 }
1029 
1030 static int __init aes_init(void)
1031 {
1032 	int err;
1033 
1034 	err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1035 	if (err)
1036 		return err;
1037 
1038 	err = crypto_register_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1039 	if (err)
1040 		goto unregister_ciphers;
1041 
1042 	return 0;
1043 
1044 unregister_ciphers:
1045 	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1046 	return err;
1047 }
1048 
1049 #ifdef USE_V8_CRYPTO_EXTENSIONS
1050 module_cpu_feature_match(AES, aes_init);
1051 EXPORT_SYMBOL_NS(ce_aes_mac_update, CRYPTO_INTERNAL);
1052 #else
1053 module_init(aes_init);
1054 EXPORT_SYMBOL(neon_aes_ecb_encrypt);
1055 EXPORT_SYMBOL(neon_aes_cbc_encrypt);
1056 EXPORT_SYMBOL(neon_aes_ctr_encrypt);
1057 EXPORT_SYMBOL(neon_aes_xts_encrypt);
1058 EXPORT_SYMBOL(neon_aes_xts_decrypt);
1059 #endif
1060 module_exit(aes_exit);
1061