xref: /linux/arch/x86/crypto/aesni-intel_glue.c (revision 483fd65ce29317044d1d00757e3fd23503b6b04c)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * Support for Intel AES-NI instructions. This file contains glue
4  * code, the real AES implementation is in intel-aes_asm.S.
5  *
6  * Copyright (C) 2008, Intel Corp.
7  *    Author: Huang Ying <ying.huang@intel.com>
8  *
9  * Added RFC4106 AES-GCM support for 128-bit keys under the AEAD
10  * interface for 64-bit kernels.
11  *    Authors: Adrian Hoban <adrian.hoban@intel.com>
12  *             Gabriele Paoloni <gabriele.paoloni@intel.com>
13  *             Tadeusz Struk (tadeusz.struk@intel.com)
14  *             Aidan O'Mahony (aidan.o.mahony@intel.com)
15  *    Copyright (c) 2010, Intel Corporation.
16  */
17 
18 #include <linux/hardirq.h>
19 #include <linux/types.h>
20 #include <linux/module.h>
21 #include <linux/err.h>
22 #include <crypto/algapi.h>
23 #include <crypto/aes.h>
24 #include <crypto/ctr.h>
25 #include <crypto/b128ops.h>
26 #include <crypto/gcm.h>
27 #include <crypto/xts.h>
28 #include <asm/cpu_device_id.h>
29 #include <asm/simd.h>
30 #include <crypto/scatterwalk.h>
31 #include <crypto/internal/aead.h>
32 #include <crypto/internal/simd.h>
33 #include <crypto/internal/skcipher.h>
34 #include <linux/jump_label.h>
35 #include <linux/workqueue.h>
36 #include <linux/spinlock.h>
37 #include <linux/static_call.h>
38 
39 
40 #define AESNI_ALIGN	16
41 #define AESNI_ALIGN_ATTR __attribute__ ((__aligned__(AESNI_ALIGN)))
42 #define AES_BLOCK_MASK	(~(AES_BLOCK_SIZE - 1))
43 #define RFC4106_HASH_SUBKEY_SIZE 16
44 #define AESNI_ALIGN_EXTRA ((AESNI_ALIGN - 1) & ~(CRYPTO_MINALIGN - 1))
45 #define CRYPTO_AES_CTX_SIZE (sizeof(struct crypto_aes_ctx) + AESNI_ALIGN_EXTRA)
46 #define XTS_AES_CTX_SIZE (sizeof(struct aesni_xts_ctx) + AESNI_ALIGN_EXTRA)
47 
48 /* This data is stored at the end of the crypto_tfm struct.
49  * It's a type of per "session" data storage location.
50  * This needs to be 16 byte aligned.
51  */
52 struct aesni_rfc4106_gcm_ctx {
53 	u8 hash_subkey[16] AESNI_ALIGN_ATTR;
54 	struct crypto_aes_ctx aes_key_expanded AESNI_ALIGN_ATTR;
55 	u8 nonce[4];
56 };
57 
58 struct generic_gcmaes_ctx {
59 	u8 hash_subkey[16] AESNI_ALIGN_ATTR;
60 	struct crypto_aes_ctx aes_key_expanded AESNI_ALIGN_ATTR;
61 };
62 
63 struct aesni_xts_ctx {
64 	struct crypto_aes_ctx tweak_ctx AESNI_ALIGN_ATTR;
65 	struct crypto_aes_ctx crypt_ctx AESNI_ALIGN_ATTR;
66 };
67 
68 #define GCM_BLOCK_LEN 16
69 
70 struct gcm_context_data {
71 	/* init, update and finalize context data */
72 	u8 aad_hash[GCM_BLOCK_LEN];
73 	u64 aad_length;
74 	u64 in_length;
75 	u8 partial_block_enc_key[GCM_BLOCK_LEN];
76 	u8 orig_IV[GCM_BLOCK_LEN];
77 	u8 current_counter[GCM_BLOCK_LEN];
78 	u64 partial_block_len;
79 	u64 unused;
80 	u8 hash_keys[GCM_BLOCK_LEN * 16];
81 };
82 
83 static inline void *aes_align_addr(void *addr)
84 {
85 	if (crypto_tfm_ctx_alignment() >= AESNI_ALIGN)
86 		return addr;
87 	return PTR_ALIGN(addr, AESNI_ALIGN);
88 }
89 
90 asmlinkage void aesni_set_key(struct crypto_aes_ctx *ctx, const u8 *in_key,
91 			      unsigned int key_len);
92 asmlinkage void aesni_enc(const void *ctx, u8 *out, const u8 *in);
93 asmlinkage void aesni_dec(const void *ctx, u8 *out, const u8 *in);
94 asmlinkage void aesni_ecb_enc(struct crypto_aes_ctx *ctx, u8 *out,
95 			      const u8 *in, unsigned int len);
96 asmlinkage void aesni_ecb_dec(struct crypto_aes_ctx *ctx, u8 *out,
97 			      const u8 *in, unsigned int len);
98 asmlinkage void aesni_cbc_enc(struct crypto_aes_ctx *ctx, u8 *out,
99 			      const u8 *in, unsigned int len, u8 *iv);
100 asmlinkage void aesni_cbc_dec(struct crypto_aes_ctx *ctx, u8 *out,
101 			      const u8 *in, unsigned int len, u8 *iv);
102 asmlinkage void aesni_cts_cbc_enc(struct crypto_aes_ctx *ctx, u8 *out,
103 				  const u8 *in, unsigned int len, u8 *iv);
104 asmlinkage void aesni_cts_cbc_dec(struct crypto_aes_ctx *ctx, u8 *out,
105 				  const u8 *in, unsigned int len, u8 *iv);
106 
107 #define AVX_GEN2_OPTSIZE 640
108 #define AVX_GEN4_OPTSIZE 4096
109 
110 asmlinkage void aesni_xts_enc(const struct crypto_aes_ctx *ctx, u8 *out,
111 			      const u8 *in, unsigned int len, u8 *iv);
112 
113 asmlinkage void aesni_xts_dec(const struct crypto_aes_ctx *ctx, u8 *out,
114 			      const u8 *in, unsigned int len, u8 *iv);
115 
116 #ifdef CONFIG_X86_64
117 
118 asmlinkage void aesni_ctr_enc(struct crypto_aes_ctx *ctx, u8 *out,
119 			      const u8 *in, unsigned int len, u8 *iv);
120 DEFINE_STATIC_CALL(aesni_ctr_enc_tfm, aesni_ctr_enc);
121 
122 /* Scatter / Gather routines, with args similar to above */
123 asmlinkage void aesni_gcm_init(void *ctx,
124 			       struct gcm_context_data *gdata,
125 			       u8 *iv,
126 			       u8 *hash_subkey, const u8 *aad,
127 			       unsigned long aad_len);
128 asmlinkage void aesni_gcm_enc_update(void *ctx,
129 				     struct gcm_context_data *gdata, u8 *out,
130 				     const u8 *in, unsigned long plaintext_len);
131 asmlinkage void aesni_gcm_dec_update(void *ctx,
132 				     struct gcm_context_data *gdata, u8 *out,
133 				     const u8 *in,
134 				     unsigned long ciphertext_len);
135 asmlinkage void aesni_gcm_finalize(void *ctx,
136 				   struct gcm_context_data *gdata,
137 				   u8 *auth_tag, unsigned long auth_tag_len);
138 
139 asmlinkage void aes_ctr_enc_128_avx_by8(const u8 *in, u8 *iv,
140 		void *keys, u8 *out, unsigned int num_bytes);
141 asmlinkage void aes_ctr_enc_192_avx_by8(const u8 *in, u8 *iv,
142 		void *keys, u8 *out, unsigned int num_bytes);
143 asmlinkage void aes_ctr_enc_256_avx_by8(const u8 *in, u8 *iv,
144 		void *keys, u8 *out, unsigned int num_bytes);
145 
146 
147 asmlinkage void aes_xctr_enc_128_avx_by8(const u8 *in, const u8 *iv,
148 	const void *keys, u8 *out, unsigned int num_bytes,
149 	unsigned int byte_ctr);
150 
151 asmlinkage void aes_xctr_enc_192_avx_by8(const u8 *in, const u8 *iv,
152 	const void *keys, u8 *out, unsigned int num_bytes,
153 	unsigned int byte_ctr);
154 
155 asmlinkage void aes_xctr_enc_256_avx_by8(const u8 *in, const u8 *iv,
156 	const void *keys, u8 *out, unsigned int num_bytes,
157 	unsigned int byte_ctr);
158 
159 /*
160  * asmlinkage void aesni_gcm_init_avx_gen2()
161  * gcm_data *my_ctx_data, context data
162  * u8 *hash_subkey,  the Hash sub key input. Data starts on a 16-byte boundary.
163  */
164 asmlinkage void aesni_gcm_init_avx_gen2(void *my_ctx_data,
165 					struct gcm_context_data *gdata,
166 					u8 *iv,
167 					u8 *hash_subkey,
168 					const u8 *aad,
169 					unsigned long aad_len);
170 
171 asmlinkage void aesni_gcm_enc_update_avx_gen2(void *ctx,
172 				     struct gcm_context_data *gdata, u8 *out,
173 				     const u8 *in, unsigned long plaintext_len);
174 asmlinkage void aesni_gcm_dec_update_avx_gen2(void *ctx,
175 				     struct gcm_context_data *gdata, u8 *out,
176 				     const u8 *in,
177 				     unsigned long ciphertext_len);
178 asmlinkage void aesni_gcm_finalize_avx_gen2(void *ctx,
179 				   struct gcm_context_data *gdata,
180 				   u8 *auth_tag, unsigned long auth_tag_len);
181 
182 /*
183  * asmlinkage void aesni_gcm_init_avx_gen4()
184  * gcm_data *my_ctx_data, context data
185  * u8 *hash_subkey,  the Hash sub key input. Data starts on a 16-byte boundary.
186  */
187 asmlinkage void aesni_gcm_init_avx_gen4(void *my_ctx_data,
188 					struct gcm_context_data *gdata,
189 					u8 *iv,
190 					u8 *hash_subkey,
191 					const u8 *aad,
192 					unsigned long aad_len);
193 
194 asmlinkage void aesni_gcm_enc_update_avx_gen4(void *ctx,
195 				     struct gcm_context_data *gdata, u8 *out,
196 				     const u8 *in, unsigned long plaintext_len);
197 asmlinkage void aesni_gcm_dec_update_avx_gen4(void *ctx,
198 				     struct gcm_context_data *gdata, u8 *out,
199 				     const u8 *in,
200 				     unsigned long ciphertext_len);
201 asmlinkage void aesni_gcm_finalize_avx_gen4(void *ctx,
202 				   struct gcm_context_data *gdata,
203 				   u8 *auth_tag, unsigned long auth_tag_len);
204 
205 static __ro_after_init DEFINE_STATIC_KEY_FALSE(gcm_use_avx);
206 static __ro_after_init DEFINE_STATIC_KEY_FALSE(gcm_use_avx2);
207 
208 static inline struct
209 aesni_rfc4106_gcm_ctx *aesni_rfc4106_gcm_ctx_get(struct crypto_aead *tfm)
210 {
211 	return aes_align_addr(crypto_aead_ctx(tfm));
212 }
213 
214 static inline struct
215 generic_gcmaes_ctx *generic_gcmaes_ctx_get(struct crypto_aead *tfm)
216 {
217 	return aes_align_addr(crypto_aead_ctx(tfm));
218 }
219 #endif
220 
221 static inline struct crypto_aes_ctx *aes_ctx(void *raw_ctx)
222 {
223 	return aes_align_addr(raw_ctx);
224 }
225 
226 static inline struct aesni_xts_ctx *aes_xts_ctx(struct crypto_skcipher *tfm)
227 {
228 	return aes_align_addr(crypto_skcipher_ctx(tfm));
229 }
230 
231 static int aes_set_key_common(struct crypto_aes_ctx *ctx,
232 			      const u8 *in_key, unsigned int key_len)
233 {
234 	int err;
235 
236 	if (!crypto_simd_usable())
237 		return aes_expandkey(ctx, in_key, key_len);
238 
239 	err = aes_check_keylen(key_len);
240 	if (err)
241 		return err;
242 
243 	kernel_fpu_begin();
244 	aesni_set_key(ctx, in_key, key_len);
245 	kernel_fpu_end();
246 	return 0;
247 }
248 
249 static int aes_set_key(struct crypto_tfm *tfm, const u8 *in_key,
250 		       unsigned int key_len)
251 {
252 	return aes_set_key_common(aes_ctx(crypto_tfm_ctx(tfm)), in_key,
253 				  key_len);
254 }
255 
256 static void aesni_encrypt(struct crypto_tfm *tfm, u8 *dst, const u8 *src)
257 {
258 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_tfm_ctx(tfm));
259 
260 	if (!crypto_simd_usable()) {
261 		aes_encrypt(ctx, dst, src);
262 	} else {
263 		kernel_fpu_begin();
264 		aesni_enc(ctx, dst, src);
265 		kernel_fpu_end();
266 	}
267 }
268 
269 static void aesni_decrypt(struct crypto_tfm *tfm, u8 *dst, const u8 *src)
270 {
271 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_tfm_ctx(tfm));
272 
273 	if (!crypto_simd_usable()) {
274 		aes_decrypt(ctx, dst, src);
275 	} else {
276 		kernel_fpu_begin();
277 		aesni_dec(ctx, dst, src);
278 		kernel_fpu_end();
279 	}
280 }
281 
282 static int aesni_skcipher_setkey(struct crypto_skcipher *tfm, const u8 *key,
283 			         unsigned int len)
284 {
285 	return aes_set_key_common(aes_ctx(crypto_skcipher_ctx(tfm)), key, len);
286 }
287 
288 static int ecb_encrypt(struct skcipher_request *req)
289 {
290 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
291 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
292 	struct skcipher_walk walk;
293 	unsigned int nbytes;
294 	int err;
295 
296 	err = skcipher_walk_virt(&walk, req, false);
297 
298 	while ((nbytes = walk.nbytes)) {
299 		kernel_fpu_begin();
300 		aesni_ecb_enc(ctx, walk.dst.virt.addr, walk.src.virt.addr,
301 			      nbytes & AES_BLOCK_MASK);
302 		kernel_fpu_end();
303 		nbytes &= AES_BLOCK_SIZE - 1;
304 		err = skcipher_walk_done(&walk, nbytes);
305 	}
306 
307 	return err;
308 }
309 
310 static int ecb_decrypt(struct skcipher_request *req)
311 {
312 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
313 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
314 	struct skcipher_walk walk;
315 	unsigned int nbytes;
316 	int err;
317 
318 	err = skcipher_walk_virt(&walk, req, false);
319 
320 	while ((nbytes = walk.nbytes)) {
321 		kernel_fpu_begin();
322 		aesni_ecb_dec(ctx, walk.dst.virt.addr, walk.src.virt.addr,
323 			      nbytes & AES_BLOCK_MASK);
324 		kernel_fpu_end();
325 		nbytes &= AES_BLOCK_SIZE - 1;
326 		err = skcipher_walk_done(&walk, nbytes);
327 	}
328 
329 	return err;
330 }
331 
332 static int cbc_encrypt(struct skcipher_request *req)
333 {
334 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
335 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
336 	struct skcipher_walk walk;
337 	unsigned int nbytes;
338 	int err;
339 
340 	err = skcipher_walk_virt(&walk, req, false);
341 
342 	while ((nbytes = walk.nbytes)) {
343 		kernel_fpu_begin();
344 		aesni_cbc_enc(ctx, walk.dst.virt.addr, walk.src.virt.addr,
345 			      nbytes & AES_BLOCK_MASK, walk.iv);
346 		kernel_fpu_end();
347 		nbytes &= AES_BLOCK_SIZE - 1;
348 		err = skcipher_walk_done(&walk, nbytes);
349 	}
350 
351 	return err;
352 }
353 
354 static int cbc_decrypt(struct skcipher_request *req)
355 {
356 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
357 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
358 	struct skcipher_walk walk;
359 	unsigned int nbytes;
360 	int err;
361 
362 	err = skcipher_walk_virt(&walk, req, false);
363 
364 	while ((nbytes = walk.nbytes)) {
365 		kernel_fpu_begin();
366 		aesni_cbc_dec(ctx, walk.dst.virt.addr, walk.src.virt.addr,
367 			      nbytes & AES_BLOCK_MASK, walk.iv);
368 		kernel_fpu_end();
369 		nbytes &= AES_BLOCK_SIZE - 1;
370 		err = skcipher_walk_done(&walk, nbytes);
371 	}
372 
373 	return err;
374 }
375 
376 static int cts_cbc_encrypt(struct skcipher_request *req)
377 {
378 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
379 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
380 	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
381 	struct scatterlist *src = req->src, *dst = req->dst;
382 	struct scatterlist sg_src[2], sg_dst[2];
383 	struct skcipher_request subreq;
384 	struct skcipher_walk walk;
385 	int err;
386 
387 	skcipher_request_set_tfm(&subreq, tfm);
388 	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
389 				      NULL, NULL);
390 
391 	if (req->cryptlen <= AES_BLOCK_SIZE) {
392 		if (req->cryptlen < AES_BLOCK_SIZE)
393 			return -EINVAL;
394 		cbc_blocks = 1;
395 	}
396 
397 	if (cbc_blocks > 0) {
398 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
399 					   cbc_blocks * AES_BLOCK_SIZE,
400 					   req->iv);
401 
402 		err = cbc_encrypt(&subreq);
403 		if (err)
404 			return err;
405 
406 		if (req->cryptlen == AES_BLOCK_SIZE)
407 			return 0;
408 
409 		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
410 		if (req->dst != req->src)
411 			dst = scatterwalk_ffwd(sg_dst, req->dst,
412 					       subreq.cryptlen);
413 	}
414 
415 	/* handle ciphertext stealing */
416 	skcipher_request_set_crypt(&subreq, src, dst,
417 				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
418 				   req->iv);
419 
420 	err = skcipher_walk_virt(&walk, &subreq, false);
421 	if (err)
422 		return err;
423 
424 	kernel_fpu_begin();
425 	aesni_cts_cbc_enc(ctx, walk.dst.virt.addr, walk.src.virt.addr,
426 			  walk.nbytes, walk.iv);
427 	kernel_fpu_end();
428 
429 	return skcipher_walk_done(&walk, 0);
430 }
431 
432 static int cts_cbc_decrypt(struct skcipher_request *req)
433 {
434 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
435 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
436 	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
437 	struct scatterlist *src = req->src, *dst = req->dst;
438 	struct scatterlist sg_src[2], sg_dst[2];
439 	struct skcipher_request subreq;
440 	struct skcipher_walk walk;
441 	int err;
442 
443 	skcipher_request_set_tfm(&subreq, tfm);
444 	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
445 				      NULL, NULL);
446 
447 	if (req->cryptlen <= AES_BLOCK_SIZE) {
448 		if (req->cryptlen < AES_BLOCK_SIZE)
449 			return -EINVAL;
450 		cbc_blocks = 1;
451 	}
452 
453 	if (cbc_blocks > 0) {
454 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
455 					   cbc_blocks * AES_BLOCK_SIZE,
456 					   req->iv);
457 
458 		err = cbc_decrypt(&subreq);
459 		if (err)
460 			return err;
461 
462 		if (req->cryptlen == AES_BLOCK_SIZE)
463 			return 0;
464 
465 		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
466 		if (req->dst != req->src)
467 			dst = scatterwalk_ffwd(sg_dst, req->dst,
468 					       subreq.cryptlen);
469 	}
470 
471 	/* handle ciphertext stealing */
472 	skcipher_request_set_crypt(&subreq, src, dst,
473 				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
474 				   req->iv);
475 
476 	err = skcipher_walk_virt(&walk, &subreq, false);
477 	if (err)
478 		return err;
479 
480 	kernel_fpu_begin();
481 	aesni_cts_cbc_dec(ctx, walk.dst.virt.addr, walk.src.virt.addr,
482 			  walk.nbytes, walk.iv);
483 	kernel_fpu_end();
484 
485 	return skcipher_walk_done(&walk, 0);
486 }
487 
488 #ifdef CONFIG_X86_64
489 static void aesni_ctr_enc_avx_tfm(struct crypto_aes_ctx *ctx, u8 *out,
490 			      const u8 *in, unsigned int len, u8 *iv)
491 {
492 	/*
493 	 * based on key length, override with the by8 version
494 	 * of ctr mode encryption/decryption for improved performance
495 	 * aes_set_key_common() ensures that key length is one of
496 	 * {128,192,256}
497 	 */
498 	if (ctx->key_length == AES_KEYSIZE_128)
499 		aes_ctr_enc_128_avx_by8(in, iv, (void *)ctx, out, len);
500 	else if (ctx->key_length == AES_KEYSIZE_192)
501 		aes_ctr_enc_192_avx_by8(in, iv, (void *)ctx, out, len);
502 	else
503 		aes_ctr_enc_256_avx_by8(in, iv, (void *)ctx, out, len);
504 }
505 
506 static int ctr_crypt(struct skcipher_request *req)
507 {
508 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
509 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
510 	u8 keystream[AES_BLOCK_SIZE];
511 	struct skcipher_walk walk;
512 	unsigned int nbytes;
513 	int err;
514 
515 	err = skcipher_walk_virt(&walk, req, false);
516 
517 	while ((nbytes = walk.nbytes) > 0) {
518 		kernel_fpu_begin();
519 		if (nbytes & AES_BLOCK_MASK)
520 			static_call(aesni_ctr_enc_tfm)(ctx, walk.dst.virt.addr,
521 						       walk.src.virt.addr,
522 						       nbytes & AES_BLOCK_MASK,
523 						       walk.iv);
524 		nbytes &= ~AES_BLOCK_MASK;
525 
526 		if (walk.nbytes == walk.total && nbytes > 0) {
527 			aesni_enc(ctx, keystream, walk.iv);
528 			crypto_xor_cpy(walk.dst.virt.addr + walk.nbytes - nbytes,
529 				       walk.src.virt.addr + walk.nbytes - nbytes,
530 				       keystream, nbytes);
531 			crypto_inc(walk.iv, AES_BLOCK_SIZE);
532 			nbytes = 0;
533 		}
534 		kernel_fpu_end();
535 		err = skcipher_walk_done(&walk, nbytes);
536 	}
537 	return err;
538 }
539 
540 static void aesni_xctr_enc_avx_tfm(struct crypto_aes_ctx *ctx, u8 *out,
541 				   const u8 *in, unsigned int len, u8 *iv,
542 				   unsigned int byte_ctr)
543 {
544 	if (ctx->key_length == AES_KEYSIZE_128)
545 		aes_xctr_enc_128_avx_by8(in, iv, (void *)ctx, out, len,
546 					 byte_ctr);
547 	else if (ctx->key_length == AES_KEYSIZE_192)
548 		aes_xctr_enc_192_avx_by8(in, iv, (void *)ctx, out, len,
549 					 byte_ctr);
550 	else
551 		aes_xctr_enc_256_avx_by8(in, iv, (void *)ctx, out, len,
552 					 byte_ctr);
553 }
554 
555 static int xctr_crypt(struct skcipher_request *req)
556 {
557 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
558 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
559 	u8 keystream[AES_BLOCK_SIZE];
560 	struct skcipher_walk walk;
561 	unsigned int nbytes;
562 	unsigned int byte_ctr = 0;
563 	int err;
564 	__le32 block[AES_BLOCK_SIZE / sizeof(__le32)];
565 
566 	err = skcipher_walk_virt(&walk, req, false);
567 
568 	while ((nbytes = walk.nbytes) > 0) {
569 		kernel_fpu_begin();
570 		if (nbytes & AES_BLOCK_MASK)
571 			aesni_xctr_enc_avx_tfm(ctx, walk.dst.virt.addr,
572 				walk.src.virt.addr, nbytes & AES_BLOCK_MASK,
573 				walk.iv, byte_ctr);
574 		nbytes &= ~AES_BLOCK_MASK;
575 		byte_ctr += walk.nbytes - nbytes;
576 
577 		if (walk.nbytes == walk.total && nbytes > 0) {
578 			memcpy(block, walk.iv, AES_BLOCK_SIZE);
579 			block[0] ^= cpu_to_le32(1 + byte_ctr / AES_BLOCK_SIZE);
580 			aesni_enc(ctx, keystream, (u8 *)block);
581 			crypto_xor_cpy(walk.dst.virt.addr + walk.nbytes -
582 				       nbytes, walk.src.virt.addr + walk.nbytes
583 				       - nbytes, keystream, nbytes);
584 			byte_ctr += nbytes;
585 			nbytes = 0;
586 		}
587 		kernel_fpu_end();
588 		err = skcipher_walk_done(&walk, nbytes);
589 	}
590 	return err;
591 }
592 
593 static int
594 rfc4106_set_hash_subkey(u8 *hash_subkey, const u8 *key, unsigned int key_len)
595 {
596 	struct crypto_aes_ctx ctx;
597 	int ret;
598 
599 	ret = aes_expandkey(&ctx, key, key_len);
600 	if (ret)
601 		return ret;
602 
603 	/* Clear the data in the hash sub key container to zero.*/
604 	/* We want to cipher all zeros to create the hash sub key. */
605 	memset(hash_subkey, 0, RFC4106_HASH_SUBKEY_SIZE);
606 
607 	aes_encrypt(&ctx, hash_subkey, hash_subkey);
608 
609 	memzero_explicit(&ctx, sizeof(ctx));
610 	return 0;
611 }
612 
613 static int common_rfc4106_set_key(struct crypto_aead *aead, const u8 *key,
614 				  unsigned int key_len)
615 {
616 	struct aesni_rfc4106_gcm_ctx *ctx = aesni_rfc4106_gcm_ctx_get(aead);
617 
618 	if (key_len < 4)
619 		return -EINVAL;
620 
621 	/*Account for 4 byte nonce at the end.*/
622 	key_len -= 4;
623 
624 	memcpy(ctx->nonce, key + key_len, sizeof(ctx->nonce));
625 
626 	return aes_set_key_common(&ctx->aes_key_expanded, key, key_len) ?:
627 	       rfc4106_set_hash_subkey(ctx->hash_subkey, key, key_len);
628 }
629 
630 /* This is the Integrity Check Value (aka the authentication tag) length and can
631  * be 8, 12 or 16 bytes long. */
632 static int common_rfc4106_set_authsize(struct crypto_aead *aead,
633 				       unsigned int authsize)
634 {
635 	switch (authsize) {
636 	case 8:
637 	case 12:
638 	case 16:
639 		break;
640 	default:
641 		return -EINVAL;
642 	}
643 
644 	return 0;
645 }
646 
647 static int generic_gcmaes_set_authsize(struct crypto_aead *tfm,
648 				       unsigned int authsize)
649 {
650 	switch (authsize) {
651 	case 4:
652 	case 8:
653 	case 12:
654 	case 13:
655 	case 14:
656 	case 15:
657 	case 16:
658 		break;
659 	default:
660 		return -EINVAL;
661 	}
662 
663 	return 0;
664 }
665 
666 static int gcmaes_crypt_by_sg(bool enc, struct aead_request *req,
667 			      unsigned int assoclen, u8 *hash_subkey,
668 			      u8 *iv, void *aes_ctx, u8 *auth_tag,
669 			      unsigned long auth_tag_len)
670 {
671 	u8 databuf[sizeof(struct gcm_context_data) + (AESNI_ALIGN - 8)] __aligned(8);
672 	struct gcm_context_data *data = PTR_ALIGN((void *)databuf, AESNI_ALIGN);
673 	unsigned long left = req->cryptlen;
674 	struct scatter_walk assoc_sg_walk;
675 	struct skcipher_walk walk;
676 	bool do_avx, do_avx2;
677 	u8 *assocmem = NULL;
678 	u8 *assoc;
679 	int err;
680 
681 	if (!enc)
682 		left -= auth_tag_len;
683 
684 	do_avx = (left >= AVX_GEN2_OPTSIZE);
685 	do_avx2 = (left >= AVX_GEN4_OPTSIZE);
686 
687 	/* Linearize assoc, if not already linear */
688 	if (req->src->length >= assoclen && req->src->length) {
689 		scatterwalk_start(&assoc_sg_walk, req->src);
690 		assoc = scatterwalk_map(&assoc_sg_walk);
691 	} else {
692 		gfp_t flags = (req->base.flags & CRYPTO_TFM_REQ_MAY_SLEEP) ?
693 			      GFP_KERNEL : GFP_ATOMIC;
694 
695 		/* assoc can be any length, so must be on heap */
696 		assocmem = kmalloc(assoclen, flags);
697 		if (unlikely(!assocmem))
698 			return -ENOMEM;
699 		assoc = assocmem;
700 
701 		scatterwalk_map_and_copy(assoc, req->src, 0, assoclen, 0);
702 	}
703 
704 	kernel_fpu_begin();
705 	if (static_branch_likely(&gcm_use_avx2) && do_avx2)
706 		aesni_gcm_init_avx_gen4(aes_ctx, data, iv, hash_subkey, assoc,
707 					assoclen);
708 	else if (static_branch_likely(&gcm_use_avx) && do_avx)
709 		aesni_gcm_init_avx_gen2(aes_ctx, data, iv, hash_subkey, assoc,
710 					assoclen);
711 	else
712 		aesni_gcm_init(aes_ctx, data, iv, hash_subkey, assoc, assoclen);
713 	kernel_fpu_end();
714 
715 	if (!assocmem)
716 		scatterwalk_unmap(assoc);
717 	else
718 		kfree(assocmem);
719 
720 	err = enc ? skcipher_walk_aead_encrypt(&walk, req, false)
721 		  : skcipher_walk_aead_decrypt(&walk, req, false);
722 
723 	while (walk.nbytes > 0) {
724 		kernel_fpu_begin();
725 		if (static_branch_likely(&gcm_use_avx2) && do_avx2) {
726 			if (enc)
727 				aesni_gcm_enc_update_avx_gen4(aes_ctx, data,
728 							      walk.dst.virt.addr,
729 							      walk.src.virt.addr,
730 							      walk.nbytes);
731 			else
732 				aesni_gcm_dec_update_avx_gen4(aes_ctx, data,
733 							      walk.dst.virt.addr,
734 							      walk.src.virt.addr,
735 							      walk.nbytes);
736 		} else if (static_branch_likely(&gcm_use_avx) && do_avx) {
737 			if (enc)
738 				aesni_gcm_enc_update_avx_gen2(aes_ctx, data,
739 							      walk.dst.virt.addr,
740 							      walk.src.virt.addr,
741 							      walk.nbytes);
742 			else
743 				aesni_gcm_dec_update_avx_gen2(aes_ctx, data,
744 							      walk.dst.virt.addr,
745 							      walk.src.virt.addr,
746 							      walk.nbytes);
747 		} else if (enc) {
748 			aesni_gcm_enc_update(aes_ctx, data, walk.dst.virt.addr,
749 					     walk.src.virt.addr, walk.nbytes);
750 		} else {
751 			aesni_gcm_dec_update(aes_ctx, data, walk.dst.virt.addr,
752 					     walk.src.virt.addr, walk.nbytes);
753 		}
754 		kernel_fpu_end();
755 
756 		err = skcipher_walk_done(&walk, 0);
757 	}
758 
759 	if (err)
760 		return err;
761 
762 	kernel_fpu_begin();
763 	if (static_branch_likely(&gcm_use_avx2) && do_avx2)
764 		aesni_gcm_finalize_avx_gen4(aes_ctx, data, auth_tag,
765 					    auth_tag_len);
766 	else if (static_branch_likely(&gcm_use_avx) && do_avx)
767 		aesni_gcm_finalize_avx_gen2(aes_ctx, data, auth_tag,
768 					    auth_tag_len);
769 	else
770 		aesni_gcm_finalize(aes_ctx, data, auth_tag, auth_tag_len);
771 	kernel_fpu_end();
772 
773 	return 0;
774 }
775 
776 static int gcmaes_encrypt(struct aead_request *req, unsigned int assoclen,
777 			  u8 *hash_subkey, u8 *iv, void *aes_ctx)
778 {
779 	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
780 	unsigned long auth_tag_len = crypto_aead_authsize(tfm);
781 	u8 auth_tag[16];
782 	int err;
783 
784 	err = gcmaes_crypt_by_sg(true, req, assoclen, hash_subkey, iv, aes_ctx,
785 				 auth_tag, auth_tag_len);
786 	if (err)
787 		return err;
788 
789 	scatterwalk_map_and_copy(auth_tag, req->dst,
790 				 req->assoclen + req->cryptlen,
791 				 auth_tag_len, 1);
792 	return 0;
793 }
794 
795 static int gcmaes_decrypt(struct aead_request *req, unsigned int assoclen,
796 			  u8 *hash_subkey, u8 *iv, void *aes_ctx)
797 {
798 	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
799 	unsigned long auth_tag_len = crypto_aead_authsize(tfm);
800 	u8 auth_tag_msg[16];
801 	u8 auth_tag[16];
802 	int err;
803 
804 	err = gcmaes_crypt_by_sg(false, req, assoclen, hash_subkey, iv, aes_ctx,
805 				 auth_tag, auth_tag_len);
806 	if (err)
807 		return err;
808 
809 	/* Copy out original auth_tag */
810 	scatterwalk_map_and_copy(auth_tag_msg, req->src,
811 				 req->assoclen + req->cryptlen - auth_tag_len,
812 				 auth_tag_len, 0);
813 
814 	/* Compare generated tag with passed in tag. */
815 	if (crypto_memneq(auth_tag_msg, auth_tag, auth_tag_len)) {
816 		memzero_explicit(auth_tag, sizeof(auth_tag));
817 		return -EBADMSG;
818 	}
819 	return 0;
820 }
821 
822 static int helper_rfc4106_encrypt(struct aead_request *req)
823 {
824 	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
825 	struct aesni_rfc4106_gcm_ctx *ctx = aesni_rfc4106_gcm_ctx_get(tfm);
826 	void *aes_ctx = &(ctx->aes_key_expanded);
827 	u8 ivbuf[16 + (AESNI_ALIGN - 8)] __aligned(8);
828 	u8 *iv = PTR_ALIGN(&ivbuf[0], AESNI_ALIGN);
829 	unsigned int i;
830 	__be32 counter = cpu_to_be32(1);
831 
832 	/* Assuming we are supporting rfc4106 64-bit extended */
833 	/* sequence numbers We need to have the AAD length equal */
834 	/* to 16 or 20 bytes */
835 	if (unlikely(req->assoclen != 16 && req->assoclen != 20))
836 		return -EINVAL;
837 
838 	/* IV below built */
839 	for (i = 0; i < 4; i++)
840 		*(iv+i) = ctx->nonce[i];
841 	for (i = 0; i < 8; i++)
842 		*(iv+4+i) = req->iv[i];
843 	*((__be32 *)(iv+12)) = counter;
844 
845 	return gcmaes_encrypt(req, req->assoclen - 8, ctx->hash_subkey, iv,
846 			      aes_ctx);
847 }
848 
849 static int helper_rfc4106_decrypt(struct aead_request *req)
850 {
851 	__be32 counter = cpu_to_be32(1);
852 	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
853 	struct aesni_rfc4106_gcm_ctx *ctx = aesni_rfc4106_gcm_ctx_get(tfm);
854 	void *aes_ctx = &(ctx->aes_key_expanded);
855 	u8 ivbuf[16 + (AESNI_ALIGN - 8)] __aligned(8);
856 	u8 *iv = PTR_ALIGN(&ivbuf[0], AESNI_ALIGN);
857 	unsigned int i;
858 
859 	if (unlikely(req->assoclen != 16 && req->assoclen != 20))
860 		return -EINVAL;
861 
862 	/* Assuming we are supporting rfc4106 64-bit extended */
863 	/* sequence numbers We need to have the AAD length */
864 	/* equal to 16 or 20 bytes */
865 
866 	/* IV below built */
867 	for (i = 0; i < 4; i++)
868 		*(iv+i) = ctx->nonce[i];
869 	for (i = 0; i < 8; i++)
870 		*(iv+4+i) = req->iv[i];
871 	*((__be32 *)(iv+12)) = counter;
872 
873 	return gcmaes_decrypt(req, req->assoclen - 8, ctx->hash_subkey, iv,
874 			      aes_ctx);
875 }
876 #endif
877 
878 static int xts_setkey_aesni(struct crypto_skcipher *tfm, const u8 *key,
879 			    unsigned int keylen)
880 {
881 	struct aesni_xts_ctx *ctx = aes_xts_ctx(tfm);
882 	int err;
883 
884 	err = xts_verify_key(tfm, key, keylen);
885 	if (err)
886 		return err;
887 
888 	keylen /= 2;
889 
890 	/* first half of xts-key is for crypt */
891 	err = aes_set_key_common(&ctx->crypt_ctx, key, keylen);
892 	if (err)
893 		return err;
894 
895 	/* second half of xts-key is for tweak */
896 	return aes_set_key_common(&ctx->tweak_ctx, key + keylen, keylen);
897 }
898 
899 typedef void (*xts_encrypt_iv_func)(const struct crypto_aes_ctx *tweak_key,
900 				    u8 iv[AES_BLOCK_SIZE]);
901 typedef void (*xts_crypt_func)(const struct crypto_aes_ctx *key,
902 			       const u8 *src, u8 *dst, unsigned int len,
903 			       u8 tweak[AES_BLOCK_SIZE]);
904 
905 /* This handles cases where the source and/or destination span pages. */
906 static noinline int
907 xts_crypt_slowpath(struct skcipher_request *req, xts_crypt_func crypt_func)
908 {
909 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
910 	const struct aesni_xts_ctx *ctx = aes_xts_ctx(tfm);
911 	int tail = req->cryptlen % AES_BLOCK_SIZE;
912 	struct scatterlist sg_src[2], sg_dst[2];
913 	struct skcipher_request subreq;
914 	struct skcipher_walk walk;
915 	struct scatterlist *src, *dst;
916 	int err;
917 
918 	/*
919 	 * If the message length isn't divisible by the AES block size, then
920 	 * separate off the last full block and the partial block.  This ensures
921 	 * that they are processed in the same call to the assembly function,
922 	 * which is required for ciphertext stealing.
923 	 */
924 	if (tail) {
925 		skcipher_request_set_tfm(&subreq, tfm);
926 		skcipher_request_set_callback(&subreq,
927 					      skcipher_request_flags(req),
928 					      NULL, NULL);
929 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
930 					   req->cryptlen - tail - AES_BLOCK_SIZE,
931 					   req->iv);
932 		req = &subreq;
933 	}
934 
935 	err = skcipher_walk_virt(&walk, req, false);
936 
937 	while (walk.nbytes) {
938 		unsigned int nbytes = walk.nbytes;
939 
940 		if (nbytes < walk.total)
941 			nbytes = round_down(nbytes, AES_BLOCK_SIZE);
942 
943 		kernel_fpu_begin();
944 		(*crypt_func)(&ctx->crypt_ctx, walk.src.virt.addr,
945 			      walk.dst.virt.addr, nbytes, req->iv);
946 		kernel_fpu_end();
947 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
948 	}
949 
950 	if (err || !tail)
951 		return err;
952 
953 	/* Do ciphertext stealing with the last full block and partial block. */
954 
955 	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
956 	if (req->dst != req->src)
957 		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
958 
959 	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
960 				   req->iv);
961 
962 	err = skcipher_walk_virt(&walk, req, false);
963 	if (err)
964 		return err;
965 
966 	kernel_fpu_begin();
967 	(*crypt_func)(&ctx->crypt_ctx, walk.src.virt.addr, walk.dst.virt.addr,
968 		      walk.nbytes, req->iv);
969 	kernel_fpu_end();
970 
971 	return skcipher_walk_done(&walk, 0);
972 }
973 
974 /* __always_inline to avoid indirect call in fastpath */
975 static __always_inline int
976 xts_crypt(struct skcipher_request *req, xts_encrypt_iv_func encrypt_iv,
977 	  xts_crypt_func crypt_func)
978 {
979 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
980 	const struct aesni_xts_ctx *ctx = aes_xts_ctx(tfm);
981 	const unsigned int cryptlen = req->cryptlen;
982 	struct scatterlist *src = req->src;
983 	struct scatterlist *dst = req->dst;
984 
985 	if (unlikely(cryptlen < AES_BLOCK_SIZE))
986 		return -EINVAL;
987 
988 	kernel_fpu_begin();
989 	(*encrypt_iv)(&ctx->tweak_ctx, req->iv);
990 
991 	/*
992 	 * In practice, virtually all XTS plaintexts and ciphertexts are either
993 	 * 512 or 4096 bytes, aligned such that they don't span page boundaries.
994 	 * To optimize the performance of these cases, and also any other case
995 	 * where no page boundary is spanned, the below fast-path handles
996 	 * single-page sources and destinations as efficiently as possible.
997 	 */
998 	if (likely(src->length >= cryptlen && dst->length >= cryptlen &&
999 		   src->offset + cryptlen <= PAGE_SIZE &&
1000 		   dst->offset + cryptlen <= PAGE_SIZE)) {
1001 		struct page *src_page = sg_page(src);
1002 		struct page *dst_page = sg_page(dst);
1003 		void *src_virt = kmap_local_page(src_page) + src->offset;
1004 		void *dst_virt = kmap_local_page(dst_page) + dst->offset;
1005 
1006 		(*crypt_func)(&ctx->crypt_ctx, src_virt, dst_virt, cryptlen,
1007 			      req->iv);
1008 		kunmap_local(dst_virt);
1009 		kunmap_local(src_virt);
1010 		kernel_fpu_end();
1011 		return 0;
1012 	}
1013 	kernel_fpu_end();
1014 	return xts_crypt_slowpath(req, crypt_func);
1015 }
1016 
1017 static void aesni_xts_encrypt_iv(const struct crypto_aes_ctx *tweak_key,
1018 				 u8 iv[AES_BLOCK_SIZE])
1019 {
1020 	aesni_enc(tweak_key, iv, iv);
1021 }
1022 
1023 static void aesni_xts_encrypt(const struct crypto_aes_ctx *key,
1024 			      const u8 *src, u8 *dst, unsigned int len,
1025 			      u8 tweak[AES_BLOCK_SIZE])
1026 {
1027 	aesni_xts_enc(key, dst, src, len, tweak);
1028 }
1029 
1030 static void aesni_xts_decrypt(const struct crypto_aes_ctx *key,
1031 			      const u8 *src, u8 *dst, unsigned int len,
1032 			      u8 tweak[AES_BLOCK_SIZE])
1033 {
1034 	aesni_xts_dec(key, dst, src, len, tweak);
1035 }
1036 
1037 static int xts_encrypt_aesni(struct skcipher_request *req)
1038 {
1039 	return xts_crypt(req, aesni_xts_encrypt_iv, aesni_xts_encrypt);
1040 }
1041 
1042 static int xts_decrypt_aesni(struct skcipher_request *req)
1043 {
1044 	return xts_crypt(req, aesni_xts_encrypt_iv, aesni_xts_decrypt);
1045 }
1046 
1047 static struct crypto_alg aesni_cipher_alg = {
1048 	.cra_name		= "aes",
1049 	.cra_driver_name	= "aes-aesni",
1050 	.cra_priority		= 300,
1051 	.cra_flags		= CRYPTO_ALG_TYPE_CIPHER,
1052 	.cra_blocksize		= AES_BLOCK_SIZE,
1053 	.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
1054 	.cra_module		= THIS_MODULE,
1055 	.cra_u	= {
1056 		.cipher	= {
1057 			.cia_min_keysize	= AES_MIN_KEY_SIZE,
1058 			.cia_max_keysize	= AES_MAX_KEY_SIZE,
1059 			.cia_setkey		= aes_set_key,
1060 			.cia_encrypt		= aesni_encrypt,
1061 			.cia_decrypt		= aesni_decrypt
1062 		}
1063 	}
1064 };
1065 
1066 static struct skcipher_alg aesni_skciphers[] = {
1067 	{
1068 		.base = {
1069 			.cra_name		= "__ecb(aes)",
1070 			.cra_driver_name	= "__ecb-aes-aesni",
1071 			.cra_priority		= 400,
1072 			.cra_flags		= CRYPTO_ALG_INTERNAL,
1073 			.cra_blocksize		= AES_BLOCK_SIZE,
1074 			.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
1075 			.cra_module		= THIS_MODULE,
1076 		},
1077 		.min_keysize	= AES_MIN_KEY_SIZE,
1078 		.max_keysize	= AES_MAX_KEY_SIZE,
1079 		.setkey		= aesni_skcipher_setkey,
1080 		.encrypt	= ecb_encrypt,
1081 		.decrypt	= ecb_decrypt,
1082 	}, {
1083 		.base = {
1084 			.cra_name		= "__cbc(aes)",
1085 			.cra_driver_name	= "__cbc-aes-aesni",
1086 			.cra_priority		= 400,
1087 			.cra_flags		= CRYPTO_ALG_INTERNAL,
1088 			.cra_blocksize		= AES_BLOCK_SIZE,
1089 			.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
1090 			.cra_module		= THIS_MODULE,
1091 		},
1092 		.min_keysize	= AES_MIN_KEY_SIZE,
1093 		.max_keysize	= AES_MAX_KEY_SIZE,
1094 		.ivsize		= AES_BLOCK_SIZE,
1095 		.setkey		= aesni_skcipher_setkey,
1096 		.encrypt	= cbc_encrypt,
1097 		.decrypt	= cbc_decrypt,
1098 	}, {
1099 		.base = {
1100 			.cra_name		= "__cts(cbc(aes))",
1101 			.cra_driver_name	= "__cts-cbc-aes-aesni",
1102 			.cra_priority		= 400,
1103 			.cra_flags		= CRYPTO_ALG_INTERNAL,
1104 			.cra_blocksize		= AES_BLOCK_SIZE,
1105 			.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
1106 			.cra_module		= THIS_MODULE,
1107 		},
1108 		.min_keysize	= AES_MIN_KEY_SIZE,
1109 		.max_keysize	= AES_MAX_KEY_SIZE,
1110 		.ivsize		= AES_BLOCK_SIZE,
1111 		.walksize	= 2 * AES_BLOCK_SIZE,
1112 		.setkey		= aesni_skcipher_setkey,
1113 		.encrypt	= cts_cbc_encrypt,
1114 		.decrypt	= cts_cbc_decrypt,
1115 #ifdef CONFIG_X86_64
1116 	}, {
1117 		.base = {
1118 			.cra_name		= "__ctr(aes)",
1119 			.cra_driver_name	= "__ctr-aes-aesni",
1120 			.cra_priority		= 400,
1121 			.cra_flags		= CRYPTO_ALG_INTERNAL,
1122 			.cra_blocksize		= 1,
1123 			.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
1124 			.cra_module		= THIS_MODULE,
1125 		},
1126 		.min_keysize	= AES_MIN_KEY_SIZE,
1127 		.max_keysize	= AES_MAX_KEY_SIZE,
1128 		.ivsize		= AES_BLOCK_SIZE,
1129 		.chunksize	= AES_BLOCK_SIZE,
1130 		.setkey		= aesni_skcipher_setkey,
1131 		.encrypt	= ctr_crypt,
1132 		.decrypt	= ctr_crypt,
1133 #endif
1134 	}, {
1135 		.base = {
1136 			.cra_name		= "__xts(aes)",
1137 			.cra_driver_name	= "__xts-aes-aesni",
1138 			.cra_priority		= 401,
1139 			.cra_flags		= CRYPTO_ALG_INTERNAL,
1140 			.cra_blocksize		= AES_BLOCK_SIZE,
1141 			.cra_ctxsize		= XTS_AES_CTX_SIZE,
1142 			.cra_module		= THIS_MODULE,
1143 		},
1144 		.min_keysize	= 2 * AES_MIN_KEY_SIZE,
1145 		.max_keysize	= 2 * AES_MAX_KEY_SIZE,
1146 		.ivsize		= AES_BLOCK_SIZE,
1147 		.walksize	= 2 * AES_BLOCK_SIZE,
1148 		.setkey		= xts_setkey_aesni,
1149 		.encrypt	= xts_encrypt_aesni,
1150 		.decrypt	= xts_decrypt_aesni,
1151 	}
1152 };
1153 
1154 static
1155 struct simd_skcipher_alg *aesni_simd_skciphers[ARRAY_SIZE(aesni_skciphers)];
1156 
1157 #ifdef CONFIG_X86_64
1158 /*
1159  * XCTR does not have a non-AVX implementation, so it must be enabled
1160  * conditionally.
1161  */
1162 static struct skcipher_alg aesni_xctr = {
1163 	.base = {
1164 		.cra_name		= "__xctr(aes)",
1165 		.cra_driver_name	= "__xctr-aes-aesni",
1166 		.cra_priority		= 400,
1167 		.cra_flags		= CRYPTO_ALG_INTERNAL,
1168 		.cra_blocksize		= 1,
1169 		.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
1170 		.cra_module		= THIS_MODULE,
1171 	},
1172 	.min_keysize	= AES_MIN_KEY_SIZE,
1173 	.max_keysize	= AES_MAX_KEY_SIZE,
1174 	.ivsize		= AES_BLOCK_SIZE,
1175 	.chunksize	= AES_BLOCK_SIZE,
1176 	.setkey		= aesni_skcipher_setkey,
1177 	.encrypt	= xctr_crypt,
1178 	.decrypt	= xctr_crypt,
1179 };
1180 
1181 static struct simd_skcipher_alg *aesni_simd_xctr;
1182 
1183 asmlinkage void aes_xts_encrypt_iv(const struct crypto_aes_ctx *tweak_key,
1184 				   u8 iv[AES_BLOCK_SIZE]);
1185 
1186 #define DEFINE_XTS_ALG(suffix, driver_name, priority)			       \
1187 									       \
1188 asmlinkage void								       \
1189 aes_xts_encrypt_##suffix(const struct crypto_aes_ctx *key, const u8 *src,      \
1190 			 u8 *dst, unsigned int len, u8 tweak[AES_BLOCK_SIZE]); \
1191 asmlinkage void								       \
1192 aes_xts_decrypt_##suffix(const struct crypto_aes_ctx *key, const u8 *src,      \
1193 			 u8 *dst, unsigned int len, u8 tweak[AES_BLOCK_SIZE]); \
1194 									       \
1195 static int xts_encrypt_##suffix(struct skcipher_request *req)		       \
1196 {									       \
1197 	return xts_crypt(req, aes_xts_encrypt_iv, aes_xts_encrypt_##suffix);   \
1198 }									       \
1199 									       \
1200 static int xts_decrypt_##suffix(struct skcipher_request *req)		       \
1201 {									       \
1202 	return xts_crypt(req, aes_xts_encrypt_iv, aes_xts_decrypt_##suffix);   \
1203 }									       \
1204 									       \
1205 static struct skcipher_alg aes_xts_alg_##suffix = {			       \
1206 	.base = {							       \
1207 		.cra_name		= "__xts(aes)",			       \
1208 		.cra_driver_name	= "__" driver_name,		       \
1209 		.cra_priority		= priority,			       \
1210 		.cra_flags		= CRYPTO_ALG_INTERNAL,		       \
1211 		.cra_blocksize		= AES_BLOCK_SIZE,		       \
1212 		.cra_ctxsize		= XTS_AES_CTX_SIZE,		       \
1213 		.cra_module		= THIS_MODULE,			       \
1214 	},								       \
1215 	.min_keysize	= 2 * AES_MIN_KEY_SIZE,				       \
1216 	.max_keysize	= 2 * AES_MAX_KEY_SIZE,				       \
1217 	.ivsize		= AES_BLOCK_SIZE,				       \
1218 	.walksize	= 2 * AES_BLOCK_SIZE,				       \
1219 	.setkey		= xts_setkey_aesni,				       \
1220 	.encrypt	= xts_encrypt_##suffix,				       \
1221 	.decrypt	= xts_decrypt_##suffix,				       \
1222 };									       \
1223 									       \
1224 static struct simd_skcipher_alg *aes_xts_simdalg_##suffix
1225 
1226 DEFINE_XTS_ALG(aesni_avx, "xts-aes-aesni-avx", 500);
1227 #if defined(CONFIG_AS_VAES) && defined(CONFIG_AS_VPCLMULQDQ)
1228 DEFINE_XTS_ALG(vaes_avx2, "xts-aes-vaes-avx2", 600);
1229 DEFINE_XTS_ALG(vaes_avx10_256, "xts-aes-vaes-avx10_256", 700);
1230 DEFINE_XTS_ALG(vaes_avx10_512, "xts-aes-vaes-avx10_512", 800);
1231 #endif
1232 
1233 /*
1234  * This is a list of CPU models that are known to suffer from downclocking when
1235  * zmm registers (512-bit vectors) are used.  On these CPUs, the AES-XTS
1236  * implementation with zmm registers won't be used by default.  An
1237  * implementation with ymm registers (256-bit vectors) will be used instead.
1238  */
1239 static const struct x86_cpu_id zmm_exclusion_list[] = {
1240 	{ .vendor = X86_VENDOR_INTEL, .family = 6, .model = INTEL_FAM6_SKYLAKE_X },
1241 	{ .vendor = X86_VENDOR_INTEL, .family = 6, .model = INTEL_FAM6_ICELAKE_X },
1242 	{ .vendor = X86_VENDOR_INTEL, .family = 6, .model = INTEL_FAM6_ICELAKE_D },
1243 	{ .vendor = X86_VENDOR_INTEL, .family = 6, .model = INTEL_FAM6_ICELAKE },
1244 	{ .vendor = X86_VENDOR_INTEL, .family = 6, .model = INTEL_FAM6_ICELAKE_L },
1245 	{ .vendor = X86_VENDOR_INTEL, .family = 6, .model = INTEL_FAM6_ICELAKE_NNPI },
1246 	{ .vendor = X86_VENDOR_INTEL, .family = 6, .model = INTEL_FAM6_TIGERLAKE_L },
1247 	{ .vendor = X86_VENDOR_INTEL, .family = 6, .model = INTEL_FAM6_TIGERLAKE },
1248 	/* Allow Rocket Lake and later, and Sapphire Rapids and later. */
1249 	/* Also allow AMD CPUs (starting with Zen 4, the first with AVX-512). */
1250 	{},
1251 };
1252 
1253 static int __init register_xts_algs(void)
1254 {
1255 	int err;
1256 
1257 	if (!boot_cpu_has(X86_FEATURE_AVX))
1258 		return 0;
1259 	err = simd_register_skciphers_compat(&aes_xts_alg_aesni_avx, 1,
1260 					     &aes_xts_simdalg_aesni_avx);
1261 	if (err)
1262 		return err;
1263 #if defined(CONFIG_AS_VAES) && defined(CONFIG_AS_VPCLMULQDQ)
1264 	if (!boot_cpu_has(X86_FEATURE_AVX2) ||
1265 	    !boot_cpu_has(X86_FEATURE_VAES) ||
1266 	    !boot_cpu_has(X86_FEATURE_VPCLMULQDQ) ||
1267 	    !boot_cpu_has(X86_FEATURE_PCLMULQDQ) ||
1268 	    !cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM, NULL))
1269 		return 0;
1270 	err = simd_register_skciphers_compat(&aes_xts_alg_vaes_avx2, 1,
1271 					     &aes_xts_simdalg_vaes_avx2);
1272 	if (err)
1273 		return err;
1274 
1275 	if (!boot_cpu_has(X86_FEATURE_AVX512BW) ||
1276 	    !boot_cpu_has(X86_FEATURE_AVX512VL) ||
1277 	    !boot_cpu_has(X86_FEATURE_BMI2) ||
1278 	    !cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM |
1279 			       XFEATURE_MASK_AVX512, NULL))
1280 		return 0;
1281 
1282 	err = simd_register_skciphers_compat(&aes_xts_alg_vaes_avx10_256, 1,
1283 					     &aes_xts_simdalg_vaes_avx10_256);
1284 	if (err)
1285 		return err;
1286 
1287 	if (x86_match_cpu(zmm_exclusion_list))
1288 		aes_xts_alg_vaes_avx10_512.base.cra_priority = 1;
1289 
1290 	err = simd_register_skciphers_compat(&aes_xts_alg_vaes_avx10_512, 1,
1291 					     &aes_xts_simdalg_vaes_avx10_512);
1292 	if (err)
1293 		return err;
1294 #endif /* CONFIG_AS_VAES && CONFIG_AS_VPCLMULQDQ */
1295 	return 0;
1296 }
1297 
1298 static void unregister_xts_algs(void)
1299 {
1300 	if (aes_xts_simdalg_aesni_avx)
1301 		simd_unregister_skciphers(&aes_xts_alg_aesni_avx, 1,
1302 					  &aes_xts_simdalg_aesni_avx);
1303 #if defined(CONFIG_AS_VAES) && defined(CONFIG_AS_VPCLMULQDQ)
1304 	if (aes_xts_simdalg_vaes_avx2)
1305 		simd_unregister_skciphers(&aes_xts_alg_vaes_avx2, 1,
1306 					  &aes_xts_simdalg_vaes_avx2);
1307 	if (aes_xts_simdalg_vaes_avx10_256)
1308 		simd_unregister_skciphers(&aes_xts_alg_vaes_avx10_256, 1,
1309 					  &aes_xts_simdalg_vaes_avx10_256);
1310 	if (aes_xts_simdalg_vaes_avx10_512)
1311 		simd_unregister_skciphers(&aes_xts_alg_vaes_avx10_512, 1,
1312 					  &aes_xts_simdalg_vaes_avx10_512);
1313 #endif
1314 }
1315 #else /* CONFIG_X86_64 */
1316 static int __init register_xts_algs(void)
1317 {
1318 	return 0;
1319 }
1320 
1321 static void unregister_xts_algs(void)
1322 {
1323 }
1324 #endif /* !CONFIG_X86_64 */
1325 
1326 #ifdef CONFIG_X86_64
1327 static int generic_gcmaes_set_key(struct crypto_aead *aead, const u8 *key,
1328 				  unsigned int key_len)
1329 {
1330 	struct generic_gcmaes_ctx *ctx = generic_gcmaes_ctx_get(aead);
1331 
1332 	return aes_set_key_common(&ctx->aes_key_expanded, key, key_len) ?:
1333 	       rfc4106_set_hash_subkey(ctx->hash_subkey, key, key_len);
1334 }
1335 
1336 static int generic_gcmaes_encrypt(struct aead_request *req)
1337 {
1338 	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
1339 	struct generic_gcmaes_ctx *ctx = generic_gcmaes_ctx_get(tfm);
1340 	void *aes_ctx = &(ctx->aes_key_expanded);
1341 	u8 ivbuf[16 + (AESNI_ALIGN - 8)] __aligned(8);
1342 	u8 *iv = PTR_ALIGN(&ivbuf[0], AESNI_ALIGN);
1343 	__be32 counter = cpu_to_be32(1);
1344 
1345 	memcpy(iv, req->iv, 12);
1346 	*((__be32 *)(iv+12)) = counter;
1347 
1348 	return gcmaes_encrypt(req, req->assoclen, ctx->hash_subkey, iv,
1349 			      aes_ctx);
1350 }
1351 
1352 static int generic_gcmaes_decrypt(struct aead_request *req)
1353 {
1354 	__be32 counter = cpu_to_be32(1);
1355 	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
1356 	struct generic_gcmaes_ctx *ctx = generic_gcmaes_ctx_get(tfm);
1357 	void *aes_ctx = &(ctx->aes_key_expanded);
1358 	u8 ivbuf[16 + (AESNI_ALIGN - 8)] __aligned(8);
1359 	u8 *iv = PTR_ALIGN(&ivbuf[0], AESNI_ALIGN);
1360 
1361 	memcpy(iv, req->iv, 12);
1362 	*((__be32 *)(iv+12)) = counter;
1363 
1364 	return gcmaes_decrypt(req, req->assoclen, ctx->hash_subkey, iv,
1365 			      aes_ctx);
1366 }
1367 
1368 static struct aead_alg aesni_aeads[] = { {
1369 	.setkey			= common_rfc4106_set_key,
1370 	.setauthsize		= common_rfc4106_set_authsize,
1371 	.encrypt		= helper_rfc4106_encrypt,
1372 	.decrypt		= helper_rfc4106_decrypt,
1373 	.ivsize			= GCM_RFC4106_IV_SIZE,
1374 	.maxauthsize		= 16,
1375 	.base = {
1376 		.cra_name		= "__rfc4106(gcm(aes))",
1377 		.cra_driver_name	= "__rfc4106-gcm-aesni",
1378 		.cra_priority		= 400,
1379 		.cra_flags		= CRYPTO_ALG_INTERNAL,
1380 		.cra_blocksize		= 1,
1381 		.cra_ctxsize		= sizeof(struct aesni_rfc4106_gcm_ctx),
1382 		.cra_alignmask		= 0,
1383 		.cra_module		= THIS_MODULE,
1384 	},
1385 }, {
1386 	.setkey			= generic_gcmaes_set_key,
1387 	.setauthsize		= generic_gcmaes_set_authsize,
1388 	.encrypt		= generic_gcmaes_encrypt,
1389 	.decrypt		= generic_gcmaes_decrypt,
1390 	.ivsize			= GCM_AES_IV_SIZE,
1391 	.maxauthsize		= 16,
1392 	.base = {
1393 		.cra_name		= "__gcm(aes)",
1394 		.cra_driver_name	= "__generic-gcm-aesni",
1395 		.cra_priority		= 400,
1396 		.cra_flags		= CRYPTO_ALG_INTERNAL,
1397 		.cra_blocksize		= 1,
1398 		.cra_ctxsize		= sizeof(struct generic_gcmaes_ctx),
1399 		.cra_alignmask		= 0,
1400 		.cra_module		= THIS_MODULE,
1401 	},
1402 } };
1403 #else
1404 static struct aead_alg aesni_aeads[0];
1405 #endif
1406 
1407 static struct simd_aead_alg *aesni_simd_aeads[ARRAY_SIZE(aesni_aeads)];
1408 
1409 static const struct x86_cpu_id aesni_cpu_id[] = {
1410 	X86_MATCH_FEATURE(X86_FEATURE_AES, NULL),
1411 	{}
1412 };
1413 MODULE_DEVICE_TABLE(x86cpu, aesni_cpu_id);
1414 
1415 static int __init aesni_init(void)
1416 {
1417 	int err;
1418 
1419 	if (!x86_match_cpu(aesni_cpu_id))
1420 		return -ENODEV;
1421 #ifdef CONFIG_X86_64
1422 	if (boot_cpu_has(X86_FEATURE_AVX2)) {
1423 		pr_info("AVX2 version of gcm_enc/dec engaged.\n");
1424 		static_branch_enable(&gcm_use_avx);
1425 		static_branch_enable(&gcm_use_avx2);
1426 	} else
1427 	if (boot_cpu_has(X86_FEATURE_AVX)) {
1428 		pr_info("AVX version of gcm_enc/dec engaged.\n");
1429 		static_branch_enable(&gcm_use_avx);
1430 	} else {
1431 		pr_info("SSE version of gcm_enc/dec engaged.\n");
1432 	}
1433 	if (boot_cpu_has(X86_FEATURE_AVX)) {
1434 		/* optimize performance of ctr mode encryption transform */
1435 		static_call_update(aesni_ctr_enc_tfm, aesni_ctr_enc_avx_tfm);
1436 		pr_info("AES CTR mode by8 optimization enabled\n");
1437 	}
1438 #endif /* CONFIG_X86_64 */
1439 
1440 	err = crypto_register_alg(&aesni_cipher_alg);
1441 	if (err)
1442 		return err;
1443 
1444 	err = simd_register_skciphers_compat(aesni_skciphers,
1445 					     ARRAY_SIZE(aesni_skciphers),
1446 					     aesni_simd_skciphers);
1447 	if (err)
1448 		goto unregister_cipher;
1449 
1450 	err = simd_register_aeads_compat(aesni_aeads, ARRAY_SIZE(aesni_aeads),
1451 					 aesni_simd_aeads);
1452 	if (err)
1453 		goto unregister_skciphers;
1454 
1455 #ifdef CONFIG_X86_64
1456 	if (boot_cpu_has(X86_FEATURE_AVX))
1457 		err = simd_register_skciphers_compat(&aesni_xctr, 1,
1458 						     &aesni_simd_xctr);
1459 	if (err)
1460 		goto unregister_aeads;
1461 #endif /* CONFIG_X86_64 */
1462 
1463 	err = register_xts_algs();
1464 	if (err)
1465 		goto unregister_xts;
1466 
1467 	return 0;
1468 
1469 unregister_xts:
1470 	unregister_xts_algs();
1471 #ifdef CONFIG_X86_64
1472 	if (aesni_simd_xctr)
1473 		simd_unregister_skciphers(&aesni_xctr, 1, &aesni_simd_xctr);
1474 unregister_aeads:
1475 #endif /* CONFIG_X86_64 */
1476 	simd_unregister_aeads(aesni_aeads, ARRAY_SIZE(aesni_aeads),
1477 				aesni_simd_aeads);
1478 
1479 unregister_skciphers:
1480 	simd_unregister_skciphers(aesni_skciphers, ARRAY_SIZE(aesni_skciphers),
1481 				  aesni_simd_skciphers);
1482 unregister_cipher:
1483 	crypto_unregister_alg(&aesni_cipher_alg);
1484 	return err;
1485 }
1486 
1487 static void __exit aesni_exit(void)
1488 {
1489 	simd_unregister_aeads(aesni_aeads, ARRAY_SIZE(aesni_aeads),
1490 			      aesni_simd_aeads);
1491 	simd_unregister_skciphers(aesni_skciphers, ARRAY_SIZE(aesni_skciphers),
1492 				  aesni_simd_skciphers);
1493 	crypto_unregister_alg(&aesni_cipher_alg);
1494 #ifdef CONFIG_X86_64
1495 	if (boot_cpu_has(X86_FEATURE_AVX))
1496 		simd_unregister_skciphers(&aesni_xctr, 1, &aesni_simd_xctr);
1497 #endif /* CONFIG_X86_64 */
1498 	unregister_xts_algs();
1499 }
1500 
1501 late_initcall(aesni_init);
1502 module_exit(aesni_exit);
1503 
1504 MODULE_DESCRIPTION("Rijndael (AES) Cipher Algorithm, Intel AES-NI instructions optimized");
1505 MODULE_LICENSE("GPL");
1506 MODULE_ALIAS_CRYPTO("aes");
1507