xref: /freebsd/crypto/libecc/src/hash/sm3.c (revision 0e8011faf58b743cc652e3b2ad0f7671227610df)
1 /*
2  *  Copyright (C) 2021 - This file is part of libecc project
3  *
4  *  Authors:
5  *      Arnaud EBALARD <arnaud.ebalard@ssi.gouv.fr>
6  *      Ryad BENADJILA <ryadbenadjila@gmail.com>
7  *
8  *  This software is licensed under a dual BSD and GPL v2 license.
9  *  See LICENSE file at the root folder of the project.
10  */
11 #include <libecc/lib_ecc_config.h>
12 #ifdef WITH_HASH_SM3
13 
14 #include <libecc/hash/sm3.h>
15 
16 /*
17  * 32-bit integer manipulation macros (big endian)
18  */
19 #ifndef GET_UINT32_BE
20 #define GET_UINT32_BE(n, b, i)				\
21 do {							\
22 	(n) =	  ( ((u32) (b)[(i)   ])  << 24 )	\
23 		| ( ((u32) (b)[(i) + 1]) << 16 )	\
24 		| ( ((u32) (b)[(i) + 2]) <<  8 )	\
25 		| ( ((u32) (b)[(i) + 3])       );	\
26 } while( 0 )
27 #endif
28 
29 #ifndef PUT_UINT32_BE
30 #define PUT_UINT32_BE(n, b, i)			\
31 do {						\
32 	(b)[(i)    ] = (u8) ( (n) >> 24 );	\
33 	(b)[(i) + 1] = (u8) ( (n) >> 16 );	\
34 	(b)[(i) + 2] = (u8) ( (n) >>  8 );	\
35 	(b)[(i) + 3] = (u8) ( (n)       );	\
36 } while( 0 )
37 #endif
38 
39 /*
40  * 64-bit integer manipulation macros (big endian)
41  */
42 #ifndef PUT_UINT64_BE
43 #define PUT_UINT64_BE(n,b,i)		\
44 do {					\
45     (b)[(i)    ] = (u8) ( (n) >> 56 );	\
46     (b)[(i) + 1] = (u8) ( (n) >> 48 );	\
47     (b)[(i) + 2] = (u8) ( (n) >> 40 );	\
48     (b)[(i) + 3] = (u8) ( (n) >> 32 );	\
49     (b)[(i) + 4] = (u8) ( (n) >> 24 );	\
50     (b)[(i) + 5] = (u8) ( (n) >> 16 );	\
51     (b)[(i) + 6] = (u8) ( (n) >>  8 );	\
52     (b)[(i) + 7] = (u8) ( (n)       );	\
53 } while( 0 )
54 #endif /* PUT_UINT64_BE */
55 
56 
57 
58 static const u32 SM3_Tj_low  = 0x79cc4519;
59 static const u32 SM3_Tj_high = 0x7a879d8a;
60 
61 /* Boolean functions FF_j and GG_j for 0 <= j <= 15 */
62 #define FF_j_low(X, Y, Z) (((u32)(X)) ^ ((u32)(Y)) ^ ((u32)(Z)))
63 #define GG_j_low(X, Y, Z) (((u32)(X)) ^ ((u32)(Y)) ^ ((u32)(Z)))
64 
65 /* Boolean functions FF_j and GG_j for 16 <= j <= 63 */
66 #define FF_j_high(X, Y, Z) ((((u32)(X)) & ((u32)(Y))) | \
67 			    (((u32)(X)) & ((u32)(Z))) | \
68 			    (((u32)(Y)) & ((u32)(Z))))
69 #define GG_j_high(X, Y, Z) ((((u32)(X)) & ((u32)(Y))) | \
70 			    ((~((u32)(X))) & ((u32)(Z))))
71 
72 /* 32-bit bitwise cyclic shift. Only support shifts value y < 32 */
73 #define _SM3_ROTL_(x, y) ((((u32)(x)) << (y)) | \
74 			(((u32)(x)) >> ((sizeof(u32) * 8) - (y))))
75 
76 #define SM3_ROTL(x, y) ((((y) < (sizeof(u32) * 8)) && ((y) > 0)) ? (_SM3_ROTL_(x, y)) : (x))
77 
78 /* Permutation Functions P_0 and P_1 */
79 #define SM3_P_0(X) (((u32)X) ^ SM3_ROTL((X),  9) ^ SM3_ROTL((X), 17))
80 #define SM3_P_1(X) (((u32)X) ^ SM3_ROTL((X), 15) ^ SM3_ROTL((X), 23))
81 
82 /* SM3 Iterative Compression Process
83  * NOTE: ctx and data sanity checks are performed by the caller (this is an internal function)
84  */
85 ATTRIBUTE_WARN_UNUSED_RET static int sm3_process(sm3_context *ctx, const u8 data[SM3_BLOCK_SIZE])
86 {
87 	u32 A, B, C, D, E, F, G, H;
88 	u32 SS1, SS2, TT1, TT2;
89 	u32 W[68 + 64];
90 	unsigned int j;
91 	int ret;
92 
93 	/* Message Expansion Function ME */
94 
95 	for (j = 0; j < 16; j++) {
96 		GET_UINT32_BE(W[j], data, 4 * j);
97 	}
98 
99 	for (j = 16; j < 68; j++) {
100 		W[j] = SM3_P_1(W[j - 16] ^ W[j - 9] ^ (SM3_ROTL(W[j - 3], 15))) ^
101 		       (SM3_ROTL(W[j - 13], 7)) ^ W[j - 6];
102 	}
103 
104 	for (j = 0; j < 64; j++) {
105 	   W[j + 68] = W[j] ^ W[j + 4];
106 	}
107 
108 	/* Compression Function CF */
109 
110 	A = ctx->sm3_state[0];
111 	B = ctx->sm3_state[1];
112 	C = ctx->sm3_state[2];
113 	D = ctx->sm3_state[3];
114 	E = ctx->sm3_state[4];
115 	F = ctx->sm3_state[5];
116 	G = ctx->sm3_state[6];
117 	H = ctx->sm3_state[7];
118 
119 	/*
120 	 * Note: in a previous version of the code, we had two loops for j from
121 	 * 0 to 15 and then from 16 to 63 with SM3_ROTL(SM3_Tj_low, (j & 0x1F))
122 	 * inside but clang-12 was smart enough to detect cases where SM3_ROTL
123 	 * macro is useless. On the other side, clang address sanitizer does not
124 	 * allow to remove the check for too high shift values in the macro
125 	 * itself. Creating 3 distinct loops instead of 2 to remove the & 0x1F
126 	 * is sufficient to satisfy everyone.
127 	 */
128 
129 	for (j = 0; j < 16; j++) {
130 		SS1 = SM3_ROTL(SM3_ROTL(A, 12) + E + SM3_ROTL(SM3_Tj_low, j),7);
131 		SS2 = SS1 ^ SM3_ROTL(A, 12);
132 		TT1 = FF_j_low(A, B, C) + D + SS2 + W[j + 68];
133 		TT2 = GG_j_low(E, F, G) + H + SS1 + W[j];
134 		D = C;
135 		C = SM3_ROTL(B, 9);
136 		B = A;
137 		A = TT1;
138 		H = G;
139 		G = SM3_ROTL(F, 19);
140 		F = E;
141 		E = SM3_P_0(TT2);
142 	}
143 
144 	for (j = 16; j < 32; j++) {
145 		SS1 = SM3_ROTL(SM3_ROTL(A, 12) + E + SM3_ROTL(SM3_Tj_high, j), 7);
146 		SS2 = SS1 ^ SM3_ROTL(A, 12);
147 		TT1 = FF_j_high(A, B, C) + D + SS2 + W[j + 68];
148 		TT2 = GG_j_high(E, F, G) + H + SS1 + W[j];
149 		D = C;
150 		C = SM3_ROTL(B, 9);
151 		B = A;
152 		A = TT1;
153 		H = G;
154 		G = SM3_ROTL(F, 19);
155 		F = E;
156 		E = SM3_P_0(TT2);
157 	}
158 
159 	for (j = 32; j < 64; j++) {
160 		SS1 = SM3_ROTL(SM3_ROTL(A, 12) + E + SM3_ROTL(SM3_Tj_high, (j - 32)), 7);
161 		SS2 = SS1 ^ SM3_ROTL(A, 12);
162 		TT1 = FF_j_high(A, B, C) + D + SS2 + W[j + 68];
163 		TT2 = GG_j_high(E, F, G) + H + SS1 + W[j];
164 		D = C;
165 		C = SM3_ROTL(B, 9);
166 		B = A;
167 		A = TT1;
168 		H = G;
169 		G = SM3_ROTL(F, 19);
170 		F = E;
171 		E = SM3_P_0(TT2);
172 	}
173 
174 	ctx->sm3_state[0] ^= A;
175 	ctx->sm3_state[1] ^= B;
176 	ctx->sm3_state[2] ^= C;
177 	ctx->sm3_state[3] ^= D;
178 	ctx->sm3_state[4] ^= E;
179 	ctx->sm3_state[5] ^= F;
180 	ctx->sm3_state[6] ^= G;
181 	ctx->sm3_state[7] ^= H;
182 
183 	ret = 0;
184 
185 	return ret;
186 }
187 
188 /* Init hash function. Initialize state to SM3 defined IV. */
189 int sm3_init(sm3_context *ctx)
190 {
191 	int ret;
192 
193 	MUST_HAVE(ctx != NULL, ret, err);
194 
195 	ctx->sm3_total = 0;
196 	ctx->sm3_state[0] = 0x7380166F;
197 	ctx->sm3_state[1] = 0x4914B2B9;
198 	ctx->sm3_state[2] = 0x172442D7;
199 	ctx->sm3_state[3] = 0xDA8A0600;
200 	ctx->sm3_state[4] = 0xA96F30BC;
201 	ctx->sm3_state[5] = 0x163138AA;
202 	ctx->sm3_state[6] = 0xE38DEE4D;
203 	ctx->sm3_state[7] = 0xB0FB0E4E;
204 
205 	/* Tell that we are initialized */
206 	ctx->magic = SM3_HASH_MAGIC;
207 
208 	ret = 0;
209 
210 err:
211 	return ret;
212 }
213 
214 /* Update hash function */
215 int sm3_update(sm3_context *ctx, const u8 *input, u32 ilen)
216 {
217 	const u8 *data_ptr = input;
218 	u32 remain_ilen = ilen;
219 	u16 fill;
220 	u8 left;
221 	int ret;
222 
223 	MUST_HAVE((input != NULL) || (ilen == 0), ret, err);
224 	SM3_HASH_CHECK_INITIALIZED(ctx, ret, err);
225 
226 	/* Nothing to process, return */
227 	if (ilen == 0) {
228 		ret = 0;
229 		goto err;
230 	}
231 
232 	/* Get what's left in our local buffer */
233 	left = (ctx->sm3_total & 0x3F);
234 	fill = (u16)(SM3_BLOCK_SIZE - left);
235 
236 	ctx->sm3_total += ilen;
237 
238 	if ((left > 0) && (remain_ilen >= fill)) {
239 		/* Copy data at the end of the buffer */
240 		ret = local_memcpy(ctx->sm3_buffer + left, data_ptr, fill); EG(ret, err);
241 		ret = sm3_process(ctx, ctx->sm3_buffer); EG(ret, err);
242 		data_ptr += fill;
243 		remain_ilen -= fill;
244 		left = 0;
245 	}
246 
247 	while (remain_ilen >= SM3_BLOCK_SIZE) {
248 		ret = sm3_process(ctx, data_ptr); EG(ret, err);
249 		data_ptr += SM3_BLOCK_SIZE;
250 		remain_ilen -= SM3_BLOCK_SIZE;
251 	}
252 
253 	if (remain_ilen > 0) {
254 		ret = local_memcpy(ctx->sm3_buffer + left, data_ptr, remain_ilen); EG(ret, err);
255 	}
256 
257 	ret = 0;
258 
259 err:
260 	return ret;
261 }
262 
263 /* Finalize */
264 int sm3_final(sm3_context *ctx, u8 output[SM3_DIGEST_SIZE])
265 {
266 	unsigned int block_present = 0;
267 	u8 last_padded_block[2 * SM3_BLOCK_SIZE];
268 	int ret;
269 
270 	MUST_HAVE((output != NULL), ret, err);
271 	SM3_HASH_CHECK_INITIALIZED(ctx, ret, err);
272 
273 	/* Fill in our last block with zeroes */
274 	ret = local_memset(last_padded_block, 0, sizeof(last_padded_block)); EG(ret, err);
275 
276 	/* This is our final step, so we proceed with the padding */
277 	block_present = (ctx->sm3_total % SM3_BLOCK_SIZE);
278 	if (block_present != 0) {
279 		/* Copy what's left in our temporary context buffer */
280 		ret = local_memcpy(last_padded_block, ctx->sm3_buffer,
281 			     block_present); EG(ret, err);
282 	}
283 
284 	/* Put the 0x80 byte, beginning of padding  */
285 	last_padded_block[block_present] = 0x80;
286 
287 	/* Handle possible additional block */
288 	if (block_present > (SM3_BLOCK_SIZE - 1 - sizeof(u64))) {
289 		/* We need an additional block */
290 		PUT_UINT64_BE(8 * ctx->sm3_total, last_padded_block,
291 			      (2 * SM3_BLOCK_SIZE) - sizeof(u64));
292 		ret = sm3_process(ctx, last_padded_block); EG(ret, err);
293 		ret = sm3_process(ctx, last_padded_block + SM3_BLOCK_SIZE); EG(ret, err);
294 	} else {
295 		/* We do not need an additional block */
296 		PUT_UINT64_BE(8 * ctx->sm3_total, last_padded_block,
297 			      SM3_BLOCK_SIZE - sizeof(u64));
298 		ret = sm3_process(ctx, last_padded_block); EG(ret, err);
299 	}
300 
301 	/* Output the hash result */
302 	PUT_UINT32_BE(ctx->sm3_state[0], output, 0);
303 	PUT_UINT32_BE(ctx->sm3_state[1], output, 4);
304 	PUT_UINT32_BE(ctx->sm3_state[2], output, 8);
305 	PUT_UINT32_BE(ctx->sm3_state[3], output, 12);
306 	PUT_UINT32_BE(ctx->sm3_state[4], output, 16);
307 	PUT_UINT32_BE(ctx->sm3_state[5], output, 20);
308 	PUT_UINT32_BE(ctx->sm3_state[6], output, 24);
309 	PUT_UINT32_BE(ctx->sm3_state[7], output, 28);
310 
311 	/* Tell that we are uninitialized */
312 	ctx->magic = WORD(0);
313 
314 	ret = 0;
315 
316 err:
317 	return ret;
318 }
319 
320 int sm3_scattered(const u8 **inputs, const u32 *ilens,
321 		  u8 output[SM3_DIGEST_SIZE])
322 {
323 	sm3_context ctx;
324 	int pos = 0, ret;
325 
326 	MUST_HAVE((inputs != NULL) && (ilens != NULL) && (output != NULL), ret, err);
327 
328 	ret = sm3_init(&ctx); EG(ret, err);
329 
330 	while (inputs[pos] != NULL) {
331 		ret = sm3_update(&ctx, inputs[pos], ilens[pos]); EG(ret, err);
332 		pos += 1;
333 	}
334 
335 	ret = sm3_final(&ctx, output);
336 
337 err:
338 	return ret;
339 }
340 
341 int sm3(const u8 *input, u32 ilen, u8 output[SM3_DIGEST_SIZE])
342 {
343 	sm3_context ctx;
344 	int ret;
345 
346 	ret = sm3_init(&ctx); EG(ret, err);
347 	ret = sm3_update(&ctx, input, ilen); EG(ret, err);
348 	ret = sm3_final(&ctx, output);
349 
350 err:
351 	return ret;
352 }
353 
354 #else /* WITH_HASH_SM3 */
355 
356 /*
357  * Dummy definition to avoid the empty translation unit ISO C warning
358  */
359 typedef int dummy;
360 #endif /* WITH_HASH_SM3 */
361