xref: /freebsd/crypto/libecc/src/examples/sig/common/common.h (revision dd21556857e8d40f66bf5ad54754d9d52669ebf7)
1 /* Some common helpers useful for many algorithms */
2 #ifndef __COMMON_H__
3 #define __COMMON_H__
4 
5 /* Include our arithmetic layer */
6 #include <libecc/libarith.h>
7 
8 /* I2OSP and OS2IP internal primitives */
9 ATTRIBUTE_WARN_UNUSED_RET static inline int _i2osp(nn_src_t x, u8 *buf, u16 buflen)
10 {
11 	int ret;
12 	bitcnt_t blen;
13 
14 	/* Sanity checks */
15 	MUST_HAVE((buf != NULL), ret, err);
16 	ret = nn_check_initialized(x); EG(ret, err);
17 
18 	/* If x >= 256^xLen (the integer does not fit in the buffer),
19 	 * return an error.
20 	 */
21 	ret = nn_bitlen(x, &blen); EG(ret, err);
22 	MUST_HAVE(((8 * buflen) >= blen), ret, err);
23 
24 	/* Export to the buffer */
25 	ret = nn_export_to_buf(buf, buflen, x);
26 
27 err:
28 	return ret;
29 }
30 
31 ATTRIBUTE_WARN_UNUSED_RET static inline int _os2ip(nn_t x, const u8 *buf, u16 buflen)
32 {
33 	int ret;
34 
35 	/* We do not want to exceed our computation compatible
36 	 * size.
37 	 */
38 	MUST_HAVE((buflen <= NN_USABLE_MAX_BYTE_LEN), ret, err);
39 
40 	/* Import the NN */
41 	ret = nn_init_from_buf(x, buf, buflen);
42 
43 err:
44 	return ret;
45 }
46 
47 /* Reverses the endiannes of a buffer in place */
48 ATTRIBUTE_WARN_UNUSED_RET static inline int _reverse_endianness(u8 *buf, u16 buf_size)
49 {
50 	u16 i;
51 	u8 tmp;
52 	int ret;
53 
54 	MUST_HAVE((buf != NULL), ret, err);
55 
56 	if(buf_size > 1){
57 		for(i = 0; i < (buf_size / 2); i++){
58 			tmp = buf[i];
59 			buf[i] = buf[buf_size - 1 - i];
60 			buf[buf_size - 1 - i] = tmp;
61 		}
62 	}
63 
64 	ret = 0;
65 
66 err:
67         return ret;
68 }
69 
70 /* Helper to fix the MSB of a scalar using the trick in
71  * https://eprint.iacr.org/2011/232.pdf
72  *
73  *  We distinguish three situations:
74  *     - The scalar m is < q (the order), in this case we compute:
75  *         -
76  *        | m' = m + (2 * q) if [log(k + q)] == [log(q)],
77  *        | m' = m + q otherwise.
78  *         -
79  *     - The scalar m is >= q and < q**2, in this case we compute:
80  *         -
81  *        | m' = m + (2 * (q**2)) if [log(k + (q**2))] == [log(q**2)],
82  *        | m' = m + (q**2) otherwise.
83  *         -
84  *     - The scalar m is >= (q**2), in this case m == m'
85  *  We only deal with 0 <= m < (q**2) using the countermeasure. When m >= (q**2),
86  *  we stick with m' = m, accepting MSB issues (not much can be done in this case
87  *  anyways).
88  */
89 ATTRIBUTE_WARN_UNUSED_RET static inline int _fix_scalar_msb(nn_src_t m, nn_src_t q, nn_t m_msb_fixed)
90 {
91 	int ret, cmp;
92 	/* _m_msb_fixed to handle aliasing */
93 	nn q_square, _m_msb_fixed;
94 	q_square.magic = _m_msb_fixed.magic = WORD(0);
95 
96 	/* Sanity checks */
97 	ret = nn_check_initialized(m); EG(ret, err);
98 	ret = nn_check_initialized(q); EG(ret, err);
99 	ret = nn_check_initialized(m_msb_fixed); EG(ret, err);
100 
101 	ret = nn_init(&q_square, 0); EG(ret, err);
102 	ret = nn_init(&_m_msb_fixed, 0); EG(ret, err);
103 
104 	/* First compute q**2 */
105 	ret = nn_sqr(&q_square, q); EG(ret, err);
106 	/* Then compute m' depending on m size */
107 	ret = nn_cmp(m, q, &cmp); EG(ret, err);
108 	if (cmp < 0){
109 		bitcnt_t msb_bit_len, q_bitlen;
110 
111 		/* Case where m < q */
112 		ret = nn_add(&_m_msb_fixed, m, q); EG(ret, err);
113 		ret = nn_bitlen(&_m_msb_fixed, &msb_bit_len); EG(ret, err);
114 		ret = nn_bitlen(q, &q_bitlen); EG(ret, err);
115 		ret = nn_cnd_add((msb_bit_len == q_bitlen), m_msb_fixed,
116 				  &_m_msb_fixed, q); EG(ret, err);
117 	} else {
118 		ret = nn_cmp(m, &q_square, &cmp); EG(ret, err);
119 		if (cmp < 0) {
120 			bitcnt_t msb_bit_len, q_square_bitlen;
121 
122 			/* Case where m >= q and m < (q**2) */
123 			ret = nn_add(&_m_msb_fixed, m, &q_square); EG(ret, err);
124 			ret = nn_bitlen(&_m_msb_fixed, &msb_bit_len); EG(ret, err);
125 			ret = nn_bitlen(&q_square, &q_square_bitlen); EG(ret, err);
126 			ret = nn_cnd_add((msb_bit_len == q_square_bitlen),
127 					m_msb_fixed, &_m_msb_fixed, &q_square); EG(ret, err);
128 		} else {
129 			/* Case where m >= (q**2) */
130 			ret = nn_copy(m_msb_fixed, m); EG(ret, err);
131 		}
132 	}
133 
134 err:
135 	nn_uninit(&q_square);
136 	nn_uninit(&_m_msb_fixed);
137 
138 	return ret;
139 }
140 
141 /* Helper to blind the scalar.
142  * Compute m_blind = m + (b * q) where b is a random value modulo q.
143  * Aliasing is supported.
144  */
145 ATTRIBUTE_WARN_UNUSED_RET static inline int _blind_scalar(nn_src_t m, nn_src_t q, nn_t m_blind)
146 {
147         int ret;
148 	nn tmp;
149 	tmp.magic = WORD(0);
150 
151 	/* Sanity checks */
152         ret = nn_check_initialized(m); EG(ret, err);
153         ret = nn_check_initialized(q); EG(ret, err);
154         ret = nn_check_initialized(m_blind); EG(ret, err);
155 
156 	ret = nn_get_random_mod(&tmp, q); EG(ret, err);
157 
158 	ret = nn_mul(&tmp, &tmp, q); EG(ret, err);
159 	ret = nn_add(m_blind, &tmp, m);
160 
161 err:
162 	nn_uninit(&tmp);
163 
164 	return ret;
165 }
166 
167 /*
168  * NOT constant time at all and not secure against side-channels. This is
169  * an internal function only used for DSA verification on public data.
170  *
171  * Compute (base ** exp) mod (mod) using a square and multiply algorithm.
172  * Internally, this computes Montgomery coefficients and uses the redc
173  * function.
174  *
175  * Returns 0 on success, -1 on error.
176  */
177 ATTRIBUTE_WARN_UNUSED_RET static inline int _nn_mod_pow_insecure(nn_t out, nn_src_t base,
178 							  nn_src_t exp, nn_src_t mod)
179 {
180 	int ret, isodd, cmp;
181 	bitcnt_t explen;
182 	u8 expbit;
183 	nn r, r_square, _base, one;
184 	word_t mpinv;
185 	r.magic = r_square.magic = _base.magic = one.magic = WORD(0);
186 
187 	/* Aliasing is not supported for this internal helper */
188 	MUST_HAVE((out != base) && (out != exp) && (out != mod), ret, err);
189 
190 	/* Check initializations */
191 	ret = nn_check_initialized(base); EG(ret, err);
192 	ret = nn_check_initialized(exp); EG(ret, err);
193 	ret = nn_check_initialized(mod); EG(ret, err);
194 
195 	ret = nn_bitlen(exp, &explen); EG(ret, err);
196 	/* Sanity check */
197 	MUST_HAVE((explen > 0), ret, err);
198 
199 	/* Check that the modulo is indeed odd */
200 	ret = nn_isodd(mod, &isodd); EG(ret, err);
201 	MUST_HAVE(isodd, ret, err);
202 
203 	/* Compute the Montgomery coefficients */
204 	ret = nn_compute_redc1_coefs(&r, &r_square, mod, &mpinv); EG(ret, err);
205 
206 	/* Reduce the base if necessary */
207 	ret = nn_cmp(base, mod, &cmp); EG(ret, err);
208 	if(cmp >= 0){
209 		ret = nn_mod(&_base, base, mod); EG(ret, err);
210 	}
211 	else{
212 		ret = nn_copy(&_base, base); EG(ret, err);
213 	}
214 
215 	ret = nn_mul_redc1(&_base, &_base, &r_square, mod, mpinv); EG(ret, err);
216 	ret = nn_copy(out, &r); EG(ret, err);
217 
218 	ret = nn_init(&one, 0); EG(ret, err);
219 	ret = nn_one(&one); EG(ret, err);
220 
221 	while (explen > 0) {
222 		explen = (bitcnt_t)(explen - 1);
223 
224 		/* Get the bit */
225 		ret = nn_getbit(exp, explen, &expbit); EG(ret, err);
226 
227 		/* Square */
228 		ret = nn_mul_redc1(out, out, out, mod, mpinv); EG(ret, err);
229 
230 		if(expbit){
231 			/* Multiply */
232 			ret = nn_mul_redc1(out, out, &_base, mod, mpinv); EG(ret, err);
233 		}
234 	}
235 	/* Unredcify the output */
236 	ret = nn_mul_redc1(out, out, &one, mod, mpinv);
237 
238 err:
239 	nn_uninit(&r);
240 	nn_uninit(&r_square);
241 	nn_uninit(&_base);
242 	nn_uninit(&one);
243 
244 	return ret;
245 }
246 
247 
248 #endif /* __COMMON_H__ */
249