xref: /freebsd/crypto/openssh/kex.c (revision 59144db3fca192c4637637dfe6b5a5d98632cd47)
1 /* $OpenBSD: kex.c,v 1.185 2024/01/08 00:34:33 djm Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25 
26 #include "includes.h"
27 
28 #include <sys/types.h>
29 #include <errno.h>
30 #include <signal.h>
31 #include <stdarg.h>
32 #include <stdio.h>
33 #include <stdlib.h>
34 #include <string.h>
35 #include <unistd.h>
36 #ifdef HAVE_POLL_H
37 #include <poll.h>
38 #endif
39 
40 #ifdef WITH_OPENSSL
41 #include <openssl/crypto.h>
42 #include <openssl/dh.h>
43 #endif
44 
45 #include "ssh.h"
46 #include "ssh2.h"
47 #include "atomicio.h"
48 #include "version.h"
49 #include "packet.h"
50 #include "compat.h"
51 #include "cipher.h"
52 #include "sshkey.h"
53 #include "kex.h"
54 #include "log.h"
55 #include "mac.h"
56 #include "match.h"
57 #include "misc.h"
58 #include "dispatch.h"
59 #include "monitor.h"
60 #include "myproposal.h"
61 
62 #include "ssherr.h"
63 #include "sshbuf.h"
64 #include "digest.h"
65 #include "xmalloc.h"
66 
67 /* prototype */
68 static int kex_choose_conf(struct ssh *, uint32_t seq);
69 static int kex_input_newkeys(int, u_int32_t, struct ssh *);
70 
71 static const char * const proposal_names[PROPOSAL_MAX] = {
72 	"KEX algorithms",
73 	"host key algorithms",
74 	"ciphers ctos",
75 	"ciphers stoc",
76 	"MACs ctos",
77 	"MACs stoc",
78 	"compression ctos",
79 	"compression stoc",
80 	"languages ctos",
81 	"languages stoc",
82 };
83 
84 struct kexalg {
85 	char *name;
86 	u_int type;
87 	int ec_nid;
88 	int hash_alg;
89 };
90 static const struct kexalg kexalgs[] = {
91 #ifdef WITH_OPENSSL
92 	{ KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
93 	{ KEX_DH14_SHA1, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
94 	{ KEX_DH14_SHA256, KEX_DH_GRP14_SHA256, 0, SSH_DIGEST_SHA256 },
95 	{ KEX_DH16_SHA512, KEX_DH_GRP16_SHA512, 0, SSH_DIGEST_SHA512 },
96 	{ KEX_DH18_SHA512, KEX_DH_GRP18_SHA512, 0, SSH_DIGEST_SHA512 },
97 	{ KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
98 #ifdef HAVE_EVP_SHA256
99 	{ KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
100 #endif /* HAVE_EVP_SHA256 */
101 #ifdef OPENSSL_HAS_ECC
102 	{ KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
103 	    NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
104 	{ KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
105 	    SSH_DIGEST_SHA384 },
106 # ifdef OPENSSL_HAS_NISTP521
107 	{ KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
108 	    SSH_DIGEST_SHA512 },
109 # endif /* OPENSSL_HAS_NISTP521 */
110 #endif /* OPENSSL_HAS_ECC */
111 #endif /* WITH_OPENSSL */
112 #if defined(HAVE_EVP_SHA256) || !defined(WITH_OPENSSL)
113 	{ KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
114 	{ KEX_CURVE25519_SHA256_OLD, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
115 #ifdef USE_SNTRUP761X25519
116 	{ KEX_SNTRUP761X25519_SHA512, KEX_KEM_SNTRUP761X25519_SHA512, 0,
117 	    SSH_DIGEST_SHA512 },
118 #endif
119 #endif /* HAVE_EVP_SHA256 || !WITH_OPENSSL */
120 	{ NULL, 0, -1, -1},
121 };
122 
123 char *
124 kex_alg_list(char sep)
125 {
126 	char *ret = NULL, *tmp;
127 	size_t nlen, rlen = 0;
128 	const struct kexalg *k;
129 
130 	for (k = kexalgs; k->name != NULL; k++) {
131 		if (ret != NULL)
132 			ret[rlen++] = sep;
133 		nlen = strlen(k->name);
134 		if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
135 			free(ret);
136 			return NULL;
137 		}
138 		ret = tmp;
139 		memcpy(ret + rlen, k->name, nlen + 1);
140 		rlen += nlen;
141 	}
142 	return ret;
143 }
144 
145 static const struct kexalg *
146 kex_alg_by_name(const char *name)
147 {
148 	const struct kexalg *k;
149 
150 	for (k = kexalgs; k->name != NULL; k++) {
151 		if (strcmp(k->name, name) == 0)
152 			return k;
153 	}
154 	return NULL;
155 }
156 
157 /* Validate KEX method name list */
158 int
159 kex_names_valid(const char *names)
160 {
161 	char *s, *cp, *p;
162 
163 	if (names == NULL || strcmp(names, "") == 0)
164 		return 0;
165 	if ((s = cp = strdup(names)) == NULL)
166 		return 0;
167 	for ((p = strsep(&cp, ",")); p && *p != '\0';
168 	    (p = strsep(&cp, ","))) {
169 		if (kex_alg_by_name(p) == NULL) {
170 			error("Unsupported KEX algorithm \"%.100s\"", p);
171 			free(s);
172 			return 0;
173 		}
174 	}
175 	debug3("kex names ok: [%s]", names);
176 	free(s);
177 	return 1;
178 }
179 
180 /* returns non-zero if proposal contains any algorithm from algs */
181 static int
182 has_any_alg(const char *proposal, const char *algs)
183 {
184 	char *cp;
185 
186 	if ((cp = match_list(proposal, algs, NULL)) == NULL)
187 		return 0;
188 	free(cp);
189 	return 1;
190 }
191 
192 /*
193  * Concatenate algorithm names, avoiding duplicates in the process.
194  * Caller must free returned string.
195  */
196 char *
197 kex_names_cat(const char *a, const char *b)
198 {
199 	char *ret = NULL, *tmp = NULL, *cp, *p;
200 	size_t len;
201 
202 	if (a == NULL || *a == '\0')
203 		return strdup(b);
204 	if (b == NULL || *b == '\0')
205 		return strdup(a);
206 	if (strlen(b) > 1024*1024)
207 		return NULL;
208 	len = strlen(a) + strlen(b) + 2;
209 	if ((tmp = cp = strdup(b)) == NULL ||
210 	    (ret = calloc(1, len)) == NULL) {
211 		free(tmp);
212 		return NULL;
213 	}
214 	strlcpy(ret, a, len);
215 	for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
216 		if (has_any_alg(ret, p))
217 			continue; /* Algorithm already present */
218 		if (strlcat(ret, ",", len) >= len ||
219 		    strlcat(ret, p, len) >= len) {
220 			free(tmp);
221 			free(ret);
222 			return NULL; /* Shouldn't happen */
223 		}
224 	}
225 	free(tmp);
226 	return ret;
227 }
228 
229 /*
230  * Assemble a list of algorithms from a default list and a string from a
231  * configuration file. The user-provided string may begin with '+' to
232  * indicate that it should be appended to the default, '-' that the
233  * specified names should be removed, or '^' that they should be placed
234  * at the head.
235  */
236 int
237 kex_assemble_names(char **listp, const char *def, const char *all)
238 {
239 	char *cp, *tmp, *patterns;
240 	char *list = NULL, *ret = NULL, *matching = NULL, *opatterns = NULL;
241 	int r = SSH_ERR_INTERNAL_ERROR;
242 
243 	if (listp == NULL || def == NULL || all == NULL)
244 		return SSH_ERR_INVALID_ARGUMENT;
245 
246 	if (*listp == NULL || **listp == '\0') {
247 		if ((*listp = strdup(def)) == NULL)
248 			return SSH_ERR_ALLOC_FAIL;
249 		return 0;
250 	}
251 
252 	list = *listp;
253 	*listp = NULL;
254 	if (*list == '+') {
255 		/* Append names to default list */
256 		if ((tmp = kex_names_cat(def, list + 1)) == NULL) {
257 			r = SSH_ERR_ALLOC_FAIL;
258 			goto fail;
259 		}
260 		free(list);
261 		list = tmp;
262 	} else if (*list == '-') {
263 		/* Remove names from default list */
264 		if ((*listp = match_filter_denylist(def, list + 1)) == NULL) {
265 			r = SSH_ERR_ALLOC_FAIL;
266 			goto fail;
267 		}
268 		free(list);
269 		/* filtering has already been done */
270 		return 0;
271 	} else if (*list == '^') {
272 		/* Place names at head of default list */
273 		if ((tmp = kex_names_cat(list + 1, def)) == NULL) {
274 			r = SSH_ERR_ALLOC_FAIL;
275 			goto fail;
276 		}
277 		free(list);
278 		list = tmp;
279 	} else {
280 		/* Explicit list, overrides default - just use "list" as is */
281 	}
282 
283 	/*
284 	 * The supplied names may be a pattern-list. For the -list case,
285 	 * the patterns are applied above. For the +list and explicit list
286 	 * cases we need to do it now.
287 	 */
288 	ret = NULL;
289 	if ((patterns = opatterns = strdup(list)) == NULL) {
290 		r = SSH_ERR_ALLOC_FAIL;
291 		goto fail;
292 	}
293 	/* Apply positive (i.e. non-negated) patterns from the list */
294 	while ((cp = strsep(&patterns, ",")) != NULL) {
295 		if (*cp == '!') {
296 			/* negated matches are not supported here */
297 			r = SSH_ERR_INVALID_ARGUMENT;
298 			goto fail;
299 		}
300 		free(matching);
301 		if ((matching = match_filter_allowlist(all, cp)) == NULL) {
302 			r = SSH_ERR_ALLOC_FAIL;
303 			goto fail;
304 		}
305 		if ((tmp = kex_names_cat(ret, matching)) == NULL) {
306 			r = SSH_ERR_ALLOC_FAIL;
307 			goto fail;
308 		}
309 		free(ret);
310 		ret = tmp;
311 	}
312 	if (ret == NULL || *ret == '\0') {
313 		/* An empty name-list is an error */
314 		/* XXX better error code? */
315 		r = SSH_ERR_INVALID_ARGUMENT;
316 		goto fail;
317 	}
318 
319 	/* success */
320 	*listp = ret;
321 	ret = NULL;
322 	r = 0;
323 
324  fail:
325 	free(matching);
326 	free(opatterns);
327 	free(list);
328 	free(ret);
329 	return r;
330 }
331 
332 /*
333  * Fill out a proposal array with dynamically allocated values, which may
334  * be modified as required for compatibility reasons.
335  * Any of the options may be NULL, in which case the default is used.
336  * Array contents must be freed by calling kex_proposal_free_entries.
337  */
338 void
339 kex_proposal_populate_entries(struct ssh *ssh, char *prop[PROPOSAL_MAX],
340     const char *kexalgos, const char *ciphers, const char *macs,
341     const char *comp, const char *hkalgs)
342 {
343 	const char *defpropserver[PROPOSAL_MAX] = { KEX_SERVER };
344 	const char *defpropclient[PROPOSAL_MAX] = { KEX_CLIENT };
345 	const char **defprop = ssh->kex->server ? defpropserver : defpropclient;
346 	u_int i;
347 	char *cp;
348 
349 	if (prop == NULL)
350 		fatal_f("proposal missing");
351 
352 	/* Append EXT_INFO signalling to KexAlgorithms */
353 	if (kexalgos == NULL)
354 		kexalgos = defprop[PROPOSAL_KEX_ALGS];
355 	if ((cp = kex_names_cat(kexalgos, ssh->kex->server ?
356 	    "ext-info-s,kex-strict-s-v00@openssh.com" :
357 	    "ext-info-c,kex-strict-c-v00@openssh.com")) == NULL)
358 		fatal_f("kex_names_cat");
359 
360 	for (i = 0; i < PROPOSAL_MAX; i++) {
361 		switch(i) {
362 		case PROPOSAL_KEX_ALGS:
363 			prop[i] = compat_kex_proposal(ssh, cp);
364 			break;
365 		case PROPOSAL_ENC_ALGS_CTOS:
366 		case PROPOSAL_ENC_ALGS_STOC:
367 			prop[i] = xstrdup(ciphers ? ciphers : defprop[i]);
368 			break;
369 		case PROPOSAL_MAC_ALGS_CTOS:
370 		case PROPOSAL_MAC_ALGS_STOC:
371 			prop[i]  = xstrdup(macs ? macs : defprop[i]);
372 			break;
373 		case PROPOSAL_COMP_ALGS_CTOS:
374 		case PROPOSAL_COMP_ALGS_STOC:
375 			prop[i] = xstrdup(comp ? comp : defprop[i]);
376 			break;
377 		case PROPOSAL_SERVER_HOST_KEY_ALGS:
378 			prop[i] = xstrdup(hkalgs ? hkalgs : defprop[i]);
379 			break;
380 		default:
381 			prop[i] = xstrdup(defprop[i]);
382 		}
383 	}
384 	free(cp);
385 }
386 
387 void
388 kex_proposal_free_entries(char *prop[PROPOSAL_MAX])
389 {
390 	u_int i;
391 
392 	for (i = 0; i < PROPOSAL_MAX; i++)
393 		free(prop[i]);
394 }
395 
396 /* put algorithm proposal into buffer */
397 int
398 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
399 {
400 	u_int i;
401 	int r;
402 
403 	sshbuf_reset(b);
404 
405 	/*
406 	 * add a dummy cookie, the cookie will be overwritten by
407 	 * kex_send_kexinit(), each time a kexinit is set
408 	 */
409 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
410 		if ((r = sshbuf_put_u8(b, 0)) != 0)
411 			return r;
412 	}
413 	for (i = 0; i < PROPOSAL_MAX; i++) {
414 		if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
415 			return r;
416 	}
417 	if ((r = sshbuf_put_u8(b, 0)) != 0 ||	/* first_kex_packet_follows */
418 	    (r = sshbuf_put_u32(b, 0)) != 0)	/* uint32 reserved */
419 		return r;
420 	return 0;
421 }
422 
423 /* parse buffer and return algorithm proposal */
424 int
425 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
426 {
427 	struct sshbuf *b = NULL;
428 	u_char v;
429 	u_int i;
430 	char **proposal = NULL;
431 	int r;
432 
433 	*propp = NULL;
434 	if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
435 		return SSH_ERR_ALLOC_FAIL;
436 	if ((b = sshbuf_fromb(raw)) == NULL) {
437 		r = SSH_ERR_ALLOC_FAIL;
438 		goto out;
439 	}
440 	if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) { /* skip cookie */
441 		error_fr(r, "consume cookie");
442 		goto out;
443 	}
444 	/* extract kex init proposal strings */
445 	for (i = 0; i < PROPOSAL_MAX; i++) {
446 		if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0) {
447 			error_fr(r, "parse proposal %u", i);
448 			goto out;
449 		}
450 		debug2("%s: %s", proposal_names[i], proposal[i]);
451 	}
452 	/* first kex follows / reserved */
453 	if ((r = sshbuf_get_u8(b, &v)) != 0 ||	/* first_kex_follows */
454 	    (r = sshbuf_get_u32(b, &i)) != 0) {	/* reserved */
455 		error_fr(r, "parse");
456 		goto out;
457 	}
458 	if (first_kex_follows != NULL)
459 		*first_kex_follows = v;
460 	debug2("first_kex_follows %d ", v);
461 	debug2("reserved %u ", i);
462 	r = 0;
463 	*propp = proposal;
464  out:
465 	if (r != 0 && proposal != NULL)
466 		kex_prop_free(proposal);
467 	sshbuf_free(b);
468 	return r;
469 }
470 
471 void
472 kex_prop_free(char **proposal)
473 {
474 	u_int i;
475 
476 	if (proposal == NULL)
477 		return;
478 	for (i = 0; i < PROPOSAL_MAX; i++)
479 		free(proposal[i]);
480 	free(proposal);
481 }
482 
483 int
484 kex_protocol_error(int type, u_int32_t seq, struct ssh *ssh)
485 {
486 	int r;
487 
488 	/* If in strict mode, any unexpected message is an error */
489 	if ((ssh->kex->flags & KEX_INITIAL) && ssh->kex->kex_strict) {
490 		ssh_packet_disconnect(ssh, "strict KEX violation: "
491 		    "unexpected packet type %u (seqnr %u)", type, seq);
492 	}
493 	error_f("type %u seq %u", type, seq);
494 	if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
495 	    (r = sshpkt_put_u32(ssh, seq)) != 0 ||
496 	    (r = sshpkt_send(ssh)) != 0)
497 		return r;
498 	return 0;
499 }
500 
501 static void
502 kex_reset_dispatch(struct ssh *ssh)
503 {
504 	ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
505 	    SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
506 }
507 
508 void
509 kex_set_server_sig_algs(struct ssh *ssh, const char *allowed_algs)
510 {
511 	char *alg, *oalgs, *algs, *sigalgs;
512 	const char *sigalg;
513 
514 	/*
515 	 * NB. allowed algorithms may contain certificate algorithms that
516 	 * map to a specific plain signature type, e.g.
517 	 * rsa-sha2-512-cert-v01@openssh.com => rsa-sha2-512
518 	 * We need to be careful here to match these, retain the mapping
519 	 * and only add each signature algorithm once.
520 	 */
521 	if ((sigalgs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
522 		fatal_f("sshkey_alg_list failed");
523 	oalgs = algs = xstrdup(allowed_algs);
524 	free(ssh->kex->server_sig_algs);
525 	ssh->kex->server_sig_algs = NULL;
526 	for ((alg = strsep(&algs, ",")); alg != NULL && *alg != '\0';
527 	    (alg = strsep(&algs, ","))) {
528 		if ((sigalg = sshkey_sigalg_by_name(alg)) == NULL)
529 			continue;
530 		if (!has_any_alg(sigalg, sigalgs))
531 			continue;
532 		/* Don't add an algorithm twice. */
533 		if (ssh->kex->server_sig_algs != NULL &&
534 		    has_any_alg(sigalg, ssh->kex->server_sig_algs))
535 			continue;
536 		xextendf(&ssh->kex->server_sig_algs, ",", "%s", sigalg);
537 	}
538 	free(oalgs);
539 	free(sigalgs);
540 	if (ssh->kex->server_sig_algs == NULL)
541 		ssh->kex->server_sig_algs = xstrdup("");
542 }
543 
544 static int
545 kex_compose_ext_info_server(struct ssh *ssh, struct sshbuf *m)
546 {
547 	int r;
548 
549 	if (ssh->kex->server_sig_algs == NULL &&
550 	    (ssh->kex->server_sig_algs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
551 		return SSH_ERR_ALLOC_FAIL;
552 	if ((r = sshbuf_put_u32(m, 3)) != 0 ||
553 	    (r = sshbuf_put_cstring(m, "server-sig-algs")) != 0 ||
554 	    (r = sshbuf_put_cstring(m, ssh->kex->server_sig_algs)) != 0 ||
555 	    (r = sshbuf_put_cstring(m,
556 	    "publickey-hostbound@openssh.com")) != 0 ||
557 	    (r = sshbuf_put_cstring(m, "0")) != 0 ||
558 	    (r = sshbuf_put_cstring(m, "ping@openssh.com")) != 0 ||
559 	    (r = sshbuf_put_cstring(m, "0")) != 0) {
560 		error_fr(r, "compose");
561 		return r;
562 	}
563 	return 0;
564 }
565 
566 static int
567 kex_compose_ext_info_client(struct ssh *ssh, struct sshbuf *m)
568 {
569 	int r;
570 
571 	if ((r = sshbuf_put_u32(m, 1)) != 0 ||
572 	    (r = sshbuf_put_cstring(m, "ext-info-in-auth@openssh.com")) != 0 ||
573 	    (r = sshbuf_put_cstring(m, "0")) != 0) {
574 		error_fr(r, "compose");
575 		goto out;
576 	}
577 	/* success */
578 	r = 0;
579  out:
580 	return r;
581 }
582 
583 static int
584 kex_maybe_send_ext_info(struct ssh *ssh)
585 {
586 	int r;
587 	struct sshbuf *m = NULL;
588 
589 	if ((ssh->kex->flags & KEX_INITIAL) == 0)
590 		return 0;
591 	if (!ssh->kex->ext_info_c && !ssh->kex->ext_info_s)
592 		return 0;
593 
594 	/* Compose EXT_INFO packet. */
595 	if ((m = sshbuf_new()) == NULL)
596 		fatal_f("sshbuf_new failed");
597 	if (ssh->kex->ext_info_c &&
598 	    (r = kex_compose_ext_info_server(ssh, m)) != 0)
599 		goto fail;
600 	if (ssh->kex->ext_info_s &&
601 	    (r = kex_compose_ext_info_client(ssh, m)) != 0)
602 		goto fail;
603 
604 	/* Send the actual KEX_INFO packet */
605 	debug("Sending SSH2_MSG_EXT_INFO");
606 	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
607 	    (r = sshpkt_putb(ssh, m)) != 0 ||
608 	    (r = sshpkt_send(ssh)) != 0) {
609 		error_f("send EXT_INFO");
610 		goto fail;
611 	}
612 
613 	r = 0;
614 
615  fail:
616 	sshbuf_free(m);
617 	return r;
618 }
619 
620 int
621 kex_server_update_ext_info(struct ssh *ssh)
622 {
623 	int r;
624 
625 	if ((ssh->kex->flags & KEX_HAS_EXT_INFO_IN_AUTH) == 0)
626 		return 0;
627 
628 	debug_f("Sending SSH2_MSG_EXT_INFO");
629 	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
630 	    (r = sshpkt_put_u32(ssh, 1)) != 0 ||
631 	    (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
632 	    (r = sshpkt_put_cstring(ssh, ssh->kex->server_sig_algs)) != 0 ||
633 	    (r = sshpkt_send(ssh)) != 0) {
634 		error_f("send EXT_INFO");
635 		return r;
636 	}
637 	return 0;
638 }
639 
640 int
641 kex_send_newkeys(struct ssh *ssh)
642 {
643 	int r;
644 
645 	kex_reset_dispatch(ssh);
646 	if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
647 	    (r = sshpkt_send(ssh)) != 0)
648 		return r;
649 	debug("SSH2_MSG_NEWKEYS sent");
650 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
651 	if ((r = kex_maybe_send_ext_info(ssh)) != 0)
652 		return r;
653 	debug("expecting SSH2_MSG_NEWKEYS");
654 	return 0;
655 }
656 
657 /* Check whether an ext_info value contains the expected version string */
658 static int
659 kex_ext_info_check_ver(struct kex *kex, const char *name,
660     const u_char *val, size_t len, const char *want_ver, u_int flag)
661 {
662 	if (memchr(val, '\0', len) != NULL) {
663 		error("SSH2_MSG_EXT_INFO: %s value contains nul byte", name);
664 		return SSH_ERR_INVALID_FORMAT;
665 	}
666 	debug_f("%s=<%s>", name, val);
667 	if (strcmp(val, want_ver) == 0)
668 		kex->flags |= flag;
669 	else
670 		debug_f("unsupported version of %s extension", name);
671 	return 0;
672 }
673 
674 static int
675 kex_ext_info_client_parse(struct ssh *ssh, const char *name,
676     const u_char *value, size_t vlen)
677 {
678 	int r;
679 
680 	/* NB. some messages are only accepted in the initial EXT_INFO */
681 	if (strcmp(name, "server-sig-algs") == 0) {
682 		/* Ensure no \0 lurking in value */
683 		if (memchr(value, '\0', vlen) != NULL) {
684 			error_f("nul byte in %s", name);
685 			return SSH_ERR_INVALID_FORMAT;
686 		}
687 		debug_f("%s=<%s>", name, value);
688 		free(ssh->kex->server_sig_algs);
689 		ssh->kex->server_sig_algs = xstrdup((const char *)value);
690 	} else if (ssh->kex->ext_info_received == 1 &&
691 	    strcmp(name, "publickey-hostbound@openssh.com") == 0) {
692 		if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
693 		    "0", KEX_HAS_PUBKEY_HOSTBOUND)) != 0) {
694 			return r;
695 		}
696 	} else if (ssh->kex->ext_info_received == 1 &&
697 	    strcmp(name, "ping@openssh.com") == 0) {
698 		if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
699 		    "0", KEX_HAS_PING)) != 0) {
700 			return r;
701 		}
702 	} else
703 		debug_f("%s (unrecognised)", name);
704 
705 	return 0;
706 }
707 
708 static int
709 kex_ext_info_server_parse(struct ssh *ssh, const char *name,
710     const u_char *value, size_t vlen)
711 {
712 	int r;
713 
714 	if (strcmp(name, "ext-info-in-auth@openssh.com") == 0) {
715 		if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
716 		    "0", KEX_HAS_EXT_INFO_IN_AUTH)) != 0) {
717 			return r;
718 		}
719 	} else
720 		debug_f("%s (unrecognised)", name);
721 	return 0;
722 }
723 
724 int
725 kex_input_ext_info(int type, u_int32_t seq, struct ssh *ssh)
726 {
727 	struct kex *kex = ssh->kex;
728 	const int max_ext_info = kex->server ? 1 : 2;
729 	u_int32_t i, ninfo;
730 	char *name;
731 	u_char *val;
732 	size_t vlen;
733 	int r;
734 
735 	debug("SSH2_MSG_EXT_INFO received");
736 	if (++kex->ext_info_received > max_ext_info) {
737 		error("too many SSH2_MSG_EXT_INFO messages sent by peer");
738 		return dispatch_protocol_error(type, seq, ssh);
739 	}
740 	ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
741 	if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
742 		return r;
743 	if (ninfo >= 1024) {
744 		error("SSH2_MSG_EXT_INFO with too many entries, expected "
745 		    "<=1024, received %u", ninfo);
746 		return dispatch_protocol_error(type, seq, ssh);
747 	}
748 	for (i = 0; i < ninfo; i++) {
749 		if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
750 			return r;
751 		if ((r = sshpkt_get_string(ssh, &val, &vlen)) != 0) {
752 			free(name);
753 			return r;
754 		}
755 		debug3_f("extension %s", name);
756 		if (kex->server) {
757 			if ((r = kex_ext_info_server_parse(ssh, name,
758 			    val, vlen)) != 0)
759 				return r;
760 		} else {
761 			if ((r = kex_ext_info_client_parse(ssh, name,
762 			    val, vlen)) != 0)
763 				return r;
764 		}
765 		free(name);
766 		free(val);
767 	}
768 	return sshpkt_get_end(ssh);
769 }
770 
771 static int
772 kex_input_newkeys(int type, u_int32_t seq, struct ssh *ssh)
773 {
774 	struct kex *kex = ssh->kex;
775 	int r, initial = (kex->flags & KEX_INITIAL) != 0;
776 	char *cp, **prop;
777 
778 	debug("SSH2_MSG_NEWKEYS received");
779 	if (kex->ext_info_c && initial)
780 		ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_input_ext_info);
781 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
782 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
783 	if ((r = sshpkt_get_end(ssh)) != 0)
784 		return r;
785 	if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0)
786 		return r;
787 	if (initial) {
788 		/* Remove initial KEX signalling from proposal for rekeying */
789 		if ((r = kex_buf2prop(kex->my, NULL, &prop)) != 0)
790 			return r;
791 		if ((cp = match_filter_denylist(prop[PROPOSAL_KEX_ALGS],
792 		    kex->server ?
793 		    "ext-info-s,kex-strict-s-v00@openssh.com" :
794 		    "ext-info-c,kex-strict-c-v00@openssh.com")) == NULL) {
795 			error_f("match_filter_denylist failed");
796 			goto fail;
797 		}
798 		free(prop[PROPOSAL_KEX_ALGS]);
799 		prop[PROPOSAL_KEX_ALGS] = cp;
800 		if ((r = kex_prop2buf(ssh->kex->my, prop)) != 0) {
801 			error_f("kex_prop2buf failed");
802  fail:
803 			kex_proposal_free_entries(prop);
804 			free(prop);
805 			return SSH_ERR_INTERNAL_ERROR;
806 		}
807 		kex_proposal_free_entries(prop);
808 		free(prop);
809 	}
810 	kex->done = 1;
811 	kex->flags &= ~KEX_INITIAL;
812 	sshbuf_reset(kex->peer);
813 	kex->flags &= ~KEX_INIT_SENT;
814 	free(kex->name);
815 	kex->name = NULL;
816 	return 0;
817 }
818 
819 int
820 kex_send_kexinit(struct ssh *ssh)
821 {
822 	u_char *cookie;
823 	struct kex *kex = ssh->kex;
824 	int r;
825 
826 	if (kex == NULL) {
827 		error_f("no kex");
828 		return SSH_ERR_INTERNAL_ERROR;
829 	}
830 	if (kex->flags & KEX_INIT_SENT)
831 		return 0;
832 	kex->done = 0;
833 
834 	/* generate a random cookie */
835 	if (sshbuf_len(kex->my) < KEX_COOKIE_LEN) {
836 		error_f("bad kex length: %zu < %d",
837 		    sshbuf_len(kex->my), KEX_COOKIE_LEN);
838 		return SSH_ERR_INVALID_FORMAT;
839 	}
840 	if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL) {
841 		error_f("buffer error");
842 		return SSH_ERR_INTERNAL_ERROR;
843 	}
844 	arc4random_buf(cookie, KEX_COOKIE_LEN);
845 
846 	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
847 	    (r = sshpkt_putb(ssh, kex->my)) != 0 ||
848 	    (r = sshpkt_send(ssh)) != 0) {
849 		error_fr(r, "compose reply");
850 		return r;
851 	}
852 	debug("SSH2_MSG_KEXINIT sent");
853 	kex->flags |= KEX_INIT_SENT;
854 	return 0;
855 }
856 
857 int
858 kex_input_kexinit(int type, u_int32_t seq, struct ssh *ssh)
859 {
860 	struct kex *kex = ssh->kex;
861 	const u_char *ptr;
862 	u_int i;
863 	size_t dlen;
864 	int r;
865 
866 	debug("SSH2_MSG_KEXINIT received");
867 	if (kex == NULL) {
868 		error_f("no kex");
869 		return SSH_ERR_INTERNAL_ERROR;
870 	}
871 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_protocol_error);
872 	ptr = sshpkt_ptr(ssh, &dlen);
873 	if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
874 		return r;
875 
876 	/* discard packet */
877 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
878 		if ((r = sshpkt_get_u8(ssh, NULL)) != 0) {
879 			error_fr(r, "discard cookie");
880 			return r;
881 		}
882 	}
883 	for (i = 0; i < PROPOSAL_MAX; i++) {
884 		if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0) {
885 			error_fr(r, "discard proposal");
886 			return r;
887 		}
888 	}
889 	/*
890 	 * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
891 	 * KEX method has the server move first, but a server might be using
892 	 * a custom method or one that we otherwise don't support. We should
893 	 * be prepared to remember first_kex_follows here so we can eat a
894 	 * packet later.
895 	 * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
896 	 * for cases where the server *doesn't* go first. I guess we should
897 	 * ignore it when it is set for these cases, which is what we do now.
898 	 */
899 	if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||	/* first_kex_follows */
900 	    (r = sshpkt_get_u32(ssh, NULL)) != 0 ||	/* reserved */
901 	    (r = sshpkt_get_end(ssh)) != 0)
902 			return r;
903 
904 	if (!(kex->flags & KEX_INIT_SENT))
905 		if ((r = kex_send_kexinit(ssh)) != 0)
906 			return r;
907 	if ((r = kex_choose_conf(ssh, seq)) != 0)
908 		return r;
909 
910 	if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
911 		return (kex->kex[kex->kex_type])(ssh);
912 
913 	error_f("unknown kex type %u", kex->kex_type);
914 	return SSH_ERR_INTERNAL_ERROR;
915 }
916 
917 struct kex *
918 kex_new(void)
919 {
920 	struct kex *kex;
921 
922 	if ((kex = calloc(1, sizeof(*kex))) == NULL ||
923 	    (kex->peer = sshbuf_new()) == NULL ||
924 	    (kex->my = sshbuf_new()) == NULL ||
925 	    (kex->client_version = sshbuf_new()) == NULL ||
926 	    (kex->server_version = sshbuf_new()) == NULL ||
927 	    (kex->session_id = sshbuf_new()) == NULL) {
928 		kex_free(kex);
929 		return NULL;
930 	}
931 	return kex;
932 }
933 
934 void
935 kex_free_newkeys(struct newkeys *newkeys)
936 {
937 	if (newkeys == NULL)
938 		return;
939 	if (newkeys->enc.key) {
940 		explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
941 		free(newkeys->enc.key);
942 		newkeys->enc.key = NULL;
943 	}
944 	if (newkeys->enc.iv) {
945 		explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
946 		free(newkeys->enc.iv);
947 		newkeys->enc.iv = NULL;
948 	}
949 	free(newkeys->enc.name);
950 	explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
951 	free(newkeys->comp.name);
952 	explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
953 	mac_clear(&newkeys->mac);
954 	if (newkeys->mac.key) {
955 		explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
956 		free(newkeys->mac.key);
957 		newkeys->mac.key = NULL;
958 	}
959 	free(newkeys->mac.name);
960 	explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
961 	freezero(newkeys, sizeof(*newkeys));
962 }
963 
964 void
965 kex_free(struct kex *kex)
966 {
967 	u_int mode;
968 
969 	if (kex == NULL)
970 		return;
971 
972 #ifdef WITH_OPENSSL
973 	DH_free(kex->dh);
974 #ifdef OPENSSL_HAS_ECC
975 	EC_KEY_free(kex->ec_client_key);
976 #endif /* OPENSSL_HAS_ECC */
977 #endif /* WITH_OPENSSL */
978 	for (mode = 0; mode < MODE_MAX; mode++) {
979 		kex_free_newkeys(kex->newkeys[mode]);
980 		kex->newkeys[mode] = NULL;
981 	}
982 	sshbuf_free(kex->peer);
983 	sshbuf_free(kex->my);
984 	sshbuf_free(kex->client_version);
985 	sshbuf_free(kex->server_version);
986 	sshbuf_free(kex->client_pub);
987 	sshbuf_free(kex->session_id);
988 	sshbuf_free(kex->initial_sig);
989 	sshkey_free(kex->initial_hostkey);
990 	free(kex->failed_choice);
991 	free(kex->hostkey_alg);
992 	free(kex->name);
993 	free(kex);
994 }
995 
996 int
997 kex_ready(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
998 {
999 	int r;
1000 
1001 	if ((r = kex_prop2buf(ssh->kex->my, proposal)) != 0)
1002 		return r;
1003 	ssh->kex->flags = KEX_INITIAL;
1004 	kex_reset_dispatch(ssh);
1005 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
1006 	return 0;
1007 }
1008 
1009 int
1010 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
1011 {
1012 	int r;
1013 
1014 	if ((r = kex_ready(ssh, proposal)) != 0)
1015 		return r;
1016 	if ((r = kex_send_kexinit(ssh)) != 0) {		/* we start */
1017 		kex_free(ssh->kex);
1018 		ssh->kex = NULL;
1019 		return r;
1020 	}
1021 	return 0;
1022 }
1023 
1024 /*
1025  * Request key re-exchange, returns 0 on success or a ssherr.h error
1026  * code otherwise. Must not be called if KEX is incomplete or in-progress.
1027  */
1028 int
1029 kex_start_rekex(struct ssh *ssh)
1030 {
1031 	if (ssh->kex == NULL) {
1032 		error_f("no kex");
1033 		return SSH_ERR_INTERNAL_ERROR;
1034 	}
1035 	if (ssh->kex->done == 0) {
1036 		error_f("requested twice");
1037 		return SSH_ERR_INTERNAL_ERROR;
1038 	}
1039 	ssh->kex->done = 0;
1040 	return kex_send_kexinit(ssh);
1041 }
1042 
1043 static int
1044 choose_enc(struct sshenc *enc, char *client, char *server)
1045 {
1046 	char *name = match_list(client, server, NULL);
1047 
1048 	if (name == NULL)
1049 		return SSH_ERR_NO_CIPHER_ALG_MATCH;
1050 	if ((enc->cipher = cipher_by_name(name)) == NULL) {
1051 		error_f("unsupported cipher %s", name);
1052 		free(name);
1053 		return SSH_ERR_INTERNAL_ERROR;
1054 	}
1055 	enc->name = name;
1056 	enc->enabled = 0;
1057 	enc->iv = NULL;
1058 	enc->iv_len = cipher_ivlen(enc->cipher);
1059 	enc->key = NULL;
1060 	enc->key_len = cipher_keylen(enc->cipher);
1061 	enc->block_size = cipher_blocksize(enc->cipher);
1062 	return 0;
1063 }
1064 
1065 static int
1066 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
1067 {
1068 	char *name = match_list(client, server, NULL);
1069 
1070 	if (name == NULL)
1071 		return SSH_ERR_NO_MAC_ALG_MATCH;
1072 	if (mac_setup(mac, name) < 0) {
1073 		error_f("unsupported MAC %s", name);
1074 		free(name);
1075 		return SSH_ERR_INTERNAL_ERROR;
1076 	}
1077 	mac->name = name;
1078 	mac->key = NULL;
1079 	mac->enabled = 0;
1080 	return 0;
1081 }
1082 
1083 static int
1084 choose_comp(struct sshcomp *comp, char *client, char *server)
1085 {
1086 	char *name = match_list(client, server, NULL);
1087 
1088 	if (name == NULL)
1089 		return SSH_ERR_NO_COMPRESS_ALG_MATCH;
1090 #ifdef WITH_ZLIB
1091 	if (strcmp(name, "zlib@openssh.com") == 0) {
1092 		comp->type = COMP_DELAYED;
1093 	} else if (strcmp(name, "zlib") == 0) {
1094 		comp->type = COMP_ZLIB;
1095 	} else
1096 #endif	/* WITH_ZLIB */
1097 	if (strcmp(name, "none") == 0) {
1098 		comp->type = COMP_NONE;
1099 	} else {
1100 		error_f("unsupported compression scheme %s", name);
1101 		free(name);
1102 		return SSH_ERR_INTERNAL_ERROR;
1103 	}
1104 	comp->name = name;
1105 	return 0;
1106 }
1107 
1108 static int
1109 choose_kex(struct kex *k, char *client, char *server)
1110 {
1111 	const struct kexalg *kexalg;
1112 
1113 	k->name = match_list(client, server, NULL);
1114 
1115 	debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
1116 	if (k->name == NULL)
1117 		return SSH_ERR_NO_KEX_ALG_MATCH;
1118 	if ((kexalg = kex_alg_by_name(k->name)) == NULL) {
1119 		error_f("unsupported KEX method %s", k->name);
1120 		return SSH_ERR_INTERNAL_ERROR;
1121 	}
1122 	k->kex_type = kexalg->type;
1123 	k->hash_alg = kexalg->hash_alg;
1124 	k->ec_nid = kexalg->ec_nid;
1125 	return 0;
1126 }
1127 
1128 static int
1129 choose_hostkeyalg(struct kex *k, char *client, char *server)
1130 {
1131 	free(k->hostkey_alg);
1132 	k->hostkey_alg = match_list(client, server, NULL);
1133 
1134 	debug("kex: host key algorithm: %s",
1135 	    k->hostkey_alg ? k->hostkey_alg : "(no match)");
1136 	if (k->hostkey_alg == NULL)
1137 		return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
1138 	k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
1139 	if (k->hostkey_type == KEY_UNSPEC) {
1140 		error_f("unsupported hostkey algorithm %s", k->hostkey_alg);
1141 		return SSH_ERR_INTERNAL_ERROR;
1142 	}
1143 	k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
1144 	return 0;
1145 }
1146 
1147 static int
1148 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
1149 {
1150 	static int check[] = {
1151 		PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
1152 	};
1153 	int *idx;
1154 	char *p;
1155 
1156 	for (idx = &check[0]; *idx != -1; idx++) {
1157 		if ((p = strchr(my[*idx], ',')) != NULL)
1158 			*p = '\0';
1159 		if ((p = strchr(peer[*idx], ',')) != NULL)
1160 			*p = '\0';
1161 		if (strcmp(my[*idx], peer[*idx]) != 0) {
1162 			debug2("proposal mismatch: my %s peer %s",
1163 			    my[*idx], peer[*idx]);
1164 			return (0);
1165 		}
1166 	}
1167 	debug2("proposals match");
1168 	return (1);
1169 }
1170 
1171 static int
1172 kexalgs_contains(char **peer, const char *ext)
1173 {
1174 	return has_any_alg(peer[PROPOSAL_KEX_ALGS], ext);
1175 }
1176 
1177 static int
1178 kex_choose_conf(struct ssh *ssh, uint32_t seq)
1179 {
1180 	struct kex *kex = ssh->kex;
1181 	struct newkeys *newkeys;
1182 	char **my = NULL, **peer = NULL;
1183 	char **cprop, **sprop;
1184 	int nenc, nmac, ncomp;
1185 	u_int mode, ctos, need, dh_need, authlen;
1186 	int r, first_kex_follows;
1187 
1188 	debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
1189 	if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
1190 		goto out;
1191 	debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
1192 	if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
1193 		goto out;
1194 
1195 	if (kex->server) {
1196 		cprop=peer;
1197 		sprop=my;
1198 	} else {
1199 		cprop=my;
1200 		sprop=peer;
1201 	}
1202 
1203 	/* Check whether peer supports ext_info/kex_strict */
1204 	if ((kex->flags & KEX_INITIAL) != 0) {
1205 		if (kex->server) {
1206 			kex->ext_info_c = kexalgs_contains(peer, "ext-info-c");
1207 			kex->kex_strict = kexalgs_contains(peer,
1208 			    "kex-strict-c-v00@openssh.com");
1209 		} else {
1210 			kex->ext_info_s = kexalgs_contains(peer, "ext-info-s");
1211 			kex->kex_strict = kexalgs_contains(peer,
1212 			    "kex-strict-s-v00@openssh.com");
1213 		}
1214 		if (kex->kex_strict) {
1215 			debug3_f("will use strict KEX ordering");
1216 			if (seq != 0)
1217 				ssh_packet_disconnect(ssh,
1218 				    "strict KEX violation: "
1219 				    "KEXINIT was not the first packet");
1220 		}
1221 	}
1222 
1223 	/* Check whether client supports rsa-sha2 algorithms */
1224 	if (kex->server && (kex->flags & KEX_INITIAL)) {
1225 		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1226 		    "rsa-sha2-256,rsa-sha2-256-cert-v01@openssh.com"))
1227 			kex->flags |= KEX_RSA_SHA2_256_SUPPORTED;
1228 		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1229 		    "rsa-sha2-512,rsa-sha2-512-cert-v01@openssh.com"))
1230 			kex->flags |= KEX_RSA_SHA2_512_SUPPORTED;
1231 	}
1232 
1233 	/* Algorithm Negotiation */
1234 	if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
1235 	    sprop[PROPOSAL_KEX_ALGS])) != 0) {
1236 		kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
1237 		peer[PROPOSAL_KEX_ALGS] = NULL;
1238 		goto out;
1239 	}
1240 	if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
1241 	    sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
1242 		kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
1243 		peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
1244 		goto out;
1245 	}
1246 	for (mode = 0; mode < MODE_MAX; mode++) {
1247 		if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
1248 			r = SSH_ERR_ALLOC_FAIL;
1249 			goto out;
1250 		}
1251 		kex->newkeys[mode] = newkeys;
1252 		ctos = (!kex->server && mode == MODE_OUT) ||
1253 		    (kex->server && mode == MODE_IN);
1254 		nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
1255 		nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
1256 		ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
1257 		if ((r = choose_enc(&newkeys->enc, cprop[nenc],
1258 		    sprop[nenc])) != 0) {
1259 			kex->failed_choice = peer[nenc];
1260 			peer[nenc] = NULL;
1261 			goto out;
1262 		}
1263 		authlen = cipher_authlen(newkeys->enc.cipher);
1264 		/* ignore mac for authenticated encryption */
1265 		if (authlen == 0 &&
1266 		    (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
1267 		    sprop[nmac])) != 0) {
1268 			kex->failed_choice = peer[nmac];
1269 			peer[nmac] = NULL;
1270 			goto out;
1271 		}
1272 		if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
1273 		    sprop[ncomp])) != 0) {
1274 			kex->failed_choice = peer[ncomp];
1275 			peer[ncomp] = NULL;
1276 			goto out;
1277 		}
1278 		debug("kex: %s cipher: %s MAC: %s compression: %s",
1279 		    ctos ? "client->server" : "server->client",
1280 		    newkeys->enc.name,
1281 		    authlen == 0 ? newkeys->mac.name : "<implicit>",
1282 		    newkeys->comp.name);
1283 	}
1284 	need = dh_need = 0;
1285 	for (mode = 0; mode < MODE_MAX; mode++) {
1286 		newkeys = kex->newkeys[mode];
1287 		need = MAXIMUM(need, newkeys->enc.key_len);
1288 		need = MAXIMUM(need, newkeys->enc.block_size);
1289 		need = MAXIMUM(need, newkeys->enc.iv_len);
1290 		need = MAXIMUM(need, newkeys->mac.key_len);
1291 		dh_need = MAXIMUM(dh_need, cipher_seclen(newkeys->enc.cipher));
1292 		dh_need = MAXIMUM(dh_need, newkeys->enc.block_size);
1293 		dh_need = MAXIMUM(dh_need, newkeys->enc.iv_len);
1294 		dh_need = MAXIMUM(dh_need, newkeys->mac.key_len);
1295 	}
1296 	/* XXX need runden? */
1297 	kex->we_need = need;
1298 	kex->dh_need = dh_need;
1299 
1300 	/* ignore the next message if the proposals do not match */
1301 	if (first_kex_follows && !proposals_match(my, peer))
1302 		ssh->dispatch_skip_packets = 1;
1303 	r = 0;
1304  out:
1305 	kex_prop_free(my);
1306 	kex_prop_free(peer);
1307 	return r;
1308 }
1309 
1310 static int
1311 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
1312     const struct sshbuf *shared_secret, u_char **keyp)
1313 {
1314 	struct kex *kex = ssh->kex;
1315 	struct ssh_digest_ctx *hashctx = NULL;
1316 	char c = id;
1317 	u_int have;
1318 	size_t mdsz;
1319 	u_char *digest;
1320 	int r;
1321 
1322 	if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
1323 		return SSH_ERR_INVALID_ARGUMENT;
1324 	if ((digest = calloc(1, ROUNDUP(need, mdsz))) == NULL) {
1325 		r = SSH_ERR_ALLOC_FAIL;
1326 		goto out;
1327 	}
1328 
1329 	/* K1 = HASH(K || H || "A" || session_id) */
1330 	if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1331 	    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1332 	    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1333 	    ssh_digest_update(hashctx, &c, 1) != 0 ||
1334 	    ssh_digest_update_buffer(hashctx, kex->session_id) != 0 ||
1335 	    ssh_digest_final(hashctx, digest, mdsz) != 0) {
1336 		r = SSH_ERR_LIBCRYPTO_ERROR;
1337 		error_f("KEX hash failed");
1338 		goto out;
1339 	}
1340 	ssh_digest_free(hashctx);
1341 	hashctx = NULL;
1342 
1343 	/*
1344 	 * expand key:
1345 	 * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
1346 	 * Key = K1 || K2 || ... || Kn
1347 	 */
1348 	for (have = mdsz; need > have; have += mdsz) {
1349 		if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1350 		    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1351 		    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1352 		    ssh_digest_update(hashctx, digest, have) != 0 ||
1353 		    ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
1354 			error_f("KDF failed");
1355 			r = SSH_ERR_LIBCRYPTO_ERROR;
1356 			goto out;
1357 		}
1358 		ssh_digest_free(hashctx);
1359 		hashctx = NULL;
1360 	}
1361 #ifdef DEBUG_KEX
1362 	fprintf(stderr, "key '%c'== ", c);
1363 	dump_digest("key", digest, need);
1364 #endif
1365 	*keyp = digest;
1366 	digest = NULL;
1367 	r = 0;
1368  out:
1369 	free(digest);
1370 	ssh_digest_free(hashctx);
1371 	return r;
1372 }
1373 
1374 #define NKEYS	6
1375 int
1376 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
1377     const struct sshbuf *shared_secret)
1378 {
1379 	struct kex *kex = ssh->kex;
1380 	u_char *keys[NKEYS];
1381 	u_int i, j, mode, ctos;
1382 	int r;
1383 
1384 	/* save initial hash as session id */
1385 	if ((kex->flags & KEX_INITIAL) != 0) {
1386 		if (sshbuf_len(kex->session_id) != 0) {
1387 			error_f("already have session ID at kex");
1388 			return SSH_ERR_INTERNAL_ERROR;
1389 		}
1390 		if ((r = sshbuf_put(kex->session_id, hash, hashlen)) != 0)
1391 			return r;
1392 	} else if (sshbuf_len(kex->session_id) == 0) {
1393 		error_f("no session ID in rekex");
1394 		return SSH_ERR_INTERNAL_ERROR;
1395 	}
1396 	for (i = 0; i < NKEYS; i++) {
1397 		if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
1398 		    shared_secret, &keys[i])) != 0) {
1399 			for (j = 0; j < i; j++)
1400 				free(keys[j]);
1401 			return r;
1402 		}
1403 	}
1404 	for (mode = 0; mode < MODE_MAX; mode++) {
1405 		ctos = (!kex->server && mode == MODE_OUT) ||
1406 		    (kex->server && mode == MODE_IN);
1407 		kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
1408 		kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
1409 		kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
1410 	}
1411 	return 0;
1412 }
1413 
1414 int
1415 kex_load_hostkey(struct ssh *ssh, struct sshkey **prvp, struct sshkey **pubp)
1416 {
1417 	struct kex *kex = ssh->kex;
1418 
1419 	*pubp = NULL;
1420 	*prvp = NULL;
1421 	if (kex->load_host_public_key == NULL ||
1422 	    kex->load_host_private_key == NULL) {
1423 		error_f("missing hostkey loader");
1424 		return SSH_ERR_INVALID_ARGUMENT;
1425 	}
1426 	*pubp = kex->load_host_public_key(kex->hostkey_type,
1427 	    kex->hostkey_nid, ssh);
1428 	*prvp = kex->load_host_private_key(kex->hostkey_type,
1429 	    kex->hostkey_nid, ssh);
1430 	if (*pubp == NULL)
1431 		return SSH_ERR_NO_HOSTKEY_LOADED;
1432 	return 0;
1433 }
1434 
1435 int
1436 kex_verify_host_key(struct ssh *ssh, struct sshkey *server_host_key)
1437 {
1438 	struct kex *kex = ssh->kex;
1439 
1440 	if (kex->verify_host_key == NULL) {
1441 		error_f("missing hostkey verifier");
1442 		return SSH_ERR_INVALID_ARGUMENT;
1443 	}
1444 	if (server_host_key->type != kex->hostkey_type ||
1445 	    (kex->hostkey_type == KEY_ECDSA &&
1446 	    server_host_key->ecdsa_nid != kex->hostkey_nid))
1447 		return SSH_ERR_KEY_TYPE_MISMATCH;
1448 	if (kex->verify_host_key(server_host_key, ssh) == -1)
1449 		return  SSH_ERR_SIGNATURE_INVALID;
1450 	return 0;
1451 }
1452 
1453 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1454 void
1455 dump_digest(const char *msg, const u_char *digest, int len)
1456 {
1457 	fprintf(stderr, "%s\n", msg);
1458 	sshbuf_dump_data(digest, len, stderr);
1459 }
1460 #endif
1461 
1462 /*
1463  * Send a plaintext error message to the peer, suffixed by \r\n.
1464  * Only used during banner exchange, and there only for the server.
1465  */
1466 static void
1467 send_error(struct ssh *ssh, char *msg)
1468 {
1469 	char *crnl = "\r\n";
1470 
1471 	if (!ssh->kex->server)
1472 		return;
1473 
1474 	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1475 	    msg, strlen(msg)) != strlen(msg) ||
1476 	    atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1477 	    crnl, strlen(crnl)) != strlen(crnl))
1478 		error_f("write: %.100s", strerror(errno));
1479 }
1480 
1481 /*
1482  * Sends our identification string and waits for the peer's. Will block for
1483  * up to timeout_ms (or indefinitely if timeout_ms <= 0).
1484  * Returns on 0 success or a ssherr.h code on failure.
1485  */
1486 int
1487 kex_exchange_identification(struct ssh *ssh, int timeout_ms,
1488     const char *version_addendum)
1489 {
1490 	int remote_major, remote_minor, mismatch, oerrno = 0;
1491 	size_t len, n;
1492 	int r, expect_nl;
1493 	u_char c;
1494 	struct sshbuf *our_version = ssh->kex->server ?
1495 	    ssh->kex->server_version : ssh->kex->client_version;
1496 	struct sshbuf *peer_version = ssh->kex->server ?
1497 	    ssh->kex->client_version : ssh->kex->server_version;
1498 	char *our_version_string = NULL, *peer_version_string = NULL;
1499 	char *cp, *remote_version = NULL;
1500 
1501 	/* Prepare and send our banner */
1502 	sshbuf_reset(our_version);
1503 	if (version_addendum != NULL && *version_addendum == '\0')
1504 		version_addendum = NULL;
1505 	if ((r = sshbuf_putf(our_version, "SSH-%d.%d-%s%s%s\r\n",
1506 	    PROTOCOL_MAJOR_2, PROTOCOL_MINOR_2, SSH_VERSION,
1507 	    version_addendum == NULL ? "" : " ",
1508 	    version_addendum == NULL ? "" : version_addendum)) != 0) {
1509 		oerrno = errno;
1510 		error_fr(r, "sshbuf_putf");
1511 		goto out;
1512 	}
1513 
1514 	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1515 	    sshbuf_mutable_ptr(our_version),
1516 	    sshbuf_len(our_version)) != sshbuf_len(our_version)) {
1517 		oerrno = errno;
1518 		debug_f("write: %.100s", strerror(errno));
1519 		r = SSH_ERR_SYSTEM_ERROR;
1520 		goto out;
1521 	}
1522 	if ((r = sshbuf_consume_end(our_version, 2)) != 0) { /* trim \r\n */
1523 		oerrno = errno;
1524 		error_fr(r, "sshbuf_consume_end");
1525 		goto out;
1526 	}
1527 	our_version_string = sshbuf_dup_string(our_version);
1528 	if (our_version_string == NULL) {
1529 		error_f("sshbuf_dup_string failed");
1530 		r = SSH_ERR_ALLOC_FAIL;
1531 		goto out;
1532 	}
1533 	debug("Local version string %.100s", our_version_string);
1534 
1535 	/* Read other side's version identification. */
1536 	for (n = 0; ; n++) {
1537 		if (n >= SSH_MAX_PRE_BANNER_LINES) {
1538 			send_error(ssh, "No SSH identification string "
1539 			    "received.");
1540 			error_f("No SSH version received in first %u lines "
1541 			    "from server", SSH_MAX_PRE_BANNER_LINES);
1542 			r = SSH_ERR_INVALID_FORMAT;
1543 			goto out;
1544 		}
1545 		sshbuf_reset(peer_version);
1546 		expect_nl = 0;
1547 		for (;;) {
1548 			if (timeout_ms > 0) {
1549 				r = waitrfd(ssh_packet_get_connection_in(ssh),
1550 				    &timeout_ms, NULL);
1551 				if (r == -1 && errno == ETIMEDOUT) {
1552 					send_error(ssh, "Timed out waiting "
1553 					    "for SSH identification string.");
1554 					error("Connection timed out during "
1555 					    "banner exchange");
1556 					r = SSH_ERR_CONN_TIMEOUT;
1557 					goto out;
1558 				} else if (r == -1) {
1559 					oerrno = errno;
1560 					error_f("%s", strerror(errno));
1561 					r = SSH_ERR_SYSTEM_ERROR;
1562 					goto out;
1563 				}
1564 			}
1565 
1566 			len = atomicio(read, ssh_packet_get_connection_in(ssh),
1567 			    &c, 1);
1568 			if (len != 1 && errno == EPIPE) {
1569 				verbose_f("Connection closed by remote host");
1570 				r = SSH_ERR_CONN_CLOSED;
1571 				goto out;
1572 			} else if (len != 1) {
1573 				oerrno = errno;
1574 				error_f("read: %.100s", strerror(errno));
1575 				r = SSH_ERR_SYSTEM_ERROR;
1576 				goto out;
1577 			}
1578 			if (c == '\r') {
1579 				expect_nl = 1;
1580 				continue;
1581 			}
1582 			if (c == '\n')
1583 				break;
1584 			if (c == '\0' || expect_nl) {
1585 				verbose_f("banner line contains invalid "
1586 				    "characters");
1587 				goto invalid;
1588 			}
1589 			if ((r = sshbuf_put_u8(peer_version, c)) != 0) {
1590 				oerrno = errno;
1591 				error_fr(r, "sshbuf_put");
1592 				goto out;
1593 			}
1594 			if (sshbuf_len(peer_version) > SSH_MAX_BANNER_LEN) {
1595 				verbose_f("banner line too long");
1596 				goto invalid;
1597 			}
1598 		}
1599 		/* Is this an actual protocol banner? */
1600 		if (sshbuf_len(peer_version) > 4 &&
1601 		    memcmp(sshbuf_ptr(peer_version), "SSH-", 4) == 0)
1602 			break;
1603 		/* If not, then just log the line and continue */
1604 		if ((cp = sshbuf_dup_string(peer_version)) == NULL) {
1605 			error_f("sshbuf_dup_string failed");
1606 			r = SSH_ERR_ALLOC_FAIL;
1607 			goto out;
1608 		}
1609 		/* Do not accept lines before the SSH ident from a client */
1610 		if (ssh->kex->server) {
1611 			verbose_f("client sent invalid protocol identifier "
1612 			    "\"%.256s\"", cp);
1613 			free(cp);
1614 			goto invalid;
1615 		}
1616 		debug_f("banner line %zu: %s", n, cp);
1617 		free(cp);
1618 	}
1619 	peer_version_string = sshbuf_dup_string(peer_version);
1620 	if (peer_version_string == NULL)
1621 		fatal_f("sshbuf_dup_string failed");
1622 	/* XXX must be same size for sscanf */
1623 	if ((remote_version = calloc(1, sshbuf_len(peer_version))) == NULL) {
1624 		error_f("calloc failed");
1625 		r = SSH_ERR_ALLOC_FAIL;
1626 		goto out;
1627 	}
1628 
1629 	/*
1630 	 * Check that the versions match.  In future this might accept
1631 	 * several versions and set appropriate flags to handle them.
1632 	 */
1633 	if (sscanf(peer_version_string, "SSH-%d.%d-%[^\n]\n",
1634 	    &remote_major, &remote_minor, remote_version) != 3) {
1635 		error("Bad remote protocol version identification: '%.100s'",
1636 		    peer_version_string);
1637  invalid:
1638 		send_error(ssh, "Invalid SSH identification string.");
1639 		r = SSH_ERR_INVALID_FORMAT;
1640 		goto out;
1641 	}
1642 	debug("Remote protocol version %d.%d, remote software version %.100s",
1643 	    remote_major, remote_minor, remote_version);
1644 	compat_banner(ssh, remote_version);
1645 
1646 	mismatch = 0;
1647 	switch (remote_major) {
1648 	case 2:
1649 		break;
1650 	case 1:
1651 		if (remote_minor != 99)
1652 			mismatch = 1;
1653 		break;
1654 	default:
1655 		mismatch = 1;
1656 		break;
1657 	}
1658 	if (mismatch) {
1659 		error("Protocol major versions differ: %d vs. %d",
1660 		    PROTOCOL_MAJOR_2, remote_major);
1661 		send_error(ssh, "Protocol major versions differ.");
1662 		r = SSH_ERR_NO_PROTOCOL_VERSION;
1663 		goto out;
1664 	}
1665 
1666 	if (ssh->kex->server && (ssh->compat & SSH_BUG_PROBE) != 0) {
1667 		logit("probed from %s port %d with %s.  Don't panic.",
1668 		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1669 		    peer_version_string);
1670 		r = SSH_ERR_CONN_CLOSED; /* XXX */
1671 		goto out;
1672 	}
1673 	if (ssh->kex->server && (ssh->compat & SSH_BUG_SCANNER) != 0) {
1674 		logit("scanned from %s port %d with %s.  Don't panic.",
1675 		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1676 		    peer_version_string);
1677 		r = SSH_ERR_CONN_CLOSED; /* XXX */
1678 		goto out;
1679 	}
1680 	/* success */
1681 	r = 0;
1682  out:
1683 	free(our_version_string);
1684 	free(peer_version_string);
1685 	free(remote_version);
1686 	if (r == SSH_ERR_SYSTEM_ERROR)
1687 		errno = oerrno;
1688 	return r;
1689 }
1690 
1691