/* $OpenBSD: sshkey-xmss.c,v 1.12 2022/10/28 00:39:29 djm Exp $ */
/*
 * Copyright (c) 2017 Markus Friedl.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#include "includes.h"
#ifdef WITH_XMSS

#include <sys/types.h>
#include <sys/uio.h>

#include <stdio.h>
#include <string.h>
#include <unistd.h>
#include <fcntl.h>
#include <errno.h>
#ifdef HAVE_SYS_FILE_H
# include <sys/file.h>
#endif

#include "ssh2.h"
#include "ssherr.h"
#include "sshbuf.h"
#include "cipher.h"
#include "sshkey.h"
#include "sshkey-xmss.h"
#include "atomicio.h"
#include "log.h"

#include "xmss_fast.h"

/* opaque internal XMSS state */
#define XMSS_MAGIC		"xmss-state-v1"
#define XMSS_CIPHERNAME		"aes256-gcm@openssh.com"
struct ssh_xmss_state {
	xmss_params	params;
	u_int32_t	n, w, h, k;

	bds_state	bds;
	u_char		*stack;
	u_int32_t	stackoffset;
	u_char		*stacklevels;
	u_char		*auth;
	u_char		*keep;
	u_char		*th_nodes;
	u_char		*retain;
	treehash_inst	*treehash;

	u_int32_t	idx;		/* state read from file */
	u_int32_t	maxidx;		/* restricted # of signatures */
	int		have_state;	/* .state file exists */
	int		lockfd;		/* locked in sshkey_xmss_get_state() */
	u_char		allow_update;	/* allow sshkey_xmss_update_state() */
	char		*enc_ciphername;/* encrypt state with cipher */
	u_char		*enc_keyiv;	/* encrypt state with key */
	u_int32_t	enc_keyiv_len;	/* length of enc_keyiv */
};

int	 sshkey_xmss_init_bds_state(struct sshkey *);
int	 sshkey_xmss_init_enc_key(struct sshkey *, const char *);
void	 sshkey_xmss_free_bds(struct sshkey *);
int	 sshkey_xmss_get_state_from_file(struct sshkey *, const char *,
	    int *, int);
int	 sshkey_xmss_encrypt_state(const struct sshkey *, struct sshbuf *,
	    struct sshbuf **);
int	 sshkey_xmss_decrypt_state(const struct sshkey *, struct sshbuf *,
	    struct sshbuf **);
int	 sshkey_xmss_serialize_enc_key(const struct sshkey *, struct sshbuf *);
int	 sshkey_xmss_deserialize_enc_key(struct sshkey *, struct sshbuf *);

#define PRINT(...) do { if (printerror) sshlog(__FILE__, __func__, __LINE__, \
    0, SYSLOG_LEVEL_ERROR, NULL, __VA_ARGS__); } while (0)

int
sshkey_xmss_init(struct sshkey *key, const char *name)
{
	struct ssh_xmss_state *state;

	if (key->xmss_state != NULL)
		return SSH_ERR_INVALID_FORMAT;
	if (name == NULL)
		return SSH_ERR_INVALID_FORMAT;
	state = calloc(sizeof(struct ssh_xmss_state), 1);
	if (state == NULL)
		return SSH_ERR_ALLOC_FAIL;
	if (strcmp(name, XMSS_SHA2_256_W16_H10_NAME) == 0) {
		state->n = 32;
		state->w = 16;
		state->h = 10;
	} else if (strcmp(name, XMSS_SHA2_256_W16_H16_NAME) == 0) {
		state->n = 32;
		state->w = 16;
		state->h = 16;
	} else if (strcmp(name, XMSS_SHA2_256_W16_H20_NAME) == 0) {
		state->n = 32;
		state->w = 16;
		state->h = 20;
	} else {
		free(state);
		return SSH_ERR_KEY_TYPE_UNKNOWN;
	}
	if ((key->xmss_name = strdup(name)) == NULL) {
		free(state);
		return SSH_ERR_ALLOC_FAIL;
	}
	state->k = 2;	/* XXX hardcoded */
	state->lockfd = -1;
	if (xmss_set_params(&state->params, state->n, state->h, state->w,
	    state->k) != 0) {
		free(state);
		return SSH_ERR_INVALID_FORMAT;
	}
	key->xmss_state = state;
	return 0;
}

void
sshkey_xmss_free_state(struct sshkey *key)
{
	struct ssh_xmss_state *state = key->xmss_state;

	sshkey_xmss_free_bds(key);
	if (state) {
		if (state->enc_keyiv) {
			explicit_bzero(state->enc_keyiv, state->enc_keyiv_len);
			free(state->enc_keyiv);
		}
		free(state->enc_ciphername);
		free(state);
	}
	key->xmss_state = NULL;
}

#define SSH_XMSS_K2_MAGIC	"k=2"
#define num_stack(x)		((x->h+1)*(x->n))
#define num_stacklevels(x)	(x->h+1)
#define num_auth(x)		((x->h)*(x->n))
#define num_keep(x)		((x->h >> 1)*(x->n))
#define num_th_nodes(x)		((x->h - x->k)*(x->n))
#define num_retain(x)		(((1ULL << x->k) - x->k - 1) * (x->n))
#define num_treehash(x)		((x->h) - (x->k))

