xref: /linux/arch/x86/crypto/aesni-intel_glue.c (revision 7fc2cd2e4b398c57c9cf961cfea05eadbf34c05c)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * Support for AES-NI and VAES instructions.  This file contains glue code.
4  * The real AES implementations are in aesni-intel_asm.S and other .S files.
5  *
6  * Copyright (C) 2008, Intel Corp.
7  *    Author: Huang Ying <ying.huang@intel.com>
8  *
9  * Added RFC4106 AES-GCM support for 128-bit keys under the AEAD
10  * interface for 64-bit kernels.
11  *    Authors: Adrian Hoban <adrian.hoban@intel.com>
12  *             Gabriele Paoloni <gabriele.paoloni@intel.com>
13  *             Tadeusz Struk (tadeusz.struk@intel.com)
14  *             Aidan O'Mahony (aidan.o.mahony@intel.com)
15  *    Copyright (c) 2010, Intel Corporation.
16  *
17  * Copyright 2024 Google LLC
18  */
19 
20 #include <linux/hardirq.h>
21 #include <linux/types.h>
22 #include <linux/module.h>
23 #include <linux/err.h>
24 #include <crypto/algapi.h>
25 #include <crypto/aes.h>
26 #include <crypto/b128ops.h>
27 #include <crypto/gcm.h>
28 #include <crypto/xts.h>
29 #include <asm/cpu_device_id.h>
30 #include <asm/simd.h>
31 #include <crypto/scatterwalk.h>
32 #include <crypto/internal/aead.h>
33 #include <crypto/internal/simd.h>
34 #include <crypto/internal/skcipher.h>
35 #include <linux/jump_label.h>
36 #include <linux/workqueue.h>
37 #include <linux/spinlock.h>
38 #include <linux/static_call.h>
39 
40 
41 #define AESNI_ALIGN	16
42 #define AESNI_ALIGN_ATTR __attribute__ ((__aligned__(AESNI_ALIGN)))
43 #define AES_BLOCK_MASK	(~(AES_BLOCK_SIZE - 1))
44 #define AESNI_ALIGN_EXTRA ((AESNI_ALIGN - 1) & ~(CRYPTO_MINALIGN - 1))
45 #define CRYPTO_AES_CTX_SIZE (sizeof(struct crypto_aes_ctx) + AESNI_ALIGN_EXTRA)
46 #define XTS_AES_CTX_SIZE (sizeof(struct aesni_xts_ctx) + AESNI_ALIGN_EXTRA)
47 
48 struct aesni_xts_ctx {
49 	struct crypto_aes_ctx tweak_ctx AESNI_ALIGN_ATTR;
50 	struct crypto_aes_ctx crypt_ctx AESNI_ALIGN_ATTR;
51 };
52 
53 static inline void *aes_align_addr(void *addr)
54 {
55 	if (crypto_tfm_ctx_alignment() >= AESNI_ALIGN)
56 		return addr;
57 	return PTR_ALIGN(addr, AESNI_ALIGN);
58 }
59 
60 asmlinkage void aesni_set_key(struct crypto_aes_ctx *ctx, const u8 *in_key,
61 			      unsigned int key_len);
62 asmlinkage void aesni_enc(const void *ctx, u8 *out, const u8 *in);
63 asmlinkage void aesni_dec(const void *ctx, u8 *out, const u8 *in);
64 asmlinkage void aesni_ecb_enc(struct crypto_aes_ctx *ctx, u8 *out,
65 			      const u8 *in, unsigned int len);
66 asmlinkage void aesni_ecb_dec(struct crypto_aes_ctx *ctx, u8 *out,
67 			      const u8 *in, unsigned int len);
68 asmlinkage void aesni_cbc_enc(struct crypto_aes_ctx *ctx, u8 *out,
69 			      const u8 *in, unsigned int len, u8 *iv);
70 asmlinkage void aesni_cbc_dec(struct crypto_aes_ctx *ctx, u8 *out,
71 			      const u8 *in, unsigned int len, u8 *iv);
72 asmlinkage void aesni_cts_cbc_enc(struct crypto_aes_ctx *ctx, u8 *out,
73 				  const u8 *in, unsigned int len, u8 *iv);
74 asmlinkage void aesni_cts_cbc_dec(struct crypto_aes_ctx *ctx, u8 *out,
75 				  const u8 *in, unsigned int len, u8 *iv);
76 
77 asmlinkage void aesni_xts_enc(const struct crypto_aes_ctx *ctx, u8 *out,
78 			      const u8 *in, unsigned int len, u8 *iv);
79 
80 asmlinkage void aesni_xts_dec(const struct crypto_aes_ctx *ctx, u8 *out,
81 			      const u8 *in, unsigned int len, u8 *iv);
82 
83 #ifdef CONFIG_X86_64
84 asmlinkage void aesni_ctr_enc(struct crypto_aes_ctx *ctx, u8 *out,
85 			      const u8 *in, unsigned int len, u8 *iv);
86 #endif
87 
88 static inline struct crypto_aes_ctx *aes_ctx(void *raw_ctx)
89 {
90 	return aes_align_addr(raw_ctx);
91 }
92 
93 static inline struct aesni_xts_ctx *aes_xts_ctx(struct crypto_skcipher *tfm)
94 {
95 	return aes_align_addr(crypto_skcipher_ctx(tfm));
96 }
97 
98 static int aes_set_key_common(struct crypto_aes_ctx *ctx,
99 			      const u8 *in_key, unsigned int key_len)
100 {
101 	int err;
102 
103 	if (!crypto_simd_usable())
104 		return aes_expandkey(ctx, in_key, key_len);
105 
106 	err = aes_check_keylen(key_len);
107 	if (err)
108 		return err;
109 
110 	kernel_fpu_begin();
111 	aesni_set_key(ctx, in_key, key_len);
112 	kernel_fpu_end();
113 	return 0;
114 }
115 
116 static int aes_set_key(struct crypto_tfm *tfm, const u8 *in_key,
117 		       unsigned int key_len)
118 {
119 	return aes_set_key_common(aes_ctx(crypto_tfm_ctx(tfm)), in_key,
120 				  key_len);
121 }
122 
123 static void aesni_encrypt(struct crypto_tfm *tfm, u8 *dst, const u8 *src)
124 {
125 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_tfm_ctx(tfm));
126 
127 	if (!crypto_simd_usable()) {
128 		aes_encrypt(ctx, dst, src);
129 	} else {
130 		kernel_fpu_begin();
131 		aesni_enc(ctx, dst, src);
132 		kernel_fpu_end();
133 	}
134 }
135 
136 static void aesni_decrypt(struct crypto_tfm *tfm, u8 *dst, const u8 *src)
137 {
138 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_tfm_ctx(tfm));
139 
140 	if (!crypto_simd_usable()) {
141 		aes_decrypt(ctx, dst, src);
142 	} else {
143 		kernel_fpu_begin();
144 		aesni_dec(ctx, dst, src);
145 		kernel_fpu_end();
146 	}
147 }
148 
149 static int aesni_skcipher_setkey(struct crypto_skcipher *tfm, const u8 *key,
150 			         unsigned int len)
151 {
152 	return aes_set_key_common(aes_ctx(crypto_skcipher_ctx(tfm)), key, len);
153 }
154 
155 static int ecb_encrypt(struct skcipher_request *req)
156 {
157 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
158 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
159 	struct skcipher_walk walk;
160 	unsigned int nbytes;
161 	int err;
162 
163 	err = skcipher_walk_virt(&walk, req, false);
164 
165 	while ((nbytes = walk.nbytes)) {
166 		kernel_fpu_begin();
167 		aesni_ecb_enc(ctx, walk.dst.virt.addr, walk.src.virt.addr,
168 			      nbytes & AES_BLOCK_MASK);
169 		kernel_fpu_end();
170 		nbytes &= AES_BLOCK_SIZE - 1;
171 		err = skcipher_walk_done(&walk, nbytes);
172 	}
173 
174 	return err;
175 }
176 
177 static int ecb_decrypt(struct skcipher_request *req)
178 {
179 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
180 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
181 	struct skcipher_walk walk;
182 	unsigned int nbytes;
183 	int err;
184 
185 	err = skcipher_walk_virt(&walk, req, false);
186 
187 	while ((nbytes = walk.nbytes)) {
188 		kernel_fpu_begin();
189 		aesni_ecb_dec(ctx, walk.dst.virt.addr, walk.src.virt.addr,
190 			      nbytes & AES_BLOCK_MASK);
191 		kernel_fpu_end();
192 		nbytes &= AES_BLOCK_SIZE - 1;
193 		err = skcipher_walk_done(&walk, nbytes);
194 	}
195 
196 	return err;
197 }
198 
199 static int cbc_encrypt(struct skcipher_request *req)
200 {
201 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
202 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
203 	struct skcipher_walk walk;
204 	unsigned int nbytes;
205 	int err;
206 
207 	err = skcipher_walk_virt(&walk, req, false);
208 
209 	while ((nbytes = walk.nbytes)) {
210 		kernel_fpu_begin();
211 		aesni_cbc_enc(ctx, walk.dst.virt.addr, walk.src.virt.addr,
212 			      nbytes & AES_BLOCK_MASK, walk.iv);
213 		kernel_fpu_end();
214 		nbytes &= AES_BLOCK_SIZE - 1;
215 		err = skcipher_walk_done(&walk, nbytes);
216 	}
217 
218 	return err;
219 }
220 
221 static int cbc_decrypt(struct skcipher_request *req)
222 {
223 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
224 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
225 	struct skcipher_walk walk;
226 	unsigned int nbytes;
227 	int err;
228 
229 	err = skcipher_walk_virt(&walk, req, false);
230 
231 	while ((nbytes = walk.nbytes)) {
232 		kernel_fpu_begin();
233 		aesni_cbc_dec(ctx, walk.dst.virt.addr, walk.src.virt.addr,
234 			      nbytes & AES_BLOCK_MASK, walk.iv);
235 		kernel_fpu_end();
236 		nbytes &= AES_BLOCK_SIZE - 1;
237 		err = skcipher_walk_done(&walk, nbytes);
238 	}
239 
240 	return err;
241 }
242 
243 static int cts_cbc_encrypt(struct skcipher_request *req)
244 {
245 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
246 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
247 	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
248 	struct scatterlist *src = req->src, *dst = req->dst;
249 	struct scatterlist sg_src[2], sg_dst[2];
250 	struct skcipher_request subreq;
251 	struct skcipher_walk walk;
252 	int err;
253 
254 	skcipher_request_set_tfm(&subreq, tfm);
255 	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
256 				      NULL, NULL);
257 
258 	if (req->cryptlen <= AES_BLOCK_SIZE) {
259 		if (req->cryptlen < AES_BLOCK_SIZE)
260 			return -EINVAL;
261 		cbc_blocks = 1;
262 	}
263 
264 	if (cbc_blocks > 0) {
265 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
266 					   cbc_blocks * AES_BLOCK_SIZE,
267 					   req->iv);
268 
269 		err = cbc_encrypt(&subreq);
270 		if (err)
271 			return err;
272 
273 		if (req->cryptlen == AES_BLOCK_SIZE)
274 			return 0;
275 
276 		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
277 		if (req->dst != req->src)
278 			dst = scatterwalk_ffwd(sg_dst, req->dst,
279 					       subreq.cryptlen);
280 	}
281 
282 	/* handle ciphertext stealing */
283 	skcipher_request_set_crypt(&subreq, src, dst,
284 				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
285 				   req->iv);
286 
287 	err = skcipher_walk_virt(&walk, &subreq, false);
288 	if (err)
289 		return err;
290 
291 	kernel_fpu_begin();
292 	aesni_cts_cbc_enc(ctx, walk.dst.virt.addr, walk.src.virt.addr,
293 			  walk.nbytes, walk.iv);
294 	kernel_fpu_end();
295 
296 	return skcipher_walk_done(&walk, 0);
297 }
298 
299 static int cts_cbc_decrypt(struct skcipher_request *req)
300 {
301 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
302 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
303 	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
304 	struct scatterlist *src = req->src, *dst = req->dst;
305 	struct scatterlist sg_src[2], sg_dst[2];
306 	struct skcipher_request subreq;
307 	struct skcipher_walk walk;
308 	int err;
309 
310 	skcipher_request_set_tfm(&subreq, tfm);
311 	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
312 				      NULL, NULL);
313 
314 	if (req->cryptlen <= AES_BLOCK_SIZE) {
315 		if (req->cryptlen < AES_BLOCK_SIZE)
316 			return -EINVAL;
317 		cbc_blocks = 1;
318 	}
319 
320 	if (cbc_blocks > 0) {
321 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
322 					   cbc_blocks * AES_BLOCK_SIZE,
323 					   req->iv);
324 
325 		err = cbc_decrypt(&subreq);
326 		if (err)
327 			return err;
328 
329 		if (req->cryptlen == AES_BLOCK_SIZE)
330 			return 0;
331 
332 		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
333 		if (req->dst != req->src)
334 			dst = scatterwalk_ffwd(sg_dst, req->dst,
335 					       subreq.cryptlen);
336 	}
337 
338 	/* handle ciphertext stealing */
339 	skcipher_request_set_crypt(&subreq, src, dst,
340 				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
341 				   req->iv);
342 
343 	err = skcipher_walk_virt(&walk, &subreq, false);
344 	if (err)
345 		return err;
346 
347 	kernel_fpu_begin();
348 	aesni_cts_cbc_dec(ctx, walk.dst.virt.addr, walk.src.virt.addr,
349 			  walk.nbytes, walk.iv);
350 	kernel_fpu_end();
351 
352 	return skcipher_walk_done(&walk, 0);
353 }
354 
355 #ifdef CONFIG_X86_64
356 /* This is the non-AVX version. */
357 static int ctr_crypt_aesni(struct skcipher_request *req)
358 {
359 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
360 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
361 	u8 keystream[AES_BLOCK_SIZE];
362 	struct skcipher_walk walk;
363 	unsigned int nbytes;
364 	int err;
365 
366 	err = skcipher_walk_virt(&walk, req, false);
367 
368 	while ((nbytes = walk.nbytes) > 0) {
369 		kernel_fpu_begin();
370 		if (nbytes & AES_BLOCK_MASK)
371 			aesni_ctr_enc(ctx, walk.dst.virt.addr,
372 				      walk.src.virt.addr,
373 				      nbytes & AES_BLOCK_MASK, walk.iv);
374 		nbytes &= ~AES_BLOCK_MASK;
375 
376 		if (walk.nbytes == walk.total && nbytes > 0) {
377 			aesni_enc(ctx, keystream, walk.iv);
378 			crypto_xor_cpy(walk.dst.virt.addr + walk.nbytes - nbytes,
379 				       walk.src.virt.addr + walk.nbytes - nbytes,
380 				       keystream, nbytes);
381 			crypto_inc(walk.iv, AES_BLOCK_SIZE);
382 			nbytes = 0;
383 		}
384 		kernel_fpu_end();
385 		err = skcipher_walk_done(&walk, nbytes);
386 	}
387 	return err;
388 }
389 #endif
390 
391 static int xts_setkey_aesni(struct crypto_skcipher *tfm, const u8 *key,
392 			    unsigned int keylen)
393 {
394 	struct aesni_xts_ctx *ctx = aes_xts_ctx(tfm);
395 	int err;
396 
397 	err = xts_verify_key(tfm, key, keylen);
398 	if (err)
399 		return err;
400 
401 	keylen /= 2;
402 
403 	/* first half of xts-key is for crypt */
404 	err = aes_set_key_common(&ctx->crypt_ctx, key, keylen);
405 	if (err)
406 		return err;
407 
408 	/* second half of xts-key is for tweak */
409 	return aes_set_key_common(&ctx->tweak_ctx, key + keylen, keylen);
410 }
411 
412 typedef void (*xts_encrypt_iv_func)(const struct crypto_aes_ctx *tweak_key,
413 				    u8 iv[AES_BLOCK_SIZE]);
414 typedef void (*xts_crypt_func)(const struct crypto_aes_ctx *key,
415 			       const u8 *src, u8 *dst, int len,
416 			       u8 tweak[AES_BLOCK_SIZE]);
417 
418 /* This handles cases where the source and/or destination span pages. */
419 static noinline int
420 xts_crypt_slowpath(struct skcipher_request *req, xts_crypt_func crypt_func)
421 {
422 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
423 	const struct aesni_xts_ctx *ctx = aes_xts_ctx(tfm);
424 	int tail = req->cryptlen % AES_BLOCK_SIZE;
425 	struct scatterlist sg_src[2], sg_dst[2];
426 	struct skcipher_request subreq;
427 	struct skcipher_walk walk;
428 	struct scatterlist *src, *dst;
429 	int err;
430 
431 	/*
432 	 * If the message length isn't divisible by the AES block size, then
433 	 * separate off the last full block and the partial block.  This ensures
434 	 * that they are processed in the same call to the assembly function,
435 	 * which is required for ciphertext stealing.
436 	 */
437 	if (tail) {
438 		skcipher_request_set_tfm(&subreq, tfm);
439 		skcipher_request_set_callback(&subreq,
440 					      skcipher_request_flags(req),
441 					      NULL, NULL);
442 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
443 					   req->cryptlen - tail - AES_BLOCK_SIZE,
444 					   req->iv);
445 		req = &subreq;
446 	}
447 
448 	err = skcipher_walk_virt(&walk, req, false);
449 
450 	while (walk.nbytes) {
451 		kernel_fpu_begin();
452 		(*crypt_func)(&ctx->crypt_ctx,
453 			      walk.src.virt.addr, walk.dst.virt.addr,
454 			      walk.nbytes & ~(AES_BLOCK_SIZE - 1), req->iv);
455 		kernel_fpu_end();
456 		err = skcipher_walk_done(&walk,
457 					 walk.nbytes & (AES_BLOCK_SIZE - 1));
458 	}
459 
460 	if (err || !tail)
461 		return err;
462 
463 	/* Do ciphertext stealing with the last full block and partial block. */
464 
465 	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
466 	if (req->dst != req->src)
467 		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
468 
469 	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
470 				   req->iv);
471 
472 	err = skcipher_walk_virt(&walk, req, false);
473 	if (err)
474 		return err;
475 
476 	kernel_fpu_begin();
477 	(*crypt_func)(&ctx->crypt_ctx, walk.src.virt.addr, walk.dst.virt.addr,
478 		      walk.nbytes, req->iv);
479 	kernel_fpu_end();
480 
481 	return skcipher_walk_done(&walk, 0);
482 }
483 
484 /* __always_inline to avoid indirect call in fastpath */
485 static __always_inline int
486 xts_crypt(struct skcipher_request *req, xts_encrypt_iv_func encrypt_iv,
487 	  xts_crypt_func crypt_func)
488 {
489 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
490 	const struct aesni_xts_ctx *ctx = aes_xts_ctx(tfm);
491 
492 	if (unlikely(req->cryptlen < AES_BLOCK_SIZE))
493 		return -EINVAL;
494 
495 	kernel_fpu_begin();
496 	(*encrypt_iv)(&ctx->tweak_ctx, req->iv);
497 
498 	/*
499 	 * In practice, virtually all XTS plaintexts and ciphertexts are either
500 	 * 512 or 4096 bytes and do not use multiple scatterlist elements.  To
501 	 * optimize the performance of these cases, the below fast-path handles
502 	 * single-scatterlist-element messages as efficiently as possible.  The
503 	 * code is 64-bit specific, as it assumes no page mapping is needed.
504 	 */
505 	if (IS_ENABLED(CONFIG_X86_64) &&
506 	    likely(req->src->length >= req->cryptlen &&
507 		   req->dst->length >= req->cryptlen)) {
508 		(*crypt_func)(&ctx->crypt_ctx, sg_virt(req->src),
509 			      sg_virt(req->dst), req->cryptlen, req->iv);
510 		kernel_fpu_end();
511 		return 0;
512 	}
513 	kernel_fpu_end();
514 	return xts_crypt_slowpath(req, crypt_func);
515 }
516 
517 static void aesni_xts_encrypt_iv(const struct crypto_aes_ctx *tweak_key,
518 				 u8 iv[AES_BLOCK_SIZE])
519 {
520 	aesni_enc(tweak_key, iv, iv);
521 }
522 
523 static void aesni_xts_encrypt(const struct crypto_aes_ctx *key,
524 			      const u8 *src, u8 *dst, int len,
525 			      u8 tweak[AES_BLOCK_SIZE])
526 {
527 	aesni_xts_enc(key, dst, src, len, tweak);
528 }
529 
530 static void aesni_xts_decrypt(const struct crypto_aes_ctx *key,
531 			      const u8 *src, u8 *dst, int len,
532 			      u8 tweak[AES_BLOCK_SIZE])
533 {
534 	aesni_xts_dec(key, dst, src, len, tweak);
535 }
536 
537 static int xts_encrypt_aesni(struct skcipher_request *req)
538 {
539 	return xts_crypt(req, aesni_xts_encrypt_iv, aesni_xts_encrypt);
540 }
541 
542 static int xts_decrypt_aesni(struct skcipher_request *req)
543 {
544 	return xts_crypt(req, aesni_xts_encrypt_iv, aesni_xts_decrypt);
545 }
546 
547 static struct crypto_alg aesni_cipher_alg = {
548 	.cra_name		= "aes",
549 	.cra_driver_name	= "aes-aesni",
550 	.cra_priority		= 300,
551 	.cra_flags		= CRYPTO_ALG_TYPE_CIPHER,
552 	.cra_blocksize		= AES_BLOCK_SIZE,
553 	.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
554 	.cra_module		= THIS_MODULE,
555 	.cra_u	= {
556 		.cipher	= {
557 			.cia_min_keysize	= AES_MIN_KEY_SIZE,
558 			.cia_max_keysize	= AES_MAX_KEY_SIZE,
559 			.cia_setkey		= aes_set_key,
560 			.cia_encrypt		= aesni_encrypt,
561 			.cia_decrypt		= aesni_decrypt
562 		}
563 	}
564 };
565 
566 static struct skcipher_alg aesni_skciphers[] = {
567 	{
568 		.base = {
569 			.cra_name		= "ecb(aes)",
570 			.cra_driver_name	= "ecb-aes-aesni",
571 			.cra_priority		= 400,
572 			.cra_blocksize		= AES_BLOCK_SIZE,
573 			.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
574 			.cra_module		= THIS_MODULE,
575 		},
576 		.min_keysize	= AES_MIN_KEY_SIZE,
577 		.max_keysize	= AES_MAX_KEY_SIZE,
578 		.setkey		= aesni_skcipher_setkey,
579 		.encrypt	= ecb_encrypt,
580 		.decrypt	= ecb_decrypt,
581 	}, {
582 		.base = {
583 			.cra_name		= "cbc(aes)",
584 			.cra_driver_name	= "cbc-aes-aesni",
585 			.cra_priority		= 400,
586 			.cra_blocksize		= AES_BLOCK_SIZE,
587 			.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
588 			.cra_module		= THIS_MODULE,
589 		},
590 		.min_keysize	= AES_MIN_KEY_SIZE,
591 		.max_keysize	= AES_MAX_KEY_SIZE,
592 		.ivsize		= AES_BLOCK_SIZE,
593 		.setkey		= aesni_skcipher_setkey,
594 		.encrypt	= cbc_encrypt,
595 		.decrypt	= cbc_decrypt,
596 	}, {
597 		.base = {
598 			.cra_name		= "cts(cbc(aes))",
599 			.cra_driver_name	= "cts-cbc-aes-aesni",
600 			.cra_priority		= 400,
601 			.cra_blocksize		= AES_BLOCK_SIZE,
602 			.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
603 			.cra_module		= THIS_MODULE,
604 		},
605 		.min_keysize	= AES_MIN_KEY_SIZE,
606 		.max_keysize	= AES_MAX_KEY_SIZE,
607 		.ivsize		= AES_BLOCK_SIZE,
608 		.walksize	= 2 * AES_BLOCK_SIZE,
609 		.setkey		= aesni_skcipher_setkey,
610 		.encrypt	= cts_cbc_encrypt,
611 		.decrypt	= cts_cbc_decrypt,
612 #ifdef CONFIG_X86_64
613 	}, {
614 		.base = {
615 			.cra_name		= "ctr(aes)",
616 			.cra_driver_name	= "ctr-aes-aesni",
617 			.cra_priority		= 400,
618 			.cra_blocksize		= 1,
619 			.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
620 			.cra_module		= THIS_MODULE,
621 		},
622 		.min_keysize	= AES_MIN_KEY_SIZE,
623 		.max_keysize	= AES_MAX_KEY_SIZE,
624 		.ivsize		= AES_BLOCK_SIZE,
625 		.chunksize	= AES_BLOCK_SIZE,
626 		.setkey		= aesni_skcipher_setkey,
627 		.encrypt	= ctr_crypt_aesni,
628 		.decrypt	= ctr_crypt_aesni,
629 #endif
630 	}, {
631 		.base = {
632 			.cra_name		= "xts(aes)",
633 			.cra_driver_name	= "xts-aes-aesni",
634 			.cra_priority		= 401,
635 			.cra_blocksize		= AES_BLOCK_SIZE,
636 			.cra_ctxsize		= XTS_AES_CTX_SIZE,
637 			.cra_module		= THIS_MODULE,
638 		},
639 		.min_keysize	= 2 * AES_MIN_KEY_SIZE,
640 		.max_keysize	= 2 * AES_MAX_KEY_SIZE,
641 		.ivsize		= AES_BLOCK_SIZE,
642 		.walksize	= 2 * AES_BLOCK_SIZE,
643 		.setkey		= xts_setkey_aesni,
644 		.encrypt	= xts_encrypt_aesni,
645 		.decrypt	= xts_decrypt_aesni,
646 	}
647 };
648 
649 #ifdef CONFIG_X86_64
650 asmlinkage void aes_xts_encrypt_iv(const struct crypto_aes_ctx *tweak_key,
651 				   u8 iv[AES_BLOCK_SIZE]);
652 
653 /* __always_inline to avoid indirect call */
654 static __always_inline int
655 ctr_crypt(struct skcipher_request *req,
656 	  void (*ctr64_func)(const struct crypto_aes_ctx *key,
657 			     const u8 *src, u8 *dst, int len,
658 			     const u64 le_ctr[2]))
659 {
660 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
661 	const struct crypto_aes_ctx *key = aes_ctx(crypto_skcipher_ctx(tfm));
662 	unsigned int nbytes, p1_nbytes, nblocks;
663 	struct skcipher_walk walk;
664 	u64 le_ctr[2];
665 	u64 ctr64;
666 	int err;
667 
668 	ctr64 = le_ctr[0] = get_unaligned_be64(&req->iv[8]);
669 	le_ctr[1] = get_unaligned_be64(&req->iv[0]);
670 
671 	err = skcipher_walk_virt(&walk, req, false);
672 
673 	while ((nbytes = walk.nbytes) != 0) {
674 		if (nbytes < walk.total) {
675 			/* Not the end yet, so keep the length block-aligned. */
676 			nbytes = round_down(nbytes, AES_BLOCK_SIZE);
677 			nblocks = nbytes / AES_BLOCK_SIZE;
678 		} else {
679 			/* It's the end, so include any final partial block. */
680 			nblocks = DIV_ROUND_UP(nbytes, AES_BLOCK_SIZE);
681 		}
682 		ctr64 += nblocks;
683 
684 		kernel_fpu_begin();
685 		if (likely(ctr64 >= nblocks)) {
686 			/* The low 64 bits of the counter won't overflow. */
687 			(*ctr64_func)(key, walk.src.virt.addr,
688 				      walk.dst.virt.addr, nbytes, le_ctr);
689 		} else {
690 			/*
691 			 * The low 64 bits of the counter will overflow.  The
692 			 * assembly doesn't handle this case, so split the
693 			 * operation into two at the point where the overflow
694 			 * will occur.  After the first part, add the carry bit.
695 			 */
696 			p1_nbytes = min_t(unsigned int, nbytes,
697 					  (nblocks - ctr64) * AES_BLOCK_SIZE);
698 			(*ctr64_func)(key, walk.src.virt.addr,
699 				      walk.dst.virt.addr, p1_nbytes, le_ctr);
700 			le_ctr[0] = 0;
701 			le_ctr[1]++;
702 			(*ctr64_func)(key, walk.src.virt.addr + p1_nbytes,
703 				      walk.dst.virt.addr + p1_nbytes,
704 				      nbytes - p1_nbytes, le_ctr);
705 		}
706 		kernel_fpu_end();
707 		le_ctr[0] = ctr64;
708 
709 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
710 	}
711 
712 	put_unaligned_be64(ctr64, &req->iv[8]);
713 	put_unaligned_be64(le_ctr[1], &req->iv[0]);
714 
715 	return err;
716 }
717 
718 /* __always_inline to avoid indirect call */
719 static __always_inline int
720 xctr_crypt(struct skcipher_request *req,
721 	   void (*xctr_func)(const struct crypto_aes_ctx *key,
722 			     const u8 *src, u8 *dst, int len,
723 			     const u8 iv[AES_BLOCK_SIZE], u64 ctr))
724 {
725 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
726 	const struct crypto_aes_ctx *key = aes_ctx(crypto_skcipher_ctx(tfm));
727 	struct skcipher_walk walk;
728 	unsigned int nbytes;
729 	u64 ctr = 1;
730 	int err;
731 
732 	err = skcipher_walk_virt(&walk, req, false);
733 	while ((nbytes = walk.nbytes) != 0) {
734 		if (nbytes < walk.total)
735 			nbytes = round_down(nbytes, AES_BLOCK_SIZE);
736 
737 		kernel_fpu_begin();
738 		(*xctr_func)(key, walk.src.virt.addr, walk.dst.virt.addr,
739 			     nbytes, req->iv, ctr);
740 		kernel_fpu_end();
741 
742 		ctr += DIV_ROUND_UP(nbytes, AES_BLOCK_SIZE);
743 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
744 	}
745 	return err;
746 }
747 
748 #define DEFINE_AVX_SKCIPHER_ALGS(suffix, driver_name_suffix, priority)	       \
749 									       \
750 asmlinkage void								       \
751 aes_xts_encrypt_##suffix(const struct crypto_aes_ctx *key, const u8 *src,      \
752 			 u8 *dst, int len, u8 tweak[AES_BLOCK_SIZE]);	       \
753 asmlinkage void								       \
754 aes_xts_decrypt_##suffix(const struct crypto_aes_ctx *key, const u8 *src,      \
755 			 u8 *dst, int len, u8 tweak[AES_BLOCK_SIZE]);	       \
756 									       \
757 static int xts_encrypt_##suffix(struct skcipher_request *req)		       \
758 {									       \
759 	return xts_crypt(req, aes_xts_encrypt_iv, aes_xts_encrypt_##suffix);   \
760 }									       \
761 									       \
762 static int xts_decrypt_##suffix(struct skcipher_request *req)		       \
763 {									       \
764 	return xts_crypt(req, aes_xts_encrypt_iv, aes_xts_decrypt_##suffix);   \
765 }									       \
766 									       \
767 asmlinkage void								       \
768 aes_ctr64_crypt_##suffix(const struct crypto_aes_ctx *key,		       \
769 			 const u8 *src, u8 *dst, int len, const u64 le_ctr[2]);\
770 									       \
771 static int ctr_crypt_##suffix(struct skcipher_request *req)		       \
772 {									       \
773 	return ctr_crypt(req, aes_ctr64_crypt_##suffix);		       \
774 }									       \
775 									       \
776 asmlinkage void								       \
777 aes_xctr_crypt_##suffix(const struct crypto_aes_ctx *key,		       \
778 			const u8 *src, u8 *dst, int len,		       \
779 			const u8 iv[AES_BLOCK_SIZE], u64 ctr);		       \
780 									       \
781 static int xctr_crypt_##suffix(struct skcipher_request *req)		       \
782 {									       \
783 	return xctr_crypt(req, aes_xctr_crypt_##suffix);		       \
784 }									       \
785 									       \
786 static struct skcipher_alg skcipher_algs_##suffix[] = {{		       \
787 	.base.cra_name		= "xts(aes)",				       \
788 	.base.cra_driver_name	= "xts-aes-" driver_name_suffix,	       \
789 	.base.cra_priority	= priority,				       \
790 	.base.cra_blocksize	= AES_BLOCK_SIZE,			       \
791 	.base.cra_ctxsize	= XTS_AES_CTX_SIZE,			       \
792 	.base.cra_module	= THIS_MODULE,				       \
793 	.min_keysize		= 2 * AES_MIN_KEY_SIZE,			       \
794 	.max_keysize		= 2 * AES_MAX_KEY_SIZE,			       \
795 	.ivsize			= AES_BLOCK_SIZE,			       \
796 	.walksize		= 2 * AES_BLOCK_SIZE,			       \
797 	.setkey			= xts_setkey_aesni,			       \
798 	.encrypt		= xts_encrypt_##suffix,			       \
799 	.decrypt		= xts_decrypt_##suffix,			       \
800 }, {									       \
801 	.base.cra_name		= "ctr(aes)",				       \
802 	.base.cra_driver_name	= "ctr-aes-" driver_name_suffix,	       \
803 	.base.cra_priority	= priority,				       \
804 	.base.cra_blocksize	= 1,					       \
805 	.base.cra_ctxsize	= CRYPTO_AES_CTX_SIZE,			       \
806 	.base.cra_module	= THIS_MODULE,				       \
807 	.min_keysize		= AES_MIN_KEY_SIZE,			       \
808 	.max_keysize		= AES_MAX_KEY_SIZE,			       \
809 	.ivsize			= AES_BLOCK_SIZE,			       \
810 	.chunksize		= AES_BLOCK_SIZE,			       \
811 	.setkey			= aesni_skcipher_setkey,		       \
812 	.encrypt		= ctr_crypt_##suffix,			       \
813 	.decrypt		= ctr_crypt_##suffix,			       \
814 }, {									       \
815 	.base.cra_name		= "xctr(aes)",				       \
816 	.base.cra_driver_name	= "xctr-aes-" driver_name_suffix,	       \
817 	.base.cra_priority	= priority,				       \
818 	.base.cra_blocksize	= 1,					       \
819 	.base.cra_ctxsize	= CRYPTO_AES_CTX_SIZE,			       \
820 	.base.cra_module	= THIS_MODULE,				       \
821 	.min_keysize		= AES_MIN_KEY_SIZE,			       \
822 	.max_keysize		= AES_MAX_KEY_SIZE,			       \
823 	.ivsize			= AES_BLOCK_SIZE,			       \
824 	.chunksize		= AES_BLOCK_SIZE,			       \
825 	.setkey			= aesni_skcipher_setkey,		       \
826 	.encrypt		= xctr_crypt_##suffix,			       \
827 	.decrypt		= xctr_crypt_##suffix,			       \
828 }}
829 
830 DEFINE_AVX_SKCIPHER_ALGS(aesni_avx, "aesni-avx", 500);
831 DEFINE_AVX_SKCIPHER_ALGS(vaes_avx2, "vaes-avx2", 600);
832 DEFINE_AVX_SKCIPHER_ALGS(vaes_avx512, "vaes-avx512", 800);
833 
834 /* The common part of the x86_64 AES-GCM key struct */
835 struct aes_gcm_key {
836 	/* Expanded AES key and the AES key length in bytes */
837 	struct crypto_aes_ctx aes_key;
838 
839 	/* RFC4106 nonce (used only by the rfc4106 algorithms) */
840 	u32 rfc4106_nonce;
841 };
842 
843 /* Key struct used by the AES-NI implementations of AES-GCM */
844 struct aes_gcm_key_aesni {
845 	/*
846 	 * Common part of the key.  The assembly code requires 16-byte alignment
847 	 * for the round keys; we get this by them being located at the start of
848 	 * the struct and the whole struct being 16-byte aligned.
849 	 */
850 	struct aes_gcm_key base;
851 
852 	/*
853 	 * Powers of the hash key H^8 through H^1.  These are 128-bit values.
854 	 * They all have an extra factor of x^-1 and are byte-reversed.  16-byte
855 	 * alignment is required by the assembly code.
856 	 */
857 	u64 h_powers[8][2] __aligned(16);
858 
859 	/*
860 	 * h_powers_xored[i] contains the two 64-bit halves of h_powers[i] XOR'd
861 	 * together.  It's used for Karatsuba multiplication.  16-byte alignment
862 	 * is required by the assembly code.
863 	 */
864 	u64 h_powers_xored[8] __aligned(16);
865 
866 	/*
867 	 * H^1 times x^64 (and also the usual extra factor of x^-1).  16-byte
868 	 * alignment is required by the assembly code.
869 	 */
870 	u64 h_times_x64[2] __aligned(16);
871 };
872 #define AES_GCM_KEY_AESNI(key)	\
873 	container_of((key), struct aes_gcm_key_aesni, base)
874 #define AES_GCM_KEY_AESNI_SIZE	\
875 	(sizeof(struct aes_gcm_key_aesni) + (15 & ~(CRYPTO_MINALIGN - 1)))
876 
877 /* Key struct used by the VAES + AVX2 implementation of AES-GCM */
878 struct aes_gcm_key_vaes_avx2 {
879 	/*
880 	 * Common part of the key.  The assembly code prefers 16-byte alignment
881 	 * for the round keys; we get this by them being located at the start of
882 	 * the struct and the whole struct being 32-byte aligned.
883 	 */
884 	struct aes_gcm_key base;
885 
886 	/*
887 	 * Powers of the hash key H^8 through H^1.  These are 128-bit values.
888 	 * They all have an extra factor of x^-1 and are byte-reversed.
889 	 * The assembly code prefers 32-byte alignment for this.
890 	 */
891 	u64 h_powers[8][2] __aligned(32);
892 
893 	/*
894 	 * Each entry in this array contains the two halves of an entry of
895 	 * h_powers XOR'd together, in the following order:
896 	 * H^8,H^6,H^7,H^5,H^4,H^2,H^3,H^1 i.e. indices 0,2,1,3,4,6,5,7.
897 	 * This is used for Karatsuba multiplication.
898 	 */
899 	u64 h_powers_xored[8];
900 };
901 
902 #define AES_GCM_KEY_VAES_AVX2(key) \
903 	container_of((key), struct aes_gcm_key_vaes_avx2, base)
904 #define AES_GCM_KEY_VAES_AVX2_SIZE \
905 	(sizeof(struct aes_gcm_key_vaes_avx2) + (31 & ~(CRYPTO_MINALIGN - 1)))
906 
907 /* Key struct used by the VAES + AVX512 implementation of AES-GCM */
908 struct aes_gcm_key_vaes_avx512 {
909 	/*
910 	 * Common part of the key.  The assembly code prefers 16-byte alignment
911 	 * for the round keys; we get this by them being located at the start of
912 	 * the struct and the whole struct being 64-byte aligned.
913 	 */
914 	struct aes_gcm_key base;
915 
916 	/*
917 	 * Powers of the hash key H^16 through H^1.  These are 128-bit values.
918 	 * They all have an extra factor of x^-1 and are byte-reversed.  This
919 	 * array is aligned to a 64-byte boundary to make it naturally aligned
920 	 * for 512-bit loads, which can improve performance.  (The assembly code
921 	 * doesn't *need* the alignment; this is just an optimization.)
922 	 */
923 	u64 h_powers[16][2] __aligned(64);
924 
925 	/* Three padding blocks required by the assembly code */
926 	u64 padding[3][2];
927 };
928 #define AES_GCM_KEY_VAES_AVX512(key) \
929 	container_of((key), struct aes_gcm_key_vaes_avx512, base)
930 #define AES_GCM_KEY_VAES_AVX512_SIZE \
931 	(sizeof(struct aes_gcm_key_vaes_avx512) + (63 & ~(CRYPTO_MINALIGN - 1)))
932 
933 /*
934  * These flags are passed to the AES-GCM helper functions to specify the
935  * specific version of AES-GCM (RFC4106 or not), whether it's encryption or
936  * decryption, and which assembly functions should be called.  Assembly
937  * functions are selected using flags instead of function pointers to avoid
938  * indirect calls (which are very expensive on x86) regardless of inlining.
939  */
940 #define FLAG_RFC4106	BIT(0)
941 #define FLAG_ENC	BIT(1)
942 #define FLAG_AVX	BIT(2)
943 #define FLAG_VAES_AVX2	BIT(3)
944 #define FLAG_VAES_AVX512 BIT(4)
945 
946 static inline struct aes_gcm_key *
947 aes_gcm_key_get(struct crypto_aead *tfm, int flags)
948 {
949 	if (flags & FLAG_VAES_AVX512)
950 		return PTR_ALIGN(crypto_aead_ctx(tfm), 64);
951 	else if (flags & FLAG_VAES_AVX2)
952 		return PTR_ALIGN(crypto_aead_ctx(tfm), 32);
953 	else
954 		return PTR_ALIGN(crypto_aead_ctx(tfm), 16);
955 }
956 
957 asmlinkage void
958 aes_gcm_precompute_aesni(struct aes_gcm_key_aesni *key);
959 asmlinkage void
960 aes_gcm_precompute_aesni_avx(struct aes_gcm_key_aesni *key);
961 asmlinkage void
962 aes_gcm_precompute_vaes_avx2(struct aes_gcm_key_vaes_avx2 *key);
963 asmlinkage void
964 aes_gcm_precompute_vaes_avx512(struct aes_gcm_key_vaes_avx512 *key);
965 
966 static void aes_gcm_precompute(struct aes_gcm_key *key, int flags)
967 {
968 	if (flags & FLAG_VAES_AVX512)
969 		aes_gcm_precompute_vaes_avx512(AES_GCM_KEY_VAES_AVX512(key));
970 	else if (flags & FLAG_VAES_AVX2)
971 		aes_gcm_precompute_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key));
972 	else if (flags & FLAG_AVX)
973 		aes_gcm_precompute_aesni_avx(AES_GCM_KEY_AESNI(key));
974 	else
975 		aes_gcm_precompute_aesni(AES_GCM_KEY_AESNI(key));
976 }
977 
978 asmlinkage void
979 aes_gcm_aad_update_aesni(const struct aes_gcm_key_aesni *key,
980 			 u8 ghash_acc[16], const u8 *aad, int aadlen);
981 asmlinkage void
982 aes_gcm_aad_update_aesni_avx(const struct aes_gcm_key_aesni *key,
983 			     u8 ghash_acc[16], const u8 *aad, int aadlen);
984 asmlinkage void
985 aes_gcm_aad_update_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key,
986 			     u8 ghash_acc[16], const u8 *aad, int aadlen);
987 asmlinkage void
988 aes_gcm_aad_update_vaes_avx512(const struct aes_gcm_key_vaes_avx512 *key,
989 			       u8 ghash_acc[16], const u8 *aad, int aadlen);
990 
991 static void aes_gcm_aad_update(const struct aes_gcm_key *key, u8 ghash_acc[16],
992 			       const u8 *aad, int aadlen, int flags)
993 {
994 	if (flags & FLAG_VAES_AVX512)
995 		aes_gcm_aad_update_vaes_avx512(AES_GCM_KEY_VAES_AVX512(key),
996 					       ghash_acc, aad, aadlen);
997 	else if (flags & FLAG_VAES_AVX2)
998 		aes_gcm_aad_update_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key),
999 					     ghash_acc, aad, aadlen);
1000 	else if (flags & FLAG_AVX)
1001 		aes_gcm_aad_update_aesni_avx(AES_GCM_KEY_AESNI(key), ghash_acc,
1002 					     aad, aadlen);
1003 	else
1004 		aes_gcm_aad_update_aesni(AES_GCM_KEY_AESNI(key), ghash_acc,
1005 					 aad, aadlen);
1006 }
1007 
1008 asmlinkage void
1009 aes_gcm_enc_update_aesni(const struct aes_gcm_key_aesni *key,
1010 			 const u32 le_ctr[4], u8 ghash_acc[16],
1011 			 const u8 *src, u8 *dst, int datalen);
1012 asmlinkage void
1013 aes_gcm_enc_update_aesni_avx(const struct aes_gcm_key_aesni *key,
1014 			     const u32 le_ctr[4], u8 ghash_acc[16],
1015 			     const u8 *src, u8 *dst, int datalen);
1016 asmlinkage void
1017 aes_gcm_enc_update_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key,
1018 			     const u32 le_ctr[4], u8 ghash_acc[16],
1019 			     const u8 *src, u8 *dst, int datalen);
1020 asmlinkage void
1021 aes_gcm_enc_update_vaes_avx512(const struct aes_gcm_key_vaes_avx512 *key,
1022 			       const u32 le_ctr[4], u8 ghash_acc[16],
1023 			       const u8 *src, u8 *dst, int datalen);
1024 
1025 asmlinkage void
1026 aes_gcm_dec_update_aesni(const struct aes_gcm_key_aesni *key,
1027 			 const u32 le_ctr[4], u8 ghash_acc[16],
1028 			 const u8 *src, u8 *dst, int datalen);
1029 asmlinkage void
1030 aes_gcm_dec_update_aesni_avx(const struct aes_gcm_key_aesni *key,
1031 			     const u32 le_ctr[4], u8 ghash_acc[16],
1032 			     const u8 *src, u8 *dst, int datalen);
1033 asmlinkage void
1034 aes_gcm_dec_update_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key,
1035 			     const u32 le_ctr[4], u8 ghash_acc[16],
1036 			     const u8 *src, u8 *dst, int datalen);
1037 asmlinkage void
1038 aes_gcm_dec_update_vaes_avx512(const struct aes_gcm_key_vaes_avx512 *key,
1039 			       const u32 le_ctr[4], u8 ghash_acc[16],
1040 			       const u8 *src, u8 *dst, int datalen);
1041 
1042 /* __always_inline to optimize out the branches based on @flags */
1043 static __always_inline void
1044 aes_gcm_update(const struct aes_gcm_key *key,
1045 	       const u32 le_ctr[4], u8 ghash_acc[16],
1046 	       const u8 *src, u8 *dst, int datalen, int flags)
1047 {
1048 	if (flags & FLAG_ENC) {
1049 		if (flags & FLAG_VAES_AVX512)
1050 			aes_gcm_enc_update_vaes_avx512(AES_GCM_KEY_VAES_AVX512(key),
1051 						       le_ctr, ghash_acc,
1052 						       src, dst, datalen);
1053 		else if (flags & FLAG_VAES_AVX2)
1054 			aes_gcm_enc_update_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key),
1055 						     le_ctr, ghash_acc,
1056 						     src, dst, datalen);
1057 		else if (flags & FLAG_AVX)
1058 			aes_gcm_enc_update_aesni_avx(AES_GCM_KEY_AESNI(key),
1059 						     le_ctr, ghash_acc,
1060 						     src, dst, datalen);
1061 		else
1062 			aes_gcm_enc_update_aesni(AES_GCM_KEY_AESNI(key), le_ctr,
1063 						 ghash_acc, src, dst, datalen);
1064 	} else {
1065 		if (flags & FLAG_VAES_AVX512)
1066 			aes_gcm_dec_update_vaes_avx512(AES_GCM_KEY_VAES_AVX512(key),
1067 						       le_ctr, ghash_acc,
1068 						       src, dst, datalen);
1069 		else if (flags & FLAG_VAES_AVX2)
1070 			aes_gcm_dec_update_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key),
1071 						     le_ctr, ghash_acc,
1072 						     src, dst, datalen);
1073 		else if (flags & FLAG_AVX)
1074 			aes_gcm_dec_update_aesni_avx(AES_GCM_KEY_AESNI(key),
1075 						     le_ctr, ghash_acc,
1076 						     src, dst, datalen);
1077 		else
1078 			aes_gcm_dec_update_aesni(AES_GCM_KEY_AESNI(key),
1079 						 le_ctr, ghash_acc,
1080 						 src, dst, datalen);
1081 	}
1082 }
1083 
1084 asmlinkage void
1085 aes_gcm_enc_final_aesni(const struct aes_gcm_key_aesni *key,
1086 			const u32 le_ctr[4], u8 ghash_acc[16],
1087 			u64 total_aadlen, u64 total_datalen);
1088 asmlinkage void
1089 aes_gcm_enc_final_aesni_avx(const struct aes_gcm_key_aesni *key,
1090 			    const u32 le_ctr[4], u8 ghash_acc[16],
1091 			    u64 total_aadlen, u64 total_datalen);
1092 asmlinkage void
1093 aes_gcm_enc_final_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key,
1094 			    const u32 le_ctr[4], u8 ghash_acc[16],
1095 			    u64 total_aadlen, u64 total_datalen);
1096 asmlinkage void
1097 aes_gcm_enc_final_vaes_avx512(const struct aes_gcm_key_vaes_avx512 *key,
1098 			      const u32 le_ctr[4], u8 ghash_acc[16],
1099 			      u64 total_aadlen, u64 total_datalen);
1100 
1101 /* __always_inline to optimize out the branches based on @flags */
1102 static __always_inline void
1103 aes_gcm_enc_final(const struct aes_gcm_key *key,
1104 		  const u32 le_ctr[4], u8 ghash_acc[16],
1105 		  u64 total_aadlen, u64 total_datalen, int flags)
1106 {
1107 	if (flags & FLAG_VAES_AVX512)
1108 		aes_gcm_enc_final_vaes_avx512(AES_GCM_KEY_VAES_AVX512(key),
1109 					      le_ctr, ghash_acc,
1110 					      total_aadlen, total_datalen);
1111 	else if (flags & FLAG_VAES_AVX2)
1112 		aes_gcm_enc_final_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key),
1113 					    le_ctr, ghash_acc,
1114 					    total_aadlen, total_datalen);
1115 	else if (flags & FLAG_AVX)
1116 		aes_gcm_enc_final_aesni_avx(AES_GCM_KEY_AESNI(key),
1117 					    le_ctr, ghash_acc,
1118 					    total_aadlen, total_datalen);
1119 	else
1120 		aes_gcm_enc_final_aesni(AES_GCM_KEY_AESNI(key),
1121 					le_ctr, ghash_acc,
1122 					total_aadlen, total_datalen);
1123 }
1124 
1125 asmlinkage bool __must_check
1126 aes_gcm_dec_final_aesni(const struct aes_gcm_key_aesni *key,
1127 			const u32 le_ctr[4], const u8 ghash_acc[16],
1128 			u64 total_aadlen, u64 total_datalen,
1129 			const u8 tag[16], int taglen);
1130 asmlinkage bool __must_check
1131 aes_gcm_dec_final_aesni_avx(const struct aes_gcm_key_aesni *key,
1132 			    const u32 le_ctr[4], const u8 ghash_acc[16],
1133 			    u64 total_aadlen, u64 total_datalen,
1134 			    const u8 tag[16], int taglen);
1135 asmlinkage bool __must_check
1136 aes_gcm_dec_final_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key,
1137 			    const u32 le_ctr[4], const u8 ghash_acc[16],
1138 			    u64 total_aadlen, u64 total_datalen,
1139 			    const u8 tag[16], int taglen);
1140 asmlinkage bool __must_check
1141 aes_gcm_dec_final_vaes_avx512(const struct aes_gcm_key_vaes_avx512 *key,
1142 			      const u32 le_ctr[4], const u8 ghash_acc[16],
1143 			      u64 total_aadlen, u64 total_datalen,
1144 			      const u8 tag[16], int taglen);
1145 
1146 /* __always_inline to optimize out the branches based on @flags */
1147 static __always_inline bool __must_check
1148 aes_gcm_dec_final(const struct aes_gcm_key *key, const u32 le_ctr[4],
1149 		  u8 ghash_acc[16], u64 total_aadlen, u64 total_datalen,
1150 		  u8 tag[16], int taglen, int flags)
1151 {
1152 	if (flags & FLAG_VAES_AVX512)
1153 		return aes_gcm_dec_final_vaes_avx512(AES_GCM_KEY_VAES_AVX512(key),
1154 						     le_ctr, ghash_acc,
1155 						     total_aadlen, total_datalen,
1156 						     tag, taglen);
1157 	else if (flags & FLAG_VAES_AVX2)
1158 		return aes_gcm_dec_final_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key),
1159 						   le_ctr, ghash_acc,
1160 						   total_aadlen, total_datalen,
1161 						   tag, taglen);
1162 	else if (flags & FLAG_AVX)
1163 		return aes_gcm_dec_final_aesni_avx(AES_GCM_KEY_AESNI(key),
1164 						   le_ctr, ghash_acc,
1165 						   total_aadlen, total_datalen,
1166 						   tag, taglen);
1167 	else
1168 		return aes_gcm_dec_final_aesni(AES_GCM_KEY_AESNI(key),
1169 					       le_ctr, ghash_acc,
1170 					       total_aadlen, total_datalen,
1171 					       tag, taglen);
1172 }
1173 
1174 /*
1175  * This is the Integrity Check Value (aka the authentication tag) length and can
1176  * be 8, 12 or 16 bytes long.
1177  */
1178 static int common_rfc4106_set_authsize(struct crypto_aead *aead,
1179 				       unsigned int authsize)
1180 {
1181 	switch (authsize) {
1182 	case 8:
1183 	case 12:
1184 	case 16:
1185 		break;
1186 	default:
1187 		return -EINVAL;
1188 	}
1189 
1190 	return 0;
1191 }
1192 
1193 static int generic_gcmaes_set_authsize(struct crypto_aead *tfm,
1194 				       unsigned int authsize)
1195 {
1196 	switch (authsize) {
1197 	case 4:
1198 	case 8:
1199 	case 12:
1200 	case 13:
1201 	case 14:
1202 	case 15:
1203 	case 16:
1204 		break;
1205 	default:
1206 		return -EINVAL;
1207 	}
1208 
1209 	return 0;
1210 }
1211 
1212 /*
1213  * This is the setkey function for the x86_64 implementations of AES-GCM.  It
1214  * saves the RFC4106 nonce if applicable, expands the AES key, and precomputes
1215  * powers of the hash key.
1216  *
1217  * To comply with the crypto_aead API, this has to be usable in no-SIMD context.
1218  * For that reason, this function includes a portable C implementation of the
1219  * needed logic.  However, the portable C implementation is very slow, taking
1220  * about the same time as encrypting 37 KB of data.  To be ready for users that
1221  * may set a key even somewhat frequently, we therefore also include a SIMD
1222  * assembly implementation, expanding the AES key using AES-NI and precomputing
1223  * the hash key powers using PCLMULQDQ or VPCLMULQDQ.
1224  */
1225 static int gcm_setkey(struct crypto_aead *tfm, const u8 *raw_key,
1226 		      unsigned int keylen, int flags)
1227 {
1228 	struct aes_gcm_key *key = aes_gcm_key_get(tfm, flags);
1229 	int err;
1230 
1231 	if (flags & FLAG_RFC4106) {
1232 		if (keylen < 4)
1233 			return -EINVAL;
1234 		keylen -= 4;
1235 		key->rfc4106_nonce = get_unaligned_be32(raw_key + keylen);
1236 	}
1237 
1238 	/* The assembly code assumes the following offsets. */
1239 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_aesni, base.aes_key.key_enc) != 0);
1240 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_aesni, base.aes_key.key_length) != 480);
1241 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_aesni, h_powers) != 496);
1242 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_aesni, h_powers_xored) != 624);
1243 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_aesni, h_times_x64) != 688);
1244 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_vaes_avx2, base.aes_key.key_enc) != 0);
1245 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_vaes_avx2, base.aes_key.key_length) != 480);
1246 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_vaes_avx2, h_powers) != 512);
1247 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_vaes_avx2, h_powers_xored) != 640);
1248 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_vaes_avx512, base.aes_key.key_enc) != 0);
1249 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_vaes_avx512, base.aes_key.key_length) != 480);
1250 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_vaes_avx512, h_powers) != 512);
1251 	BUILD_BUG_ON(offsetof(struct aes_gcm_key_vaes_avx512, padding) != 768);
1252 
1253 	if (likely(crypto_simd_usable())) {
1254 		err = aes_check_keylen(keylen);
1255 		if (err)
1256 			return err;
1257 		kernel_fpu_begin();
1258 		aesni_set_key(&key->aes_key, raw_key, keylen);
1259 		aes_gcm_precompute(key, flags);
1260 		kernel_fpu_end();
1261 	} else {
1262 		static const u8 x_to_the_minus1[16] __aligned(__alignof__(be128)) = {
1263 			[0] = 0xc2, [15] = 1
1264 		};
1265 		static const u8 x_to_the_63[16] __aligned(__alignof__(be128)) = {
1266 			[7] = 1,
1267 		};
1268 		be128 h1 = {};
1269 		be128 h;
1270 		int i;
1271 
1272 		err = aes_expandkey(&key->aes_key, raw_key, keylen);
1273 		if (err)
1274 			return err;
1275 
1276 		/* Encrypt the all-zeroes block to get the hash key H^1 */
1277 		aes_encrypt(&key->aes_key, (u8 *)&h1, (u8 *)&h1);
1278 
1279 		/* Compute H^1 * x^-1 */
1280 		h = h1;
1281 		gf128mul_lle(&h, (const be128 *)x_to_the_minus1);
1282 
1283 		/* Compute the needed key powers */
1284 		if (flags & FLAG_VAES_AVX512) {
1285 			struct aes_gcm_key_vaes_avx512 *k =
1286 				AES_GCM_KEY_VAES_AVX512(key);
1287 
1288 			for (i = ARRAY_SIZE(k->h_powers) - 1; i >= 0; i--) {
1289 				k->h_powers[i][0] = be64_to_cpu(h.b);
1290 				k->h_powers[i][1] = be64_to_cpu(h.a);
1291 				gf128mul_lle(&h, &h1);
1292 			}
1293 			memset(k->padding, 0, sizeof(k->padding));
1294 		} else if (flags & FLAG_VAES_AVX2) {
1295 			struct aes_gcm_key_vaes_avx2 *k =
1296 				AES_GCM_KEY_VAES_AVX2(key);
1297 			static const u8 indices[8] = { 0, 2, 1, 3, 4, 6, 5, 7 };
1298 
1299 			for (i = ARRAY_SIZE(k->h_powers) - 1; i >= 0; i--) {
1300 				k->h_powers[i][0] = be64_to_cpu(h.b);
1301 				k->h_powers[i][1] = be64_to_cpu(h.a);
1302 				gf128mul_lle(&h, &h1);
1303 			}
1304 			for (i = 0; i < ARRAY_SIZE(k->h_powers_xored); i++) {
1305 				int j = indices[i];
1306 
1307 				k->h_powers_xored[i] = k->h_powers[j][0] ^
1308 						       k->h_powers[j][1];
1309 			}
1310 		} else {
1311 			struct aes_gcm_key_aesni *k = AES_GCM_KEY_AESNI(key);
1312 
1313 			for (i = ARRAY_SIZE(k->h_powers) - 1; i >= 0; i--) {
1314 				k->h_powers[i][0] = be64_to_cpu(h.b);
1315 				k->h_powers[i][1] = be64_to_cpu(h.a);
1316 				k->h_powers_xored[i] = k->h_powers[i][0] ^
1317 						       k->h_powers[i][1];
1318 				gf128mul_lle(&h, &h1);
1319 			}
1320 			gf128mul_lle(&h1, (const be128 *)x_to_the_63);
1321 			k->h_times_x64[0] = be64_to_cpu(h1.b);
1322 			k->h_times_x64[1] = be64_to_cpu(h1.a);
1323 		}
1324 	}
1325 	return 0;
1326 }
1327 
1328 /*
1329  * Initialize @ghash_acc, then pass all @assoclen bytes of associated data
1330  * (a.k.a. additional authenticated data) from @sg_src through the GHASH update
1331  * assembly function.  kernel_fpu_begin() must have already been called.
1332  */
1333 static void gcm_process_assoc(const struct aes_gcm_key *key, u8 ghash_acc[16],
1334 			      struct scatterlist *sg_src, unsigned int assoclen,
1335 			      int flags)
1336 {
1337 	struct scatter_walk walk;
1338 	/*
1339 	 * The assembly function requires that the length of any non-last
1340 	 * segment of associated data be a multiple of 16 bytes, so this
1341 	 * function does the buffering needed to achieve that.
1342 	 */
1343 	unsigned int pos = 0;
1344 	u8 buf[16];
1345 
1346 	memset(ghash_acc, 0, 16);
1347 	scatterwalk_start(&walk, sg_src);
1348 
1349 	while (assoclen) {
1350 		unsigned int orig_len_this_step = scatterwalk_next(
1351 			&walk, assoclen);
1352 		unsigned int len_this_step = orig_len_this_step;
1353 		unsigned int len;
1354 		const u8 *src = walk.addr;
1355 
1356 		if (unlikely(pos)) {
1357 			len = min(len_this_step, 16 - pos);
1358 			memcpy(&buf[pos], src, len);
1359 			pos += len;
1360 			src += len;
1361 			len_this_step -= len;
1362 			if (pos < 16)
1363 				goto next;
1364 			aes_gcm_aad_update(key, ghash_acc, buf, 16, flags);
1365 			pos = 0;
1366 		}
1367 		len = len_this_step;
1368 		if (unlikely(assoclen)) /* Not the last segment yet? */
1369 			len = round_down(len, 16);
1370 		aes_gcm_aad_update(key, ghash_acc, src, len, flags);
1371 		src += len;
1372 		len_this_step -= len;
1373 		if (unlikely(len_this_step)) {
1374 			memcpy(buf, src, len_this_step);
1375 			pos = len_this_step;
1376 		}
1377 next:
1378 		scatterwalk_done_src(&walk, orig_len_this_step);
1379 		if (need_resched()) {
1380 			kernel_fpu_end();
1381 			kernel_fpu_begin();
1382 		}
1383 		assoclen -= orig_len_this_step;
1384 	}
1385 	if (unlikely(pos))
1386 		aes_gcm_aad_update(key, ghash_acc, buf, pos, flags);
1387 }
1388 
1389 
1390 /* __always_inline to optimize out the branches based on @flags */
1391 static __always_inline int
1392 gcm_crypt(struct aead_request *req, int flags)
1393 {
1394 	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
1395 	const struct aes_gcm_key *key = aes_gcm_key_get(tfm, flags);
1396 	unsigned int assoclen = req->assoclen;
1397 	struct skcipher_walk walk;
1398 	unsigned int nbytes;
1399 	u8 ghash_acc[16]; /* GHASH accumulator */
1400 	u32 le_ctr[4]; /* Counter in little-endian format */
1401 	int taglen;
1402 	int err;
1403 
1404 	/* Initialize the counter and determine the associated data length. */
1405 	le_ctr[0] = 2;
1406 	if (flags & FLAG_RFC4106) {
1407 		if (unlikely(assoclen != 16 && assoclen != 20))
1408 			return -EINVAL;
1409 		assoclen -= 8;
1410 		le_ctr[1] = get_unaligned_be32(req->iv + 4);
1411 		le_ctr[2] = get_unaligned_be32(req->iv + 0);
1412 		le_ctr[3] = key->rfc4106_nonce; /* already byte-swapped */
1413 	} else {
1414 		le_ctr[1] = get_unaligned_be32(req->iv + 8);
1415 		le_ctr[2] = get_unaligned_be32(req->iv + 4);
1416 		le_ctr[3] = get_unaligned_be32(req->iv + 0);
1417 	}
1418 
1419 	/* Begin walking through the plaintext or ciphertext. */
1420 	if (flags & FLAG_ENC)
1421 		err = skcipher_walk_aead_encrypt(&walk, req, false);
1422 	else
1423 		err = skcipher_walk_aead_decrypt(&walk, req, false);
1424 	if (err)
1425 		return err;
1426 
1427 	/*
1428 	 * Since the AES-GCM assembly code requires that at least three assembly
1429 	 * functions be called to process any message (this is needed to support
1430 	 * incremental updates cleanly), to reduce overhead we try to do all
1431 	 * three calls in the same kernel FPU section if possible.  We close the
1432 	 * section and start a new one if there are multiple data segments or if
1433 	 * rescheduling is needed while processing the associated data.
1434 	 */
1435 	kernel_fpu_begin();
1436 
1437 	/* Pass the associated data through GHASH. */
1438 	gcm_process_assoc(key, ghash_acc, req->src, assoclen, flags);
1439 
1440 	/* En/decrypt the data and pass the ciphertext through GHASH. */
1441 	while (unlikely((nbytes = walk.nbytes) < walk.total)) {
1442 		/*
1443 		 * Non-last segment.  In this case, the assembly function
1444 		 * requires that the length be a multiple of 16 (AES_BLOCK_SIZE)
1445 		 * bytes.  The needed buffering of up to 16 bytes is handled by
1446 		 * the skcipher_walk.  Here we just need to round down to a
1447 		 * multiple of 16.
1448 		 */
1449 		nbytes = round_down(nbytes, AES_BLOCK_SIZE);
1450 		aes_gcm_update(key, le_ctr, ghash_acc, walk.src.virt.addr,
1451 			       walk.dst.virt.addr, nbytes, flags);
1452 		le_ctr[0] += nbytes / AES_BLOCK_SIZE;
1453 		kernel_fpu_end();
1454 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
1455 		if (err)
1456 			return err;
1457 		kernel_fpu_begin();
1458 	}
1459 	/* Last segment: process all remaining data. */
1460 	aes_gcm_update(key, le_ctr, ghash_acc, walk.src.virt.addr,
1461 		       walk.dst.virt.addr, nbytes, flags);
1462 	/*
1463 	 * The low word of the counter isn't used by the finalize, so there's no
1464 	 * need to increment it here.
1465 	 */
1466 
1467 	/* Finalize */
1468 	taglen = crypto_aead_authsize(tfm);
1469 	if (flags & FLAG_ENC) {
1470 		/* Finish computing the auth tag. */
1471 		aes_gcm_enc_final(key, le_ctr, ghash_acc, assoclen,
1472 				  req->cryptlen, flags);
1473 
1474 		/* Store the computed auth tag in the dst scatterlist. */
1475 		scatterwalk_map_and_copy(ghash_acc, req->dst, req->assoclen +
1476 					 req->cryptlen, taglen, 1);
1477 	} else {
1478 		unsigned int datalen = req->cryptlen - taglen;
1479 		u8 tag[16];
1480 
1481 		/* Get the transmitted auth tag from the src scatterlist. */
1482 		scatterwalk_map_and_copy(tag, req->src, req->assoclen + datalen,
1483 					 taglen, 0);
1484 		/*
1485 		 * Finish computing the auth tag and compare it to the
1486 		 * transmitted one.  The assembly function does the actual tag
1487 		 * comparison.  Here, just check the boolean result.
1488 		 */
1489 		if (!aes_gcm_dec_final(key, le_ctr, ghash_acc, assoclen,
1490 				       datalen, tag, taglen, flags))
1491 			err = -EBADMSG;
1492 	}
1493 	kernel_fpu_end();
1494 	if (nbytes)
1495 		skcipher_walk_done(&walk, 0);
1496 	return err;
1497 }
1498 
1499 #define DEFINE_GCM_ALGS(suffix, flags, generic_driver_name, rfc_driver_name,   \
1500 			ctxsize, priority)				       \
1501 									       \
1502 static int gcm_setkey_##suffix(struct crypto_aead *tfm, const u8 *raw_key,     \
1503 			       unsigned int keylen)			       \
1504 {									       \
1505 	return gcm_setkey(tfm, raw_key, keylen, (flags));		       \
1506 }									       \
1507 									       \
1508 static int gcm_encrypt_##suffix(struct aead_request *req)		       \
1509 {									       \
1510 	return gcm_crypt(req, (flags) | FLAG_ENC);			       \
1511 }									       \
1512 									       \
1513 static int gcm_decrypt_##suffix(struct aead_request *req)		       \
1514 {									       \
1515 	return gcm_crypt(req, (flags));					       \
1516 }									       \
1517 									       \
1518 static int rfc4106_setkey_##suffix(struct crypto_aead *tfm, const u8 *raw_key, \
1519 				   unsigned int keylen)			       \
1520 {									       \
1521 	return gcm_setkey(tfm, raw_key, keylen, (flags) | FLAG_RFC4106);       \
1522 }									       \
1523 									       \
1524 static int rfc4106_encrypt_##suffix(struct aead_request *req)		       \
1525 {									       \
1526 	return gcm_crypt(req, (flags) | FLAG_RFC4106 | FLAG_ENC);	       \
1527 }									       \
1528 									       \
1529 static int rfc4106_decrypt_##suffix(struct aead_request *req)		       \
1530 {									       \
1531 	return gcm_crypt(req, (flags) | FLAG_RFC4106);			       \
1532 }									       \
1533 									       \
1534 static struct aead_alg aes_gcm_algs_##suffix[] = { {			       \
1535 	.setkey			= gcm_setkey_##suffix,			       \
1536 	.setauthsize		= generic_gcmaes_set_authsize,		       \
1537 	.encrypt		= gcm_encrypt_##suffix,			       \
1538 	.decrypt		= gcm_decrypt_##suffix,			       \
1539 	.ivsize			= GCM_AES_IV_SIZE,			       \
1540 	.chunksize		= AES_BLOCK_SIZE,			       \
1541 	.maxauthsize		= 16,					       \
1542 	.base = {							       \
1543 		.cra_name		= "gcm(aes)",			       \
1544 		.cra_driver_name	= generic_driver_name,		       \
1545 		.cra_priority		= (priority),			       \
1546 		.cra_blocksize		= 1,				       \
1547 		.cra_ctxsize		= (ctxsize),			       \
1548 		.cra_module		= THIS_MODULE,			       \
1549 	},								       \
1550 }, {									       \
1551 	.setkey			= rfc4106_setkey_##suffix,		       \
1552 	.setauthsize		= common_rfc4106_set_authsize,		       \
1553 	.encrypt		= rfc4106_encrypt_##suffix,		       \
1554 	.decrypt		= rfc4106_decrypt_##suffix,		       \
1555 	.ivsize			= GCM_RFC4106_IV_SIZE,			       \
1556 	.chunksize		= AES_BLOCK_SIZE,			       \
1557 	.maxauthsize		= 16,					       \
1558 	.base = {							       \
1559 		.cra_name		= "rfc4106(gcm(aes))",		       \
1560 		.cra_driver_name	= rfc_driver_name,		       \
1561 		.cra_priority		= (priority),			       \
1562 		.cra_blocksize		= 1,				       \
1563 		.cra_ctxsize		= (ctxsize),			       \
1564 		.cra_module		= THIS_MODULE,			       \
1565 	},								       \
1566 } }
1567 
1568 /* aes_gcm_algs_aesni */
1569 DEFINE_GCM_ALGS(aesni, /* no flags */ 0,
1570 		"generic-gcm-aesni", "rfc4106-gcm-aesni",
1571 		AES_GCM_KEY_AESNI_SIZE, 400);
1572 
1573 /* aes_gcm_algs_aesni_avx */
1574 DEFINE_GCM_ALGS(aesni_avx, FLAG_AVX,
1575 		"generic-gcm-aesni-avx", "rfc4106-gcm-aesni-avx",
1576 		AES_GCM_KEY_AESNI_SIZE, 500);
1577 
1578 /* aes_gcm_algs_vaes_avx2 */
1579 DEFINE_GCM_ALGS(vaes_avx2, FLAG_VAES_AVX2,
1580 		"generic-gcm-vaes-avx2", "rfc4106-gcm-vaes-avx2",
1581 		AES_GCM_KEY_VAES_AVX2_SIZE, 600);
1582 
1583 /* aes_gcm_algs_vaes_avx512 */
1584 DEFINE_GCM_ALGS(vaes_avx512, FLAG_VAES_AVX512,
1585 		"generic-gcm-vaes-avx512", "rfc4106-gcm-vaes-avx512",
1586 		AES_GCM_KEY_VAES_AVX512_SIZE, 800);
1587 
1588 static int __init register_avx_algs(void)
1589 {
1590 	int err;
1591 
1592 	if (!boot_cpu_has(X86_FEATURE_AVX))
1593 		return 0;
1594 	err = crypto_register_skciphers(skcipher_algs_aesni_avx,
1595 					ARRAY_SIZE(skcipher_algs_aesni_avx));
1596 	if (err)
1597 		return err;
1598 	err = crypto_register_aeads(aes_gcm_algs_aesni_avx,
1599 				    ARRAY_SIZE(aes_gcm_algs_aesni_avx));
1600 	if (err)
1601 		return err;
1602 	/*
1603 	 * Note: not all the algorithms registered below actually require
1604 	 * VPCLMULQDQ.  But in practice every CPU with VAES also has VPCLMULQDQ.
1605 	 * Similarly, the assembler support was added at about the same time.
1606 	 * For simplicity, just always check for VAES and VPCLMULQDQ together.
1607 	 */
1608 	if (!boot_cpu_has(X86_FEATURE_AVX2) ||
1609 	    !boot_cpu_has(X86_FEATURE_VAES) ||
1610 	    !boot_cpu_has(X86_FEATURE_VPCLMULQDQ) ||
1611 	    !boot_cpu_has(X86_FEATURE_PCLMULQDQ) ||
1612 	    !cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM, NULL))
1613 		return 0;
1614 	err = crypto_register_skciphers(skcipher_algs_vaes_avx2,
1615 					ARRAY_SIZE(skcipher_algs_vaes_avx2));
1616 	if (err)
1617 		return err;
1618 	err = crypto_register_aeads(aes_gcm_algs_vaes_avx2,
1619 				    ARRAY_SIZE(aes_gcm_algs_vaes_avx2));
1620 	if (err)
1621 		return err;
1622 
1623 	if (!boot_cpu_has(X86_FEATURE_AVX512BW) ||
1624 	    !boot_cpu_has(X86_FEATURE_AVX512VL) ||
1625 	    !boot_cpu_has(X86_FEATURE_BMI2) ||
1626 	    !cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM |
1627 			       XFEATURE_MASK_AVX512, NULL))
1628 		return 0;
1629 
1630 	if (boot_cpu_has(X86_FEATURE_PREFER_YMM)) {
1631 		int i;
1632 
1633 		for (i = 0; i < ARRAY_SIZE(skcipher_algs_vaes_avx512); i++)
1634 			skcipher_algs_vaes_avx512[i].base.cra_priority = 1;
1635 		for (i = 0; i < ARRAY_SIZE(aes_gcm_algs_vaes_avx512); i++)
1636 			aes_gcm_algs_vaes_avx512[i].base.cra_priority = 1;
1637 	}
1638 
1639 	err = crypto_register_skciphers(skcipher_algs_vaes_avx512,
1640 					ARRAY_SIZE(skcipher_algs_vaes_avx512));
1641 	if (err)
1642 		return err;
1643 	err = crypto_register_aeads(aes_gcm_algs_vaes_avx512,
1644 				    ARRAY_SIZE(aes_gcm_algs_vaes_avx512));
1645 	if (err)
1646 		return err;
1647 
1648 	return 0;
1649 }
1650 
1651 #define unregister_skciphers(A) \
1652 	if (refcount_read(&(A)[0].base.cra_refcnt) != 0) \
1653 		crypto_unregister_skciphers((A), ARRAY_SIZE(A))
1654 #define unregister_aeads(A) \
1655 	if (refcount_read(&(A)[0].base.cra_refcnt) != 0) \
1656 		crypto_unregister_aeads((A), ARRAY_SIZE(A))
1657 
1658 static void unregister_avx_algs(void)
1659 {
1660 	unregister_skciphers(skcipher_algs_aesni_avx);
1661 	unregister_aeads(aes_gcm_algs_aesni_avx);
1662 	unregister_skciphers(skcipher_algs_vaes_avx2);
1663 	unregister_skciphers(skcipher_algs_vaes_avx512);
1664 	unregister_aeads(aes_gcm_algs_vaes_avx2);
1665 	unregister_aeads(aes_gcm_algs_vaes_avx512);
1666 }
1667 #else /* CONFIG_X86_64 */
1668 static struct aead_alg aes_gcm_algs_aesni[0];
1669 
1670 static int __init register_avx_algs(void)
1671 {
1672 	return 0;
1673 }
1674 
1675 static void unregister_avx_algs(void)
1676 {
1677 }
1678 #endif /* !CONFIG_X86_64 */
1679 
1680 static const struct x86_cpu_id aesni_cpu_id[] = {
1681 	X86_MATCH_FEATURE(X86_FEATURE_AES, NULL),
1682 	{}
1683 };
1684 MODULE_DEVICE_TABLE(x86cpu, aesni_cpu_id);
1685 
1686 static int __init aesni_init(void)
1687 {
1688 	int err;
1689 
1690 	if (!x86_match_cpu(aesni_cpu_id))
1691 		return -ENODEV;
1692 
1693 	err = crypto_register_alg(&aesni_cipher_alg);
1694 	if (err)
1695 		return err;
1696 
1697 	err = crypto_register_skciphers(aesni_skciphers,
1698 					ARRAY_SIZE(aesni_skciphers));
1699 	if (err)
1700 		goto unregister_cipher;
1701 
1702 	err = crypto_register_aeads(aes_gcm_algs_aesni,
1703 				    ARRAY_SIZE(aes_gcm_algs_aesni));
1704 	if (err)
1705 		goto unregister_skciphers;
1706 
1707 	err = register_avx_algs();
1708 	if (err)
1709 		goto unregister_avx;
1710 
1711 	return 0;
1712 
1713 unregister_avx:
1714 	unregister_avx_algs();
1715 	crypto_unregister_aeads(aes_gcm_algs_aesni,
1716 				ARRAY_SIZE(aes_gcm_algs_aesni));
1717 unregister_skciphers:
1718 	crypto_unregister_skciphers(aesni_skciphers,
1719 				    ARRAY_SIZE(aesni_skciphers));
1720 unregister_cipher:
1721 	crypto_unregister_alg(&aesni_cipher_alg);
1722 	return err;
1723 }
1724 
1725 static void __exit aesni_exit(void)
1726 {
1727 	crypto_unregister_aeads(aes_gcm_algs_aesni,
1728 				ARRAY_SIZE(aes_gcm_algs_aesni));
1729 	crypto_unregister_skciphers(aesni_skciphers,
1730 				    ARRAY_SIZE(aesni_skciphers));
1731 	crypto_unregister_alg(&aesni_cipher_alg);
1732 	unregister_avx_algs();
1733 }
1734 
1735 module_init(aesni_init);
1736 module_exit(aesni_exit);
1737 
1738 MODULE_DESCRIPTION("AES cipher and modes, optimized with AES-NI or VAES instructions");
1739 MODULE_LICENSE("GPL");
1740 MODULE_ALIAS_CRYPTO("aes");
1741