xref: /freebsd/crypto/openssl/crypto/ml_dsa/ml_dsa_poly.h (revision e7be843b4a162e68651d3911f0357ed464915629)
1 /*
2  * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9 #include <openssl/crypto.h>
10 
11 #define ML_DSA_NUM_POLY_COEFFICIENTS 256
12 
13 /* Polynomial object with 256 coefficients. The coefficients are unsigned 32 bits */
14 struct poly_st {
15     uint32_t coeff[ML_DSA_NUM_POLY_COEFFICIENTS];
16 };
17 
18 static ossl_inline ossl_unused void
poly_zero(POLY * p)19 poly_zero(POLY *p)
20 {
21     memset(p->coeff, 0, sizeof(*p));
22 }
23 
24 /**
25  * @brief Polynomial addition.
26  *
27  * @param lhs A polynomial with coefficients in the range (0..q-1)
28  * @param rhs A polynomial with coefficients in the range (0..q-1) to add
29  *            to the 'lhs'.
30  * @param out The returned addition result with the coefficients all in the
31  *            range 0..q-1
32  */
33 static ossl_inline ossl_unused void
poly_add(const POLY * lhs,const POLY * rhs,POLY * out)34 poly_add(const POLY *lhs, const POLY *rhs, POLY *out)
35 {
36     int i;
37 
38     for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++)
39         out->coeff[i] = reduce_once(lhs->coeff[i] + rhs->coeff[i]);
40 }
41 
42 /**
43  * @brief Polynomial subtraction.
44  *
45  * @param lhs A polynomial with coefficients in the range (0..q-1)
46  * @param rhs A polynomial with coefficients in the range (0..q-1) to subtract
47  *            from the 'lhs'.
48  * @param out The returned subtraction result with the coefficients all in the
49  *            range 0..q-1
50  */
51 static ossl_inline ossl_unused void
poly_sub(const POLY * lhs,const POLY * rhs,POLY * out)52 poly_sub(const POLY *lhs, const POLY *rhs, POLY *out)
53 {
54     int i;
55 
56     for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++)
57         out->coeff[i] = mod_sub(lhs->coeff[i], rhs->coeff[i]);
58 }
59 
60 /* @returns 1 if the polynomials are equal, or 0 otherwise */
61 static ossl_inline ossl_unused int
poly_equal(const POLY * a,const POLY * b)62 poly_equal(const POLY *a, const POLY *b)
63 {
64     return CRYPTO_memcmp(a, b, sizeof(*a)) == 0;
65 }
66 
67 static ossl_inline ossl_unused void
poly_ntt(POLY * p)68 poly_ntt(POLY *p)
69 {
70     ossl_ml_dsa_poly_ntt(p);
71 }
72 
73 static ossl_inline ossl_unused int
poly_sample_in_ball_ntt(POLY * out,const uint8_t * seed,int seed_len,EVP_MD_CTX * h_ctx,const EVP_MD * md,uint32_t tau)74 poly_sample_in_ball_ntt(POLY *out, const uint8_t *seed, int seed_len,
75                         EVP_MD_CTX *h_ctx, const EVP_MD *md, uint32_t tau)
76 {
77     if (!ossl_ml_dsa_poly_sample_in_ball(out, seed, seed_len, h_ctx, md, tau))
78         return 0;
79     poly_ntt(out);
80     return 1;
81 }
82 
83 static ossl_inline ossl_unused int
poly_expand_mask(POLY * out,const uint8_t * seed,size_t seed_len,uint32_t gamma1,EVP_MD_CTX * h_ctx,const EVP_MD * md)84 poly_expand_mask(POLY *out, const uint8_t *seed, size_t seed_len,
85                  uint32_t gamma1, EVP_MD_CTX *h_ctx, const EVP_MD *md)
86 {
87     return ossl_ml_dsa_poly_expand_mask(out, seed, seed_len, gamma1, h_ctx, md);
88 }
89 
90 /**
91  * @brief Decompose the coefficients of a polynomial into (r1, r0) such that
92  * coeff[i] == t1[i] * 2^13 + t0[i] mod q
93  * See FIPS 204, Algorithm 35, Power2Round()
94  *
95  * @param t A polynomial containing coefficients in the range 0..q-1
96  * @param t1 The returned polynomial containing coefficients that represent
97  *           the top 10 MSB of each coefficient in t (i.e each ranging from 0..1023)
98  * @param t0 The remainder coefficients of t in the range (0..4096 or q-4095..q-1)
99  *           Each t0 coefficient has an effective range of 8192 (i.e. 13 bits).
100  */
101 static ossl_inline ossl_unused void
poly_power2_round(const POLY * t,POLY * t1,POLY * t0)102 poly_power2_round(const POLY *t, POLY *t1, POLY *t0)
103 {
104     int i;
105 
106     for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++)
107         ossl_ml_dsa_key_compress_power2_round(t->coeff[i],
108                                               t1->coeff + i, t0->coeff + i);
109 }
110 
111 static ossl_inline ossl_unused void
poly_scale_power2_round(POLY * in,POLY * out)112 poly_scale_power2_round(POLY *in, POLY *out)
113 {
114     int i;
115 
116     for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++)
117         out->coeff[i] = (in->coeff[i] << ML_DSA_D_BITS);
118 }
119 
120 static ossl_inline ossl_unused void
poly_high_bits(const POLY * in,uint32_t gamma2,POLY * out)121 poly_high_bits(const POLY *in, uint32_t gamma2, POLY *out)
122 {
123     int i;
124 
125     for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++)
126         out->coeff[i] = ossl_ml_dsa_key_compress_high_bits(in->coeff[i], gamma2);
127 }
128 
129 static ossl_inline ossl_unused void
poly_low_bits(const POLY * in,uint32_t gamma2,POLY * out)130 poly_low_bits(const POLY *in, uint32_t gamma2, POLY *out)
131 {
132     int i;
133 
134     for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++)
135         out->coeff[i] = ossl_ml_dsa_key_compress_low_bits(in->coeff[i], gamma2);
136 }
137 
138 static ossl_inline ossl_unused void
poly_make_hint(const POLY * ct0,const POLY * cs2,const POLY * w,uint32_t gamma2,POLY * out)139 poly_make_hint(const POLY *ct0, const POLY *cs2, const POLY *w, uint32_t gamma2,
140                POLY *out)
141 {
142     int i;
143 
144     for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++)
145         out->coeff[i] = ossl_ml_dsa_key_compress_make_hint(ct0->coeff[i],
146                                                            cs2->coeff[i],
147                                                            gamma2, w->coeff[i]);
148 }
149 
150 static ossl_inline ossl_unused void
poly_use_hint(const POLY * h,const POLY * r,uint32_t gamma2,POLY * out)151 poly_use_hint(const POLY *h, const POLY *r, uint32_t gamma2, POLY *out)
152 {
153     int i;
154 
155     for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++)
156         out->coeff[i] = ossl_ml_dsa_key_compress_use_hint(h->coeff[i],
157                                                           r->coeff[i], gamma2);
158 }
159 
160 static ossl_inline ossl_unused void
poly_max(const POLY * p,uint32_t * mx)161 poly_max(const POLY *p, uint32_t *mx)
162 {
163     int i;
164 
165     for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++) {
166         uint32_t c = p->coeff[i];
167         uint32_t abs = abs_mod_prime(c);
168 
169         *mx = maximum(*mx, abs);
170     }
171 }
172 
173 static ossl_inline ossl_unused void
poly_max_signed(const POLY * p,uint32_t * mx)174 poly_max_signed(const POLY *p, uint32_t *mx)
175 {
176     int i;
177 
178     for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++) {
179         uint32_t c = p->coeff[i];
180         uint32_t abs = abs_signed(c);
181 
182         *mx = maximum(*mx, abs);
183     }
184 }
185