int
sshkey_xmss_init_bds_state(struct sshkey *key)
{
	struct ssh_xmss_state *state = key->xmss_state;
	u_int32_t i;

	state->stackoffset = 0;
	if ((state->stack = calloc(num_stack(state), 1)) == NULL ||
	    (state->stacklevels = calloc(num_stacklevels(state), 1))== NULL ||
	    (state->auth = calloc(num_auth(state), 1)) == NULL ||
	    (state->keep = calloc(num_keep(state), 1)) == NULL ||
	    (state->th_nodes = calloc(num_th_nodes(state), 1)) == NULL ||
	    (state->retain = calloc(num_retain(state), 1)) == NULL ||
	    (state->treehash = calloc(num_treehash(state),
	    sizeof(treehash_inst))) == NULL) {
		sshkey_xmss_free_bds(key);
		return SSH_ERR_ALLOC_FAIL;
	}
	for (i = 0; i < state->h - state->k; i++)
		state->treehash[i].node = &state->th_nodes[state->n*i];
	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
	    state->stacklevels, state->auth, state->keep, state->treehash,
	    state->retain, 0);
	return 0;
}

void
sshkey_xmss_free_bds(struct sshkey *key)
{
	struct ssh_xmss_state *state = key->xmss_state;

	if (state == NULL)
		return;
	free(state->stack);
	free(state->stacklevels);
	free(state->auth);
	free(state->keep);
	free(state->th_nodes);
	free(state->retain);
	free(state->treehash);
	state->stack = NULL;
	state->stacklevels = NULL;
	state->auth = NULL;
	state->keep = NULL;
	state->th_nodes = NULL;
	state->retain = NULL;
	state->treehash = NULL;
}

void *
sshkey_xmss_params(const struct sshkey *key)
{
	struct ssh_xmss_state *state = key->xmss_state;

	if (state == NULL)
		return NULL;
	return &state->params;
}

void *
sshkey_xmss_bds_state(const struct sshkey *key)
{
	struct ssh_xmss_state *state = key->xmss_state;

	if (state == NULL)
		return NULL;
	return &state->bds;
}

int
sshkey_xmss_siglen(const struct sshkey *key, size_t *lenp)
{
	struct ssh_xmss_state *state = key->xmss_state;

	if (lenp == NULL)
		return SSH_ERR_INVALID_ARGUMENT;
	if (state == NULL)
		return SSH_ERR_INVALID_FORMAT;
	*lenp = 4 + state->n +
	    state->params.wots_par.keysize +
	    state->h * state->n;
	return 0;
}

size_t
sshkey_xmss_pklen(const struct sshkey *key)
{
	struct ssh_xmss_state *state = key->xmss_state;

	if (state == NULL)
		return 0;
	return state->n * 2;
}

size_t
sshkey_xmss_sklen(const struct sshkey *key)
{
	struct ssh_xmss_state *state = key->xmss_state;

	if (state == NULL)
		return 0;
	return state->n * 4 + 4;
}

int
sshkey_xmss_init_enc_key(struct sshkey *k, const char *ciphername)
{
	struct ssh_xmss_state *state = k->xmss_state;
	const struct sshcipher *cipher;
	size_t keylen = 0, ivlen = 0;

	if (state == NULL)
		return SSH_ERR_INVALID_ARGUMENT;
	if ((cipher = cipher_by_name(ciphername)) == NULL)
		return SSH_ERR_INTERNAL_ERROR;
	if ((state->enc_ciphername = strdup(ciphername)) == NULL)
		return SSH_ERR_ALLOC_FAIL;
	keylen = cipher_keylen(cipher);
	ivlen = cipher_ivlen(cipher);
	state->enc_keyiv_len = keylen + ivlen;
	if ((state->enc_keyiv = calloc(state->enc_keyiv_len, 1)) == NULL) {
		free(state->enc_ciphername);
		state->enc_ciphername = NULL;
		return SSH_ERR_ALLOC_FAIL;
	}
	arc4random_buf(state->enc_keyiv, state->enc_keyiv_len);
	return 0;
}

int
sshkey_xmss_serialize_enc_key(const struct sshkey *k, struct sshbuf *b)
{
	struct ssh_xmss_state *state = k->xmss_state;
	int r;

	if (state == NULL || state->enc_keyiv == NULL ||
	    state->enc_ciphername == NULL)
		return SSH_ERR_INVALID_ARGUMENT;
	if ((r = sshbuf_put_cstring(b, state->enc_ciphername)) != 0 ||
	    (r = sshbuf_put_string(b, state->enc_keyiv,
	    state->enc_keyiv_len)) != 0)
		return r;
	return 0;
}

