1 /*
2 * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3 *
4 * Licensed under the Apache License 2.0 (the "License"). You may not use
5 * this file except in compliance with the License. You can obtain a copy
6 * in the file LICENSE in the source distribution or at
7 * https://www.openssl.org/source/license.html
8 */
9
10 #include <assert.h>
11 #include "ml_dsa_poly.h"
12
13 struct vector_st {
14 POLY *poly;
15 size_t num_poly;
16 };
17
18 /**
19 * @brief Initialize a Vector object.
20 *
21 * @param v The vector to initialize.
22 * @param polys Preallocated storage for an array of Polynomials blocks. |v|
23 * does not own/free this.
24 * @param num_polys The number of |polys| blocks (k or l)
25 */
26 static ossl_inline ossl_unused
vector_init(VECTOR * v,POLY * polys,size_t num_polys)27 void vector_init(VECTOR *v, POLY *polys, size_t num_polys)
28 {
29 v->poly = polys;
30 v->num_poly = num_polys;
31 }
32
33 static ossl_inline ossl_unused
vector_alloc(VECTOR * v,size_t num_polys)34 int vector_alloc(VECTOR *v, size_t num_polys)
35 {
36 v->poly = OPENSSL_malloc(num_polys * sizeof(POLY));
37 if (v->poly == NULL)
38 return 0;
39 v->num_poly = num_polys;
40 return 1;
41 }
42
43 static ossl_inline ossl_unused
vector_free(VECTOR * v)44 void vector_free(VECTOR *v)
45 {
46 OPENSSL_free(v->poly);
47 v->poly = NULL;
48 v->num_poly = 0;
49 }
50
51 /* @brief zeroize a vectors polynomial coefficients */
52 static ossl_inline ossl_unused
vector_zero(VECTOR * va)53 void vector_zero(VECTOR *va)
54 {
55 if (va->poly != NULL)
56 memset(va->poly, 0, va->num_poly * sizeof(va->poly[0]));
57 }
58
59 /*
60 * @brief copy a vector
61 * The assumption is that |dst| has already been initialized
62 */
63 static ossl_inline ossl_unused void
vector_copy(VECTOR * dst,const VECTOR * src)64 vector_copy(VECTOR *dst, const VECTOR *src)
65 {
66 assert(dst->num_poly == src->num_poly);
67 memcpy(dst->poly, src->poly, src->num_poly * sizeof(src->poly[0]));
68 }
69
70 /* @brief return 1 if 2 vectors are equal, or 0 otherwise */
71 static ossl_inline ossl_unused int
vector_equal(const VECTOR * a,const VECTOR * b)72 vector_equal(const VECTOR *a, const VECTOR *b)
73 {
74 size_t i;
75
76 if (a->num_poly != b->num_poly)
77 return 0;
78 for (i = 0; i < a->num_poly; ++i) {
79 if (!poly_equal(a->poly + i, b->poly + i))
80 return 0;
81 }
82 return 1;
83 }
84
85 /* @brief add 2 vectors */
86 static ossl_inline ossl_unused void
vector_add(const VECTOR * lhs,const VECTOR * rhs,VECTOR * out)87 vector_add(const VECTOR *lhs, const VECTOR *rhs, VECTOR *out)
88 {
89 size_t i;
90
91 for (i = 0; i < lhs->num_poly; i++)
92 poly_add(lhs->poly + i, rhs->poly + i, out->poly + i);
93 }
94
95 /* @brief subtract 2 vectors */
96 static ossl_inline ossl_unused void
vector_sub(const VECTOR * lhs,const VECTOR * rhs,VECTOR * out)97 vector_sub(const VECTOR *lhs, const VECTOR *rhs, VECTOR *out)
98 {
99 size_t i;
100
101 for (i = 0; i < lhs->num_poly; i++)
102 poly_sub(lhs->poly + i, rhs->poly + i, out->poly + i);
103 }
104
105 /* @brief convert a vector in place into NTT form */
106 static ossl_inline ossl_unused void
vector_ntt(VECTOR * va)107 vector_ntt(VECTOR *va)
108 {
109 size_t i;
110
111 for (i = 0; i < va->num_poly; i++)
112 ossl_ml_dsa_poly_ntt(va->poly + i);
113 }
114
115 /* @brief convert a vector in place into inverse NTT form */
116 static ossl_inline ossl_unused void
vector_ntt_inverse(VECTOR * va)117 vector_ntt_inverse(VECTOR *va)
118 {
119 size_t i;
120
121 for (i = 0; i < va->num_poly; i++)
122 ossl_ml_dsa_poly_ntt_inverse(va->poly + i);
123 }
124
125 /* @brief multiply a vector by a SCALAR polynomial */
126 static ossl_inline ossl_unused void
vector_mult_scalar(const VECTOR * lhs,const POLY * rhs,VECTOR * out)127 vector_mult_scalar(const VECTOR *lhs, const POLY *rhs, VECTOR *out)
128 {
129 size_t i;
130
131 for (i = 0; i < lhs->num_poly; i++)
132 ossl_ml_dsa_poly_ntt_mult(lhs->poly + i, rhs, out->poly + i);
133 }
134
135 static ossl_inline ossl_unused int
vector_expand_S(EVP_MD_CTX * h_ctx,const EVP_MD * md,int eta,const uint8_t * seed,VECTOR * s1,VECTOR * s2)136 vector_expand_S(EVP_MD_CTX *h_ctx, const EVP_MD *md, int eta,
137 const uint8_t *seed, VECTOR *s1, VECTOR *s2)
138 {
139 return ossl_ml_dsa_vector_expand_S(h_ctx, md, eta, seed, s1, s2);
140 }
141
142 static ossl_inline ossl_unused void
vector_expand_mask(VECTOR * out,const uint8_t * rho_prime,size_t rho_prime_len,uint32_t kappa,uint32_t gamma1,EVP_MD_CTX * h_ctx,const EVP_MD * md)143 vector_expand_mask(VECTOR *out, const uint8_t *rho_prime, size_t rho_prime_len,
144 uint32_t kappa, uint32_t gamma1,
145 EVP_MD_CTX *h_ctx, const EVP_MD *md)
146 {
147 size_t i;
148 uint8_t derived_seed[ML_DSA_RHO_PRIME_BYTES + 2];
149
150 memcpy(derived_seed, rho_prime, ML_DSA_RHO_PRIME_BYTES);
151
152 for (i = 0; i < out->num_poly; i++) {
153 size_t index = kappa + i;
154
155 derived_seed[ML_DSA_RHO_PRIME_BYTES] = index & 0xFF;
156 derived_seed[ML_DSA_RHO_PRIME_BYTES + 1] = (index >> 8) & 0xFF;
157 poly_expand_mask(out->poly + i, derived_seed, sizeof(derived_seed),
158 gamma1, h_ctx, md);
159 }
160 }
161
162 /* Scale back previously rounded value */
163 static ossl_inline ossl_unused void
vector_scale_power2_round_ntt(const VECTOR * in,VECTOR * out)164 vector_scale_power2_round_ntt(const VECTOR *in, VECTOR *out)
165 {
166 size_t i;
167
168 for (i = 0; i < in->num_poly; i++)
169 poly_scale_power2_round(in->poly + i, out->poly + i);
170 vector_ntt(out);
171 }
172
173 /*
174 * @brief Decompose all polynomial coefficients of a vector into (t1, t0) such
175 * that coeff[i] == t1[i] * 2^13 + t0[i] mod q.
176 * See FIPS 204, Algorithm 35, Power2Round()
177 */
178 static ossl_inline ossl_unused void
vector_power2_round(const VECTOR * t,VECTOR * t1,VECTOR * t0)179 vector_power2_round(const VECTOR *t, VECTOR *t1, VECTOR *t0)
180 {
181 size_t i;
182
183 for (i = 0; i < t->num_poly; i++)
184 poly_power2_round(t->poly + i, t1->poly + i, t0->poly + i);
185 }
186
187 static ossl_inline ossl_unused void
vector_high_bits(const VECTOR * in,uint32_t gamma2,VECTOR * out)188 vector_high_bits(const VECTOR *in, uint32_t gamma2, VECTOR *out)
189 {
190 size_t i;
191
192 for (i = 0; i < out->num_poly; i++)
193 poly_high_bits(in->poly + i, gamma2, out->poly + i);
194 }
195
196 static ossl_inline ossl_unused void
vector_low_bits(const VECTOR * in,uint32_t gamma2,VECTOR * out)197 vector_low_bits(const VECTOR *in, uint32_t gamma2, VECTOR *out)
198 {
199 size_t i;
200
201 for (i = 0; i < out->num_poly; i++)
202 poly_low_bits(in->poly + i, gamma2, out->poly + i);
203 }
204
205 static ossl_inline ossl_unused uint32_t
vector_max(const VECTOR * v)206 vector_max(const VECTOR *v)
207 {
208 size_t i;
209 uint32_t mx = 0;
210
211 for (i = 0; i < v->num_poly; i++)
212 poly_max(v->poly + i, &mx);
213 return mx;
214 }
215
216 static ossl_inline ossl_unused uint32_t
vector_max_signed(const VECTOR * v)217 vector_max_signed(const VECTOR *v)
218 {
219 size_t i;
220 uint32_t mx = 0;
221
222 for (i = 0; i < v->num_poly; i++)
223 poly_max_signed(v->poly + i, &mx);
224 return mx;
225 }
226
227 static ossl_inline ossl_unused size_t
vector_count_ones(const VECTOR * v)228 vector_count_ones(const VECTOR *v)
229 {
230 int j;
231 size_t i, count = 0;
232
233 for (i = 0; i < v->num_poly; i++)
234 for (j = 0; j < ML_DSA_NUM_POLY_COEFFICIENTS; j++)
235 count += v->poly[i].coeff[j];
236 return count;
237 }
238
239 static ossl_inline ossl_unused void
vector_make_hint(const VECTOR * ct0,const VECTOR * cs2,const VECTOR * w,uint32_t gamma2,VECTOR * out)240 vector_make_hint(const VECTOR *ct0, const VECTOR *cs2, const VECTOR *w,
241 uint32_t gamma2, VECTOR *out)
242 {
243 size_t i;
244
245 for (i = 0; i < out->num_poly; i++)
246 poly_make_hint(ct0->poly + i, cs2->poly + i, w->poly + i, gamma2,
247 out->poly + i);
248 }
249
250 static ossl_inline ossl_unused void
vector_use_hint(const VECTOR * h,const VECTOR * r,uint32_t gamma2,VECTOR * out)251 vector_use_hint(const VECTOR *h, const VECTOR *r, uint32_t gamma2, VECTOR *out)
252 {
253 size_t i;
254
255 for (i = 0; i < out->num_poly; i++)
256 poly_use_hint(h->poly + i, r->poly + i, gamma2, out->poly + i);
257 }
258