xref: /freebsd/crypto/openssl/crypto/slh_dsa/slh_hash.c (revision e7be843b4a162e68651d3911f0357ed464915629)
1 /*
2  * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9 
10 #include "internal/deprecated.h" /* PKCS1_MGF1() */
11 
12 #include <string.h>
13 #include <openssl/evp.h>
14 #include <openssl/core_names.h>
15 #include <openssl/rsa.h> /* PKCS1_MGF1() */
16 #include "slh_dsa_local.h"
17 #include "slh_dsa_key.h"
18 
19 #define MAX_DIGEST_SIZE 64 /* SHA-512 is used for security category 3 & 5 */
20 
21 static OSSL_SLH_HASHFUNC_H_MSG slh_hmsg_sha2;
22 static OSSL_SLH_HASHFUNC_PRF slh_prf_sha2;
23 static OSSL_SLH_HASHFUNC_PRF_MSG slh_prf_msg_sha2;
24 static OSSL_SLH_HASHFUNC_F slh_f_sha2;
25 static OSSL_SLH_HASHFUNC_H slh_h_sha2;
26 static OSSL_SLH_HASHFUNC_T slh_t_sha2;
27 
28 static OSSL_SLH_HASHFUNC_H_MSG slh_hmsg_shake;
29 static OSSL_SLH_HASHFUNC_PRF slh_prf_shake;
30 static OSSL_SLH_HASHFUNC_PRF_MSG slh_prf_msg_shake;
31 static OSSL_SLH_HASHFUNC_F slh_f_shake;
32 static OSSL_SLH_HASHFUNC_H slh_h_shake;
33 static OSSL_SLH_HASHFUNC_T slh_t_shake;
34 
xof_digest_3(EVP_MD_CTX * ctx,const uint8_t * in1,size_t in1_len,const uint8_t * in2,size_t in2_len,const uint8_t * in3,size_t in3_len,uint8_t * out,size_t out_len)35 static ossl_inline int xof_digest_3(EVP_MD_CTX *ctx,
36                                     const uint8_t *in1, size_t in1_len,
37                                     const uint8_t *in2, size_t in2_len,
38                                     const uint8_t *in3, size_t in3_len,
39                                     uint8_t *out, size_t out_len)
40 {
41     return (EVP_DigestInit_ex2(ctx, NULL, NULL) == 1
42             && EVP_DigestUpdate(ctx, in1, in1_len) == 1
43             && EVP_DigestUpdate(ctx, in2, in2_len) == 1
44             && EVP_DigestUpdate(ctx, in3, in3_len) == 1
45             && EVP_DigestFinalXOF(ctx, out, out_len) == 1);
46 }
47 
xof_digest_4(EVP_MD_CTX * ctx,const uint8_t * in1,size_t in1_len,const uint8_t * in2,size_t in2_len,const uint8_t * in3,size_t in3_len,const uint8_t * in4,size_t in4_len,uint8_t * out,size_t out_len)48 static ossl_inline int xof_digest_4(EVP_MD_CTX *ctx,
49                                     const uint8_t *in1, size_t in1_len,
50                                     const uint8_t *in2, size_t in2_len,
51                                     const uint8_t *in3, size_t in3_len,
52                                     const uint8_t *in4, size_t in4_len,
53                                     uint8_t *out, size_t out_len)
54 {
55     return (EVP_DigestInit_ex2(ctx, NULL, NULL) == 1
56             && EVP_DigestUpdate(ctx, in1, in1_len) == 1
57             && EVP_DigestUpdate(ctx, in2, in2_len) == 1
58             && EVP_DigestUpdate(ctx, in3, in3_len) == 1
59             && EVP_DigestUpdate(ctx, in4, in4_len) == 1
60             && EVP_DigestFinalXOF(ctx, out, out_len) == 1);
61 }
62 
63 /* See FIPS 205 Section 11.1 */
64 static int
slh_hmsg_shake(SLH_DSA_HASH_CTX * ctx,const uint8_t * r,const uint8_t * pk_seed,const uint8_t * pk_root,const uint8_t * msg,size_t msg_len,uint8_t * out,size_t out_len)65 slh_hmsg_shake(SLH_DSA_HASH_CTX *ctx, const uint8_t *r,
66                const uint8_t *pk_seed, const uint8_t *pk_root,
67                const uint8_t *msg, size_t msg_len,
68                uint8_t *out, size_t out_len)
69 {
70     const SLH_DSA_PARAMS *params = ctx->key->params;
71     size_t m = params->m;
72     size_t n = params->n;
73 
74     return xof_digest_4(ctx->md_ctx, r, n, pk_seed, n, pk_root, n,
75                         msg, msg_len, out, m);
76 }
77 
78 static int
slh_prf_shake(SLH_DSA_HASH_CTX * ctx,const uint8_t * pk_seed,const uint8_t * sk_seed,const uint8_t * adrs,uint8_t * out,size_t out_len)79 slh_prf_shake(SLH_DSA_HASH_CTX *ctx,
80               const uint8_t *pk_seed, const uint8_t *sk_seed,
81               const uint8_t *adrs, uint8_t *out, size_t out_len)
82 {
83     const SLH_DSA_PARAMS *params = ctx->key->params;
84     size_t n = params->n;
85 
86     return xof_digest_3(ctx->md_ctx, pk_seed, n, adrs, SLH_ADRS_SIZE,
87                         sk_seed, n, out, n);
88 }
89 
90 static int
slh_prf_msg_shake(SLH_DSA_HASH_CTX * ctx,const uint8_t * sk_prf,const uint8_t * opt_rand,const uint8_t * msg,size_t msg_len,WPACKET * pkt)91 slh_prf_msg_shake(SLH_DSA_HASH_CTX *ctx, const uint8_t *sk_prf,
92                   const uint8_t *opt_rand, const uint8_t *msg, size_t msg_len,
93                   WPACKET *pkt)
94 {
95     unsigned char out[SLH_MAX_N];
96     const SLH_DSA_PARAMS *params = ctx->key->params;
97     size_t n = params->n;
98 
99     return xof_digest_3(ctx->md_ctx, sk_prf, n, opt_rand, n, msg, msg_len, out, n)
100         && WPACKET_memcpy(pkt, out, n);
101 }
102 
103 static int
slh_f_shake(SLH_DSA_HASH_CTX * ctx,const uint8_t * pk_seed,const uint8_t * adrs,const uint8_t * m1,size_t m1_len,uint8_t * out,size_t out_len)104 slh_f_shake(SLH_DSA_HASH_CTX *ctx, const uint8_t *pk_seed, const uint8_t *adrs,
105             const uint8_t *m1, size_t m1_len, uint8_t *out, size_t out_len)
106 {
107     const SLH_DSA_PARAMS *params = ctx->key->params;
108     size_t n = params->n;
109 
110     return xof_digest_3(ctx->md_ctx, pk_seed, n, adrs, SLH_ADRS_SIZE, m1, m1_len, out, n);
111 }
112 
113 static int
slh_h_shake(SLH_DSA_HASH_CTX * ctx,const uint8_t * pk_seed,const uint8_t * adrs,const uint8_t * m1,const uint8_t * m2,uint8_t * out,size_t out_len)114 slh_h_shake(SLH_DSA_HASH_CTX *ctx, const uint8_t *pk_seed, const uint8_t *adrs,
115             const uint8_t *m1, const uint8_t *m2, uint8_t *out, size_t out_len)
116 {
117     const SLH_DSA_PARAMS *params = ctx->key->params;
118     size_t n = params->n;
119 
120     return xof_digest_4(ctx->md_ctx, pk_seed, n, adrs, SLH_ADRS_SIZE, m1, n, m2, n, out, n);
121 }
122 
123 static int
slh_t_shake(SLH_DSA_HASH_CTX * ctx,const uint8_t * pk_seed,const uint8_t * adrs,const uint8_t * ml,size_t ml_len,uint8_t * out,size_t out_len)124 slh_t_shake(SLH_DSA_HASH_CTX *ctx, const uint8_t *pk_seed, const uint8_t *adrs,
125             const uint8_t *ml, size_t ml_len, uint8_t *out, size_t out_len)
126 {
127     const SLH_DSA_PARAMS *params = ctx->key->params;
128     size_t n = params->n;
129 
130     return xof_digest_3(ctx->md_ctx, pk_seed, n, adrs, SLH_ADRS_SIZE, ml, ml_len, out, n);
131 }
132 
133 static ossl_inline int
digest_4(EVP_MD_CTX * ctx,const uint8_t * in1,size_t in1_len,const uint8_t * in2,size_t in2_len,const uint8_t * in3,size_t in3_len,const uint8_t * in4,size_t in4_len,uint8_t * out)134 digest_4(EVP_MD_CTX *ctx,
135          const uint8_t *in1, size_t in1_len, const uint8_t *in2, size_t in2_len,
136          const uint8_t *in3, size_t in3_len, const uint8_t *in4, size_t in4_len,
137          uint8_t *out)
138 {
139     return (EVP_DigestInit_ex2(ctx, NULL, NULL) == 1
140             && EVP_DigestUpdate(ctx, in1, in1_len) == 1
141             && EVP_DigestUpdate(ctx, in2, in2_len) == 1
142             && EVP_DigestUpdate(ctx, in3, in3_len) == 1
143             && EVP_DigestUpdate(ctx, in4, in4_len) == 1
144             && EVP_DigestFinal_ex(ctx, out, NULL) == 1);
145 }
146 
147 /* FIPS 205 Section 11.2.1 and 11.2.2 */
148 
149 static int
slh_hmsg_sha2(SLH_DSA_HASH_CTX * hctx,const uint8_t * r,const uint8_t * pk_seed,const uint8_t * pk_root,const uint8_t * msg,size_t msg_len,uint8_t * out,size_t out_len)150 slh_hmsg_sha2(SLH_DSA_HASH_CTX *hctx, const uint8_t *r, const uint8_t *pk_seed,
151               const uint8_t *pk_root, const uint8_t *msg, size_t msg_len,
152               uint8_t *out, size_t out_len)
153 {
154     const SLH_DSA_PARAMS *params = hctx->key->params;
155     size_t m = params->m;
156     size_t n = params->n;
157     uint8_t seed[2 * SLH_MAX_N + MAX_DIGEST_SIZE];
158     int sz = EVP_MD_get_size(hctx->key->md_big);
159     size_t seed_len = (size_t)sz + 2 * n;
160 
161     memcpy(seed, r, n);
162     memcpy(seed + n, pk_seed, n);
163     return digest_4(hctx->md_big_ctx, r, n, pk_seed, n, pk_root, n, msg, msg_len,
164                     seed + 2 * n)
165         && (PKCS1_MGF1(out, m, seed, seed_len, hctx->key->md_big) == 0);
166 }
167 
168 static int
slh_prf_msg_sha2(SLH_DSA_HASH_CTX * hctx,const uint8_t * sk_prf,const uint8_t * opt_rand,const uint8_t * msg,size_t msg_len,WPACKET * pkt)169 slh_prf_msg_sha2(SLH_DSA_HASH_CTX *hctx,
170                  const uint8_t *sk_prf, const uint8_t *opt_rand,
171                  const uint8_t *msg, size_t msg_len, WPACKET *pkt)
172 {
173     int ret;
174     const SLH_DSA_KEY *key = hctx->key;
175     EVP_MAC_CTX *mctx = hctx->hmac_ctx;
176     const SLH_DSA_PARAMS *prms = key->params;
177     size_t n = prms->n;
178     uint8_t mac[MAX_DIGEST_SIZE] = {0};
179     OSSL_PARAM *p = NULL;
180     OSSL_PARAM params[3];
181 
182     /*
183      * Due to the way HMAC works, it is not possible to do this code early
184      * in hmac_ctx_new() since it requires a key in order to set the digest.
185      * So we do a lazy update here on the first call.
186      */
187     if (hctx->hmac_digest_used == 0) {
188         p = params;
189         /* The underlying digest to be used */
190         *p++ = OSSL_PARAM_construct_utf8_string(OSSL_MAC_PARAM_DIGEST,
191                                                 (char *)EVP_MD_get0_name(key->md_big), 0);
192         if (key->propq != NULL)
193             *p++ = OSSL_PARAM_construct_utf8_string(OSSL_MAC_PARAM_PROPERTIES,
194                                                     (char *)key->propq, 0);
195         *p = OSSL_PARAM_construct_end();
196         p = params;
197         hctx->hmac_digest_used = 1;
198     }
199 
200     ret = EVP_MAC_init(mctx, sk_prf, n, p) == 1
201         && EVP_MAC_update(mctx, opt_rand, n) == 1
202         && EVP_MAC_update(mctx, msg, msg_len) == 1
203         && EVP_MAC_final(mctx, mac, NULL, sizeof(mac)) == 1
204         && WPACKET_memcpy(pkt, mac, n); /* Truncate output to n bytes */
205     return ret;
206 }
207 
208 static ossl_inline int
do_hash(EVP_MD_CTX * ctx,size_t n,const uint8_t * pk_seed,const uint8_t * adrs,const uint8_t * m,size_t m_len,size_t b,uint8_t * out,size_t out_len)209 do_hash(EVP_MD_CTX *ctx, size_t n, const uint8_t *pk_seed, const uint8_t *adrs,
210         const uint8_t *m, size_t m_len, size_t b, uint8_t *out, size_t out_len)
211 {
212     int ret;
213     uint8_t zeros[128] = { 0 };
214     uint8_t digest[MAX_DIGEST_SIZE];
215 
216     ret = digest_4(ctx, pk_seed, n, zeros, b - n, adrs, SLH_ADRSC_SIZE,
217                    m, m_len, digest);
218     /* Truncated returned value is n = 16 bytes */
219     memcpy(out, digest, n);
220     return ret;
221 }
222 
223 static int
slh_prf_sha2(SLH_DSA_HASH_CTX * hctx,const uint8_t * pk_seed,const uint8_t * sk_seed,const uint8_t * adrs,uint8_t * out,size_t out_len)224 slh_prf_sha2(SLH_DSA_HASH_CTX *hctx, const uint8_t *pk_seed,
225              const uint8_t *sk_seed, const uint8_t *adrs,
226              uint8_t *out, size_t out_len)
227 {
228     size_t n = hctx->key->params->n;
229 
230     return do_hash(hctx->md_ctx, n, pk_seed, adrs, sk_seed, n,
231                    OSSL_SLH_DSA_SHA2_NUM_ZEROS_H_AND_T_BOUND1, out, out_len);
232 }
233 
234 static int
slh_f_sha2(SLH_DSA_HASH_CTX * hctx,const uint8_t * pk_seed,const uint8_t * adrs,const uint8_t * m1,size_t m1_len,uint8_t * out,size_t out_len)235 slh_f_sha2(SLH_DSA_HASH_CTX *hctx, const uint8_t *pk_seed, const uint8_t *adrs,
236            const uint8_t *m1, size_t m1_len, uint8_t *out, size_t out_len)
237 {
238     return do_hash(hctx->md_ctx, hctx->key->params->n, pk_seed, adrs, m1, m1_len,
239                    OSSL_SLH_DSA_SHA2_NUM_ZEROS_H_AND_T_BOUND1, out, out_len);
240 }
241 
242 static int
slh_h_sha2(SLH_DSA_HASH_CTX * hctx,const uint8_t * pk_seed,const uint8_t * adrs,const uint8_t * m1,const uint8_t * m2,uint8_t * out,size_t out_len)243 slh_h_sha2(SLH_DSA_HASH_CTX *hctx, const uint8_t *pk_seed, const uint8_t *adrs,
244            const uint8_t *m1, const uint8_t *m2, uint8_t *out, size_t out_len)
245 {
246     uint8_t m[SLH_MAX_N * 2];
247     const SLH_DSA_PARAMS *prms = hctx->key->params;
248     size_t n = prms->n;
249 
250     memcpy(m, m1, n);
251     memcpy(m + n, m2, n);
252     return do_hash(hctx->md_big_ctx, n, pk_seed, adrs, m, 2 * n,
253                    prms->sha2_h_and_t_bound, out, out_len);
254 }
255 
256 static int
slh_t_sha2(SLH_DSA_HASH_CTX * hctx,const uint8_t * pk_seed,const uint8_t * adrs,const uint8_t * ml,size_t ml_len,uint8_t * out,size_t out_len)257 slh_t_sha2(SLH_DSA_HASH_CTX *hctx, const uint8_t *pk_seed, const uint8_t *adrs,
258            const uint8_t *ml, size_t ml_len, uint8_t *out, size_t out_len)
259 {
260     const SLH_DSA_PARAMS *prms = hctx->key->params;
261 
262     return do_hash(hctx->md_big_ctx, prms->n, pk_seed, adrs, ml, ml_len,
263                    prms->sha2_h_and_t_bound, out, out_len);
264 }
265 
ossl_slh_get_hash_fn(int is_shake)266 const SLH_HASH_FUNC *ossl_slh_get_hash_fn(int is_shake)
267 {
268     static const SLH_HASH_FUNC methods[] = {
269         {
270             slh_hmsg_shake,
271             slh_prf_shake,
272             slh_prf_msg_shake,
273             slh_f_shake,
274             slh_h_shake,
275             slh_t_shake
276         },
277         {
278             slh_hmsg_sha2,
279             slh_prf_sha2,
280             slh_prf_msg_sha2,
281             slh_f_sha2,
282             slh_h_sha2,
283             slh_t_sha2
284         }
285     };
286     return &methods[is_shake ? 0 : 1];
287 }
288