int
sshkey_xmss_deserialize_enc_key(struct sshkey *k, struct sshbuf *b)
{
	struct ssh_xmss_state *state = k->xmss_state;
	size_t len;
	int r;

	if (state == NULL)
		return SSH_ERR_INVALID_ARGUMENT;
	if ((r = sshbuf_get_cstring(b, &state->enc_ciphername, NULL)) != 0 ||
	    (r = sshbuf_get_string(b, &state->enc_keyiv, &len)) != 0)
		return r;
	state->enc_keyiv_len = len;
	return 0;
}

int
sshkey_xmss_serialize_pk_info(const struct sshkey *k, struct sshbuf *b,
    enum sshkey_serialize_rep opts)
{
	struct ssh_xmss_state *state = k->xmss_state;
	u_char have_info = 1;
	u_int32_t idx;
	int r;

	if (state == NULL)
		return SSH_ERR_INVALID_ARGUMENT;
	if (opts != SSHKEY_SERIALIZE_INFO)
		return 0;
	idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
	if ((r = sshbuf_put_u8(b, have_info)) != 0 ||
	    (r = sshbuf_put_u32(b, idx)) != 0 ||
	    (r = sshbuf_put_u32(b, state->maxidx)) != 0)
		return r;
	return 0;
}

int
sshkey_xmss_deserialize_pk_info(struct sshkey *k, struct sshbuf *b)
{
	struct ssh_xmss_state *state = k->xmss_state;
	u_char have_info;
	int r;

	if (state == NULL)
		return SSH_ERR_INVALID_ARGUMENT;
	/* optional */
	if (sshbuf_len(b) == 0)
		return 0;
	if ((r = sshbuf_get_u8(b, &have_info)) != 0)
		return r;
	if (have_info != 1)
		return SSH_ERR_INVALID_ARGUMENT;
	if ((r = sshbuf_get_u32(b, &state->idx)) != 0 ||
	    (r = sshbuf_get_u32(b, &state->maxidx)) != 0)
		return r;
	return 0;
}

int
sshkey_xmss_generate_private_key(struct sshkey *k, int bits)
{
	int r;
	const char *name;

	if (bits == 10) {
		name = XMSS_SHA2_256_W16_H10_NAME;
	} else if (bits == 16) {
		name = XMSS_SHA2_256_W16_H16_NAME;
	} else if (bits == 20) {
		name = XMSS_SHA2_256_W16_H20_NAME;
	} else {
		name = XMSS_DEFAULT_NAME;
	}
	if ((r = sshkey_xmss_init(k, name)) != 0 ||
	    (r = sshkey_xmss_init_bds_state(k)) != 0 ||
	    (r = sshkey_xmss_init_enc_key(k, XMSS_CIPHERNAME)) != 0)
		return r;
	if ((k->xmss_pk = malloc(sshkey_xmss_pklen(k))) == NULL ||
	    (k->xmss_sk = malloc(sshkey_xmss_sklen(k))) == NULL) {
		return SSH_ERR_ALLOC_FAIL;
	}
	xmss_keypair(k->xmss_pk, k->xmss_sk, sshkey_xmss_bds_state(k),
	    sshkey_xmss_params(k));
	return 0;
}

int
sshkey_xmss_get_state_from_file(struct sshkey *k, const char *filename,
    int *have_file, int printerror)
{
	struct sshbuf *b = NULL, *enc = NULL;
	int ret = SSH_ERR_SYSTEM_ERROR, r, fd = -1;
	u_int32_t len;
	unsigned char buf[4], *data = NULL;

	*have_file = 0;
	if ((fd = open(filename, O_RDONLY)) >= 0) {
		*have_file = 1;
		if (atomicio(read, fd, buf, sizeof(buf)) != sizeof(buf)) {
			PRINT("corrupt state file: %s", filename);
			goto done;
		}
		len = PEEK_U32(buf);
		if ((data = calloc(len, 1)) == NULL) {
			ret = SSH_ERR_ALLOC_FAIL;
			goto done;
		}
		if (atomicio(read, fd, data, len) != len) {
			PRINT("cannot read blob: %s", filename);
			goto done;
		}
		if ((enc = sshbuf_from(data, len)) == NULL) {
			ret = SSH_ERR_ALLOC_FAIL;
			goto done;
		}
		sshkey_xmss_free_bds(k);
		if ((r = sshkey_xmss_decrypt_state(k, enc, &b)) != 0) {
			ret = r;
			goto done;
		}
		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0) {
			ret = r;
			goto done;
		}
		ret = 0;
	}
done:
	if (fd != -1)
		close(fd);
	free(data);
	sshbuf_free(enc);
	sshbuf_free(b);
	return ret;
}

