xref: /freebsd/crypto/openssl/crypto/slh_dsa/slh_hypertree.c (revision e7be843b4a162e68651d3911f0357ed464915629)
1 /*
2  * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9 
10 #include <string.h>
11 #include "slh_dsa_local.h"
12 #include "slh_dsa_key.h"
13 
14 /**
15  * @brief Generate a Hypertree Signature
16  * See FIPS 205 Section 7.1 Algorithm 12
17  *
18  * This writes |d| XMSS signatures i.e. ((|h| + |d| * |len|) * |n|)
19  * where the first signature uses the XMSS key at the lowest layer, and the last
20  * signature uses the XMSS key at the top layer.
21  *
22  * @param ctx Contains SLH_DSA algorithm functions and constants.
23  * @param msg A message of size |n|.
24  * @param sk_seed The private key seed of size |n|
25  * @param pk_seed The public key seed of size |n|
26  * @param tree_id Index of the XMSS tree that will sign the message
27  * @param leaf_id Index of the WOTS+ key within the XMSS tree that will sign the message
28  * @param sig_wpkt A WPACKET object to write the Hypertree Signature to.
29  * @returns 1 on success, or 0 on error.
30  */
ossl_slh_ht_sign(SLH_DSA_HASH_CTX * ctx,const uint8_t * msg,const uint8_t * sk_seed,const uint8_t * pk_seed,uint64_t tree_id,uint32_t leaf_id,WPACKET * sig_wpkt)31 int ossl_slh_ht_sign(SLH_DSA_HASH_CTX *ctx,
32                      const uint8_t *msg, const uint8_t *sk_seed,
33                      const uint8_t *pk_seed,
34                      uint64_t tree_id, uint32_t leaf_id, WPACKET *sig_wpkt)
35 {
36     const SLH_DSA_KEY *key = ctx->key;
37     SLH_ADRS_FUNC_DECLARE(key, adrsf);
38     SLH_ADRS_DECLARE(adrs);
39     uint8_t root[SLH_MAX_N];
40     uint32_t layer, mask;
41     const SLH_DSA_PARAMS *params = key->params;
42     uint32_t n = params->n;
43     uint32_t d = params->d;
44     uint32_t hm = params->hm;
45     uint8_t *psig;
46     PACKET rpkt, *xmss_sig_rpkt = &rpkt;
47 
48     mask = (1 << hm) - 1; /* A mod 2^h = A & ((2^h - 1))) */
49 
50     adrsf->zero(adrs);
51     /*
52      * For each XMSS tree there is a current leaf node that is used for signing.
53      * The first iteration of the loop signs the input message using the bottom
54      * tree. Subsequent passes use the parent trees leaf node to sign the current
55      * trees public key.
56      * Each node in an XMSS tree has a sibling (except for the root node),
57      * so starting at the leaf node it traverses up the tree calculating
58      * hashes for all the siblings in the path to the root node,
59      * which are then stored in the XMSS signature. The verify then just needs
60      * the hash of the leaf node which is can then combine with the signature
61      * path hashes to work all the way up to the root node to calculate the
62      * public key.
63      */
64     memcpy(root, msg, n);
65 
66     for (layer = 0; layer < d; ++layer) {
67         /* type = SLH_ADRS_TYPE_WOTS_HASH */
68         adrsf->set_layer_address(adrs, layer);
69         adrsf->set_tree_address(adrs, tree_id);
70         psig = WPACKET_get_curr(sig_wpkt);
71         if (!ossl_slh_xmss_sign(ctx, root, sk_seed, leaf_id, pk_seed, adrs,
72                                 sig_wpkt))
73             return 0;
74         /*
75          * On the last loop it skips getting the public key since it is not needed
76          * to calculate another signature. If this was called it should equal
77          * the PK_ROOT (i.e. the public key of the top level tree).
78          */
79         if (layer < d - 1) {
80             if (!PACKET_buf_init(xmss_sig_rpkt, psig,
81                                  WPACKET_get_curr(sig_wpkt) - psig))
82                 return 0;
83             if (!ossl_slh_xmss_pk_from_sig(ctx, leaf_id, xmss_sig_rpkt, root,
84                                            pk_seed, adrs, root, sizeof(root)))
85                 return 0;
86             leaf_id = tree_id & mask;
87             tree_id >>= hm;
88         }
89     }
90     return 1;
91 }
92 
93 /**
94  * @brief Verify a Hypertree Signature
95  * See FIPS 205 Section 7.2 Algorithm 13
96  *
97  * @param ctx Contains SLH_DSA algorithm functions and constants.
98  * @param msg A message of size |n| bytes
99  * @param sig A HT signature of size (|h| + |d| * |len|) * |n| bytes
100  * @param pk_seed SLH_DSA public key seed of size |n|
101  * @param tree_id Index of the XMSS tree that signed the message
102  * @param leaf_id Index of the WOTS+ key within the XMSS tree that signed the message
103  * @param pk_root The known Hypertree public key of size |n|
104  *
105  * @returns 1 if the computed XMSS public key matches pk_root, or 0 otherwise.
106  */
ossl_slh_ht_verify(SLH_DSA_HASH_CTX * ctx,const uint8_t * msg,PACKET * sig_pkt,const uint8_t * pk_seed,uint64_t tree_id,uint32_t leaf_id,const uint8_t * pk_root)107 int ossl_slh_ht_verify(SLH_DSA_HASH_CTX *ctx, const uint8_t *msg, PACKET *sig_pkt,
108                        const uint8_t *pk_seed, uint64_t tree_id, uint32_t leaf_id,
109                        const uint8_t *pk_root)
110 {
111     const SLH_DSA_KEY *key = ctx->key;
112     SLH_ADRS_FUNC_DECLARE(key, adrsf);
113     SLH_ADRS_DECLARE(adrs);
114     uint8_t node[SLH_MAX_N];
115     const SLH_DSA_PARAMS *params = key->params;
116     uint32_t tree_height = params->hm;
117     uint32_t n = params->n;
118     uint32_t d = params->d;
119     uint32_t mask = (1 << tree_height) - 1;
120     uint32_t layer;
121 
122     adrsf->zero(adrs);
123     memcpy(node, msg, n);
124 
125     for (layer = 0; layer < d; ++layer) {
126         adrsf->set_layer_address(adrs, layer);
127         adrsf->set_tree_address(adrs, tree_id);
128         if (!ossl_slh_xmss_pk_from_sig(ctx, leaf_id, sig_pkt, node,
129                                        pk_seed, adrs, node, sizeof(node)))
130             return 0;
131         leaf_id = tree_id & mask;
132         tree_id >>= tree_height;
133     }
134     return (memcmp(node, pk_root, n) == 0);
135 }
136