xref: /freebsd/crypto/openssl/crypto/ml_dsa/ml_dsa_local.h (revision e7be843b4a162e68651d3911f0357ed464915629)
1 /*
2  * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9 
10 #ifndef OSSL_CRYPTO_ML_DSA_LOCAL_H
11 # define OSSL_CRYPTO_ML_DSA_LOCAL_H
12 
13 # include "crypto/ml_dsa.h"
14 # include "internal/constant_time.h"
15 # include "internal/packet.h"
16 
17 /* The following constants are shared by ML-DSA-44, ML-DSA-65 & ML-DSA-87 */
18 # define ML_DSA_Q 8380417   /* The modulus is 23 bits (2^23 - 2^13 + 1) */
19 # define ML_DSA_Q_MINUS1_DIV2 ((ML_DSA_Q - 1) / 2)
20 
21 # define ML_DSA_Q_BITS 23
22 # define ML_DSA_Q_INV 58728449  /* q^-1 satisfies: q^-1 * q = 1 mod 2^32 */
23 # define ML_DSA_Q_NEG_INV 4236238847 /* Inverse of -q modulo 2^32 */
24 # define ML_DSA_DEGREE_INV_MONTGOMERY 41978 /* Inverse of 256 mod q, in Montgomery form. */
25 
26 # define ML_DSA_D_BITS 13   /* The number of bits dropped from the public vector t */
27 # define ML_DSA_NUM_POLY_COEFFICIENTS 256  /* The number of coefficients in the polynomials */
28 # define ML_DSA_RHO_BYTES 32   /* p = Public Random Seed */
29 # define ML_DSA_PRIV_SEED_BYTES 64 /* p' = Private random seed */
30 # define ML_DSA_K_BYTES 32 /* K = Private random seed for signing */
31 # define ML_DSA_TR_BYTES 64 /* Size of the Hash of the public key used for signing */
32 # define ML_DSA_MU_BYTES 64 /* Size of the Hash for the message representative */
33 # define ML_DSA_RHO_PRIME_BYTES 64 /* private random seed size */
34 
35 /*
36  * There is special case code related to encoding/decoding that tests the
37  * for the following values.
38  */
39 /*
40  * The possible value for eta - If a new value is added, then all code
41  * that accesses ML_DSA_ETA_4 would need to be modified.
42  */
43 # define ML_DSA_ETA_4 4
44 # define ML_DSA_ETA_2 2
45 /*
46  * The possible values of gamma1 - If a new value is added, then all code
47  * that accesses ML_DSA_GAMMA1_TWO_POWER_19 would need to be modified.
48  */
49 # define ML_DSA_GAMMA1_TWO_POWER_19 (1 << 19)
50 # define ML_DSA_GAMMA1_TWO_POWER_17 (1 << 17)
51 /*
52  * The possible values for gamma2 - If a new value is added, then all code
53  * that accesses ML_DSA_GAMMA2_Q_MINUS1_DIV32 would need to be modified.
54  */
55 # define ML_DSA_GAMMA2_Q_MINUS1_DIV32 ((ML_DSA_Q - 1) / 32)
56 # define ML_DSA_GAMMA2_Q_MINUS1_DIV88 ((ML_DSA_Q - 1) / 88)
57 
58 typedef struct poly_st POLY;
59 typedef struct vector_st VECTOR;
60 typedef struct matrix_st MATRIX;
61 typedef struct ml_dsa_sig_st ML_DSA_SIG;
62 
63 int ossl_ml_dsa_matrix_expand_A(EVP_MD_CTX *g_ctx, const EVP_MD *md,
64                                 const uint8_t *rho, MATRIX *out);
65 int ossl_ml_dsa_vector_expand_S(EVP_MD_CTX *h_ctx, const EVP_MD *md, int eta,
66                                 const uint8_t *seed, VECTOR *s1, VECTOR *s2);
67 void ossl_ml_dsa_matrix_mult_vector(const MATRIX *matrix_kl, const VECTOR *vl,
68                                     VECTOR *vk);
69 int ossl_ml_dsa_poly_expand_mask(POLY *out, const uint8_t *seed, size_t seed_len,
70                                  uint32_t gamma1,
71                                  EVP_MD_CTX *h_ctx, const EVP_MD *md);
72 int ossl_ml_dsa_poly_sample_in_ball(POLY *out_c, const uint8_t *seed, int seed_len,
73                                     EVP_MD_CTX *h_ctx, const EVP_MD *md,
74                                     uint32_t tau);
75 
76 void ossl_ml_dsa_poly_ntt(POLY *s);
77 void ossl_ml_dsa_poly_ntt_inverse(POLY *s);
78 void ossl_ml_dsa_poly_ntt_mult(const POLY *lhs, const POLY *rhs, POLY *out);
79 
80 void ossl_ml_dsa_key_compress_power2_round(uint32_t r, uint32_t *r1, uint32_t *r0);
81 uint32_t ossl_ml_dsa_key_compress_high_bits(uint32_t r, uint32_t gamma2);
82 void ossl_ml_dsa_key_compress_decompose(uint32_t r, uint32_t gamma2,
83                                         uint32_t *r1, int32_t *r0);
84 void ossl_ml_dsa_key_compress_decompose(uint32_t r, uint32_t gamma2,
85                                         uint32_t *r1, int32_t *r0);
86 int32_t ossl_ml_dsa_key_compress_low_bits(uint32_t r, uint32_t gamma2);
87 int32_t ossl_ml_dsa_key_compress_make_hint(uint32_t ct0, uint32_t cs2,
88                                            uint32_t gamma2, uint32_t w);
89 uint32_t ossl_ml_dsa_key_compress_use_hint(uint32_t hint, uint32_t r,
90                                            uint32_t gamma2);
91 
92 int ossl_ml_dsa_pk_encode(ML_DSA_KEY *key);
93 int ossl_ml_dsa_sk_encode(ML_DSA_KEY *key);
94 
95 int ossl_ml_dsa_sig_encode(const ML_DSA_SIG *sig, const ML_DSA_PARAMS *params,
96                            uint8_t *out);
97 int ossl_ml_dsa_sig_decode(ML_DSA_SIG *sig, const uint8_t *in, size_t in_len,
98                            const ML_DSA_PARAMS *params);
99 int ossl_ml_dsa_w1_encode(const VECTOR *w1, uint32_t gamma2,
100                           uint8_t *out, size_t out_len);
101 int ossl_ml_dsa_poly_decode_expand_mask(POLY *out,
102                                         const uint8_t *in, size_t in_len,
103                                         uint32_t gamma1);
104 
105 /*
106  * @brief Reduces x mod q in constant time
107  * i.e. return x < q ? x : x - q;
108  *
109  * @param x Where x is assumed to be in the range 0 <= x < 2*q
110  * @returns the difference in the range 0..q-1
111  */
reduce_once(uint32_t x)112 static ossl_inline ossl_unused uint32_t reduce_once(uint32_t x)
113 {
114     return constant_time_select_32(constant_time_lt_32(x, ML_DSA_Q), x, x - ML_DSA_Q);
115 }
116 
117 /*
118  * @brief Calculate The positive value of (a-b) mod q in constant time.
119  *
120  * a - b mod q gives a value in the range -(q-1)..(q-1)
121  * By adding q we get a range of 1..(2q-1).
122  * Reducing this once then gives the range 0..q-1
123  *
124  * @param a The minuend assumed to be in the range 0..q-1
125  * @param b The subtracthend assumed to be in the range 0..q-1.
126  * @returns The value (q + a - b) mod q
127  */
mod_sub(uint32_t a,uint32_t b)128 static ossl_inline ossl_unused uint32_t mod_sub(uint32_t a, uint32_t b)
129 {
130     return reduce_once(ML_DSA_Q + a - b);
131 }
132 
133 /*
134  * @brief Returns the absolute value in constant time.
135  * i.e. return is_positive(x) ? x : -x;
136  */
abs_signed(uint32_t x)137 static ossl_inline ossl_unused uint32_t abs_signed(uint32_t x)
138 {
139     return constant_time_select_32(constant_time_lt_32(x, 0x80000000), x, 0u - x);
140 }
141 
142 /*
143  * @brief Returns the absolute value modulo q in constant time
144  * i.e return x > (q - 1) / 2 ? q - x : x;
145  */
abs_mod_prime(uint32_t x)146 static ossl_inline ossl_unused uint32_t abs_mod_prime(uint32_t x)
147 {
148     return constant_time_select_32(constant_time_lt_32(ML_DSA_Q_MINUS1_DIV2, x),
149                                                        ML_DSA_Q - x, x);
150 }
151 
152 /*
153  * @brief Returns the maximum of two values in constant time.
154  * i.e return x < y ? y : x;
155  */
maximum(uint32_t x,uint32_t y)156 static ossl_inline ossl_unused uint32_t maximum(uint32_t x, uint32_t y)
157 {
158     return constant_time_select_int(constant_time_lt(x, y), y, x);
159 }
160 
161 #endif /* OSSL_CRYPTO_ML_DSA_LOCAL_H */
162