int
sshkey_xmss_get_state(const struct sshkey *k, int printerror)
{
	struct ssh_xmss_state *state = k->xmss_state;
	u_int32_t idx = 0;
	char *filename = NULL;
	char *statefile = NULL, *ostatefile = NULL, *lockfile = NULL;
	int lockfd = -1, have_state = 0, have_ostate, tries = 0;
	int ret = SSH_ERR_INVALID_ARGUMENT, r;

	if (state == NULL)
		goto done;
	/*
	 * If maxidx is set, then we are allowed a limited number
	 * of signatures, but don't need to access the disk.
	 * Otherwise we need to deal with the on-disk state.
	 */
	if (state->maxidx) {
		/* xmss_sk always contains the current state */
		idx = PEEK_U32(k->xmss_sk);
		if (idx < state->maxidx) {
			state->allow_update = 1;
			return 0;
		}
		return SSH_ERR_INVALID_ARGUMENT;
	}
	if ((filename = k->xmss_filename) == NULL)
		goto done;
	if (asprintf(&lockfile, "%s.lock", filename) == -1 ||
	    asprintf(&statefile, "%s.state", filename) == -1 ||
	    asprintf(&ostatefile, "%s.ostate", filename) == -1) {
		ret = SSH_ERR_ALLOC_FAIL;
		goto done;
	}
	if ((lockfd = open(lockfile, O_CREAT|O_RDONLY, 0600)) == -1) {
		ret = SSH_ERR_SYSTEM_ERROR;
		PRINT("cannot open/create: %s", lockfile);
		goto done;
	}
	while (flock(lockfd, LOCK_EX|LOCK_NB) == -1) {
		if (errno != EWOULDBLOCK) {
			ret = SSH_ERR_SYSTEM_ERROR;
			PRINT("cannot lock: %s", lockfile);
			goto done;
		}
		if (++tries > 10) {
			ret = SSH_ERR_SYSTEM_ERROR;
			PRINT("giving up on: %s", lockfile);
			goto done;
		}
		usleep(1000*100*tries);
	}
	/* XXX no longer const */
	if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
	    statefile, &have_state, printerror)) != 0) {
		if ((r = sshkey_xmss_get_state_from_file((struct sshkey *)k,
		    ostatefile, &have_ostate, printerror)) == 0) {
			state->allow_update = 1;
			r = sshkey_xmss_forward_state(k, 1);
			state->idx = PEEK_U32(k->xmss_sk);
			state->allow_update = 0;
		}
	}
	if (!have_state && !have_ostate) {
		/* check that bds state is initialized */
		if (state->bds.auth == NULL)
			goto done;
		PRINT("start from scratch idx 0: %u", state->idx);
	} else if (r != 0) {
		ret = r;
		goto done;
	}
	if (state->idx + 1 < state->idx) {
		PRINT("state wrap: %u", state->idx);
		goto done;
	}
	state->have_state = have_state;
	state->lockfd = lockfd;
	state->allow_update = 1;
	lockfd = -1;
	ret = 0;
done:
	if (lockfd != -1)
		close(lockfd);
	free(lockfile);
	free(statefile);
	free(ostatefile);
	return ret;
}

int
sshkey_xmss_forward_state(const struct sshkey *k, u_int32_t reserve)
{
	struct ssh_xmss_state *state = k->xmss_state;
	u_char *sig = NULL;
	size_t required_siglen;
	unsigned long long smlen;
	u_char data;
	int ret, r;

	if (state == NULL || !state->allow_update)
		return SSH_ERR_INVALID_ARGUMENT;
	if (reserve == 0)
		return SSH_ERR_INVALID_ARGUMENT;
	if (state->idx + reserve <= state->idx)
		return SSH_ERR_INVALID_ARGUMENT;
	if ((r = sshkey_xmss_siglen(k, &required_siglen)) != 0)
		return r;
	if ((sig = malloc(required_siglen)) == NULL)
		return SSH_ERR_ALLOC_FAIL;
	while (reserve-- > 0) {
		state->idx = PEEK_U32(k->xmss_sk);
		smlen = required_siglen;
		if ((ret = xmss_sign(k->xmss_sk, sshkey_xmss_bds_state(k),
		    sig, &smlen, &data, 0, sshkey_xmss_params(k))) != 0) {
			r = SSH_ERR_INVALID_ARGUMENT;
			break;
		}
	}
	free(sig);
	return r;
}

