xref: /linux/lib/crypto/mpi/mpicoder.c (revision 2fe3c78a2c26dd5ee811024a1b7d6cfb4d654319)
1 /* mpicoder.c  -  Coder for the external representation of MPIs
2  * Copyright (C) 1998, 1999 Free Software Foundation, Inc.
3  *
4  * This file is part of GnuPG.
5  *
6  * GnuPG is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation; either version 2 of the License, or
9  * (at your option) any later version.
10  *
11  * GnuPG is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with this program; if not, write to the Free Software
18  * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA
19  */
20 
21 #include <linux/bitops.h>
22 #include <linux/count_zeros.h>
23 #include <linux/byteorder/generic.h>
24 #include <linux/scatterlist.h>
25 #include <linux/string.h>
26 #include "mpi-internal.h"
27 
28 #define MAX_EXTERN_MPI_BITS 16384
29 
30 /**
31  * mpi_read_raw_data - Read a raw byte stream as a positive integer
32  * @xbuffer: The data to read
33  * @nbytes: The amount of data to read
34  */
35 MPI mpi_read_raw_data(const void *xbuffer, size_t nbytes)
36 {
37 	const uint8_t *buffer = xbuffer;
38 	int i, j;
39 	unsigned nbits, nlimbs;
40 	mpi_limb_t a;
41 	MPI val = NULL;
42 
43 	while (nbytes > 0 && buffer[0] == 0) {
44 		buffer++;
45 		nbytes--;
46 	}
47 
48 	nbits = nbytes * 8;
49 	if (nbits > MAX_EXTERN_MPI_BITS) {
50 		pr_info("MPI: mpi too large (%u bits)\n", nbits);
51 		return NULL;
52 	}
53 	if (nbytes > 0)
54 		nbits -= count_leading_zeros(buffer[0]) - (BITS_PER_LONG - 8);
55 
56 	nlimbs = DIV_ROUND_UP(nbytes, BYTES_PER_MPI_LIMB);
57 	val = mpi_alloc(nlimbs);
58 	if (!val)
59 		return NULL;
60 	val->nbits = nbits;
61 	val->sign = 0;
62 	val->nlimbs = nlimbs;
63 
64 	if (nbytes > 0) {
65 		i = BYTES_PER_MPI_LIMB - nbytes % BYTES_PER_MPI_LIMB;
66 		i %= BYTES_PER_MPI_LIMB;
67 		for (j = nlimbs; j > 0; j--) {
68 			a = 0;
69 			for (; i < BYTES_PER_MPI_LIMB; i++) {
70 				a <<= 8;
71 				a |= *buffer++;
72 			}
73 			i = 0;
74 			val->d[j - 1] = a;
75 		}
76 	}
77 	return val;
78 }
79 EXPORT_SYMBOL_GPL(mpi_read_raw_data);
80 
81 MPI mpi_read_from_buffer(const void *xbuffer, unsigned *ret_nread)
82 {
83 	const uint8_t *buffer = xbuffer;
84 	unsigned int nbits, nbytes;
85 	MPI val;
86 
87 	if (*ret_nread < 2)
88 		return ERR_PTR(-EINVAL);
89 	nbits = buffer[0] << 8 | buffer[1];
90 
91 	if (nbits > MAX_EXTERN_MPI_BITS) {
92 		pr_info("MPI: mpi too large (%u bits)\n", nbits);
93 		return ERR_PTR(-EINVAL);
94 	}
95 
96 	nbytes = DIV_ROUND_UP(nbits, 8);
97 	if (nbytes + 2 > *ret_nread) {
98 		pr_info("MPI: mpi larger than buffer nbytes=%u ret_nread=%u\n",
99 				nbytes, *ret_nread);
100 		return ERR_PTR(-EINVAL);
101 	}
102 
103 	val = mpi_read_raw_data(buffer + 2, nbytes);
104 	if (!val)
105 		return ERR_PTR(-ENOMEM);
106 
107 	*ret_nread = nbytes + 2;
108 	return val;
109 }
110 EXPORT_SYMBOL_GPL(mpi_read_from_buffer);
111 
112 static int count_lzeros(MPI a)
113 {
114 	mpi_limb_t alimb;
115 	int i, lzeros = 0;
116 
117 	for (i = a->nlimbs - 1; i >= 0; i--) {
118 		alimb = a->d[i];
119 		if (alimb == 0) {
120 			lzeros += sizeof(mpi_limb_t);
121 		} else {
122 			lzeros += count_leading_zeros(alimb) / 8;
123 			break;
124 		}
125 	}
126 	return lzeros;
127 }
128 
129 /**
130  * mpi_read_buffer() - read MPI to a buffer provided by user (msb first)
131  *
132  * @a:		a multi precision integer
133  * @buf:	buffer to which the output will be written to. Needs to be at
134  *		least mpi_get_size(a) long.
135  * @buf_len:	size of the buf.
136  * @nbytes:	receives the actual length of the data written on success and
137  *		the data to-be-written on -EOVERFLOW in case buf_len was too
138  *		small.
139  * @sign:	if not NULL, it will be set to the sign of a.
140  *
141  * Return:	0 on success or error code in case of error
142  */
143 int mpi_read_buffer(MPI a, uint8_t *buf, unsigned buf_len, unsigned *nbytes,
144 		    int *sign)
145 {
146 	uint8_t *p;
147 #if BYTES_PER_MPI_LIMB == 4
148 	__be32 alimb;
149 #elif BYTES_PER_MPI_LIMB == 8
150 	__be64 alimb;
151 #else
152 #error please implement for this limb size.
153 #endif
154 	unsigned int n = mpi_get_size(a);
155 	int i, lzeros;
156 
157 	if (!buf || !nbytes)
158 		return -EINVAL;
159 
160 	if (sign)
161 		*sign = a->sign;
162 
163 	lzeros = count_lzeros(a);
164 
165 	if (buf_len < n - lzeros) {
166 		*nbytes = n - lzeros;
167 		return -EOVERFLOW;
168 	}
169 
170 	p = buf;
171 	*nbytes = n - lzeros;
172 
173 	for (i = a->nlimbs - 1 - lzeros / BYTES_PER_MPI_LIMB,
174 			lzeros %= BYTES_PER_MPI_LIMB;
175 		i >= 0; i--) {
176 #if BYTES_PER_MPI_LIMB == 4
177 		alimb = cpu_to_be32(a->d[i]);
178 #elif BYTES_PER_MPI_LIMB == 8
179 		alimb = cpu_to_be64(a->d[i]);
180 #else
181 #error please implement for this limb size.
182 #endif
183 		memcpy(p, (u8 *)&alimb + lzeros, BYTES_PER_MPI_LIMB - lzeros);
184 		p += BYTES_PER_MPI_LIMB - lzeros;
185 		lzeros = 0;
186 	}
187 	return 0;
188 }
189 EXPORT_SYMBOL_GPL(mpi_read_buffer);
190 
191 /*
192  * mpi_get_buffer() - Returns an allocated buffer with the MPI (msb first).
193  * Caller must free the return string.
194  * This function does return a 0 byte buffer with nbytes set to zero if the
195  * value of A is zero.
196  *
197  * @a:		a multi precision integer.
198  * @nbytes:	receives the length of this buffer.
199  * @sign:	if not NULL, it will be set to the sign of the a.
200  *
201  * Return:	Pointer to MPI buffer or NULL on error
202  */
203 void *mpi_get_buffer(MPI a, unsigned *nbytes, int *sign)
204 {
205 	uint8_t *buf;
206 	unsigned int n;
207 	int ret;
208 
209 	if (!nbytes)
210 		return NULL;
211 
212 	n = mpi_get_size(a);
213 
214 	if (!n)
215 		n++;
216 
217 	buf = kmalloc(n, GFP_KERNEL);
218 
219 	if (!buf)
220 		return NULL;
221 
222 	ret = mpi_read_buffer(a, buf, n, nbytes, sign);
223 
224 	if (ret) {
225 		kfree(buf);
226 		return NULL;
227 	}
228 	return buf;
229 }
230 EXPORT_SYMBOL_GPL(mpi_get_buffer);
231 
232 /**
233  * mpi_write_to_sgl() - Funnction exports MPI to an sgl (msb first)
234  *
235  * This function works in the same way as the mpi_read_buffer, but it
236  * takes an sgl instead of u8 * buf.
237  *
238  * @a:		a multi precision integer
239  * @sgl:	scatterlist to write to. Needs to be at least
240  *		mpi_get_size(a) long.
241  * @nbytes:	the number of bytes to write.  Leading bytes will be
242  *		filled with zero.
243  * @sign:	if not NULL, it will be set to the sign of a.
244  *
245  * Return:	0 on success or error code in case of error
246  */
247 int mpi_write_to_sgl(MPI a, struct scatterlist *sgl, unsigned nbytes,
248 		     int *sign)
249 {
250 	u8 *p, *p2;
251 #if BYTES_PER_MPI_LIMB == 4
252 	__be32 alimb;
253 #elif BYTES_PER_MPI_LIMB == 8
254 	__be64 alimb;
255 #else
256 #error please implement for this limb size.
257 #endif
258 	unsigned int n = mpi_get_size(a);
259 	struct sg_mapping_iter miter;
260 	int i, x, buf_len;
261 	int nents;
262 
263 	if (sign)
264 		*sign = a->sign;
265 
266 	if (nbytes < n)
267 		return -EOVERFLOW;
268 
269 	nents = sg_nents_for_len(sgl, nbytes);
270 	if (nents < 0)
271 		return -EINVAL;
272 
273 	sg_miter_start(&miter, sgl, nents, SG_MITER_ATOMIC | SG_MITER_TO_SG);
274 	sg_miter_next(&miter);
275 	buf_len = miter.length;
276 	p2 = miter.addr;
277 
278 	while (nbytes > n) {
279 		i = min_t(unsigned, nbytes - n, buf_len);
280 		memset(p2, 0, i);
281 		p2 += i;
282 		nbytes -= i;
283 
284 		buf_len -= i;
285 		if (!buf_len) {
286 			sg_miter_next(&miter);
287 			buf_len = miter.length;
288 			p2 = miter.addr;
289 		}
290 	}
291 
292 	for (i = a->nlimbs - 1; i >= 0; i--) {
293 #if BYTES_PER_MPI_LIMB == 4
294 		alimb = a->d[i] ? cpu_to_be32(a->d[i]) : 0;
295 #elif BYTES_PER_MPI_LIMB == 8
296 		alimb = a->d[i] ? cpu_to_be64(a->d[i]) : 0;
297 #else
298 #error please implement for this limb size.
299 #endif
300 		p = (u8 *)&alimb;
301 
302 		for (x = 0; x < sizeof(alimb); x++) {
303 			*p2++ = *p++;
304 			if (!--buf_len) {
305 				sg_miter_next(&miter);
306 				buf_len = miter.length;
307 				p2 = miter.addr;
308 			}
309 		}
310 	}
311 
312 	sg_miter_stop(&miter);
313 	return 0;
314 }
315 EXPORT_SYMBOL_GPL(mpi_write_to_sgl);
316 
317 /*
318  * mpi_read_raw_from_sgl() - Function allocates an MPI and populates it with
319  *			     data from the sgl
320  *
321  * This function works in the same way as the mpi_read_raw_data, but it
322  * takes an sgl instead of void * buffer. i.e. it allocates
323  * a new MPI and reads the content of the sgl to the MPI.
324  *
325  * @sgl:	scatterlist to read from
326  * @nbytes:	number of bytes to read
327  *
328  * Return:	Pointer to a new MPI or NULL on error
329  */
330 MPI mpi_read_raw_from_sgl(struct scatterlist *sgl, unsigned int nbytes)
331 {
332 	struct sg_mapping_iter miter;
333 	unsigned int nbits, nlimbs;
334 	int x, j, z, lzeros, ents;
335 	unsigned int len;
336 	const u8 *buff;
337 	mpi_limb_t a;
338 	MPI val = NULL;
339 
340 	ents = sg_nents_for_len(sgl, nbytes);
341 	if (ents < 0)
342 		return NULL;
343 
344 	sg_miter_start(&miter, sgl, ents, SG_MITER_ATOMIC | SG_MITER_FROM_SG);
345 
346 	lzeros = 0;
347 	len = 0;
348 	while (nbytes > 0) {
349 		while (len && !*buff) {
350 			lzeros++;
351 			len--;
352 			buff++;
353 		}
354 
355 		if (len && *buff)
356 			break;
357 
358 		sg_miter_next(&miter);
359 		buff = miter.addr;
360 		len = miter.length;
361 
362 		nbytes -= lzeros;
363 		lzeros = 0;
364 	}
365 
366 	miter.consumed = lzeros;
367 
368 	nbytes -= lzeros;
369 	nbits = nbytes * 8;
370 	if (nbits > MAX_EXTERN_MPI_BITS) {
371 		sg_miter_stop(&miter);
372 		pr_info("MPI: mpi too large (%u bits)\n", nbits);
373 		return NULL;
374 	}
375 
376 	if (nbytes > 0)
377 		nbits -= count_leading_zeros(*buff) - (BITS_PER_LONG - 8);
378 
379 	sg_miter_stop(&miter);
380 
381 	nlimbs = DIV_ROUND_UP(nbytes, BYTES_PER_MPI_LIMB);
382 	val = mpi_alloc(nlimbs);
383 	if (!val)
384 		return NULL;
385 
386 	val->nbits = nbits;
387 	val->sign = 0;
388 	val->nlimbs = nlimbs;
389 
390 	if (nbytes == 0)
391 		return val;
392 
393 	j = nlimbs - 1;
394 	a = 0;
395 	z = BYTES_PER_MPI_LIMB - nbytes % BYTES_PER_MPI_LIMB;
396 	z %= BYTES_PER_MPI_LIMB;
397 
398 	while (sg_miter_next(&miter)) {
399 		buff = miter.addr;
400 		len = min_t(unsigned, miter.length, nbytes);
401 		nbytes -= len;
402 
403 		for (x = 0; x < len; x++) {
404 			a <<= 8;
405 			a |= *buff++;
406 			if (((z + x + 1) % BYTES_PER_MPI_LIMB) == 0) {
407 				val->d[j--] = a;
408 				a = 0;
409 			}
410 		}
411 		z += x;
412 	}
413 
414 	return val;
415 }
416 EXPORT_SYMBOL_GPL(mpi_read_raw_from_sgl);
417