xref: /linux/arch/x86/crypto/aesni-intel_glue.c (revision 6f7e6393d1ce636bb7ec77a7fe7b77458fddf701)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * Support for AES-NI and VAES instructions.  This file contains glue code.
4  * The real AES implementations are in aesni-intel_asm.S and other .S files.
5  *
6  * Copyright (C) 2008, Intel Corp.
7  *    Author: Huang Ying <ying.huang@intel.com>
8  *
9  * Added RFC4106 AES-GCM support for 128-bit keys under the AEAD
10  * interface for 64-bit kernels.
11  *    Authors: Adrian Hoban <adrian.hoban@intel.com>
12  *             Gabriele Paoloni <gabriele.paoloni@intel.com>
13  *             Tadeusz Struk (tadeusz.struk@intel.com)
14  *             Aidan O'Mahony (aidan.o.mahony@intel.com)
15  *    Copyright (c) 2010, Intel Corporation.
16  *
17  * Copyright 2024 Google LLC
18  */
19 
20 #include <linux/hardirq.h>
21 #include <linux/types.h>
22 #include <linux/module.h>
23 #include <linux/err.h>
24 #include <crypto/algapi.h>
25 #include <crypto/aes.h>
26 #include <crypto/b128ops.h>
27 #include <crypto/gcm.h>
28 #include <crypto/xts.h>
29 #include <asm/cpu_device_id.h>
30 #include <asm/simd.h>
31 #include <crypto/scatterwalk.h>
32 #include <crypto/internal/aead.h>
33 #include <crypto/internal/simd.h>
34 #include <crypto/internal/skcipher.h>
35 #include <linux/jump_label.h>
36 #include <linux/workqueue.h>
37 #include <linux/spinlock.h>
38 #include <linux/static_call.h>
39 
40 
41 #define AESNI_ALIGN	16
42 #define AESNI_ALIGN_ATTR __attribute__ ((__aligned__(AESNI_ALIGN)))
43 #define AES_BLOCK_MASK	(~(AES_BLOCK_SIZE - 1))
44 #define AESNI_ALIGN_EXTRA ((AESNI_ALIGN - 1) & ~(CRYPTO_MINALIGN - 1))
45 #define CRYPTO_AES_CTX_SIZE (sizeof(struct crypto_aes_ctx) + AESNI_ALIGN_EXTRA)
46 #define XTS_AES_CTX_SIZE (sizeof(struct aesni_xts_ctx) + AESNI_ALIGN_EXTRA)
47 
48 struct aesni_xts_ctx {
49 	struct crypto_aes_ctx tweak_ctx AESNI_ALIGN_ATTR;
50 	struct crypto_aes_ctx crypt_ctx AESNI_ALIGN_ATTR;
51 };
52 
53 static inline void *aes_align_addr(void *addr)
54 {
55 	if (crypto_tfm_ctx_alignment() >= AESNI_ALIGN)
56 		return addr;
57 	return PTR_ALIGN(addr, AESNI_ALIGN);
58 }
59 
60 asmlinkage void aesni_set_key(struct crypto_aes_ctx *ctx, const u8 *in_key,
61 			      unsigned int key_len);
62 asmlinkage void aesni_enc(const void *ctx, u8 *out, const u8 *in);
63 asmlinkage void aesni_ecb_enc(struct crypto_aes_ctx *ctx, u8 *out,
64 			      const u8 *in, unsigned int len);
65 asmlinkage void aesni_ecb_dec(struct crypto_aes_ctx *ctx, u8 *out,
66 			      const u8 *in, unsigned int len);
67 asmlinkage void aesni_cbc_enc(struct crypto_aes_ctx *ctx, u8 *out,
68 			      const u8 *in, unsigned int len, u8 *iv);
69 asmlinkage void aesni_cbc_dec(struct crypto_aes_ctx *ctx, u8 *out,
70 			      const u8 *in, unsigned int len, u8 *iv);
71 asmlinkage void aesni_cts_cbc_enc(struct crypto_aes_ctx *ctx, u8 *out,
72 				  const u8 *in, unsigned int len, u8 *iv);
73 asmlinkage void aesni_cts_cbc_dec(struct crypto_aes_ctx *ctx, u8 *out,
74 				  const u8 *in, unsigned int len, u8 *iv);
75 
76 asmlinkage void aesni_xts_enc(const struct crypto_aes_ctx *ctx, u8 *out,
77 			      const u8 *in, unsigned int len, u8 *iv);
78 
79 asmlinkage void aesni_xts_dec(const struct crypto_aes_ctx *ctx, u8 *out,
80 			      const u8 *in, unsigned int len, u8 *iv);
81 
82 #ifdef CONFIG_X86_64
83 asmlinkage void aesni_ctr_enc(struct crypto_aes_ctx *ctx, u8 *out,
84 			      const u8 *in, unsigned int len, u8 *iv);
85 #endif
86 
87 static inline struct crypto_aes_ctx *aes_ctx(void *raw_ctx)
88 {
89 	return aes_align_addr(raw_ctx);
90 }
91 
92 static inline struct aesni_xts_ctx *aes_xts_ctx(struct crypto_skcipher *tfm)
93 {
94 	return aes_align_addr(crypto_skcipher_ctx(tfm));
95 }
96 
97 static int aes_set_key_common(struct crypto_aes_ctx *ctx,
98 			      const u8 *in_key, unsigned int key_len)
99 {
100 	int err;
101 
102 	if (!crypto_simd_usable())
103 		return aes_expandkey(ctx, in_key, key_len);
104 
105 	err = aes_check_keylen(key_len);
106 	if (err)
107 		return err;
108 
109 	kernel_fpu_begin();
110 	aesni_set_key(ctx, in_key, key_len);
111 	kernel_fpu_end();
112 	return 0;
113 }
114 
115 static int aesni_skcipher_setkey(struct crypto_skcipher *tfm, const u8 *key,
116 			         unsigned int len)
117 {
118 	return aes_set_key_common(aes_ctx(crypto_skcipher_ctx(tfm)), key, len);
119 }
120 
121 static int ecb_encrypt(struct skcipher_request *req)
122 {
123 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
124 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
125 	struct skcipher_walk walk;
126 	unsigned int nbytes;
127 	int err;
128 
129 	err = skcipher_walk_virt(&walk, req, false);
130 
131 	while ((nbytes = walk.nbytes)) {
132 		kernel_fpu_begin();
133 		aesni_ecb_enc(ctx, walk.dst.virt.addr, walk.src.virt.addr,
134 			      nbytes & AES_BLOCK_MASK);
135 		kernel_fpu_end();
136 		nbytes &= AES_BLOCK_SIZE - 1;
137 		err = skcipher_walk_done(&walk, nbytes);
138 	}
139 
140 	return err;
141 }
142 
143 static int ecb_decrypt(struct skcipher_request *req)
144 {
145 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
146 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
147 	struct skcipher_walk walk;
148 	unsigned int nbytes;
149 	int err;
150 
151 	err = skcipher_walk_virt(&walk, req, false);
152 
153 	while ((nbytes = walk.nbytes)) {
154 		kernel_fpu_begin();
155 		aesni_ecb_dec(ctx, walk.dst.virt.addr, walk.src.virt.addr,
156 			      nbytes & AES_BLOCK_MASK);
157 		kernel_fpu_end();
158 		nbytes &= AES_BLOCK_SIZE - 1;
159 		err = skcipher_walk_done(&walk, nbytes);
160 	}
161 
162 	return err;
163 }
164 
165 static int cbc_encrypt(struct skcipher_request *req)
166 {
167 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
168 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
169 	struct skcipher_walk walk;
170 	unsigned int nbytes;
171 	int err;
172 
173 	err = skcipher_walk_virt(&walk, req, false);
174 
175 	while ((nbytes = walk.nbytes)) {
176 		kernel_fpu_begin();
177 		aesni_cbc_enc(ctx, walk.dst.virt.addr, walk.src.virt.addr,
178 			      nbytes & AES_BLOCK_MASK, walk.iv);
179 		kernel_fpu_end();
180 		nbytes &= AES_BLOCK_SIZE - 1;
181 		err = skcipher_walk_done(&walk, nbytes);
182 	}
183 
184 	return err;
185 }
186 
187 static int cbc_decrypt(struct skcipher_request *req)
188 {
189 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
190 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
191 	struct skcipher_walk walk;
192 	unsigned int nbytes;
193 	int err;
194 
195 	err = skcipher_walk_virt(&walk, req, false);
196 
197 	while ((nbytes = walk.nbytes)) {
198 		kernel_fpu_begin();
199 		aesni_cbc_dec(ctx, walk.dst.virt.addr, walk.src.virt.addr,
200 			      nbytes & AES_BLOCK_MASK, walk.iv);
201 		kernel_fpu_end();
202 		nbytes &= AES_BLOCK_SIZE - 1;
203 		err = skcipher_walk_done(&walk, nbytes);
204 	}
205 
206 	return err;
207 }
208 
209 static int cts_cbc_encrypt(struct skcipher_request *req)
210 {
211 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
212 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
213 	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
214 	struct scatterlist *src = req->src, *dst = req->dst;
215 	struct scatterlist sg_src[2], sg_dst[2];
216 	struct skcipher_request subreq;
217 	struct skcipher_walk walk;
218 	int err;
219 
220 	skcipher_request_set_tfm(&subreq, tfm);
221 	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
222 				      NULL, NULL);
223 
224 	if (req->cryptlen <= AES_BLOCK_SIZE) {
225 		if (req->cryptlen < AES_BLOCK_SIZE)
226 			return -EINVAL;
227 		cbc_blocks = 1;
228 	}
229 
230 	if (cbc_blocks > 0) {
231 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
232 					   cbc_blocks * AES_BLOCK_SIZE,
233 					   req->iv);
234 
235 		err = cbc_encrypt(&subreq);
236 		if (err)
237 			return err;
238 
239 		if (req->cryptlen == AES_BLOCK_SIZE)
240 			return 0;
241 
242 		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
243 		if (req->dst != req->src)
244 			dst = scatterwalk_ffwd(sg_dst, req->dst,
245 					       subreq.cryptlen);
246 	}
247 
248 	/* handle ciphertext stealing */
249 	skcipher_request_set_crypt(&subreq, src, dst,
250 				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
251 				   req->iv);
252 
253 	err = skcipher_walk_virt(&walk, &subreq, false);
254 	if (err)
255 		return err;
256 
257 	kernel_fpu_begin();
258 	aesni_cts_cbc_enc(ctx, walk.dst.virt.addr, walk.src.virt.addr,
259 			  walk.nbytes, walk.iv);
260 	kernel_fpu_end();
261 
262 	return skcipher_walk_done(&walk, 0);
263 }
264 
265 static int cts_cbc_decrypt(struct skcipher_request *req)
266 {
267 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
268 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
269 	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
270 	struct scatterlist *src = req->src, *dst = req->dst;
271 	struct scatterlist sg_src[2], sg_dst[2];
272 	struct skcipher_request subreq;
273 	struct skcipher_walk walk;
274 	int err;
275 
276 	skcipher_request_set_tfm(&subreq, tfm);
277 	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
278 				      NULL, NULL);
279 
280 	if (req->cryptlen <= AES_BLOCK_SIZE) {
281 		if (req->cryptlen < AES_BLOCK_SIZE)
282 			return -EINVAL;
283 		cbc_blocks = 1;
284 	}
285 
286 	if (cbc_blocks > 0) {
287 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
288 					   cbc_blocks * AES_BLOCK_SIZE,
289 					   req->iv);
290 
291 		err = cbc_decrypt(&subreq);
292 		if (err)
293 			return err;
294 
295 		if (req->cryptlen == AES_BLOCK_SIZE)
296 			return 0;
297 
298 		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
299 		if (req->dst != req->src)
300 			dst = scatterwalk_ffwd(sg_dst, req->dst,
301 					       subreq.cryptlen);
302 	}
303 
304 	/* handle ciphertext stealing */
305 	skcipher_request_set_crypt(&subreq, src, dst,
306 				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
307 				   req->iv);
308 
309 	err = skcipher_walk_virt(&walk, &subreq, false);
310 	if (err)
311 		return err;
312 
313 	kernel_fpu_begin();
314 	aesni_cts_cbc_dec(ctx, walk.dst.virt.addr, walk.src.virt.addr,
315 			  walk.nbytes, walk.iv);
316 	kernel_fpu_end();
317 
318 	return skcipher_walk_done(&walk, 0);
319 }
320 
321 #ifdef CONFIG_X86_64
322 /* This is the non-AVX version. */
323 static int ctr_crypt_aesni(struct skcipher_request *req)
324 {
325 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
326 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
327 	u8 keystream[AES_BLOCK_SIZE];
328 	struct skcipher_walk walk;
329 	unsigned int nbytes;
330 	int err;
331 
332 	err = skcipher_walk_virt(&walk, req, false);
333 
334 	while ((nbytes = walk.nbytes) > 0) {
335 		kernel_fpu_begin();
336 		if (nbytes & AES_BLOCK_MASK)
337 			aesni_ctr_enc(ctx, walk.dst.virt.addr,
338 				      walk.src.virt.addr,
339 				      nbytes & AES_BLOCK_MASK, walk.iv);
340 		nbytes &= ~AES_BLOCK_MASK;
341 
342 		if (walk.nbytes == walk.total && nbytes > 0) {
343 			aesni_enc(ctx, keystream, walk.iv);
344 			crypto_xor_cpy(walk.dst.virt.addr + walk.nbytes - nbytes,
345 				       walk.src.virt.addr + walk.nbytes - nbytes,
346 				       keystream, nbytes);
347 			crypto_inc(walk.iv, AES_BLOCK_SIZE);
348 			nbytes = 0;
349 		}
350 		kernel_fpu_end();
351 		err = skcipher_walk_done(&walk, nbytes);
352 	}
353 	return err;
354 }
355 #endif
356 
357 static int xts_setkey_aesni(struct crypto_skcipher *tfm, const u8 *key,
358 			    unsigned int keylen)
359 {
360 	struct aesni_xts_ctx *ctx = aes_xts_ctx(tfm);
361 	int err;
362 
363 	err = xts_verify_key(tfm, key, keylen);
364 	if (err)
365 		return err;
366 
367 	keylen /= 2;
368 
369 	/* first half of xts-key is for crypt */
370 	err = aes_set_key_common(&ctx->crypt_ctx, key, keylen);
371 	if (err)
372 		return err;
373 
374 	/* second half of xts-key is for tweak */
375 	return aes_set_key_common(&ctx->tweak_ctx, key + keylen, keylen);
376 }
377 
378 typedef void (*xts_encrypt_iv_func)(const struct crypto_aes_ctx *tweak_key,
379 				    u8 iv[AES_BLOCK_SIZE]);
380 typedef void (*xts_crypt_func)(const struct crypto_aes_ctx *key,
381 			       const u8 *src, u8 *dst, int len,
382 			       u8 tweak[AES_BLOCK_SIZE]);
383 
384 /* This handles cases where the source and/or destination span pages. */
385 static noinline int
386 xts_crypt_slowpath(struct skcipher_request *req, xts_crypt_func crypt_func)
387 {
388 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
389 	const struct aesni_xts_ctx *ctx = aes_xts_ctx(tfm);
390 	int tail = req->cryptlen % AES_BLOCK_SIZE;
391 	struct scatterlist sg_src[2], sg_dst[2];
392 	struct skcipher_request subreq;
393 	struct skcipher_walk walk;
394 	struct scatterlist *src, *dst;
395 	int err;
396 
397 	/*
398 	 * If the message length isn't divisible by the AES block size, then
399 	 * separate off the last full block and the partial block.  This ensures
400 	 * that they are processed in the same call to the assembly function,
401 	 * which is required for ciphertext stealing.
402 	 */
403 	if (tail) {
404 		skcipher_request_set_tfm(&subreq, tfm);
405 		skcipher_request_set_callback(&subreq,
406 					      skcipher_request_flags(req),
407 					      NULL, NULL);
408 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
409 					   req->cryptlen - tail - AES_BLOCK_SIZE,
410 					   req->iv);
411 		req = &subreq;
412 	}
413 
414 	err = skcipher_walk_virt(&walk, req, false);
415 
416 	while (walk.nbytes) {
417 		kernel_fpu_begin();
418 		(*crypt_func)(&ctx->crypt_ctx,
419 			      walk.src.virt.addr, walk.dst.virt.addr,
420 			      walk.nbytes & ~(AES_BLOCK_SIZE - 1), req->iv);
421 		kernel_fpu_end();
422 		err = skcipher_walk_done(&walk,
423 					 walk.nbytes & (AES_BLOCK_SIZE - 1));
424 	}
425 
426 	if (err || !tail)
427 		return err;
428 
429 	/* Do ciphertext stealing with the last full block and partial block. */
430 
431 	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
432 	if (req->dst != req->src)
433 		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
434 
435 	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
436 				   req->iv);
437 
438 	err = skcipher_walk_virt(&walk, req, false);
439 	if (err)
440 		return err;
441 
442 	kernel_fpu_begin();
443 	(*crypt_func)(&ctx->crypt_ctx, walk.src.virt.addr, walk.dst.virt.addr,
444 		      walk.nbytes, req->iv);
445 	kernel_fpu_end();
446 
447 	return skcipher_walk_done(&walk, 0);
448 }
449 
450 /* __always_inline to avoid indirect call in fastpath */
451 static __always_inline int
452 xts_crypt(struct skcipher_request *req, xts_encrypt_iv_func encrypt_iv,
453 	  xts_crypt_func crypt_func)
454 {
455 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
456 	const struct aesni_xts_ctx *ctx = aes_xts_ctx(tfm);
457 
458 	if (unlikely(req->cryptlen < AES_BLOCK_SIZE))
459 		return -EINVAL;
460 
461 	kernel_fpu_begin();
462 	(*encrypt_iv)(&ctx->tweak_ctx, req->iv);
463 
464 	/*
465 	 * In practice, virtually all XTS plaintexts and ciphertexts are either
466 	 * 512 or 4096 bytes and do not use multiple scatterlist elements.  To
467 	 * optimize the performance of these cases, the below fast-path handles
468 	 * single-scatterlist-element messages as efficiently as possible.  The
469 	 * code is 64-bit specific, as it assumes no page mapping is needed.
470 	 */
471 	if (IS_ENABLED(CONFIG_X86_64) &&
472 	    likely(req->src->length >= req->cryptlen &&
473 		   req->dst->length >= req->cryptlen)) {
474 		(*crypt_func)(&ctx->crypt_ctx, sg_virt(req->src),
475 			      sg_virt(req->dst), req->cryptlen, req->iv);
476 		kernel_fpu_end();
477 		return 0;
478 	}
479 	kernel_fpu_end();
480 	return xts_crypt_slowpath(req, crypt_func);
481 }
482 
483 static void aesni_xts_encrypt_iv(const struct crypto_aes_ctx *tweak_key,
484 				 u8 iv[AES_BLOCK_SIZE])
485 {
486 	aesni_enc(tweak_key, iv, iv);
487 }
488 
489 static void aesni_xts_encrypt(const struct crypto_aes_ctx *key,
490 			      const u8 *src, u8 *dst, int len,
491 			      u8 tweak[AES_BLOCK_SIZE])
492 {
493 	aesni_xts_enc(key, dst, src, len, tweak);
494 }
495 
496 static void aesni_xts_decrypt(const struct crypto_aes_ctx *key,
497 			      const u8 *src, u8 *dst, int len,
498 			      u8 tweak[AES_BLOCK_SIZE])
499 {
500 	aesni_xts_dec(key, dst, src, len, tweak);
501 }
502 
503 static int xts_encrypt_aesni(struct skcipher_request *req)
504 {
505 	return xts_crypt(req, aesni_xts_encrypt_iv, aesni_xts_encrypt);
506 }
507 
508 static int xts_decrypt_aesni(struct skcipher_request *req)
509 {
510 	return xts_crypt(req, aesni_xts_encrypt_iv, aesni_xts_decrypt);
511 }
512 
513 static struct skcipher_alg aesni_skciphers[] = {
514 	{
515 		.base = {
516 			.cra_name		= "ecb(aes)",
517 			.cra_driver_name	= "ecb-aes-aesni",
518 			.cra_priority		= 400,
519 			.cra_blocksize		= AES_BLOCK_SIZE,
520 			.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
521 			.cra_module		= THIS_MODULE,
522 		},
523 		.min_keysize	= AES_MIN_KEY_SIZE,
524 		.max_keysize	= AES_MAX_KEY_SIZE,
525 		.setkey		= aesni_skcipher_setkey,
526 		.encrypt	= ecb_encrypt,
527 		.decrypt	= ecb_decrypt,
528 	}, {
529 		.base = {
530 			.cra_name		= "cbc(aes)",
531 			.cra_driver_name	= "cbc-aes-aesni",
532 			.cra_priority		= 400,
533 			.cra_blocksize		= AES_BLOCK_SIZE,
534 			.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
535 			.cra_module		= THIS_MODULE,
536 		},
537 		.min_keysize	= AES_MIN_KEY_SIZE,
538 		.max_keysize	= AES_MAX_KEY_SIZE,
539 		.ivsize		= AES_BLOCK_SIZE,
540 		.setkey		= aesni_skcipher_setkey,
541 		.encrypt	= cbc_encrypt,
542 		.decrypt	= cbc_decrypt,
543 	}, {
544 		.base = {
545 			.cra_name		= "cts(cbc(aes))",
546 			.cra_driver_name	= "cts-cbc-aes-aesni",
547 			.cra_priority		= 400,
548 			.cra_blocksize		= AES_BLOCK_SIZE,
549 			.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
550 			.cra_module		= THIS_MODULE,
551 		},
552 		.min_keysize	= AES_MIN_KEY_SIZE,
553 		.max_keysize	= AES_MAX_KEY_SIZE,
554 		.ivsize		= AES_BLOCK_SIZE,
555 		.walksize	= 2 * AES_BLOCK_SIZE,
556 		.setkey		= aesni_skcipher_setkey,
557 		.encrypt	= cts_cbc_encrypt,
558 		.decrypt	= cts_cbc_decrypt,
559 #ifdef CONFIG_X86_64
560 	}, {
561 		.base = {
562 			.cra_name		= "ctr(aes)",
563 			.cra_driver_name	= "ctr-aes-aesni",
564 			.cra_priority		= 400,
565 			.cra_blocksize		= 1,
566 			.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
567 			.cra_module		= THIS_MODULE,
568 		},
569 		.min_keysize	= AES_MIN_KEY_SIZE,
570 		.max_keysize	= AES_MAX_KEY_SIZE,
571 		.ivsize		= AES_BLOCK_SIZE,
572 		.chunksize	= AES_BLOCK_SIZE,
573 		.setkey		= aesni_skcipher_setkey,
574 		.encrypt	= ctr_crypt_aesni,
575 		.decrypt	= ctr_crypt_aesni,
576 #endif
577 	}, {
578 		.base = {
579 			.cra_name		= "xts(aes)",
580 			.cra_driver_name	= "xts-aes-aesni",
581 			.cra_priority		= 401,
582 			.cra_blocksize		= AES_BLOCK_SIZE,
583 			.cra_ctxsize		= XTS_AES_CTX_SIZE,
584 			.cra_module		= THIS_MODULE,
585 		},
586 		.min_keysize	= 2 * AES_MIN_KEY_SIZE,
587 		.max_keysize	= 2 * AES_MAX_KEY_SIZE,
588 		.ivsize		= AES_BLOCK_SIZE,
589 		.walksize	= 2 * AES_BLOCK_SIZE,
590 		.setkey		= xts_setkey_aesni,
591 		.encrypt	= xts_encrypt_aesni,
592 		.decrypt	= xts_decrypt_aesni,
593 	}
594 };
595 
596 #ifdef CONFIG_X86_64
597 asmlinkage void aes_xts_encrypt_iv(const struct crypto_aes_ctx *tweak_key,
598 				   u8 iv[AES_BLOCK_SIZE]);
599 
600 /* __always_inline to avoid indirect call */
601 static __always_inline int
602 ctr_crypt(struct skcipher_request *req,
603 	  void (*ctr64_func)(const struct crypto_aes_ctx *key,
604 			     const u8 *src, u8 *dst, int len,
605 			     const u64 le_ctr[2]))
606 {
607 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
608 	const struct crypto_aes_ctx *key = aes_ctx(crypto_skcipher_ctx(tfm));
609 	unsigned int nbytes, p1_nbytes, nblocks;
610 	struct skcipher_walk walk;
611 	u64 le_ctr[2];
612 	u64 ctr64;
613 	int err;
614 
615 	ctr64 = le_ctr[0] = get_unaligned_be64(&req->iv[8]);
616 	le_ctr[1] = get_unaligned_be64(&req->iv[0]);
617 
618 	err = skcipher_walk_virt(&walk, req, false);
619 
620 	while ((nbytes = walk.nbytes) != 0) {
621 		if (nbytes < walk.total) {
622 			/* Not the end yet, so keep the length block-aligned. */
623 			nbytes = round_down(nbytes, AES_BLOCK_SIZE);
624 			nblocks = nbytes / AES_BLOCK_SIZE;
625 		} else {
626 			/* It's the end, so include any final partial block. */
627 			nblocks = DIV_ROUND_UP(nbytes, AES_BLOCK_SIZE);
628 		}
629 		ctr64 += nblocks;
630 
631 		kernel_fpu_begin();
632 		if (likely(ctr64 >= nblocks)) {
633 			/* The low 64 bits of the counter won't overflow. */
634 			(*ctr64_func)(key, walk.src.virt.addr,
635 				      walk.dst.virt.addr, nbytes, le_ctr);
636 		} else {
637 			/*
638 			 * The low 64 bits of the counter will overflow.  The
639 			 * assembly doesn't handle this case, so split the
640 			 * operation into two at the point where the overflow
641 			 * will occur.  After the first part, add the carry bit.
642 			 */
643 			p1_nbytes = min(nbytes, (nblocks - ctr64) * AES_BLOCK_SIZE);
644 			(*ctr64_func)(key, walk.src.virt.addr,
645 				      walk.dst.virt.addr, p1_nbytes, le_ctr);
646 			le_ctr[0] = 0;
647 			le_ctr[1]++;
648 			(*ctr64_func)(key, walk.src.virt.addr + p1_nbytes,
649 				      walk.dst.virt.addr + p1_nbytes,
650 				      nbytes - p1_nbytes, le_ctr);
651 		}
652 		kernel_fpu_end();
653 		le_ctr[0] = ctr64;
654 
655 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
656 	}
657 
658 	put_unaligned_be64(ctr64, &req->iv[8]);
659 	put_unaligned_be64(le_ctr[1], &req->iv[0]);
660 
661 	return err;
662 }
663 
664 /* __always_inline to avoid indirect call */
665 static __always_inline int
666 xctr_crypt(struct skcipher_request *req,
667 	   void (*xctr_func)(const struct crypto_aes_ctx *key,
668 			     const u8 *src, u8 *dst, int len,
669 			     const u8 iv[AES_BLOCK_SIZE], u64 ctr))
670 {
671 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
672 	const struct crypto_aes_ctx *key = aes_ctx(crypto_skcipher_ctx(tfm));
673 	struct skcipher_walk walk;
674 	unsigned int nbytes;
675 	u64 ctr = 1;
676 	int err;
677 
678 	err = skcipher_walk_virt(&walk, req, false);
679 	while ((nbytes = walk.nbytes) != 0) {
680 		if (nbytes < walk.total)
681 			nbytes = round_down(nbytes, AES_BLOCK_SIZE);
682 
683 		kernel_fpu_begin();
684 		(*xctr_func)(key, walk.src.virt.addr, walk.dst.virt.addr,
685 			     nbytes, req->iv, ctr);
686 		kernel_fpu_end();
687 
688 		ctr += DIV_ROUND_UP(nbytes, AES_BLOCK_SIZE);
689 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
690 	}
691 	return err;
692 }
693 
694 #define DEFINE_AVX_SKCIPHER_ALGS(suffix, driver_name_suffix, priority)	       \
695 									       \
696 asmlinkage void								       \
697 aes_xts_encrypt_##suffix(const struct crypto_aes_ctx *key, const u8 *src,      \
698 			 u8 *dst, int len, u8 tweak[AES_BLOCK_SIZE]);	       \
699 asmlinkage void								       \
700 aes_xts_decrypt_##suffix(const struct crypto_aes_ctx *key, const u8 *src,      \
701 			 u8 *dst, int len, u8 tweak[AES_BLOCK_SIZE]);	       \
702 									       \
703 static int xts_encrypt_##suffix(struct skcipher_request *req)		       \
704 {									       \
705 	return xts_crypt(req, aes_xts_encrypt_iv, aes_xts_encrypt_##suffix);   \
706 }									       \
707 									       \
708 static int xts_decrypt_##suffix(struct skcipher_request *req)		       \
709 {									       \
710 	return xts_crypt(req, aes_xts_encrypt_iv, aes_xts_decrypt_##suffix);   \
711 }									       \
712 									       \
713 asmlinkage void								       \
714 aes_ctr64_crypt_##suffix(const struct crypto_aes_ctx *key,		       \
715 			 const u8 *src, u8 *dst, int len, const u64 le_ctr[2]);\
716 									       \
717 static int ctr_crypt_##suffix(struct skcipher_request *req)		       \
718 {									       \
719 	return ctr_crypt(req, aes_ctr64_crypt_##suffix);		       \
720 }									       \
721 									       \
722 asmlinkage void								       \
723 aes_xctr_crypt_##suffix(const struct crypto_aes_ctx *key,		       \
724 			const u8 *src, u8 *dst, int len,		       \
725 			const u8 iv[AES_BLOCK_SIZE], u64 ctr);		       \
726 									       \
727 static int xctr_crypt_##suffix(struct skcipher_request *req)		       \
728 {									       \
729 	return xctr_crypt(req, aes_xctr_crypt_##suffix);		       \
730 }									       \
731 									       \
732 static struct skcipher_alg skcipher_algs_##suffix[] = {{		       \
733 	.base.cra_name		= "xts(aes)",				       \
734 	.base.cra_driver_name	= "xts-aes-" driver_name_suffix,	       \
735 	.base.cra_priority	= priority,				       \
736 	.base.cra_blocksize	= AES_BLOCK_SIZE,			       \
737 	.base.cra_ctxsize	= XTS_AES_CTX_SIZE,			       \
738 	.base.cra_module	= THIS_MODULE,				       \
739 	.min_keysize		= 2 * AES_MIN_KEY_SIZE,			       \
740 	.max_keysize		= 2 * AES_MAX_KEY_SIZE,			       \
741 	.ivsize			= AES_BLOCK_SIZE,			       \
742 	.walksize		= 2 * AES_BLOCK_SIZE,			       \
743 	.setkey			= xts_setkey_aesni,			       \
744 	.encrypt		= xts_encrypt_##suffix,			       \
745 	.decrypt		= xts_decrypt_##suffix,			       \
746 }, {									       \
747 	.base.cra_name		= "ctr(aes)",				       \
748 	.base.cra_driver_name	= "ctr-aes-" driver_name_suffix,	       \
749 	.base.cra_priority	= priority,				       \
750 	.base.cra_blocksize	= 1,					       \
751 	.base.cra_ctxsize	= CRYPTO_AES_CTX_SIZE,			       \
752 	.base.cra_module	= THIS_MODULE,				       \
753 	.min_keysize		= AES_MIN_KEY_SIZE,			       \
754 	.max_keysize		= AES_MAX_KEY_SIZE,			       \
755 	.ivsize			= AES_BLOCK_SIZE,			       \
756 	.chunksize		= AES_BLOCK_SIZE,			       \
757 	.setkey			= aesni_skcipher_setkey,		       \
758 	.encrypt		= ctr_crypt_##suffix,			       \
759 	.decrypt		= ctr_crypt_##suffix,			       \
760 }, {									       \
761 	.base.cra_name		= "xctr(aes)",				       \
762 	.base.cra_driver_name	= "xctr-aes-" driver_name_suffix,	       \
763 	.base.cra_priority	= priority,				       \
764 	.base.cra_blocksize	= 1,					       \
765 	.base.cra_ctxsize	= CRYPTO_AES_CTX_SIZE,			       \
766 	.base.cra_module	= THIS_MODULE,				       \
767 	.min_keysize		= AES_MIN_KEY_SIZE,			       \
768 	.max_keysize		= AES_MAX_KEY_SIZE,			       \
769 	.ivsize			= AES_BLOCK_SIZE,			       \
770 	.chunksize		= AES_BLOCK_SIZE,			       \
771 	.setkey			= aesni_skcipher_setkey,		       \
772 	.encrypt		= xctr_crypt_##suffix,			       \
773 	.decrypt		= xctr_crypt_##suffix,			       \
774 }}
775 
776 DEFINE_AVX_SKCIPHER_ALGS(aesni_avx, "aesni-avx", 500);
777 DEFINE_AVX_SKCIPHER_ALGS(vaes_avx2, "vaes-avx2", 600);
778 DEFINE_AVX_SKCIPHER_ALGS(vaes_avx512, "vaes-avx512", 800);
779 
780 /* The common part of the x86_64 AES-GCM key struct */
781 struct aes_gcm_key {
782 	/* Expanded AES key and the AES key length in bytes */
783 	struct aes_enckey aes_key;
784 
785 	/* RFC4106 nonce (used only by the rfc4106 algorithms) */
786 	u32 rfc4106_nonce;
787 };
788 
789 /* Key struct used by the AES-NI implementations of AES-GCM */
790 struct aes_gcm_key_aesni {
791 	/*
792 	 * Common part of the key.  16-byte alignment is required by the
793 	 * assembly code.
794 	 */
795 	struct aes_gcm_key base __aligned(16);
796 
797 	/*
798 	 * Powers of the hash key H^8 through H^1.  These are 128-bit values.
799 	 * They all have an extra factor of x^-1 and are byte-reversed.  16-byte
800 	 * alignment is required by the assembly code.
801 	 */
802 	u64 h_powers[8][2] __aligned(16);
803 
804 	/*
805 	 * h_powers_xored[i] contains the two 64-bit halves of h_powers[i] XOR'd
806 	 * together.  It's used for Karatsuba multiplication.  16-byte alignment
807 	 * is required by the assembly code.
808 	 */
809 	u64 h_powers_xored[8] __aligned(16);
810 
811 	/*
812 	 * H^1 times x^64 (and also the usual extra factor of x^-1).  16-byte
813 	 * alignment is required by the assembly code.
814 	 */
815 	u64 h_times_x64[2] __aligned(16);
816 };
817 #define AES_GCM_KEY_AESNI(key)	\
818 	container_of((key), struct aes_gcm_key_aesni, base)
819 #define AES_GCM_KEY_AESNI_SIZE	\
820 	(sizeof(struct aes_gcm_key_aesni) + (15 & ~(CRYPTO_MINALIGN - 1)))
821 
822 /* Key struct used by the VAES + AVX2 implementation of AES-GCM */
823 struct aes_gcm_key_vaes_avx2 {
824 	/*
825 	 * Common part of the key.  The assembly code prefers 16-byte alignment
826 	 * for this.
827 	 */
828 	struct aes_gcm_key base __aligned(16);
829 
830 	/*
831 	 * Powers of the hash key H^8 through H^1.  These are 128-bit values.
832 	 * They all have an extra factor of x^-1 and are byte-reversed.
833 	 * The assembly code prefers 32-byte alignment for this.
834 	 */
835 	u64 h_powers[8][2] __aligned(32);
836 
837 	/*
838 	 * Each entry in this array contains the two halves of an entry of
839 	 * h_powers XOR'd together, in the following order:
840 	 * H^8,H^6,H^7,H^5,H^4,H^2,H^3,H^1 i.e. indices 0,2,1,3,4,6,5,7.
841 	 * This is used for Karatsuba multiplication.
842 	 */
843 	u64 h_powers_xored[8];
844 };
845 
846 #define AES_GCM_KEY_VAES_AVX2(key) \
847 	container_of((key), struct aes_gcm_key_vaes_avx2, base)
848 #define AES_GCM_KEY_VAES_AVX2_SIZE \
849 	(sizeof(struct aes_gcm_key_vaes_avx2) + (31 & ~(CRYPTO_MINALIGN - 1)))
850 
851 /* Key struct used by the VAES + AVX512 implementation of AES-GCM */
852 struct aes_gcm_key_vaes_avx512 {
853 	/*
854 	 * Common part of the key.  The assembly code prefers 16-byte alignment
855 	 * for this.
856 	 */
857 	struct aes_gcm_key base __aligned(16);
858 
859 	/*
860 	 * Powers of the hash key H^16 through H^1.  These are 128-bit values.
861 	 * They all have an extra factor of x^-1 and are byte-reversed.  This
862 	 * array is aligned to a 64-byte boundary to make it naturally aligned
863 	 * for 512-bit loads, which can improve performance.  (The assembly code
864 	 * doesn't *need* the alignment; this is just an optimization.)
865 	 */
866 	u64 h_powers[16][2] __aligned(64);
867 
868 	/* Three padding blocks required by the assembly code */
869 	u64 padding[3][2];
870 };
871 #define AES_GCM_KEY_VAES_AVX512(key) \
872 	container_of((key), struct aes_gcm_key_vaes_avx512, base)
873 #define AES_GCM_KEY_VAES_AVX512_SIZE \
874 	(sizeof(struct aes_gcm_key_vaes_avx512) + (63 & ~(CRYPTO_MINALIGN - 1)))
875 
876 /*
877  * These flags are passed to the AES-GCM helper functions to specify the
878  * specific version of AES-GCM (RFC4106 or not), whether it's encryption or
879  * decryption, and which assembly functions should be called.  Assembly
880  * functions are selected using flags instead of function pointers to avoid
881  * indirect calls (which are very expensive on x86) regardless of inlining.
882  */
883 #define FLAG_RFC4106	BIT(0)
884 #define FLAG_ENC	BIT(1)
885 #define FLAG_AVX	BIT(2)
886 #define FLAG_VAES_AVX2	BIT(3)
887 #define FLAG_VAES_AVX512 BIT(4)
888 
889 static inline struct aes_gcm_key *
890 aes_gcm_key_get(struct crypto_aead *tfm, int flags)
891 {
892 	if (flags & FLAG_VAES_AVX512)
893 		return PTR_ALIGN(crypto_aead_ctx(tfm), 64);
894 	else if (flags & FLAG_VAES_AVX2)
895 		return PTR_ALIGN(crypto_aead_ctx(tfm), 32);
896 	else
897 		return PTR_ALIGN(crypto_aead_ctx(tfm), 16);
898 }
899 
900 asmlinkage void
901 aes_gcm_precompute_aesni(struct aes_gcm_key_aesni *key);
902 asmlinkage void
903 aes_gcm_precompute_aesni_avx(struct aes_gcm_key_aesni *key);
904 asmlinkage void
905 aes_gcm_precompute_vaes_avx2(struct aes_gcm_key_vaes_avx2 *key);
906 asmlinkage void
907 aes_gcm_precompute_vaes_avx512(struct aes_gcm_key_vaes_avx512 *key);
908 
909 static void aes_gcm_precompute(struct aes_gcm_key *key, int flags)
910 {
911 	if (flags & FLAG_VAES_AVX512)
912 		aes_gcm_precompute_vaes_avx512(AES_GCM_KEY_VAES_AVX512(key));
913 	else if (flags & FLAG_VAES_AVX2)
914 		aes_gcm_precompute_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key));
915 	else if (flags & FLAG_AVX)
916 		aes_gcm_precompute_aesni_avx(AES_GCM_KEY_AESNI(key));
917 	else
918 		aes_gcm_precompute_aesni(AES_GCM_KEY_AESNI(key));
919 }
920 
921 asmlinkage void
922 aes_gcm_aad_update_aesni(const struct aes_gcm_key_aesni *key,
923 			 u8 ghash_acc[16], const u8 *aad, int aadlen);
924 asmlinkage void
925 aes_gcm_aad_update_aesni_avx(const struct aes_gcm_key_aesni *key,
926 			     u8 ghash_acc[16], const u8 *aad, int aadlen);
927 asmlinkage void
928 aes_gcm_aad_update_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key,
929 			     u8 ghash_acc[16], const u8 *aad, int aadlen);
930 asmlinkage void
931 aes_gcm_aad_update_vaes_avx512(const struct aes_gcm_key_vaes_avx512 *key,
932 			       u8 ghash_acc[16], const u8 *aad, int aadlen);
933 
934 static void aes_gcm_aad_update(const struct aes_gcm_key *key, u8 ghash_acc[16],
935 			       const u8 *aad, int aadlen, int flags)
936 {
937 	if (flags & FLAG_VAES_AVX512)
938 		aes_gcm_aad_update_vaes_avx512(AES_GCM_KEY_VAES_AVX512(key),
939 					       ghash_acc, aad, aadlen);
940 	else if (flags & FLAG_VAES_AVX2)
941 		aes_gcm_aad_update_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key),
942 					     ghash_acc, aad, aadlen);
943 	else if (flags & FLAG_AVX)
944 		aes_gcm_aad_update_aesni_avx(AES_GCM_KEY_AESNI(key), ghash_acc,
945 					     aad, aadlen);
946 	else
947 		aes_gcm_aad_update_aesni(AES_GCM_KEY_AESNI(key), ghash_acc,
948 					 aad, aadlen);
949 }
950 
951 asmlinkage void
952 aes_gcm_enc_update_aesni(const struct aes_gcm_key_aesni *key,
953 			 const u32 le_ctr[4], u8 ghash_acc[16],
954 			 const u8 *src, u8 *dst, int datalen);
955 asmlinkage void
956 aes_gcm_enc_update_aesni_avx(const struct aes_gcm_key_aesni *key,
957 			     const u32 le_ctr[4], u8 ghash_acc[16],
958 			     const u8 *src, u8 *dst, int datalen);
959 asmlinkage void
960 aes_gcm_enc_update_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key,
961 			     const u32 le_ctr[4], u8 ghash_acc[16],
962 			     const u8 *src, u8 *dst, int datalen);
963 asmlinkage void
964 aes_gcm_enc_update_vaes_avx512(const struct aes_gcm_key_vaes_avx512 *key,
965 			       const u32 le_ctr[4], u8 ghash_acc[16],
966 			       const u8 *src, u8 *dst, int datalen);
967 
968 asmlinkage void
969 aes_gcm_dec_update_aesni(const struct aes_gcm_key_aesni *key,
970 			 const u32 le_ctr[4], u8 ghash_acc[16],
971 			 const u8 *src, u8 *dst, int datalen);
972 asmlinkage void
973 aes_gcm_dec_update_aesni_avx(const struct aes_gcm_key_aesni *key,
974 			     const u32 le_ctr[4], u8 ghash_acc[16],
975 			     const u8 *src, u8 *dst, int datalen);
976 asmlinkage void
977 aes_gcm_dec_update_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key,
978 			     const u32 le_ctr[4], u8 ghash_acc[16],
979 			     const u8 *src, u8 *dst, int datalen);
980 asmlinkage void
981 aes_gcm_dec_update_vaes_avx512(const struct aes_gcm_key_vaes_avx512 *key,
982 			       const u32 le_ctr[4], u8 ghash_acc[16],
983 			       const u8 *src, u8 *dst, int datalen);
984 
985 /* __always_inline to optimize out the branches based on @flags */
986 static __always_inline void
987 aes_gcm_update(const struct aes_gcm_key *key,
988 	       const u32 le_ctr[4], u8 ghash_acc[16],
989 	       const u8 *src, u8 *dst, int datalen, int flags)
990 {
991 	if (flags & FLAG_ENC) {
992 		if (flags & FLAG_VAES_AVX512)
993 			aes_gcm_enc_update_vaes_avx512(AES_GCM_KEY_VAES_AVX512(key),
994 						       le_ctr, ghash_acc,
995 						       src, dst, datalen);
996 		else if (flags & FLAG_VAES_AVX2)
997 			aes_gcm_enc_update_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key),
998 						     le_ctr, ghash_acc,
999 						     src, dst, datalen);
1000 		else if (flags & FLAG_AVX)
1001 			aes_gcm_enc_update_aesni_avx(AES_GCM_KEY_AESNI(key),
1002 						     le_ctr, ghash_acc,
1003 						     src, dst, datalen);
1004 		else
1005 			aes_gcm_enc_update_aesni(AES_GCM_KEY_AESNI(key), le_ctr,
1006 						 ghash_acc, src, dst, datalen);
1007 	} else {
1008 		if (flags & FLAG_VAES_AVX512)
1009 			aes_gcm_dec_update_vaes_avx512(AES_GCM_KEY_VAES_AVX512(key),
1010 						       le_ctr, ghash_acc,
1011 						       src, dst, datalen);
1012 		else if (flags & FLAG_VAES_AVX2)
1013 			aes_gcm_dec_update_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key),
1014 						     le_ctr, ghash_acc,
1015 						     src, dst, datalen);
1016 		else if (flags & FLAG_AVX)
1017 			aes_gcm_dec_update_aesni_avx(AES_GCM_KEY_AESNI(key),
1018 						     le_ctr, ghash_acc,
1019 						     src, dst, datalen);
1020 		else
1021 			aes_gcm_dec_update_aesni(AES_GCM_KEY_AESNI(key),
1022 						 le_ctr, ghash_acc,
1023 						 src, dst, datalen);
1024 	}
1025 }
1026 
1027 asmlinkage void
1028 aes_gcm_enc_final_aesni(const struct aes_gcm_key_aesni *key,
1029 			const u32 le_ctr[4], u8 ghash_acc[16],
1030 			u64 total_aadlen, u64 total_datalen);
1031 asmlinkage void
1032 aes_gcm_enc_final_aesni_avx(const struct aes_gcm_key_aesni *key,
1033 			    const u32 le_ctr[4], u8 ghash_acc[16],
1034 			    u64 total_aadlen, u64 total_datalen);
1035 asmlinkage void
1036 aes_gcm_enc_final_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key,
1037 			    const u32 le_ctr[4], u8 ghash_acc[16],
1038 			    u64 total_aadlen, u64 total_datalen);
1039 asmlinkage void
1040 aes_gcm_enc_final_vaes_avx512(const struct aes_gcm_key_vaes_avx512 *key,
1041 			      const u32 le_ctr[4], u8 ghash_acc[16],
1042 			      u64 total_aadlen, u64 total_datalen);
1043 
1044 /* __always_inline to optimize out the branches based on @flags */
1045 static __always_inline void
1046 aes_gcm_enc_final(const struct aes_gcm_key *key,
1047 		  const u32 le_ctr[4], u8 ghash_acc[16],
1048 		  u64 total_aadlen, u64 total_datalen, int flags)
1049 {
1050 	if (flags & FLAG_VAES_AVX512)
1051 		aes_gcm_enc_final_vaes_avx512(AES_GCM_KEY_VAES_AVX512(key),
1052 					      le_ctr, ghash_acc,
1053 					      total_aadlen, total_datalen);
1054 	else if (flags & FLAG_VAES_AVX2)
1055 		aes_gcm_enc_final_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key),
1056 					    le_ctr, ghash_acc,
1057 					    total_aadlen, total_datalen);
1058 	else if (flags & FLAG_AVX)
1059 		aes_gcm_enc_final_aesni_avx(AES_GCM_KEY_AESNI(key),
1060 					    le_ctr, ghash_acc,
1061 					    total_aadlen, total_datalen);
1062 	else
1063 		aes_gcm_enc_final_aesni(AES_GCM_KEY_AESNI(key),
1064 					le_ctr, ghash_acc,
1065 					total_aadlen, total_datalen);
1066 }
1067 
1068 asmlinkage bool __must_check
1069 aes_gcm_dec_final_aesni(const struct aes_gcm_key_aesni *key,
1070 			const u32 le_ctr[4], const u8 ghash_acc[16],
1071 			u64 total_aadlen, u64 total_datalen,
1072 			const u8 tag[16], int taglen);
1073 asmlinkage bool __must_check
1074 aes_gcm_dec_final_aesni_avx(const struct aes_gcm_key_aesni *key,
1075 			    const u32 le_ctr[4], const u8 ghash_acc[16],
1076 			    u64 total_aadlen, u64 total_datalen,
1077 			    const u8 tag[16], int taglen);
1078 asmlinkage bool __must_check
1079 aes_gcm_dec_final_vaes_avx2(const struct aes_gcm_key_vaes_avx2 *key,
1080 			    const u32 le_ctr[4], const u8 ghash_acc[16],
1081 			    u64 total_aadlen, u64 total_datalen,
1082 			    const u8 tag[16], int taglen);
1083 asmlinkage bool __must_check
1084 aes_gcm_dec_final_vaes_avx512(const struct aes_gcm_key_vaes_avx512 *key,
1085 			      const u32 le_ctr[4], const u8 ghash_acc[16],
1086 			      u64 total_aadlen, u64 total_datalen,
1087 			      const u8 tag[16], int taglen);
1088 
1089 /* __always_inline to optimize out the branches based on @flags */
1090 static __always_inline bool __must_check
1091 aes_gcm_dec_final(const struct aes_gcm_key *key, const u32 le_ctr[4],
1092 		  u8 ghash_acc[16], u64 total_aadlen, u64 total_datalen,
1093 		  u8 tag[16], int taglen, int flags)
1094 {
1095 	if (flags & FLAG_VAES_AVX512)
1096 		return aes_gcm_dec_final_vaes_avx512(AES_GCM_KEY_VAES_AVX512(key),
1097 						     le_ctr, ghash_acc,
1098 						     total_aadlen, total_datalen,
1099 						     tag, taglen);
1100 	else if (flags & FLAG_VAES_AVX2)
1101 		return aes_gcm_dec_final_vaes_avx2(AES_GCM_KEY_VAES_AVX2(key),
1102 						   le_ctr, ghash_acc,
1103 						   total_aadlen, total_datalen,
1104 						   tag, taglen);
1105 	else if (flags & FLAG_AVX)
1106 		return aes_gcm_dec_final_aesni_avx(AES_GCM_KEY_AESNI(key),
1107 						   le_ctr, ghash_acc,
1108 						   total_aadlen, total_datalen,
1109 						   tag, taglen);
1110 	else
1111 		return aes_gcm_dec_final_aesni(AES_GCM_KEY_AESNI(key),
1112 					       le_ctr, ghash_acc,
1113 					       total_aadlen, total_datalen,
1114 					       tag, taglen);
1115 }
1116 
1117 /*
1118  * This is the Integrity Check Value (aka the authentication tag) length and can
1119  * be 8, 12 or 16 bytes long.
1120  */
1121 static int common_rfc4106_set_authsize(struct crypto_aead *aead,
1122 				       unsigned int authsize)
1123 {
1124 	switch (authsize) {
1125 	case 8:
1126 	case 12:
1127 	case 16:
1128 		break;
1129 	default:
1130 		return -EINVAL;
1131 	}
1132 
1133 	return 0;
1134 }
1135 
1136 static int generic_gcmaes_set_authsize(struct crypto_aead *tfm,
1137 				       unsigned int authsize)
1138 {
1139 	switch (authsize) {
1140 	case 4:
1141 	case 8:
1142 	case 12:
1143 	case 13:
1144 	case 14:
1145 	case 15:
1146 	case 16:
1147 		break;
1148 	default:
1149 		return -EINVAL;
1150 	}
1151 
1152 	return 0;
1153 }
1154 
1155 /*
1156  * This is the setkey function for the x86_64 implementations of AES-GCM.  It
1157  * saves the RFC4106 nonce if applicable, expands the AES key, and precomputes
1158  * powers of the hash key.
1159  *
1160  * To comply with the crypto_aead API, this has to be usable in no-SIMD context.
1161  * For that reason, this function includes a portable C implementation of the
1162  * needed logic.  However, the portable C implementation is very slow, taking
1163  * about the same time as encrypting 37 KB of data.  To be ready for users that
1164  * may set a key even somewhat frequently, we therefore also include a SIMD
1165  * assembly implementation, expanding the AES key using AES-NI and precomputing
1166  * the hash key powers using PCLMULQDQ or VPCLMULQDQ.
1167  */
1168 static int gcm_setkey(struct crypto_aead *tfm, const u8 *raw_key,
1169 		      unsigned int keylen, int flags)
1170 {
1171 	struct aes_gcm_key *key = aes_gcm_key_get(tfm, flags);
1172 	int err;
1173 
1174 	if (flags & FLAG_RFC4106) {
1175 		if (keylen < 4)
1176 			return -EINVAL;
1177 		keylen -= 4;
1178 		key->rfc4106_nonce = get_unaligned_be32(raw_key + keylen);
1179 	}
1180 
1181 	/* The assembly code assumes the following offsets. */
1182 	static_assert(offsetof(struct aes_gcm_key_aesni, base.aes_key.len) == 0);
1183 	static_assert(offsetof(struct aes_gcm_key_aesni, base.aes_key.k.rndkeys) == 16);
1184 	static_assert(offsetof(struct aes_gcm_key_aesni, h_powers) == 272);
1185 	static_assert(offsetof(struct aes_gcm_key_aesni, h_powers_xored) == 400);
1186 	static_assert(offsetof(struct aes_gcm_key_aesni, h_times_x64) == 464);
1187 	static_assert(offsetof(struct aes_gcm_key_vaes_avx2, base.aes_key.len) == 0);
1188 	static_assert(offsetof(struct aes_gcm_key_vaes_avx2, base.aes_key.k.rndkeys) == 16);
1189 	static_assert(offsetof(struct aes_gcm_key_vaes_avx2, h_powers) == 288);
1190 	static_assert(offsetof(struct aes_gcm_key_vaes_avx2, h_powers_xored) == 416);
1191 	static_assert(offsetof(struct aes_gcm_key_vaes_avx512, base.aes_key.len) == 0);
1192 	static_assert(offsetof(struct aes_gcm_key_vaes_avx512, base.aes_key.k.rndkeys) == 16);
1193 	static_assert(offsetof(struct aes_gcm_key_vaes_avx512, h_powers) == 320);
1194 	static_assert(offsetof(struct aes_gcm_key_vaes_avx512, padding) == 576);
1195 
1196 	err = aes_prepareenckey(&key->aes_key, raw_key, keylen);
1197 	if (err)
1198 		return err;
1199 
1200 	if (likely(crypto_simd_usable())) {
1201 		kernel_fpu_begin();
1202 		aes_gcm_precompute(key, flags);
1203 		kernel_fpu_end();
1204 	} else {
1205 		static const u8 x_to_the_minus1[16] __aligned(__alignof__(be128)) = {
1206 			[0] = 0xc2, [15] = 1
1207 		};
1208 		static const u8 x_to_the_63[16] __aligned(__alignof__(be128)) = {
1209 			[7] = 1,
1210 		};
1211 		be128 h1 = {};
1212 		be128 h;
1213 		int i;
1214 
1215 		/* Encrypt the all-zeroes block to get the hash key H^1 */
1216 		aes_encrypt(&key->aes_key, (u8 *)&h1, (u8 *)&h1);
1217 
1218 		/* Compute H^1 * x^-1 */
1219 		h = h1;
1220 		gf128mul_lle(&h, (const be128 *)x_to_the_minus1);
1221 
1222 		/* Compute the needed key powers */
1223 		if (flags & FLAG_VAES_AVX512) {
1224 			struct aes_gcm_key_vaes_avx512 *k =
1225 				AES_GCM_KEY_VAES_AVX512(key);
1226 
1227 			for (i = ARRAY_SIZE(k->h_powers) - 1; i >= 0; i--) {
1228 				k->h_powers[i][0] = be64_to_cpu(h.b);
1229 				k->h_powers[i][1] = be64_to_cpu(h.a);
1230 				gf128mul_lle(&h, &h1);
1231 			}
1232 			memset(k->padding, 0, sizeof(k->padding));
1233 		} else if (flags & FLAG_VAES_AVX2) {
1234 			struct aes_gcm_key_vaes_avx2 *k =
1235 				AES_GCM_KEY_VAES_AVX2(key);
1236 			static const u8 indices[8] = { 0, 2, 1, 3, 4, 6, 5, 7 };
1237 
1238 			for (i = ARRAY_SIZE(k->h_powers) - 1; i >= 0; i--) {
1239 				k->h_powers[i][0] = be64_to_cpu(h.b);
1240 				k->h_powers[i][1] = be64_to_cpu(h.a);
1241 				gf128mul_lle(&h, &h1);
1242 			}
1243 			for (i = 0; i < ARRAY_SIZE(k->h_powers_xored); i++) {
1244 				int j = indices[i];
1245 
1246 				k->h_powers_xored[i] = k->h_powers[j][0] ^
1247 						       k->h_powers[j][1];
1248 			}
1249 		} else {
1250 			struct aes_gcm_key_aesni *k = AES_GCM_KEY_AESNI(key);
1251 
1252 			for (i = ARRAY_SIZE(k->h_powers) - 1; i >= 0; i--) {
1253 				k->h_powers[i][0] = be64_to_cpu(h.b);
1254 				k->h_powers[i][1] = be64_to_cpu(h.a);
1255 				k->h_powers_xored[i] = k->h_powers[i][0] ^
1256 						       k->h_powers[i][1];
1257 				gf128mul_lle(&h, &h1);
1258 			}
1259 			gf128mul_lle(&h1, (const be128 *)x_to_the_63);
1260 			k->h_times_x64[0] = be64_to_cpu(h1.b);
1261 			k->h_times_x64[1] = be64_to_cpu(h1.a);
1262 		}
1263 	}
1264 	return 0;
1265 }
1266 
1267 /*
1268  * Initialize @ghash_acc, then pass all @assoclen bytes of associated data
1269  * (a.k.a. additional authenticated data) from @sg_src through the GHASH update
1270  * assembly function.  kernel_fpu_begin() must have already been called.
1271  */
1272 static void gcm_process_assoc(const struct aes_gcm_key *key, u8 ghash_acc[16],
1273 			      struct scatterlist *sg_src, unsigned int assoclen,
1274 			      int flags)
1275 {
1276 	struct scatter_walk walk;
1277 	/*
1278 	 * The assembly function requires that the length of any non-last
1279 	 * segment of associated data be a multiple of 16 bytes, so this
1280 	 * function does the buffering needed to achieve that.
1281 	 */
1282 	unsigned int pos = 0;
1283 	u8 buf[16];
1284 
1285 	memset(ghash_acc, 0, 16);
1286 	scatterwalk_start(&walk, sg_src);
1287 
1288 	while (assoclen) {
1289 		unsigned int orig_len_this_step = scatterwalk_next(
1290 			&walk, assoclen);
1291 		unsigned int len_this_step = orig_len_this_step;
1292 		unsigned int len;
1293 		const u8 *src = walk.addr;
1294 
1295 		if (unlikely(pos)) {
1296 			len = min(len_this_step, 16 - pos);
1297 			memcpy(&buf[pos], src, len);
1298 			pos += len;
1299 			src += len;
1300 			len_this_step -= len;
1301 			if (pos < 16)
1302 				goto next;
1303 			aes_gcm_aad_update(key, ghash_acc, buf, 16, flags);
1304 			pos = 0;
1305 		}
1306 		len = len_this_step;
1307 		if (unlikely(assoclen)) /* Not the last segment yet? */
1308 			len = round_down(len, 16);
1309 		aes_gcm_aad_update(key, ghash_acc, src, len, flags);
1310 		src += len;
1311 		len_this_step -= len;
1312 		if (unlikely(len_this_step)) {
1313 			memcpy(buf, src, len_this_step);
1314 			pos = len_this_step;
1315 		}
1316 next:
1317 		scatterwalk_done_src(&walk, orig_len_this_step);
1318 		if (need_resched()) {
1319 			kernel_fpu_end();
1320 			kernel_fpu_begin();
1321 		}
1322 		assoclen -= orig_len_this_step;
1323 	}
1324 	if (unlikely(pos))
1325 		aes_gcm_aad_update(key, ghash_acc, buf, pos, flags);
1326 }
1327 
1328 
1329 /* __always_inline to optimize out the branches based on @flags */
1330 static __always_inline int
1331 gcm_crypt(struct aead_request *req, int flags)
1332 {
1333 	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
1334 	const struct aes_gcm_key *key = aes_gcm_key_get(tfm, flags);
1335 	unsigned int assoclen = req->assoclen;
1336 	struct skcipher_walk walk;
1337 	unsigned int nbytes;
1338 	u8 ghash_acc[16]; /* GHASH accumulator */
1339 	u32 le_ctr[4]; /* Counter in little-endian format */
1340 	int taglen;
1341 	int err;
1342 
1343 	/* Initialize the counter and determine the associated data length. */
1344 	le_ctr[0] = 2;
1345 	if (flags & FLAG_RFC4106) {
1346 		if (unlikely(assoclen != 16 && assoclen != 20))
1347 			return -EINVAL;
1348 		assoclen -= 8;
1349 		le_ctr[1] = get_unaligned_be32(req->iv + 4);
1350 		le_ctr[2] = get_unaligned_be32(req->iv + 0);
1351 		le_ctr[3] = key->rfc4106_nonce; /* already byte-swapped */
1352 	} else {
1353 		le_ctr[1] = get_unaligned_be32(req->iv + 8);
1354 		le_ctr[2] = get_unaligned_be32(req->iv + 4);
1355 		le_ctr[3] = get_unaligned_be32(req->iv + 0);
1356 	}
1357 
1358 	/* Begin walking through the plaintext or ciphertext. */
1359 	if (flags & FLAG_ENC)
1360 		err = skcipher_walk_aead_encrypt(&walk, req, false);
1361 	else
1362 		err = skcipher_walk_aead_decrypt(&walk, req, false);
1363 	if (err)
1364 		return err;
1365 
1366 	/*
1367 	 * Since the AES-GCM assembly code requires that at least three assembly
1368 	 * functions be called to process any message (this is needed to support
1369 	 * incremental updates cleanly), to reduce overhead we try to do all
1370 	 * three calls in the same kernel FPU section if possible.  We close the
1371 	 * section and start a new one if there are multiple data segments or if
1372 	 * rescheduling is needed while processing the associated data.
1373 	 */
1374 	kernel_fpu_begin();
1375 
1376 	/* Pass the associated data through GHASH. */
1377 	gcm_process_assoc(key, ghash_acc, req->src, assoclen, flags);
1378 
1379 	/* En/decrypt the data and pass the ciphertext through GHASH. */
1380 	while (unlikely((nbytes = walk.nbytes) < walk.total)) {
1381 		/*
1382 		 * Non-last segment.  In this case, the assembly function
1383 		 * requires that the length be a multiple of 16 (AES_BLOCK_SIZE)
1384 		 * bytes.  The needed buffering of up to 16 bytes is handled by
1385 		 * the skcipher_walk.  Here we just need to round down to a
1386 		 * multiple of 16.
1387 		 */
1388 		nbytes = round_down(nbytes, AES_BLOCK_SIZE);
1389 		aes_gcm_update(key, le_ctr, ghash_acc, walk.src.virt.addr,
1390 			       walk.dst.virt.addr, nbytes, flags);
1391 		le_ctr[0] += nbytes / AES_BLOCK_SIZE;
1392 		kernel_fpu_end();
1393 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
1394 		if (err)
1395 			return err;
1396 		kernel_fpu_begin();
1397 	}
1398 	/* Last segment: process all remaining data. */
1399 	aes_gcm_update(key, le_ctr, ghash_acc, walk.src.virt.addr,
1400 		       walk.dst.virt.addr, nbytes, flags);
1401 	/*
1402 	 * The low word of the counter isn't used by the finalize, so there's no
1403 	 * need to increment it here.
1404 	 */
1405 
1406 	/* Finalize */
1407 	taglen = crypto_aead_authsize(tfm);
1408 	if (flags & FLAG_ENC) {
1409 		/* Finish computing the auth tag. */
1410 		aes_gcm_enc_final(key, le_ctr, ghash_acc, assoclen,
1411 				  req->cryptlen, flags);
1412 
1413 		/* Store the computed auth tag in the dst scatterlist. */
1414 		scatterwalk_map_and_copy(ghash_acc, req->dst, req->assoclen +
1415 					 req->cryptlen, taglen, 1);
1416 	} else {
1417 		unsigned int datalen = req->cryptlen - taglen;
1418 		u8 tag[16];
1419 
1420 		/* Get the transmitted auth tag from the src scatterlist. */
1421 		scatterwalk_map_and_copy(tag, req->src, req->assoclen + datalen,
1422 					 taglen, 0);
1423 		/*
1424 		 * Finish computing the auth tag and compare it to the
1425 		 * transmitted one.  The assembly function does the actual tag
1426 		 * comparison.  Here, just check the boolean result.
1427 		 */
1428 		if (!aes_gcm_dec_final(key, le_ctr, ghash_acc, assoclen,
1429 				       datalen, tag, taglen, flags))
1430 			err = -EBADMSG;
1431 	}
1432 	kernel_fpu_end();
1433 	if (nbytes)
1434 		skcipher_walk_done(&walk, 0);
1435 	return err;
1436 }
1437 
1438 #define DEFINE_GCM_ALGS(suffix, flags, generic_driver_name, rfc_driver_name,   \
1439 			ctxsize, priority)				       \
1440 									       \
1441 static int gcm_setkey_##suffix(struct crypto_aead *tfm, const u8 *raw_key,     \
1442 			       unsigned int keylen)			       \
1443 {									       \
1444 	return gcm_setkey(tfm, raw_key, keylen, (flags));		       \
1445 }									       \
1446 									       \
1447 static int gcm_encrypt_##suffix(struct aead_request *req)		       \
1448 {									       \
1449 	return gcm_crypt(req, (flags) | FLAG_ENC);			       \
1450 }									       \
1451 									       \
1452 static int gcm_decrypt_##suffix(struct aead_request *req)		       \
1453 {									       \
1454 	return gcm_crypt(req, (flags));					       \
1455 }									       \
1456 									       \
1457 static int rfc4106_setkey_##suffix(struct crypto_aead *tfm, const u8 *raw_key, \
1458 				   unsigned int keylen)			       \
1459 {									       \
1460 	return gcm_setkey(tfm, raw_key, keylen, (flags) | FLAG_RFC4106);       \
1461 }									       \
1462 									       \
1463 static int rfc4106_encrypt_##suffix(struct aead_request *req)		       \
1464 {									       \
1465 	return gcm_crypt(req, (flags) | FLAG_RFC4106 | FLAG_ENC);	       \
1466 }									       \
1467 									       \
1468 static int rfc4106_decrypt_##suffix(struct aead_request *req)		       \
1469 {									       \
1470 	return gcm_crypt(req, (flags) | FLAG_RFC4106);			       \
1471 }									       \
1472 									       \
1473 static struct aead_alg aes_gcm_algs_##suffix[] = { {			       \
1474 	.setkey			= gcm_setkey_##suffix,			       \
1475 	.setauthsize		= generic_gcmaes_set_authsize,		       \
1476 	.encrypt		= gcm_encrypt_##suffix,			       \
1477 	.decrypt		= gcm_decrypt_##suffix,			       \
1478 	.ivsize			= GCM_AES_IV_SIZE,			       \
1479 	.chunksize		= AES_BLOCK_SIZE,			       \
1480 	.maxauthsize		= 16,					       \
1481 	.base = {							       \
1482 		.cra_name		= "gcm(aes)",			       \
1483 		.cra_driver_name	= generic_driver_name,		       \
1484 		.cra_priority		= (priority),			       \
1485 		.cra_blocksize		= 1,				       \
1486 		.cra_ctxsize		= (ctxsize),			       \
1487 		.cra_module		= THIS_MODULE,			       \
1488 	},								       \
1489 }, {									       \
1490 	.setkey			= rfc4106_setkey_##suffix,		       \
1491 	.setauthsize		= common_rfc4106_set_authsize,		       \
1492 	.encrypt		= rfc4106_encrypt_##suffix,		       \
1493 	.decrypt		= rfc4106_decrypt_##suffix,		       \
1494 	.ivsize			= GCM_RFC4106_IV_SIZE,			       \
1495 	.chunksize		= AES_BLOCK_SIZE,			       \
1496 	.maxauthsize		= 16,					       \
1497 	.base = {							       \
1498 		.cra_name		= "rfc4106(gcm(aes))",		       \
1499 		.cra_driver_name	= rfc_driver_name,		       \
1500 		.cra_priority		= (priority),			       \
1501 		.cra_blocksize		= 1,				       \
1502 		.cra_ctxsize		= (ctxsize),			       \
1503 		.cra_module		= THIS_MODULE,			       \
1504 	},								       \
1505 } }
1506 
1507 /* aes_gcm_algs_aesni */
1508 DEFINE_GCM_ALGS(aesni, /* no flags */ 0,
1509 		"generic-gcm-aesni", "rfc4106-gcm-aesni",
1510 		AES_GCM_KEY_AESNI_SIZE, 400);
1511 
1512 /* aes_gcm_algs_aesni_avx */
1513 DEFINE_GCM_ALGS(aesni_avx, FLAG_AVX,
1514 		"generic-gcm-aesni-avx", "rfc4106-gcm-aesni-avx",
1515 		AES_GCM_KEY_AESNI_SIZE, 500);
1516 
1517 /* aes_gcm_algs_vaes_avx2 */
1518 DEFINE_GCM_ALGS(vaes_avx2, FLAG_VAES_AVX2,
1519 		"generic-gcm-vaes-avx2", "rfc4106-gcm-vaes-avx2",
1520 		AES_GCM_KEY_VAES_AVX2_SIZE, 600);
1521 
1522 /* aes_gcm_algs_vaes_avx512 */
1523 DEFINE_GCM_ALGS(vaes_avx512, FLAG_VAES_AVX512,
1524 		"generic-gcm-vaes-avx512", "rfc4106-gcm-vaes-avx512",
1525 		AES_GCM_KEY_VAES_AVX512_SIZE, 800);
1526 
1527 static int __init register_avx_algs(void)
1528 {
1529 	int err;
1530 
1531 	if (!boot_cpu_has(X86_FEATURE_AVX))
1532 		return 0;
1533 	err = crypto_register_skciphers(skcipher_algs_aesni_avx,
1534 					ARRAY_SIZE(skcipher_algs_aesni_avx));
1535 	if (err)
1536 		return err;
1537 	err = crypto_register_aeads(aes_gcm_algs_aesni_avx,
1538 				    ARRAY_SIZE(aes_gcm_algs_aesni_avx));
1539 	if (err)
1540 		return err;
1541 	/*
1542 	 * Note: not all the algorithms registered below actually require
1543 	 * VPCLMULQDQ.  But in practice every CPU with VAES also has VPCLMULQDQ.
1544 	 * Similarly, the assembler support was added at about the same time.
1545 	 * For simplicity, just always check for VAES and VPCLMULQDQ together.
1546 	 */
1547 	if (!boot_cpu_has(X86_FEATURE_AVX2) ||
1548 	    !boot_cpu_has(X86_FEATURE_VAES) ||
1549 	    !boot_cpu_has(X86_FEATURE_VPCLMULQDQ) ||
1550 	    !boot_cpu_has(X86_FEATURE_PCLMULQDQ) ||
1551 	    !cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM, NULL))
1552 		return 0;
1553 	err = crypto_register_skciphers(skcipher_algs_vaes_avx2,
1554 					ARRAY_SIZE(skcipher_algs_vaes_avx2));
1555 	if (err)
1556 		return err;
1557 	err = crypto_register_aeads(aes_gcm_algs_vaes_avx2,
1558 				    ARRAY_SIZE(aes_gcm_algs_vaes_avx2));
1559 	if (err)
1560 		return err;
1561 
1562 	if (!boot_cpu_has(X86_FEATURE_AVX512BW) ||
1563 	    !boot_cpu_has(X86_FEATURE_AVX512VL) ||
1564 	    !boot_cpu_has(X86_FEATURE_BMI2) ||
1565 	    !cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM |
1566 			       XFEATURE_MASK_AVX512, NULL))
1567 		return 0;
1568 
1569 	if (boot_cpu_has(X86_FEATURE_PREFER_YMM)) {
1570 		int i;
1571 
1572 		for (i = 0; i < ARRAY_SIZE(skcipher_algs_vaes_avx512); i++)
1573 			skcipher_algs_vaes_avx512[i].base.cra_priority = 1;
1574 		for (i = 0; i < ARRAY_SIZE(aes_gcm_algs_vaes_avx512); i++)
1575 			aes_gcm_algs_vaes_avx512[i].base.cra_priority = 1;
1576 	}
1577 
1578 	err = crypto_register_skciphers(skcipher_algs_vaes_avx512,
1579 					ARRAY_SIZE(skcipher_algs_vaes_avx512));
1580 	if (err)
1581 		return err;
1582 	err = crypto_register_aeads(aes_gcm_algs_vaes_avx512,
1583 				    ARRAY_SIZE(aes_gcm_algs_vaes_avx512));
1584 	if (err)
1585 		return err;
1586 
1587 	return 0;
1588 }
1589 
1590 #define unregister_skciphers(A) \
1591 	if (refcount_read(&(A)[0].base.cra_refcnt) != 0) \
1592 		crypto_unregister_skciphers((A), ARRAY_SIZE(A))
1593 #define unregister_aeads(A) \
1594 	if (refcount_read(&(A)[0].base.cra_refcnt) != 0) \
1595 		crypto_unregister_aeads((A), ARRAY_SIZE(A))
1596 
1597 static void unregister_avx_algs(void)
1598 {
1599 	unregister_skciphers(skcipher_algs_aesni_avx);
1600 	unregister_aeads(aes_gcm_algs_aesni_avx);
1601 	unregister_skciphers(skcipher_algs_vaes_avx2);
1602 	unregister_skciphers(skcipher_algs_vaes_avx512);
1603 	unregister_aeads(aes_gcm_algs_vaes_avx2);
1604 	unregister_aeads(aes_gcm_algs_vaes_avx512);
1605 }
1606 #else /* CONFIG_X86_64 */
1607 static struct aead_alg aes_gcm_algs_aesni[0];
1608 
1609 static int __init register_avx_algs(void)
1610 {
1611 	return 0;
1612 }
1613 
1614 static void unregister_avx_algs(void)
1615 {
1616 }
1617 #endif /* !CONFIG_X86_64 */
1618 
1619 static const struct x86_cpu_id aesni_cpu_id[] = {
1620 	X86_MATCH_FEATURE(X86_FEATURE_AES, NULL),
1621 	{}
1622 };
1623 MODULE_DEVICE_TABLE(x86cpu, aesni_cpu_id);
1624 
1625 static int __init aesni_init(void)
1626 {
1627 	int err;
1628 
1629 	if (!x86_match_cpu(aesni_cpu_id))
1630 		return -ENODEV;
1631 
1632 	err = crypto_register_skciphers(aesni_skciphers,
1633 					ARRAY_SIZE(aesni_skciphers));
1634 	if (err)
1635 		return err;
1636 
1637 	err = crypto_register_aeads(aes_gcm_algs_aesni,
1638 				    ARRAY_SIZE(aes_gcm_algs_aesni));
1639 	if (err)
1640 		goto unregister_skciphers;
1641 
1642 	err = register_avx_algs();
1643 	if (err)
1644 		goto unregister_avx;
1645 
1646 	return 0;
1647 
1648 unregister_avx:
1649 	unregister_avx_algs();
1650 	crypto_unregister_aeads(aes_gcm_algs_aesni,
1651 				ARRAY_SIZE(aes_gcm_algs_aesni));
1652 unregister_skciphers:
1653 	crypto_unregister_skciphers(aesni_skciphers,
1654 				    ARRAY_SIZE(aesni_skciphers));
1655 	return err;
1656 }
1657 
1658 static void __exit aesni_exit(void)
1659 {
1660 	crypto_unregister_aeads(aes_gcm_algs_aesni,
1661 				ARRAY_SIZE(aes_gcm_algs_aesni));
1662 	crypto_unregister_skciphers(aesni_skciphers,
1663 				    ARRAY_SIZE(aesni_skciphers));
1664 	unregister_avx_algs();
1665 }
1666 
1667 module_init(aesni_init);
1668 module_exit(aesni_exit);
1669 
1670 MODULE_DESCRIPTION("AES cipher and modes, optimized with AES-NI or VAES instructions");
1671 MODULE_LICENSE("GPL");
1672 MODULE_ALIAS_CRYPTO("aes");
1673