int
sshkey_xmss_update_state(const struct sshkey *k, int printerror)
{
	struct ssh_xmss_state *state = k->xmss_state;
	struct sshbuf *b = NULL, *enc = NULL;
	u_int32_t idx = 0;
	unsigned char buf[4];
	char *filename = NULL;
	char *statefile = NULL, *ostatefile = NULL, *nstatefile = NULL;
	int fd = -1;
	int ret = SSH_ERR_INVALID_ARGUMENT;

	if (state == NULL || !state->allow_update)
		return ret;
	if (state->maxidx) {
		/* no update since the number of signatures is limited */
		ret = 0;
		goto done;
	}
	idx = PEEK_U32(k->xmss_sk);
	if (idx == state->idx) {
		/* no signature happened, no need to update */
		ret = 0;
		goto done;
	} else if (idx != state->idx + 1) {
		PRINT("more than one signature happened: idx %u state %u",
		    idx, state->idx);
		goto done;
	}
	state->idx = idx;
	if ((filename = k->xmss_filename) == NULL)
		goto done;
	if (asprintf(&statefile, "%s.state", filename) == -1 ||
	    asprintf(&ostatefile, "%s.ostate", filename) == -1 ||
	    asprintf(&nstatefile, "%s.nstate", filename) == -1) {
		ret = SSH_ERR_ALLOC_FAIL;
		goto done;
	}
	unlink(nstatefile);
	if ((b = sshbuf_new()) == NULL) {
		ret = SSH_ERR_ALLOC_FAIL;
		goto done;
	}
	if ((ret = sshkey_xmss_serialize_state(k, b)) != 0) {
		PRINT("SERLIALIZE FAILED: %d", ret);
		goto done;
	}
	if ((ret = sshkey_xmss_encrypt_state(k, b, &enc)) != 0) {
		PRINT("ENCRYPT FAILED: %d", ret);
		goto done;
	}
	if ((fd = open(nstatefile, O_CREAT|O_WRONLY|O_EXCL, 0600)) == -1) {
		ret = SSH_ERR_SYSTEM_ERROR;
		PRINT("open new state file: %s", nstatefile);
		goto done;
	}
	POKE_U32(buf, sshbuf_len(enc));
	if (atomicio(vwrite, fd, buf, sizeof(buf)) != sizeof(buf)) {
		ret = SSH_ERR_SYSTEM_ERROR;
		PRINT("write new state file hdr: %s", nstatefile);
		close(fd);
		goto done;
	}
	if (atomicio(vwrite, fd, sshbuf_mutable_ptr(enc), sshbuf_len(enc)) !=
	    sshbuf_len(enc)) {
		ret = SSH_ERR_SYSTEM_ERROR;
		PRINT("write new state file data: %s", nstatefile);
		close(fd);
		goto done;
	}
	if (fsync(fd) == -1) {
		ret = SSH_ERR_SYSTEM_ERROR;
		PRINT("sync new state file: %s", nstatefile);
		close(fd);
		goto done;
	}
	if (close(fd) == -1) {
		ret = SSH_ERR_SYSTEM_ERROR;
		PRINT("close new state file: %s", nstatefile);
		goto done;
	}
	if (state->have_state) {
		unlink(ostatefile);
		if (link(statefile, ostatefile)) {
			ret = SSH_ERR_SYSTEM_ERROR;
			PRINT("backup state %s to %s", statefile, ostatefile);
			goto done;
		}
	}
	if (rename(nstatefile, statefile) == -1) {
		ret = SSH_ERR_SYSTEM_ERROR;
		PRINT("rename %s to %s", nstatefile, statefile);
		goto done;
	}
	ret = 0;
done:
	if (state->lockfd != -1) {
		close(state->lockfd);
		state->lockfd = -1;
	}
	if (nstatefile)
		unlink(nstatefile);
	free(statefile);
	free(ostatefile);
	free(nstatefile);
	sshbuf_free(b);
	sshbuf_free(enc);
	return ret;
}

int
sshkey_xmss_serialize_state(const struct sshkey *k, struct sshbuf *b)
{
	struct ssh_xmss_state *state = k->xmss_state;
	treehash_inst *th;
	u_int32_t i, node;
	int r;

	if (state == NULL)
		return SSH_ERR_INVALID_ARGUMENT;
	if (state->stack == NULL)
		return SSH_ERR_INVALID_ARGUMENT;
	state->stackoffset = state->bds.stackoffset;	/* copy back */
	if ((r = sshbuf_put_cstring(b, SSH_XMSS_K2_MAGIC)) != 0 ||
	    (r = sshbuf_put_u32(b, state->idx)) != 0 ||
	    (r = sshbuf_put_string(b, state->stack, num_stack(state))) != 0 ||
	    (r = sshbuf_put_u32(b, state->stackoffset)) != 0 ||
	    (r = sshbuf_put_string(b, state->stacklevels, num_stacklevels(state))) != 0 ||
	    (r = sshbuf_put_string(b, state->auth, num_auth(state))) != 0 ||
	    (r = sshbuf_put_string(b, state->keep, num_keep(state))) != 0 ||
	    (r = sshbuf_put_string(b, state->th_nodes, num_th_nodes(state))) != 0 ||
	    (r = sshbuf_put_string(b, state->retain, num_retain(state))) != 0 ||
	    (r = sshbuf_put_u32(b, num_treehash(state))) != 0)
		return r;
	for (i = 0; i < num_treehash(state); i++) {
		th = &state->treehash[i];
		node = th->node - state->th_nodes;
		if ((r = sshbuf_put_u32(b, th->h)) != 0 ||
		    (r = sshbuf_put_u32(b, th->next_idx)) != 0 ||
		    (r = sshbuf_put_u32(b, th->stackusage)) != 0 ||
		    (r = sshbuf_put_u8(b, th->completed)) != 0 ||
		    (r = sshbuf_put_u32(b, node)) != 0)
			return r;
	}
	return 0;
}

int
sshkey_xmss_serialize_state_opt(const struct sshkey *k, struct sshbuf *b,
    enum sshkey_serialize_rep opts)
{
	struct ssh_xmss_state *state = k->xmss_state;
	int r = SSH_ERR_INVALID_ARGUMENT;
	u_char have_stack, have_filename, have_enc;

