xref: /freebsd/crypto/openssl/crypto/ml_dsa/ml_dsa_vector.h (revision f25b8c9fb4f58cf61adb47d7570abe7caa6d385d)
1 /*
2  * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9 
10 #include <assert.h>
11 #include "ml_dsa_poly.h"
12 
13 struct vector_st {
14     POLY *poly;
15     size_t num_poly;
16 };
17 
18 /**
19  * @brief Initialize a Vector object.
20  *
21  * @param v The vector to initialize.
22  * @param polys Preallocated storage for an array of Polynomials blocks. |v|
23  *              does not own/free this.
24  * @param num_polys The number of |polys| blocks (k or l)
25  */
vector_init(VECTOR * v,POLY * polys,size_t num_polys)26 static ossl_inline ossl_unused void vector_init(VECTOR *v, POLY *polys, size_t num_polys)
27 {
28     v->poly = polys;
29     v->num_poly = num_polys;
30 }
31 
vector_alloc(VECTOR * v,size_t num_polys)32 static ossl_inline ossl_unused int vector_alloc(VECTOR *v, size_t num_polys)
33 {
34     v->poly = OPENSSL_malloc(num_polys * sizeof(POLY));
35     if (v->poly == NULL)
36         return 0;
37     v->num_poly = num_polys;
38     return 1;
39 }
40 
vector_free(VECTOR * v)41 static ossl_inline ossl_unused void vector_free(VECTOR *v)
42 {
43     OPENSSL_free(v->poly);
44     v->poly = NULL;
45     v->num_poly = 0;
46 }
47 
48 /* @brief zeroize a vectors polynomial coefficients */
vector_zero(VECTOR * va)49 static ossl_inline ossl_unused void vector_zero(VECTOR *va)
50 {
51     if (va->poly != NULL)
52         memset(va->poly, 0, va->num_poly * sizeof(va->poly[0]));
53 }
54 
55 /*
56  * @brief copy a vector
57  * The assumption is that |dst| has already been initialized
58  */
59 static ossl_inline ossl_unused void
vector_copy(VECTOR * dst,const VECTOR * src)60 vector_copy(VECTOR *dst, const VECTOR *src)
61 {
62     assert(dst->num_poly == src->num_poly);
63     memcpy(dst->poly, src->poly, src->num_poly * sizeof(src->poly[0]));
64 }
65 
66 /* @brief return 1 if 2 vectors are equal, or 0 otherwise */
67 static ossl_inline ossl_unused int
vector_equal(const VECTOR * a,const VECTOR * b)68 vector_equal(const VECTOR *a, const VECTOR *b)
69 {
70     size_t i;
71 
72     if (a->num_poly != b->num_poly)
73         return 0;
74     for (i = 0; i < a->num_poly; ++i) {
75         if (!poly_equal(a->poly + i, b->poly + i))
76             return 0;
77     }
78     return 1;
79 }
80 
81 /* @brief add 2 vectors */
82 static ossl_inline ossl_unused void
vector_add(const VECTOR * lhs,const VECTOR * rhs,VECTOR * out)83 vector_add(const VECTOR *lhs, const VECTOR *rhs, VECTOR *out)
84 {
85     size_t i;
86 
87     for (i = 0; i < lhs->num_poly; i++)
88         poly_add(lhs->poly + i, rhs->poly + i, out->poly + i);
89 }
90 
91 /* @brief subtract 2 vectors */
92 static ossl_inline ossl_unused void
vector_sub(const VECTOR * lhs,const VECTOR * rhs,VECTOR * out)93 vector_sub(const VECTOR *lhs, const VECTOR *rhs, VECTOR *out)
94 {
95     size_t i;
96 
97     for (i = 0; i < lhs->num_poly; i++)
98         poly_sub(lhs->poly + i, rhs->poly + i, out->poly + i);
99 }
100 
101 /* @brief convert a vector in place into NTT form */
102 static ossl_inline ossl_unused void
vector_ntt(VECTOR * va)103 vector_ntt(VECTOR *va)
104 {
105     size_t i;
106 
107     for (i = 0; i < va->num_poly; i++)
108         ossl_ml_dsa_poly_ntt(va->poly + i);
109 }
110 
111 /* @brief convert a vector in place into inverse NTT form */
112 static ossl_inline ossl_unused void
vector_ntt_inverse(VECTOR * va)113 vector_ntt_inverse(VECTOR *va)
114 {
115     size_t i;
116 
117     for (i = 0; i < va->num_poly; i++)
118         ossl_ml_dsa_poly_ntt_inverse(va->poly + i);
119 }
120 
121 /* @brief multiply a vector by a SCALAR polynomial */
122 static ossl_inline ossl_unused void
vector_mult_scalar(const VECTOR * lhs,const POLY * rhs,VECTOR * out)123 vector_mult_scalar(const VECTOR *lhs, const POLY *rhs, VECTOR *out)
124 {
125     size_t i;
126 
127     for (i = 0; i < lhs->num_poly; i++)
128         ossl_ml_dsa_poly_ntt_mult(lhs->poly + i, rhs, out->poly + i);
129 }
130 
131 static ossl_inline ossl_unused int
vector_expand_S(EVP_MD_CTX * h_ctx,const EVP_MD * md,int eta,const uint8_t * seed,VECTOR * s1,VECTOR * s2)132 vector_expand_S(EVP_MD_CTX *h_ctx, const EVP_MD *md, int eta,
133     const uint8_t *seed, VECTOR *s1, VECTOR *s2)
134 {
135     return ossl_ml_dsa_vector_expand_S(h_ctx, md, eta, seed, s1, s2);
136 }
137 
138 static ossl_inline ossl_unused void
vector_expand_mask(VECTOR * out,const uint8_t * rho_prime,size_t rho_prime_len,uint32_t kappa,uint32_t gamma1,EVP_MD_CTX * h_ctx,const EVP_MD * md)139 vector_expand_mask(VECTOR *out, const uint8_t *rho_prime, size_t rho_prime_len,
140     uint32_t kappa, uint32_t gamma1,
141     EVP_MD_CTX *h_ctx, const EVP_MD *md)
142 {
143     size_t i;
144     uint8_t derived_seed[ML_DSA_RHO_PRIME_BYTES + 2];
145 
146     memcpy(derived_seed, rho_prime, ML_DSA_RHO_PRIME_BYTES);
147 
148     for (i = 0; i < out->num_poly; i++) {
149         size_t index = kappa + i;
150 
151         derived_seed[ML_DSA_RHO_PRIME_BYTES] = index & 0xFF;
152         derived_seed[ML_DSA_RHO_PRIME_BYTES + 1] = (index >> 8) & 0xFF;
153         poly_expand_mask(out->poly + i, derived_seed, sizeof(derived_seed),
154             gamma1, h_ctx, md);
155     }
156 }
157 
158 /* Scale back previously rounded value */
159 static ossl_inline ossl_unused void
vector_scale_power2_round_ntt(const VECTOR * in,VECTOR * out)160 vector_scale_power2_round_ntt(const VECTOR *in, VECTOR *out)
161 {
162     size_t i;
163 
164     for (i = 0; i < in->num_poly; i++)
165         poly_scale_power2_round(in->poly + i, out->poly + i);
166     vector_ntt(out);
167 }
168 
169 /*
170  * @brief Decompose all polynomial coefficients of a vector into (t1, t0) such
171  * that coeff[i] == t1[i] * 2^13 + t0[i] mod q.
172  * See FIPS 204, Algorithm 35, Power2Round()
173  */
174 static ossl_inline ossl_unused void
vector_power2_round(const VECTOR * t,VECTOR * t1,VECTOR * t0)175 vector_power2_round(const VECTOR *t, VECTOR *t1, VECTOR *t0)
176 {
177     size_t i;
178 
179     for (i = 0; i < t->num_poly; i++)
180         poly_power2_round(t->poly + i, t1->poly + i, t0->poly + i);
181 }
182 
183 static ossl_inline ossl_unused void
vector_high_bits(const VECTOR * in,uint32_t gamma2,VECTOR * out)184 vector_high_bits(const VECTOR *in, uint32_t gamma2, VECTOR *out)
185 {
186     size_t i;
187 
188     for (i = 0; i < out->num_poly; i++)
189         poly_high_bits(in->poly + i, gamma2, out->poly + i);
190 }
191 
192 static ossl_inline ossl_unused void
vector_low_bits(const VECTOR * in,uint32_t gamma2,VECTOR * out)193 vector_low_bits(const VECTOR *in, uint32_t gamma2, VECTOR *out)
194 {
195     size_t i;
196 
197     for (i = 0; i < out->num_poly; i++)
198         poly_low_bits(in->poly + i, gamma2, out->poly + i);
199 }
200 
201 static ossl_inline ossl_unused uint32_t
vector_max(const VECTOR * v)202 vector_max(const VECTOR *v)
203 {
204     size_t i;
205     uint32_t mx = 0;
206 
207     for (i = 0; i < v->num_poly; i++)
208         poly_max(v->poly + i, &mx);
209     return mx;
210 }
211 
212 static ossl_inline ossl_unused uint32_t
vector_max_signed(const VECTOR * v)213 vector_max_signed(const VECTOR *v)
214 {
215     size_t i;
216     uint32_t mx = 0;
217 
218     for (i = 0; i < v->num_poly; i++)
219         poly_max_signed(v->poly + i, &mx);
220     return mx;
221 }
222 
223 static ossl_inline ossl_unused size_t
vector_count_ones(const VECTOR * v)224 vector_count_ones(const VECTOR *v)
225 {
226     int j;
227     size_t i, count = 0;
228 
229     for (i = 0; i < v->num_poly; i++)
230         for (j = 0; j < ML_DSA_NUM_POLY_COEFFICIENTS; j++)
231             count += v->poly[i].coeff[j];
232     return count;
233 }
234 
235 static ossl_inline ossl_unused void
vector_make_hint(const VECTOR * ct0,const VECTOR * cs2,const VECTOR * w,uint32_t gamma2,VECTOR * out)236 vector_make_hint(const VECTOR *ct0, const VECTOR *cs2, const VECTOR *w,
237     uint32_t gamma2, VECTOR *out)
238 {
239     size_t i;
240 
241     for (i = 0; i < out->num_poly; i++)
242         poly_make_hint(ct0->poly + i, cs2->poly + i, w->poly + i, gamma2,
243             out->poly + i);
244 }
245 
246 static ossl_inline ossl_unused void
vector_use_hint(const VECTOR * h,const VECTOR * r,uint32_t gamma2,VECTOR * out)247 vector_use_hint(const VECTOR *h, const VECTOR *r, uint32_t gamma2, VECTOR *out)
248 {
249     size_t i;
250 
251     for (i = 0; i < out->num_poly; i++)
252         poly_use_hint(h->poly + i, r->poly + i, gamma2, out->poly + i);
253 }
254