xref: /linux/arch/x86/crypto/aesni-intel_glue.c (revision 8f0e0cf74ccef41b383daddcf5447bba655031b3)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * Support for Intel AES-NI instructions. This file contains glue
4  * code, the real AES implementation is in intel-aes_asm.S.
5  *
6  * Copyright (C) 2008, Intel Corp.
7  *    Author: Huang Ying <ying.huang@intel.com>
8  *
9  * Added RFC4106 AES-GCM support for 128-bit keys under the AEAD
10  * interface for 64-bit kernels.
11  *    Authors: Adrian Hoban <adrian.hoban@intel.com>
12  *             Gabriele Paoloni <gabriele.paoloni@intel.com>
13  *             Tadeusz Struk (tadeusz.struk@intel.com)
14  *             Aidan O'Mahony (aidan.o.mahony@intel.com)
15  *    Copyright (c) 2010, Intel Corporation.
16  */
17 
18 #include <linux/hardirq.h>
19 #include <linux/types.h>
20 #include <linux/module.h>
21 #include <linux/err.h>
22 #include <crypto/algapi.h>
23 #include <crypto/aes.h>
24 #include <crypto/ctr.h>
25 #include <crypto/b128ops.h>
26 #include <crypto/gcm.h>
27 #include <crypto/xts.h>
28 #include <asm/cpu_device_id.h>
29 #include <asm/simd.h>
30 #include <crypto/scatterwalk.h>
31 #include <crypto/internal/aead.h>
32 #include <crypto/internal/simd.h>
33 #include <crypto/internal/skcipher.h>
34 #include <linux/jump_label.h>
35 #include <linux/workqueue.h>
36 #include <linux/spinlock.h>
37 #include <linux/static_call.h>
38 
39 
40 #define AESNI_ALIGN	16
41 #define AESNI_ALIGN_ATTR __attribute__ ((__aligned__(AESNI_ALIGN)))
42 #define AES_BLOCK_MASK	(~(AES_BLOCK_SIZE - 1))
43 #define RFC4106_HASH_SUBKEY_SIZE 16
44 #define AESNI_ALIGN_EXTRA ((AESNI_ALIGN - 1) & ~(CRYPTO_MINALIGN - 1))
45 #define CRYPTO_AES_CTX_SIZE (sizeof(struct crypto_aes_ctx) + AESNI_ALIGN_EXTRA)
46 #define XTS_AES_CTX_SIZE (sizeof(struct aesni_xts_ctx) + AESNI_ALIGN_EXTRA)
47 
48 /* This data is stored at the end of the crypto_tfm struct.
49  * It's a type of per "session" data storage location.
50  * This needs to be 16 byte aligned.
51  */
52 struct aesni_rfc4106_gcm_ctx {
53 	u8 hash_subkey[16] AESNI_ALIGN_ATTR;
54 	struct crypto_aes_ctx aes_key_expanded AESNI_ALIGN_ATTR;
55 	u8 nonce[4];
56 };
57 
58 struct generic_gcmaes_ctx {
59 	u8 hash_subkey[16] AESNI_ALIGN_ATTR;
60 	struct crypto_aes_ctx aes_key_expanded AESNI_ALIGN_ATTR;
61 };
62 
63 struct aesni_xts_ctx {
64 	struct crypto_aes_ctx tweak_ctx AESNI_ALIGN_ATTR;
65 	struct crypto_aes_ctx crypt_ctx AESNI_ALIGN_ATTR;
66 };
67 
68 #define GCM_BLOCK_LEN 16
69 
70 struct gcm_context_data {
71 	/* init, update and finalize context data */
72 	u8 aad_hash[GCM_BLOCK_LEN];
73 	u64 aad_length;
74 	u64 in_length;
75 	u8 partial_block_enc_key[GCM_BLOCK_LEN];
76 	u8 orig_IV[GCM_BLOCK_LEN];
77 	u8 current_counter[GCM_BLOCK_LEN];
78 	u64 partial_block_len;
79 	u64 unused;
80 	u8 hash_keys[GCM_BLOCK_LEN * 16];
81 };
82 
83 static inline void *aes_align_addr(void *addr)
84 {
85 	if (crypto_tfm_ctx_alignment() >= AESNI_ALIGN)
86 		return addr;
87 	return PTR_ALIGN(addr, AESNI_ALIGN);
88 }
89 
90 asmlinkage void aesni_set_key(struct crypto_aes_ctx *ctx, const u8 *in_key,
91 			      unsigned int key_len);
92 asmlinkage void aesni_enc(const void *ctx, u8 *out, const u8 *in);
93 asmlinkage void aesni_dec(const void *ctx, u8 *out, const u8 *in);
94 asmlinkage void aesni_ecb_enc(struct crypto_aes_ctx *ctx, u8 *out,
95 			      const u8 *in, unsigned int len);
96 asmlinkage void aesni_ecb_dec(struct crypto_aes_ctx *ctx, u8 *out,
97 			      const u8 *in, unsigned int len);
98 asmlinkage void aesni_cbc_enc(struct crypto_aes_ctx *ctx, u8 *out,
99 			      const u8 *in, unsigned int len, u8 *iv);
100 asmlinkage void aesni_cbc_dec(struct crypto_aes_ctx *ctx, u8 *out,
101 			      const u8 *in, unsigned int len, u8 *iv);
102 asmlinkage void aesni_cts_cbc_enc(struct crypto_aes_ctx *ctx, u8 *out,
103 				  const u8 *in, unsigned int len, u8 *iv);
104 asmlinkage void aesni_cts_cbc_dec(struct crypto_aes_ctx *ctx, u8 *out,
105 				  const u8 *in, unsigned int len, u8 *iv);
106 
107 #define AVX_GEN2_OPTSIZE 640
108 #define AVX_GEN4_OPTSIZE 4096
109 
110 asmlinkage void aesni_xts_encrypt(const struct crypto_aes_ctx *ctx, u8 *out,
111 				  const u8 *in, unsigned int len, u8 *iv);
112 
113 asmlinkage void aesni_xts_decrypt(const struct crypto_aes_ctx *ctx, u8 *out,
114 				  const u8 *in, unsigned int len, u8 *iv);
115 
116 #ifdef CONFIG_X86_64
117 
118 asmlinkage void aesni_ctr_enc(struct crypto_aes_ctx *ctx, u8 *out,
119 			      const u8 *in, unsigned int len, u8 *iv);
120 DEFINE_STATIC_CALL(aesni_ctr_enc_tfm, aesni_ctr_enc);
121 
122 /* Scatter / Gather routines, with args similar to above */
123 asmlinkage void aesni_gcm_init(void *ctx,
124 			       struct gcm_context_data *gdata,
125 			       u8 *iv,
126 			       u8 *hash_subkey, const u8 *aad,
127 			       unsigned long aad_len);
128 asmlinkage void aesni_gcm_enc_update(void *ctx,
129 				     struct gcm_context_data *gdata, u8 *out,
130 				     const u8 *in, unsigned long plaintext_len);
131 asmlinkage void aesni_gcm_dec_update(void *ctx,
132 				     struct gcm_context_data *gdata, u8 *out,
133 				     const u8 *in,
134 				     unsigned long ciphertext_len);
135 asmlinkage void aesni_gcm_finalize(void *ctx,
136 				   struct gcm_context_data *gdata,
137 				   u8 *auth_tag, unsigned long auth_tag_len);
138 
139 asmlinkage void aes_ctr_enc_128_avx_by8(const u8 *in, u8 *iv,
140 		void *keys, u8 *out, unsigned int num_bytes);
141 asmlinkage void aes_ctr_enc_192_avx_by8(const u8 *in, u8 *iv,
142 		void *keys, u8 *out, unsigned int num_bytes);
143 asmlinkage void aes_ctr_enc_256_avx_by8(const u8 *in, u8 *iv,
144 		void *keys, u8 *out, unsigned int num_bytes);
145 
146 
147 asmlinkage void aes_xctr_enc_128_avx_by8(const u8 *in, const u8 *iv,
148 	const void *keys, u8 *out, unsigned int num_bytes,
149 	unsigned int byte_ctr);
150 
151 asmlinkage void aes_xctr_enc_192_avx_by8(const u8 *in, const u8 *iv,
152 	const void *keys, u8 *out, unsigned int num_bytes,
153 	unsigned int byte_ctr);
154 
155 asmlinkage void aes_xctr_enc_256_avx_by8(const u8 *in, const u8 *iv,
156 	const void *keys, u8 *out, unsigned int num_bytes,
157 	unsigned int byte_ctr);
158 
159 /*
160  * asmlinkage void aesni_gcm_init_avx_gen2()
161  * gcm_data *my_ctx_data, context data
162  * u8 *hash_subkey,  the Hash sub key input. Data starts on a 16-byte boundary.
163  */
164 asmlinkage void aesni_gcm_init_avx_gen2(void *my_ctx_data,
165 					struct gcm_context_data *gdata,
166 					u8 *iv,
167 					u8 *hash_subkey,
168 					const u8 *aad,
169 					unsigned long aad_len);
170 
171 asmlinkage void aesni_gcm_enc_update_avx_gen2(void *ctx,
172 				     struct gcm_context_data *gdata, u8 *out,
173 				     const u8 *in, unsigned long plaintext_len);
174 asmlinkage void aesni_gcm_dec_update_avx_gen2(void *ctx,
175 				     struct gcm_context_data *gdata, u8 *out,
176 				     const u8 *in,
177 				     unsigned long ciphertext_len);
178 asmlinkage void aesni_gcm_finalize_avx_gen2(void *ctx,
179 				   struct gcm_context_data *gdata,
180 				   u8 *auth_tag, unsigned long auth_tag_len);
181 
182 /*
183  * asmlinkage void aesni_gcm_init_avx_gen4()
184  * gcm_data *my_ctx_data, context data
185  * u8 *hash_subkey,  the Hash sub key input. Data starts on a 16-byte boundary.
186  */
187 asmlinkage void aesni_gcm_init_avx_gen4(void *my_ctx_data,
188 					struct gcm_context_data *gdata,
189 					u8 *iv,
190 					u8 *hash_subkey,
191 					const u8 *aad,
192 					unsigned long aad_len);
193 
194 asmlinkage void aesni_gcm_enc_update_avx_gen4(void *ctx,
195 				     struct gcm_context_data *gdata, u8 *out,
196 				     const u8 *in, unsigned long plaintext_len);
197 asmlinkage void aesni_gcm_dec_update_avx_gen4(void *ctx,
198 				     struct gcm_context_data *gdata, u8 *out,
199 				     const u8 *in,
200 				     unsigned long ciphertext_len);
201 asmlinkage void aesni_gcm_finalize_avx_gen4(void *ctx,
202 				   struct gcm_context_data *gdata,
203 				   u8 *auth_tag, unsigned long auth_tag_len);
204 
205 static __ro_after_init DEFINE_STATIC_KEY_FALSE(gcm_use_avx);
206 static __ro_after_init DEFINE_STATIC_KEY_FALSE(gcm_use_avx2);
207 
208 static inline struct
209 aesni_rfc4106_gcm_ctx *aesni_rfc4106_gcm_ctx_get(struct crypto_aead *tfm)
210 {
211 	return aes_align_addr(crypto_aead_ctx(tfm));
212 }
213 
214 static inline struct
215 generic_gcmaes_ctx *generic_gcmaes_ctx_get(struct crypto_aead *tfm)
216 {
217 	return aes_align_addr(crypto_aead_ctx(tfm));
218 }
219 #endif
220 
221 static inline struct crypto_aes_ctx *aes_ctx(void *raw_ctx)
222 {
223 	return aes_align_addr(raw_ctx);
224 }
225 
226 static inline struct aesni_xts_ctx *aes_xts_ctx(struct crypto_skcipher *tfm)
227 {
228 	return aes_align_addr(crypto_skcipher_ctx(tfm));
229 }
230 
231 static int aes_set_key_common(struct crypto_aes_ctx *ctx,
232 			      const u8 *in_key, unsigned int key_len)
233 {
234 	int err;
235 
236 	if (!crypto_simd_usable())
237 		return aes_expandkey(ctx, in_key, key_len);
238 
239 	err = aes_check_keylen(key_len);
240 	if (err)
241 		return err;
242 
243 	kernel_fpu_begin();
244 	aesni_set_key(ctx, in_key, key_len);
245 	kernel_fpu_end();
246 	return 0;
247 }
248 
249 static int aes_set_key(struct crypto_tfm *tfm, const u8 *in_key,
250 		       unsigned int key_len)
251 {
252 	return aes_set_key_common(aes_ctx(crypto_tfm_ctx(tfm)), in_key,
253 				  key_len);
254 }
255 
256 static void aesni_encrypt(struct crypto_tfm *tfm, u8 *dst, const u8 *src)
257 {
258 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_tfm_ctx(tfm));
259 
260 	if (!crypto_simd_usable()) {
261 		aes_encrypt(ctx, dst, src);
262 	} else {
263 		kernel_fpu_begin();
264 		aesni_enc(ctx, dst, src);
265 		kernel_fpu_end();
266 	}
267 }
268 
269 static void aesni_decrypt(struct crypto_tfm *tfm, u8 *dst, const u8 *src)
270 {
271 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_tfm_ctx(tfm));
272 
273 	if (!crypto_simd_usable()) {
274 		aes_decrypt(ctx, dst, src);
275 	} else {
276 		kernel_fpu_begin();
277 		aesni_dec(ctx, dst, src);
278 		kernel_fpu_end();
279 	}
280 }
281 
282 static int aesni_skcipher_setkey(struct crypto_skcipher *tfm, const u8 *key,
283 			         unsigned int len)
284 {
285 	return aes_set_key_common(aes_ctx(crypto_skcipher_ctx(tfm)), key, len);
286 }
287 
288 static int ecb_encrypt(struct skcipher_request *req)
289 {
290 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
291 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
292 	struct skcipher_walk walk;
293 	unsigned int nbytes;
294 	int err;
295 
296 	err = skcipher_walk_virt(&walk, req, false);
297 
298 	while ((nbytes = walk.nbytes)) {
299 		kernel_fpu_begin();
300 		aesni_ecb_enc(ctx, walk.dst.virt.addr, walk.src.virt.addr,
301 			      nbytes & AES_BLOCK_MASK);
302 		kernel_fpu_end();
303 		nbytes &= AES_BLOCK_SIZE - 1;
304 		err = skcipher_walk_done(&walk, nbytes);
305 	}
306 
307 	return err;
308 }
309 
310 static int ecb_decrypt(struct skcipher_request *req)
311 {
312 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
313 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
314 	struct skcipher_walk walk;
315 	unsigned int nbytes;
316 	int err;
317 
318 	err = skcipher_walk_virt(&walk, req, false);
319 
320 	while ((nbytes = walk.nbytes)) {
321 		kernel_fpu_begin();
322 		aesni_ecb_dec(ctx, walk.dst.virt.addr, walk.src.virt.addr,
323 			      nbytes & AES_BLOCK_MASK);
324 		kernel_fpu_end();
325 		nbytes &= AES_BLOCK_SIZE - 1;
326 		err = skcipher_walk_done(&walk, nbytes);
327 	}
328 
329 	return err;
330 }
331 
332 static int cbc_encrypt(struct skcipher_request *req)
333 {
334 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
335 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
336 	struct skcipher_walk walk;
337 	unsigned int nbytes;
338 	int err;
339 
340 	err = skcipher_walk_virt(&walk, req, false);
341 
342 	while ((nbytes = walk.nbytes)) {
343 		kernel_fpu_begin();
344 		aesni_cbc_enc(ctx, walk.dst.virt.addr, walk.src.virt.addr,
345 			      nbytes & AES_BLOCK_MASK, walk.iv);
346 		kernel_fpu_end();
347 		nbytes &= AES_BLOCK_SIZE - 1;
348 		err = skcipher_walk_done(&walk, nbytes);
349 	}
350 
351 	return err;
352 }
353 
354 static int cbc_decrypt(struct skcipher_request *req)
355 {
356 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
357 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
358 	struct skcipher_walk walk;
359 	unsigned int nbytes;
360 	int err;
361 
362 	err = skcipher_walk_virt(&walk, req, false);
363 
364 	while ((nbytes = walk.nbytes)) {
365 		kernel_fpu_begin();
366 		aesni_cbc_dec(ctx, walk.dst.virt.addr, walk.src.virt.addr,
367 			      nbytes & AES_BLOCK_MASK, walk.iv);
368 		kernel_fpu_end();
369 		nbytes &= AES_BLOCK_SIZE - 1;
370 		err = skcipher_walk_done(&walk, nbytes);
371 	}
372 
373 	return err;
374 }
375 
376 static int cts_cbc_encrypt(struct skcipher_request *req)
377 {
378 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
379 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
380 	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
381 	struct scatterlist *src = req->src, *dst = req->dst;
382 	struct scatterlist sg_src[2], sg_dst[2];
383 	struct skcipher_request subreq;
384 	struct skcipher_walk walk;
385 	int err;
386 
387 	skcipher_request_set_tfm(&subreq, tfm);
388 	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
389 				      NULL, NULL);
390 
391 	if (req->cryptlen <= AES_BLOCK_SIZE) {
392 		if (req->cryptlen < AES_BLOCK_SIZE)
393 			return -EINVAL;
394 		cbc_blocks = 1;
395 	}
396 
397 	if (cbc_blocks > 0) {
398 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
399 					   cbc_blocks * AES_BLOCK_SIZE,
400 					   req->iv);
401 
402 		err = cbc_encrypt(&subreq);
403 		if (err)
404 			return err;
405 
406 		if (req->cryptlen == AES_BLOCK_SIZE)
407 			return 0;
408 
409 		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
410 		if (req->dst != req->src)
411 			dst = scatterwalk_ffwd(sg_dst, req->dst,
412 					       subreq.cryptlen);
413 	}
414 
415 	/* handle ciphertext stealing */
416 	skcipher_request_set_crypt(&subreq, src, dst,
417 				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
418 				   req->iv);
419 
420 	err = skcipher_walk_virt(&walk, &subreq, false);
421 	if (err)
422 		return err;
423 
424 	kernel_fpu_begin();
425 	aesni_cts_cbc_enc(ctx, walk.dst.virt.addr, walk.src.virt.addr,
426 			  walk.nbytes, walk.iv);
427 	kernel_fpu_end();
428 
429 	return skcipher_walk_done(&walk, 0);
430 }
431 
432 static int cts_cbc_decrypt(struct skcipher_request *req)
433 {
434 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
435 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
436 	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
437 	struct scatterlist *src = req->src, *dst = req->dst;
438 	struct scatterlist sg_src[2], sg_dst[2];
439 	struct skcipher_request subreq;
440 	struct skcipher_walk walk;
441 	int err;
442 
443 	skcipher_request_set_tfm(&subreq, tfm);
444 	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
445 				      NULL, NULL);
446 
447 	if (req->cryptlen <= AES_BLOCK_SIZE) {
448 		if (req->cryptlen < AES_BLOCK_SIZE)
449 			return -EINVAL;
450 		cbc_blocks = 1;
451 	}
452 
453 	if (cbc_blocks > 0) {
454 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
455 					   cbc_blocks * AES_BLOCK_SIZE,
456 					   req->iv);
457 
458 		err = cbc_decrypt(&subreq);
459 		if (err)
460 			return err;
461 
462 		if (req->cryptlen == AES_BLOCK_SIZE)
463 			return 0;
464 
465 		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
466 		if (req->dst != req->src)
467 			dst = scatterwalk_ffwd(sg_dst, req->dst,
468 					       subreq.cryptlen);
469 	}
470 
471 	/* handle ciphertext stealing */
472 	skcipher_request_set_crypt(&subreq, src, dst,
473 				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
474 				   req->iv);
475 
476 	err = skcipher_walk_virt(&walk, &subreq, false);
477 	if (err)
478 		return err;
479 
480 	kernel_fpu_begin();
481 	aesni_cts_cbc_dec(ctx, walk.dst.virt.addr, walk.src.virt.addr,
482 			  walk.nbytes, walk.iv);
483 	kernel_fpu_end();
484 
485 	return skcipher_walk_done(&walk, 0);
486 }
487 
488 #ifdef CONFIG_X86_64
489 static void aesni_ctr_enc_avx_tfm(struct crypto_aes_ctx *ctx, u8 *out,
490 			      const u8 *in, unsigned int len, u8 *iv)
491 {
492 	/*
493 	 * based on key length, override with the by8 version
494 	 * of ctr mode encryption/decryption for improved performance
495 	 * aes_set_key_common() ensures that key length is one of
496 	 * {128,192,256}
497 	 */
498 	if (ctx->key_length == AES_KEYSIZE_128)
499 		aes_ctr_enc_128_avx_by8(in, iv, (void *)ctx, out, len);
500 	else if (ctx->key_length == AES_KEYSIZE_192)
501 		aes_ctr_enc_192_avx_by8(in, iv, (void *)ctx, out, len);
502 	else
503 		aes_ctr_enc_256_avx_by8(in, iv, (void *)ctx, out, len);
504 }
505 
506 static int ctr_crypt(struct skcipher_request *req)
507 {
508 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
509 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
510 	u8 keystream[AES_BLOCK_SIZE];
511 	struct skcipher_walk walk;
512 	unsigned int nbytes;
513 	int err;
514 
515 	err = skcipher_walk_virt(&walk, req, false);
516 
517 	while ((nbytes = walk.nbytes) > 0) {
518 		kernel_fpu_begin();
519 		if (nbytes & AES_BLOCK_MASK)
520 			static_call(aesni_ctr_enc_tfm)(ctx, walk.dst.virt.addr,
521 						       walk.src.virt.addr,
522 						       nbytes & AES_BLOCK_MASK,
523 						       walk.iv);
524 		nbytes &= ~AES_BLOCK_MASK;
525 
526 		if (walk.nbytes == walk.total && nbytes > 0) {
527 			aesni_enc(ctx, keystream, walk.iv);
528 			crypto_xor_cpy(walk.dst.virt.addr + walk.nbytes - nbytes,
529 				       walk.src.virt.addr + walk.nbytes - nbytes,
530 				       keystream, nbytes);
531 			crypto_inc(walk.iv, AES_BLOCK_SIZE);
532 			nbytes = 0;
533 		}
534 		kernel_fpu_end();
535 		err = skcipher_walk_done(&walk, nbytes);
536 	}
537 	return err;
538 }
539 
540 static void aesni_xctr_enc_avx_tfm(struct crypto_aes_ctx *ctx, u8 *out,
541 				   const u8 *in, unsigned int len, u8 *iv,
542 				   unsigned int byte_ctr)
543 {
544 	if (ctx->key_length == AES_KEYSIZE_128)
545 		aes_xctr_enc_128_avx_by8(in, iv, (void *)ctx, out, len,
546 					 byte_ctr);
547 	else if (ctx->key_length == AES_KEYSIZE_192)
548 		aes_xctr_enc_192_avx_by8(in, iv, (void *)ctx, out, len,
549 					 byte_ctr);
550 	else
551 		aes_xctr_enc_256_avx_by8(in, iv, (void *)ctx, out, len,
552 					 byte_ctr);
553 }
554 
555 static int xctr_crypt(struct skcipher_request *req)
556 {
557 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
558 	struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
559 	u8 keystream[AES_BLOCK_SIZE];
560 	struct skcipher_walk walk;
561 	unsigned int nbytes;
562 	unsigned int byte_ctr = 0;
563 	int err;
564 	__le32 block[AES_BLOCK_SIZE / sizeof(__le32)];
565 
566 	err = skcipher_walk_virt(&walk, req, false);
567 
568 	while ((nbytes = walk.nbytes) > 0) {
569 		kernel_fpu_begin();
570 		if (nbytes & AES_BLOCK_MASK)
571 			aesni_xctr_enc_avx_tfm(ctx, walk.dst.virt.addr,
572 				walk.src.virt.addr, nbytes & AES_BLOCK_MASK,
573 				walk.iv, byte_ctr);
574 		nbytes &= ~AES_BLOCK_MASK;
575 		byte_ctr += walk.nbytes - nbytes;
576 
577 		if (walk.nbytes == walk.total && nbytes > 0) {
578 			memcpy(block, walk.iv, AES_BLOCK_SIZE);
579 			block[0] ^= cpu_to_le32(1 + byte_ctr / AES_BLOCK_SIZE);
580 			aesni_enc(ctx, keystream, (u8 *)block);
581 			crypto_xor_cpy(walk.dst.virt.addr + walk.nbytes -
582 				       nbytes, walk.src.virt.addr + walk.nbytes
583 				       - nbytes, keystream, nbytes);
584 			byte_ctr += nbytes;
585 			nbytes = 0;
586 		}
587 		kernel_fpu_end();
588 		err = skcipher_walk_done(&walk, nbytes);
589 	}
590 	return err;
591 }
592 
593 static int
594 rfc4106_set_hash_subkey(u8 *hash_subkey, const u8 *key, unsigned int key_len)
595 {
596 	struct crypto_aes_ctx ctx;
597 	int ret;
598 
599 	ret = aes_expandkey(&ctx, key, key_len);
600 	if (ret)
601 		return ret;
602 
603 	/* Clear the data in the hash sub key container to zero.*/
604 	/* We want to cipher all zeros to create the hash sub key. */
605 	memset(hash_subkey, 0, RFC4106_HASH_SUBKEY_SIZE);
606 
607 	aes_encrypt(&ctx, hash_subkey, hash_subkey);
608 
609 	memzero_explicit(&ctx, sizeof(ctx));
610 	return 0;
611 }
612 
613 static int common_rfc4106_set_key(struct crypto_aead *aead, const u8 *key,
614 				  unsigned int key_len)
615 {
616 	struct aesni_rfc4106_gcm_ctx *ctx = aesni_rfc4106_gcm_ctx_get(aead);
617 
618 	if (key_len < 4)
619 		return -EINVAL;
620 
621 	/*Account for 4 byte nonce at the end.*/
622 	key_len -= 4;
623 
624 	memcpy(ctx->nonce, key + key_len, sizeof(ctx->nonce));
625 
626 	return aes_set_key_common(&ctx->aes_key_expanded, key, key_len) ?:
627 	       rfc4106_set_hash_subkey(ctx->hash_subkey, key, key_len);
628 }
629 
630 /* This is the Integrity Check Value (aka the authentication tag) length and can
631  * be 8, 12 or 16 bytes long. */
632 static int common_rfc4106_set_authsize(struct crypto_aead *aead,
633 				       unsigned int authsize)
634 {
635 	switch (authsize) {
636 	case 8:
637 	case 12:
638 	case 16:
639 		break;
640 	default:
641 		return -EINVAL;
642 	}
643 
644 	return 0;
645 }
646 
647 static int generic_gcmaes_set_authsize(struct crypto_aead *tfm,
648 				       unsigned int authsize)
649 {
650 	switch (authsize) {
651 	case 4:
652 	case 8:
653 	case 12:
654 	case 13:
655 	case 14:
656 	case 15:
657 	case 16:
658 		break;
659 	default:
660 		return -EINVAL;
661 	}
662 
663 	return 0;
664 }
665 
666 static int gcmaes_crypt_by_sg(bool enc, struct aead_request *req,
667 			      unsigned int assoclen, u8 *hash_subkey,
668 			      u8 *iv, void *aes_ctx, u8 *auth_tag,
669 			      unsigned long auth_tag_len)
670 {
671 	u8 databuf[sizeof(struct gcm_context_data) + (AESNI_ALIGN - 8)] __aligned(8);
672 	struct gcm_context_data *data = PTR_ALIGN((void *)databuf, AESNI_ALIGN);
673 	unsigned long left = req->cryptlen;
674 	struct scatter_walk assoc_sg_walk;
675 	struct skcipher_walk walk;
676 	bool do_avx, do_avx2;
677 	u8 *assocmem = NULL;
678 	u8 *assoc;
679 	int err;
680 
681 	if (!enc)
682 		left -= auth_tag_len;
683 
684 	do_avx = (left >= AVX_GEN2_OPTSIZE);
685 	do_avx2 = (left >= AVX_GEN4_OPTSIZE);
686 
687 	/* Linearize assoc, if not already linear */
688 	if (req->src->length >= assoclen && req->src->length) {
689 		scatterwalk_start(&assoc_sg_walk, req->src);
690 		assoc = scatterwalk_map(&assoc_sg_walk);
691 	} else {
692 		gfp_t flags = (req->base.flags & CRYPTO_TFM_REQ_MAY_SLEEP) ?
693 			      GFP_KERNEL : GFP_ATOMIC;
694 
695 		/* assoc can be any length, so must be on heap */
696 		assocmem = kmalloc(assoclen, flags);
697 		if (unlikely(!assocmem))
698 			return -ENOMEM;
699 		assoc = assocmem;
700 
701 		scatterwalk_map_and_copy(assoc, req->src, 0, assoclen, 0);
702 	}
703 
704 	kernel_fpu_begin();
705 	if (static_branch_likely(&gcm_use_avx2) && do_avx2)
706 		aesni_gcm_init_avx_gen4(aes_ctx, data, iv, hash_subkey, assoc,
707 					assoclen);
708 	else if (static_branch_likely(&gcm_use_avx) && do_avx)
709 		aesni_gcm_init_avx_gen2(aes_ctx, data, iv, hash_subkey, assoc,
710 					assoclen);
711 	else
712 		aesni_gcm_init(aes_ctx, data, iv, hash_subkey, assoc, assoclen);
713 	kernel_fpu_end();
714 
715 	if (!assocmem)
716 		scatterwalk_unmap(assoc);
717 	else
718 		kfree(assocmem);
719 
720 	err = enc ? skcipher_walk_aead_encrypt(&walk, req, false)
721 		  : skcipher_walk_aead_decrypt(&walk, req, false);
722 
723 	while (walk.nbytes > 0) {
724 		kernel_fpu_begin();
725 		if (static_branch_likely(&gcm_use_avx2) && do_avx2) {
726 			if (enc)
727 				aesni_gcm_enc_update_avx_gen4(aes_ctx, data,
728 							      walk.dst.virt.addr,
729 							      walk.src.virt.addr,
730 							      walk.nbytes);
731 			else
732 				aesni_gcm_dec_update_avx_gen4(aes_ctx, data,
733 							      walk.dst.virt.addr,
734 							      walk.src.virt.addr,
735 							      walk.nbytes);
736 		} else if (static_branch_likely(&gcm_use_avx) && do_avx) {
737 			if (enc)
738 				aesni_gcm_enc_update_avx_gen2(aes_ctx, data,
739 							      walk.dst.virt.addr,
740 							      walk.src.virt.addr,
741 							      walk.nbytes);
742 			else
743 				aesni_gcm_dec_update_avx_gen2(aes_ctx, data,
744 							      walk.dst.virt.addr,
745 							      walk.src.virt.addr,
746 							      walk.nbytes);
747 		} else if (enc) {
748 			aesni_gcm_enc_update(aes_ctx, data, walk.dst.virt.addr,
749 					     walk.src.virt.addr, walk.nbytes);
750 		} else {
751 			aesni_gcm_dec_update(aes_ctx, data, walk.dst.virt.addr,
752 					     walk.src.virt.addr, walk.nbytes);
753 		}
754 		kernel_fpu_end();
755 
756 		err = skcipher_walk_done(&walk, 0);
757 	}
758 
759 	if (err)
760 		return err;
761 
762 	kernel_fpu_begin();
763 	if (static_branch_likely(&gcm_use_avx2) && do_avx2)
764 		aesni_gcm_finalize_avx_gen4(aes_ctx, data, auth_tag,
765 					    auth_tag_len);
766 	else if (static_branch_likely(&gcm_use_avx) && do_avx)
767 		aesni_gcm_finalize_avx_gen2(aes_ctx, data, auth_tag,
768 					    auth_tag_len);
769 	else
770 		aesni_gcm_finalize(aes_ctx, data, auth_tag, auth_tag_len);
771 	kernel_fpu_end();
772 
773 	return 0;
774 }
775 
776 static int gcmaes_encrypt(struct aead_request *req, unsigned int assoclen,
777 			  u8 *hash_subkey, u8 *iv, void *aes_ctx)
778 {
779 	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
780 	unsigned long auth_tag_len = crypto_aead_authsize(tfm);
781 	u8 auth_tag[16];
782 	int err;
783 
784 	err = gcmaes_crypt_by_sg(true, req, assoclen, hash_subkey, iv, aes_ctx,
785 				 auth_tag, auth_tag_len);
786 	if (err)
787 		return err;
788 
789 	scatterwalk_map_and_copy(auth_tag, req->dst,
790 				 req->assoclen + req->cryptlen,
791 				 auth_tag_len, 1);
792 	return 0;
793 }
794 
795 static int gcmaes_decrypt(struct aead_request *req, unsigned int assoclen,
796 			  u8 *hash_subkey, u8 *iv, void *aes_ctx)
797 {
798 	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
799 	unsigned long auth_tag_len = crypto_aead_authsize(tfm);
800 	u8 auth_tag_msg[16];
801 	u8 auth_tag[16];
802 	int err;
803 
804 	err = gcmaes_crypt_by_sg(false, req, assoclen, hash_subkey, iv, aes_ctx,
805 				 auth_tag, auth_tag_len);
806 	if (err)
807 		return err;
808 
809 	/* Copy out original auth_tag */
810 	scatterwalk_map_and_copy(auth_tag_msg, req->src,
811 				 req->assoclen + req->cryptlen - auth_tag_len,
812 				 auth_tag_len, 0);
813 
814 	/* Compare generated tag with passed in tag. */
815 	if (crypto_memneq(auth_tag_msg, auth_tag, auth_tag_len)) {
816 		memzero_explicit(auth_tag, sizeof(auth_tag));
817 		return -EBADMSG;
818 	}
819 	return 0;
820 }
821 
822 static int helper_rfc4106_encrypt(struct aead_request *req)
823 {
824 	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
825 	struct aesni_rfc4106_gcm_ctx *ctx = aesni_rfc4106_gcm_ctx_get(tfm);
826 	void *aes_ctx = &(ctx->aes_key_expanded);
827 	u8 ivbuf[16 + (AESNI_ALIGN - 8)] __aligned(8);
828 	u8 *iv = PTR_ALIGN(&ivbuf[0], AESNI_ALIGN);
829 	unsigned int i;
830 	__be32 counter = cpu_to_be32(1);
831 
832 	/* Assuming we are supporting rfc4106 64-bit extended */
833 	/* sequence numbers We need to have the AAD length equal */
834 	/* to 16 or 20 bytes */
835 	if (unlikely(req->assoclen != 16 && req->assoclen != 20))
836 		return -EINVAL;
837 
838 	/* IV below built */
839 	for (i = 0; i < 4; i++)
840 		*(iv+i) = ctx->nonce[i];
841 	for (i = 0; i < 8; i++)
842 		*(iv+4+i) = req->iv[i];
843 	*((__be32 *)(iv+12)) = counter;
844 
845 	return gcmaes_encrypt(req, req->assoclen - 8, ctx->hash_subkey, iv,
846 			      aes_ctx);
847 }
848 
849 static int helper_rfc4106_decrypt(struct aead_request *req)
850 {
851 	__be32 counter = cpu_to_be32(1);
852 	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
853 	struct aesni_rfc4106_gcm_ctx *ctx = aesni_rfc4106_gcm_ctx_get(tfm);
854 	void *aes_ctx = &(ctx->aes_key_expanded);
855 	u8 ivbuf[16 + (AESNI_ALIGN - 8)] __aligned(8);
856 	u8 *iv = PTR_ALIGN(&ivbuf[0], AESNI_ALIGN);
857 	unsigned int i;
858 
859 	if (unlikely(req->assoclen != 16 && req->assoclen != 20))
860 		return -EINVAL;
861 
862 	/* Assuming we are supporting rfc4106 64-bit extended */
863 	/* sequence numbers We need to have the AAD length */
864 	/* equal to 16 or 20 bytes */
865 
866 	/* IV below built */
867 	for (i = 0; i < 4; i++)
868 		*(iv+i) = ctx->nonce[i];
869 	for (i = 0; i < 8; i++)
870 		*(iv+4+i) = req->iv[i];
871 	*((__be32 *)(iv+12)) = counter;
872 
873 	return gcmaes_decrypt(req, req->assoclen - 8, ctx->hash_subkey, iv,
874 			      aes_ctx);
875 }
876 #endif
877 
878 static int xts_aesni_setkey(struct crypto_skcipher *tfm, const u8 *key,
879 			    unsigned int keylen)
880 {
881 	struct aesni_xts_ctx *ctx = aes_xts_ctx(tfm);
882 	int err;
883 
884 	err = xts_verify_key(tfm, key, keylen);
885 	if (err)
886 		return err;
887 
888 	keylen /= 2;
889 
890 	/* first half of xts-key is for crypt */
891 	err = aes_set_key_common(&ctx->crypt_ctx, key, keylen);
892 	if (err)
893 		return err;
894 
895 	/* second half of xts-key is for tweak */
896 	return aes_set_key_common(&ctx->tweak_ctx, key + keylen, keylen);
897 }
898 
899 static int xts_crypt(struct skcipher_request *req, bool encrypt)
900 {
901 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
902 	struct aesni_xts_ctx *ctx = aes_xts_ctx(tfm);
903 	int tail = req->cryptlen % AES_BLOCK_SIZE;
904 	struct skcipher_request subreq;
905 	struct skcipher_walk walk;
906 	int err;
907 
908 	if (req->cryptlen < AES_BLOCK_SIZE)
909 		return -EINVAL;
910 
911 	err = skcipher_walk_virt(&walk, req, false);
912 	if (!walk.nbytes)
913 		return err;
914 
915 	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
916 		int blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
917 
918 		skcipher_walk_abort(&walk);
919 
920 		skcipher_request_set_tfm(&subreq, tfm);
921 		skcipher_request_set_callback(&subreq,
922 					      skcipher_request_flags(req),
923 					      NULL, NULL);
924 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
925 					   blocks * AES_BLOCK_SIZE, req->iv);
926 		req = &subreq;
927 
928 		err = skcipher_walk_virt(&walk, req, false);
929 		if (!walk.nbytes)
930 			return err;
931 	} else {
932 		tail = 0;
933 	}
934 
935 	kernel_fpu_begin();
936 
937 	/* calculate first value of T */
938 	aesni_enc(&ctx->tweak_ctx, walk.iv, walk.iv);
939 
940 	while (walk.nbytes > 0) {
941 		int nbytes = walk.nbytes;
942 
943 		if (nbytes < walk.total)
944 			nbytes &= ~(AES_BLOCK_SIZE - 1);
945 
946 		if (encrypt)
947 			aesni_xts_encrypt(&ctx->crypt_ctx,
948 					  walk.dst.virt.addr, walk.src.virt.addr,
949 					  nbytes, walk.iv);
950 		else
951 			aesni_xts_decrypt(&ctx->crypt_ctx,
952 					  walk.dst.virt.addr, walk.src.virt.addr,
953 					  nbytes, walk.iv);
954 		kernel_fpu_end();
955 
956 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
957 
958 		if (walk.nbytes > 0)
959 			kernel_fpu_begin();
960 	}
961 
962 	if (unlikely(tail > 0 && !err)) {
963 		struct scatterlist sg_src[2], sg_dst[2];
964 		struct scatterlist *src, *dst;
965 
966 		dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
967 		if (req->dst != req->src)
968 			dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
969 
970 		skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
971 					   req->iv);
972 
973 		err = skcipher_walk_virt(&walk, &subreq, false);
974 		if (err)
975 			return err;
976 
977 		kernel_fpu_begin();
978 		if (encrypt)
979 			aesni_xts_encrypt(&ctx->crypt_ctx,
980 					  walk.dst.virt.addr, walk.src.virt.addr,
981 					  walk.nbytes, walk.iv);
982 		else
983 			aesni_xts_decrypt(&ctx->crypt_ctx,
984 					  walk.dst.virt.addr, walk.src.virt.addr,
985 					  walk.nbytes, walk.iv);
986 		kernel_fpu_end();
987 
988 		err = skcipher_walk_done(&walk, 0);
989 	}
990 	return err;
991 }
992 
993 static int xts_encrypt(struct skcipher_request *req)
994 {
995 	return xts_crypt(req, true);
996 }
997 
998 static int xts_decrypt(struct skcipher_request *req)
999 {
1000 	return xts_crypt(req, false);
1001 }
1002 
1003 static struct crypto_alg aesni_cipher_alg = {
1004 	.cra_name		= "aes",
1005 	.cra_driver_name	= "aes-aesni",
1006 	.cra_priority		= 300,
1007 	.cra_flags		= CRYPTO_ALG_TYPE_CIPHER,
1008 	.cra_blocksize		= AES_BLOCK_SIZE,
1009 	.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
1010 	.cra_module		= THIS_MODULE,
1011 	.cra_u	= {
1012 		.cipher	= {
1013 			.cia_min_keysize	= AES_MIN_KEY_SIZE,
1014 			.cia_max_keysize	= AES_MAX_KEY_SIZE,
1015 			.cia_setkey		= aes_set_key,
1016 			.cia_encrypt		= aesni_encrypt,
1017 			.cia_decrypt		= aesni_decrypt
1018 		}
1019 	}
1020 };
1021 
1022 static struct skcipher_alg aesni_skciphers[] = {
1023 	{
1024 		.base = {
1025 			.cra_name		= "__ecb(aes)",
1026 			.cra_driver_name	= "__ecb-aes-aesni",
1027 			.cra_priority		= 400,
1028 			.cra_flags		= CRYPTO_ALG_INTERNAL,
1029 			.cra_blocksize		= AES_BLOCK_SIZE,
1030 			.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
1031 			.cra_module		= THIS_MODULE,
1032 		},
1033 		.min_keysize	= AES_MIN_KEY_SIZE,
1034 		.max_keysize	= AES_MAX_KEY_SIZE,
1035 		.setkey		= aesni_skcipher_setkey,
1036 		.encrypt	= ecb_encrypt,
1037 		.decrypt	= ecb_decrypt,
1038 	}, {
1039 		.base = {
1040 			.cra_name		= "__cbc(aes)",
1041 			.cra_driver_name	= "__cbc-aes-aesni",
1042 			.cra_priority		= 400,
1043 			.cra_flags		= CRYPTO_ALG_INTERNAL,
1044 			.cra_blocksize		= AES_BLOCK_SIZE,
1045 			.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
1046 			.cra_module		= THIS_MODULE,
1047 		},
1048 		.min_keysize	= AES_MIN_KEY_SIZE,
1049 		.max_keysize	= AES_MAX_KEY_SIZE,
1050 		.ivsize		= AES_BLOCK_SIZE,
1051 		.setkey		= aesni_skcipher_setkey,
1052 		.encrypt	= cbc_encrypt,
1053 		.decrypt	= cbc_decrypt,
1054 	}, {
1055 		.base = {
1056 			.cra_name		= "__cts(cbc(aes))",
1057 			.cra_driver_name	= "__cts-cbc-aes-aesni",
1058 			.cra_priority		= 400,
1059 			.cra_flags		= CRYPTO_ALG_INTERNAL,
1060 			.cra_blocksize		= AES_BLOCK_SIZE,
1061 			.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
1062 			.cra_module		= THIS_MODULE,
1063 		},
1064 		.min_keysize	= AES_MIN_KEY_SIZE,
1065 		.max_keysize	= AES_MAX_KEY_SIZE,
1066 		.ivsize		= AES_BLOCK_SIZE,
1067 		.walksize	= 2 * AES_BLOCK_SIZE,
1068 		.setkey		= aesni_skcipher_setkey,
1069 		.encrypt	= cts_cbc_encrypt,
1070 		.decrypt	= cts_cbc_decrypt,
1071 #ifdef CONFIG_X86_64
1072 	}, {
1073 		.base = {
1074 			.cra_name		= "__ctr(aes)",
1075 			.cra_driver_name	= "__ctr-aes-aesni",
1076 			.cra_priority		= 400,
1077 			.cra_flags		= CRYPTO_ALG_INTERNAL,
1078 			.cra_blocksize		= 1,
1079 			.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
1080 			.cra_module		= THIS_MODULE,
1081 		},
1082 		.min_keysize	= AES_MIN_KEY_SIZE,
1083 		.max_keysize	= AES_MAX_KEY_SIZE,
1084 		.ivsize		= AES_BLOCK_SIZE,
1085 		.chunksize	= AES_BLOCK_SIZE,
1086 		.setkey		= aesni_skcipher_setkey,
1087 		.encrypt	= ctr_crypt,
1088 		.decrypt	= ctr_crypt,
1089 #endif
1090 	}, {
1091 		.base = {
1092 			.cra_name		= "__xts(aes)",
1093 			.cra_driver_name	= "__xts-aes-aesni",
1094 			.cra_priority		= 401,
1095 			.cra_flags		= CRYPTO_ALG_INTERNAL,
1096 			.cra_blocksize		= AES_BLOCK_SIZE,
1097 			.cra_ctxsize		= XTS_AES_CTX_SIZE,
1098 			.cra_module		= THIS_MODULE,
1099 		},
1100 		.min_keysize	= 2 * AES_MIN_KEY_SIZE,
1101 		.max_keysize	= 2 * AES_MAX_KEY_SIZE,
1102 		.ivsize		= AES_BLOCK_SIZE,
1103 		.walksize	= 2 * AES_BLOCK_SIZE,
1104 		.setkey		= xts_aesni_setkey,
1105 		.encrypt	= xts_encrypt,
1106 		.decrypt	= xts_decrypt,
1107 	}
1108 };
1109 
1110 static
1111 struct simd_skcipher_alg *aesni_simd_skciphers[ARRAY_SIZE(aesni_skciphers)];
1112 
1113 #ifdef CONFIG_X86_64
1114 /*
1115  * XCTR does not have a non-AVX implementation, so it must be enabled
1116  * conditionally.
1117  */
1118 static struct skcipher_alg aesni_xctr = {
1119 	.base = {
1120 		.cra_name		= "__xctr(aes)",
1121 		.cra_driver_name	= "__xctr-aes-aesni",
1122 		.cra_priority		= 400,
1123 		.cra_flags		= CRYPTO_ALG_INTERNAL,
1124 		.cra_blocksize		= 1,
1125 		.cra_ctxsize		= CRYPTO_AES_CTX_SIZE,
1126 		.cra_module		= THIS_MODULE,
1127 	},
1128 	.min_keysize	= AES_MIN_KEY_SIZE,
1129 	.max_keysize	= AES_MAX_KEY_SIZE,
1130 	.ivsize		= AES_BLOCK_SIZE,
1131 	.chunksize	= AES_BLOCK_SIZE,
1132 	.setkey		= aesni_skcipher_setkey,
1133 	.encrypt	= xctr_crypt,
1134 	.decrypt	= xctr_crypt,
1135 };
1136 
1137 static struct simd_skcipher_alg *aesni_simd_xctr;
1138 
1139 asmlinkage void aes_xts_encrypt_iv(const struct crypto_aes_ctx *tweak_key,
1140 				   u8 iv[AES_BLOCK_SIZE]);
1141 
1142 typedef void (*xts_asm_func)(const struct crypto_aes_ctx *key,
1143 			     const u8 *src, u8 *dst, size_t len,
1144 			     u8 tweak[AES_BLOCK_SIZE]);
1145 
1146 /* This handles cases where the source and/or destination span pages. */
1147 static noinline int
1148 xts_crypt_slowpath(struct skcipher_request *req, xts_asm_func asm_func)
1149 {
1150 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
1151 	const struct aesni_xts_ctx *ctx = aes_xts_ctx(tfm);
1152 	int tail = req->cryptlen % AES_BLOCK_SIZE;
1153 	struct scatterlist sg_src[2], sg_dst[2];
1154 	struct skcipher_request subreq;
1155 	struct skcipher_walk walk;
1156 	struct scatterlist *src, *dst;
1157 	int err;
1158 
1159 	/*
1160 	 * If the message length isn't divisible by the AES block size, then
1161 	 * separate off the last full block and the partial block.  This ensures
1162 	 * that they are processed in the same call to the assembly function,
1163 	 * which is required for ciphertext stealing.
1164 	 */
1165 	if (tail) {
1166 		skcipher_request_set_tfm(&subreq, tfm);
1167 		skcipher_request_set_callback(&subreq,
1168 					      skcipher_request_flags(req),
1169 					      NULL, NULL);
1170 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
1171 					   req->cryptlen - tail - AES_BLOCK_SIZE,
1172 					   req->iv);
1173 		req = &subreq;
1174 	}
1175 
1176 	err = skcipher_walk_virt(&walk, req, false);
1177 
1178 	while (walk.nbytes) {
1179 		unsigned int nbytes = walk.nbytes;
1180 
1181 		if (nbytes < walk.total)
1182 			nbytes = round_down(nbytes, AES_BLOCK_SIZE);
1183 
1184 		kernel_fpu_begin();
1185 		(*asm_func)(&ctx->crypt_ctx, walk.src.virt.addr,
1186 			    walk.dst.virt.addr, nbytes, req->iv);
1187 		kernel_fpu_end();
1188 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
1189 	}
1190 
1191 	if (err || !tail)
1192 		return err;
1193 
1194 	/* Do ciphertext stealing with the last full block and partial block. */
1195 
1196 	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
1197 	if (req->dst != req->src)
1198 		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
1199 
1200 	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
1201 				   req->iv);
1202 
1203 	err = skcipher_walk_virt(&walk, req, false);
1204 	if (err)
1205 		return err;
1206 
1207 	kernel_fpu_begin();
1208 	(*asm_func)(&ctx->crypt_ctx, walk.src.virt.addr, walk.dst.virt.addr,
1209 		    walk.nbytes, req->iv);
1210 	kernel_fpu_end();
1211 
1212 	return skcipher_walk_done(&walk, 0);
1213 }
1214 
1215 /* __always_inline to avoid indirect call in fastpath */
1216 static __always_inline int
1217 xts_crypt2(struct skcipher_request *req, xts_asm_func asm_func)
1218 {
1219 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
1220 	const struct aesni_xts_ctx *ctx = aes_xts_ctx(tfm);
1221 	const unsigned int cryptlen = req->cryptlen;
1222 	struct scatterlist *src = req->src;
1223 	struct scatterlist *dst = req->dst;
1224 
1225 	if (unlikely(cryptlen < AES_BLOCK_SIZE))
1226 		return -EINVAL;
1227 
1228 	kernel_fpu_begin();
1229 	aes_xts_encrypt_iv(&ctx->tweak_ctx, req->iv);
1230 
1231 	/*
1232 	 * In practice, virtually all XTS plaintexts and ciphertexts are either
1233 	 * 512 or 4096 bytes, aligned such that they don't span page boundaries.
1234 	 * To optimize the performance of these cases, and also any other case
1235 	 * where no page boundary is spanned, the below fast-path handles
1236 	 * single-page sources and destinations as efficiently as possible.
1237 	 */
1238 	if (likely(src->length >= cryptlen && dst->length >= cryptlen &&
1239 		   src->offset + cryptlen <= PAGE_SIZE &&
1240 		   dst->offset + cryptlen <= PAGE_SIZE)) {
1241 		struct page *src_page = sg_page(src);
1242 		struct page *dst_page = sg_page(dst);
1243 		void *src_virt = kmap_local_page(src_page) + src->offset;
1244 		void *dst_virt = kmap_local_page(dst_page) + dst->offset;
1245 
1246 		(*asm_func)(&ctx->crypt_ctx, src_virt, dst_virt, cryptlen,
1247 			    req->iv);
1248 		kunmap_local(dst_virt);
1249 		kunmap_local(src_virt);
1250 		kernel_fpu_end();
1251 		return 0;
1252 	}
1253 	kernel_fpu_end();
1254 	return xts_crypt_slowpath(req, asm_func);
1255 }
1256 
1257 #define DEFINE_XTS_ALG(suffix, driver_name, priority)			       \
1258 									       \
1259 asmlinkage void aes_xts_encrypt_##suffix(const struct crypto_aes_ctx *key,     \
1260 					 const u8 *src, u8 *dst, size_t len,   \
1261 					 u8 tweak[AES_BLOCK_SIZE]);	       \
1262 asmlinkage void aes_xts_decrypt_##suffix(const struct crypto_aes_ctx *key,     \
1263 					 const u8 *src, u8 *dst, size_t len,   \
1264 					 u8 tweak[AES_BLOCK_SIZE]);	       \
1265 									       \
1266 static int xts_encrypt_##suffix(struct skcipher_request *req)		       \
1267 {									       \
1268 	return xts_crypt2(req, aes_xts_encrypt_##suffix);		       \
1269 }									       \
1270 									       \
1271 static int xts_decrypt_##suffix(struct skcipher_request *req)		       \
1272 {									       \
1273 	return xts_crypt2(req, aes_xts_decrypt_##suffix);		       \
1274 }									       \
1275 									       \
1276 static struct skcipher_alg aes_xts_alg_##suffix = {			       \
1277 	.base = {							       \
1278 		.cra_name		= "__xts(aes)",			       \
1279 		.cra_driver_name	= "__" driver_name,		       \
1280 		.cra_priority		= priority,			       \
1281 		.cra_flags		= CRYPTO_ALG_INTERNAL,		       \
1282 		.cra_blocksize		= AES_BLOCK_SIZE,		       \
1283 		.cra_ctxsize		= XTS_AES_CTX_SIZE,		       \
1284 		.cra_module		= THIS_MODULE,			       \
1285 	},								       \
1286 	.min_keysize	= 2 * AES_MIN_KEY_SIZE,				       \
1287 	.max_keysize	= 2 * AES_MAX_KEY_SIZE,				       \
1288 	.ivsize		= AES_BLOCK_SIZE,				       \
1289 	.walksize	= 2 * AES_BLOCK_SIZE,				       \
1290 	.setkey		= xts_aesni_setkey,				       \
1291 	.encrypt	= xts_encrypt_##suffix,				       \
1292 	.decrypt	= xts_decrypt_##suffix,				       \
1293 };									       \
1294 									       \
1295 static struct simd_skcipher_alg *aes_xts_simdalg_##suffix
1296 
1297 DEFINE_XTS_ALG(aesni_avx, "xts-aes-aesni-avx", 500);
1298 #if defined(CONFIG_AS_VAES) && defined(CONFIG_AS_VPCLMULQDQ)
1299 DEFINE_XTS_ALG(vaes_avx2, "xts-aes-vaes-avx2", 600);
1300 DEFINE_XTS_ALG(vaes_avx10_256, "xts-aes-vaes-avx10_256", 700);
1301 DEFINE_XTS_ALG(vaes_avx10_512, "xts-aes-vaes-avx10_512", 800);
1302 #endif
1303 
1304 /*
1305  * This is a list of CPU models that are known to suffer from downclocking when
1306  * zmm registers (512-bit vectors) are used.  On these CPUs, the AES-XTS
1307  * implementation with zmm registers won't be used by default.  An
1308  * implementation with ymm registers (256-bit vectors) will be used instead.
1309  */
1310 static const struct x86_cpu_id zmm_exclusion_list[] = {
1311 	{ .vendor = X86_VENDOR_INTEL, .family = 6, .model = INTEL_FAM6_SKYLAKE_X },
1312 	{ .vendor = X86_VENDOR_INTEL, .family = 6, .model = INTEL_FAM6_ICELAKE_X },
1313 	{ .vendor = X86_VENDOR_INTEL, .family = 6, .model = INTEL_FAM6_ICELAKE_D },
1314 	{ .vendor = X86_VENDOR_INTEL, .family = 6, .model = INTEL_FAM6_ICELAKE },
1315 	{ .vendor = X86_VENDOR_INTEL, .family = 6, .model = INTEL_FAM6_ICELAKE_L },
1316 	{ .vendor = X86_VENDOR_INTEL, .family = 6, .model = INTEL_FAM6_ICELAKE_NNPI },
1317 	{ .vendor = X86_VENDOR_INTEL, .family = 6, .model = INTEL_FAM6_TIGERLAKE_L },
1318 	{ .vendor = X86_VENDOR_INTEL, .family = 6, .model = INTEL_FAM6_TIGERLAKE },
1319 	/* Allow Rocket Lake and later, and Sapphire Rapids and later. */
1320 	/* Also allow AMD CPUs (starting with Zen 4, the first with AVX-512). */
1321 	{},
1322 };
1323 
1324 static int __init register_xts_algs(void)
1325 {
1326 	int err;
1327 
1328 	if (!boot_cpu_has(X86_FEATURE_AVX))
1329 		return 0;
1330 	err = simd_register_skciphers_compat(&aes_xts_alg_aesni_avx, 1,
1331 					     &aes_xts_simdalg_aesni_avx);
1332 	if (err)
1333 		return err;
1334 #if defined(CONFIG_AS_VAES) && defined(CONFIG_AS_VPCLMULQDQ)
1335 	if (!boot_cpu_has(X86_FEATURE_AVX2) ||
1336 	    !boot_cpu_has(X86_FEATURE_VAES) ||
1337 	    !boot_cpu_has(X86_FEATURE_VPCLMULQDQ) ||
1338 	    !boot_cpu_has(X86_FEATURE_PCLMULQDQ) ||
1339 	    !cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM, NULL))
1340 		return 0;
1341 	err = simd_register_skciphers_compat(&aes_xts_alg_vaes_avx2, 1,
1342 					     &aes_xts_simdalg_vaes_avx2);
1343 	if (err)
1344 		return err;
1345 
1346 	if (!boot_cpu_has(X86_FEATURE_AVX512BW) ||
1347 	    !boot_cpu_has(X86_FEATURE_AVX512VL) ||
1348 	    !boot_cpu_has(X86_FEATURE_BMI2) ||
1349 	    !cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM |
1350 			       XFEATURE_MASK_AVX512, NULL))
1351 		return 0;
1352 
1353 	err = simd_register_skciphers_compat(&aes_xts_alg_vaes_avx10_256, 1,
1354 					     &aes_xts_simdalg_vaes_avx10_256);
1355 	if (err)
1356 		return err;
1357 
1358 	if (x86_match_cpu(zmm_exclusion_list))
1359 		aes_xts_alg_vaes_avx10_512.base.cra_priority = 1;
1360 
1361 	err = simd_register_skciphers_compat(&aes_xts_alg_vaes_avx10_512, 1,
1362 					     &aes_xts_simdalg_vaes_avx10_512);
1363 	if (err)
1364 		return err;
1365 #endif /* CONFIG_AS_VAES && CONFIG_AS_VPCLMULQDQ */
1366 	return 0;
1367 }
1368 
1369 static void unregister_xts_algs(void)
1370 {
1371 	if (aes_xts_simdalg_aesni_avx)
1372 		simd_unregister_skciphers(&aes_xts_alg_aesni_avx, 1,
1373 					  &aes_xts_simdalg_aesni_avx);
1374 #if defined(CONFIG_AS_VAES) && defined(CONFIG_AS_VPCLMULQDQ)
1375 	if (aes_xts_simdalg_vaes_avx2)
1376 		simd_unregister_skciphers(&aes_xts_alg_vaes_avx2, 1,
1377 					  &aes_xts_simdalg_vaes_avx2);
1378 	if (aes_xts_simdalg_vaes_avx10_256)
1379 		simd_unregister_skciphers(&aes_xts_alg_vaes_avx10_256, 1,
1380 					  &aes_xts_simdalg_vaes_avx10_256);
1381 	if (aes_xts_simdalg_vaes_avx10_512)
1382 		simd_unregister_skciphers(&aes_xts_alg_vaes_avx10_512, 1,
1383 					  &aes_xts_simdalg_vaes_avx10_512);
1384 #endif
1385 }
1386 #else /* CONFIG_X86_64 */
1387 static int __init register_xts_algs(void)
1388 {
1389 	return 0;
1390 }
1391 
1392 static void unregister_xts_algs(void)
1393 {
1394 }
1395 #endif /* !CONFIG_X86_64 */
1396 
1397 #ifdef CONFIG_X86_64
1398 static int generic_gcmaes_set_key(struct crypto_aead *aead, const u8 *key,
1399 				  unsigned int key_len)
1400 {
1401 	struct generic_gcmaes_ctx *ctx = generic_gcmaes_ctx_get(aead);
1402 
1403 	return aes_set_key_common(&ctx->aes_key_expanded, key, key_len) ?:
1404 	       rfc4106_set_hash_subkey(ctx->hash_subkey, key, key_len);
1405 }
1406 
1407 static int generic_gcmaes_encrypt(struct aead_request *req)
1408 {
1409 	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
1410 	struct generic_gcmaes_ctx *ctx = generic_gcmaes_ctx_get(tfm);
1411 	void *aes_ctx = &(ctx->aes_key_expanded);
1412 	u8 ivbuf[16 + (AESNI_ALIGN - 8)] __aligned(8);
1413 	u8 *iv = PTR_ALIGN(&ivbuf[0], AESNI_ALIGN);
1414 	__be32 counter = cpu_to_be32(1);
1415 
1416 	memcpy(iv, req->iv, 12);
1417 	*((__be32 *)(iv+12)) = counter;
1418 
1419 	return gcmaes_encrypt(req, req->assoclen, ctx->hash_subkey, iv,
1420 			      aes_ctx);
1421 }
1422 
1423 static int generic_gcmaes_decrypt(struct aead_request *req)
1424 {
1425 	__be32 counter = cpu_to_be32(1);
1426 	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
1427 	struct generic_gcmaes_ctx *ctx = generic_gcmaes_ctx_get(tfm);
1428 	void *aes_ctx = &(ctx->aes_key_expanded);
1429 	u8 ivbuf[16 + (AESNI_ALIGN - 8)] __aligned(8);
1430 	u8 *iv = PTR_ALIGN(&ivbuf[0], AESNI_ALIGN);
1431 
1432 	memcpy(iv, req->iv, 12);
1433 	*((__be32 *)(iv+12)) = counter;
1434 
1435 	return gcmaes_decrypt(req, req->assoclen, ctx->hash_subkey, iv,
1436 			      aes_ctx);
1437 }
1438 
1439 static struct aead_alg aesni_aeads[] = { {
1440 	.setkey			= common_rfc4106_set_key,
1441 	.setauthsize		= common_rfc4106_set_authsize,
1442 	.encrypt		= helper_rfc4106_encrypt,
1443 	.decrypt		= helper_rfc4106_decrypt,
1444 	.ivsize			= GCM_RFC4106_IV_SIZE,
1445 	.maxauthsize		= 16,
1446 	.base = {
1447 		.cra_name		= "__rfc4106(gcm(aes))",
1448 		.cra_driver_name	= "__rfc4106-gcm-aesni",
1449 		.cra_priority		= 400,
1450 		.cra_flags		= CRYPTO_ALG_INTERNAL,
1451 		.cra_blocksize		= 1,
1452 		.cra_ctxsize		= sizeof(struct aesni_rfc4106_gcm_ctx),
1453 		.cra_alignmask		= 0,
1454 		.cra_module		= THIS_MODULE,
1455 	},
1456 }, {
1457 	.setkey			= generic_gcmaes_set_key,
1458 	.setauthsize		= generic_gcmaes_set_authsize,
1459 	.encrypt		= generic_gcmaes_encrypt,
1460 	.decrypt		= generic_gcmaes_decrypt,
1461 	.ivsize			= GCM_AES_IV_SIZE,
1462 	.maxauthsize		= 16,
1463 	.base = {
1464 		.cra_name		= "__gcm(aes)",
1465 		.cra_driver_name	= "__generic-gcm-aesni",
1466 		.cra_priority		= 400,
1467 		.cra_flags		= CRYPTO_ALG_INTERNAL,
1468 		.cra_blocksize		= 1,
1469 		.cra_ctxsize		= sizeof(struct generic_gcmaes_ctx),
1470 		.cra_alignmask		= 0,
1471 		.cra_module		= THIS_MODULE,
1472 	},
1473 } };
1474 #else
1475 static struct aead_alg aesni_aeads[0];
1476 #endif
1477 
1478 static struct simd_aead_alg *aesni_simd_aeads[ARRAY_SIZE(aesni_aeads)];
1479 
1480 static const struct x86_cpu_id aesni_cpu_id[] = {
1481 	X86_MATCH_FEATURE(X86_FEATURE_AES, NULL),
1482 	{}
1483 };
1484 MODULE_DEVICE_TABLE(x86cpu, aesni_cpu_id);
1485 
1486 static int __init aesni_init(void)
1487 {
1488 	int err;
1489 
1490 	if (!x86_match_cpu(aesni_cpu_id))
1491 		return -ENODEV;
1492 #ifdef CONFIG_X86_64
1493 	if (boot_cpu_has(X86_FEATURE_AVX2)) {
1494 		pr_info("AVX2 version of gcm_enc/dec engaged.\n");
1495 		static_branch_enable(&gcm_use_avx);
1496 		static_branch_enable(&gcm_use_avx2);
1497 	} else
1498 	if (boot_cpu_has(X86_FEATURE_AVX)) {
1499 		pr_info("AVX version of gcm_enc/dec engaged.\n");
1500 		static_branch_enable(&gcm_use_avx);
1501 	} else {
1502 		pr_info("SSE version of gcm_enc/dec engaged.\n");
1503 	}
1504 	if (boot_cpu_has(X86_FEATURE_AVX)) {
1505 		/* optimize performance of ctr mode encryption transform */
1506 		static_call_update(aesni_ctr_enc_tfm, aesni_ctr_enc_avx_tfm);
1507 		pr_info("AES CTR mode by8 optimization enabled\n");
1508 	}
1509 #endif /* CONFIG_X86_64 */
1510 
1511 	err = crypto_register_alg(&aesni_cipher_alg);
1512 	if (err)
1513 		return err;
1514 
1515 	err = simd_register_skciphers_compat(aesni_skciphers,
1516 					     ARRAY_SIZE(aesni_skciphers),
1517 					     aesni_simd_skciphers);
1518 	if (err)
1519 		goto unregister_cipher;
1520 
1521 	err = simd_register_aeads_compat(aesni_aeads, ARRAY_SIZE(aesni_aeads),
1522 					 aesni_simd_aeads);
1523 	if (err)
1524 		goto unregister_skciphers;
1525 
1526 #ifdef CONFIG_X86_64
1527 	if (boot_cpu_has(X86_FEATURE_AVX))
1528 		err = simd_register_skciphers_compat(&aesni_xctr, 1,
1529 						     &aesni_simd_xctr);
1530 	if (err)
1531 		goto unregister_aeads;
1532 #endif /* CONFIG_X86_64 */
1533 
1534 	err = register_xts_algs();
1535 	if (err)
1536 		goto unregister_xts;
1537 
1538 	return 0;
1539 
1540 unregister_xts:
1541 	unregister_xts_algs();
1542 #ifdef CONFIG_X86_64
1543 	if (aesni_simd_xctr)
1544 		simd_unregister_skciphers(&aesni_xctr, 1, &aesni_simd_xctr);
1545 unregister_aeads:
1546 #endif /* CONFIG_X86_64 */
1547 	simd_unregister_aeads(aesni_aeads, ARRAY_SIZE(aesni_aeads),
1548 				aesni_simd_aeads);
1549 
1550 unregister_skciphers:
1551 	simd_unregister_skciphers(aesni_skciphers, ARRAY_SIZE(aesni_skciphers),
1552 				  aesni_simd_skciphers);
1553 unregister_cipher:
1554 	crypto_unregister_alg(&aesni_cipher_alg);
1555 	return err;
1556 }
1557 
1558 static void __exit aesni_exit(void)
1559 {
1560 	simd_unregister_aeads(aesni_aeads, ARRAY_SIZE(aesni_aeads),
1561 			      aesni_simd_aeads);
1562 	simd_unregister_skciphers(aesni_skciphers, ARRAY_SIZE(aesni_skciphers),
1563 				  aesni_simd_skciphers);
1564 	crypto_unregister_alg(&aesni_cipher_alg);
1565 #ifdef CONFIG_X86_64
1566 	if (boot_cpu_has(X86_FEATURE_AVX))
1567 		simd_unregister_skciphers(&aesni_xctr, 1, &aesni_simd_xctr);
1568 #endif /* CONFIG_X86_64 */
1569 	unregister_xts_algs();
1570 }
1571 
1572 late_initcall(aesni_init);
1573 module_exit(aesni_exit);
1574 
1575 MODULE_DESCRIPTION("Rijndael (AES) Cipher Algorithm, Intel AES-NI instructions optimized");
1576 MODULE_LICENSE("GPL");
1577 MODULE_ALIAS_CRYPTO("aes");
1578