	if (state == NULL)
		return SSH_ERR_INVALID_ARGUMENT;
	if ((r = sshbuf_put_u8(b, opts)) != 0)
		return r;
	switch (opts) {
	case SSHKEY_SERIALIZE_STATE:
		r = sshkey_xmss_serialize_state(k, b);
		break;
	case SSHKEY_SERIALIZE_FULL:
		if ((r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
			return r;
		r = sshkey_xmss_serialize_state(k, b);
		break;
	case SSHKEY_SERIALIZE_SHIELD:
		/* all of stack/filename/enc are optional */
		have_stack = state->stack != NULL;
		if ((r = sshbuf_put_u8(b, have_stack)) != 0)
			return r;
		if (have_stack) {
			state->idx = PEEK_U32(k->xmss_sk);	/* update */
			if ((r = sshkey_xmss_serialize_state(k, b)) != 0)
				return r;
		}
		have_filename = k->xmss_filename != NULL;
		if ((r = sshbuf_put_u8(b, have_filename)) != 0)
			return r;
		if (have_filename &&
		    (r = sshbuf_put_cstring(b, k->xmss_filename)) != 0)
			return r;
		have_enc = state->enc_keyiv != NULL;
		if ((r = sshbuf_put_u8(b, have_enc)) != 0)
			return r;
		if (have_enc &&
		    (r = sshkey_xmss_serialize_enc_key(k, b)) != 0)
			return r;
		if ((r = sshbuf_put_u32(b, state->maxidx)) != 0 ||
		    (r = sshbuf_put_u8(b, state->allow_update)) != 0)
			return r;
		break;
	case SSHKEY_SERIALIZE_DEFAULT:
		r = 0;
		break;
	default:
		r = SSH_ERR_INVALID_ARGUMENT;
		break;
	}
	return r;
}

int
sshkey_xmss_deserialize_state(struct sshkey *k, struct sshbuf *b)
{
	struct ssh_xmss_state *state = k->xmss_state;
	treehash_inst *th;
	u_int32_t i, lh, node;
	size_t ls, lsl, la, lk, ln, lr;
	char *magic;
	int r = SSH_ERR_INTERNAL_ERROR;

	if (state == NULL)
		return SSH_ERR_INVALID_ARGUMENT;
	if (k->xmss_sk == NULL)
		return SSH_ERR_INVALID_ARGUMENT;
	if ((state->treehash = calloc(num_treehash(state),
	    sizeof(treehash_inst))) == NULL)
		return SSH_ERR_ALLOC_FAIL;
	if ((r = sshbuf_get_cstring(b, &magic, NULL)) != 0 ||
	    (r = sshbuf_get_u32(b, &state->idx)) != 0 ||
	    (r = sshbuf_get_string(b, &state->stack, &ls)) != 0 ||
	    (r = sshbuf_get_u32(b, &state->stackoffset)) != 0 ||
	    (r = sshbuf_get_string(b, &state->stacklevels, &lsl)) != 0 ||
	    (r = sshbuf_get_string(b, &state->auth, &la)) != 0 ||
	    (r = sshbuf_get_string(b, &state->keep, &lk)) != 0 ||
	    (r = sshbuf_get_string(b, &state->th_nodes, &ln)) != 0 ||
	    (r = sshbuf_get_string(b, &state->retain, &lr)) != 0 ||
	    (r = sshbuf_get_u32(b, &lh)) != 0)
		goto out;
	if (strcmp(magic, SSH_XMSS_K2_MAGIC) != 0) {
		r = SSH_ERR_INVALID_ARGUMENT;
		goto out;
	}
	/* XXX check stackoffset */
	if (ls != num_stack(state) ||
	    lsl != num_stacklevels(state) ||
	    la != num_auth(state) ||
	    lk != num_keep(state) ||
	    ln != num_th_nodes(state) ||
	    lr != num_retain(state) ||
	    lh != num_treehash(state)) {
		r = SSH_ERR_INVALID_ARGUMENT;
		goto out;
	}
	for (i = 0; i < num_treehash(state); i++) {
		th = &state->treehash[i];
		if ((r = sshbuf_get_u32(b, &th->h)) != 0 ||
		    (r = sshbuf_get_u32(b, &th->next_idx)) != 0 ||
		    (r = sshbuf_get_u32(b, &th->stackusage)) != 0 ||
		    (r = sshbuf_get_u8(b, &th->completed)) != 0 ||
		    (r = sshbuf_get_u32(b, &node)) != 0)
			goto out;
		if (node < num_th_nodes(state))
			th->node = &state->th_nodes[node];
	}
	POKE_U32(k->xmss_sk, state->idx);
	xmss_set_bds_state(&state->bds, state->stack, state->stackoffset,
	    state->stacklevels, state->auth, state->keep, state->treehash,
	    state->retain, 0);
	/* success */
	r = 0;
 out:
	free(magic);
	return r;
}

int
sshkey_xmss_deserialize_state_opt(struct sshkey *k, struct sshbuf *b)
{
	struct ssh_xmss_state *state = k->xmss_state;
	enum sshkey_serialize_rep opts;
	u_char have_state, have_stack, have_filename, have_enc;
	int r;

	if ((r = sshbuf_get_u8(b, &have_state)) != 0)
		return r;

	opts = have_state;
	switch (opts) {
	case SSHKEY_SERIALIZE_DEFAULT:
		r = 0;
		break;
	case SSHKEY_SERIALIZE_SHIELD:
		if ((r = sshbuf_get_u8(b, &have_stack)) != 0)
			return r;
		if (have_stack &&
		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
			return r;
		if ((r = sshbuf_get_u8(b, &have_filename)) != 0)
			return r;
		if (have_filename &&
		    (r = sshbuf_get_cstring(b, &k->xmss_filename, NULL)) != 0)
			return r;
		if ((r = sshbuf_get_u8(b, &have_enc)) != 0)
			return r;
		if (have_enc &&
		    (r = sshkey_xmss_deserialize_enc_key(k, b)) != 0)
			return r;
		if ((r = sshbuf_get_u32(b, &state->maxidx)) != 0 ||
		    (r = sshbuf_get_u8(b, &state->allow_update)) != 0)
			return r;
		break;
	case SSHKEY_SERIALIZE_STATE:
		if ((r = sshkey_xmss_deserialize_state(k, b)) != 0)
			return r;
		break;
	case SSHKEY_SERIALIZE_FULL:
		if ((r = sshkey_xmss_deserialize_enc_key(k, b)) != 0 ||
		    (r = sshkey_xmss_deserialize_state(k, b)) != 0)
			return r;
		break;
	default:
		r = SSH_ERR_INVALID_FORMAT;
		break;
	}
	return r;
}

int
sshkey_xmss_encrypt_state(const struct sshkey *k, struct sshbuf *b,
   struct sshbuf **retp)
{
	struct ssh_xmss_state *state = k->xmss_state;
	struct sshbuf *encrypted = NULL, *encoded = NULL, *padded = NULL;
	struct sshcipher_ctx *ciphercontext = NULL;
	const struct sshcipher *cipher;
	u_char *cp, *key, *iv = NULL;
	size_t i, keylen, ivlen, blocksize, authlen, encrypted_len, aadlen;
	int r = SSH_ERR_INTERNAL_ERROR;

	if (retp != NULL)
		*retp = NULL;
	if (state == NULL ||
	    state->enc_keyiv == NULL ||
	    state->enc_ciphername == NULL)
		return SSH_ERR_INTERNAL_ERROR;
	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
		r = SSH_ERR_INTERNAL_ERROR;
		goto out;
	}
	blocksize = cipher_blocksize(cipher);
	keylen = cipher_keylen(cipher);
	ivlen = cipher_ivlen(cipher);
	authlen = cipher_authlen(cipher);
	if (state->enc_keyiv_len != keylen + ivlen) {
		r = SSH_ERR_INVALID_FORMAT;
		goto out;
	}
	key = state->enc_keyiv;
	if ((encrypted = sshbuf_new()) == NULL ||
	    (encoded = sshbuf_new()) == NULL ||
	    (padded = sshbuf_new()) == NULL ||
	    (iv = malloc(ivlen)) == NULL) {
		r = SSH_ERR_ALLOC_FAIL;
		goto out;
	}

	/* replace first 4 bytes of IV with index to ensure uniqueness */
	memcpy(iv, key + keylen, ivlen);
	POKE_U32(iv, state->idx);

	if ((r = sshbuf_put(encoded, XMSS_MAGIC, sizeof(XMSS_MAGIC))) != 0 ||
	    (r = sshbuf_put_u32(encoded, state->idx)) != 0)
		goto out;

	/* padded state will be encrypted */
	if ((r = sshbuf_putb(padded, b)) != 0)
		goto out;
	i = 0;
	while (sshbuf_len(padded) % blocksize) {
		if ((r = sshbuf_put_u8(padded, ++i & 0xff)) != 0)
			goto out;
	}
	encrypted_len = sshbuf_len(padded);

	/* header including the length of state is used as AAD */
	if ((r = sshbuf_put_u32(encoded, encrypted_len)) != 0)
		goto out;
	aadlen = sshbuf_len(encoded);

	/* concat header and state */
	if ((r = sshbuf_putb(encoded, padded)) != 0)
		goto out;

	/* reserve space for encryption of encoded data plus auth tag */
	/* encrypt at offset addlen */
	if ((r = sshbuf_reserve(encrypted,
	    encrypted_len + aadlen + authlen, &cp)) != 0 ||
	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
	    iv, ivlen, 1)) != 0 ||
	    (r = cipher_crypt(ciphercontext, 0, cp, sshbuf_ptr(encoded),
	    encrypted_len, aadlen, authlen)) != 0)
		goto out;

