xref: /linux/lib/crypto/mpi/mpicoder.c (revision 2330437da0994321020777c605a2a8cb0ecb7001)
1 /* mpicoder.c  -  Coder for the external representation of MPIs
2  * Copyright (C) 1998, 1999 Free Software Foundation, Inc.
3  *
4  * This file is part of GnuPG.
5  *
6  * GnuPG is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation; either version 2 of the License, or
9  * (at your option) any later version.
10  *
11  * GnuPG is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with this program; if not, write to the Free Software
18  * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA
19  */
20 
21 #include <linux/bitops.h>
22 #include <linux/byteorder/generic.h>
23 #include <linux/count_zeros.h>
24 #include <linux/export.h>
25 #include <linux/scatterlist.h>
26 #include <linux/string.h>
27 #include "mpi-internal.h"
28 
29 #define MAX_EXTERN_MPI_BITS 16384
30 
31 /**
32  * mpi_read_raw_data - Read a raw byte stream as a positive integer
33  * @xbuffer: The data to read
34  * @nbytes: The amount of data to read
35  */
36 MPI mpi_read_raw_data(const void *xbuffer, size_t nbytes)
37 {
38 	const uint8_t *buffer = xbuffer;
39 	int i, j;
40 	unsigned nbits, nlimbs;
41 	mpi_limb_t a;
42 	MPI val = NULL;
43 
44 	while (nbytes > 0 && buffer[0] == 0) {
45 		buffer++;
46 		nbytes--;
47 	}
48 
49 	nbits = nbytes * 8;
50 	if (nbits > MAX_EXTERN_MPI_BITS) {
51 		pr_info("MPI: mpi too large (%u bits)\n", nbits);
52 		return NULL;
53 	}
54 	if (nbytes > 0)
55 		nbits -= count_leading_zeros(buffer[0]) - (BITS_PER_LONG - 8);
56 
57 	nlimbs = DIV_ROUND_UP(nbytes, BYTES_PER_MPI_LIMB);
58 	val = mpi_alloc(nlimbs);
59 	if (!val)
60 		return NULL;
61 	val->nbits = nbits;
62 	val->sign = 0;
63 	val->nlimbs = nlimbs;
64 
65 	if (nbytes > 0) {
66 		i = BYTES_PER_MPI_LIMB - nbytes % BYTES_PER_MPI_LIMB;
67 		i %= BYTES_PER_MPI_LIMB;
68 		for (j = nlimbs; j > 0; j--) {
69 			a = 0;
70 			for (; i < BYTES_PER_MPI_LIMB; i++) {
71 				a <<= 8;
72 				a |= *buffer++;
73 			}
74 			i = 0;
75 			val->d[j - 1] = a;
76 		}
77 	}
78 	return val;
79 }
80 EXPORT_SYMBOL_GPL(mpi_read_raw_data);
81 
82 MPI mpi_read_from_buffer(const void *xbuffer, unsigned *ret_nread)
83 {
84 	const uint8_t *buffer = xbuffer;
85 	unsigned int nbits, nbytes;
86 	MPI val;
87 
88 	if (*ret_nread < 2)
89 		return ERR_PTR(-EINVAL);
90 	nbits = buffer[0] << 8 | buffer[1];
91 
92 	if (nbits > MAX_EXTERN_MPI_BITS) {
93 		pr_info("MPI: mpi too large (%u bits)\n", nbits);
94 		return ERR_PTR(-EINVAL);
95 	}
96 
97 	nbytes = DIV_ROUND_UP(nbits, 8);
98 	if (nbytes + 2 > *ret_nread) {
99 		pr_info("MPI: mpi larger than buffer nbytes=%u ret_nread=%u\n",
100 				nbytes, *ret_nread);
101 		return ERR_PTR(-EINVAL);
102 	}
103 
104 	val = mpi_read_raw_data(buffer + 2, nbytes);
105 	if (!val)
106 		return ERR_PTR(-ENOMEM);
107 
108 	*ret_nread = nbytes + 2;
109 	return val;
110 }
111 EXPORT_SYMBOL_GPL(mpi_read_from_buffer);
112 
113 static int count_lzeros(MPI a)
114 {
115 	mpi_limb_t alimb;
116 	int i, lzeros = 0;
117 
118 	for (i = a->nlimbs - 1; i >= 0; i--) {
119 		alimb = a->d[i];
120 		if (alimb == 0) {
121 			lzeros += sizeof(mpi_limb_t);
122 		} else {
123 			lzeros += count_leading_zeros(alimb) / 8;
124 			break;
125 		}
126 	}
127 	return lzeros;
128 }
129 
130 /**
131  * mpi_read_buffer() - read MPI to a buffer provided by user (msb first)
132  *
133  * @a:		a multi precision integer
134  * @buf:	buffer to which the output will be written to. Needs to be at
135  *		least mpi_get_size(a) long.
136  * @buf_len:	size of the buf.
137  * @nbytes:	receives the actual length of the data written on success and
138  *		the data to-be-written on -EOVERFLOW in case buf_len was too
139  *		small.
140  * @sign:	if not NULL, it will be set to the sign of a.
141  *
142  * Return:	0 on success or error code in case of error
143  */
144 int mpi_read_buffer(MPI a, uint8_t *buf, unsigned buf_len, unsigned *nbytes,
145 		    int *sign)
146 {
147 	uint8_t *p;
148 #if BYTES_PER_MPI_LIMB == 4
149 	__be32 alimb;
150 #elif BYTES_PER_MPI_LIMB == 8
151 	__be64 alimb;
152 #else
153 #error please implement for this limb size.
154 #endif
155 	unsigned int n = mpi_get_size(a);
156 	int i, lzeros;
157 
158 	if (!buf || !nbytes)
159 		return -EINVAL;
160 
161 	if (sign)
162 		*sign = a->sign;
163 
164 	lzeros = count_lzeros(a);
165 
166 	if (buf_len < n - lzeros) {
167 		*nbytes = n - lzeros;
168 		return -EOVERFLOW;
169 	}
170 
171 	p = buf;
172 	*nbytes = n - lzeros;
173 
174 	for (i = a->nlimbs - 1 - lzeros / BYTES_PER_MPI_LIMB,
175 			lzeros %= BYTES_PER_MPI_LIMB;
176 		i >= 0; i--) {
177 #if BYTES_PER_MPI_LIMB == 4
178 		alimb = cpu_to_be32(a->d[i]);
179 #elif BYTES_PER_MPI_LIMB == 8
180 		alimb = cpu_to_be64(a->d[i]);
181 #else
182 #error please implement for this limb size.
183 #endif
184 		memcpy(p, (u8 *)&alimb + lzeros, BYTES_PER_MPI_LIMB - lzeros);
185 		p += BYTES_PER_MPI_LIMB - lzeros;
186 		lzeros = 0;
187 	}
188 	return 0;
189 }
190 EXPORT_SYMBOL_GPL(mpi_read_buffer);
191 
192 /*
193  * mpi_get_buffer() - Returns an allocated buffer with the MPI (msb first).
194  * Caller must free the return string.
195  * This function does return a 0 byte buffer with nbytes set to zero if the
196  * value of A is zero.
197  *
198  * @a:		a multi precision integer.
199  * @nbytes:	receives the length of this buffer.
200  * @sign:	if not NULL, it will be set to the sign of the a.
201  *
202  * Return:	Pointer to MPI buffer or NULL on error
203  */
204 void *mpi_get_buffer(MPI a, unsigned *nbytes, int *sign)
205 {
206 	uint8_t *buf;
207 	unsigned int n;
208 	int ret;
209 
210 	if (!nbytes)
211 		return NULL;
212 
213 	n = mpi_get_size(a);
214 
215 	if (!n)
216 		n++;
217 
218 	buf = kmalloc(n, GFP_KERNEL);
219 
220 	if (!buf)
221 		return NULL;
222 
223 	ret = mpi_read_buffer(a, buf, n, nbytes, sign);
224 
225 	if (ret) {
226 		kfree(buf);
227 		return NULL;
228 	}
229 	return buf;
230 }
231 EXPORT_SYMBOL_GPL(mpi_get_buffer);
232 
233 /**
234  * mpi_write_to_sgl() - Funnction exports MPI to an sgl (msb first)
235  *
236  * This function works in the same way as the mpi_read_buffer, but it
237  * takes an sgl instead of u8 * buf.
238  *
239  * @a:		a multi precision integer
240  * @sgl:	scatterlist to write to. Needs to be at least
241  *		mpi_get_size(a) long.
242  * @nbytes:	the number of bytes to write.  Leading bytes will be
243  *		filled with zero.
244  * @sign:	if not NULL, it will be set to the sign of a.
245  *
246  * Return:	0 on success or error code in case of error
247  */
248 int mpi_write_to_sgl(MPI a, struct scatterlist *sgl, unsigned nbytes,
249 		     int *sign)
250 {
251 	u8 *p, *p2;
252 #if BYTES_PER_MPI_LIMB == 4
253 	__be32 alimb;
254 #elif BYTES_PER_MPI_LIMB == 8
255 	__be64 alimb;
256 #else
257 #error please implement for this limb size.
258 #endif
259 	unsigned int n = mpi_get_size(a);
260 	struct sg_mapping_iter miter;
261 	int i, x, buf_len;
262 	int nents;
263 
264 	if (sign)
265 		*sign = a->sign;
266 
267 	if (nbytes < n)
268 		return -EOVERFLOW;
269 
270 	nents = sg_nents_for_len(sgl, nbytes);
271 	if (nents < 0)
272 		return -EINVAL;
273 
274 	sg_miter_start(&miter, sgl, nents, SG_MITER_ATOMIC | SG_MITER_TO_SG);
275 	sg_miter_next(&miter);
276 	buf_len = miter.length;
277 	p2 = miter.addr;
278 
279 	while (nbytes > n) {
280 		i = min_t(unsigned, nbytes - n, buf_len);
281 		memset(p2, 0, i);
282 		p2 += i;
283 		nbytes -= i;
284 
285 		buf_len -= i;
286 		if (!buf_len) {
287 			sg_miter_next(&miter);
288 			buf_len = miter.length;
289 			p2 = miter.addr;
290 		}
291 	}
292 
293 	for (i = a->nlimbs - 1; i >= 0; i--) {
294 #if BYTES_PER_MPI_LIMB == 4
295 		alimb = a->d[i] ? cpu_to_be32(a->d[i]) : 0;
296 #elif BYTES_PER_MPI_LIMB == 8
297 		alimb = a->d[i] ? cpu_to_be64(a->d[i]) : 0;
298 #else
299 #error please implement for this limb size.
300 #endif
301 		p = (u8 *)&alimb;
302 
303 		for (x = 0; x < sizeof(alimb); x++) {
304 			*p2++ = *p++;
305 			if (!--buf_len) {
306 				sg_miter_next(&miter);
307 				buf_len = miter.length;
308 				p2 = miter.addr;
309 			}
310 		}
311 	}
312 
313 	sg_miter_stop(&miter);
314 	return 0;
315 }
316 EXPORT_SYMBOL_GPL(mpi_write_to_sgl);
317 
318 /*
319  * mpi_read_raw_from_sgl() - Function allocates an MPI and populates it with
320  *			     data from the sgl
321  *
322  * This function works in the same way as the mpi_read_raw_data, but it
323  * takes an sgl instead of void * buffer. i.e. it allocates
324  * a new MPI and reads the content of the sgl to the MPI.
325  *
326  * @sgl:	scatterlist to read from
327  * @nbytes:	number of bytes to read
328  *
329  * Return:	Pointer to a new MPI or NULL on error
330  */
331 MPI mpi_read_raw_from_sgl(struct scatterlist *sgl, unsigned int nbytes)
332 {
333 	struct sg_mapping_iter miter;
334 	unsigned int nbits, nlimbs;
335 	int x, j, z, lzeros, ents;
336 	unsigned int len;
337 	const u8 *buff;
338 	mpi_limb_t a;
339 	MPI val = NULL;
340 
341 	ents = sg_nents_for_len(sgl, nbytes);
342 	if (ents < 0)
343 		return NULL;
344 
345 	sg_miter_start(&miter, sgl, ents, SG_MITER_ATOMIC | SG_MITER_FROM_SG);
346 
347 	lzeros = 0;
348 	len = 0;
349 	while (nbytes > 0) {
350 		while (len && !*buff) {
351 			lzeros++;
352 			len--;
353 			buff++;
354 		}
355 
356 		if (len && *buff)
357 			break;
358 
359 		sg_miter_next(&miter);
360 		buff = miter.addr;
361 		len = miter.length;
362 
363 		nbytes -= lzeros;
364 		lzeros = 0;
365 	}
366 
367 	miter.consumed = lzeros;
368 
369 	nbytes -= lzeros;
370 	nbits = nbytes * 8;
371 	if (nbits > MAX_EXTERN_MPI_BITS) {
372 		sg_miter_stop(&miter);
373 		pr_info("MPI: mpi too large (%u bits)\n", nbits);
374 		return NULL;
375 	}
376 
377 	if (nbytes > 0)
378 		nbits -= count_leading_zeros(*buff) - (BITS_PER_LONG - 8);
379 
380 	sg_miter_stop(&miter);
381 
382 	nlimbs = DIV_ROUND_UP(nbytes, BYTES_PER_MPI_LIMB);
383 	val = mpi_alloc(nlimbs);
384 	if (!val)
385 		return NULL;
386 
387 	val->nbits = nbits;
388 	val->sign = 0;
389 	val->nlimbs = nlimbs;
390 
391 	if (nbytes == 0)
392 		return val;
393 
394 	j = nlimbs - 1;
395 	a = 0;
396 	z = BYTES_PER_MPI_LIMB - nbytes % BYTES_PER_MPI_LIMB;
397 	z %= BYTES_PER_MPI_LIMB;
398 
399 	while (sg_miter_next(&miter)) {
400 		buff = miter.addr;
401 		len = min_t(unsigned, miter.length, nbytes);
402 		nbytes -= len;
403 
404 		for (x = 0; x < len; x++) {
405 			a <<= 8;
406 			a |= *buff++;
407 			if (((z + x + 1) % BYTES_PER_MPI_LIMB) == 0) {
408 				val->d[j--] = a;
409 				a = 0;
410 			}
411 		}
412 		z += x;
413 	}
414 
415 	return val;
416 }
417 EXPORT_SYMBOL_GPL(mpi_read_raw_from_sgl);
418