xref: /freebsd/crypto/openssh/sshkey-xmss.c (revision 9729f076e4d93c5a37e78d427bfe0f1ab99bbcc6)
1 /* $OpenBSD: sshkey-xmss.c,v 1.11 2021/04/03 06:18:41 djm Exp $ */
2 /*
3  * Copyright (c) 2017 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25 
26 #include "includes.h"
27 #ifdef WITH_XMSS
28 
29 #include <sys/types.h>
30 #include <sys/uio.h>
31 
32 #include <stdio.h>
33 #include <string.h>
34 #include <unistd.h>
35 #include <fcntl.h>
36 #include <errno.h>
37 #ifdef HAVE_SYS_FILE_H
38 # include <sys/file.h>
39 #endif
40 
41 #include "ssh2.h"
42 #include "ssherr.h"
43 #include "sshbuf.h"
44 #include "cipher.h"
45 #include "sshkey.h"
46 #include "sshkey-xmss.h"
47 #include "atomicio.h"
48 #include "log.h"
49 
50 #include "xmss_fast.h"
51 
52 /* opaque internal XMSS state */
53 #define XMSS_MAGIC		"xmss-state-v1"
54 #define XMSS_CIPHERNAME		"aes256-gcm@openssh.com"
55 struct ssh_xmss_state {
56 	xmss_params	params;
57 	u_int32_t	n, w, h, k;
58 
59 	bds_state	bds;
60 	u_char		*stack;
61 	u_int32_t	stackoffset;
62 	u_char		*stacklevels;
63 	u_char		*auth;
64 	u_char		*keep;
65 	u_char		*th_nodes;
66 	u_char		*retain;
67 	treehash_inst	*treehash;
68 
69 	u_int32_t	idx;		/* state read from file */
70 	u_int32_t	maxidx;		/* restricted # of signatures */
71 	int		have_state;	/* .state file exists */
72 	int		lockfd;		/* locked in sshkey_xmss_get_state() */
73 	u_char		allow_update;	/* allow sshkey_xmss_update_state() */
74 	char		*enc_ciphername;/* encrypt state with cipher */
75 	u_char		*enc_keyiv;	/* encrypt state with key */
76 	u_int32_t	enc_keyiv_len;	/* length of enc_keyiv */
77 };
78 
79 int	 sshkey_xmss_init_bds_state(struct sshkey *);
80 int	 sshkey_xmss_init_enc_key(struct sshkey *, const char *);
81 void	 sshkey_xmss_free_bds(struct sshkey *);
82 int	 sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
83 	    int *, int);
84 int	 sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
85 	    struct sshbuf **);
86 int	 sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
87 	    struct sshbuf **);
88 int	 sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
89 int	 sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);
90 
91 #define PRINT(...) do { if (printerror) sshlog(__FILE__, __func__, __LINE__, \
92     0, SYSLOG_LEVEL_ERROR, NULL, __VA_ARGS__); } while (0)
93 
94 int
95 sshkey_xmss_init(struct sshkey *key, const char *name)
96 {
97 	struct ssh_xmss_state *state;
98 
99 	if (key->xmss_state != NULL)
100 		return SSH_ERR_INVALID_FORMAT;
101 	if (name == NULL)
102 		return SSH_ERR_INVALID_FORMAT;
103 	state = calloc(sizeof(struct ssh_xmss_state), 1);
104 	if (state == NULL)
105 		return SSH_ERR_ALLOC_FAIL;
106 	if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
107 		state->n = 32;
108 		state->w = 16;
109 		state->h = 10;
110 	} else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
111 		state->n = 32;
112 		state->w = 16;
113 		state->h = 16;
114 	} else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
115 		state->n = 32;
116 		state->w = 16;
117 		state->h = 20;
118 	} else {
119 		free(state);
120 		return SSH_ERR_KEY_TYPE_UNKNOWN;
121 	}
122 	if ((key->xmss_name = strdup(name)) == NULL) {
123 		free(state);
124 		return SSH_ERR_ALLOC_FAIL;
125 	}
126 	state->k = 2;	/* XXX hardcoded */
127 	state->lockfd = -1;
128 	if (xmss_set_params(&state->params, state->n, state->h, state->w,
129 	    state->k) != 0) {
130 		free(state);
131 		return SSH_ERR_INVALID_FORMAT;
132 	}
133 	key->xmss_state = state;
134 	return 0;
135 }
136 
137 void
138 sshkey_xmss_free_state(struct sshkey *key)
139 {
140 	struct ssh_xmss_state *state = key->xmss_state;
141 
142 	sshkey_xmss_free_bds(key);
143 	if (state) {
144 		if (state->enc_keyiv) {
145 			explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
146 			free(state->enc_keyiv);
147 		}
148 		free(state->enc_ciphername);
149 		free(state);
150 	}
151 	key->xmss_state = NULL;
152 }
153 
154 #define SSH_XMSS_K2_MAGIC	"k=2"
155 #define num_stack(x)		((x->h+1)*(x->n))
156 #define num_stacklevels(x)	(x->h+1)
157 #define num_auth(x)		((x->h)*(x->n))
158 #define num_keep(x)		((x->h >> 1)*(x->n))
159 #define num_th_nodes(x)		((x->h - x->k)*(x->n))
160 #define num_retain(x)		(((1ULL << x->k) - x->k - 1) * (x->n))
161 #define num_treehash(x)		((x->h) - (x->k))
162 
163 int
164 sshkey_xmss_init_bds_state(struct sshkey *key)
165 {
166 	struct ssh_xmss_state *state = key->xmss_state;
167 	u_int32_t i;
168 
169 	state->stackoffset = 0;
170 	if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
171 	    (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
172 	    (state->auth = calloc(num_auth(state), 1)) == NULL ||
173 	    (state->keep = calloc(num_keep(state), 1)) == NULL ||
174 	    (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
175 	    (state->retain = calloc(num_retain(state), 1)) == NULL ||
176 	    (state->treehash = calloc(num_treehash(state),
177 	    sizeof(treehash_inst))) == NULL) {
178 		sshkey_xmss_free_bds(key);
179 		return SSH_ERR_ALLOC_FAIL;
180 	}
181 	for (i = 0; i < state->h - state->k; i++)
182 		state->treehash[i].node = &state->th_nodes[state->n*i];
183 	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
184 	    state->stacklevels, state->auth, state->keep, state->treehash,
185 	    state->retain, 0);
186 	return 0;
187 }
188 
189 void
190 sshkey_xmss_free_bds(struct sshkey *key)
191 {
192 	struct ssh_xmss_state *state = key->xmss_state;
193 
194 	if (state == NULL)
195 		return;
196 	free(state->stack);
197 	free(state->stacklevels);
198 	free(state->auth);
199 	free(state->keep);
200 	free(state->th_nodes);
201 	free(state->retain);
202 	free(state->treehash);
203 	state->stack = NULL;
204 	state->stacklevels = NULL;
205 	state->auth = NULL;
206 	state->keep = NULL;
207 	state->th_nodes = NULL;
208 	state->retain = NULL;
209 	state->treehash = NULL;
210 }
211 
212 void *
213 sshkey_xmss_params(const struct sshkey *key)
214 {
215 	struct ssh_xmss_state *state = key->xmss_state;
216 
217 	if (state == NULL)
218 		return NULL;
219 	return &state->params;
220 }
221 
222 void *
223 sshkey_xmss_bds_state(const struct sshkey *key)
224 {
225 	struct ssh_xmss_state *state = key->xmss_state;
226 
227 	if (state == NULL)
228 		return NULL;
229 	return &state->bds;
230 }
231 
232 int
233 sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
234 {
235 	struct ssh_xmss_state *state = key->xmss_state;
236 
237 	if (lenp == NULL)
238 		return SSH_ERR_INVALID_ARGUMENT;
239 	if (state == NULL)
240 		return SSH_ERR_INVALID_FORMAT;
241 	*lenp = 4 + state->n +
242 	    state->params.wots_par.keysize +
243 	    state->h * state->n;
244 	return 0;
245 }
246 
247 size_t
248 sshkey_xmss_pklen(const struct sshkey *key)
249 {
250 	struct ssh_xmss_state *state = key->xmss_state;
251 
252 	if (state == NULL)
253 		return 0;
254 	return state->n * 2;
255 }
256 
257 size_t
258 sshkey_xmss_sklen(const struct sshkey *key)
259 {
260 	struct ssh_xmss_state *state = key->xmss_state;
261 
262 	if (state == NULL)
263 		return 0;
264 	return state->n * 4 + 4;
265 }
266 
267 int
268 sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
269 {
270 	struct ssh_xmss_state *state = k->xmss_state;
271 	const struct sshcipher *cipher;
272 	size_t keylen = 0, ivlen = 0;
273 
274 	if (state == NULL)
275 		return SSH_ERR_INVALID_ARGUMENT;
276 	if ((cipher = cipher_by_name(ciphername)) == NULL)
277 		return SSH_ERR_INTERNAL_ERROR;
278 	if ((state->enc_ciphername = strdup(ciphername)) == NULL)
279 		return SSH_ERR_ALLOC_FAIL;
280 	keylen = cipher_keylen(cipher);
281 	ivlen = cipher_ivlen(cipher);
282 	state->enc_keyiv_len = keylen + ivlen;
283 	if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
284 		free(state->enc_ciphername);
285 		state->enc_ciphername = NULL;
286 		return SSH_ERR_ALLOC_FAIL;
287 	}
288 	arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
289 	return 0;
290 }
291 
292 int
293 sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
294 {
295 	struct ssh_xmss_state *state = k->xmss_state;
296 	int r;
297 
298 	if (state == NULL || state->enc_keyiv == NULL ||
299 	    state->enc_ciphername == NULL)
300 		return SSH_ERR_INVALID_ARGUMENT;
301 	if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
302 	    (r = sshbuf_put_string(b, state->enc_keyiv,
303 	    state->enc_keyiv_len)) != 0)
304 		return r;
305 	return 0;
306 }
307 
308 int
309 sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
310 {
311 	struct ssh_xmss_state *state = k->xmss_state;
312 	size_t len;
313 	int r;
314 
315 	if (state == NULL)
316 		return SSH_ERR_INVALID_ARGUMENT;
317 	if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
318 	    (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
319 		return r;
320 	state->enc_keyiv_len = len;
321 	return 0;
322 }
323 
324 int
325 sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
326     enum sshkey_serialize_rep opts)
327 {
328 	struct ssh_xmss_state *state = k->xmss_state;
329 	u_char have_info = 1;
330 	u_int32_t idx;
331 	int r;
332 
333 	if (state == NULL)
334 		return SSH_ERR_INVALID_ARGUMENT;
335 	if (opts != SSHKEY_SERIALIZE_INFO)
336 		return 0;
337 	idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
338 	if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
339 	    (r = sshbuf_put_u32(b, idx)) != 0 ||
340 	    (r = sshbuf_put_u32(b, state->maxidx)) != 0)
341 		return r;
342 	return 0;
343 }
344 
345 int
346 sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
347 {
348 	struct ssh_xmss_state *state = k->xmss_state;
349 	u_char have_info;
350 	int r;
351 
352 	if (state == NULL)
353 		return SSH_ERR_INVALID_ARGUMENT;
354 	/* optional */
355 	if (sshbuf_len(b) == 0)
356 		return 0;
357 	if ((r = sshbuf_get_u8(b, &have_info)) != 0)
358 		return r;
359 	if (have_info != 1)
360 		return SSH_ERR_INVALID_ARGUMENT;
361 	if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
362 	    (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
363 		return r;
364 	return 0;
365 }
366 
367 int
368 sshkey_xmss_generate_private_key(struct sshkey *k, u_int bits)
369 {
370 	int r;
371 	const char *name;
372 
373 	if (bits == 10) {
374 		name = XMSS_SHA2_256_W16_H10_NAME;
375 	} else if (bits == 16) {
376 		name = XMSS_SHA2_256_W16_H16_NAME;
377 	} else if (bits == 20) {
378 		name = XMSS_SHA2_256_W16_H20_NAME;
379 	} else {
380 		name = XMSS_DEFAULT_NAME;
381 	}
382 	if ((r = sshkey_xmss_init(k, name)) != 0 ||
383 	    (r = sshkey_xmss_init_bds_state(k)) != 0 ||
384 	    (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
385 		return r;
386 	if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
387 	    (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
388 		return SSH_ERR_ALLOC_FAIL;
389 	}
390 	xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
391 	    sshkey_xmss_params(k));
392 	return 0;
393 }
394 
395 int
396 sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
397     int *have_file, int printerror)
398 {
399 	struct sshbuf *b = NULL, *enc = NULL;
400 	int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
401 	u_int32_t len;
402 	unsigned char buf[4], *data = NULL;
403 
404 	*have_file = 0;
405 	if ((fd = open(filename, O_RDONLY)) >= 0) {
406 		*have_file = 1;
407 		if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
408 			PRINT("corrupt state file: %s", filename);
409 			goto done;
410 		}
411 		len = PEEK_U32(buf);
412 		if ((data = calloc(len, 1)) == NULL) {
413 			ret = SSH_ERR_ALLOC_FAIL;
414 			goto done;
415 		}
416 		if (atomicio(read, fd, data, len) != len) {
417 			PRINT("cannot read blob: %s", filename);
418 			goto done;
419 		}
420 		if ((enc = sshbuf_from(data, len)) == NULL) {
421 			ret = SSH_ERR_ALLOC_FAIL;
422 			goto done;
423 		}
424 		sshkey_xmss_free_bds(k);
425 		if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
426 			ret = r;
427 			goto done;
428 		}
429 		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
430 			ret = r;
431 			goto done;
432 		}
433 		ret = 0;
434 	}
435 done:
436 	if (fd != -1)
437 		close(fd);
438 	free(data);
439 	sshbuf_free(enc);
440 	sshbuf_free(b);
441 	return ret;
442 }
443 
444 int
445 sshkey_xmss_get_state(const struct sshkey *k, int printerror)
446 {
447 	struct ssh_xmss_state *state = k->xmss_state;
448 	u_int32_t idx = 0;
449 	char *filename = NULL;
450 	char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
451 	int lockfd = -1, have_state = 0, have_ostate, tries = 0;
452 	int ret = SSH_ERR_INVALID_ARGUMENT, r;
453 
454 	if (state == NULL)
455 		goto done;
456 	/*
457 	 * If maxidx is set, then we are allowed a limited number
458 	 * of signatures, but don't need to access the disk.
459 	 * Otherwise we need to deal with the on-disk state.
460 	 */
461 	if (state->maxidx) {
462 		/* xmss_sk always contains the current state */
463 		idx = PEEK_U32(k->xmss_sk);
464 		if (idx < state->maxidx) {
465 			state->allow_update = 1;
466 			return 0;
467 		}
468 		return SSH_ERR_INVALID_ARGUMENT;
469 	}
470 	if ((filename = k->xmss_filename) == NULL)
471 		goto done;
472 	if (asprintf(&lockfile, "%s.lock", filename) == -1 ||
473 	    asprintf(&statefile, "%s.state", filename) == -1 ||
474 	    asprintf(&ostatefile, "%s.ostate", filename) == -1) {
475 		ret = SSH_ERR_ALLOC_FAIL;
476 		goto done;
477 	}
478 	if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) == -1) {
479 		ret = SSH_ERR_SYSTEM_ERROR;
480 		PRINT("cannot open/create: %s", lockfile);
481 		goto done;
482 	}
483 	while (flock(lockfd, LOCK_EX|LOCK_NB) == -1) {
484 		if (errno != EWOULDBLOCK) {
485 			ret = SSH_ERR_SYSTEM_ERROR;
486 			PRINT("cannot lock: %s", lockfile);
487 			goto done;
488 		}
489 		if (++tries > 10) {
490 			ret = SSH_ERR_SYSTEM_ERROR;
491 			PRINT("giving up on: %s", lockfile);
492 			goto done;
493 		}
494 		usleep(1000*100*tries);
495 	}
496 	/* XXX no longer const */
497 	if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
498 	    statefile, &have_state, printerror)) != 0) {
499 		if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
500 		    ostatefile, &have_ostate, printerror)) == 0) {
501 			state->allow_update = 1;
502 			r = sshkey_xmss_forward_state(k, 1);
503 			state->idx = PEEK_U32(k->xmss_sk);
504 			state->allow_update = 0;
505 		}
506 	}
507 	if (!have_state && !have_ostate) {
508 		/* check that bds state is initialized */
509 		if (state->bds.auth == NULL)
510 			goto done;
511 		PRINT("start from scratch idx 0: %u", state->idx);
512 	} else if (r != 0) {
513 		ret = r;
514 		goto done;
515 	}
516 	if (state->idx + 1 < state->idx) {
517 		PRINT("state wrap: %u", state->idx);
518 		goto done;
519 	}
520 	state->have_state = have_state;
521 	state->lockfd = lockfd;
522 	state->allow_update = 1;
523 	lockfd = -1;
524 	ret = 0;
525 done:
526 	if (lockfd != -1)
527 		close(lockfd);
528 	free(lockfile);
529 	free(statefile);
530 	free(ostatefile);
531 	return ret;
532 }
533 
534 int
535 sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
536 {
537 	struct ssh_xmss_state *state = k->xmss_state;
538 	u_char *sig = NULL;
539 	size_t required_siglen;
540 	unsigned long long smlen;
541 	u_char data;
542 	int ret, r;
543 
544 	if (state == NULL || !state->allow_update)
545 		return SSH_ERR_INVALID_ARGUMENT;
546 	if (reserve == 0)
547 		return SSH_ERR_INVALID_ARGUMENT;
548 	if (state->idx + reserve <= state->idx)
549 		return SSH_ERR_INVALID_ARGUMENT;
550 	if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
551 		return r;
552 	if ((sig = malloc(required_siglen)) == NULL)
553 		return SSH_ERR_ALLOC_FAIL;
554 	while (reserve-- > 0) {
555 		state->idx = PEEK_U32(k->xmss_sk);
556 		smlen = required_siglen;
557 		if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
558 		    sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
559 			r = SSH_ERR_INVALID_ARGUMENT;
560 			break;
561 		}
562 	}
563 	free(sig);
564 	return r;
565 }
566 
567 int
568 sshkey_xmss_update_state(const struct sshkey *k, int printerror)
569 {
570 	struct ssh_xmss_state *state = k->xmss_state;
571 	struct sshbuf *b = NULL, *enc = NULL;
572 	u_int32_t idx = 0;
573 	unsigned char buf[4];
574 	char *filename = NULL;
575 	char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
576 	int fd = -1;
577 	int ret = SSH_ERR_INVALID_ARGUMENT;
578 
579 	if (state == NULL || !state->allow_update)
580 		return ret;
581 	if (state->maxidx) {
582 		/* no update since the number of signatures is limited */
583 		ret = 0;
584 		goto done;
585 	}
586 	idx = PEEK_U32(k->xmss_sk);
587 	if (idx == state->idx) {
588 		/* no signature happened, no need to update */
589 		ret = 0;
590 		goto done;
591 	} else if (idx != state->idx + 1) {
592 		PRINT("more than one signature happened: idx %u state %u",
593 		    idx, state->idx);
594 		goto done;
595 	}
596 	state->idx = idx;
597 	if ((filename = k->xmss_filename) == NULL)
598 		goto done;
599 	if (asprintf(&statefile, "%s.state", filename) == -1 ||
600 	    asprintf(&ostatefile, "%s.ostate", filename) == -1 ||
601 	    asprintf(&nstatefile, "%s.nstate", filename) == -1) {
602 		ret = SSH_ERR_ALLOC_FAIL;
603 		goto done;
604 	}
605 	unlink(nstatefile);
606 	if ((b = sshbuf_new()) == NULL) {
607 		ret = SSH_ERR_ALLOC_FAIL;
608 		goto done;
609 	}
610 	if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
611 		PRINT("SERLIALIZE FAILED: %d", ret);
612 		goto done;
613 	}
614 	if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
615 		PRINT("ENCRYPT FAILED: %d", ret);
616 		goto done;
617 	}
618 	if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) == -1) {
619 		ret = SSH_ERR_SYSTEM_ERROR;
620 		PRINT("open new state file: %s", nstatefile);
621 		goto done;
622 	}
623 	POKE_U32(buf, sshbuf_len(enc));
624 	if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
625 		ret = SSH_ERR_SYSTEM_ERROR;
626 		PRINT("write new state file hdr: %s", nstatefile);
627 		close(fd);
628 		goto done;
629 	}
630 	if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
631 	    sshbuf_len(enc)) {
632 		ret = SSH_ERR_SYSTEM_ERROR;
633 		PRINT("write new state file data: %s", nstatefile);
634 		close(fd);
635 		goto done;
636 	}
637 	if (fsync(fd) == -1) {
638 		ret = SSH_ERR_SYSTEM_ERROR;
639 		PRINT("sync new state file: %s", nstatefile);
640 		close(fd);
641 		goto done;
642 	}
643 	if (close(fd) == -1) {
644 		ret = SSH_ERR_SYSTEM_ERROR;
645 		PRINT("close new state file: %s", nstatefile);
646 		goto done;
647 	}
648 	if (state->have_state) {
649 		unlink(ostatefile);
650 		if (link(statefile, ostatefile)) {
651 			ret = SSH_ERR_SYSTEM_ERROR;
652 			PRINT("backup state %s to %s", statefile, ostatefile);
653 			goto done;
654 		}
655 	}
656 	if (rename(nstatefile, statefile) == -1) {
657 		ret = SSH_ERR_SYSTEM_ERROR;
658 		PRINT("rename %s to %s", nstatefile, statefile);
659 		goto done;
660 	}
661 	ret = 0;
662 done:
663 	if (state->lockfd != -1) {
664 		close(state->lockfd);
665 		state->lockfd = -1;
666 	}
667 	if (nstatefile)
668 		unlink(nstatefile);
669 	free(statefile);
670 	free(ostatefile);
671 	free(nstatefile);
672 	sshbuf_free(b);
673 	sshbuf_free(enc);
674 	return ret;
675 }
676 
677 int
678 sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
679 {
680 	struct ssh_xmss_state *state = k->xmss_state;
681 	treehash_inst *th;
682 	u_int32_t i, node;
683 	int r;
684 
685 	if (state == NULL)
686 		return SSH_ERR_INVALID_ARGUMENT;
687 	if (state->stack == NULL)
688 		return SSH_ERR_INVALID_ARGUMENT;
689 	state->stackoffset = state->bds.stackoffset;	/* copy back */
690 	if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
691 	    (r = sshbuf_put_u32(b, state->idx)) != 0 ||
692 	    (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
693 	    (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
694 	    (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
695 	    (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
696 	    (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
697 	    (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
698 	    (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
699 	    (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
700 		return r;
701 	for (i = 0; i < num_treehash(state); i++) {
702 		th = &state->treehash[i];
703 		node = th->node - state->th_nodes;
704 		if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
705 		    (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
706 		    (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
707 		    (r = sshbuf_put_u8(b, th->completed)) != 0 ||
708 		    (r = sshbuf_put_u32(b, node)) != 0)
709 			return r;
710 	}
711 	return 0;
712 }
713 
714 int
715 sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
716     enum sshkey_serialize_rep opts)
717 {
718 	struct ssh_xmss_state *state = k->xmss_state;
719 	int r = SSH_ERR_INVALID_ARGUMENT;
720 	u_char have_stack, have_filename, have_enc;
721 
722 	if (state == NULL)
723 		return SSH_ERR_INVALID_ARGUMENT;
724 	if ((r = sshbuf_put_u8(b, opts)) != 0)
725 		return r;
726 	switch (opts) {
727 	case SSHKEY_SERIALIZE_STATE:
728 		r = sshkey_xmss_serialize_state(k, b);
729 		break;
730 	case SSHKEY_SERIALIZE_FULL:
731 		if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
732 			return r;
733 		r = sshkey_xmss_serialize_state(k, b);
734 		break;
735 	case SSHKEY_SERIALIZE_SHIELD:
736 		/* all of stack/filename/enc are optional */
737 		have_stack = state->stack != NULL;
738 		if ((r = sshbuf_put_u8(b, have_stack)) != 0)
739 			return r;
740 		if (have_stack) {
741 			state->idx = PEEK_U32(k->xmss_sk);	/* update */
742 			if ((r = sshkey_xmss_serialize_state(k, b)) != 0)
743 				return r;
744 		}
745 		have_filename = k->xmss_filename != NULL;
746 		if ((r = sshbuf_put_u8(b, have_filename)) != 0)
747 			return r;
748 		if (have_filename &&
749 		    (r = sshbuf_put_cstring(b, k->xmss_filename)) != 0)
750 			return r;
751 		have_enc = state->enc_keyiv != NULL;
752 		if ((r = sshbuf_put_u8(b, have_enc)) != 0)
753 			return r;
754 		if (have_enc &&
755 		    (r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
756 			return r;
757 		if ((r = sshbuf_put_u32(b, state->maxidx)) != 0 ||
758 		    (r = sshbuf_put_u8(b, state->allow_update)) != 0)
759 			return r;
760 		break;
761 	case SSHKEY_SERIALIZE_DEFAULT:
762 		r = 0;
763 		break;
764 	default:
765 		r = SSH_ERR_INVALID_ARGUMENT;
766 		break;
767 	}
768 	return r;
769 }
770 
771 int
772 sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
773 {
774 	struct ssh_xmss_state *state = k->xmss_state;
775 	treehash_inst *th;
776 	u_int32_t i, lh, node;
777 	size_t ls, lsl, la, lk, ln, lr;
778 	char *magic;
779 	int r = SSH_ERR_INTERNAL_ERROR;
780 
781 	if (state == NULL)
782 		return SSH_ERR_INVALID_ARGUMENT;
783 	if (k->xmss_sk == NULL)
784 		return SSH_ERR_INVALID_ARGUMENT;
785 	if ((state->treehash = calloc(num_treehash(state),
786 	    sizeof(treehash_inst))) == NULL)
787 		return SSH_ERR_ALLOC_FAIL;
788 	if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
789 	    (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
790 	    (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
791 	    (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
792 	    (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
793 	    (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
794 	    (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
795 	    (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
796 	    (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
797 	    (r = sshbuf_get_u32(b, &lh)) != 0)
798 		goto out;
799 	if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0) {
800 		r = SSH_ERR_INVALID_ARGUMENT;
801 		goto out;
802 	}
803 	/* XXX check stackoffset */
804 	if (ls != num_stack(state) ||
805 	    lsl != num_stacklevels(state) ||
806 	    la != num_auth(state) ||
807 	    lk != num_keep(state) ||
808 	    ln != num_th_nodes(state) ||
809 	    lr != num_retain(state) ||
810 	    lh != num_treehash(state)) {
811 		r = SSH_ERR_INVALID_ARGUMENT;
812 		goto out;
813 	}
814 	for (i = 0; i < num_treehash(state); i++) {
815 		th = &state->treehash[i];
816 		if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
817 		    (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
818 		    (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
819 		    (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
820 		    (r = sshbuf_get_u32(b, &node)) != 0)
821 			goto out;
822 		if (node < num_th_nodes(state))
823 			th->node = &state->th_nodes[node];
824 	}
825 	POKE_U32(k->xmss_sk, state->idx);
826 	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
827 	    state->stacklevels, state->auth, state->keep, state->treehash,
828 	    state->retain, 0);
829 	/* success */
830 	r = 0;
831  out:
832 	free(magic);
833 	return r;
834 }
835 
836 int
837 sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
838 {
839 	struct ssh_xmss_state *state = k->xmss_state;
840 	enum sshkey_serialize_rep opts;
841 	u_char have_state, have_stack, have_filename, have_enc;
842 	int r;
843 
844 	if ((r = sshbuf_get_u8(b, &have_state)) != 0)
845 		return r;
846 
847 	opts = have_state;
848 	switch (opts) {
849 	case SSHKEY_SERIALIZE_DEFAULT:
850 		r = 0;
851 		break;
852 	case SSHKEY_SERIALIZE_SHIELD:
853 		if ((r = sshbuf_get_u8(b, &have_stack)) != 0)
854 			return r;
855 		if (have_stack &&
856 		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
857 			return r;
858 		if ((r = sshbuf_get_u8(b, &have_filename)) != 0)
859 			return r;
860 		if (have_filename &&
861 		    (r = sshbuf_get_cstring(b, &k->xmss_filename, NULL)) != 0)
862 			return r;
863 		if ((r = sshbuf_get_u8(b, &have_enc)) != 0)
864 			return r;
865 		if (have_enc &&
866 		    (r = sshkey_xmss_deserialize_enc_key(k, b)) != 0)
867 			return r;
868 		if ((r = sshbuf_get_u32(b, &state->maxidx)) != 0 ||
869 		    (r = sshbuf_get_u8(b, &state->allow_update)) != 0)
870 			return r;
871 		break;
872 	case SSHKEY_SERIALIZE_STATE:
873 		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
874 			return r;
875 		break;
876 	case SSHKEY_SERIALIZE_FULL:
877 		if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
878 		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
879 			return r;
880 		break;
881 	default:
882 		r = SSH_ERR_INVALID_FORMAT;
883 		break;
884 	}
885 	return r;
886 }
887 
888 int
889 sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
890    struct sshbuf **retp)
891 {
892 	struct ssh_xmss_state *state = k->xmss_state;
893 	struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
894 	struct sshcipher_ctx *ciphercontext = NULL;
895 	const struct sshcipher *cipher;
896 	u_char *cp, *key, *iv = NULL;
897 	size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
898 	int r = SSH_ERR_INTERNAL_ERROR;
899 
900 	if (retp != NULL)
901 		*retp = NULL;
902 	if (state == NULL ||
903 	    state->enc_keyiv == NULL ||
904 	    state->enc_ciphername == NULL)
905 		return SSH_ERR_INTERNAL_ERROR;
906 	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
907 		r = SSH_ERR_INTERNAL_ERROR;
908 		goto out;
909 	}
910 	blocksize = cipher_blocksize(cipher);
911 	keylen = cipher_keylen(cipher);
912 	ivlen = cipher_ivlen(cipher);
913 	authlen = cipher_authlen(cipher);
914 	if (state->enc_keyiv_len != keylen + ivlen) {
915 		r = SSH_ERR_INVALID_FORMAT;
916 		goto out;
917 	}
918 	key = state->enc_keyiv;
919 	if ((encrypted = sshbuf_new()) == NULL ||
920 	    (encoded = sshbuf_new()) == NULL ||
921 	    (padded = sshbuf_new()) == NULL ||
922 	    (iv = malloc(ivlen)) == NULL) {
923 		r = SSH_ERR_ALLOC_FAIL;
924 		goto out;
925 	}
926 
927 	/* replace first 4 bytes of IV with index to ensure uniqueness */
928 	memcpy(iv, key + keylen, ivlen);
929 	POKE_U32(iv, state->idx);
930 
931 	if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
932 	    (r = sshbuf_put_u32(encoded, state->idx)) != 0)
933 		goto out;
934 
935 	/* padded state will be encrypted */
936 	if ((r = sshbuf_putb(padded, b)) != 0)
937 		goto out;
938 	i = 0;
939 	while (sshbuf_len(padded) % blocksize) {
940 		if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
941 			goto out;
942 	}
943 	encrypted_len = sshbuf_len(padded);
944 
945 	/* header including the length of state is used as AAD */
946 	if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
947 		goto out;
948 	aadlen = sshbuf_len(encoded);
949 
950 	/* concat header and state */
951 	if ((r = sshbuf_putb(encoded, padded)) != 0)
952 		goto out;
953 
954 	/* reserve space for encryption of encoded data plus auth tag */
955 	/* encrypt at offset addlen */
956 	if ((r = sshbuf_reserve(encrypted,
957 	    encrypted_len + aadlen + authlen, &cp)) != 0 ||
958 	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
959 	    iv, ivlen, 1)) != 0 ||
960 	    (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
961 	    encrypted_len, aadlen, authlen)) != 0)
962 		goto out;
963 
964 	/* success */
965 	r = 0;
966  out:
967 	if (retp != NULL) {
968 		*retp = encrypted;
969 		encrypted = NULL;
970 	}
971 	sshbuf_free(padded);
972 	sshbuf_free(encoded);
973 	sshbuf_free(encrypted);
974 	cipher_free(ciphercontext);
975 	free(iv);
976 	return r;
977 }
978 
979 int
980 sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
981    struct sshbuf **retp)
982 {
983 	struct ssh_xmss_state *state = k->xmss_state;
984 	struct sshbuf *copy = NULL, *decrypted = NULL;
985 	struct sshcipher_ctx *ciphercontext = NULL;
986 	const struct sshcipher *cipher = NULL;
987 	u_char *key, *iv = NULL, *dp;
988 	size_t keylen, ivlen, authlen, aadlen;
989 	u_int blocksize, encrypted_len, index;
990 	int r = SSH_ERR_INTERNAL_ERROR;
991 
992 	if (retp != NULL)
993 		*retp = NULL;
994 	if (state == NULL ||
995 	    state->enc_keyiv == NULL ||
996 	    state->enc_ciphername == NULL)
997 		return SSH_ERR_INTERNAL_ERROR;
998 	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
999 		r = SSH_ERR_INVALID_FORMAT;
1000 		goto out;
1001 	}
1002 	blocksize = cipher_blocksize(cipher);
1003 	keylen = cipher_keylen(cipher);
1004 	ivlen = cipher_ivlen(cipher);
1005 	authlen = cipher_authlen(cipher);
1006 	if (state->enc_keyiv_len != keylen + ivlen) {
1007 		r = SSH_ERR_INTERNAL_ERROR;
1008 		goto out;
1009 	}
1010 	key = state->enc_keyiv;
1011 
1012 	if ((copy = sshbuf_fromb(encoded)) == NULL ||
1013 	    (decrypted = sshbuf_new()) == NULL ||
1014 	    (iv = malloc(ivlen)) == NULL) {
1015 		r = SSH_ERR_ALLOC_FAIL;
1016 		goto out;
1017 	}
1018 
1019 	/* check magic */
1020 	if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
1021 	    memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
1022 		r = SSH_ERR_INVALID_FORMAT;
1023 		goto out;
1024 	}
1025 	/* parse public portion */
1026 	if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
1027 	    (r = sshbuf_get_u32(encoded, &index)) != 0 ||
1028 	    (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
1029 		goto out;
1030 
1031 	/* check size of encrypted key blob */
1032 	if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
1033 		r = SSH_ERR_INVALID_FORMAT;
1034 		goto out;
1035 	}
1036 	/* check that an appropriate amount of auth data is present */
1037 	if (sshbuf_len(encoded) < authlen ||
1038 	    sshbuf_len(encoded) - authlen < encrypted_len) {
1039 		r = SSH_ERR_INVALID_FORMAT;
1040 		goto out;
1041 	}
1042 
1043 	aadlen = sshbuf_len(copy) - sshbuf_len(encoded);
1044 
1045 	/* replace first 4 bytes of IV with index to ensure uniqueness */
1046 	memcpy(iv, key + keylen, ivlen);
1047 	POKE_U32(iv, index);
1048 
1049 	/* decrypt private state of key */
1050 	if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
1051 	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
1052 	    iv, ivlen, 0)) != 0 ||
1053 	    (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
1054 	    encrypted_len, aadlen, authlen)) != 0)
1055 		goto out;
1056 
1057 	/* there should be no trailing data */
1058 	if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
1059 		goto out;
1060 	if (sshbuf_len(encoded) != 0) {
1061 		r = SSH_ERR_INVALID_FORMAT;
1062 		goto out;
1063 	}
1064 
1065 	/* remove AAD */
1066 	if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
1067 		goto out;
1068 	/* XXX encrypted includes unchecked padding */
1069 
1070 	/* success */
1071 	r = 0;
1072 	if (retp != NULL) {
1073 		*retp = decrypted;
1074 		decrypted = NULL;
1075 	}
1076  out:
1077 	cipher_free(ciphercontext);
1078 	sshbuf_free(copy);
1079 	sshbuf_free(decrypted);
1080 	free(iv);
1081 	return r;
1082 }
1083 
1084 u_int32_t
1085 sshkey_xmss_signatures_left(const struct sshkey *k)
1086 {
1087 	struct ssh_xmss_state *state = k->xmss_state;
1088 	u_int32_t idx;
1089 
1090 	if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
1091 	    state->maxidx) {
1092 		idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
1093 		if (idx < state->maxidx)
1094 			return state->maxidx - idx;
1095 	}
1096 	return 0;
1097 }
1098 
1099 int
1100 sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
1101 {
1102 	struct ssh_xmss_state *state = k->xmss_state;
1103 
1104 	if (sshkey_type_plain(k->type) != KEY_XMSS)
1105 		return SSH_ERR_INVALID_ARGUMENT;
1106 	if (maxsign == 0)
1107 		return 0;
1108 	if (state->idx + maxsign < state->idx)
1109 		return SSH_ERR_INVALID_ARGUMENT;
1110 	state->maxidx = state->idx + maxsign;
1111 	return 0;
1112 }
1113 #endif /* WITH_XMSS */
1114