xref: /freebsd/crypto/openssh/kex.c (revision f73124b077d867990cbcb4d903b48be2ca55e4ca)
1 /* $OpenBSD: kex.c,v 1.184 2023/12/18 14:45:49 djm Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25 
26 #include "includes.h"
27 
28 #include <sys/types.h>
29 #include <errno.h>
30 #include <signal.h>
31 #include <stdarg.h>
32 #include <stdio.h>
33 #include <stdlib.h>
34 #include <string.h>
35 #include <unistd.h>
36 #ifdef HAVE_POLL_H
37 #include <poll.h>
38 #endif
39 
40 #ifdef WITH_OPENSSL
41 #include <openssl/crypto.h>
42 #include <openssl/dh.h>
43 #endif
44 
45 #include "ssh.h"
46 #include "ssh2.h"
47 #include "atomicio.h"
48 #include "version.h"
49 #include "packet.h"
50 #include "compat.h"
51 #include "cipher.h"
52 #include "sshkey.h"
53 #include "kex.h"
54 #include "log.h"
55 #include "mac.h"
56 #include "match.h"
57 #include "misc.h"
58 #include "dispatch.h"
59 #include "monitor.h"
60 #include "myproposal.h"
61 
62 #include "ssherr.h"
63 #include "sshbuf.h"
64 #include "digest.h"
65 #include "xmalloc.h"
66 
67 /* prototype */
68 static int kex_choose_conf(struct ssh *, uint32_t seq);
69 static int kex_input_newkeys(int, u_int32_t, struct ssh *);
70 
71 static const char * const proposal_names[PROPOSAL_MAX] = {
72 	"KEX algorithms",
73 	"host key algorithms",
74 	"ciphers ctos",
75 	"ciphers stoc",
76 	"MACs ctos",
77 	"MACs stoc",
78 	"compression ctos",
79 	"compression stoc",
80 	"languages ctos",
81 	"languages stoc",
82 };
83 
84 struct kexalg {
85 	char *name;
86 	u_int type;
87 	int ec_nid;
88 	int hash_alg;
89 };
90 static const struct kexalg kexalgs[] = {
91 #ifdef WITH_OPENSSL
92 	{ KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
93 	{ KEX_DH14_SHA1, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
94 	{ KEX_DH14_SHA256, KEX_DH_GRP14_SHA256, 0, SSH_DIGEST_SHA256 },
95 	{ KEX_DH16_SHA512, KEX_DH_GRP16_SHA512, 0, SSH_DIGEST_SHA512 },
96 	{ KEX_DH18_SHA512, KEX_DH_GRP18_SHA512, 0, SSH_DIGEST_SHA512 },
97 	{ KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
98 #ifdef HAVE_EVP_SHA256
99 	{ KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
100 #endif /* HAVE_EVP_SHA256 */
101 #ifdef OPENSSL_HAS_ECC
102 	{ KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
103 	    NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
104 	{ KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
105 	    SSH_DIGEST_SHA384 },
106 # ifdef OPENSSL_HAS_NISTP521
107 	{ KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
108 	    SSH_DIGEST_SHA512 },
109 # endif /* OPENSSL_HAS_NISTP521 */
110 #endif /* OPENSSL_HAS_ECC */
111 #endif /* WITH_OPENSSL */
112 #if defined(HAVE_EVP_SHA256) || !defined(WITH_OPENSSL)
113 	{ KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
114 	{ KEX_CURVE25519_SHA256_OLD, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
115 #ifdef USE_SNTRUP761X25519
116 	{ KEX_SNTRUP761X25519_SHA512, KEX_KEM_SNTRUP761X25519_SHA512, 0,
117 	    SSH_DIGEST_SHA512 },
118 #endif
119 #endif /* HAVE_EVP_SHA256 || !WITH_OPENSSL */
120 	{ NULL, 0, -1, -1},
121 };
122 
123 char *
124 kex_alg_list(char sep)
125 {
126 	char *ret = NULL, *tmp;
127 	size_t nlen, rlen = 0;
128 	const struct kexalg *k;
129 
130 	for (k = kexalgs; k->name != NULL; k++) {
131 		if (ret != NULL)
132 			ret[rlen++] = sep;
133 		nlen = strlen(k->name);
134 		if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
135 			free(ret);
136 			return NULL;
137 		}
138 		ret = tmp;
139 		memcpy(ret + rlen, k->name, nlen + 1);
140 		rlen += nlen;
141 	}
142 	return ret;
143 }
144 
145 static const struct kexalg *
146 kex_alg_by_name(const char *name)
147 {
148 	const struct kexalg *k;
149 
150 	for (k = kexalgs; k->name != NULL; k++) {
151 		if (strcmp(k->name, name) == 0)
152 			return k;
153 	}
154 	return NULL;
155 }
156 
157 /* Validate KEX method name list */
158 int
159 kex_names_valid(const char *names)
160 {
161 	char *s, *cp, *p;
162 
163 	if (names == NULL || strcmp(names, "") == 0)
164 		return 0;
165 	if ((s = cp = strdup(names)) == NULL)
166 		return 0;
167 	for ((p = strsep(&cp, ",")); p && *p != '\0';
168 	    (p = strsep(&cp, ","))) {
169 		if (kex_alg_by_name(p) == NULL) {
170 			error("Unsupported KEX algorithm \"%.100s\"", p);
171 			free(s);
172 			return 0;
173 		}
174 	}
175 	debug3("kex names ok: [%s]", names);
176 	free(s);
177 	return 1;
178 }
179 
180 /* returns non-zero if proposal contains any algorithm from algs */
181 static int
182 has_any_alg(const char *proposal, const char *algs)
183 {
184 	char *cp;
185 
186 	if ((cp = match_list(proposal, algs, NULL)) == NULL)
187 		return 0;
188 	free(cp);
189 	return 1;
190 }
191 
192 /*
193  * Concatenate algorithm names, avoiding duplicates in the process.
194  * Caller must free returned string.
195  */
196 char *
197 kex_names_cat(const char *a, const char *b)
198 {
199 	char *ret = NULL, *tmp = NULL, *cp, *p;
200 	size_t len;
201 
202 	if (a == NULL || *a == '\0')
203 		return strdup(b);
204 	if (b == NULL || *b == '\0')
205 		return strdup(a);
206 	if (strlen(b) > 1024*1024)
207 		return NULL;
208 	len = strlen(a) + strlen(b) + 2;
209 	if ((tmp = cp = strdup(b)) == NULL ||
210 	    (ret = calloc(1, len)) == NULL) {
211 		free(tmp);
212 		return NULL;
213 	}
214 	strlcpy(ret, a, len);
215 	for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
216 		if (has_any_alg(ret, p))
217 			continue; /* Algorithm already present */
218 		if (strlcat(ret, ",", len) >= len ||
219 		    strlcat(ret, p, len) >= len) {
220 			free(tmp);
221 			free(ret);
222 			return NULL; /* Shouldn't happen */
223 		}
224 	}
225 	free(tmp);
226 	return ret;
227 }
228 
229 /*
230  * Assemble a list of algorithms from a default list and a string from a
231  * configuration file. The user-provided string may begin with '+' to
232  * indicate that it should be appended to the default, '-' that the
233  * specified names should be removed, or '^' that they should be placed
234  * at the head.
235  */
236 int
237 kex_assemble_names(char **listp, const char *def, const char *all)
238 {
239 	char *cp, *tmp, *patterns;
240 	char *list = NULL, *ret = NULL, *matching = NULL, *opatterns = NULL;
241 	int r = SSH_ERR_INTERNAL_ERROR;
242 
243 	if (listp == NULL || def == NULL || all == NULL)
244 		return SSH_ERR_INVALID_ARGUMENT;
245 
246 	if (*listp == NULL || **listp == '\0') {
247 		if ((*listp = strdup(def)) == NULL)
248 			return SSH_ERR_ALLOC_FAIL;
249 		return 0;
250 	}
251 
252 	list = *listp;
253 	*listp = NULL;
254 	if (*list == '+') {
255 		/* Append names to default list */
256 		if ((tmp = kex_names_cat(def, list + 1)) == NULL) {
257 			r = SSH_ERR_ALLOC_FAIL;
258 			goto fail;
259 		}
260 		free(list);
261 		list = tmp;
262 	} else if (*list == '-') {
263 		/* Remove names from default list */
264 		if ((*listp = match_filter_denylist(def, list + 1)) == NULL) {
265 			r = SSH_ERR_ALLOC_FAIL;
266 			goto fail;
267 		}
268 		free(list);
269 		/* filtering has already been done */
270 		return 0;
271 	} else if (*list == '^') {
272 		/* Place names at head of default list */
273 		if ((tmp = kex_names_cat(list + 1, def)) == NULL) {
274 			r = SSH_ERR_ALLOC_FAIL;
275 			goto fail;
276 		}
277 		free(list);
278 		list = tmp;
279 	} else {
280 		/* Explicit list, overrides default - just use "list" as is */
281 	}
282 
283 	/*
284 	 * The supplied names may be a pattern-list. For the -list case,
285 	 * the patterns are applied above. For the +list and explicit list
286 	 * cases we need to do it now.
287 	 */
288 	ret = NULL;
289 	if ((patterns = opatterns = strdup(list)) == NULL) {
290 		r = SSH_ERR_ALLOC_FAIL;
291 		goto fail;
292 	}
293 	/* Apply positive (i.e. non-negated) patterns from the list */
294 	while ((cp = strsep(&patterns, ",")) != NULL) {
295 		if (*cp == '!') {
296 			/* negated matches are not supported here */
297 			r = SSH_ERR_INVALID_ARGUMENT;
298 			goto fail;
299 		}
300 		free(matching);
301 		if ((matching = match_filter_allowlist(all, cp)) == NULL) {
302 			r = SSH_ERR_ALLOC_FAIL;
303 			goto fail;
304 		}
305 		if ((tmp = kex_names_cat(ret, matching)) == NULL) {
306 			r = SSH_ERR_ALLOC_FAIL;
307 			goto fail;
308 		}
309 		free(ret);
310 		ret = tmp;
311 	}
312 	if (ret == NULL || *ret == '\0') {
313 		/* An empty name-list is an error */
314 		/* XXX better error code? */
315 		r = SSH_ERR_INVALID_ARGUMENT;
316 		goto fail;
317 	}
318 
319 	/* success */
320 	*listp = ret;
321 	ret = NULL;
322 	r = 0;
323 
324  fail:
325 	free(matching);
326 	free(opatterns);
327 	free(list);
328 	free(ret);
329 	return r;
330 }
331 
332 /*
333  * Fill out a proposal array with dynamically allocated values, which may
334  * be modified as required for compatibility reasons.
335  * Any of the options may be NULL, in which case the default is used.
336  * Array contents must be freed by calling kex_proposal_free_entries.
337  */
338 void
339 kex_proposal_populate_entries(struct ssh *ssh, char *prop[PROPOSAL_MAX],
340     const char *kexalgos, const char *ciphers, const char *macs,
341     const char *comp, const char *hkalgs)
342 {
343 	const char *defpropserver[PROPOSAL_MAX] = { KEX_SERVER };
344 	const char *defpropclient[PROPOSAL_MAX] = { KEX_CLIENT };
345 	const char **defprop = ssh->kex->server ? defpropserver : defpropclient;
346 	u_int i;
347 	char *cp;
348 
349 	if (prop == NULL)
350 		fatal_f("proposal missing");
351 
352 	/* Append EXT_INFO signalling to KexAlgorithms */
353 	if (kexalgos == NULL)
354 		kexalgos = defprop[PROPOSAL_KEX_ALGS];
355 	if ((cp = kex_names_cat(kexalgos, ssh->kex->server ?
356 	    "ext-info-s,kex-strict-s-v00@openssh.com" :
357 	    "ext-info-c,kex-strict-c-v00@openssh.com")) == NULL)
358 		fatal_f("kex_names_cat");
359 
360 	for (i = 0; i < PROPOSAL_MAX; i++) {
361 		switch(i) {
362 		case PROPOSAL_KEX_ALGS:
363 			prop[i] = compat_kex_proposal(ssh, cp);
364 			break;
365 		case PROPOSAL_ENC_ALGS_CTOS:
366 		case PROPOSAL_ENC_ALGS_STOC:
367 			prop[i] = xstrdup(ciphers ? ciphers : defprop[i]);
368 			break;
369 		case PROPOSAL_MAC_ALGS_CTOS:
370 		case PROPOSAL_MAC_ALGS_STOC:
371 			prop[i]  = xstrdup(macs ? macs : defprop[i]);
372 			break;
373 		case PROPOSAL_COMP_ALGS_CTOS:
374 		case PROPOSAL_COMP_ALGS_STOC:
375 			prop[i] = xstrdup(comp ? comp : defprop[i]);
376 			break;
377 		case PROPOSAL_SERVER_HOST_KEY_ALGS:
378 			prop[i] = xstrdup(hkalgs ? hkalgs : defprop[i]);
379 			break;
380 		default:
381 			prop[i] = xstrdup(defprop[i]);
382 		}
383 	}
384 	free(cp);
385 }
386 
387 void
388 kex_proposal_free_entries(char *prop[PROPOSAL_MAX])
389 {
390 	u_int i;
391 
392 	for (i = 0; i < PROPOSAL_MAX; i++)
393 		free(prop[i]);
394 }
395 
396 /* put algorithm proposal into buffer */
397 int
398 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
399 {
400 	u_int i;
401 	int r;
402 
403 	sshbuf_reset(b);
404 
405 	/*
406 	 * add a dummy cookie, the cookie will be overwritten by
407 	 * kex_send_kexinit(), each time a kexinit is set
408 	 */
409 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
410 		if ((r = sshbuf_put_u8(b, 0)) != 0)
411 			return r;
412 	}
413 	for (i = 0; i < PROPOSAL_MAX; i++) {
414 		if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
415 			return r;
416 	}
417 	if ((r = sshbuf_put_u8(b, 0)) != 0 ||	/* first_kex_packet_follows */
418 	    (r = sshbuf_put_u32(b, 0)) != 0)	/* uint32 reserved */
419 		return r;
420 	return 0;
421 }
422 
423 /* parse buffer and return algorithm proposal */
424 int
425 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
426 {
427 	struct sshbuf *b = NULL;
428 	u_char v;
429 	u_int i;
430 	char **proposal = NULL;
431 	int r;
432 
433 	*propp = NULL;
434 	if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
435 		return SSH_ERR_ALLOC_FAIL;
436 	if ((b = sshbuf_fromb(raw)) == NULL) {
437 		r = SSH_ERR_ALLOC_FAIL;
438 		goto out;
439 	}
440 	if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) { /* skip cookie */
441 		error_fr(r, "consume cookie");
442 		goto out;
443 	}
444 	/* extract kex init proposal strings */
445 	for (i = 0; i < PROPOSAL_MAX; i++) {
446 		if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0) {
447 			error_fr(r, "parse proposal %u", i);
448 			goto out;
449 		}
450 		debug2("%s: %s", proposal_names[i], proposal[i]);
451 	}
452 	/* first kex follows / reserved */
453 	if ((r = sshbuf_get_u8(b, &v)) != 0 ||	/* first_kex_follows */
454 	    (r = sshbuf_get_u32(b, &i)) != 0) {	/* reserved */
455 		error_fr(r, "parse");
456 		goto out;
457 	}
458 	if (first_kex_follows != NULL)
459 		*first_kex_follows = v;
460 	debug2("first_kex_follows %d ", v);
461 	debug2("reserved %u ", i);
462 	r = 0;
463 	*propp = proposal;
464  out:
465 	if (r != 0 && proposal != NULL)
466 		kex_prop_free(proposal);
467 	sshbuf_free(b);
468 	return r;
469 }
470 
471 void
472 kex_prop_free(char **proposal)
473 {
474 	u_int i;
475 
476 	if (proposal == NULL)
477 		return;
478 	for (i = 0; i < PROPOSAL_MAX; i++)
479 		free(proposal[i]);
480 	free(proposal);
481 }
482 
483 int
484 kex_protocol_error(int type, u_int32_t seq, struct ssh *ssh)
485 {
486 	int r;
487 
488 	/* If in strict mode, any unexpected message is an error */
489 	if ((ssh->kex->flags & KEX_INITIAL) && ssh->kex->kex_strict) {
490 		ssh_packet_disconnect(ssh, "strict KEX violation: "
491 		    "unexpected packet type %u (seqnr %u)", type, seq);
492 	}
493 	error_f("type %u seq %u", type, seq);
494 	if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
495 	    (r = sshpkt_put_u32(ssh, seq)) != 0 ||
496 	    (r = sshpkt_send(ssh)) != 0)
497 		return r;
498 	return 0;
499 }
500 
501 static void
502 kex_reset_dispatch(struct ssh *ssh)
503 {
504 	ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
505 	    SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
506 }
507 
508 void
509 kex_set_server_sig_algs(struct ssh *ssh, const char *allowed_algs)
510 {
511 	char *alg, *oalgs, *algs, *sigalgs;
512 	const char *sigalg;
513 
514 	/*
515 	 * NB. allowed algorithms may contain certificate algorithms that
516 	 * map to a specific plain signature type, e.g.
517 	 * rsa-sha2-512-cert-v01@openssh.com => rsa-sha2-512
518 	 * We need to be careful here to match these, retain the mapping
519 	 * and only add each signature algorithm once.
520 	 */
521 	if ((sigalgs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
522 		fatal_f("sshkey_alg_list failed");
523 	oalgs = algs = xstrdup(allowed_algs);
524 	free(ssh->kex->server_sig_algs);
525 	ssh->kex->server_sig_algs = NULL;
526 	for ((alg = strsep(&algs, ",")); alg != NULL && *alg != '\0';
527 	    (alg = strsep(&algs, ","))) {
528 		if ((sigalg = sshkey_sigalg_by_name(alg)) == NULL)
529 			continue;
530 		if (!has_any_alg(sigalg, sigalgs))
531 			continue;
532 		/* Don't add an algorithm twice. */
533 		if (ssh->kex->server_sig_algs != NULL &&
534 		    has_any_alg(sigalg, ssh->kex->server_sig_algs))
535 			continue;
536 		xextendf(&ssh->kex->server_sig_algs, ",", "%s", sigalg);
537 	}
538 	free(oalgs);
539 	free(sigalgs);
540 	if (ssh->kex->server_sig_algs == NULL)
541 		ssh->kex->server_sig_algs = xstrdup("");
542 }
543 
544 static int
545 kex_compose_ext_info_server(struct ssh *ssh, struct sshbuf *m)
546 {
547 	int r;
548 
549 	if (ssh->kex->server_sig_algs == NULL &&
550 	    (ssh->kex->server_sig_algs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
551 		return SSH_ERR_ALLOC_FAIL;
552 	if ((r = sshbuf_put_u32(m, 3)) != 0 ||
553 	    (r = sshbuf_put_cstring(m, "server-sig-algs")) != 0 ||
554 	    (r = sshbuf_put_cstring(m, ssh->kex->server_sig_algs)) != 0 ||
555 	    (r = sshbuf_put_cstring(m,
556 	    "publickey-hostbound@openssh.com")) != 0 ||
557 	    (r = sshbuf_put_cstring(m, "0")) != 0 ||
558 	    (r = sshbuf_put_cstring(m, "ping@openssh.com")) != 0 ||
559 	    (r = sshbuf_put_cstring(m, "0")) != 0) {
560 		error_fr(r, "compose");
561 		return r;
562 	}
563 	return 0;
564 }
565 
566 static int
567 kex_compose_ext_info_client(struct ssh *ssh, struct sshbuf *m)
568 {
569 	int r;
570 
571 	if ((r = sshbuf_put_u32(m, 1)) != 0 ||
572 	    (r = sshbuf_put_cstring(m, "ext-info-in-auth@openssh.com")) != 0 ||
573 	    (r = sshbuf_put_cstring(m, "0")) != 0) {
574 		error_fr(r, "compose");
575 		goto out;
576 	}
577 	/* success */
578 	r = 0;
579  out:
580 	return r;
581 }
582 
583 static int
584 kex_maybe_send_ext_info(struct ssh *ssh)
585 {
586 	int r;
587 	struct sshbuf *m = NULL;
588 
589 	if ((ssh->kex->flags & KEX_INITIAL) == 0)
590 		return 0;
591 	if (!ssh->kex->ext_info_c && !ssh->kex->ext_info_s)
592 		return 0;
593 
594 	/* Compose EXT_INFO packet. */
595 	if ((m = sshbuf_new()) == NULL)
596 		fatal_f("sshbuf_new failed");
597 	if (ssh->kex->ext_info_c &&
598 	    (r = kex_compose_ext_info_server(ssh, m)) != 0)
599 		goto fail;
600 	if (ssh->kex->ext_info_s &&
601 	    (r = kex_compose_ext_info_client(ssh, m)) != 0)
602 		goto fail;
603 
604 	/* Send the actual KEX_INFO packet */
605 	debug("Sending SSH2_MSG_EXT_INFO");
606 	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
607 	    (r = sshpkt_putb(ssh, m)) != 0 ||
608 	    (r = sshpkt_send(ssh)) != 0) {
609 		error_f("send EXT_INFO");
610 		goto fail;
611 	}
612 
613 	r = 0;
614 
615  fail:
616 	sshbuf_free(m);
617 	return r;
618 }
619 
620 int
621 kex_server_update_ext_info(struct ssh *ssh)
622 {
623 	int r;
624 
625 	if ((ssh->kex->flags & KEX_HAS_EXT_INFO_IN_AUTH) == 0)
626 		return 0;
627 
628 	debug_f("Sending SSH2_MSG_EXT_INFO");
629 	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
630 	    (r = sshpkt_put_u32(ssh, 1)) != 0 ||
631 	    (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
632 	    (r = sshpkt_put_cstring(ssh, ssh->kex->server_sig_algs)) != 0 ||
633 	    (r = sshpkt_send(ssh)) != 0) {
634 		error_f("send EXT_INFO");
635 		return r;
636 	}
637 	return 0;
638 }
639 
640 int
641 kex_send_newkeys(struct ssh *ssh)
642 {
643 	int r;
644 
645 	kex_reset_dispatch(ssh);
646 	if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
647 	    (r = sshpkt_send(ssh)) != 0)
648 		return r;
649 	debug("SSH2_MSG_NEWKEYS sent");
650 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
651 	if ((r = kex_maybe_send_ext_info(ssh)) != 0)
652 		return r;
653 	debug("expecting SSH2_MSG_NEWKEYS");
654 	return 0;
655 }
656 
657 /* Check whether an ext_info value contains the expected version string */
658 static int
659 kex_ext_info_check_ver(struct kex *kex, const char *name,
660     const u_char *val, size_t len, const char *want_ver, u_int flag)
661 {
662 	if (memchr(val, '\0', len) != NULL) {
663 		error("SSH2_MSG_EXT_INFO: %s value contains nul byte", name);
664 		return SSH_ERR_INVALID_FORMAT;
665 	}
666 	debug_f("%s=<%s>", name, val);
667 	if (strcmp(val, want_ver) == 0)
668 		kex->flags |= flag;
669 	else
670 		debug_f("unsupported version of %s extension", name);
671 	return 0;
672 }
673 
674 static int
675 kex_ext_info_client_parse(struct ssh *ssh, const char *name,
676     const u_char *value, size_t vlen)
677 {
678 	int r;
679 
680 	/* NB. some messages are only accepted in the initial EXT_INFO */
681 	if (strcmp(name, "server-sig-algs") == 0) {
682 		/* Ensure no \0 lurking in value */
683 		if (memchr(value, '\0', vlen) != NULL) {
684 			error_f("nul byte in %s", name);
685 			return SSH_ERR_INVALID_FORMAT;
686 		}
687 		debug_f("%s=<%s>", name, value);
688 		free(ssh->kex->server_sig_algs);
689 		ssh->kex->server_sig_algs = xstrdup((const char *)value);
690 	} else if (ssh->kex->ext_info_received == 1 &&
691 	    strcmp(name, "publickey-hostbound@openssh.com") == 0) {
692 		if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
693 		    "0", KEX_HAS_PUBKEY_HOSTBOUND)) != 0) {
694 			return r;
695 		}
696 	} else if (ssh->kex->ext_info_received == 1 &&
697 	    strcmp(name, "ping@openssh.com") == 0) {
698 		if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
699 		    "0", KEX_HAS_PING)) != 0) {
700 			return r;
701 		}
702 	} else
703 		debug_f("%s (unrecognised)", name);
704 
705 	return 0;
706 }
707 
708 static int
709 kex_ext_info_server_parse(struct ssh *ssh, const char *name,
710     const u_char *value, size_t vlen)
711 {
712 	int r;
713 
714 	if (strcmp(name, "ext-info-in-auth@openssh.com") == 0) {
715 		if ((r = kex_ext_info_check_ver(ssh->kex, name, value, vlen,
716 		    "0", KEX_HAS_EXT_INFO_IN_AUTH)) != 0) {
717 			return r;
718 		}
719 	} else
720 		debug_f("%s (unrecognised)", name);
721 	return 0;
722 }
723 
724 int
725 kex_input_ext_info(int type, u_int32_t seq, struct ssh *ssh)
726 {
727 	struct kex *kex = ssh->kex;
728 	const int max_ext_info = kex->server ? 1 : 2;
729 	u_int32_t i, ninfo;
730 	char *name;
731 	u_char *val;
732 	size_t vlen;
733 	int r;
734 
735 	debug("SSH2_MSG_EXT_INFO received");
736 	if (++kex->ext_info_received > max_ext_info) {
737 		error("too many SSH2_MSG_EXT_INFO messages sent by peer");
738 		return dispatch_protocol_error(type, seq, ssh);
739 	}
740 	ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
741 	if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
742 		return r;
743 	if (ninfo >= 1024) {
744 		error("SSH2_MSG_EXT_INFO with too many entries, expected "
745 		    "<=1024, received %u", ninfo);
746 		return dispatch_protocol_error(type, seq, ssh);
747 	}
748 	for (i = 0; i < ninfo; i++) {
749 		if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
750 			return r;
751 		if ((r = sshpkt_get_string(ssh, &val, &vlen)) != 0) {
752 			free(name);
753 			return r;
754 		}
755 		debug3_f("extension %s", name);
756 		if (kex->server) {
757 			if ((r = kex_ext_info_server_parse(ssh, name,
758 			    val, vlen)) != 0)
759 				return r;
760 		} else {
761 			if ((r = kex_ext_info_client_parse(ssh, name,
762 			    val, vlen)) != 0)
763 				return r;
764 		}
765 		free(name);
766 		free(val);
767 	}
768 	return sshpkt_get_end(ssh);
769 }
770 
771 static int
772 kex_input_newkeys(int type, u_int32_t seq, struct ssh *ssh)
773 {
774 	struct kex *kex = ssh->kex;
775 	int r;
776 
777 	debug("SSH2_MSG_NEWKEYS received");
778 	if (kex->ext_info_c && (kex->flags & KEX_INITIAL) != 0)
779 		ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_input_ext_info);
780 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
781 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
782 	if ((r = sshpkt_get_end(ssh)) != 0)
783 		return r;
784 	if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0)
785 		return r;
786 	kex->done = 1;
787 	kex->flags &= ~KEX_INITIAL;
788 	sshbuf_reset(kex->peer);
789 	/* sshbuf_reset(kex->my); */
790 	kex->flags &= ~KEX_INIT_SENT;
791 	free(kex->name);
792 	kex->name = NULL;
793 	return 0;
794 }
795 
796 int
797 kex_send_kexinit(struct ssh *ssh)
798 {
799 	u_char *cookie;
800 	struct kex *kex = ssh->kex;
801 	int r;
802 
803 	if (kex == NULL) {
804 		error_f("no kex");
805 		return SSH_ERR_INTERNAL_ERROR;
806 	}
807 	if (kex->flags & KEX_INIT_SENT)
808 		return 0;
809 	kex->done = 0;
810 
811 	/* generate a random cookie */
812 	if (sshbuf_len(kex->my) < KEX_COOKIE_LEN) {
813 		error_f("bad kex length: %zu < %d",
814 		    sshbuf_len(kex->my), KEX_COOKIE_LEN);
815 		return SSH_ERR_INVALID_FORMAT;
816 	}
817 	if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL) {
818 		error_f("buffer error");
819 		return SSH_ERR_INTERNAL_ERROR;
820 	}
821 	arc4random_buf(cookie, KEX_COOKIE_LEN);
822 
823 	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
824 	    (r = sshpkt_putb(ssh, kex->my)) != 0 ||
825 	    (r = sshpkt_send(ssh)) != 0) {
826 		error_fr(r, "compose reply");
827 		return r;
828 	}
829 	debug("SSH2_MSG_KEXINIT sent");
830 	kex->flags |= KEX_INIT_SENT;
831 	return 0;
832 }
833 
834 int
835 kex_input_kexinit(int type, u_int32_t seq, struct ssh *ssh)
836 {
837 	struct kex *kex = ssh->kex;
838 	const u_char *ptr;
839 	u_int i;
840 	size_t dlen;
841 	int r;
842 
843 	debug("SSH2_MSG_KEXINIT received");
844 	if (kex == NULL) {
845 		error_f("no kex");
846 		return SSH_ERR_INTERNAL_ERROR;
847 	}
848 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_protocol_error);
849 	ptr = sshpkt_ptr(ssh, &dlen);
850 	if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
851 		return r;
852 
853 	/* discard packet */
854 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
855 		if ((r = sshpkt_get_u8(ssh, NULL)) != 0) {
856 			error_fr(r, "discard cookie");
857 			return r;
858 		}
859 	}
860 	for (i = 0; i < PROPOSAL_MAX; i++) {
861 		if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0) {
862 			error_fr(r, "discard proposal");
863 			return r;
864 		}
865 	}
866 	/*
867 	 * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
868 	 * KEX method has the server move first, but a server might be using
869 	 * a custom method or one that we otherwise don't support. We should
870 	 * be prepared to remember first_kex_follows here so we can eat a
871 	 * packet later.
872 	 * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
873 	 * for cases where the server *doesn't* go first. I guess we should
874 	 * ignore it when it is set for these cases, which is what we do now.
875 	 */
876 	if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||	/* first_kex_follows */
877 	    (r = sshpkt_get_u32(ssh, NULL)) != 0 ||	/* reserved */
878 	    (r = sshpkt_get_end(ssh)) != 0)
879 			return r;
880 
881 	if (!(kex->flags & KEX_INIT_SENT))
882 		if ((r = kex_send_kexinit(ssh)) != 0)
883 			return r;
884 	if ((r = kex_choose_conf(ssh, seq)) != 0)
885 		return r;
886 
887 	if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
888 		return (kex->kex[kex->kex_type])(ssh);
889 
890 	error_f("unknown kex type %u", kex->kex_type);
891 	return SSH_ERR_INTERNAL_ERROR;
892 }
893 
894 struct kex *
895 kex_new(void)
896 {
897 	struct kex *kex;
898 
899 	if ((kex = calloc(1, sizeof(*kex))) == NULL ||
900 	    (kex->peer = sshbuf_new()) == NULL ||
901 	    (kex->my = sshbuf_new()) == NULL ||
902 	    (kex->client_version = sshbuf_new()) == NULL ||
903 	    (kex->server_version = sshbuf_new()) == NULL ||
904 	    (kex->session_id = sshbuf_new()) == NULL) {
905 		kex_free(kex);
906 		return NULL;
907 	}
908 	return kex;
909 }
910 
911 void
912 kex_free_newkeys(struct newkeys *newkeys)
913 {
914 	if (newkeys == NULL)
915 		return;
916 	if (newkeys->enc.key) {
917 		explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
918 		free(newkeys->enc.key);
919 		newkeys->enc.key = NULL;
920 	}
921 	if (newkeys->enc.iv) {
922 		explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
923 		free(newkeys->enc.iv);
924 		newkeys->enc.iv = NULL;
925 	}
926 	free(newkeys->enc.name);
927 	explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
928 	free(newkeys->comp.name);
929 	explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
930 	mac_clear(&newkeys->mac);
931 	if (newkeys->mac.key) {
932 		explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
933 		free(newkeys->mac.key);
934 		newkeys->mac.key = NULL;
935 	}
936 	free(newkeys->mac.name);
937 	explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
938 	freezero(newkeys, sizeof(*newkeys));
939 }
940 
941 void
942 kex_free(struct kex *kex)
943 {
944 	u_int mode;
945 
946 	if (kex == NULL)
947 		return;
948 
949 #ifdef WITH_OPENSSL
950 	DH_free(kex->dh);
951 #ifdef OPENSSL_HAS_ECC
952 	EC_KEY_free(kex->ec_client_key);
953 #endif /* OPENSSL_HAS_ECC */
954 #endif /* WITH_OPENSSL */
955 	for (mode = 0; mode < MODE_MAX; mode++) {
956 		kex_free_newkeys(kex->newkeys[mode]);
957 		kex->newkeys[mode] = NULL;
958 	}
959 	sshbuf_free(kex->peer);
960 	sshbuf_free(kex->my);
961 	sshbuf_free(kex->client_version);
962 	sshbuf_free(kex->server_version);
963 	sshbuf_free(kex->client_pub);
964 	sshbuf_free(kex->session_id);
965 	sshbuf_free(kex->initial_sig);
966 	sshkey_free(kex->initial_hostkey);
967 	free(kex->failed_choice);
968 	free(kex->hostkey_alg);
969 	free(kex->name);
970 	free(kex);
971 }
972 
973 int
974 kex_ready(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
975 {
976 	int r;
977 
978 	if ((r = kex_prop2buf(ssh->kex->my, proposal)) != 0)
979 		return r;
980 	ssh->kex->flags = KEX_INITIAL;
981 	kex_reset_dispatch(ssh);
982 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
983 	return 0;
984 }
985 
986 int
987 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
988 {
989 	int r;
990 
991 	if ((r = kex_ready(ssh, proposal)) != 0)
992 		return r;
993 	if ((r = kex_send_kexinit(ssh)) != 0) {		/* we start */
994 		kex_free(ssh->kex);
995 		ssh->kex = NULL;
996 		return r;
997 	}
998 	return 0;
999 }
1000 
1001 /*
1002  * Request key re-exchange, returns 0 on success or a ssherr.h error
1003  * code otherwise. Must not be called if KEX is incomplete or in-progress.
1004  */
1005 int
1006 kex_start_rekex(struct ssh *ssh)
1007 {
1008 	if (ssh->kex == NULL) {
1009 		error_f("no kex");
1010 		return SSH_ERR_INTERNAL_ERROR;
1011 	}
1012 	if (ssh->kex->done == 0) {
1013 		error_f("requested twice");
1014 		return SSH_ERR_INTERNAL_ERROR;
1015 	}
1016 	ssh->kex->done = 0;
1017 	return kex_send_kexinit(ssh);
1018 }
1019 
1020 static int
1021 choose_enc(struct sshenc *enc, char *client, char *server)
1022 {
1023 	char *name = match_list(client, server, NULL);
1024 
1025 	if (name == NULL)
1026 		return SSH_ERR_NO_CIPHER_ALG_MATCH;
1027 	if ((enc->cipher = cipher_by_name(name)) == NULL) {
1028 		error_f("unsupported cipher %s", name);
1029 		free(name);
1030 		return SSH_ERR_INTERNAL_ERROR;
1031 	}
1032 	enc->name = name;
1033 	enc->enabled = 0;
1034 	enc->iv = NULL;
1035 	enc->iv_len = cipher_ivlen(enc->cipher);
1036 	enc->key = NULL;
1037 	enc->key_len = cipher_keylen(enc->cipher);
1038 	enc->block_size = cipher_blocksize(enc->cipher);
1039 	return 0;
1040 }
1041 
1042 static int
1043 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
1044 {
1045 	char *name = match_list(client, server, NULL);
1046 
1047 	if (name == NULL)
1048 		return SSH_ERR_NO_MAC_ALG_MATCH;
1049 	if (mac_setup(mac, name) < 0) {
1050 		error_f("unsupported MAC %s", name);
1051 		free(name);
1052 		return SSH_ERR_INTERNAL_ERROR;
1053 	}
1054 	mac->name = name;
1055 	mac->key = NULL;
1056 	mac->enabled = 0;
1057 	return 0;
1058 }
1059 
1060 static int
1061 choose_comp(struct sshcomp *comp, char *client, char *server)
1062 {
1063 	char *name = match_list(client, server, NULL);
1064 
1065 	if (name == NULL)
1066 		return SSH_ERR_NO_COMPRESS_ALG_MATCH;
1067 #ifdef WITH_ZLIB
1068 	if (strcmp(name, "zlib@openssh.com") == 0) {
1069 		comp->type = COMP_DELAYED;
1070 	} else if (strcmp(name, "zlib") == 0) {
1071 		comp->type = COMP_ZLIB;
1072 	} else
1073 #endif	/* WITH_ZLIB */
1074 	if (strcmp(name, "none") == 0) {
1075 		comp->type = COMP_NONE;
1076 	} else {
1077 		error_f("unsupported compression scheme %s", name);
1078 		free(name);
1079 		return SSH_ERR_INTERNAL_ERROR;
1080 	}
1081 	comp->name = name;
1082 	return 0;
1083 }
1084 
1085 static int
1086 choose_kex(struct kex *k, char *client, char *server)
1087 {
1088 	const struct kexalg *kexalg;
1089 
1090 	k->name = match_list(client, server, NULL);
1091 
1092 	debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
1093 	if (k->name == NULL)
1094 		return SSH_ERR_NO_KEX_ALG_MATCH;
1095 	if ((kexalg = kex_alg_by_name(k->name)) == NULL) {
1096 		error_f("unsupported KEX method %s", k->name);
1097 		return SSH_ERR_INTERNAL_ERROR;
1098 	}
1099 	k->kex_type = kexalg->type;
1100 	k->hash_alg = kexalg->hash_alg;
1101 	k->ec_nid = kexalg->ec_nid;
1102 	return 0;
1103 }
1104 
1105 static int
1106 choose_hostkeyalg(struct kex *k, char *client, char *server)
1107 {
1108 	free(k->hostkey_alg);
1109 	k->hostkey_alg = match_list(client, server, NULL);
1110 
1111 	debug("kex: host key algorithm: %s",
1112 	    k->hostkey_alg ? k->hostkey_alg : "(no match)");
1113 	if (k->hostkey_alg == NULL)
1114 		return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
1115 	k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
1116 	if (k->hostkey_type == KEY_UNSPEC) {
1117 		error_f("unsupported hostkey algorithm %s", k->hostkey_alg);
1118 		return SSH_ERR_INTERNAL_ERROR;
1119 	}
1120 	k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
1121 	return 0;
1122 }
1123 
1124 static int
1125 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
1126 {
1127 	static int check[] = {
1128 		PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
1129 	};
1130 	int *idx;
1131 	char *p;
1132 
1133 	for (idx = &check[0]; *idx != -1; idx++) {
1134 		if ((p = strchr(my[*idx], ',')) != NULL)
1135 			*p = '\0';
1136 		if ((p = strchr(peer[*idx], ',')) != NULL)
1137 			*p = '\0';
1138 		if (strcmp(my[*idx], peer[*idx]) != 0) {
1139 			debug2("proposal mismatch: my %s peer %s",
1140 			    my[*idx], peer[*idx]);
1141 			return (0);
1142 		}
1143 	}
1144 	debug2("proposals match");
1145 	return (1);
1146 }
1147 
1148 static int
1149 kexalgs_contains(char **peer, const char *ext)
1150 {
1151 	return has_any_alg(peer[PROPOSAL_KEX_ALGS], ext);
1152 }
1153 
1154 static int
1155 kex_choose_conf(struct ssh *ssh, uint32_t seq)
1156 {
1157 	struct kex *kex = ssh->kex;
1158 	struct newkeys *newkeys;
1159 	char **my = NULL, **peer = NULL;
1160 	char **cprop, **sprop;
1161 	int nenc, nmac, ncomp;
1162 	u_int mode, ctos, need, dh_need, authlen;
1163 	int r, first_kex_follows;
1164 
1165 	debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
1166 	if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
1167 		goto out;
1168 	debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
1169 	if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
1170 		goto out;
1171 
1172 	if (kex->server) {
1173 		cprop=peer;
1174 		sprop=my;
1175 	} else {
1176 		cprop=my;
1177 		sprop=peer;
1178 	}
1179 
1180 	/* Check whether peer supports ext_info/kex_strict */
1181 	if ((kex->flags & KEX_INITIAL) != 0) {
1182 		if (kex->server) {
1183 			kex->ext_info_c = kexalgs_contains(peer, "ext-info-c");
1184 			kex->kex_strict = kexalgs_contains(peer,
1185 			    "kex-strict-c-v00@openssh.com");
1186 		} else {
1187 			kex->ext_info_s = kexalgs_contains(peer, "ext-info-s");
1188 			kex->kex_strict = kexalgs_contains(peer,
1189 			    "kex-strict-s-v00@openssh.com");
1190 		}
1191 		if (kex->kex_strict) {
1192 			debug3_f("will use strict KEX ordering");
1193 			if (seq != 0)
1194 				ssh_packet_disconnect(ssh,
1195 				    "strict KEX violation: "
1196 				    "KEXINIT was not the first packet");
1197 		}
1198 	}
1199 
1200 	/* Check whether client supports rsa-sha2 algorithms */
1201 	if (kex->server && (kex->flags & KEX_INITIAL)) {
1202 		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1203 		    "rsa-sha2-256,rsa-sha2-256-cert-v01@openssh.com"))
1204 			kex->flags |= KEX_RSA_SHA2_256_SUPPORTED;
1205 		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1206 		    "rsa-sha2-512,rsa-sha2-512-cert-v01@openssh.com"))
1207 			kex->flags |= KEX_RSA_SHA2_512_SUPPORTED;
1208 	}
1209 
1210 	/* Algorithm Negotiation */
1211 	if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
1212 	    sprop[PROPOSAL_KEX_ALGS])) != 0) {
1213 		kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
1214 		peer[PROPOSAL_KEX_ALGS] = NULL;
1215 		goto out;
1216 	}
1217 	if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
1218 	    sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
1219 		kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
1220 		peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
1221 		goto out;
1222 	}
1223 	for (mode = 0; mode < MODE_MAX; mode++) {
1224 		if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
1225 			r = SSH_ERR_ALLOC_FAIL;
1226 			goto out;
1227 		}
1228 		kex->newkeys[mode] = newkeys;
1229 		ctos = (!kex->server && mode == MODE_OUT) ||
1230 		    (kex->server && mode == MODE_IN);
1231 		nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
1232 		nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
1233 		ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
1234 		if ((r = choose_enc(&newkeys->enc, cprop[nenc],
1235 		    sprop[nenc])) != 0) {
1236 			kex->failed_choice = peer[nenc];
1237 			peer[nenc] = NULL;
1238 			goto out;
1239 		}
1240 		authlen = cipher_authlen(newkeys->enc.cipher);
1241 		/* ignore mac for authenticated encryption */
1242 		if (authlen == 0 &&
1243 		    (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
1244 		    sprop[nmac])) != 0) {
1245 			kex->failed_choice = peer[nmac];
1246 			peer[nmac] = NULL;
1247 			goto out;
1248 		}
1249 		if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
1250 		    sprop[ncomp])) != 0) {
1251 			kex->failed_choice = peer[ncomp];
1252 			peer[ncomp] = NULL;
1253 			goto out;
1254 		}
1255 		debug("kex: %s cipher: %s MAC: %s compression: %s",
1256 		    ctos ? "client->server" : "server->client",
1257 		    newkeys->enc.name,
1258 		    authlen == 0 ? newkeys->mac.name : "<implicit>",
1259 		    newkeys->comp.name);
1260 	}
1261 	need = dh_need = 0;
1262 	for (mode = 0; mode < MODE_MAX; mode++) {
1263 		newkeys = kex->newkeys[mode];
1264 		need = MAXIMUM(need, newkeys->enc.key_len);
1265 		need = MAXIMUM(need, newkeys->enc.block_size);
1266 		need = MAXIMUM(need, newkeys->enc.iv_len);
1267 		need = MAXIMUM(need, newkeys->mac.key_len);
1268 		dh_need = MAXIMUM(dh_need, cipher_seclen(newkeys->enc.cipher));
1269 		dh_need = MAXIMUM(dh_need, newkeys->enc.block_size);
1270 		dh_need = MAXIMUM(dh_need, newkeys->enc.iv_len);
1271 		dh_need = MAXIMUM(dh_need, newkeys->mac.key_len);
1272 	}
1273 	/* XXX need runden? */
1274 	kex->we_need = need;
1275 	kex->dh_need = dh_need;
1276 
1277 	/* ignore the next message if the proposals do not match */
1278 	if (first_kex_follows && !proposals_match(my, peer))
1279 		ssh->dispatch_skip_packets = 1;
1280 	r = 0;
1281  out:
1282 	kex_prop_free(my);
1283 	kex_prop_free(peer);
1284 	return r;
1285 }
1286 
1287 static int
1288 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
1289     const struct sshbuf *shared_secret, u_char **keyp)
1290 {
1291 	struct kex *kex = ssh->kex;
1292 	struct ssh_digest_ctx *hashctx = NULL;
1293 	char c = id;
1294 	u_int have;
1295 	size_t mdsz;
1296 	u_char *digest;
1297 	int r;
1298 
1299 	if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
1300 		return SSH_ERR_INVALID_ARGUMENT;
1301 	if ((digest = calloc(1, ROUNDUP(need, mdsz))) == NULL) {
1302 		r = SSH_ERR_ALLOC_FAIL;
1303 		goto out;
1304 	}
1305 
1306 	/* K1 = HASH(K || H || "A" || session_id) */
1307 	if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1308 	    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1309 	    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1310 	    ssh_digest_update(hashctx, &c, 1) != 0 ||
1311 	    ssh_digest_update_buffer(hashctx, kex->session_id) != 0 ||
1312 	    ssh_digest_final(hashctx, digest, mdsz) != 0) {
1313 		r = SSH_ERR_LIBCRYPTO_ERROR;
1314 		error_f("KEX hash failed");
1315 		goto out;
1316 	}
1317 	ssh_digest_free(hashctx);
1318 	hashctx = NULL;
1319 
1320 	/*
1321 	 * expand key:
1322 	 * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
1323 	 * Key = K1 || K2 || ... || Kn
1324 	 */
1325 	for (have = mdsz; need > have; have += mdsz) {
1326 		if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1327 		    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1328 		    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1329 		    ssh_digest_update(hashctx, digest, have) != 0 ||
1330 		    ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
1331 			error_f("KDF failed");
1332 			r = SSH_ERR_LIBCRYPTO_ERROR;
1333 			goto out;
1334 		}
1335 		ssh_digest_free(hashctx);
1336 		hashctx = NULL;
1337 	}
1338 #ifdef DEBUG_KEX
1339 	fprintf(stderr, "key '%c'== ", c);
1340 	dump_digest("key", digest, need);
1341 #endif
1342 	*keyp = digest;
1343 	digest = NULL;
1344 	r = 0;
1345  out:
1346 	free(digest);
1347 	ssh_digest_free(hashctx);
1348 	return r;
1349 }
1350 
1351 #define NKEYS	6
1352 int
1353 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
1354     const struct sshbuf *shared_secret)
1355 {
1356 	struct kex *kex = ssh->kex;
1357 	u_char *keys[NKEYS];
1358 	u_int i, j, mode, ctos;
1359 	int r;
1360 
1361 	/* save initial hash as session id */
1362 	if ((kex->flags & KEX_INITIAL) != 0) {
1363 		if (sshbuf_len(kex->session_id) != 0) {
1364 			error_f("already have session ID at kex");
1365 			return SSH_ERR_INTERNAL_ERROR;
1366 		}
1367 		if ((r = sshbuf_put(kex->session_id, hash, hashlen)) != 0)
1368 			return r;
1369 	} else if (sshbuf_len(kex->session_id) == 0) {
1370 		error_f("no session ID in rekex");
1371 		return SSH_ERR_INTERNAL_ERROR;
1372 	}
1373 	for (i = 0; i < NKEYS; i++) {
1374 		if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
1375 		    shared_secret, &keys[i])) != 0) {
1376 			for (j = 0; j < i; j++)
1377 				free(keys[j]);
1378 			return r;
1379 		}
1380 	}
1381 	for (mode = 0; mode < MODE_MAX; mode++) {
1382 		ctos = (!kex->server && mode == MODE_OUT) ||
1383 		    (kex->server && mode == MODE_IN);
1384 		kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
1385 		kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
1386 		kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
1387 	}
1388 	return 0;
1389 }
1390 
1391 int
1392 kex_load_hostkey(struct ssh *ssh, struct sshkey **prvp, struct sshkey **pubp)
1393 {
1394 	struct kex *kex = ssh->kex;
1395 
1396 	*pubp = NULL;
1397 	*prvp = NULL;
1398 	if (kex->load_host_public_key == NULL ||
1399 	    kex->load_host_private_key == NULL) {
1400 		error_f("missing hostkey loader");
1401 		return SSH_ERR_INVALID_ARGUMENT;
1402 	}
1403 	*pubp = kex->load_host_public_key(kex->hostkey_type,
1404 	    kex->hostkey_nid, ssh);
1405 	*prvp = kex->load_host_private_key(kex->hostkey_type,
1406 	    kex->hostkey_nid, ssh);
1407 	if (*pubp == NULL)
1408 		return SSH_ERR_NO_HOSTKEY_LOADED;
1409 	return 0;
1410 }
1411 
1412 int
1413 kex_verify_host_key(struct ssh *ssh, struct sshkey *server_host_key)
1414 {
1415 	struct kex *kex = ssh->kex;
1416 
1417 	if (kex->verify_host_key == NULL) {
1418 		error_f("missing hostkey verifier");
1419 		return SSH_ERR_INVALID_ARGUMENT;
1420 	}
1421 	if (server_host_key->type != kex->hostkey_type ||
1422 	    (kex->hostkey_type == KEY_ECDSA &&
1423 	    server_host_key->ecdsa_nid != kex->hostkey_nid))
1424 		return SSH_ERR_KEY_TYPE_MISMATCH;
1425 	if (kex->verify_host_key(server_host_key, ssh) == -1)
1426 		return  SSH_ERR_SIGNATURE_INVALID;
1427 	return 0;
1428 }
1429 
1430 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1431 void
1432 dump_digest(const char *msg, const u_char *digest, int len)
1433 {
1434 	fprintf(stderr, "%s\n", msg);
1435 	sshbuf_dump_data(digest, len, stderr);
1436 }
1437 #endif
1438 
1439 /*
1440  * Send a plaintext error message to the peer, suffixed by \r\n.
1441  * Only used during banner exchange, and there only for the server.
1442  */
1443 static void
1444 send_error(struct ssh *ssh, char *msg)
1445 {
1446 	char *crnl = "\r\n";
1447 
1448 	if (!ssh->kex->server)
1449 		return;
1450 
1451 	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1452 	    msg, strlen(msg)) != strlen(msg) ||
1453 	    atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1454 	    crnl, strlen(crnl)) != strlen(crnl))
1455 		error_f("write: %.100s", strerror(errno));
1456 }
1457 
1458 /*
1459  * Sends our identification string and waits for the peer's. Will block for
1460  * up to timeout_ms (or indefinitely if timeout_ms <= 0).
1461  * Returns on 0 success or a ssherr.h code on failure.
1462  */
1463 int
1464 kex_exchange_identification(struct ssh *ssh, int timeout_ms,
1465     const char *version_addendum)
1466 {
1467 	int remote_major, remote_minor, mismatch, oerrno = 0;
1468 	size_t len, n;
1469 	int r, expect_nl;
1470 	u_char c;
1471 	struct sshbuf *our_version = ssh->kex->server ?
1472 	    ssh->kex->server_version : ssh->kex->client_version;
1473 	struct sshbuf *peer_version = ssh->kex->server ?
1474 	    ssh->kex->client_version : ssh->kex->server_version;
1475 	char *our_version_string = NULL, *peer_version_string = NULL;
1476 	char *cp, *remote_version = NULL;
1477 
1478 	/* Prepare and send our banner */
1479 	sshbuf_reset(our_version);
1480 	if (version_addendum != NULL && *version_addendum == '\0')
1481 		version_addendum = NULL;
1482 	if ((r = sshbuf_putf(our_version, "SSH-%d.%d-%s%s%s\r\n",
1483 	    PROTOCOL_MAJOR_2, PROTOCOL_MINOR_2, SSH_VERSION,
1484 	    version_addendum == NULL ? "" : " ",
1485 	    version_addendum == NULL ? "" : version_addendum)) != 0) {
1486 		oerrno = errno;
1487 		error_fr(r, "sshbuf_putf");
1488 		goto out;
1489 	}
1490 
1491 	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1492 	    sshbuf_mutable_ptr(our_version),
1493 	    sshbuf_len(our_version)) != sshbuf_len(our_version)) {
1494 		oerrno = errno;
1495 		debug_f("write: %.100s", strerror(errno));
1496 		r = SSH_ERR_SYSTEM_ERROR;
1497 		goto out;
1498 	}
1499 	if ((r = sshbuf_consume_end(our_version, 2)) != 0) { /* trim \r\n */
1500 		oerrno = errno;
1501 		error_fr(r, "sshbuf_consume_end");
1502 		goto out;
1503 	}
1504 	our_version_string = sshbuf_dup_string(our_version);
1505 	if (our_version_string == NULL) {
1506 		error_f("sshbuf_dup_string failed");
1507 		r = SSH_ERR_ALLOC_FAIL;
1508 		goto out;
1509 	}
1510 	debug("Local version string %.100s", our_version_string);
1511 
1512 	/* Read other side's version identification. */
1513 	for (n = 0; ; n++) {
1514 		if (n >= SSH_MAX_PRE_BANNER_LINES) {
1515 			send_error(ssh, "No SSH identification string "
1516 			    "received.");
1517 			error_f("No SSH version received in first %u lines "
1518 			    "from server", SSH_MAX_PRE_BANNER_LINES);
1519 			r = SSH_ERR_INVALID_FORMAT;
1520 			goto out;
1521 		}
1522 		sshbuf_reset(peer_version);
1523 		expect_nl = 0;
1524 		for (;;) {
1525 			if (timeout_ms > 0) {
1526 				r = waitrfd(ssh_packet_get_connection_in(ssh),
1527 				    &timeout_ms, NULL);
1528 				if (r == -1 && errno == ETIMEDOUT) {
1529 					send_error(ssh, "Timed out waiting "
1530 					    "for SSH identification string.");
1531 					error("Connection timed out during "
1532 					    "banner exchange");
1533 					r = SSH_ERR_CONN_TIMEOUT;
1534 					goto out;
1535 				} else if (r == -1) {
1536 					oerrno = errno;
1537 					error_f("%s", strerror(errno));
1538 					r = SSH_ERR_SYSTEM_ERROR;
1539 					goto out;
1540 				}
1541 			}
1542 
1543 			len = atomicio(read, ssh_packet_get_connection_in(ssh),
1544 			    &c, 1);
1545 			if (len != 1 && errno == EPIPE) {
1546 				verbose_f("Connection closed by remote host");
1547 				r = SSH_ERR_CONN_CLOSED;
1548 				goto out;
1549 			} else if (len != 1) {
1550 				oerrno = errno;
1551 				error_f("read: %.100s", strerror(errno));
1552 				r = SSH_ERR_SYSTEM_ERROR;
1553 				goto out;
1554 			}
1555 			if (c == '\r') {
1556 				expect_nl = 1;
1557 				continue;
1558 			}
1559 			if (c == '\n')
1560 				break;
1561 			if (c == '\0' || expect_nl) {
1562 				verbose_f("banner line contains invalid "
1563 				    "characters");
1564 				goto invalid;
1565 			}
1566 			if ((r = sshbuf_put_u8(peer_version, c)) != 0) {
1567 				oerrno = errno;
1568 				error_fr(r, "sshbuf_put");
1569 				goto out;
1570 			}
1571 			if (sshbuf_len(peer_version) > SSH_MAX_BANNER_LEN) {
1572 				verbose_f("banner line too long");
1573 				goto invalid;
1574 			}
1575 		}
1576 		/* Is this an actual protocol banner? */
1577 		if (sshbuf_len(peer_version) > 4 &&
1578 		    memcmp(sshbuf_ptr(peer_version), "SSH-", 4) == 0)
1579 			break;
1580 		/* If not, then just log the line and continue */
1581 		if ((cp = sshbuf_dup_string(peer_version)) == NULL) {
1582 			error_f("sshbuf_dup_string failed");
1583 			r = SSH_ERR_ALLOC_FAIL;
1584 			goto out;
1585 		}
1586 		/* Do not accept lines before the SSH ident from a client */
1587 		if (ssh->kex->server) {
1588 			verbose_f("client sent invalid protocol identifier "
1589 			    "\"%.256s\"", cp);
1590 			free(cp);
1591 			goto invalid;
1592 		}
1593 		debug_f("banner line %zu: %s", n, cp);
1594 		free(cp);
1595 	}
1596 	peer_version_string = sshbuf_dup_string(peer_version);
1597 	if (peer_version_string == NULL)
1598 		fatal_f("sshbuf_dup_string failed");
1599 	/* XXX must be same size for sscanf */
1600 	if ((remote_version = calloc(1, sshbuf_len(peer_version))) == NULL) {
1601 		error_f("calloc failed");
1602 		r = SSH_ERR_ALLOC_FAIL;
1603 		goto out;
1604 	}
1605 
1606 	/*
1607 	 * Check that the versions match.  In future this might accept
1608 	 * several versions and set appropriate flags to handle them.
1609 	 */
1610 	if (sscanf(peer_version_string, "SSH-%d.%d-%[^\n]\n",
1611 	    &remote_major, &remote_minor, remote_version) != 3) {
1612 		error("Bad remote protocol version identification: '%.100s'",
1613 		    peer_version_string);
1614  invalid:
1615 		send_error(ssh, "Invalid SSH identification string.");
1616 		r = SSH_ERR_INVALID_FORMAT;
1617 		goto out;
1618 	}
1619 	debug("Remote protocol version %d.%d, remote software version %.100s",
1620 	    remote_major, remote_minor, remote_version);
1621 	compat_banner(ssh, remote_version);
1622 
1623 	mismatch = 0;
1624 	switch (remote_major) {
1625 	case 2:
1626 		break;
1627 	case 1:
1628 		if (remote_minor != 99)
1629 			mismatch = 1;
1630 		break;
1631 	default:
1632 		mismatch = 1;
1633 		break;
1634 	}
1635 	if (mismatch) {
1636 		error("Protocol major versions differ: %d vs. %d",
1637 		    PROTOCOL_MAJOR_2, remote_major);
1638 		send_error(ssh, "Protocol major versions differ.");
1639 		r = SSH_ERR_NO_PROTOCOL_VERSION;
1640 		goto out;
1641 	}
1642 
1643 	if (ssh->kex->server && (ssh->compat & SSH_BUG_PROBE) != 0) {
1644 		logit("probed from %s port %d with %s.  Don't panic.",
1645 		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1646 		    peer_version_string);
1647 		r = SSH_ERR_CONN_CLOSED; /* XXX */
1648 		goto out;
1649 	}
1650 	if (ssh->kex->server && (ssh->compat & SSH_BUG_SCANNER) != 0) {
1651 		logit("scanned from %s port %d with %s.  Don't panic.",
1652 		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1653 		    peer_version_string);
1654 		r = SSH_ERR_CONN_CLOSED; /* XXX */
1655 		goto out;
1656 	}
1657 	/* success */
1658 	r = 0;
1659  out:
1660 	free(our_version_string);
1661 	free(peer_version_string);
1662 	free(remote_version);
1663 	if (r == SSH_ERR_SYSTEM_ERROR)
1664 		errno = oerrno;
1665 	return r;
1666 }
1667 
1668