	/* success */
	r = 0;
 out:
	if (retp != NULL) {
		*retp = encrypted;
		encrypted = NULL;
	}
	sshbuf_free(padded);
	sshbuf_free(encoded);
	sshbuf_free(encrypted);
	cipher_free(ciphercontext);
	free(iv);
	return r;
}

int
sshkey_xmss_decrypt_state(const struct sshkey *k, struct sshbuf *encoded,
   struct sshbuf **retp)
{
	struct ssh_xmss_state *state = k->xmss_state;
	struct sshbuf *copy = NULL, *decrypted = NULL;
	struct sshcipher_ctx *ciphercontext = NULL;
	const struct sshcipher *cipher = NULL;
	u_char *key, *iv = NULL, *dp;
	size_t keylen, ivlen, authlen, aadlen;
	u_int blocksize, encrypted_len, index;
	int r = SSH_ERR_INTERNAL_ERROR;

	if (retp != NULL)
		*retp = NULL;
	if (state == NULL ||
	    state->enc_keyiv == NULL ||
	    state->enc_ciphername == NULL)
		return SSH_ERR_INTERNAL_ERROR;
	if ((cipher = cipher_by_name(state->enc_ciphername)) == NULL) {
		r = SSH_ERR_INVALID_FORMAT;
		goto out;
	}
	blocksize = cipher_blocksize(cipher);
	keylen = cipher_keylen(cipher);
	ivlen = cipher_ivlen(cipher);
	authlen = cipher_authlen(cipher);
	if (state->enc_keyiv_len != keylen + ivlen) {
		r = SSH_ERR_INTERNAL_ERROR;
		goto out;
	}
	key = state->enc_keyiv;

