xref: /linux/lib/crypto/riscv/aes.h (revision a4e573db06a4e8c519ec4c42f8e1249a0853367a)
1 /* SPDX-License-Identifier: GPL-2.0-only */
2 /*
3  * Copyright (C) 2023 VRULL GmbH
4  * Copyright (C) 2023 SiFive, Inc.
5  * Copyright 2024 Google LLC
6  */
7 
8 #include <asm/simd.h>
9 #include <asm/vector.h>
10 
11 static __ro_after_init DEFINE_STATIC_KEY_FALSE(have_zvkned);
12 
13 void aes_encrypt_zvkned(const u32 rndkeys[], int key_len,
14 			u8 out[AES_BLOCK_SIZE], const u8 in[AES_BLOCK_SIZE]);
15 void aes_decrypt_zvkned(const u32 rndkeys[], int key_len,
16 			u8 out[AES_BLOCK_SIZE], const u8 in[AES_BLOCK_SIZE]);
17 
18 static void aes_preparekey_arch(union aes_enckey_arch *k,
19 				union aes_invkey_arch *inv_k,
20 				const u8 *in_key, int key_len, int nrounds)
21 {
22 	aes_expandkey_generic(k->rndkeys, inv_k ? inv_k->inv_rndkeys : NULL,
23 			      in_key, key_len);
24 }
25 
26 static void aes_encrypt_arch(const struct aes_enckey *key,
27 			     u8 out[AES_BLOCK_SIZE],
28 			     const u8 in[AES_BLOCK_SIZE])
29 {
30 	if (static_branch_likely(&have_zvkned) && likely(may_use_simd())) {
31 		kernel_vector_begin();
32 		aes_encrypt_zvkned(key->k.rndkeys, key->len, out, in);
33 		kernel_vector_end();
34 	} else {
35 		aes_encrypt_generic(key->k.rndkeys, key->nrounds, out, in);
36 	}
37 }
38 
39 static void aes_decrypt_arch(const struct aes_key *key,
40 			     u8 out[AES_BLOCK_SIZE],
41 			     const u8 in[AES_BLOCK_SIZE])
42 {
43 	/*
44 	 * Note that the Zvkned code uses the standard round keys, while the
45 	 * fallback uses the inverse round keys.  Thus both must be present.
46 	 */
47 	if (static_branch_likely(&have_zvkned) && likely(may_use_simd())) {
48 		kernel_vector_begin();
49 		aes_decrypt_zvkned(key->k.rndkeys, key->len, out, in);
50 		kernel_vector_end();
51 	} else {
52 		aes_decrypt_generic(key->inv_k.inv_rndkeys, key->nrounds,
53 				    out, in);
54 	}
55 }
56 
57 #define aes_mod_init_arch aes_mod_init_arch
58 static void aes_mod_init_arch(void)
59 {
60 	if (riscv_isa_extension_available(NULL, ZVKNED) &&
61 	    riscv_vector_vlen() >= 128)
62 		static_branch_enable(&have_zvkned);
63 }
64