xref: /linux/arch/arm64/crypto/aes-glue.c (revision 7fc2cd2e4b398c57c9cf961cfea05eadbf34c05c)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
4  *
5  * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7 
8 #include <crypto/aes.h>
9 #include <crypto/ctr.h>
10 #include <crypto/internal/hash.h>
11 #include <crypto/internal/skcipher.h>
12 #include <crypto/scatterwalk.h>
13 #include <crypto/sha2.h>
14 #include <crypto/utils.h>
15 #include <crypto/xts.h>
16 #include <linux/cpufeature.h>
17 #include <linux/kernel.h>
18 #include <linux/module.h>
19 #include <linux/string.h>
20 
21 #include <asm/hwcap.h>
22 #include <asm/simd.h>
23 
24 #include "aes-ce-setkey.h"
25 
26 #ifdef USE_V8_CRYPTO_EXTENSIONS
27 #define MODE			"ce"
28 #define PRIO			300
29 #define aes_expandkey		ce_aes_expandkey
30 #define aes_ecb_encrypt		ce_aes_ecb_encrypt
31 #define aes_ecb_decrypt		ce_aes_ecb_decrypt
32 #define aes_cbc_encrypt		ce_aes_cbc_encrypt
33 #define aes_cbc_decrypt		ce_aes_cbc_decrypt
34 #define aes_cbc_cts_encrypt	ce_aes_cbc_cts_encrypt
35 #define aes_cbc_cts_decrypt	ce_aes_cbc_cts_decrypt
36 #define aes_essiv_cbc_encrypt	ce_aes_essiv_cbc_encrypt
37 #define aes_essiv_cbc_decrypt	ce_aes_essiv_cbc_decrypt
38 #define aes_ctr_encrypt		ce_aes_ctr_encrypt
39 #define aes_xctr_encrypt	ce_aes_xctr_encrypt
40 #define aes_xts_encrypt		ce_aes_xts_encrypt
41 #define aes_xts_decrypt		ce_aes_xts_decrypt
42 #define aes_mac_update		ce_aes_mac_update
43 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 Crypto Extensions");
44 #else
45 #define MODE			"neon"
46 #define PRIO			200
47 #define aes_ecb_encrypt		neon_aes_ecb_encrypt
48 #define aes_ecb_decrypt		neon_aes_ecb_decrypt
49 #define aes_cbc_encrypt		neon_aes_cbc_encrypt
50 #define aes_cbc_decrypt		neon_aes_cbc_decrypt
51 #define aes_cbc_cts_encrypt	neon_aes_cbc_cts_encrypt
52 #define aes_cbc_cts_decrypt	neon_aes_cbc_cts_decrypt
53 #define aes_essiv_cbc_encrypt	neon_aes_essiv_cbc_encrypt
54 #define aes_essiv_cbc_decrypt	neon_aes_essiv_cbc_decrypt
55 #define aes_ctr_encrypt		neon_aes_ctr_encrypt
56 #define aes_xctr_encrypt	neon_aes_xctr_encrypt
57 #define aes_xts_encrypt		neon_aes_xts_encrypt
58 #define aes_xts_decrypt		neon_aes_xts_decrypt
59 #define aes_mac_update		neon_aes_mac_update
60 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 NEON");
61 #endif
62 #if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
63 MODULE_ALIAS_CRYPTO("ecb(aes)");
64 MODULE_ALIAS_CRYPTO("cbc(aes)");
65 MODULE_ALIAS_CRYPTO("ctr(aes)");
66 MODULE_ALIAS_CRYPTO("xts(aes)");
67 MODULE_ALIAS_CRYPTO("xctr(aes)");
68 #endif
69 MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
70 MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)");
71 MODULE_ALIAS_CRYPTO("cmac(aes)");
72 MODULE_ALIAS_CRYPTO("xcbc(aes)");
73 MODULE_ALIAS_CRYPTO("cbcmac(aes)");
74 
75 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
76 MODULE_LICENSE("GPL v2");
77 
78 /* defined in aes-modes.S */
79 asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
80 				int rounds, int blocks);
81 asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
82 				int rounds, int blocks);
83 
84 asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
85 				int rounds, int blocks, u8 iv[]);
86 asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
87 				int rounds, int blocks, u8 iv[]);
88 
89 asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
90 				int rounds, int bytes, u8 const iv[]);
91 asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
92 				int rounds, int bytes, u8 const iv[]);
93 
94 asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
95 				int rounds, int bytes, u8 ctr[]);
96 
97 asmlinkage void aes_xctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
98 				 int rounds, int bytes, u8 ctr[], int byte_ctr);
99 
100 asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
101 				int rounds, int bytes, u32 const rk2[], u8 iv[],
102 				int first);
103 asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
104 				int rounds, int bytes, u32 const rk2[], u8 iv[],
105 				int first);
106 
107 asmlinkage void aes_essiv_cbc_encrypt(u8 out[], u8 const in[], u32 const rk1[],
108 				      int rounds, int blocks, u8 iv[],
109 				      u32 const rk2[]);
110 asmlinkage void aes_essiv_cbc_decrypt(u8 out[], u8 const in[], u32 const rk1[],
111 				      int rounds, int blocks, u8 iv[],
112 				      u32 const rk2[]);
113 
114 asmlinkage int aes_mac_update(u8 const in[], u32 const rk[], int rounds,
115 			      int blocks, u8 dg[], int enc_before,
116 			      int enc_after);
117 
118 struct crypto_aes_xts_ctx {
119 	struct crypto_aes_ctx key1;
120 	struct crypto_aes_ctx __aligned(8) key2;
121 };
122 
123 struct crypto_aes_essiv_cbc_ctx {
124 	struct crypto_aes_ctx key1;
125 	struct crypto_aes_ctx __aligned(8) key2;
126 };
127 
128 struct mac_tfm_ctx {
129 	struct crypto_aes_ctx key;
130 	u8 __aligned(8) consts[];
131 };
132 
133 struct mac_desc_ctx {
134 	u8 dg[AES_BLOCK_SIZE];
135 };
136 
137 static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
138 			       unsigned int key_len)
139 {
140 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
141 
142 	return aes_expandkey(ctx, in_key, key_len);
143 }
144 
145 static int __maybe_unused xts_set_key(struct crypto_skcipher *tfm,
146 				      const u8 *in_key, unsigned int key_len)
147 {
148 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
149 	int ret;
150 
151 	ret = xts_verify_key(tfm, in_key, key_len);
152 	if (ret)
153 		return ret;
154 
155 	ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
156 	if (!ret)
157 		ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
158 				    key_len / 2);
159 	return ret;
160 }
161 
162 static int __maybe_unused essiv_cbc_set_key(struct crypto_skcipher *tfm,
163 					    const u8 *in_key,
164 					    unsigned int key_len)
165 {
166 	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
167 	u8 digest[SHA256_DIGEST_SIZE];
168 	int ret;
169 
170 	ret = aes_expandkey(&ctx->key1, in_key, key_len);
171 	if (ret)
172 		return ret;
173 
174 	sha256(in_key, key_len, digest);
175 
176 	return aes_expandkey(&ctx->key2, digest, sizeof(digest));
177 }
178 
179 static int __maybe_unused ecb_encrypt(struct skcipher_request *req)
180 {
181 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
182 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
183 	int err, rounds = 6 + ctx->key_length / 4;
184 	struct skcipher_walk walk;
185 	unsigned int blocks;
186 
187 	err = skcipher_walk_virt(&walk, req, false);
188 
189 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
190 		scoped_ksimd()
191 			aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
192 					ctx->key_enc, rounds, blocks);
193 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
194 	}
195 	return err;
196 }
197 
198 static int __maybe_unused ecb_decrypt(struct skcipher_request *req)
199 {
200 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
201 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
202 	int err, rounds = 6 + ctx->key_length / 4;
203 	struct skcipher_walk walk;
204 	unsigned int blocks;
205 
206 	err = skcipher_walk_virt(&walk, req, false);
207 
208 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
209 		scoped_ksimd()
210 			aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
211 					ctx->key_dec, rounds, blocks);
212 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
213 	}
214 	return err;
215 }
216 
217 static int cbc_encrypt_walk(struct skcipher_request *req,
218 			    struct skcipher_walk *walk)
219 {
220 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
221 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
222 	int err = 0, rounds = 6 + ctx->key_length / 4;
223 	unsigned int blocks;
224 
225 	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
226 		scoped_ksimd()
227 			aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
228 					ctx->key_enc, rounds, blocks, walk->iv);
229 		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
230 	}
231 	return err;
232 }
233 
234 static int __maybe_unused cbc_encrypt(struct skcipher_request *req)
235 {
236 	struct skcipher_walk walk;
237 	int err;
238 
239 	err = skcipher_walk_virt(&walk, req, false);
240 	if (err)
241 		return err;
242 	return cbc_encrypt_walk(req, &walk);
243 }
244 
245 static int cbc_decrypt_walk(struct skcipher_request *req,
246 			    struct skcipher_walk *walk)
247 {
248 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
249 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
250 	int err = 0, rounds = 6 + ctx->key_length / 4;
251 	unsigned int blocks;
252 
253 	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
254 		scoped_ksimd()
255 			aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
256 					ctx->key_dec, rounds, blocks, walk->iv);
257 		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
258 	}
259 	return err;
260 }
261 
262 static int __maybe_unused cbc_decrypt(struct skcipher_request *req)
263 {
264 	struct skcipher_walk walk;
265 	int err;
266 
267 	err = skcipher_walk_virt(&walk, req, false);
268 	if (err)
269 		return err;
270 	return cbc_decrypt_walk(req, &walk);
271 }
272 
273 static int cts_cbc_encrypt(struct skcipher_request *req)
274 {
275 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
276 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
277 	int err, rounds = 6 + ctx->key_length / 4;
278 	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
279 	struct scatterlist *src = req->src, *dst = req->dst;
280 	struct scatterlist sg_src[2], sg_dst[2];
281 	struct skcipher_request subreq;
282 	struct skcipher_walk walk;
283 
284 	skcipher_request_set_tfm(&subreq, tfm);
285 	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
286 				      NULL, NULL);
287 
288 	if (req->cryptlen <= AES_BLOCK_SIZE) {
289 		if (req->cryptlen < AES_BLOCK_SIZE)
290 			return -EINVAL;
291 		cbc_blocks = 1;
292 	}
293 
294 	if (cbc_blocks > 0) {
295 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
296 					   cbc_blocks * AES_BLOCK_SIZE,
297 					   req->iv);
298 
299 		err = skcipher_walk_virt(&walk, &subreq, false) ?:
300 		      cbc_encrypt_walk(&subreq, &walk);
301 		if (err)
302 			return err;
303 
304 		if (req->cryptlen == AES_BLOCK_SIZE)
305 			return 0;
306 
307 		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
308 		if (req->dst != req->src)
309 			dst = scatterwalk_ffwd(sg_dst, req->dst,
310 					       subreq.cryptlen);
311 	}
312 
313 	/* handle ciphertext stealing */
314 	skcipher_request_set_crypt(&subreq, src, dst,
315 				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
316 				   req->iv);
317 
318 	err = skcipher_walk_virt(&walk, &subreq, false);
319 	if (err)
320 		return err;
321 
322 	scoped_ksimd()
323 		aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
324 				    ctx->key_enc, rounds, walk.nbytes, walk.iv);
325 
326 	return skcipher_walk_done(&walk, 0);
327 }
328 
329 static int cts_cbc_decrypt(struct skcipher_request *req)
330 {
331 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
332 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
333 	int err, rounds = 6 + ctx->key_length / 4;
334 	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
335 	struct scatterlist *src = req->src, *dst = req->dst;
336 	struct scatterlist sg_src[2], sg_dst[2];
337 	struct skcipher_request subreq;
338 	struct skcipher_walk walk;
339 
340 	skcipher_request_set_tfm(&subreq, tfm);
341 	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
342 				      NULL, NULL);
343 
344 	if (req->cryptlen <= AES_BLOCK_SIZE) {
345 		if (req->cryptlen < AES_BLOCK_SIZE)
346 			return -EINVAL;
347 		cbc_blocks = 1;
348 	}
349 
350 	if (cbc_blocks > 0) {
351 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
352 					   cbc_blocks * AES_BLOCK_SIZE,
353 					   req->iv);
354 
355 		err = skcipher_walk_virt(&walk, &subreq, false) ?:
356 		      cbc_decrypt_walk(&subreq, &walk);
357 		if (err)
358 			return err;
359 
360 		if (req->cryptlen == AES_BLOCK_SIZE)
361 			return 0;
362 
363 		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
364 		if (req->dst != req->src)
365 			dst = scatterwalk_ffwd(sg_dst, req->dst,
366 					       subreq.cryptlen);
367 	}
368 
369 	/* handle ciphertext stealing */
370 	skcipher_request_set_crypt(&subreq, src, dst,
371 				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
372 				   req->iv);
373 
374 	err = skcipher_walk_virt(&walk, &subreq, false);
375 	if (err)
376 		return err;
377 
378 	scoped_ksimd()
379 		aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
380 				    ctx->key_dec, rounds, walk.nbytes, walk.iv);
381 
382 	return skcipher_walk_done(&walk, 0);
383 }
384 
385 static int __maybe_unused essiv_cbc_encrypt(struct skcipher_request *req)
386 {
387 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
388 	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
389 	int err, rounds = 6 + ctx->key1.key_length / 4;
390 	struct skcipher_walk walk;
391 	unsigned int blocks;
392 
393 	err = skcipher_walk_virt(&walk, req, false);
394 
395 	blocks = walk.nbytes / AES_BLOCK_SIZE;
396 	if (blocks) {
397 		scoped_ksimd()
398 			aes_essiv_cbc_encrypt(walk.dst.virt.addr,
399 					      walk.src.virt.addr,
400 					      ctx->key1.key_enc, rounds, blocks,
401 					      req->iv, ctx->key2.key_enc);
402 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
403 	}
404 	return err ?: cbc_encrypt_walk(req, &walk);
405 }
406 
407 static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)
408 {
409 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
410 	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
411 	int err, rounds = 6 + ctx->key1.key_length / 4;
412 	struct skcipher_walk walk;
413 	unsigned int blocks;
414 
415 	err = skcipher_walk_virt(&walk, req, false);
416 
417 	blocks = walk.nbytes / AES_BLOCK_SIZE;
418 	if (blocks) {
419 		scoped_ksimd()
420 			aes_essiv_cbc_decrypt(walk.dst.virt.addr,
421 					      walk.src.virt.addr,
422 					      ctx->key1.key_dec, rounds, blocks,
423 					      req->iv, ctx->key2.key_enc);
424 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
425 	}
426 	return err ?: cbc_decrypt_walk(req, &walk);
427 }
428 
429 static int __maybe_unused xctr_encrypt(struct skcipher_request *req)
430 {
431 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
432 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
433 	int err, rounds = 6 + ctx->key_length / 4;
434 	struct skcipher_walk walk;
435 	unsigned int byte_ctr = 0;
436 
437 	err = skcipher_walk_virt(&walk, req, false);
438 
439 	while (walk.nbytes > 0) {
440 		const u8 *src = walk.src.virt.addr;
441 		unsigned int nbytes = walk.nbytes;
442 		u8 *dst = walk.dst.virt.addr;
443 		u8 buf[AES_BLOCK_SIZE];
444 
445 		/*
446 		 * If given less than 16 bytes, we must copy the partial block
447 		 * into a temporary buffer of 16 bytes to avoid out of bounds
448 		 * reads and writes.  Furthermore, this code is somewhat unusual
449 		 * in that it expects the end of the data to be at the end of
450 		 * the temporary buffer, rather than the start of the data at
451 		 * the start of the temporary buffer.
452 		 */
453 		if (unlikely(nbytes < AES_BLOCK_SIZE))
454 			src = dst = memcpy(buf + sizeof(buf) - nbytes,
455 					   src, nbytes);
456 		else if (nbytes < walk.total)
457 			nbytes &= ~(AES_BLOCK_SIZE - 1);
458 
459 		scoped_ksimd()
460 			aes_xctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
461 							 walk.iv, byte_ctr);
462 
463 		if (unlikely(nbytes < AES_BLOCK_SIZE))
464 			memcpy(walk.dst.virt.addr,
465 			       buf + sizeof(buf) - nbytes, nbytes);
466 		byte_ctr += nbytes;
467 
468 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
469 	}
470 
471 	return err;
472 }
473 
474 static int __maybe_unused ctr_encrypt(struct skcipher_request *req)
475 {
476 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
477 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
478 	int err, rounds = 6 + ctx->key_length / 4;
479 	struct skcipher_walk walk;
480 
481 	err = skcipher_walk_virt(&walk, req, false);
482 
483 	while (walk.nbytes > 0) {
484 		const u8 *src = walk.src.virt.addr;
485 		unsigned int nbytes = walk.nbytes;
486 		u8 *dst = walk.dst.virt.addr;
487 		u8 buf[AES_BLOCK_SIZE];
488 
489 		/*
490 		 * If given less than 16 bytes, we must copy the partial block
491 		 * into a temporary buffer of 16 bytes to avoid out of bounds
492 		 * reads and writes.  Furthermore, this code is somewhat unusual
493 		 * in that it expects the end of the data to be at the end of
494 		 * the temporary buffer, rather than the start of the data at
495 		 * the start of the temporary buffer.
496 		 */
497 		if (unlikely(nbytes < AES_BLOCK_SIZE))
498 			src = dst = memcpy(buf + sizeof(buf) - nbytes,
499 					   src, nbytes);
500 		else if (nbytes < walk.total)
501 			nbytes &= ~(AES_BLOCK_SIZE - 1);
502 
503 		scoped_ksimd()
504 			aes_ctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
505 					walk.iv);
506 
507 		if (unlikely(nbytes < AES_BLOCK_SIZE))
508 			memcpy(walk.dst.virt.addr,
509 			       buf + sizeof(buf) - nbytes, nbytes);
510 
511 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
512 	}
513 
514 	return err;
515 }
516 
517 static int __maybe_unused xts_encrypt(struct skcipher_request *req)
518 {
519 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
520 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
521 	int err, first, rounds = 6 + ctx->key1.key_length / 4;
522 	int tail = req->cryptlen % AES_BLOCK_SIZE;
523 	struct scatterlist sg_src[2], sg_dst[2];
524 	struct skcipher_request subreq;
525 	struct scatterlist *src, *dst;
526 	struct skcipher_walk walk;
527 
528 	if (req->cryptlen < AES_BLOCK_SIZE)
529 		return -EINVAL;
530 
531 	err = skcipher_walk_virt(&walk, req, false);
532 
533 	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
534 		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
535 					      AES_BLOCK_SIZE) - 2;
536 
537 		skcipher_walk_abort(&walk);
538 
539 		skcipher_request_set_tfm(&subreq, tfm);
540 		skcipher_request_set_callback(&subreq,
541 					      skcipher_request_flags(req),
542 					      NULL, NULL);
543 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
544 					   xts_blocks * AES_BLOCK_SIZE,
545 					   req->iv);
546 		req = &subreq;
547 		err = skcipher_walk_virt(&walk, req, false);
548 	} else {
549 		tail = 0;
550 	}
551 
552 	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
553 		int nbytes = walk.nbytes;
554 
555 		if (walk.nbytes < walk.total)
556 			nbytes &= ~(AES_BLOCK_SIZE - 1);
557 
558 		scoped_ksimd()
559 			aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
560 					ctx->key1.key_enc, rounds, nbytes,
561 					ctx->key2.key_enc, walk.iv, first);
562 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
563 	}
564 
565 	if (err || likely(!tail))
566 		return err;
567 
568 	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
569 	if (req->dst != req->src)
570 		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
571 
572 	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
573 				   req->iv);
574 
575 	err = skcipher_walk_virt(&walk, &subreq, false);
576 	if (err)
577 		return err;
578 
579 	scoped_ksimd()
580 		aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
581 				ctx->key1.key_enc, rounds, walk.nbytes,
582 				ctx->key2.key_enc, walk.iv, first);
583 
584 	return skcipher_walk_done(&walk, 0);
585 }
586 
587 static int __maybe_unused xts_decrypt(struct skcipher_request *req)
588 {
589 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
590 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
591 	int err, first, rounds = 6 + ctx->key1.key_length / 4;
592 	int tail = req->cryptlen % AES_BLOCK_SIZE;
593 	struct scatterlist sg_src[2], sg_dst[2];
594 	struct skcipher_request subreq;
595 	struct scatterlist *src, *dst;
596 	struct skcipher_walk walk;
597 
598 	if (req->cryptlen < AES_BLOCK_SIZE)
599 		return -EINVAL;
600 
601 	err = skcipher_walk_virt(&walk, req, false);
602 
603 	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
604 		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
605 					      AES_BLOCK_SIZE) - 2;
606 
607 		skcipher_walk_abort(&walk);
608 
609 		skcipher_request_set_tfm(&subreq, tfm);
610 		skcipher_request_set_callback(&subreq,
611 					      skcipher_request_flags(req),
612 					      NULL, NULL);
613 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
614 					   xts_blocks * AES_BLOCK_SIZE,
615 					   req->iv);
616 		req = &subreq;
617 		err = skcipher_walk_virt(&walk, req, false);
618 	} else {
619 		tail = 0;
620 	}
621 
622 	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
623 		int nbytes = walk.nbytes;
624 
625 		if (walk.nbytes < walk.total)
626 			nbytes &= ~(AES_BLOCK_SIZE - 1);
627 
628 		scoped_ksimd()
629 			aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
630 					ctx->key1.key_dec, rounds, nbytes,
631 					ctx->key2.key_enc, walk.iv, first);
632 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
633 	}
634 
635 	if (err || likely(!tail))
636 		return err;
637 
638 	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
639 	if (req->dst != req->src)
640 		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
641 
642 	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
643 				   req->iv);
644 
645 	err = skcipher_walk_virt(&walk, &subreq, false);
646 	if (err)
647 		return err;
648 
649 
650 	scoped_ksimd()
651 		aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
652 				ctx->key1.key_dec, rounds, walk.nbytes,
653 				ctx->key2.key_enc, walk.iv, first);
654 
655 	return skcipher_walk_done(&walk, 0);
656 }
657 
658 static struct skcipher_alg aes_algs[] = { {
659 #if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
660 	.base = {
661 		.cra_name		= "ecb(aes)",
662 		.cra_driver_name	= "ecb-aes-" MODE,
663 		.cra_priority		= PRIO,
664 		.cra_blocksize		= AES_BLOCK_SIZE,
665 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
666 		.cra_module		= THIS_MODULE,
667 	},
668 	.min_keysize	= AES_MIN_KEY_SIZE,
669 	.max_keysize	= AES_MAX_KEY_SIZE,
670 	.setkey		= skcipher_aes_setkey,
671 	.encrypt	= ecb_encrypt,
672 	.decrypt	= ecb_decrypt,
673 }, {
674 	.base = {
675 		.cra_name		= "cbc(aes)",
676 		.cra_driver_name	= "cbc-aes-" MODE,
677 		.cra_priority		= PRIO,
678 		.cra_blocksize		= AES_BLOCK_SIZE,
679 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
680 		.cra_module		= THIS_MODULE,
681 	},
682 	.min_keysize	= AES_MIN_KEY_SIZE,
683 	.max_keysize	= AES_MAX_KEY_SIZE,
684 	.ivsize		= AES_BLOCK_SIZE,
685 	.setkey		= skcipher_aes_setkey,
686 	.encrypt	= cbc_encrypt,
687 	.decrypt	= cbc_decrypt,
688 }, {
689 	.base = {
690 		.cra_name		= "ctr(aes)",
691 		.cra_driver_name	= "ctr-aes-" MODE,
692 		.cra_priority		= PRIO,
693 		.cra_blocksize		= 1,
694 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
695 		.cra_module		= THIS_MODULE,
696 	},
697 	.min_keysize	= AES_MIN_KEY_SIZE,
698 	.max_keysize	= AES_MAX_KEY_SIZE,
699 	.ivsize		= AES_BLOCK_SIZE,
700 	.chunksize	= AES_BLOCK_SIZE,
701 	.setkey		= skcipher_aes_setkey,
702 	.encrypt	= ctr_encrypt,
703 	.decrypt	= ctr_encrypt,
704 }, {
705 	.base = {
706 		.cra_name		= "xctr(aes)",
707 		.cra_driver_name	= "xctr-aes-" MODE,
708 		.cra_priority		= PRIO,
709 		.cra_blocksize		= 1,
710 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
711 		.cra_module		= THIS_MODULE,
712 	},
713 	.min_keysize	= AES_MIN_KEY_SIZE,
714 	.max_keysize	= AES_MAX_KEY_SIZE,
715 	.ivsize		= AES_BLOCK_SIZE,
716 	.chunksize	= AES_BLOCK_SIZE,
717 	.setkey		= skcipher_aes_setkey,
718 	.encrypt	= xctr_encrypt,
719 	.decrypt	= xctr_encrypt,
720 }, {
721 	.base = {
722 		.cra_name		= "xts(aes)",
723 		.cra_driver_name	= "xts-aes-" MODE,
724 		.cra_priority		= PRIO,
725 		.cra_blocksize		= AES_BLOCK_SIZE,
726 		.cra_ctxsize		= sizeof(struct crypto_aes_xts_ctx),
727 		.cra_module		= THIS_MODULE,
728 	},
729 	.min_keysize	= 2 * AES_MIN_KEY_SIZE,
730 	.max_keysize	= 2 * AES_MAX_KEY_SIZE,
731 	.ivsize		= AES_BLOCK_SIZE,
732 	.walksize	= 2 * AES_BLOCK_SIZE,
733 	.setkey		= xts_set_key,
734 	.encrypt	= xts_encrypt,
735 	.decrypt	= xts_decrypt,
736 }, {
737 #endif
738 	.base = {
739 		.cra_name		= "cts(cbc(aes))",
740 		.cra_driver_name	= "cts-cbc-aes-" MODE,
741 		.cra_priority		= PRIO,
742 		.cra_blocksize		= AES_BLOCK_SIZE,
743 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
744 		.cra_module		= THIS_MODULE,
745 	},
746 	.min_keysize	= AES_MIN_KEY_SIZE,
747 	.max_keysize	= AES_MAX_KEY_SIZE,
748 	.ivsize		= AES_BLOCK_SIZE,
749 	.walksize	= 2 * AES_BLOCK_SIZE,
750 	.setkey		= skcipher_aes_setkey,
751 	.encrypt	= cts_cbc_encrypt,
752 	.decrypt	= cts_cbc_decrypt,
753 }, {
754 	.base = {
755 		.cra_name		= "essiv(cbc(aes),sha256)",
756 		.cra_driver_name	= "essiv-cbc-aes-sha256-" MODE,
757 		.cra_priority		= PRIO + 1,
758 		.cra_blocksize		= AES_BLOCK_SIZE,
759 		.cra_ctxsize		= sizeof(struct crypto_aes_essiv_cbc_ctx),
760 		.cra_module		= THIS_MODULE,
761 	},
762 	.min_keysize	= AES_MIN_KEY_SIZE,
763 	.max_keysize	= AES_MAX_KEY_SIZE,
764 	.ivsize		= AES_BLOCK_SIZE,
765 	.setkey		= essiv_cbc_set_key,
766 	.encrypt	= essiv_cbc_encrypt,
767 	.decrypt	= essiv_cbc_decrypt,
768 } };
769 
770 static int cbcmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
771 			 unsigned int key_len)
772 {
773 	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
774 
775 	return aes_expandkey(&ctx->key, in_key, key_len);
776 }
777 
778 static void cmac_gf128_mul_by_x(be128 *y, const be128 *x)
779 {
780 	u64 a = be64_to_cpu(x->a);
781 	u64 b = be64_to_cpu(x->b);
782 
783 	y->a = cpu_to_be64((a << 1) | (b >> 63));
784 	y->b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
785 }
786 
787 static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
788 		       unsigned int key_len)
789 {
790 	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
791 	be128 *consts = (be128 *)ctx->consts;
792 	int rounds = 6 + key_len / 4;
793 	int err;
794 
795 	err = cbcmac_setkey(tfm, in_key, key_len);
796 	if (err)
797 		return err;
798 
799 	/* encrypt the zero vector */
800 	scoped_ksimd()
801 		aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){},
802 				ctx->key.key_enc, rounds, 1);
803 
804 	cmac_gf128_mul_by_x(consts, consts);
805 	cmac_gf128_mul_by_x(consts + 1, consts);
806 
807 	return 0;
808 }
809 
810 static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
811 		       unsigned int key_len)
812 {
813 	static u8 const ks[3][AES_BLOCK_SIZE] = {
814 		{ [0 ... AES_BLOCK_SIZE - 1] = 0x1 },
815 		{ [0 ... AES_BLOCK_SIZE - 1] = 0x2 },
816 		{ [0 ... AES_BLOCK_SIZE - 1] = 0x3 },
817 	};
818 
819 	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
820 	int rounds = 6 + key_len / 4;
821 	u8 key[AES_BLOCK_SIZE];
822 	int err;
823 
824 	err = cbcmac_setkey(tfm, in_key, key_len);
825 	if (err)
826 		return err;
827 
828 	scoped_ksimd() {
829 		aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
830 		aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
831 	}
832 
833 	return cbcmac_setkey(tfm, key, sizeof(key));
834 }
835 
836 static int mac_init(struct shash_desc *desc)
837 {
838 	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
839 
840 	memset(ctx->dg, 0, AES_BLOCK_SIZE);
841 	return 0;
842 }
843 
844 static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
845 			  u8 dg[], int enc_before)
846 {
847 	int rounds = 6 + ctx->key_length / 4;
848 	int rem;
849 
850 	do {
851 		scoped_ksimd()
852 			rem = aes_mac_update(in, ctx->key_enc, rounds, blocks,
853 					     dg, enc_before, !enc_before);
854 		in += (blocks - rem) * AES_BLOCK_SIZE;
855 		blocks = rem;
856 	} while (blocks);
857 }
858 
859 static int mac_update(struct shash_desc *desc, const u8 *p, unsigned int len)
860 {
861 	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
862 	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
863 	int blocks = len / AES_BLOCK_SIZE;
864 
865 	len %= AES_BLOCK_SIZE;
866 	mac_do_update(&tctx->key, p, blocks, ctx->dg, 0);
867 	return len;
868 }
869 
870 static int cbcmac_finup(struct shash_desc *desc, const u8 *src,
871 			unsigned int len, u8 *out)
872 {
873 	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
874 	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
875 
876 	if (len) {
877 		crypto_xor(ctx->dg, src, len);
878 		mac_do_update(&tctx->key, NULL, 0, ctx->dg, 1);
879 	}
880 	memcpy(out, ctx->dg, AES_BLOCK_SIZE);
881 	return 0;
882 }
883 
884 static int cmac_finup(struct shash_desc *desc, const u8 *src, unsigned int len,
885 		      u8 *out)
886 {
887 	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
888 	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
889 	u8 *consts = tctx->consts;
890 
891 	crypto_xor(ctx->dg, src, len);
892 	if (len != AES_BLOCK_SIZE) {
893 		ctx->dg[len] ^= 0x80;
894 		consts += AES_BLOCK_SIZE;
895 	}
896 	mac_do_update(&tctx->key, consts, 1, ctx->dg, 0);
897 	memcpy(out, ctx->dg, AES_BLOCK_SIZE);
898 	return 0;
899 }
900 
901 static struct shash_alg mac_algs[] = { {
902 	.base.cra_name		= "cmac(aes)",
903 	.base.cra_driver_name	= "cmac-aes-" MODE,
904 	.base.cra_priority	= PRIO,
905 	.base.cra_flags		= CRYPTO_AHASH_ALG_BLOCK_ONLY |
906 				  CRYPTO_AHASH_ALG_FINAL_NONZERO,
907 	.base.cra_blocksize	= AES_BLOCK_SIZE,
908 	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx) +
909 				  2 * AES_BLOCK_SIZE,
910 	.base.cra_module	= THIS_MODULE,
911 
912 	.digestsize		= AES_BLOCK_SIZE,
913 	.init			= mac_init,
914 	.update			= mac_update,
915 	.finup			= cmac_finup,
916 	.setkey			= cmac_setkey,
917 	.descsize		= sizeof(struct mac_desc_ctx),
918 }, {
919 	.base.cra_name		= "xcbc(aes)",
920 	.base.cra_driver_name	= "xcbc-aes-" MODE,
921 	.base.cra_priority	= PRIO,
922 	.base.cra_flags		= CRYPTO_AHASH_ALG_BLOCK_ONLY |
923 				  CRYPTO_AHASH_ALG_FINAL_NONZERO,
924 	.base.cra_blocksize	= AES_BLOCK_SIZE,
925 	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx) +
926 				  2 * AES_BLOCK_SIZE,
927 	.base.cra_module	= THIS_MODULE,
928 
929 	.digestsize		= AES_BLOCK_SIZE,
930 	.init			= mac_init,
931 	.update			= mac_update,
932 	.finup			= cmac_finup,
933 	.setkey			= xcbc_setkey,
934 	.descsize		= sizeof(struct mac_desc_ctx),
935 }, {
936 	.base.cra_name		= "cbcmac(aes)",
937 	.base.cra_driver_name	= "cbcmac-aes-" MODE,
938 	.base.cra_priority	= PRIO,
939 	.base.cra_flags		= CRYPTO_AHASH_ALG_BLOCK_ONLY,
940 	.base.cra_blocksize	= AES_BLOCK_SIZE,
941 	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx),
942 	.base.cra_module	= THIS_MODULE,
943 
944 	.digestsize		= AES_BLOCK_SIZE,
945 	.init			= mac_init,
946 	.update			= mac_update,
947 	.finup			= cbcmac_finup,
948 	.setkey			= cbcmac_setkey,
949 	.descsize		= sizeof(struct mac_desc_ctx),
950 } };
951 
952 static void aes_exit(void)
953 {
954 	crypto_unregister_shashes(mac_algs, ARRAY_SIZE(mac_algs));
955 	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
956 }
957 
958 static int __init aes_init(void)
959 {
960 	int err;
961 
962 	err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
963 	if (err)
964 		return err;
965 
966 	err = crypto_register_shashes(mac_algs, ARRAY_SIZE(mac_algs));
967 	if (err)
968 		goto unregister_ciphers;
969 
970 	return 0;
971 
972 unregister_ciphers:
973 	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
974 	return err;
975 }
976 
977 #ifdef USE_V8_CRYPTO_EXTENSIONS
978 module_cpu_feature_match(AES, aes_init);
979 EXPORT_SYMBOL_NS(ce_aes_mac_update, "CRYPTO_INTERNAL");
980 #else
981 module_init(aes_init);
982 EXPORT_SYMBOL(neon_aes_ecb_encrypt);
983 EXPORT_SYMBOL(neon_aes_cbc_encrypt);
984 EXPORT_SYMBOL(neon_aes_ctr_encrypt);
985 EXPORT_SYMBOL(neon_aes_xts_encrypt);
986 EXPORT_SYMBOL(neon_aes_xts_decrypt);
987 #endif
988 module_exit(aes_exit);
989