xref: /linux/arch/riscv/crypto/sm3-riscv64-glue.c (revision 1d5198dd08ac04b13a8b7539131baf0980998032)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * SM3 using the RISC-V vector crypto extensions
4  *
5  * Copyright (C) 2023 VRULL GmbH
6  * Author: Heiko Stuebner <heiko.stuebner@vrull.eu>
7  *
8  * Copyright (C) 2023 SiFive, Inc.
9  * Author: Jerry Shih <jerry.shih@sifive.com>
10  */
11 
12 #include <asm/simd.h>
13 #include <asm/vector.h>
14 #include <crypto/internal/hash.h>
15 #include <crypto/internal/simd.h>
16 #include <crypto/sm3_base.h>
17 #include <linux/linkage.h>
18 #include <linux/module.h>
19 
20 /*
21  * Note: the asm function only uses the 'state' field of struct sm3_state.
22  * It is assumed to be the first field.
23  */
24 asmlinkage void sm3_transform_zvksh_zvkb(
25 	struct sm3_state *state, const u8 *data, int num_blocks);
26 
27 static int riscv64_sm3_update(struct shash_desc *desc, const u8 *data,
28 			      unsigned int len)
29 {
30 	/*
31 	 * Ensure struct sm3_state begins directly with the SM3
32 	 * 256-bit internal state, as this is what the asm function expects.
33 	 */
34 	BUILD_BUG_ON(offsetof(struct sm3_state, state) != 0);
35 
36 	if (crypto_simd_usable()) {
37 		kernel_vector_begin();
38 		sm3_base_do_update(desc, data, len, sm3_transform_zvksh_zvkb);
39 		kernel_vector_end();
40 	} else {
41 		sm3_update(shash_desc_ctx(desc), data, len);
42 	}
43 	return 0;
44 }
45 
46 static int riscv64_sm3_finup(struct shash_desc *desc, const u8 *data,
47 			     unsigned int len, u8 *out)
48 {
49 	struct sm3_state *ctx;
50 
51 	if (crypto_simd_usable()) {
52 		kernel_vector_begin();
53 		if (len)
54 			sm3_base_do_update(desc, data, len,
55 					   sm3_transform_zvksh_zvkb);
56 		sm3_base_do_finalize(desc, sm3_transform_zvksh_zvkb);
57 		kernel_vector_end();
58 
59 		return sm3_base_finish(desc, out);
60 	}
61 
62 	ctx = shash_desc_ctx(desc);
63 	if (len)
64 		sm3_update(ctx, data, len);
65 	sm3_final(ctx, out);
66 
67 	return 0;
68 }
69 
70 static int riscv64_sm3_final(struct shash_desc *desc, u8 *out)
71 {
72 	return riscv64_sm3_finup(desc, NULL, 0, out);
73 }
74 
75 static struct shash_alg riscv64_sm3_alg = {
76 	.init = sm3_base_init,
77 	.update = riscv64_sm3_update,
78 	.final = riscv64_sm3_final,
79 	.finup = riscv64_sm3_finup,
80 	.descsize = sizeof(struct sm3_state),
81 	.digestsize = SM3_DIGEST_SIZE,
82 	.base = {
83 		.cra_blocksize = SM3_BLOCK_SIZE,
84 		.cra_priority = 300,
85 		.cra_name = "sm3",
86 		.cra_driver_name = "sm3-riscv64-zvksh-zvkb",
87 		.cra_module = THIS_MODULE,
88 	},
89 };
90 
91 static int __init riscv64_sm3_mod_init(void)
92 {
93 	if (riscv_isa_extension_available(NULL, ZVKSH) &&
94 	    riscv_isa_extension_available(NULL, ZVKB) &&
95 	    riscv_vector_vlen() >= 128)
96 		return crypto_register_shash(&riscv64_sm3_alg);
97 
98 	return -ENODEV;
99 }
100 
101 static void __exit riscv64_sm3_mod_exit(void)
102 {
103 	crypto_unregister_shash(&riscv64_sm3_alg);
104 }
105 
106 module_init(riscv64_sm3_mod_init);
107 module_exit(riscv64_sm3_mod_exit);
108 
109 MODULE_DESCRIPTION("SM3 (RISC-V accelerated)");
110 MODULE_AUTHOR("Heiko Stuebner <heiko.stuebner@vrull.eu>");
111 MODULE_LICENSE("GPL");
112 MODULE_ALIAS_CRYPTO("sm3");
113