xref: /freebsd/crypto/openssh/sshkey-xmss.c (revision e8d8bef961a50d4dc22501cde4fb9fb0be1b2532)
1 /* $OpenBSD: sshkey-xmss.c,v 1.3 2018/07/09 21:59:10 markus Exp $ */
2 /*
3  * Copyright (c) 2017 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25 
26 #include "includes.h"
27 #ifdef WITH_XMSS
28 
29 #include <sys/types.h>
30 #include <sys/uio.h>
31 
32 #include <stdio.h>
33 #include <string.h>
34 #include <unistd.h>
35 #include <fcntl.h>
36 #include <errno.h>
37 #ifdef HAVE_SYS_FILE_H
38 # include <sys/file.h>
39 #endif
40 
41 #include "ssh2.h"
42 #include "ssherr.h"
43 #include "sshbuf.h"
44 #include "cipher.h"
45 #include "sshkey.h"
46 #include "sshkey-xmss.h"
47 #include "atomicio.h"
48 
49 #include "xmss_fast.h"
50 
51 /* opaque internal XMSS state */
52 #define XMSS_MAGIC		"xmss-state-v1"
53 #define XMSS_CIPHERNAME		"aes256-gcm@openssh.com"
54 struct ssh_xmss_state {
55 	xmss_params	params;
56 	u_int32_t	n, w, h, k;
57 
58 	bds_state	bds;
59 	u_char		*stack;
60 	u_int32_t	stackoffset;
61 	u_char		*stacklevels;
62 	u_char		*auth;
63 	u_char		*keep;
64 	u_char		*th_nodes;
65 	u_char		*retain;
66 	treehash_inst	*treehash;
67 
68 	u_int32_t	idx;		/* state read from file */
69 	u_int32_t	maxidx;		/* restricted # of signatures */
70 	int		have_state;	/* .state file exists */
71 	int		lockfd;		/* locked in sshkey_xmss_get_state() */
72 	int		allow_update;	/* allow sshkey_xmss_update_state() */
73 	char		*enc_ciphername;/* encrypt state with cipher */
74 	u_char		*enc_keyiv;	/* encrypt state with key */
75 	u_int32_t	enc_keyiv_len;	/* length of enc_keyiv */
76 };
77 
78 int	 sshkey_xmss_init_bds_state(struct sshkey *);
79 int	 sshkey_xmss_init_enc_key(struct sshkey *, const char *);
80 void	 sshkey_xmss_free_bds(struct sshkey *);
81 int	 sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
82 	    int *, sshkey_printfn *);
83 int	 sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
84 	    struct sshbuf **);
85 int	 sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
86 	    struct sshbuf **);
87 int	 sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
88 int	 sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
89 
90 #define PRINT(s...) do { if (pr) pr(s); } while (0)
91 
92 int
93 sshkey_xmss_init(struct sshkey *key, const char *name)
94 {
95 	struct ssh_xmss_state *state;
96 
97 	if (key->xmss_state != NULL)
98 		return SSH_ERR_INVALID_FORMAT;
99 	if (name == NULL)
100 		return SSH_ERR_INVALID_FORMAT;
101 	state = calloc(sizeof(struct ssh_xmss_state), 1);
102 	if (state == NULL)
103 		return SSH_ERR_ALLOC_FAIL;
104 	if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
105 		state->n = 32;
106 		state->w = 16;
107 		state->h = 10;
108 	} else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
109 		state->n = 32;
110 		state->w = 16;
111 		state->h = 16;
112 	} else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
113 		state->n = 32;
114 		state->w = 16;
115 		state->h = 20;
116 	} else {
117 		free(state);
118 		return SSH_ERR_KEY_TYPE_UNKNOWN;
119 	}
120 	if ((key->xmss_name = strdup(name)) == NULL) {
121 		free(state);
122 		return SSH_ERR_ALLOC_FAIL;
123 	}
124 	state->k = 2;	/* XXX hardcoded */
125 	state->lockfd = -1;
126 	if (xmss_set_params(&state->params, state->n, state->h, state->w,
127 	    state->k) != 0) {
128 		free(state);
129 		return SSH_ERR_INVALID_FORMAT;
130 	}
131 	key->xmss_state = state;
132 	return 0;
133 }
134 
135 void
136 sshkey_xmss_free_state(struct sshkey *key)
137 {
138 	struct ssh_xmss_state *state = key->xmss_state;
139 
140 	sshkey_xmss_free_bds(key);
141 	if (state) {
142 		if (state->enc_keyiv) {
143 			explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
144 			free(state->enc_keyiv);
145 		}
146 		free(state->enc_ciphername);
147 		free(state);
148 	}
149 	key->xmss_state = NULL;
150 }
151 
152 #define SSH_XMSS_K2_MAGIC	"k=2"
153 #define num_stack(x)		((x->h+1)*(x->n))
154 #define num_stacklevels(x)	(x->h+1)
155 #define num_auth(x)		((x->h)*(x->n))
156 #define num_keep(x)		((x->h >> 1)*(x->n))
157 #define num_th_nodes(x)		((x->h - x->k)*(x->n))
158 #define num_retain(x)		(((1ULL << x->k) - x->k - 1) * (x->n))
159 #define num_treehash(x)		((x->h) - (x->k))
160 
161 int
162 sshkey_xmss_init_bds_state(struct sshkey *key)
163 {
164 	struct ssh_xmss_state *state = key->xmss_state;
165 	u_int32_t i;
166 
167 	state->stackoffset = 0;
168 	if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
169 	    (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
170 	    (state->auth = calloc(num_auth(state), 1)) == NULL ||
171 	    (state->keep = calloc(num_keep(state), 1)) == NULL ||
172 	    (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
173 	    (state->retain = calloc(num_retain(state), 1)) == NULL ||
174 	    (state->treehash = calloc(num_treehash(state),
175 	    sizeof(treehash_inst))) == NULL) {
176 		sshkey_xmss_free_bds(key);
177 		return SSH_ERR_ALLOC_FAIL;
178 	}
179 	for (i = 0; i < state->h - state->k; i++)
180 		state->treehash[i].node = &state->th_nodes[state->n*i];
181 	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
182 	    state->stacklevels, state->auth, state->keep, state->treehash,
183 	    state->retain, 0);
184 	return 0;
185 }
186 
187 void
188 sshkey_xmss_free_bds(struct sshkey *key)
189 {
190 	struct ssh_xmss_state *state = key->xmss_state;
191 
192 	if (state == NULL)
193 		return;
194 	free(state->stack);
195 	free(state->stacklevels);
196 	free(state->auth);
197 	free(state->keep);
198 	free(state->th_nodes);
199 	free(state->retain);
200 	free(state->treehash);
201 	state->stack = NULL;
202 	state->stacklevels = NULL;
203 	state->auth = NULL;
204 	state->keep = NULL;
205 	state->th_nodes = NULL;
206 	state->retain = NULL;
207 	state->treehash = NULL;
208 }
209 
210 void *
211 sshkey_xmss_params(const struct sshkey *key)
212 {
213 	struct ssh_xmss_state *state = key->xmss_state;
214 
215 	if (state == NULL)
216 		return NULL;
217 	return &state->params;
218 }
219 
220 void *
221 sshkey_xmss_bds_state(const struct sshkey *key)
222 {
223 	struct ssh_xmss_state *state = key->xmss_state;
224 
225 	if (state == NULL)
226 		return NULL;
227 	return &state->bds;
228 }
229 
230 int
231 sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
232 {
233 	struct ssh_xmss_state *state = key->xmss_state;
234 
235 	if (lenp == NULL)
236 		return SSH_ERR_INVALID_ARGUMENT;
237 	if (state == NULL)
238 		return SSH_ERR_INVALID_FORMAT;
239 	*lenp = 4 + state->n +
240 	    state->params.wots_par.keysize +
241 	    state->h * state->n;
242 	return 0;
243 }
244 
245 size_t
246 sshkey_xmss_pklen(const struct sshkey *key)
247 {
248 	struct ssh_xmss_state *state = key->xmss_state;
249 
250 	if (state == NULL)
251 		return 0;
252 	return state->n * 2;
253 }
254 
255 size_t
256 sshkey_xmss_sklen(const struct sshkey *key)
257 {
258 	struct ssh_xmss_state *state = key->xmss_state;
259 
260 	if (state == NULL)
261 		return 0;
262 	return state->n * 4 + 4;
263 }
264 
265 int
266 sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
267 {
268 	struct ssh_xmss_state *state = k->xmss_state;
269 	const struct sshcipher *cipher;
270 	size_t keylen = 0, ivlen = 0;
271 
272 	if (state == NULL)
273 		return SSH_ERR_INVALID_ARGUMENT;
274 	if ((cipher = cipher_by_name(ciphername)) == NULL)
275 		return SSH_ERR_INTERNAL_ERROR;
276 	if ((state->enc_ciphername = strdup(ciphername)) == NULL)
277 		return SSH_ERR_ALLOC_FAIL;
278 	keylen = cipher_keylen(cipher);
279 	ivlen = cipher_ivlen(cipher);
280 	state->enc_keyiv_len = keylen + ivlen;
281 	if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
282 		free(state->enc_ciphername);
283 		state->enc_ciphername = NULL;
284 		return SSH_ERR_ALLOC_FAIL;
285 	}
286 	arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
287 	return 0;
288 }
289 
290 int
291 sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
292 {
293 	struct ssh_xmss_state *state = k->xmss_state;
294 	int r;
295 
296 	if (state == NULL || state->enc_keyiv == NULL ||
297 	    state->enc_ciphername == NULL)
298 		return SSH_ERR_INVALID_ARGUMENT;
299 	if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
300 	    (r = sshbuf_put_string(b, state->enc_keyiv,
301 	    state->enc_keyiv_len)) != 0)
302 		return r;
303 	return 0;
304 }
305 
306 int
307 sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
308 {
309 	struct ssh_xmss_state *state = k->xmss_state;
310 	size_t len;
311 	int r;
312 
313 	if (state == NULL)
314 		return SSH_ERR_INVALID_ARGUMENT;
315 	if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
316 	    (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
317 		return r;
318 	state->enc_keyiv_len = len;
319 	return 0;
320 }
321 
322 int
323 sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
324     enum sshkey_serialize_rep opts)
325 {
326 	struct ssh_xmss_state *state = k->xmss_state;
327 	u_char have_info = 1;
328 	u_int32_t idx;
329 	int r;
330 
331 	if (state == NULL)
332 		return SSH_ERR_INVALID_ARGUMENT;
333 	if (opts != SSHKEY_SERIALIZE_INFO)
334 		return 0;
335 	idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
336 	if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
337 	    (r = sshbuf_put_u32(b, idx)) != 0 ||
338 	    (r = sshbuf_put_u32(b, state->maxidx)) != 0)
339 		return r;
340 	return 0;
341 }
342 
343 int
344 sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
345 {
346 	struct ssh_xmss_state *state = k->xmss_state;
347 	u_char have_info;
348 	int r;
349 
350 	if (state == NULL)
351 		return SSH_ERR_INVALID_ARGUMENT;
352 	/* optional */
353 	if (sshbuf_len(b) == 0)
354 		return 0;
355 	if ((r = sshbuf_get_u8(b, &have_info)) != 0)
356 		return r;
357 	if (have_info != 1)
358 		return SSH_ERR_INVALID_ARGUMENT;
359 	if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
360 	    (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
361 		return r;
362 	return 0;
363 }
364 
365 int
366 sshkey_xmss_generate_private_key(struct sshkey *k, u_int bits)
367 {
368 	int r;
369 	const char *name;
370 
371 	if (bits == 10) {
372 		name = XMSS_SHA2_256_W16_H10_NAME;
373 	} else if (bits == 16) {
374 		name = XMSS_SHA2_256_W16_H16_NAME;
375 	} else if (bits == 20) {
376 		name = XMSS_SHA2_256_W16_H20_NAME;
377 	} else {
378 		name = XMSS_DEFAULT_NAME;
379 	}
380 	if ((r = sshkey_xmss_init(k, name)) != 0 ||
381 	    (r = sshkey_xmss_init_bds_state(k)) != 0 ||
382 	    (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
383 		return r;
384 	if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
385 	    (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
386 		return SSH_ERR_ALLOC_FAIL;
387 	}
388 	xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
389 	    sshkey_xmss_params(k));
390 	return 0;
391 }
392 
393 int
394 sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
395     int *have_file, sshkey_printfn *pr)
396 {
397 	struct sshbuf *b = NULL, *enc = NULL;
398 	int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
399 	u_int32_t len;
400 	unsigned char buf[4], *data = NULL;
401 
402 	*have_file = 0;
403 	if ((fd = open(filename, O_RDONLY)) >= 0) {
404 		*have_file = 1;
405 		if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
406 			PRINT("%s: corrupt state file: %s", __func__, filename);
407 			goto done;
408 		}
409 		len = PEEK_U32(buf);
410 		if ((data = calloc(len, 1)) == NULL) {
411 			ret = SSH_ERR_ALLOC_FAIL;
412 			goto done;
413 		}
414 		if (atomicio(read, fd, data, len) != len) {
415 			PRINT("%s: cannot read blob: %s", __func__, filename);
416 			goto done;
417 		}
418 		if ((enc = sshbuf_from(data, len)) == NULL) {
419 			ret = SSH_ERR_ALLOC_FAIL;
420 			goto done;
421 		}
422 		sshkey_xmss_free_bds(k);
423 		if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
424 			ret = r;
425 			goto done;
426 		}
427 		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
428 			ret = r;
429 			goto done;
430 		}
431 		ret = 0;
432 	}
433 done:
434 	if (fd != -1)
435 		close(fd);
436 	free(data);
437 	sshbuf_free(enc);
438 	sshbuf_free(b);
439 	return ret;
440 }
441 
442 int
443 sshkey_xmss_get_state(const struct sshkey *k, sshkey_printfn *pr)
444 {
445 	struct ssh_xmss_state *state = k->xmss_state;
446 	u_int32_t idx = 0;
447 	char *filename = NULL;
448 	char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
449 	int lockfd = -1, have_state = 0, have_ostate, tries = 0;
450 	int ret = SSH_ERR_INVALID_ARGUMENT, r;
451 
452 	if (state == NULL)
453 		goto done;
454 	/*
455 	 * If maxidx is set, then we are allowed a limited number
456 	 * of signatures, but don't need to access the disk.
457 	 * Otherwise we need to deal with the on-disk state.
458 	 */
459 	if (state->maxidx) {
460 		/* xmss_sk always contains the current state */
461 		idx = PEEK_U32(k->xmss_sk);
462 		if (idx < state->maxidx) {
463 			state->allow_update = 1;
464 			return 0;
465 		}
466 		return SSH_ERR_INVALID_ARGUMENT;
467 	}
468 	if ((filename = k->xmss_filename) == NULL)
469 		goto done;
470 	if (asprintf(&lockfile, "%s.lock", filename) < 0 ||
471 	    asprintf(&statefile, "%s.state", filename) < 0 ||
472 	    asprintf(&ostatefile, "%s.ostate", filename) < 0) {
473 		ret = SSH_ERR_ALLOC_FAIL;
474 		goto done;
475 	}
476 	if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) < 0) {
477 		ret = SSH_ERR_SYSTEM_ERROR;
478 		PRINT("%s: cannot open/create: %s", __func__, lockfile);
479 		goto done;
480 	}
481 	while (flock(lockfd, LOCK_EX|LOCK_NB) < 0) {
482 		if (errno != EWOULDBLOCK) {
483 			ret = SSH_ERR_SYSTEM_ERROR;
484 			PRINT("%s: cannot lock: %s", __func__, lockfile);
485 			goto done;
486 		}
487 		if (++tries > 10) {
488 			ret = SSH_ERR_SYSTEM_ERROR;
489 			PRINT("%s: giving up on: %s", __func__, lockfile);
490 			goto done;
491 		}
492 		usleep(1000*100*tries);
493 	}
494 	/* XXX no longer const */
495 	if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
496 	    statefile, &have_state, pr)) != 0) {
497 		if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
498 		    ostatefile, &have_ostate, pr)) == 0) {
499 			state->allow_update = 1;
500 			r = sshkey_xmss_forward_state(k, 1);
501 			state->idx = PEEK_U32(k->xmss_sk);
502 			state->allow_update = 0;
503 		}
504 	}
505 	if (!have_state && !have_ostate) {
506 		/* check that bds state is initialized */
507 		if (state->bds.auth == NULL)
508 			goto done;
509 		PRINT("%s: start from scratch idx 0: %u", __func__, state->idx);
510 	} else if (r != 0) {
511 		ret = r;
512 		goto done;
513 	}
514 	if (state->idx + 1 < state->idx) {
515 		PRINT("%s: state wrap: %u", __func__, state->idx);
516 		goto done;
517 	}
518 	state->have_state = have_state;
519 	state->lockfd = lockfd;
520 	state->allow_update = 1;
521 	lockfd = -1;
522 	ret = 0;
523 done:
524 	if (lockfd != -1)
525 		close(lockfd);
526 	free(lockfile);
527 	free(statefile);
528 	free(ostatefile);
529 	return ret;
530 }
531 
532 int
533 sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
534 {
535 	struct ssh_xmss_state *state = k->xmss_state;
536 	u_char *sig = NULL;
537 	size_t required_siglen;
538 	unsigned long long smlen;
539 	u_char data;
540 	int ret, r;
541 
542 	if (state == NULL || !state->allow_update)
543 		return SSH_ERR_INVALID_ARGUMENT;
544 	if (reserve == 0)
545 		return SSH_ERR_INVALID_ARGUMENT;
546 	if (state->idx + reserve <= state->idx)
547 		return SSH_ERR_INVALID_ARGUMENT;
548 	if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
549 		return r;
550 	if ((sig = malloc(required_siglen)) == NULL)
551 		return SSH_ERR_ALLOC_FAIL;
552 	while (reserve-- > 0) {
553 		state->idx = PEEK_U32(k->xmss_sk);
554 		smlen = required_siglen;
555 		if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
556 		    sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
557 			r = SSH_ERR_INVALID_ARGUMENT;
558 			break;
559 		}
560 	}
561 	free(sig);
562 	return r;
563 }
564 
565 int
566 sshkey_xmss_update_state(const struct sshkey *k, sshkey_printfn *pr)
567 {
568 	struct ssh_xmss_state *state = k->xmss_state;
569 	struct sshbuf *b = NULL, *enc = NULL;
570 	u_int32_t idx = 0;
571 	unsigned char buf[4];
572 	char *filename = NULL;
573 	char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
574 	int fd = -1;
575 	int ret = SSH_ERR_INVALID_ARGUMENT;
576 
577 	if (state == NULL || !state->allow_update)
578 		return ret;
579 	if (state->maxidx) {
580 		/* no update since the number of signatures is limited */
581 		ret = 0;
582 		goto done;
583 	}
584 	idx = PEEK_U32(k->xmss_sk);
585 	if (idx == state->idx) {
586 		/* no signature happened, no need to update */
587 		ret = 0;
588 		goto done;
589 	} else if (idx != state->idx + 1) {
590 		PRINT("%s: more than one signature happened: idx %u state %u",
591 		     __func__, idx, state->idx);
592 		goto done;
593 	}
594 	state->idx = idx;
595 	if ((filename = k->xmss_filename) == NULL)
596 		goto done;
597 	if (asprintf(&statefile, "%s.state", filename) < 0 ||
598 	    asprintf(&ostatefile, "%s.ostate", filename) < 0 ||
599 	    asprintf(&nstatefile, "%s.nstate", filename) < 0) {
600 		ret = SSH_ERR_ALLOC_FAIL;
601 		goto done;
602 	}
603 	unlink(nstatefile);
604 	if ((b = sshbuf_new()) == NULL) {
605 		ret = SSH_ERR_ALLOC_FAIL;
606 		goto done;
607 	}
608 	if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
609 		PRINT("%s: SERLIALIZE FAILED: %d", __func__, ret);
610 		goto done;
611 	}
612 	if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
613 		PRINT("%s: ENCRYPT FAILED: %d", __func__, ret);
614 		goto done;
615 	}
616 	if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) < 0) {
617 		ret = SSH_ERR_SYSTEM_ERROR;
618 		PRINT("%s: open new state file: %s", __func__, nstatefile);
619 		goto done;
620 	}
621 	POKE_U32(buf, sshbuf_len(enc));
622 	if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
623 		ret = SSH_ERR_SYSTEM_ERROR;
624 		PRINT("%s: write new state file hdr: %s", __func__, nstatefile);
625 		close(fd);
626 		goto done;
627 	}
628 	if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
629 	    sshbuf_len(enc)) {
630 		ret = SSH_ERR_SYSTEM_ERROR;
631 		PRINT("%s: write new state file data: %s", __func__, nstatefile);
632 		close(fd);
633 		goto done;
634 	}
635 	if (fsync(fd) < 0) {
636 		ret = SSH_ERR_SYSTEM_ERROR;
637 		PRINT("%s: sync new state file: %s", __func__, nstatefile);
638 		close(fd);
639 		goto done;
640 	}
641 	if (close(fd) < 0) {
642 		ret = SSH_ERR_SYSTEM_ERROR;
643 		PRINT("%s: close new state file: %s", __func__, nstatefile);
644 		goto done;
645 	}
646 	if (state->have_state) {
647 		unlink(ostatefile);
648 		if (link(statefile, ostatefile)) {
649 			ret = SSH_ERR_SYSTEM_ERROR;
650 			PRINT("%s: backup state %s to %s", __func__, statefile,
651 			    ostatefile);
652 			goto done;
653 		}
654 	}
655 	if (rename(nstatefile, statefile) < 0) {
656 		ret = SSH_ERR_SYSTEM_ERROR;
657 		PRINT("%s: rename %s to %s", __func__, nstatefile, statefile);
658 		goto done;
659 	}
660 	ret = 0;
661 done:
662 	if (state->lockfd != -1) {
663 		close(state->lockfd);
664 		state->lockfd = -1;
665 	}
666 	if (nstatefile)
667 		unlink(nstatefile);
668 	free(statefile);
669 	free(ostatefile);
670 	free(nstatefile);
671 	sshbuf_free(b);
672 	sshbuf_free(enc);
673 	return ret;
674 }
675 
676 int
677 sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
678 {
679 	struct ssh_xmss_state *state = k->xmss_state;
680 	treehash_inst *th;
681 	u_int32_t i, node;
682 	int r;
683 
684 	if (state == NULL)
685 		return SSH_ERR_INVALID_ARGUMENT;
686 	if (state->stack == NULL)
687 		return SSH_ERR_INVALID_ARGUMENT;
688 	state->stackoffset = state->bds.stackoffset;	/* copy back */
689 	if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
690 	    (r = sshbuf_put_u32(b, state->idx)) != 0 ||
691 	    (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
692 	    (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
693 	    (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
694 	    (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
695 	    (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
696 	    (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
697 	    (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
698 	    (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
699 		return r;
700 	for (i = 0; i < num_treehash(state); i++) {
701 		th = &state->treehash[i];
702 		node = th->node - state->th_nodes;
703 		if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
704 		    (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
705 		    (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
706 		    (r = sshbuf_put_u8(b, th->completed)) != 0 ||
707 		    (r = sshbuf_put_u32(b, node)) != 0)
708 			return r;
709 	}
710 	return 0;
711 }
712 
713 int
714 sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
715     enum sshkey_serialize_rep opts)
716 {
717 	struct ssh_xmss_state *state = k->xmss_state;
718 	int r = SSH_ERR_INVALID_ARGUMENT;
719 
720 	if (state == NULL)
721 		return SSH_ERR_INVALID_ARGUMENT;
722 	if ((r = sshbuf_put_u8(b, opts)) != 0)
723 		return r;
724 	switch (opts) {
725 	case SSHKEY_SERIALIZE_STATE:
726 		r = sshkey_xmss_serialize_state(k, b);
727 		break;
728 	case SSHKEY_SERIALIZE_FULL:
729 		if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
730 			break;
731 		r = sshkey_xmss_serialize_state(k, b);
732 		break;
733 	case SSHKEY_SERIALIZE_DEFAULT:
734 		r = 0;
735 		break;
736 	default:
737 		r = SSH_ERR_INVALID_ARGUMENT;
738 		break;
739 	}
740 	return r;
741 }
742 
743 int
744 sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
745 {
746 	struct ssh_xmss_state *state = k->xmss_state;
747 	treehash_inst *th;
748 	u_int32_t i, lh, node;
749 	size_t ls, lsl, la, lk, ln, lr;
750 	char *magic;
751 	int r;
752 
753 	if (state == NULL)
754 		return SSH_ERR_INVALID_ARGUMENT;
755 	if (k->xmss_sk == NULL)
756 		return SSH_ERR_INVALID_ARGUMENT;
757 	if ((state->treehash = calloc(num_treehash(state),
758 	    sizeof(treehash_inst))) == NULL)
759 		return SSH_ERR_ALLOC_FAIL;
760 	if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
761 	    (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
762 	    (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
763 	    (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
764 	    (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
765 	    (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
766 	    (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
767 	    (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
768 	    (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
769 	    (r = sshbuf_get_u32(b, &lh)) != 0)
770 		return r;
771 	if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0)
772 		return SSH_ERR_INVALID_ARGUMENT;
773 	/* XXX check stackoffset */
774 	if (ls != num_stack(state) ||
775 	    lsl != num_stacklevels(state) ||
776 	    la != num_auth(state) ||
777 	    lk != num_keep(state) ||
778 	    ln != num_th_nodes(state) ||
779 	    lr != num_retain(state) ||
780 	    lh != num_treehash(state))
781 		return SSH_ERR_INVALID_ARGUMENT;
782 	for (i = 0; i < num_treehash(state); i++) {
783 		th = &state->treehash[i];
784 		if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
785 		    (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
786 		    (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
787 		    (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
788 		    (r = sshbuf_get_u32(b, &node)) != 0)
789 			return r;
790 		if (node < num_th_nodes(state))
791 			th->node = &state->th_nodes[node];
792 	}
793 	POKE_U32(k->xmss_sk, state->idx);
794 	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
795 	    state->stacklevels, state->auth, state->keep, state->treehash,
796 	    state->retain, 0);
797 	return 0;
798 }
799 
800 int
801 sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
802 {
803 	enum sshkey_serialize_rep opts;
804 	u_char have_state;
805 	int r;
806 
807 	if ((r = sshbuf_get_u8(b, &have_state)) != 0)
808 		return r;
809 
810 	opts = have_state;
811 	switch (opts) {
812 	case SSHKEY_SERIALIZE_DEFAULT:
813 		r = 0;
814 		break;
815 	case SSHKEY_SERIALIZE_STATE:
816 		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
817 			return r;
818 		break;
819 	case SSHKEY_SERIALIZE_FULL:
820 		if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
821 		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
822 			return r;
823 		break;
824 	default:
825 		r = SSH_ERR_INVALID_FORMAT;
826 		break;
827 	}
828 	return r;
829 }
830 
831 int
832 sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
833    struct sshbuf **retp)
834 {
835 	struct ssh_xmss_state *state = k->xmss_state;
836 	struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
837 	struct sshcipher_ctx *ciphercontext = NULL;
838 	const struct sshcipher *cipher;
839 	u_char *cp, *key, *iv = NULL;
840 	size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
841 	int r = SSH_ERR_INTERNAL_ERROR;
842 
843 	if (retp != NULL)
844 		*retp = NULL;
845 	if (state == NULL ||
846 	    state->enc_keyiv == NULL ||
847 	    state->enc_ciphername == NULL)
848 		return SSH_ERR_INTERNAL_ERROR;
849 	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
850 		r = SSH_ERR_INTERNAL_ERROR;
851 		goto out;
852 	}
853 	blocksize = cipher_blocksize(cipher);
854 	keylen = cipher_keylen(cipher);
855 	ivlen = cipher_ivlen(cipher);
856 	authlen = cipher_authlen(cipher);
857 	if (state->enc_keyiv_len != keylen + ivlen) {
858 		r = SSH_ERR_INVALID_FORMAT;
859 		goto out;
860 	}
861 	key = state->enc_keyiv;
862 	if ((encrypted = sshbuf_new()) == NULL ||
863 	    (encoded = sshbuf_new()) == NULL ||
864 	    (padded = sshbuf_new()) == NULL ||
865 	    (iv = malloc(ivlen)) == NULL) {
866 		r = SSH_ERR_ALLOC_FAIL;
867 		goto out;
868 	}
869 
870 	/* replace first 4 bytes of IV with index to ensure uniqueness */
871 	memcpy(iv, key + keylen, ivlen);
872 	POKE_U32(iv, state->idx);
873 
874 	if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
875 	    (r = sshbuf_put_u32(encoded, state->idx)) != 0)
876 		goto out;
877 
878 	/* padded state will be encrypted */
879 	if ((r = sshbuf_putb(padded, b)) != 0)
880 		goto out;
881 	i = 0;
882 	while (sshbuf_len(padded) % blocksize) {
883 		if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
884 			goto out;
885 	}
886 	encrypted_len = sshbuf_len(padded);
887 
888 	/* header including the length of state is used as AAD */
889 	if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
890 		goto out;
891 	aadlen = sshbuf_len(encoded);
892 
893 	/* concat header and state */
894 	if ((r = sshbuf_putb(encoded, padded)) != 0)
895 		goto out;
896 
897 	/* reserve space for encryption of encoded data plus auth tag */
898 	/* encrypt at offset addlen */
899 	if ((r = sshbuf_reserve(encrypted,
900 	    encrypted_len + aadlen + authlen, &cp)) != 0 ||
901 	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
902 	    iv, ivlen, 1)) != 0 ||
903 	    (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
904 	    encrypted_len, aadlen, authlen)) != 0)
905 		goto out;
906 
907 	/* success */
908 	r = 0;
909  out:
910 	if (retp != NULL) {
911 		*retp = encrypted;
912 		encrypted = NULL;
913 	}
914 	sshbuf_free(padded);
915 	sshbuf_free(encoded);
916 	sshbuf_free(encrypted);
917 	cipher_free(ciphercontext);
918 	free(iv);
919 	return r;
920 }
921 
922 int
923 sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
924    struct sshbuf **retp)
925 {
926 	struct ssh_xmss_state *state = k->xmss_state;
927 	struct sshbuf *copy = NULL, *decrypted = NULL;
928 	struct sshcipher_ctx *ciphercontext = NULL;
929 	const struct sshcipher *cipher = NULL;
930 	u_char *key, *iv = NULL, *dp;
931 	size_t keylen, ivlen, authlen, aadlen;
932 	u_int blocksize, encrypted_len, index;
933 	int r = SSH_ERR_INTERNAL_ERROR;
934 
935 	if (retp != NULL)
936 		*retp = NULL;
937 	if (state == NULL ||
938 	    state->enc_keyiv == NULL ||
939 	    state->enc_ciphername == NULL)
940 		return SSH_ERR_INTERNAL_ERROR;
941 	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
942 		r = SSH_ERR_INVALID_FORMAT;
943 		goto out;
944 	}
945 	blocksize = cipher_blocksize(cipher);
946 	keylen = cipher_keylen(cipher);
947 	ivlen = cipher_ivlen(cipher);
948 	authlen = cipher_authlen(cipher);
949 	if (state->enc_keyiv_len != keylen + ivlen) {
950 		r = SSH_ERR_INTERNAL_ERROR;
951 		goto out;
952 	}
953 	key = state->enc_keyiv;
954 
955 	if ((copy = sshbuf_fromb(encoded)) == NULL ||
956 	    (decrypted = sshbuf_new()) == NULL ||
957 	    (iv = malloc(ivlen)) == NULL) {
958 		r = SSH_ERR_ALLOC_FAIL;
959 		goto out;
960 	}
961 
962 	/* check magic */
963 	if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
964 	    memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
965 		r = SSH_ERR_INVALID_FORMAT;
966 		goto out;
967 	}
968 	/* parse public portion */
969 	if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
970 	    (r = sshbuf_get_u32(encoded, &index)) != 0 ||
971 	    (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
972 		goto out;
973 
974 	/* check size of encrypted key blob */
975 	if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
976 		r = SSH_ERR_INVALID_FORMAT;
977 		goto out;
978 	}
979 	/* check that an appropriate amount of auth data is present */
980 	if (sshbuf_len(encoded) < encrypted_len + authlen) {
981 		r = SSH_ERR_INVALID_FORMAT;
982 		goto out;
983 	}
984 
985 	aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
986 
987 	/* replace first 4 bytes of IV with index to ensure uniqueness */
988 	memcpy(iv, key + keylen, ivlen);
989 	POKE_U32(iv, index);
990 
991 	/* decrypt private state of key */
992 	if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
993 	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
994 	    iv, ivlen, 0)) != 0 ||
995 	    (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
996 	    encrypted_len, aadlen, authlen)) != 0)
997 		goto out;
998 
999 	/* there should be no trailing data */
1000 	if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
1001 		goto out;
1002 	if (sshbuf_len(encoded) != 0) {
1003 		r = SSH_ERR_INVALID_FORMAT;
1004 		goto out;
1005 	}
1006 
1007 	/* remove AAD */
1008 	if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
1009 		goto out;
1010 	/* XXX encrypted includes unchecked padding */
1011 
1012 	/* success */
1013 	r = 0;
1014 	if (retp != NULL) {
1015 		*retp = decrypted;
1016 		decrypted = NULL;
1017 	}
1018  out:
1019 	cipher_free(ciphercontext);
1020 	sshbuf_free(copy);
1021 	sshbuf_free(decrypted);
1022 	free(iv);
1023 	return r;
1024 }
1025 
1026 u_int32_t
1027 sshkey_xmss_signatures_left(const struct sshkey *k)
1028 {
1029 	struct ssh_xmss_state *state = k->xmss_state;
1030 	u_int32_t idx;
1031 
1032 	if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
1033 	    state->maxidx) {
1034 		idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
1035 		if (idx < state->maxidx)
1036 			return state->maxidx - idx;
1037 	}
1038 	return 0;
1039 }
1040 
1041 int
1042 sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
1043 {
1044 	struct ssh_xmss_state *state = k->xmss_state;
1045 
1046 	if (sshkey_type_plain(k->type) != KEY_XMSS)
1047 		return SSH_ERR_INVALID_ARGUMENT;
1048 	if (maxsign == 0)
1049 		return 0;
1050 	if (state->idx + maxsign < state->idx)
1051 		return SSH_ERR_INVALID_ARGUMENT;
1052 	state->maxidx = state->idx + maxsign;
1053 	return 0;
1054 }
1055 #endif /* WITH_XMSS */
1056