xref: /linux/arch/x86/crypto/aesni-intel_glue.c (revision 566ab427f827b0256d3e8ce0235d088e6a9c28bd)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * Support for AES-NI and VAES instructions.  This file contains glue code.
4  * The real AES implementations are in aesni-intel_asm.S and other .S files.
5  *
6  * Copyright (C) 2008, Intel Corp.
7  *    Author: Huang Ying <ying.huang@intel.com>
8  *
9  * Added RFC4106 AES-GCM support for 128-bit keys under the AEAD
10  * interface for 64-bit kernels.
11  *    Authors: Adrian Hoban <adrian.hoban@intel.com>
12  *             Gabriele Paoloni <gabriele.paoloni@intel.com>
13  *             Tadeusz Struk (tadeusz.struk@intel.com)
14  *             Aidan O'Mahony (aidan.o.mahony@intel.com)
15  *    Copyright (c) 2010, Intel Corporation.
16  *
17  * Copyright 2024 Google LLC
18  */
19 
20 #include <linux/hardirq.h>
21 #include <linux/types.h>
22 #include <linux/module.h>
23 #include <linux/err.h>
24 #include <crypto/algapi.h>
25 #include <crypto/aes.h>
26 #include <crypto/ctr.h>
27 #include <crypto/b128ops.h>
28 #include <crypto/gcm.h>
29 #include <crypto/xts.h>
30 #include <asm/cpu_device_id.h>
31 #include <asm/simd.h>
32 #include <crypto/scatterwalk.h>
33 #include <crypto/internal/aead.h>
34 #include <crypto/internal/simd.h>
35 #include <crypto/internal/skcipher.h>
36 #include <linux/jump_label.h>
37 #include <linux/workqueue.h>
38 #include <linux/spinlock.h>
39 #include <linux/static_call.h>
40 
41 
42 #define AESNI_ALIGN	16
43 #define AESNI_ALIGN_ATTR __attribute__ ((__aligned__(AESNI_ALIGN)))
44 #define AES_BLOCK_MASK	(~(AES_BLOCK_SIZE - 1))
45 #define AESNI_ALIGN_EXTRA ((AESNI_ALIGN - 1) & ~(CRYPTO_MINALIGN - 1))
46 #define CRYPTO_AES_CTX_SIZE (sizeof(struct crypto_aes_ctx) + AESNI_ALIGN_EXTRA)
47 #define XTS_AES_CTX_SIZE (sizeof(struct aesni_xts_ctx) + AESNI_ALIGN_EXTRA)
48 
49 struct aesni_xts_ctx {
50 	struct crypto_aes_ctx tweak_ctx AESNI_ALIGN_ATTR;
51 	struct crypto_aes_ctx crypt_ctx AESNI_ALIGN_ATTR;
52 };
53 
54 static inline void *aes_align_addr(void *addr)
55 {
56 	if (crypto_tfm_ctx_alignment() >= AESNI_ALIGN)
57 		return addr;
58 	return PTR_ALIGN(addr, AESNI_ALIGN);
59 }
60 
61 asmlinkage void aesni_set_key(struct crypto_aes_ctx *ctx, const u8 *in_key,
62 			      unsigned int key_len);
63 asmlinkage void aesni_enc(const void *ctx, u8 *out, const u8 *in);
64 asmlinkage void aesni_dec(const void *ctx, u8 *out, const u8 *in);
65 asmlinkage void aesni_ecb_enc(struct crypto_aes_ctx *ctx, u8 *out,
66 			      const u8 *in, unsigned int len);
67 asmlinkage void aesni_ecb_dec(struct crypto_aes_ctx *ctx, u8 *out,
68 			      const u8 *in, unsigned int len);
69 asmlinkage void aesni_cbc_enc(struct crypto_aes_ctx *ctx, u8 *out,
70 			      const u8 *in, unsigned int len, u8 *iv);
71 asmlinkage void aesni_cbc_dec(struct crypto_aes_ctx *ctx, u8 *out,
72 			      const u8 *in, unsigned int len, u8 *iv);
73 asmlinkage void aesni_cts_cbc_enc(struct crypto_aes_ctx *ctx, u8 *out,
74 				  const u8 *in, unsigned int len, u8 *iv);
75 asmlinkage void aesni_cts_cbc_dec(struct crypto_aes_ctx *ctx, u8 *out,
76 				  const u8 *in, unsigned int len, u8 *iv);
77 
78 asmlinkage void aesni_xts_enc(const struct crypto_aes_ctx *ctx, u8 *out,
79 			      const u8 *in, unsigned int len, u8 *iv);
80 
81 asmlinkage void aesni_xts_dec(const struct crypto_aes_ctx *ctx, u8 *out,
82 			      const u8 *in, unsigned int len, u8 *iv);
83 
84 #ifdef CONFIG_X86_64
85 
86 asmlinkage void aesni_ctr_enc(struct crypto_aes_ctx *ctx, u8 *out,
87 			      const u8 *in, unsigned int len, u8 *iv);
88 DEFINE_STATIC_CALL(aesni_ctr_enc_tfm, aesni_ctr_enc);
89 
90 asmlinkage void aes_ctr_enc_128_avx_by8(const u8 *in, u8 *iv,
91 		void *keys, u8 *out, unsigned int num_bytes);
92 asmlinkage void aes_ctr_enc_192_avx_by8(const u8 *in, u8 *iv,
93 		void *keys, u8 *out, unsigned int num_bytes);
94 asmlinkage void aes_ctr_enc_256_avx_by8(const u8 *in, u8 *iv,
95 		void *keys, u8 *out, unsigned int num_bytes);
96 
97 
98 asmlinkage void aes_xctr_enc_128_avx_by8(const u8 *in, const u8 *iv,
99 	const void *keys, u8 *out, unsigned int num_bytes,
100 	unsigned int byte_ctr);
101 
102 asmlinkage void aes_xctr_enc_192_avx_by8(const u8 *in, const u8 *iv,
103 	const void *keys, u8 *out, unsigned int num_bytes,
104 	unsigned int byte_ctr);
105 
106 asmlinkage void aes_xctr_enc_256_avx_by8(const u8 *in, const u8 *iv,
107 	const void *keys, u8 *out, unsigned int num_bytes,
108 	unsigned int byte_ctr);
109 #endif
110 
111 static inline struct crypto_aes_ctx *aes_ctx(void *raw_ctx)
112 {
113 	return aes_align_addr(raw_ctx);
114 }
115 
116 static inline struct aesni_xts_ctx *aes_xts_ctx(struct crypto_skcipher *tfm)
117 {
118 	return aes_align_addr(crypto_skcipher_ctx(tfm));
119 }
120 
121 static int aes_set_key_common(struct crypto_aes_ctx *ctx,
122 			      const u8 *in_key, unsigned int key_len)
123 {
124 	int err;
125 
126 	if (!crypto_simd_usable())
127 		return aes_expandkey(ctx, in_key, key_len);
128 
129 	err = aes_check_keylen(key_len);
130 	if (err)
131 		return err;
132 
133 	kernel_fpu_begin();
134 	aesni_set_key(ctx, in_key, key_len);
135 	kernel_fpu_end();
136 	return 0;
137 }
138 
139 static int aes_set_key(struct crypto_tfm *tfm, const u8 *in_key,
140 		       unsigned int key_len)
141 {
142 	return aes_set_key_common(aes_ctx(crypto_tfm_ctx(tfm)), in_key,
143 				  key_len);
144 }
145 
146 static void aesni_encrypt(struct crypto_tfm *tfm, u8 *dst, const u8 *src)
147 {
148 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_tfm_ctx(tfm));
149 
150 	if (!crypto_simd_usable()) {
151 		aes_encrypt(ctx, dst, src);
152 	} else {
153 		kernel_fpu_begin();
154 		aesni_enc(ctx, dst, src);
155 		kernel_fpu_end();
156 	}
157 }
158 
159 static void aesni_decrypt(struct crypto_tfm *tfm, u8 *dst, const u8 *src)
160 {
161 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_tfm_ctx(tfm));
162 
163 	if (!crypto_simd_usable()) {
164 		aes_decrypt(ctx, dst, src);
165 	} else {
166 		kernel_fpu_begin();
167 		aesni_dec(ctx, dst, src);
168 		kernel_fpu_end();
169 	}
170 }
171 
172 static int aesni_skcipher_setkey(struct crypto_skcipher *tfm, const u8 *key,
173 			         unsigned int len)
174 {
175 	return aes_set_key_common(aes_ctx(crypto_skcipher_ctx(tfm)), key, len);
176 }
177 
178 static int ecb_encrypt(struct skcipher_request *req)
179 {
180 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
181 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
182 	struct skcipher_walk walk;
183 	unsigned int nbytes;
184 	int err;
185 
186 	err = skcipher_walk_virt(&walk, req, false);
187 
188 	while ((nbytes = walk.nbytes)) {
189 		kernel_fpu_begin();
190 		aesni_ecb_enc(ctx, walk.dst.virt.addr, walk.src.virt.addr,
191 			      nbytes & AES_BLOCK_MASK);
192 		kernel_fpu_end();
193 		nbytes &= AES_BLOCK_SIZE - 1;
194 		err = skcipher_walk_done(&walk, nbytes);
195 	}
196 
197 	return err;
198 }
199 
200 static int ecb_decrypt(struct skcipher_request *req)
201 {
202 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
203 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
204 	struct skcipher_walk walk;
205 	unsigned int nbytes;
206 	int err;
207 
208 	err = skcipher_walk_virt(&walk, req, false);
209 
210 	while ((nbytes = walk.nbytes)) {
211 		kernel_fpu_begin();
212 		aesni_ecb_dec(ctx, walk.dst.virt.addr, walk.src.virt.addr,
213 			      nbytes & AES_BLOCK_MASK);
214 		kernel_fpu_end();
215 		nbytes &= AES_BLOCK_SIZE - 1;
216 		err = skcipher_walk_done(&walk, nbytes);
217 	}
218 
219 	return err;
220 }
221 
222 static int cbc_encrypt(struct skcipher_request *req)
223 {
224 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
225 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
226 	struct skcipher_walk walk;
227 	unsigned int nbytes;
228 	int err;
229 
230 	err = skcipher_walk_virt(&walk, req, false);
231 
232 	while ((nbytes = walk.nbytes)) {
233 		kernel_fpu_begin();
234 		aesni_cbc_enc(ctx, walk.dst.virt.addr, walk.src.virt.addr,
235 			      nbytes & AES_BLOCK_MASK, walk.iv);
236 		kernel_fpu_end();
237 		nbytes &= AES_BLOCK_SIZE - 1;
238 		err = skcipher_walk_done(&walk, nbytes);
239 	}
240 
241 	return err;
242 }
243 
244 static int cbc_decrypt(struct skcipher_request *req)
245 {
246 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
247 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
248 	struct skcipher_walk walk;
249 	unsigned int nbytes;
250 	int err;
251 
252 	err = skcipher_walk_virt(&walk, req, false);
253 
254 	while ((nbytes = walk.nbytes)) {
255 		kernel_fpu_begin();
256 		aesni_cbc_dec(ctx, walk.dst.virt.addr, walk.src.virt.addr,
257 			      nbytes & AES_BLOCK_MASK, walk.iv);
258 		kernel_fpu_end();
259 		nbytes &= AES_BLOCK_SIZE - 1;
260 		err = skcipher_walk_done(&walk, nbytes);
261 	}
262 
263 	return err;
264 }
265 
266 static int cts_cbc_encrypt(struct skcipher_request *req)
267 {
268 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
269 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
270 	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
271 	struct scatterlist *src = req->src, *dst = req->dst;
272 	struct scatterlist sg_src[2], sg_dst[2];
273 	struct skcipher_request subreq;
274 	struct skcipher_walk walk;
275 	int err;
276 
277 	skcipher_request_set_tfm(&subreq, tfm);
278 	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
279 				      NULL, NULL);
280 
281 	if (req->cryptlen <= AES_BLOCK_SIZE) {
282 		if (req->cryptlen < AES_BLOCK_SIZE)
283 			return -EINVAL;
284 		cbc_blocks = 1;
285 	}
286 
287 	if (cbc_blocks > 0) {
288 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
289 					   cbc_blocks * AES_BLOCK_SIZE,
290 					   req->iv);
291 
292 		err = cbc_encrypt(&subreq);
293 		if (err)
294 			return err;
295 
296 		if (req->cryptlen == AES_BLOCK_SIZE)
297 			return 0;
298 
299 		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
300 		if (req->dst != req->src)
301 			dst = scatterwalk_ffwd(sg_dst, req->dst,
302 					       subreq.cryptlen);
303 	}
304 
305 	/* handle ciphertext stealing */
306 	skcipher_request_set_crypt(&subreq, src, dst,
307 				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
308 				   req->iv);
309 
310 	err = skcipher_walk_virt(&walk, &subreq, false);
311 	if (err)
312 		return err;
313 
314 	kernel_fpu_begin();
315 	aesni_cts_cbc_enc(ctx, walk.dst.virt.addr, walk.src.virt.addr,
316 			  walk.nbytes, walk.iv);
317 	kernel_fpu_end();
318 
319 	return skcipher_walk_done(&walk, 0);
320 }
321 
322 static int cts_cbc_decrypt(struct skcipher_request *req)
323 {
324 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
325 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
326 	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
327 	struct scatterlist *src = req->src, *dst = req->dst;
328 	struct scatterlist sg_src[2], sg_dst[2];
329 	struct skcipher_request subreq;
330 	struct skcipher_walk walk;
331 	int err;
332 
333 	skcipher_request_set_tfm(&subreq, tfm);
334 	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
335 				      NULL, NULL);
336 
337 	if (req->cryptlen <= AES_BLOCK_SIZE) {
338 		if (req->cryptlen < AES_BLOCK_SIZE)
339 			return -EINVAL;
340 		cbc_blocks = 1;
341 	}
342 
343 	if (cbc_blocks > 0) {
344 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
345 					   cbc_blocks * AES_BLOCK_SIZE,
346 					   req->iv);
347 
348 		err = cbc_decrypt(&subreq);
349 		if (err)
350 			return err;
351 
352 		if (req->cryptlen == AES_BLOCK_SIZE)
353 			return 0;
354 
355 		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
356 		if (req->dst != req->src)
357 			dst = scatterwalk_ffwd(sg_dst, req->dst,
358 					       subreq.cryptlen);
359 	}
360 
361 	/* handle ciphertext stealing */
362 	skcipher_request_set_crypt(&subreq, src, dst,
363 				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
364 				   req->iv);
365 
366 	err = skcipher_walk_virt(&walk, &subreq, false);
367 	if (err)
368 		return err;
369 
370 	kernel_fpu_begin();
371 	aesni_cts_cbc_dec(ctx, walk.dst.virt.addr, walk.src.virt.addr,
372 			  walk.nbytes, walk.iv);
373 	kernel_fpu_end();
374 
375 	return skcipher_walk_done(&walk, 0);
376 }
377 
378 #ifdef CONFIG_X86_64
379 static void aesni_ctr_enc_avx_tfm(struct crypto_aes_ctx *ctx, u8 *out,
380 			      const u8 *in, unsigned int len, u8 *iv)
381 {
382 	/*
383 	 * based on key length, override with the by8 version
384 	 * of ctr mode encryption/decryption for improved performance
385 	 * aes_set_key_common() ensures that key length is one of
386 	 * {128,192,256}
387 	 */
388 	if (ctx->key_length == AES_KEYSIZE_128)
389 		aes_ctr_enc_128_avx_by8(in, iv, (void *)ctx, out, len);
390 	else if (ctx->key_length == AES_KEYSIZE_192)
391 		aes_ctr_enc_192_avx_by8(in, iv, (void *)ctx, out, len);
392 	else
393 		aes_ctr_enc_256_avx_by8(in, iv, (void *)ctx, out, len);
394 }
395 
396 static int ctr_crypt(struct skcipher_request *req)
397 {
398 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
399 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
400 	u8 keystream[AES_BLOCK_SIZE];
401 	struct skcipher_walk walk;
402 	unsigned int nbytes;
403 	int err;
404 
405 	err = skcipher_walk_virt(&walk, req, false);
406 
407 	while ((nbytes = walk.nbytes) > 0) {
408 		kernel_fpu_begin();
409 		if (nbytes & AES_BLOCK_MASK)
410 			static_call(aesni_ctr_enc_tfm)(ctx, walk.dst.virt.addr,
411 						       walk.src.virt.addr,
412 						       nbytes & AES_BLOCK_MASK,
413 						       walk.iv);
414 		nbytes &= ~AES_BLOCK_MASK;
415 
416 		if (walk.nbytes == walk.total && nbytes > 0) {
417 			aesni_enc(ctx, keystream, walk.iv);
418 			crypto_xor_cpy(walk.dst.virt.addr + walk.nbytes - nbytes,
419 				       walk.src.virt.addr + walk.nbytes - nbytes,
420 				       keystream, nbytes);
421 			crypto_inc(walk.iv, AES_BLOCK_SIZE);
422 			nbytes = 0;
423 		}
424 		kernel_fpu_end();
425 		err = skcipher_walk_done(&walk, nbytes);
426 	}
427 	return err;
428 }
429 
430 static void aesni_xctr_enc_avx_tfm(struct crypto_aes_ctx *ctx, u8 *out,
431 				   const u8 *in, unsigned int len, u8 *iv,
432 				   unsigned int byte_ctr)
433 {
434 	if (ctx->key_length == AES_KEYSIZE_128)
435 		aes_xctr_enc_128_avx_by8(in, iv, (void *)ctx, out, len,
436 					 byte_ctr);
437 	else if (ctx->key_length == AES_KEYSIZE_192)
438 		aes_xctr_enc_192_avx_by8(in, iv, (void *)ctx, out, len,
439 					 byte_ctr);
440 	else
441 		aes_xctr_enc_256_avx_by8(in, iv, (void *)ctx, out, len,
442 					 byte_ctr);
443 }
444 
445 static int xctr_crypt(struct skcipher_request *req)
446 {
447 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
448 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
449 	u8 keystream[AES_BLOCK_SIZE];
450 	struct skcipher_walk walk;
451 	unsigned int nbytes;
452 	unsigned int byte_ctr = 0;
453 	int err;
454 	__le32 block[AES_BLOCK_SIZE / sizeof(__le32)];
455 
456 	err = skcipher_walk_virt(&walk, req, false);
457 
458 	while ((nbytes = walk.nbytes) > 0) {
459 		kernel_fpu_begin();
460 		if (nbytes & AES_BLOCK_MASK)
461 			aesni_xctr_enc_avx_tfm(ctx, walk.dst.virt.addr,
462 				walk.src.virt.addr, nbytes & AES_BLOCK_MASK,
463 				walk.iv, byte_ctr);
464 		nbytes &= ~AES_BLOCK_MASK;
465 		byte_ctr += walk.nbytes - nbytes;
466 
467 		if (walk.nbytes == walk.total && nbytes > 0) {
468 			memcpy(block, walk.iv, AES_BLOCK_SIZE);
469 			block[0] ^= cpu_to_le32(1 + byte_ctr / AES_BLOCK_SIZE);
470 			aesni_enc(ctx, keystream, (u8 *)block);
471 			crypto_xor_cpy(walk.dst.virt.addr + walk.nbytes -
472 				       nbytes, walk.src.virt.addr + walk.nbytes
473 				       - nbytes, keystream, nbytes);
474 			byte_ctr += nbytes;
475 			nbytes = 0;
476 		}
477 		kernel_fpu_end();
478 		err = skcipher_walk_done(&walk, nbytes);
479 	}
480 	return err;
481 }
482 #endif
483 
484 static int xts_setkey_aesni(struct crypto_skcipher *tfm, const u8 *key,
485 			    unsigned int keylen)
486 {
487 	struct aesni_xts_ctx *ctx = aes_xts_ctx(tfm);
488 	int err;
489 
490 	err = xts_verify_key(tfm, key, keylen);
491 	if (err)
492 		return err;
493 
494 	keylen /= 2;
495 
496 	/* first half of xts-key is for crypt */
497 	err = aes_set_key_common(&ctx->crypt_ctx, key, keylen);
498 	if (err)
499 		return err;
500 
501 	/* second half of xts-key is for tweak */
502 	return aes_set_key_common(&ctx->tweak_ctx, key + keylen, keylen);
503 }
504 
505 typedef void (*xts_encrypt_iv_func)(const struct crypto_aes_ctx *tweak_key,
506 				    u8 iv[AES_BLOCK_SIZE]);
507 typedef void (*xts_crypt_func)(const struct crypto_aes_ctx *key,
508 			       const u8 *src, u8 *dst, unsigned int len,
509 			       u8 tweak[AES_BLOCK_SIZE]);
510 
511 /* This handles cases where the source and/or destination span pages. */
512 static noinline int
513 xts_crypt_slowpath(struct skcipher_request *req, xts_crypt_func crypt_func)
514 {
515 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
516 	const struct aesni_xts_ctx *ctx = aes_xts_ctx(tfm);
517 	int tail = req->cryptlen % AES_BLOCK_SIZE;
518 	struct scatterlist sg_src[2], sg_dst[2];
519 	struct skcipher_request subreq;
520 	struct skcipher_walk walk;
521 	struct scatterlist *src, *dst;
522 	int err;
523 
524 	/*
525 	 * If the message length isn't divisible by the AES block size, then
526 	 * separate off the last full block and the partial block.  This ensures
527 	 * that they are processed in the same call to the assembly function,
528 	 * which is required for ciphertext stealing.
529 	 */
530 	if (tail) {
531 		skcipher_request_set_tfm(&subreq, tfm);
532 		skcipher_request_set_callback(&subreq,
533 					      skcipher_request_flags(req),
534 					      NULL, NULL);
535 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
536 					   req->cryptlen - tail - AES_BLOCK_SIZE,
537 					   req->iv);
538 		req = &subreq;
539 	}
540 
541 	err = skcipher_walk_virt(&walk, req, false);
542 
543 	while (walk.nbytes) {
544 		kernel_fpu_begin();
545 		(*crypt_func)(&ctx->crypt_ctx,
546 			      walk.src.virt.addr, walk.dst.virt.addr,
547 			      walk.nbytes & ~(AES_BLOCK_SIZE - 1), req->iv);
548 		kernel_fpu_end();
549 		err = skcipher_walk_done(&walk,
550 					 walk.nbytes & (AES_BLOCK_SIZE - 1));
551 	}
552 
553 	if (err || !tail)
554 		return err;
555 
556 	/* Do ciphertext stealing with the last full block and partial block. */
557 
558 	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
559 	if (req->dst != req->src)
560 		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
561 
562 	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
563 				   req->iv);
564 
565 	err = skcipher_walk_virt(&walk, req, false);
566 	if (err)
567 		return err;
568 
569 	kernel_fpu_begin();
570 	(*crypt_func)(&ctx->crypt_ctx, walk.src.virt.addr, walk.dst.virt.addr,
571 		      walk.nbytes, req->iv);
572 	kernel_fpu_end();
573 
574 	return skcipher_walk_done(&walk, 0);
575 }
576 
577 /* __always_inline to avoid indirect call in fastpath */
578 static __always_inline int
579 xts_crypt(struct skcipher_request *req, xts_encrypt_iv_func encrypt_iv,
580 	  xts_crypt_func crypt_func)
581 {
582 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
583 	const struct aesni_xts_ctx *ctx = aes_xts_ctx(tfm);
584 	const unsigned int cryptlen = req->cryptlen;
585 	struct scatterlist *src = req->src;
586 	struct scatterlist *dst = req->dst;
587 
588 	if (unlikely(cryptlen < AES_BLOCK_SIZE))
589 		return -EINVAL;
590 
591 	kernel_fpu_begin();
592 	(*encrypt_iv)(&ctx->tweak_ctx, req->iv);
593 
594 	/*
595 	 * In practice, virtually all XTS plaintexts and ciphertexts are either
596 	 * 512 or 4096 bytes, aligned such that they don't span page boundaries.
597 	 * To optimize the performance of these cases, and also any other case
598 	 * where no page boundary is spanned, the below fast-path handles
599 	 * single-page sources and destinations as efficiently as possible.
600 	 */
601 	if (likely(src->length >= cryptlen && dst->length >= cryptlen &&
602 		   src->offset + cryptlen <= PAGE_SIZE &&
603 		   dst->offset + cryptlen <= PAGE_SIZE)) {
604 		struct page *src_page = sg_page(src);
605 		struct page *dst_page = sg_page(dst);
606 		void *src_virt = kmap_local_page(src_page) + src->offset;
607 		void *dst_virt = kmap_local_page(dst_page) + dst->offset;
608 
609 		(*crypt_func)(&ctx->crypt_ctx, src_virt, dst_virt, cryptlen,
610 			      req->iv);
611 		kunmap_local(dst_virt);
612 		kunmap_local(src_virt);
613 		kernel_fpu_end();
614 		return 0;
615 	}
616 	kernel_fpu_end();
617 	return xts_crypt_slowpath(req, crypt_func);
618 }
619 
620 static void aesni_xts_encrypt_iv(const struct crypto_aes_ctx *tweak_key,
621 				 u8 iv[AES_BLOCK_SIZE])
622 {
623 	aesni_enc(tweak_key, iv, iv);
624 }
625 
626 static void aesni_xts_encrypt(const struct crypto_aes_ctx *key,
627 			      const u8 *src, u8 *dst, unsigned int len,
628 			      u8 tweak[AES_BLOCK_SIZE])
629 {
630 	aesni_xts_enc(key, dst, src, len, tweak);
631 }
632 
633 static void aesni_xts_decrypt(const struct crypto_aes_ctx *key,
634 			      const u8 *src, u8 *dst, unsigned int len,
635 			      u8 tweak[AES_BLOCK_SIZE])
636 {
637 	aesni_xts_dec(key, dst, src, len, tweak);
638 }
639 
640 static int xts_encrypt_aesni(struct skcipher_request *req)
641 {
642 	return xts_crypt(req, aesni_xts_encrypt_iv, aesni_xts_encrypt);
643 }
644 
645 static int xts_decrypt_aesni(struct skcipher_request *req)
646 {
647 	return xts_crypt(req, aesni_xts_encrypt_iv, aesni_xts_decrypt);
648 }
649 
650 static struct crypto_alg aesni_cipher_alg = {
651 	.cra_name		= "aes",
652 	.cra_driver_name	= "aes-aesni",
653 	.cra_priority		= 300,
654 	.cra_flags		= CRYPTO_ALG_TYPE_CIPHER,
655 	.cra_blocksize		= AES_BLOCK_SIZE,
656 	.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
657 	.cra_module		= THIS_MODULE,
658 	.cra_u	= {
659 		.cipher	= {
660 			.cia_min_keysize	= AES_MIN_KEY_SIZE,
661 			.cia_max_keysize	= AES_MAX_KEY_SIZE,
662 			.cia_setkey		= aes_set_key,
663 			.cia_encrypt		= aesni_encrypt,
664 			.cia_decrypt		= aesni_decrypt
665 		}
666 	}
667 };
668 
669 static struct skcipher_alg aesni_skciphers[] = {
670 	{
671 		.base = {
672 			.cra_name		= "__ecb(aes)",
673 			.cra_driver_name	= "__ecb-aes-aesni",
674 			.cra_priority		= 400,
675 			.cra_flags		= CRYPTO_ALG_INTERNAL,
676 			.cra_blocksize		= AES_BLOCK_SIZE,
677 			.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
678 			.cra_module		= THIS_MODULE,
679 		},
680 		.min_keysize	= AES_MIN_KEY_SIZE,
681 		.max_keysize	= AES_MAX_KEY_SIZE,
682 		.setkey		= aesni_skcipher_setkey,
683 		.encrypt	= ecb_encrypt,
684 		.decrypt	= ecb_decrypt,
685 	}, {
686 		.base = {
687 			.cra_name		= "__cbc(aes)",
688 			.cra_driver_name	= "__cbc-aes-aesni",
689 			.cra_priority		= 400,
690 			.cra_flags		= CRYPTO_ALG_INTERNAL,
691 			.cra_blocksize		= AES_BLOCK_SIZE,
692 			.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
693 			.cra_module		= THIS_MODULE,
694 		},
695 		.min_keysize	= AES_MIN_KEY_SIZE,
696 		.max_keysize	= AES_MAX_KEY_SIZE,
697 		.ivsize		= AES_BLOCK_SIZE,
698 		.setkey		= aesni_skcipher_setkey,
699 		.encrypt	= cbc_encrypt,
700 		.decrypt	= cbc_decrypt,
701 	}, {
702 		.base = {
703 			.cra_name		= "__cts(cbc(aes))",
704 			.cra_driver_name	= "__cts-cbc-aes-aesni",
705 			.cra_priority		= 400,
706 			.cra_flags		= CRYPTO_ALG_INTERNAL,
707 			.cra_blocksize		= AES_BLOCK_SIZE,
708 			.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
709 			.cra_module		= THIS_MODULE,
710 		},
711 		.min_keysize	= AES_MIN_KEY_SIZE,
712 		.max_keysize	= AES_MAX_KEY_SIZE,
713 		.ivsize		= AES_BLOCK_SIZE,
714 		.walksize	= 2 * AES_BLOCK_SIZE,
715 		.setkey		= aesni_skcipher_setkey,
716 		.encrypt	= cts_cbc_encrypt,
717 		.decrypt	= cts_cbc_decrypt,
718 #ifdef CONFIG_X86_64
719 	}, {
720 		.base = {
721 			.cra_name		= "__ctr(aes)",
722 			.cra_driver_name	= "__ctr-aes-aesni",
723 			.cra_priority		= 400,
724 			.cra_flags		= CRYPTO_ALG_INTERNAL,
725 			.cra_blocksize		= 1,
726 			.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
727 			.cra_module		= THIS_MODULE,
728 		},
729 		.min_keysize	= AES_MIN_KEY_SIZE,
730 		.max_keysize	= AES_MAX_KEY_SIZE,
731 		.ivsize		= AES_BLOCK_SIZE,
732 		.chunksize	= AES_BLOCK_SIZE,
733 		.setkey		= aesni_skcipher_setkey,
734 		.encrypt	= ctr_crypt,
735 		.decrypt	= ctr_crypt,
736 #endif
737 	}, {
738 		.base = {
739 			.cra_name		= "__xts(aes)",
740 			.cra_driver_name	= "__xts-aes-aesni",
741 			.cra_priority		= 401,
742 			.cra_flags		= CRYPTO_ALG_INTERNAL,
743 			.cra_blocksize		= AES_BLOCK_SIZE,
744 			.cra_ctxsize		= XTS_AES_CTX_SIZE,
745 			.cra_module		= THIS_MODULE,
746 		},
747 		.min_keysize	= 2 * AES_MIN_KEY_SIZE,
748 		.max_keysize	= 2 * AES_MAX_KEY_SIZE,
749 		.ivsize		= AES_BLOCK_SIZE,
750 		.walksize	= 2 * AES_BLOCK_SIZE,
751 		.setkey		= xts_setkey_aesni,
752 		.encrypt	= xts_encrypt_aesni,
753 		.decrypt	= xts_decrypt_aesni,
754 	}
755 };
756 
757 static
758 struct simd_skcipher_alg *aesni_simd_skciphers[ARRAY_SIZE(aesni_skciphers)];
759 
760 #ifdef CONFIG_X86_64
761 /*
762  * XCTR does not have a non-AVX implementation, so it must be enabled
763  * conditionally.
764  */
765 static struct skcipher_alg aesni_xctr = {
766 	.base = {
767 		.cra_name		= "__xctr(aes)",
768 		.cra_driver_name	= "__xctr-aes-aesni",
769 		.cra_priority		= 400,
770 		.cra_flags		= CRYPTO_ALG_INTERNAL,
771 		.cra_blocksize		= 1,
772 		.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
773 		.cra_module		= THIS_MODULE,
774 	},
775 	.min_keysize	= AES_MIN_KEY_SIZE,
776 	.max_keysize	= AES_MAX_KEY_SIZE,
777 	.ivsize		= AES_BLOCK_SIZE,
778 	.chunksize	= AES_BLOCK_SIZE,
779 	.setkey		= aesni_skcipher_setkey,
780 	.encrypt	= xctr_crypt,
781 	.decrypt	= xctr_crypt,
782 };
783 
784 static struct simd_skcipher_alg *aesni_simd_xctr;
785 
786 asmlinkage void aes_xts_encrypt_iv(const struct crypto_aes_ctx *tweak_key,
787 				   u8 iv[AES_BLOCK_SIZE]);
788 
789 #define DEFINE_XTS_ALG(suffix, driver_name, priority)			       \
790 									       \
791 asmlinkage void								       \
792 aes_xts_encrypt_##suffix(const struct crypto_aes_ctx *key, const u8 *src,      \
793 			 u8 *dst, unsigned int len, u8 tweak[AES_BLOCK_SIZE]); \
794 asmlinkage void								       \
795 aes_xts_decrypt_##suffix(const struct crypto_aes_ctx *key, const u8 *src,      \
796 			 u8 *dst, unsigned int len, u8 tweak[AES_BLOCK_SIZE]); \
797 									       \
798 static int xts_encrypt_##suffix(struct skcipher_request *req)		       \
799 {									       \
800 	return xts_crypt(req, aes_xts_encrypt_iv, aes_xts_encrypt_##suffix);   \
801 }									       \
802 									       \
803 static int xts_decrypt_##suffix(struct skcipher_request *req)		       \
804 {									       \
805 	return xts_crypt(req, aes_xts_encrypt_iv, aes_xts_decrypt_##suffix);   \
806 }									       \
807 									       \
808 static struct skcipher_alg aes_xts_alg_##suffix = {			       \
809 	.base = {							       \
810 		.cra_name		= "__xts(aes)",			       \
811 		.cra_driver_name	= "__" driver_name,		       \
812 		.cra_priority		= priority,			       \
813 		.cra_flags		= CRYPTO_ALG_INTERNAL,		       \
814 		.cra_blocksize		= AES_BLOCK_SIZE,		       \
815 		.cra_ctxsize		= XTS_AES_CTX_SIZE,		       \
816 		.cra_module		= THIS_MODULE,			       \
817 	},								       \
818 	.min_keysize	= 2 * AES_MIN_KEY_SIZE,				       \
819 	.max_keysize	= 2 * AES_MAX_KEY_SIZE,				       \
820 	.ivsize		= AES_BLOCK_SIZE,				       \
821 	.walksize	= 2 * AES_BLOCK_SIZE,				       \
822 	.setkey		= xts_setkey_aesni,				       \
823 	.encrypt	= xts_encrypt_##suffix,				       \
824 	.decrypt	= xts_decrypt_##suffix,				       \
825 };									       \
826 									       \
827 static struct simd_skcipher_alg *aes_xts_simdalg_##suffix
828 
829 DEFINE_XTS_ALG(aesni_avx, "xts-aes-aesni-avx", 500);
830 #if defined(CONFIG_AS_VAES) && defined(CONFIG_AS_VPCLMULQDQ)
831 DEFINE_XTS_ALG(vaes_avx2, "xts-aes-vaes-avx2", 600);
832 DEFINE_XTS_ALG(vaes_avx10_256, "xts-aes-vaes-avx10_256", 700);
833 DEFINE_XTS_ALG(vaes_avx10_512, "xts-aes-vaes-avx10_512", 800);
834 #endif
835 
836 /* The common part of the x86_64 AES-GCM key struct */
837 struct aes_gcm_key {
838 	/* Expanded AES key and the AES key length in bytes */
839 	struct crypto_aes_ctx aes_key;
840 
841 	/* RFC4106 nonce (used only by the rfc4106 algorithms) */
842 	u32 rfc4106_nonce;
843 };
844 
845 /* Key struct used by the AES-NI implementations of AES-GCM */
846 struct aes_gcm_key_aesni {
847 	/*
848 	 * Common part of the key.  The assembly code requires 16-byte alignment
849 	 * for the round keys; we get this by them being located at the start of
850 	 * the struct and the whole struct being 16-byte aligned.
851 	 */
852 	struct aes_gcm_key base;
853 
854 	/*
855 	 * Powers of the hash key H^8 through H^1.  These are 128-bit values.
856 	 * They all have an extra factor of x^-1 and are byte-reversed.  16-byte
857 	 * alignment is required by the assembly code.
858 	 */
859 	u64 h_powers[8][2] __aligned(16);
860 
861 	/*
862 	 * h_powers_xored[i] contains the two 64-bit halves of h_powers[i] XOR'd
863 	 * together.  It's used for Karatsuba multiplication.  16-byte alignment
864 	 * is required by the assembly code.
865 	 */
866 	u64 h_powers_xored[8] __aligned(16);
867 
868 	/*
869 	 * H^1 times x^64 (and also the usual extra factor of x^-1).  16-byte
870 	 * alignment is required by the assembly code.
871 	 */
872 	u64 h_times_x64[2] __aligned(16);
873 };
874 #define AES_GCM_KEY_AESNI(key)	\
875 	container_of((key), struct aes_gcm_key_aesni, base)
876 #define AES_GCM_KEY_AESNI_SIZE	\
877 	(sizeof(struct aes_gcm_key_aesni) + (15 & ~(CRYPTO_MINALIGN - 1)))
878 
879 /* Key struct used by the VAES + AVX10 implementations of AES-GCM */
880 struct aes_gcm_key_avx10 {
881 	/*
882 	 * Common part of the key.  The assembly code prefers 16-byte alignment
883 	 * for the round keys; we get this by them being located at the start of
884 	 * the struct and the whole struct being 64-byte aligned.
885 	 */
886 	struct aes_gcm_key base;
887 
888 	/*
889 	 * Powers of the hash key H^16 through H^1.  These are 128-bit values.
890 	 * They all have an extra factor of x^-1 and are byte-reversed.  This
891 	 * array is aligned to a 64-byte boundary to make it naturally aligned
892 	 * for 512-bit loads, which can improve performance.  (The assembly code
893 	 * doesn't *need* the alignment; this is just an optimization.)
894 	 */
895 	u64 h_powers[16][2] __aligned(64);
896 
897 	/* Three padding blocks required by the assembly code */
898 	u64 padding[3][2];
899 };
900 #define AES_GCM_KEY_AVX10(key)	\
901 	container_of((key), struct aes_gcm_key_avx10, base)
902 #define AES_GCM_KEY_AVX10_SIZE	\
903 	(sizeof(struct aes_gcm_key_avx10) + (63 & ~(CRYPTO_MINALIGN - 1)))
904 
905 /*
906  * These flags are passed to the AES-GCM helper functions to specify the
907  * specific version of AES-GCM (RFC4106 or not), whether it's encryption or
908  * decryption, and which assembly functions should be called.  Assembly
909  * functions are selected using flags instead of function pointers to avoid
910  * indirect calls (which are very expensive on x86) regardless of inlining.
911  */
912 #define FLAG_RFC4106	BIT(0)
913 #define FLAG_ENC	BIT(1)
914 #define FLAG_AVX	BIT(2)
915 #if defined(CONFIG_AS_VAES) && defined(CONFIG_AS_VPCLMULQDQ)
916 #  define FLAG_AVX10_256	BIT(3)
917 #  define FLAG_AVX10_512	BIT(4)
918 #else
919    /*
920     * This should cause all calls to the AVX10 assembly functions to be
921     * optimized out, avoiding the need to ifdef each call individually.
922     */
923 #  define FLAG_AVX10_256	0
924 #  define FLAG_AVX10_512	0
925 #endif
926 
927 static inline struct aes_gcm_key *
928 aes_gcm_key_get(struct crypto_aead *tfm, int flags)
929 {
930 	if (flags & (FLAG_AVX10_256 | FLAG_AVX10_512))
931 		return PTR_ALIGN(crypto_aead_ctx(tfm), 64);
932 	else
933 		return PTR_ALIGN(crypto_aead_ctx(tfm), 16);
934 }
935 
936 asmlinkage void
937 aes_gcm_precompute_aesni(struct aes_gcm_key_aesni *key);
938 asmlinkage void
939 aes_gcm_precompute_aesni_avx(struct aes_gcm_key_aesni *key);
940 asmlinkage void
941 aes_gcm_precompute_vaes_avx10_256(struct aes_gcm_key_avx10 *key);
942 asmlinkage void
943 aes_gcm_precompute_vaes_avx10_512(struct aes_gcm_key_avx10 *key);
944 
945 static void aes_gcm_precompute(struct aes_gcm_key *key, int flags)
946 {
947 	/*
948 	 * To make things a bit easier on the assembly side, the AVX10
949 	 * implementations use the same key format.  Therefore, a single
950 	 * function using 256-bit vectors would suffice here.  However, it's
951 	 * straightforward to provide a 512-bit one because of how the assembly
952 	 * code is structured, and it works nicely because the total size of the
953 	 * key powers is a multiple of 512 bits.  So we take advantage of that.
954 	 *
955 	 * A similar situation applies to the AES-NI implementations.
956 	 */
957 	if (flags & FLAG_AVX10_512)
958 		aes_gcm_precompute_vaes_avx10_512(AES_GCM_KEY_AVX10(key));
959 	else if (flags & FLAG_AVX10_256)
960 		aes_gcm_precompute_vaes_avx10_256(AES_GCM_KEY_AVX10(key));
961 	else if (flags & FLAG_AVX)
962 		aes_gcm_precompute_aesni_avx(AES_GCM_KEY_AESNI(key));
963 	else
964 		aes_gcm_precompute_aesni(AES_GCM_KEY_AESNI(key));
965 }
966 
967 asmlinkage void
968 aes_gcm_aad_update_aesni(const struct aes_gcm_key_aesni *key,
969 			 u8 ghash_acc[16], const u8 *aad, int aadlen);
970 asmlinkage void
971 aes_gcm_aad_update_aesni_avx(const struct aes_gcm_key_aesni *key,
972 			     u8 ghash_acc[16], const u8 *aad, int aadlen);
973 asmlinkage void
974 aes_gcm_aad_update_vaes_avx10(const struct aes_gcm_key_avx10 *key,
975 			      u8 ghash_acc[16], const u8 *aad, int aadlen);
976 
977 static void aes_gcm_aad_update(const struct aes_gcm_key *key, u8 ghash_acc[16],
978 			       const u8 *aad, int aadlen, int flags)
979 {
980 	if (flags & (FLAG_AVX10_256 | FLAG_AVX10_512))
981 		aes_gcm_aad_update_vaes_avx10(AES_GCM_KEY_AVX10(key), ghash_acc,
982 					      aad, aadlen);
983 	else if (flags & FLAG_AVX)
984 		aes_gcm_aad_update_aesni_avx(AES_GCM_KEY_AESNI(key), ghash_acc,
985 					     aad, aadlen);
986 	else
987 		aes_gcm_aad_update_aesni(AES_GCM_KEY_AESNI(key), ghash_acc,
988 					 aad, aadlen);
989 }
990 
991 asmlinkage void
992 aes_gcm_enc_update_aesni(const struct aes_gcm_key_aesni *key,
993 			 const u32 le_ctr[4], u8 ghash_acc[16],
994 			 const u8 *src, u8 *dst, int datalen);
995 asmlinkage void
996 aes_gcm_enc_update_aesni_avx(const struct aes_gcm_key_aesni *key,
997 			     const u32 le_ctr[4], u8 ghash_acc[16],
998 			     const u8 *src, u8 *dst, int datalen);
999 asmlinkage void
1000 aes_gcm_enc_update_vaes_avx10_256(const struct aes_gcm_key_avx10 *key,
1001 				  const u32 le_ctr[4], u8 ghash_acc[16],
1002 				  const u8 *src, u8 *dst, int datalen);
1003 asmlinkage void
1004 aes_gcm_enc_update_vaes_avx10_512(const struct aes_gcm_key_avx10 *key,
1005 				  const u32 le_ctr[4], u8 ghash_acc[16],
1006 				  const u8 *src, u8 *dst, int datalen);
1007 
1008 asmlinkage void
1009 aes_gcm_dec_update_aesni(const struct aes_gcm_key_aesni *key,
1010 			 const u32 le_ctr[4], u8 ghash_acc[16],
1011 			 const u8 *src, u8 *dst, int datalen);
1012 asmlinkage void
1013 aes_gcm_dec_update_aesni_avx(const struct aes_gcm_key_aesni *key,
1014 			     const u32 le_ctr[4], u8 ghash_acc[16],
1015 			     const u8 *src, u8 *dst, int datalen);
1016 asmlinkage void
1017 aes_gcm_dec_update_vaes_avx10_256(const struct aes_gcm_key_avx10 *key,
1018 				  const u32 le_ctr[4], u8 ghash_acc[16],
1019 				  const u8 *src, u8 *dst, int datalen);
1020 asmlinkage void
1021 aes_gcm_dec_update_vaes_avx10_512(const struct aes_gcm_key_avx10 *key,
1022 				  const u32 le_ctr[4], u8 ghash_acc[16],
1023 				  const u8 *src, u8 *dst, int datalen);
1024 
1025 /* __always_inline to optimize out the branches based on @flags */
1026 static __always_inline void
1027 aes_gcm_update(const struct aes_gcm_key *key,
1028 	       const u32 le_ctr[4], u8 ghash_acc[16],
1029 	       const u8 *src, u8 *dst, int datalen, int flags)
1030 {
1031 	if (flags & FLAG_ENC) {
1032 		if (flags & FLAG_AVX10_512)
1033 			aes_gcm_enc_update_vaes_avx10_512(AES_GCM_KEY_AVX10(key),
1034 							  le_ctr, ghash_acc,
1035 							  src, dst, datalen);
1036 		else if (flags & FLAG_AVX10_256)
1037 			aes_gcm_enc_update_vaes_avx10_256(AES_GCM_KEY_AVX10(key),
1038 							  le_ctr, ghash_acc,
1039 							  src, dst, datalen);
1040 		else if (flags & FLAG_AVX)
1041 			aes_gcm_enc_update_aesni_avx(AES_GCM_KEY_AESNI(key),
1042 						     le_ctr, ghash_acc,
1043 						     src, dst, datalen);
1044 		else
1045 			aes_gcm_enc_update_aesni(AES_GCM_KEY_AESNI(key), le_ctr,
1046 						 ghash_acc, src, dst, datalen);
1047 	} else {
1048 		if (flags & FLAG_AVX10_512)
1049 			aes_gcm_dec_update_vaes_avx10_512(AES_GCM_KEY_AVX10(key),
1050 							  le_ctr, ghash_acc,
1051 							  src, dst, datalen);
1052 		else if (flags & FLAG_AVX10_256)
1053 			aes_gcm_dec_update_vaes_avx10_256(AES_GCM_KEY_AVX10(key),
1054 							  le_ctr, ghash_acc,
1055 							  src, dst, datalen);
1056 		else if (flags & FLAG_AVX)
1057 			aes_gcm_dec_update_aesni_avx(AES_GCM_KEY_AESNI(key),
1058 						     le_ctr, ghash_acc,
1059 						     src, dst, datalen);
1060 		else
1061 			aes_gcm_dec_update_aesni(AES_GCM_KEY_AESNI(key),
1062 						 le_ctr, ghash_acc,
1063 						 src, dst, datalen);
1064 	}
1065 }
1066 
1067 asmlinkage void
1068 aes_gcm_enc_final_aesni(const struct aes_gcm_key_aesni *key,
1069 			const u32 le_ctr[4], u8 ghash_acc[16],
1070 			u64 total_aadlen, u64 total_datalen);
1071 asmlinkage void
1072 aes_gcm_enc_final_aesni_avx(const struct aes_gcm_key_aesni *key,
1073 			    const u32 le_ctr[4], u8 ghash_acc[16],
1074 			    u64 total_aadlen, u64 total_datalen);
1075 asmlinkage void
1076 aes_gcm_enc_final_vaes_avx10(const struct aes_gcm_key_avx10 *key,
1077 			     const u32 le_ctr[4], u8 ghash_acc[16],
1078 			     u64 total_aadlen, u64 total_datalen);
1079 
1080 /* __always_inline to optimize out the branches based on @flags */
1081 static __always_inline void
1082 aes_gcm_enc_final(const struct aes_gcm_key *key,
1083 		  const u32 le_ctr[4], u8 ghash_acc[16],
1084 		  u64 total_aadlen, u64 total_datalen, int flags)
1085 {
1086 	if (flags & (FLAG_AVX10_256 | FLAG_AVX10_512))
1087 		aes_gcm_enc_final_vaes_avx10(AES_GCM_KEY_AVX10(key),
1088 					     le_ctr, ghash_acc,
1089 					     total_aadlen, total_datalen);
1090 	else if (flags & FLAG_AVX)
1091 		aes_gcm_enc_final_aesni_avx(AES_GCM_KEY_AESNI(key),
1092 					    le_ctr, ghash_acc,
1093 					    total_aadlen, total_datalen);
1094 	else
1095 		aes_gcm_enc_final_aesni(AES_GCM_KEY_AESNI(key),
1096 					le_ctr, ghash_acc,
1097 					total_aadlen, total_datalen);
1098 }
1099 
1100 asmlinkage bool __must_check
1101 aes_gcm_dec_final_aesni(const struct aes_gcm_key_aesni *key,
1102 			const u32 le_ctr[4], const u8 ghash_acc[16],
1103 			u64 total_aadlen, u64 total_datalen,
1104 			const u8 tag[16], int taglen);
1105 asmlinkage bool __must_check
1106 aes_gcm_dec_final_aesni_avx(const struct aes_gcm_key_aesni *key,
1107 			    const u32 le_ctr[4], const u8 ghash_acc[16],
1108 			    u64 total_aadlen, u64 total_datalen,
1109 			    const u8 tag[16], int taglen);
1110 asmlinkage bool __must_check
1111 aes_gcm_dec_final_vaes_avx10(const struct aes_gcm_key_avx10 *key,
1112 			     const u32 le_ctr[4], const u8 ghash_acc[16],
1113 			     u64 total_aadlen, u64 total_datalen,
1114 			     const u8 tag[16], int taglen);
1115 
1116 /* __always_inline to optimize out the branches based on @flags */
1117 static __always_inline bool __must_check
1118 aes_gcm_dec_final(const struct aes_gcm_key *key, const u32 le_ctr[4],
1119 		  u8 ghash_acc[16], u64 total_aadlen, u64 total_datalen,
1120 		  u8 tag[16], int taglen, int flags)
1121 {
1122 	if (flags & (FLAG_AVX10_256 | FLAG_AVX10_512))
1123 		return aes_gcm_dec_final_vaes_avx10(AES_GCM_KEY_AVX10(key),
1124 						    le_ctr, ghash_acc,
1125 						    total_aadlen, total_datalen,
1126 						    tag, taglen);
1127 	else if (flags & FLAG_AVX)
1128 		return aes_gcm_dec_final_aesni_avx(AES_GCM_KEY_AESNI(key),
1129 						   le_ctr, ghash_acc,
1130 						   total_aadlen, total_datalen,
1131 						   tag, taglen);
1132 	else
1133 		return aes_gcm_dec_final_aesni(AES_GCM_KEY_AESNI(key),
1134 					       le_ctr, ghash_acc,
1135 					       total_aadlen, total_datalen,
1136 					       tag, taglen);
1137 }
1138 
1139 /*
1140  * This is the Integrity Check Value (aka the authentication tag) length and can
1141  * be 8, 12 or 16 bytes long.
1142  */
1143 static int common_rfc4106_set_authsize(struct crypto_aead *aead,
1144 				       unsigned int authsize)
1145 {
1146 	switch (authsize) {
1147 	case 8:
1148 	case 12:
1149 	case 16:
1150 		break;
1151 	default:
1152 		return -EINVAL;
1153 	}
1154 
1155 	return 0;
1156 }
1157 
1158 static int generic_gcmaes_set_authsize(struct crypto_aead *tfm,
1159 				       unsigned int authsize)
1160 {
1161 	switch (authsize) {
1162 	case 4:
1163 	case 8:
1164 	case 12:
1165 	case 13:
1166 	case 14:
1167 	case 15:
1168 	case 16:
1169 		break;
1170 	default:
1171 		return -EINVAL;
1172 	}
1173 
1174 	return 0;
1175 }
1176 
1177 /*
1178  * This is the setkey function for the x86_64 implementations of AES-GCM.  It
1179  * saves the RFC4106 nonce if applicable, expands the AES key, and precomputes
1180  * powers of the hash key.
1181  *
1182  * To comply with the crypto_aead API, this has to be usable in no-SIMD context.
1183  * For that reason, this function includes a portable C implementation of the
1184  * needed logic.  However, the portable C implementation is very slow, taking
1185  * about the same time as encrypting 37 KB of data.  To be ready for users that
1186  * may set a key even somewhat frequently, we therefore also include a SIMD
1187  * assembly implementation, expanding the AES key using AES-NI and precomputing
1188  * the hash key powers using PCLMULQDQ or VPCLMULQDQ.
1189  */
1190 static int gcm_setkey(struct crypto_aead *tfm, const u8 *raw_key,
1191 		      unsigned int keylen, int flags)
1192 {
1193 	struct aes_gcm_key *key = aes_gcm_key_get(tfm, flags);
1194 	int err;
1195 
1196 	if (flags & FLAG_RFC4106) {
1197 		if (keylen < 4)
1198 			return -EINVAL;
1199 		keylen -= 4;
1200 		key->rfc4106_nonce = get_unaligned_be32(raw_key + keylen);
1201 	}
1202 
1203 	/* The assembly code assumes the following offsets. */
1204 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_aesni, base.aes_key.key_enc) != 0);
1205 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_aesni, base.aes_key.key_length) != 480);
1206 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_aesni, h_powers) != 496);
1207 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_aesni, h_powers_xored) != 624);
1208 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_aesni, h_times_x64) != 688);
1209 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_avx10, base.aes_key.key_enc) != 0);
1210 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_avx10, base.aes_key.key_length) != 480);
1211 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_avx10, h_powers) != 512);
1212 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_avx10, padding) != 768);
1213 
1214 	if (likely(crypto_simd_usable())) {
1215 		err = aes_check_keylen(keylen);
1216 		if (err)
1217 			return err;
1218 		kernel_fpu_begin();
1219 		aesni_set_key(&key->aes_key, raw_key, keylen);
1220 		aes_gcm_precompute(key, flags);
1221 		kernel_fpu_end();
1222 	} else {
1223 		static const u8 x_to_the_minus1[16] __aligned(__alignof__(be128)) = {
1224 			[0] = 0xc2, [15] = 1
1225 		};
1226 		static const u8 x_to_the_63[16] __aligned(__alignof__(be128)) = {
1227 			[7] = 1,
1228 		};
1229 		be128 h1 = {};
1230 		be128 h;
1231 		int i;
1232 
1233 		err = aes_expandkey(&key->aes_key, raw_key, keylen);
1234 		if (err)
1235 			return err;
1236 
1237 		/* Encrypt the all-zeroes block to get the hash key H^1 */
1238 		aes_encrypt(&key->aes_key, (u8 *)&h1, (u8 *)&h1);
1239 
1240 		/* Compute H^1 * x^-1 */
1241 		h = h1;
1242 		gf128mul_lle(&h, (const be128 *)x_to_the_minus1);
1243 
1244 		/* Compute the needed key powers */
1245 		if (flags & (FLAG_AVX10_256 | FLAG_AVX10_512)) {
1246 			struct aes_gcm_key_avx10 *k = AES_GCM_KEY_AVX10(key);
1247 
1248 			for (i = ARRAY_SIZE(k->h_powers) - 1; i >= 0; i--) {
1249 				k->h_powers[i][0] = be64_to_cpu(h.b);
1250 				k->h_powers[i][1] = be64_to_cpu(h.a);
1251 				gf128mul_lle(&h, &h1);
1252 			}
1253 			memset(k->padding, 0, sizeof(k->padding));
1254 		} else {
1255 			struct aes_gcm_key_aesni *k = AES_GCM_KEY_AESNI(key);
1256 
1257 			for (i = ARRAY_SIZE(k->h_powers) - 1; i >= 0; i--) {
1258 				k->h_powers[i][0] = be64_to_cpu(h.b);
1259 				k->h_powers[i][1] = be64_to_cpu(h.a);
1260 				k->h_powers_xored[i] = k->h_powers[i][0] ^
1261 						       k->h_powers[i][1];
1262 				gf128mul_lle(&h, &h1);
1263 			}
1264 			gf128mul_lle(&h1, (const be128 *)x_to_the_63);
1265 			k->h_times_x64[0] = be64_to_cpu(h1.b);
1266 			k->h_times_x64[1] = be64_to_cpu(h1.a);
1267 		}
1268 	}
1269 	return 0;
1270 }
1271 
1272 /*
1273  * Initialize @ghash_acc, then pass all @assoclen bytes of associated data
1274  * (a.k.a. additional authenticated data) from @sg_src through the GHASH update
1275  * assembly function.  kernel_fpu_begin() must have already been called.
1276  */
1277 static void gcm_process_assoc(const struct aes_gcm_key *key, u8 ghash_acc[16],
1278 			      struct scatterlist *sg_src, unsigned int assoclen,
1279 			      int flags)
1280 {
1281 	struct scatter_walk walk;
1282 	/*
1283 	 * The assembly function requires that the length of any non-last
1284 	 * segment of associated data be a multiple of 16 bytes, so this
1285 	 * function does the buffering needed to achieve that.
1286 	 */
1287 	unsigned int pos = 0;
1288 	u8 buf[16];
1289 
1290 	memset(ghash_acc, 0, 16);
1291 	scatterwalk_start(&walk, sg_src);
1292 
1293 	while (assoclen) {
1294 		unsigned int len_this_page = scatterwalk_clamp(&walk, assoclen);
1295 		void *mapped = scatterwalk_map(&walk);
1296 		const void *src = mapped;
1297 		unsigned int len;
1298 
1299 		assoclen -= len_this_page;
1300 		scatterwalk_advance(&walk, len_this_page);
1301 		if (unlikely(pos)) {
1302 			len = min(len_this_page, 16 - pos);
1303 			memcpy(&buf[pos], src, len);
1304 			pos += len;
1305 			src += len;
1306 			len_this_page -= len;
1307 			if (pos < 16)
1308 				goto next;
1309 			aes_gcm_aad_update(key, ghash_acc, buf, 16, flags);
1310 			pos = 0;
1311 		}
1312 		len = len_this_page;
1313 		if (unlikely(assoclen)) /* Not the last segment yet? */
1314 			len = round_down(len, 16);
1315 		aes_gcm_aad_update(key, ghash_acc, src, len, flags);
1316 		src += len;
1317 		len_this_page -= len;
1318 		if (unlikely(len_this_page)) {
1319 			memcpy(buf, src, len_this_page);
1320 			pos = len_this_page;
1321 		}
1322 next:
1323 		scatterwalk_unmap(mapped);
1324 		scatterwalk_pagedone(&walk, 0, assoclen);
1325 		if (need_resched()) {
1326 			kernel_fpu_end();
1327 			kernel_fpu_begin();
1328 		}
1329 	}
1330 	if (unlikely(pos))
1331 		aes_gcm_aad_update(key, ghash_acc, buf, pos, flags);
1332 }
1333 
1334 
1335 /* __always_inline to optimize out the branches based on @flags */
1336 static __always_inline int
1337 gcm_crypt(struct aead_request *req, int flags)
1338 {
1339 	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
1340 	const struct aes_gcm_key *key = aes_gcm_key_get(tfm, flags);
1341 	unsigned int assoclen = req->assoclen;
1342 	struct skcipher_walk walk;
1343 	unsigned int nbytes;
1344 	u8 ghash_acc[16]; /* GHASH accumulator */
1345 	u32 le_ctr[4]; /* Counter in little-endian format */
1346 	int taglen;
1347 	int err;
1348 
1349 	/* Initialize the counter and determine the associated data length. */
1350 	le_ctr[0] = 2;
1351 	if (flags & FLAG_RFC4106) {
1352 		if (unlikely(assoclen != 16 && assoclen != 20))
1353 			return -EINVAL;
1354 		assoclen -= 8;
1355 		le_ctr[1] = get_unaligned_be32(req->iv + 4);
1356 		le_ctr[2] = get_unaligned_be32(req->iv + 0);
1357 		le_ctr[3] = key->rfc4106_nonce; /* already byte-swapped */
1358 	} else {
1359 		le_ctr[1] = get_unaligned_be32(req->iv + 8);
1360 		le_ctr[2] = get_unaligned_be32(req->iv + 4);
1361 		le_ctr[3] = get_unaligned_be32(req->iv + 0);
1362 	}
1363 
1364 	/* Begin walking through the plaintext or ciphertext. */
1365 	if (flags & FLAG_ENC)
1366 		err = skcipher_walk_aead_encrypt(&walk, req, false);
1367 	else
1368 		err = skcipher_walk_aead_decrypt(&walk, req, false);
1369 	if (err)
1370 		return err;
1371 
1372 	/*
1373 	 * Since the AES-GCM assembly code requires that at least three assembly
1374 	 * functions be called to process any message (this is needed to support
1375 	 * incremental updates cleanly), to reduce overhead we try to do all
1376 	 * three calls in the same kernel FPU section if possible.  We close the
1377 	 * section and start a new one if there are multiple data segments or if
1378 	 * rescheduling is needed while processing the associated data.
1379 	 */
1380 	kernel_fpu_begin();
1381 
1382 	/* Pass the associated data through GHASH. */
1383 	gcm_process_assoc(key, ghash_acc, req->src, assoclen, flags);
1384 
1385 	/* En/decrypt the data and pass the ciphertext through GHASH. */
1386 	while (unlikely((nbytes = walk.nbytes) < walk.total)) {
1387 		/*
1388 		 * Non-last segment.  In this case, the assembly function
1389 		 * requires that the length be a multiple of 16 (AES_BLOCK_SIZE)
1390 		 * bytes.  The needed buffering of up to 16 bytes is handled by
1391 		 * the skcipher_walk.  Here we just need to round down to a
1392 		 * multiple of 16.
1393 		 */
1394 		nbytes = round_down(nbytes, AES_BLOCK_SIZE);
1395 		aes_gcm_update(key, le_ctr, ghash_acc, walk.src.virt.addr,
1396 			       walk.dst.virt.addr, nbytes, flags);
1397 		le_ctr[0] += nbytes / AES_BLOCK_SIZE;
1398 		kernel_fpu_end();
1399 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
1400 		if (err)
1401 			return err;
1402 		kernel_fpu_begin();
1403 	}
1404 	/* Last segment: process all remaining data. */
1405 	aes_gcm_update(key, le_ctr, ghash_acc, walk.src.virt.addr,
1406 		       walk.dst.virt.addr, nbytes, flags);
1407 	/*
1408 	 * The low word of the counter isn't used by the finalize, so there's no
1409 	 * need to increment it here.
1410 	 */
1411 
1412 	/* Finalize */
1413 	taglen = crypto_aead_authsize(tfm);
1414 	if (flags & FLAG_ENC) {
1415 		/* Finish computing the auth tag. */
1416 		aes_gcm_enc_final(key, le_ctr, ghash_acc, assoclen,
1417 				  req->cryptlen, flags);
1418 
1419 		/* Store the computed auth tag in the dst scatterlist. */
1420 		scatterwalk_map_and_copy(ghash_acc, req->dst, req->assoclen +
1421 					 req->cryptlen, taglen, 1);
1422 	} else {
1423 		unsigned int datalen = req->cryptlen - taglen;
1424 		u8 tag[16];
1425 
1426 		/* Get the transmitted auth tag from the src scatterlist. */
1427 		scatterwalk_map_and_copy(tag, req->src, req->assoclen + datalen,
1428 					 taglen, 0);
1429 		/*
1430 		 * Finish computing the auth tag and compare it to the
1431 		 * transmitted one.  The assembly function does the actual tag
1432 		 * comparison.  Here, just check the boolean result.
1433 		 */
1434 		if (!aes_gcm_dec_final(key, le_ctr, ghash_acc, assoclen,
1435 				       datalen, tag, taglen, flags))
1436 			err = -EBADMSG;
1437 	}
1438 	kernel_fpu_end();
1439 	if (nbytes)
1440 		skcipher_walk_done(&walk, 0);
1441 	return err;
1442 }
1443 
1444 #define DEFINE_GCM_ALGS(suffix, flags, generic_driver_name, rfc_driver_name,   \
1445 			ctxsize, priority)				       \
1446 									       \
1447 static int gcm_setkey_##suffix(struct crypto_aead *tfm, const u8 *raw_key,     \
1448 			       unsigned int keylen)			       \
1449 {									       \
1450 	return gcm_setkey(tfm, raw_key, keylen, (flags));		       \
1451 }									       \
1452 									       \
1453 static int gcm_encrypt_##suffix(struct aead_request *req)		       \
1454 {									       \
1455 	return gcm_crypt(req, (flags) | FLAG_ENC);			       \
1456 }									       \
1457 									       \
1458 static int gcm_decrypt_##suffix(struct aead_request *req)		       \
1459 {									       \
1460 	return gcm_crypt(req, (flags));					       \
1461 }									       \
1462 									       \
1463 static int rfc4106_setkey_##suffix(struct crypto_aead *tfm, const u8 *raw_key, \
1464 				   unsigned int keylen)			       \
1465 {									       \
1466 	return gcm_setkey(tfm, raw_key, keylen, (flags) | FLAG_RFC4106);       \
1467 }									       \
1468 									       \
1469 static int rfc4106_encrypt_##suffix(struct aead_request *req)		       \
1470 {									       \
1471 	return gcm_crypt(req, (flags) | FLAG_RFC4106 | FLAG_ENC);	       \
1472 }									       \
1473 									       \
1474 static int rfc4106_decrypt_##suffix(struct aead_request *req)		       \
1475 {									       \
1476 	return gcm_crypt(req, (flags) | FLAG_RFC4106);			       \
1477 }									       \
1478 									       \
1479 static struct aead_alg aes_gcm_algs_##suffix[] = { {			       \
1480 	.setkey			= gcm_setkey_##suffix,			       \
1481 	.setauthsize		= generic_gcmaes_set_authsize,		       \
1482 	.encrypt		= gcm_encrypt_##suffix,			       \
1483 	.decrypt		= gcm_decrypt_##suffix,			       \
1484 	.ivsize			= GCM_AES_IV_SIZE,			       \
1485 	.chunksize		= AES_BLOCK_SIZE,			       \
1486 	.maxauthsize		= 16,					       \
1487 	.base = {							       \
1488 		.cra_name		= "__gcm(aes)",			       \
1489 		.cra_driver_name	= "__" generic_driver_name,	       \
1490 		.cra_priority		= (priority),			       \
1491 		.cra_flags		= CRYPTO_ALG_INTERNAL,		       \
1492 		.cra_blocksize		= 1,				       \
1493 		.cra_ctxsize		= (ctxsize),			       \
1494 		.cra_module		= THIS_MODULE,			       \
1495 	},								       \
1496 }, {									       \
1497 	.setkey			= rfc4106_setkey_##suffix,		       \
1498 	.setauthsize		= common_rfc4106_set_authsize,		       \
1499 	.encrypt		= rfc4106_encrypt_##suffix,		       \
1500 	.decrypt		= rfc4106_decrypt_##suffix,		       \
1501 	.ivsize			= GCM_RFC4106_IV_SIZE,			       \
1502 	.chunksize		= AES_BLOCK_SIZE,			       \
1503 	.maxauthsize		= 16,					       \
1504 	.base = {							       \
1505 		.cra_name		= "__rfc4106(gcm(aes))",	       \
1506 		.cra_driver_name	= "__" rfc_driver_name,		       \
1507 		.cra_priority		= (priority),			       \
1508 		.cra_flags		= CRYPTO_ALG_INTERNAL,		       \
1509 		.cra_blocksize		= 1,				       \
1510 		.cra_ctxsize		= (ctxsize),			       \
1511 		.cra_module		= THIS_MODULE,			       \
1512 	},								       \
1513 } };									       \
1514 									       \
1515 static struct simd_aead_alg *aes_gcm_simdalgs_##suffix[2]		       \
1516 
1517 /* aes_gcm_algs_aesni */
1518 DEFINE_GCM_ALGS(aesni, /* no flags */ 0,
1519 		"generic-gcm-aesni", "rfc4106-gcm-aesni",
1520 		AES_GCM_KEY_AESNI_SIZE, 400);
1521 
1522 /* aes_gcm_algs_aesni_avx */
1523 DEFINE_GCM_ALGS(aesni_avx, FLAG_AVX,
1524 		"generic-gcm-aesni-avx", "rfc4106-gcm-aesni-avx",
1525 		AES_GCM_KEY_AESNI_SIZE, 500);
1526 
1527 #if defined(CONFIG_AS_VAES) && defined(CONFIG_AS_VPCLMULQDQ)
1528 /* aes_gcm_algs_vaes_avx10_256 */
1529 DEFINE_GCM_ALGS(vaes_avx10_256, FLAG_AVX10_256,
1530 		"generic-gcm-vaes-avx10_256", "rfc4106-gcm-vaes-avx10_256",
1531 		AES_GCM_KEY_AVX10_SIZE, 700);
1532 
1533 /* aes_gcm_algs_vaes_avx10_512 */
1534 DEFINE_GCM_ALGS(vaes_avx10_512, FLAG_AVX10_512,
1535 		"generic-gcm-vaes-avx10_512", "rfc4106-gcm-vaes-avx10_512",
1536 		AES_GCM_KEY_AVX10_SIZE, 800);
1537 #endif /* CONFIG_AS_VAES && CONFIG_AS_VPCLMULQDQ */
1538 
1539 /*
1540  * This is a list of CPU models that are known to suffer from downclocking when
1541  * zmm registers (512-bit vectors) are used.  On these CPUs, the AES mode
1542  * implementations with zmm registers won't be used by default.  Implementations
1543  * with ymm registers (256-bit vectors) will be used by default instead.
1544  */
1545 static const struct x86_cpu_id zmm_exclusion_list[] = {
1546 	X86_MATCH_VFM(INTEL_SKYLAKE_X,		0),
1547 	X86_MATCH_VFM(INTEL_ICELAKE_X,		0),
1548 	X86_MATCH_VFM(INTEL_ICELAKE_D,		0),
1549 	X86_MATCH_VFM(INTEL_ICELAKE,		0),
1550 	X86_MATCH_VFM(INTEL_ICELAKE_L,		0),
1551 	X86_MATCH_VFM(INTEL_ICELAKE_NNPI,	0),
1552 	X86_MATCH_VFM(INTEL_TIGERLAKE_L,	0),
1553 	X86_MATCH_VFM(INTEL_TIGERLAKE,		0),
1554 	/* Allow Rocket Lake and later, and Sapphire Rapids and later. */
1555 	/* Also allow AMD CPUs (starting with Zen 4, the first with AVX-512). */
1556 	{},
1557 };
1558 
1559 static int __init register_avx_algs(void)
1560 {
1561 	int err;
1562 
1563 	if (!boot_cpu_has(X86_FEATURE_AVX))
1564 		return 0;
1565 	err = simd_register_skciphers_compat(&aes_xts_alg_aesni_avx, 1,
1566 					     &aes_xts_simdalg_aesni_avx);
1567 	if (err)
1568 		return err;
1569 	err = simd_register_aeads_compat(aes_gcm_algs_aesni_avx,
1570 					 ARRAY_SIZE(aes_gcm_algs_aesni_avx),
1571 					 aes_gcm_simdalgs_aesni_avx);
1572 	if (err)
1573 		return err;
1574 #if defined(CONFIG_AS_VAES) && defined(CONFIG_AS_VPCLMULQDQ)
1575 	if (!boot_cpu_has(X86_FEATURE_AVX2) ||
1576 	    !boot_cpu_has(X86_FEATURE_VAES) ||
1577 	    !boot_cpu_has(X86_FEATURE_VPCLMULQDQ) ||
1578 	    !boot_cpu_has(X86_FEATURE_PCLMULQDQ) ||
1579 	    !cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM, NULL))
1580 		return 0;
1581 	err = simd_register_skciphers_compat(&aes_xts_alg_vaes_avx2, 1,
1582 					     &aes_xts_simdalg_vaes_avx2);
1583 	if (err)
1584 		return err;
1585 
1586 	if (!boot_cpu_has(X86_FEATURE_AVX512BW) ||
1587 	    !boot_cpu_has(X86_FEATURE_AVX512VL) ||
1588 	    !boot_cpu_has(X86_FEATURE_BMI2) ||
1589 	    !cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM |
1590 			       XFEATURE_MASK_AVX512, NULL))
1591 		return 0;
1592 
1593 	err = simd_register_skciphers_compat(&aes_xts_alg_vaes_avx10_256, 1,
1594 					     &aes_xts_simdalg_vaes_avx10_256);
1595 	if (err)
1596 		return err;
1597 	err = simd_register_aeads_compat(aes_gcm_algs_vaes_avx10_256,
1598 					 ARRAY_SIZE(aes_gcm_algs_vaes_avx10_256),
1599 					 aes_gcm_simdalgs_vaes_avx10_256);
1600 	if (err)
1601 		return err;
1602 
1603 	if (x86_match_cpu(zmm_exclusion_list)) {
1604 		int i;
1605 
1606 		aes_xts_alg_vaes_avx10_512.base.cra_priority = 1;
1607 		for (i = 0; i < ARRAY_SIZE(aes_gcm_algs_vaes_avx10_512); i++)
1608 			aes_gcm_algs_vaes_avx10_512[i].base.cra_priority = 1;
1609 	}
1610 
1611 	err = simd_register_skciphers_compat(&aes_xts_alg_vaes_avx10_512, 1,
1612 					     &aes_xts_simdalg_vaes_avx10_512);
1613 	if (err)
1614 		return err;
1615 	err = simd_register_aeads_compat(aes_gcm_algs_vaes_avx10_512,
1616 					 ARRAY_SIZE(aes_gcm_algs_vaes_avx10_512),
1617 					 aes_gcm_simdalgs_vaes_avx10_512);
1618 	if (err)
1619 		return err;
1620 #endif /* CONFIG_AS_VAES && CONFIG_AS_VPCLMULQDQ */
1621 	return 0;
1622 }
1623 
1624 static void unregister_avx_algs(void)
1625 {
1626 	if (aes_xts_simdalg_aesni_avx)
1627 		simd_unregister_skciphers(&aes_xts_alg_aesni_avx, 1,
1628 					  &aes_xts_simdalg_aesni_avx);
1629 	if (aes_gcm_simdalgs_aesni_avx[0])
1630 		simd_unregister_aeads(aes_gcm_algs_aesni_avx,
1631 				      ARRAY_SIZE(aes_gcm_algs_aesni_avx),
1632 				      aes_gcm_simdalgs_aesni_avx);
1633 #if defined(CONFIG_AS_VAES) && defined(CONFIG_AS_VPCLMULQDQ)
1634 	if (aes_xts_simdalg_vaes_avx2)
1635 		simd_unregister_skciphers(&aes_xts_alg_vaes_avx2, 1,
1636 					  &aes_xts_simdalg_vaes_avx2);
1637 	if (aes_xts_simdalg_vaes_avx10_256)
1638 		simd_unregister_skciphers(&aes_xts_alg_vaes_avx10_256, 1,
1639 					  &aes_xts_simdalg_vaes_avx10_256);
1640 	if (aes_gcm_simdalgs_vaes_avx10_256[0])
1641 		simd_unregister_aeads(aes_gcm_algs_vaes_avx10_256,
1642 				      ARRAY_SIZE(aes_gcm_algs_vaes_avx10_256),
1643 				      aes_gcm_simdalgs_vaes_avx10_256);
1644 	if (aes_xts_simdalg_vaes_avx10_512)
1645 		simd_unregister_skciphers(&aes_xts_alg_vaes_avx10_512, 1,
1646 					  &aes_xts_simdalg_vaes_avx10_512);
1647 	if (aes_gcm_simdalgs_vaes_avx10_512[0])
1648 		simd_unregister_aeads(aes_gcm_algs_vaes_avx10_512,
1649 				      ARRAY_SIZE(aes_gcm_algs_vaes_avx10_512),
1650 				      aes_gcm_simdalgs_vaes_avx10_512);
1651 #endif
1652 }
1653 #else /* CONFIG_X86_64 */
1654 static struct aead_alg aes_gcm_algs_aesni[0];
1655 static struct simd_aead_alg *aes_gcm_simdalgs_aesni[0];
1656 
1657 static int __init register_avx_algs(void)
1658 {
1659 	return 0;
1660 }
1661 
1662 static void unregister_avx_algs(void)
1663 {
1664 }
1665 #endif /* !CONFIG_X86_64 */
1666 
1667 static const struct x86_cpu_id aesni_cpu_id[] = {
1668 	X86_MATCH_FEATURE(X86_FEATURE_AES, NULL),
1669 	{}
1670 };
1671 MODULE_DEVICE_TABLE(x86cpu, aesni_cpu_id);
1672 
1673 static int __init aesni_init(void)
1674 {
1675 	int err;
1676 
1677 	if (!x86_match_cpu(aesni_cpu_id))
1678 		return -ENODEV;
1679 #ifdef CONFIG_X86_64
1680 	if (boot_cpu_has(X86_FEATURE_AVX)) {
1681 		/* optimize performance of ctr mode encryption transform */
1682 		static_call_update(aesni_ctr_enc_tfm, aesni_ctr_enc_avx_tfm);
1683 		pr_info("AES CTR mode by8 optimization enabled\n");
1684 	}
1685 #endif /* CONFIG_X86_64 */
1686 
1687 	err = crypto_register_alg(&aesni_cipher_alg);
1688 	if (err)
1689 		return err;
1690 
1691 	err = simd_register_skciphers_compat(aesni_skciphers,
1692 					     ARRAY_SIZE(aesni_skciphers),
1693 					     aesni_simd_skciphers);
1694 	if (err)
1695 		goto unregister_cipher;
1696 
1697 	err = simd_register_aeads_compat(aes_gcm_algs_aesni,
1698 					 ARRAY_SIZE(aes_gcm_algs_aesni),
1699 					 aes_gcm_simdalgs_aesni);
1700 	if (err)
1701 		goto unregister_skciphers;
1702 
1703 #ifdef CONFIG_X86_64
1704 	if (boot_cpu_has(X86_FEATURE_AVX))
1705 		err = simd_register_skciphers_compat(&aesni_xctr, 1,
1706 						     &aesni_simd_xctr);
1707 	if (err)
1708 		goto unregister_aeads;
1709 #endif /* CONFIG_X86_64 */
1710 
1711 	err = register_avx_algs();
1712 	if (err)
1713 		goto unregister_avx;
1714 
1715 	return 0;
1716 
1717 unregister_avx:
1718 	unregister_avx_algs();
1719 #ifdef CONFIG_X86_64
1720 	if (aesni_simd_xctr)
1721 		simd_unregister_skciphers(&aesni_xctr, 1, &aesni_simd_xctr);
1722 unregister_aeads:
1723 #endif /* CONFIG_X86_64 */
1724 	simd_unregister_aeads(aes_gcm_algs_aesni,
1725 			      ARRAY_SIZE(aes_gcm_algs_aesni),
1726 			      aes_gcm_simdalgs_aesni);
1727 unregister_skciphers:
1728 	simd_unregister_skciphers(aesni_skciphers, ARRAY_SIZE(aesni_skciphers),
1729 				  aesni_simd_skciphers);
1730 unregister_cipher:
1731 	crypto_unregister_alg(&aesni_cipher_alg);
1732 	return err;
1733 }
1734 
1735 static void __exit aesni_exit(void)
1736 {
1737 	simd_unregister_aeads(aes_gcm_algs_aesni,
1738 			      ARRAY_SIZE(aes_gcm_algs_aesni),
1739 			      aes_gcm_simdalgs_aesni);
1740 	simd_unregister_skciphers(aesni_skciphers, ARRAY_SIZE(aesni_skciphers),
1741 				  aesni_simd_skciphers);
1742 	crypto_unregister_alg(&aesni_cipher_alg);
1743 #ifdef CONFIG_X86_64
1744 	if (boot_cpu_has(X86_FEATURE_AVX))
1745 		simd_unregister_skciphers(&aesni_xctr, 1, &aesni_simd_xctr);
1746 #endif /* CONFIG_X86_64 */
1747 	unregister_avx_algs();
1748 }
1749 
1750 late_initcall(aesni_init);
1751 module_exit(aesni_exit);
1752 
1753 MODULE_DESCRIPTION("AES cipher and modes, optimized with AES-NI or VAES instructions");
1754 MODULE_LICENSE("GPL");
1755 MODULE_ALIAS_CRYPTO("aes");
1756