	if ((copy = sshbuf_fromb(encoded)) == NULL ||
	    (decrypted = sshbuf_new()) == NULL ||
	    (iv = malloc(ivlen)) == NULL) {
		r = SSH_ERR_ALLOC_FAIL;
		goto out;
	}

	/* check magic */
	if (sshbuf_len(encoded) < sizeof(XMSS_MAGIC) ||
	    memcmp(sshbuf_ptr(encoded), XMSS_MAGIC, sizeof(XMSS_MAGIC))) {
		r = SSH_ERR_INVALID_FORMAT;
		goto out;
	}
	/* parse public portion */
	if ((r = sshbuf_consume(encoded, sizeof(XMSS_MAGIC))) != 0 ||
	    (r = sshbuf_get_u32(encoded, &index)) != 0 ||
	    (r = sshbuf_get_u32(encoded, &encrypted_len)) != 0)
		goto out;

	/* check size of encrypted key blob */
	if (encrypted_len < blocksize || (encrypted_len % blocksize) != 0) {
		r = SSH_ERR_INVALID_FORMAT;
		goto out;
	}
	/* check that an appropriate amount of auth data is present */
	if (sshbuf_len(encoded) < authlen ||
	    sshbuf_len(encoded) - authlen < encrypted_len) {
		r = SSH_ERR_INVALID_FORMAT;
		goto out;
	}

	aadlen = sshbuf_len(copy) - sshbuf_len(encoded);

	/* replace first 4 bytes of IV with index to ensure uniqueness */
	memcpy(iv, key + keylen, ivlen);
	POKE_U32(iv, index);

	/* decrypt private state of key */
	if ((r = sshbuf_reserve(decrypted, aadlen + encrypted_len, &dp)) != 0 ||
	    (r = cipher_init(&ciphercontext, cipher, key, keylen,
	    iv, ivlen, 0)) != 0 ||
	    (r = cipher_crypt(ciphercontext, 0, dp, sshbuf_ptr(copy),
	    encrypted_len, aadlen, authlen)) != 0)
		goto out;

	/* there should be no trailing data */
	if ((r = sshbuf_consume(encoded, encrypted_len + authlen)) != 0)
		goto out;
	if (sshbuf_len(encoded) != 0) {
		r = SSH_ERR_INVALID_FORMAT;
		goto out;
	}

	/* remove AAD */
	if ((r = sshbuf_consume(decrypted, aadlen)) != 0)
		goto out;
	/* XXX encrypted includes unchecked padding */

	/* success */
	r = 0;
	if (retp != NULL) {
		*retp = decrypted;
		decrypted = NULL;
	}
 out:
	cipher_free(ciphercontext);
	sshbuf_free(copy);
	sshbuf_free(decrypted);
	free(iv);
	return r;
}

u_int32_t
sshkey_xmss_signatures_left(const struct sshkey *k)
{
	struct ssh_xmss_state *state = k->xmss_state;
	u_int32_t idx;

	if (sshkey_type_plain(k->type) == KEY_XMSS && state &&
	    state->maxidx) {
		idx = k->xmss_sk ? PEEK_U32(k->xmss_sk) : state->idx;
		if (idx < state->maxidx)
			return state->maxidx - idx;
	}
	return 0;
}

int
sshkey_xmss_enable_maxsign(struct sshkey *k, u_int32_t maxsign)
{
	struct ssh_xmss_state *state = k->xmss_state;

	if (sshkey_type_plain(k->type) != KEY_XMSS)
		return SSH_ERR_INVALID_ARGUMENT;
	if (maxsign == 0)
		return 0;
	if (state->idx + maxsign < state->idx)
		return SSH_ERR_INVALID_ARGUMENT;
	state->maxidx = state->idx + maxsign;
	return 0;
}
#endif /* WITH_XMSS */