xref: /freebsd/crypto/openssh/kex.c (revision 45dd2eaac379e5576f745380260470204c49beac)
1 /* $OpenBSD: kex.c,v 1.172 2022/02/01 23:32:51 djm Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25 
26 #include "includes.h"
27 
28 #include <sys/types.h>
29 #include <errno.h>
30 #include <signal.h>
31 #include <stdarg.h>
32 #include <stdio.h>
33 #include <stdlib.h>
34 #include <string.h>
35 #include <unistd.h>
36 #ifdef HAVE_POLL_H
37 #include <poll.h>
38 #endif
39 
40 #ifdef WITH_OPENSSL
41 #include <openssl/crypto.h>
42 #include <openssl/dh.h>
43 #endif
44 
45 #include "ssh.h"
46 #include "ssh2.h"
47 #include "atomicio.h"
48 #include "version.h"
49 #include "packet.h"
50 #include "compat.h"
51 #include "cipher.h"
52 #include "sshkey.h"
53 #include "kex.h"
54 #include "log.h"
55 #include "mac.h"
56 #include "match.h"
57 #include "misc.h"
58 #include "dispatch.h"
59 #include "monitor.h"
60 
61 #include "ssherr.h"
62 #include "sshbuf.h"
63 #include "digest.h"
64 
65 /* prototype */
66 static int kex_choose_conf(struct ssh *);
67 static int kex_input_newkeys(int, u_int32_t, struct ssh *);
68 
69 static const char * const proposal_names[PROPOSAL_MAX] = {
70 	"KEX algorithms",
71 	"host key algorithms",
72 	"ciphers ctos",
73 	"ciphers stoc",
74 	"MACs ctos",
75 	"MACs stoc",
76 	"compression ctos",
77 	"compression stoc",
78 	"languages ctos",
79 	"languages stoc",
80 };
81 
82 struct kexalg {
83 	char *name;
84 	u_int type;
85 	int ec_nid;
86 	int hash_alg;
87 };
88 static const struct kexalg kexalgs[] = {
89 #ifdef WITH_OPENSSL
90 	{ KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
91 	{ KEX_DH14_SHA1, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
92 	{ KEX_DH14_SHA256, KEX_DH_GRP14_SHA256, 0, SSH_DIGEST_SHA256 },
93 	{ KEX_DH16_SHA512, KEX_DH_GRP16_SHA512, 0, SSH_DIGEST_SHA512 },
94 	{ KEX_DH18_SHA512, KEX_DH_GRP18_SHA512, 0, SSH_DIGEST_SHA512 },
95 	{ KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
96 #ifdef HAVE_EVP_SHA256
97 	{ KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
98 #endif /* HAVE_EVP_SHA256 */
99 #ifdef OPENSSL_HAS_ECC
100 	{ KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
101 	    NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
102 	{ KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
103 	    SSH_DIGEST_SHA384 },
104 # ifdef OPENSSL_HAS_NISTP521
105 	{ KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
106 	    SSH_DIGEST_SHA512 },
107 # endif /* OPENSSL_HAS_NISTP521 */
108 #endif /* OPENSSL_HAS_ECC */
109 #endif /* WITH_OPENSSL */
110 #if defined(HAVE_EVP_SHA256) || !defined(WITH_OPENSSL)
111 	{ KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
112 	{ KEX_CURVE25519_SHA256_OLD, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
113 #ifdef USE_SNTRUP761X25519
114 	{ KEX_SNTRUP761X25519_SHA512, KEX_KEM_SNTRUP761X25519_SHA512, 0,
115 	    SSH_DIGEST_SHA512 },
116 #endif
117 #endif /* HAVE_EVP_SHA256 || !WITH_OPENSSL */
118 	{ NULL, 0, -1, -1},
119 };
120 
121 char *
122 kex_alg_list(char sep)
123 {
124 	char *ret = NULL, *tmp;
125 	size_t nlen, rlen = 0;
126 	const struct kexalg *k;
127 
128 	for (k = kexalgs; k->name != NULL; k++) {
129 		if (ret != NULL)
130 			ret[rlen++] = sep;
131 		nlen = strlen(k->name);
132 		if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
133 			free(ret);
134 			return NULL;
135 		}
136 		ret = tmp;
137 		memcpy(ret + rlen, k->name, nlen + 1);
138 		rlen += nlen;
139 	}
140 	return ret;
141 }
142 
143 static const struct kexalg *
144 kex_alg_by_name(const char *name)
145 {
146 	const struct kexalg *k;
147 
148 	for (k = kexalgs; k->name != NULL; k++) {
149 		if (strcmp(k->name, name) == 0)
150 			return k;
151 	}
152 	return NULL;
153 }
154 
155 /* Validate KEX method name list */
156 int
157 kex_names_valid(const char *names)
158 {
159 	char *s, *cp, *p;
160 
161 	if (names == NULL || strcmp(names, "") == 0)
162 		return 0;
163 	if ((s = cp = strdup(names)) == NULL)
164 		return 0;
165 	for ((p = strsep(&cp, ",")); p && *p != '\0';
166 	    (p = strsep(&cp, ","))) {
167 		if (kex_alg_by_name(p) == NULL) {
168 			error("Unsupported KEX algorithm \"%.100s\"", p);
169 			free(s);
170 			return 0;
171 		}
172 	}
173 	debug3("kex names ok: [%s]", names);
174 	free(s);
175 	return 1;
176 }
177 
178 /*
179  * Concatenate algorithm names, avoiding duplicates in the process.
180  * Caller must free returned string.
181  */
182 char *
183 kex_names_cat(const char *a, const char *b)
184 {
185 	char *ret = NULL, *tmp = NULL, *cp, *p, *m;
186 	size_t len;
187 
188 	if (a == NULL || *a == '\0')
189 		return strdup(b);
190 	if (b == NULL || *b == '\0')
191 		return strdup(a);
192 	if (strlen(b) > 1024*1024)
193 		return NULL;
194 	len = strlen(a) + strlen(b) + 2;
195 	if ((tmp = cp = strdup(b)) == NULL ||
196 	    (ret = calloc(1, len)) == NULL) {
197 		free(tmp);
198 		return NULL;
199 	}
200 	strlcpy(ret, a, len);
201 	for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
202 		if ((m = match_list(ret, p, NULL)) != NULL) {
203 			free(m);
204 			continue; /* Algorithm already present */
205 		}
206 		if (strlcat(ret, ",", len) >= len ||
207 		    strlcat(ret, p, len) >= len) {
208 			free(tmp);
209 			free(ret);
210 			return NULL; /* Shouldn't happen */
211 		}
212 	}
213 	free(tmp);
214 	return ret;
215 }
216 
217 /*
218  * Assemble a list of algorithms from a default list and a string from a
219  * configuration file. The user-provided string may begin with '+' to
220  * indicate that it should be appended to the default, '-' that the
221  * specified names should be removed, or '^' that they should be placed
222  * at the head.
223  */
224 int
225 kex_assemble_names(char **listp, const char *def, const char *all)
226 {
227 	char *cp, *tmp, *patterns;
228 	char *list = NULL, *ret = NULL, *matching = NULL, *opatterns = NULL;
229 	int r = SSH_ERR_INTERNAL_ERROR;
230 
231 	if (listp == NULL || def == NULL || all == NULL)
232 		return SSH_ERR_INVALID_ARGUMENT;
233 
234 	if (*listp == NULL || **listp == '\0') {
235 		if ((*listp = strdup(def)) == NULL)
236 			return SSH_ERR_ALLOC_FAIL;
237 		return 0;
238 	}
239 
240 	list = *listp;
241 	*listp = NULL;
242 	if (*list == '+') {
243 		/* Append names to default list */
244 		if ((tmp = kex_names_cat(def, list + 1)) == NULL) {
245 			r = SSH_ERR_ALLOC_FAIL;
246 			goto fail;
247 		}
248 		free(list);
249 		list = tmp;
250 	} else if (*list == '-') {
251 		/* Remove names from default list */
252 		if ((*listp = match_filter_denylist(def, list + 1)) == NULL) {
253 			r = SSH_ERR_ALLOC_FAIL;
254 			goto fail;
255 		}
256 		free(list);
257 		/* filtering has already been done */
258 		return 0;
259 	} else if (*list == '^') {
260 		/* Place names at head of default list */
261 		if ((tmp = kex_names_cat(list + 1, def)) == NULL) {
262 			r = SSH_ERR_ALLOC_FAIL;
263 			goto fail;
264 		}
265 		free(list);
266 		list = tmp;
267 	} else {
268 		/* Explicit list, overrides default - just use "list" as is */
269 	}
270 
271 	/*
272 	 * The supplied names may be a pattern-list. For the -list case,
273 	 * the patterns are applied above. For the +list and explicit list
274 	 * cases we need to do it now.
275 	 */
276 	ret = NULL;
277 	if ((patterns = opatterns = strdup(list)) == NULL) {
278 		r = SSH_ERR_ALLOC_FAIL;
279 		goto fail;
280 	}
281 	/* Apply positive (i.e. non-negated) patterns from the list */
282 	while ((cp = strsep(&patterns, ",")) != NULL) {
283 		if (*cp == '!') {
284 			/* negated matches are not supported here */
285 			r = SSH_ERR_INVALID_ARGUMENT;
286 			goto fail;
287 		}
288 		free(matching);
289 		if ((matching = match_filter_allowlist(all, cp)) == NULL) {
290 			r = SSH_ERR_ALLOC_FAIL;
291 			goto fail;
292 		}
293 		if ((tmp = kex_names_cat(ret, matching)) == NULL) {
294 			r = SSH_ERR_ALLOC_FAIL;
295 			goto fail;
296 		}
297 		free(ret);
298 		ret = tmp;
299 	}
300 	if (ret == NULL || *ret == '\0') {
301 		/* An empty name-list is an error */
302 		/* XXX better error code? */
303 		r = SSH_ERR_INVALID_ARGUMENT;
304 		goto fail;
305 	}
306 
307 	/* success */
308 	*listp = ret;
309 	ret = NULL;
310 	r = 0;
311 
312  fail:
313 	free(matching);
314 	free(opatterns);
315 	free(list);
316 	free(ret);
317 	return r;
318 }
319 
320 /* put algorithm proposal into buffer */
321 int
322 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
323 {
324 	u_int i;
325 	int r;
326 
327 	sshbuf_reset(b);
328 
329 	/*
330 	 * add a dummy cookie, the cookie will be overwritten by
331 	 * kex_send_kexinit(), each time a kexinit is set
332 	 */
333 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
334 		if ((r = sshbuf_put_u8(b, 0)) != 0)
335 			return r;
336 	}
337 	for (i = 0; i < PROPOSAL_MAX; i++) {
338 		if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
339 			return r;
340 	}
341 	if ((r = sshbuf_put_u8(b, 0)) != 0 ||	/* first_kex_packet_follows */
342 	    (r = sshbuf_put_u32(b, 0)) != 0)	/* uint32 reserved */
343 		return r;
344 	return 0;
345 }
346 
347 /* parse buffer and return algorithm proposal */
348 int
349 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
350 {
351 	struct sshbuf *b = NULL;
352 	u_char v;
353 	u_int i;
354 	char **proposal = NULL;
355 	int r;
356 
357 	*propp = NULL;
358 	if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
359 		return SSH_ERR_ALLOC_FAIL;
360 	if ((b = sshbuf_fromb(raw)) == NULL) {
361 		r = SSH_ERR_ALLOC_FAIL;
362 		goto out;
363 	}
364 	if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) { /* skip cookie */
365 		error_fr(r, "consume cookie");
366 		goto out;
367 	}
368 	/* extract kex init proposal strings */
369 	for (i = 0; i < PROPOSAL_MAX; i++) {
370 		if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0) {
371 			error_fr(r, "parse proposal %u", i);
372 			goto out;
373 		}
374 		debug2("%s: %s", proposal_names[i], proposal[i]);
375 	}
376 	/* first kex follows / reserved */
377 	if ((r = sshbuf_get_u8(b, &v)) != 0 ||	/* first_kex_follows */
378 	    (r = sshbuf_get_u32(b, &i)) != 0) {	/* reserved */
379 		error_fr(r, "parse");
380 		goto out;
381 	}
382 	if (first_kex_follows != NULL)
383 		*first_kex_follows = v;
384 	debug2("first_kex_follows %d ", v);
385 	debug2("reserved %u ", i);
386 	r = 0;
387 	*propp = proposal;
388  out:
389 	if (r != 0 && proposal != NULL)
390 		kex_prop_free(proposal);
391 	sshbuf_free(b);
392 	return r;
393 }
394 
395 void
396 kex_prop_free(char **proposal)
397 {
398 	u_int i;
399 
400 	if (proposal == NULL)
401 		return;
402 	for (i = 0; i < PROPOSAL_MAX; i++)
403 		free(proposal[i]);
404 	free(proposal);
405 }
406 
407 /* ARGSUSED */
408 int
409 kex_protocol_error(int type, u_int32_t seq, struct ssh *ssh)
410 {
411 	int r;
412 
413 	error("kex protocol error: type %d seq %u", type, seq);
414 	if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
415 	    (r = sshpkt_put_u32(ssh, seq)) != 0 ||
416 	    (r = sshpkt_send(ssh)) != 0)
417 		return r;
418 	return 0;
419 }
420 
421 static void
422 kex_reset_dispatch(struct ssh *ssh)
423 {
424 	ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
425 	    SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
426 }
427 
428 static int
429 kex_send_ext_info(struct ssh *ssh)
430 {
431 	int r;
432 	char *algs;
433 
434 	debug("Sending SSH2_MSG_EXT_INFO");
435 	if ((algs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
436 		return SSH_ERR_ALLOC_FAIL;
437 	/* XXX filter algs list by allowed pubkey/hostbased types */
438 	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
439 	    (r = sshpkt_put_u32(ssh, 2)) != 0 ||
440 	    (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
441 	    (r = sshpkt_put_cstring(ssh, algs)) != 0 ||
442 	    (r = sshpkt_put_cstring(ssh,
443 	    "publickey-hostbound@openssh.com")) != 0 ||
444 	    (r = sshpkt_put_cstring(ssh, "0")) != 0 ||
445 	    (r = sshpkt_send(ssh)) != 0) {
446 		error_fr(r, "compose");
447 		goto out;
448 	}
449 	/* success */
450 	r = 0;
451  out:
452 	free(algs);
453 	return r;
454 }
455 
456 int
457 kex_send_newkeys(struct ssh *ssh)
458 {
459 	int r;
460 
461 	kex_reset_dispatch(ssh);
462 	if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
463 	    (r = sshpkt_send(ssh)) != 0)
464 		return r;
465 	debug("SSH2_MSG_NEWKEYS sent");
466 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
467 	if (ssh->kex->ext_info_c && (ssh->kex->flags & KEX_INITIAL) != 0)
468 		if ((r = kex_send_ext_info(ssh)) != 0)
469 			return r;
470 	debug("expecting SSH2_MSG_NEWKEYS");
471 	return 0;
472 }
473 
474 int
475 kex_input_ext_info(int type, u_int32_t seq, struct ssh *ssh)
476 {
477 	struct kex *kex = ssh->kex;
478 	u_int32_t i, ninfo;
479 	char *name;
480 	u_char *val;
481 	size_t vlen;
482 	int r;
483 
484 	debug("SSH2_MSG_EXT_INFO received");
485 	ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
486 	if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
487 		return r;
488 	for (i = 0; i < ninfo; i++) {
489 		if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
490 			return r;
491 		if ((r = sshpkt_get_string(ssh, &val, &vlen)) != 0) {
492 			free(name);
493 			return r;
494 		}
495 		if (strcmp(name, "server-sig-algs") == 0) {
496 			/* Ensure no \0 lurking in value */
497 			if (memchr(val, '\0', vlen) != NULL) {
498 				error_f("nul byte in %s", name);
499 				return SSH_ERR_INVALID_FORMAT;
500 			}
501 			debug_f("%s=<%s>", name, val);
502 			kex->server_sig_algs = val;
503 			val = NULL;
504 		} else if (strcmp(name,
505 		    "publickey-hostbound@openssh.com") == 0) {
506 			/* XXX refactor */
507 			/* Ensure no \0 lurking in value */
508 			if (memchr(val, '\0', vlen) != NULL) {
509 				error_f("nul byte in %s", name);
510 				return SSH_ERR_INVALID_FORMAT;
511 			}
512 			debug_f("%s=<%s>", name, val);
513 			if (strcmp(val, "0") == 0)
514 				kex->flags |= KEX_HAS_PUBKEY_HOSTBOUND;
515 			else {
516 				debug_f("unsupported version of %s extension",
517 				    name);
518 			}
519 		} else
520 			debug_f("%s (unrecognised)", name);
521 		free(name);
522 		free(val);
523 	}
524 	return sshpkt_get_end(ssh);
525 }
526 
527 static int
528 kex_input_newkeys(int type, u_int32_t seq, struct ssh *ssh)
529 {
530 	struct kex *kex = ssh->kex;
531 	int r;
532 
533 	debug("SSH2_MSG_NEWKEYS received");
534 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
535 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
536 	if ((r = sshpkt_get_end(ssh)) != 0)
537 		return r;
538 	if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0)
539 		return r;
540 	kex->done = 1;
541 	kex->flags &= ~KEX_INITIAL;
542 	sshbuf_reset(kex->peer);
543 	/* sshbuf_reset(kex->my); */
544 	kex->flags &= ~KEX_INIT_SENT;
545 	free(kex->name);
546 	kex->name = NULL;
547 	return 0;
548 }
549 
550 int
551 kex_send_kexinit(struct ssh *ssh)
552 {
553 	u_char *cookie;
554 	struct kex *kex = ssh->kex;
555 	int r;
556 
557 	if (kex == NULL) {
558 		error_f("no kex");
559 		return SSH_ERR_INTERNAL_ERROR;
560 	}
561 	if (kex->flags & KEX_INIT_SENT)
562 		return 0;
563 	kex->done = 0;
564 
565 	/* generate a random cookie */
566 	if (sshbuf_len(kex->my) < KEX_COOKIE_LEN) {
567 		error_f("bad kex length: %zu < %d",
568 		    sshbuf_len(kex->my), KEX_COOKIE_LEN);
569 		return SSH_ERR_INVALID_FORMAT;
570 	}
571 	if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL) {
572 		error_f("buffer error");
573 		return SSH_ERR_INTERNAL_ERROR;
574 	}
575 	arc4random_buf(cookie, KEX_COOKIE_LEN);
576 
577 	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
578 	    (r = sshpkt_putb(ssh, kex->my)) != 0 ||
579 	    (r = sshpkt_send(ssh)) != 0) {
580 		error_fr(r, "compose reply");
581 		return r;
582 	}
583 	debug("SSH2_MSG_KEXINIT sent");
584 	kex->flags |= KEX_INIT_SENT;
585 	return 0;
586 }
587 
588 /* ARGSUSED */
589 int
590 kex_input_kexinit(int type, u_int32_t seq, struct ssh *ssh)
591 {
592 	struct kex *kex = ssh->kex;
593 	const u_char *ptr;
594 	u_int i;
595 	size_t dlen;
596 	int r;
597 
598 	debug("SSH2_MSG_KEXINIT received");
599 	if (kex == NULL) {
600 		error_f("no kex");
601 		return SSH_ERR_INTERNAL_ERROR;
602 	}
603 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, NULL);
604 	ptr = sshpkt_ptr(ssh, &dlen);
605 	if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
606 		return r;
607 
608 	/* discard packet */
609 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
610 		if ((r = sshpkt_get_u8(ssh, NULL)) != 0) {
611 			error_fr(r, "discard cookie");
612 			return r;
613 		}
614 	}
615 	for (i = 0; i < PROPOSAL_MAX; i++) {
616 		if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0) {
617 			error_fr(r, "discard proposal");
618 			return r;
619 		}
620 	}
621 	/*
622 	 * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
623 	 * KEX method has the server move first, but a server might be using
624 	 * a custom method or one that we otherwise don't support. We should
625 	 * be prepared to remember first_kex_follows here so we can eat a
626 	 * packet later.
627 	 * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
628 	 * for cases where the server *doesn't* go first. I guess we should
629 	 * ignore it when it is set for these cases, which is what we do now.
630 	 */
631 	if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||	/* first_kex_follows */
632 	    (r = sshpkt_get_u32(ssh, NULL)) != 0 ||	/* reserved */
633 	    (r = sshpkt_get_end(ssh)) != 0)
634 			return r;
635 
636 	if (!(kex->flags & KEX_INIT_SENT))
637 		if ((r = kex_send_kexinit(ssh)) != 0)
638 			return r;
639 	if ((r = kex_choose_conf(ssh)) != 0)
640 		return r;
641 
642 	if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
643 		return (kex->kex[kex->kex_type])(ssh);
644 
645 	error_f("unknown kex type %u", kex->kex_type);
646 	return SSH_ERR_INTERNAL_ERROR;
647 }
648 
649 struct kex *
650 kex_new(void)
651 {
652 	struct kex *kex;
653 
654 	if ((kex = calloc(1, sizeof(*kex))) == NULL ||
655 	    (kex->peer = sshbuf_new()) == NULL ||
656 	    (kex->my = sshbuf_new()) == NULL ||
657 	    (kex->client_version = sshbuf_new()) == NULL ||
658 	    (kex->server_version = sshbuf_new()) == NULL ||
659 	    (kex->session_id = sshbuf_new()) == NULL) {
660 		kex_free(kex);
661 		return NULL;
662 	}
663 	return kex;
664 }
665 
666 void
667 kex_free_newkeys(struct newkeys *newkeys)
668 {
669 	if (newkeys == NULL)
670 		return;
671 	if (newkeys->enc.key) {
672 		explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
673 		free(newkeys->enc.key);
674 		newkeys->enc.key = NULL;
675 	}
676 	if (newkeys->enc.iv) {
677 		explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
678 		free(newkeys->enc.iv);
679 		newkeys->enc.iv = NULL;
680 	}
681 	free(newkeys->enc.name);
682 	explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
683 	free(newkeys->comp.name);
684 	explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
685 	mac_clear(&newkeys->mac);
686 	if (newkeys->mac.key) {
687 		explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
688 		free(newkeys->mac.key);
689 		newkeys->mac.key = NULL;
690 	}
691 	free(newkeys->mac.name);
692 	explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
693 	freezero(newkeys, sizeof(*newkeys));
694 }
695 
696 void
697 kex_free(struct kex *kex)
698 {
699 	u_int mode;
700 
701 	if (kex == NULL)
702 		return;
703 
704 #ifdef WITH_OPENSSL
705 	DH_free(kex->dh);
706 #ifdef OPENSSL_HAS_ECC
707 	EC_KEY_free(kex->ec_client_key);
708 #endif /* OPENSSL_HAS_ECC */
709 #endif /* WITH_OPENSSL */
710 	for (mode = 0; mode < MODE_MAX; mode++) {
711 		kex_free_newkeys(kex->newkeys[mode]);
712 		kex->newkeys[mode] = NULL;
713 	}
714 	sshbuf_free(kex->peer);
715 	sshbuf_free(kex->my);
716 	sshbuf_free(kex->client_version);
717 	sshbuf_free(kex->server_version);
718 	sshbuf_free(kex->client_pub);
719 	sshbuf_free(kex->session_id);
720 	sshbuf_free(kex->initial_sig);
721 	sshkey_free(kex->initial_hostkey);
722 	free(kex->failed_choice);
723 	free(kex->hostkey_alg);
724 	free(kex->name);
725 	free(kex);
726 }
727 
728 int
729 kex_ready(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
730 {
731 	int r;
732 
733 	if ((r = kex_prop2buf(ssh->kex->my, proposal)) != 0)
734 		return r;
735 	ssh->kex->flags = KEX_INITIAL;
736 	kex_reset_dispatch(ssh);
737 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
738 	return 0;
739 }
740 
741 int
742 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
743 {
744 	int r;
745 
746 	if ((r = kex_ready(ssh, proposal)) != 0)
747 		return r;
748 	if ((r = kex_send_kexinit(ssh)) != 0) {		/* we start */
749 		kex_free(ssh->kex);
750 		ssh->kex = NULL;
751 		return r;
752 	}
753 	return 0;
754 }
755 
756 /*
757  * Request key re-exchange, returns 0 on success or a ssherr.h error
758  * code otherwise. Must not be called if KEX is incomplete or in-progress.
759  */
760 int
761 kex_start_rekex(struct ssh *ssh)
762 {
763 	if (ssh->kex == NULL) {
764 		error_f("no kex");
765 		return SSH_ERR_INTERNAL_ERROR;
766 	}
767 	if (ssh->kex->done == 0) {
768 		error_f("requested twice");
769 		return SSH_ERR_INTERNAL_ERROR;
770 	}
771 	ssh->kex->done = 0;
772 	return kex_send_kexinit(ssh);
773 }
774 
775 static int
776 choose_enc(struct sshenc *enc, char *client, char *server)
777 {
778 	char *name = match_list(client, server, NULL);
779 
780 	if (name == NULL)
781 		return SSH_ERR_NO_CIPHER_ALG_MATCH;
782 	if ((enc->cipher = cipher_by_name(name)) == NULL) {
783 		error_f("unsupported cipher %s", name);
784 		free(name);
785 		return SSH_ERR_INTERNAL_ERROR;
786 	}
787 	enc->name = name;
788 	enc->enabled = 0;
789 	enc->iv = NULL;
790 	enc->iv_len = cipher_ivlen(enc->cipher);
791 	enc->key = NULL;
792 	enc->key_len = cipher_keylen(enc->cipher);
793 	enc->block_size = cipher_blocksize(enc->cipher);
794 	return 0;
795 }
796 
797 static int
798 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
799 {
800 	char *name = match_list(client, server, NULL);
801 
802 	if (name == NULL)
803 		return SSH_ERR_NO_MAC_ALG_MATCH;
804 	if (mac_setup(mac, name) < 0) {
805 		error_f("unsupported MAC %s", name);
806 		free(name);
807 		return SSH_ERR_INTERNAL_ERROR;
808 	}
809 	mac->name = name;
810 	mac->key = NULL;
811 	mac->enabled = 0;
812 	return 0;
813 }
814 
815 static int
816 choose_comp(struct sshcomp *comp, char *client, char *server)
817 {
818 	char *name = match_list(client, server, NULL);
819 
820 	if (name == NULL)
821 		return SSH_ERR_NO_COMPRESS_ALG_MATCH;
822 #ifdef WITH_ZLIB
823 	if (strcmp(name, "zlib@openssh.com") == 0) {
824 		comp->type = COMP_DELAYED;
825 	} else if (strcmp(name, "zlib") == 0) {
826 		comp->type = COMP_ZLIB;
827 	} else
828 #endif	/* WITH_ZLIB */
829 	if (strcmp(name, "none") == 0) {
830 		comp->type = COMP_NONE;
831 	} else {
832 		error_f("unsupported compression scheme %s", name);
833 		free(name);
834 		return SSH_ERR_INTERNAL_ERROR;
835 	}
836 	comp->name = name;
837 	return 0;
838 }
839 
840 static int
841 choose_kex(struct kex *k, char *client, char *server)
842 {
843 	const struct kexalg *kexalg;
844 
845 	k->name = match_list(client, server, NULL);
846 
847 	debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
848 	if (k->name == NULL)
849 		return SSH_ERR_NO_KEX_ALG_MATCH;
850 	if ((kexalg = kex_alg_by_name(k->name)) == NULL) {
851 		error_f("unsupported KEX method %s", k->name);
852 		return SSH_ERR_INTERNAL_ERROR;
853 	}
854 	k->kex_type = kexalg->type;
855 	k->hash_alg = kexalg->hash_alg;
856 	k->ec_nid = kexalg->ec_nid;
857 	return 0;
858 }
859 
860 static int
861 choose_hostkeyalg(struct kex *k, char *client, char *server)
862 {
863 	free(k->hostkey_alg);
864 	k->hostkey_alg = match_list(client, server, NULL);
865 
866 	debug("kex: host key algorithm: %s",
867 	    k->hostkey_alg ? k->hostkey_alg : "(no match)");
868 	if (k->hostkey_alg == NULL)
869 		return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
870 	k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
871 	if (k->hostkey_type == KEY_UNSPEC) {
872 		error_f("unsupported hostkey algorithm %s", k->hostkey_alg);
873 		return SSH_ERR_INTERNAL_ERROR;
874 	}
875 	k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
876 	return 0;
877 }
878 
879 static int
880 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
881 {
882 	static int check[] = {
883 		PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
884 	};
885 	int *idx;
886 	char *p;
887 
888 	for (idx = &check[0]; *idx != -1; idx++) {
889 		if ((p = strchr(my[*idx], ',')) != NULL)
890 			*p = '\0';
891 		if ((p = strchr(peer[*idx], ',')) != NULL)
892 			*p = '\0';
893 		if (strcmp(my[*idx], peer[*idx]) != 0) {
894 			debug2("proposal mismatch: my %s peer %s",
895 			    my[*idx], peer[*idx]);
896 			return (0);
897 		}
898 	}
899 	debug2("proposals match");
900 	return (1);
901 }
902 
903 /* returns non-zero if proposal contains any algorithm from algs */
904 static int
905 has_any_alg(const char *proposal, const char *algs)
906 {
907 	char *cp;
908 
909 	if ((cp = match_list(proposal, algs, NULL)) == NULL)
910 		return 0;
911 	free(cp);
912 	return 1;
913 }
914 
915 static int
916 kex_choose_conf(struct ssh *ssh)
917 {
918 	struct kex *kex = ssh->kex;
919 	struct newkeys *newkeys;
920 	char **my = NULL, **peer = NULL;
921 	char **cprop, **sprop;
922 	int nenc, nmac, ncomp;
923 	u_int mode, ctos, need, dh_need, authlen;
924 	int r, first_kex_follows;
925 
926 	debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
927 	if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
928 		goto out;
929 	debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
930 	if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
931 		goto out;
932 
933 	if (kex->server) {
934 		cprop=peer;
935 		sprop=my;
936 	} else {
937 		cprop=my;
938 		sprop=peer;
939 	}
940 
941 	/* Check whether client supports ext_info_c */
942 	if (kex->server && (kex->flags & KEX_INITIAL)) {
943 		char *ext;
944 
945 		ext = match_list("ext-info-c", peer[PROPOSAL_KEX_ALGS], NULL);
946 		kex->ext_info_c = (ext != NULL);
947 		free(ext);
948 	}
949 
950 	/* Check whether client supports rsa-sha2 algorithms */
951 	if (kex->server && (kex->flags & KEX_INITIAL)) {
952 		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
953 		    "rsa-sha2-256,rsa-sha2-256-cert-v01@openssh.com"))
954 			kex->flags |= KEX_RSA_SHA2_256_SUPPORTED;
955 		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
956 		    "rsa-sha2-512,rsa-sha2-512-cert-v01@openssh.com"))
957 			kex->flags |= KEX_RSA_SHA2_512_SUPPORTED;
958 	}
959 
960 	/* Algorithm Negotiation */
961 	if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
962 	    sprop[PROPOSAL_KEX_ALGS])) != 0) {
963 		kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
964 		peer[PROPOSAL_KEX_ALGS] = NULL;
965 		goto out;
966 	}
967 	if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
968 	    sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
969 		kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
970 		peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
971 		goto out;
972 	}
973 	for (mode = 0; mode < MODE_MAX; mode++) {
974 		if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
975 			r = SSH_ERR_ALLOC_FAIL;
976 			goto out;
977 		}
978 		kex->newkeys[mode] = newkeys;
979 		ctos = (!kex->server && mode == MODE_OUT) ||
980 		    (kex->server && mode == MODE_IN);
981 		nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
982 		nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
983 		ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
984 		if ((r = choose_enc(&newkeys->enc, cprop[nenc],
985 		    sprop[nenc])) != 0) {
986 			kex->failed_choice = peer[nenc];
987 			peer[nenc] = NULL;
988 			goto out;
989 		}
990 		authlen = cipher_authlen(newkeys->enc.cipher);
991 		/* ignore mac for authenticated encryption */
992 		if (authlen == 0 &&
993 		    (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
994 		    sprop[nmac])) != 0) {
995 			kex->failed_choice = peer[nmac];
996 			peer[nmac] = NULL;
997 			goto out;
998 		}
999 		if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
1000 		    sprop[ncomp])) != 0) {
1001 			kex->failed_choice = peer[ncomp];
1002 			peer[ncomp] = NULL;
1003 			goto out;
1004 		}
1005 		debug("kex: %s cipher: %s MAC: %s compression: %s",
1006 		    ctos ? "client->server" : "server->client",
1007 		    newkeys->enc.name,
1008 		    authlen == 0 ? newkeys->mac.name : "<implicit>",
1009 		    newkeys->comp.name);
1010 	}
1011 	need = dh_need = 0;
1012 	for (mode = 0; mode < MODE_MAX; mode++) {
1013 		newkeys = kex->newkeys[mode];
1014 		need = MAXIMUM(need, newkeys->enc.key_len);
1015 		need = MAXIMUM(need, newkeys->enc.block_size);
1016 		need = MAXIMUM(need, newkeys->enc.iv_len);
1017 		need = MAXIMUM(need, newkeys->mac.key_len);
1018 		dh_need = MAXIMUM(dh_need, cipher_seclen(newkeys->enc.cipher));
1019 		dh_need = MAXIMUM(dh_need, newkeys->enc.block_size);
1020 		dh_need = MAXIMUM(dh_need, newkeys->enc.iv_len);
1021 		dh_need = MAXIMUM(dh_need, newkeys->mac.key_len);
1022 	}
1023 	/* XXX need runden? */
1024 	kex->we_need = need;
1025 	kex->dh_need = dh_need;
1026 
1027 	/* ignore the next message if the proposals do not match */
1028 	if (first_kex_follows && !proposals_match(my, peer))
1029 		ssh->dispatch_skip_packets = 1;
1030 	r = 0;
1031  out:
1032 	kex_prop_free(my);
1033 	kex_prop_free(peer);
1034 	return r;
1035 }
1036 
1037 static int
1038 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
1039     const struct sshbuf *shared_secret, u_char **keyp)
1040 {
1041 	struct kex *kex = ssh->kex;
1042 	struct ssh_digest_ctx *hashctx = NULL;
1043 	char c = id;
1044 	u_int have;
1045 	size_t mdsz;
1046 	u_char *digest;
1047 	int r;
1048 
1049 	if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
1050 		return SSH_ERR_INVALID_ARGUMENT;
1051 	if ((digest = calloc(1, ROUNDUP(need, mdsz))) == NULL) {
1052 		r = SSH_ERR_ALLOC_FAIL;
1053 		goto out;
1054 	}
1055 
1056 	/* K1 = HASH(K || H || "A" || session_id) */
1057 	if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1058 	    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1059 	    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1060 	    ssh_digest_update(hashctx, &c, 1) != 0 ||
1061 	    ssh_digest_update_buffer(hashctx, kex->session_id) != 0 ||
1062 	    ssh_digest_final(hashctx, digest, mdsz) != 0) {
1063 		r = SSH_ERR_LIBCRYPTO_ERROR;
1064 		error_f("KEX hash failed");
1065 		goto out;
1066 	}
1067 	ssh_digest_free(hashctx);
1068 	hashctx = NULL;
1069 
1070 	/*
1071 	 * expand key:
1072 	 * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
1073 	 * Key = K1 || K2 || ... || Kn
1074 	 */
1075 	for (have = mdsz; need > have; have += mdsz) {
1076 		if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1077 		    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1078 		    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1079 		    ssh_digest_update(hashctx, digest, have) != 0 ||
1080 		    ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
1081 			error_f("KDF failed");
1082 			r = SSH_ERR_LIBCRYPTO_ERROR;
1083 			goto out;
1084 		}
1085 		ssh_digest_free(hashctx);
1086 		hashctx = NULL;
1087 	}
1088 #ifdef DEBUG_KEX
1089 	fprintf(stderr, "key '%c'== ", c);
1090 	dump_digest("key", digest, need);
1091 #endif
1092 	*keyp = digest;
1093 	digest = NULL;
1094 	r = 0;
1095  out:
1096 	free(digest);
1097 	ssh_digest_free(hashctx);
1098 	return r;
1099 }
1100 
1101 #define NKEYS	6
1102 int
1103 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
1104     const struct sshbuf *shared_secret)
1105 {
1106 	struct kex *kex = ssh->kex;
1107 	u_char *keys[NKEYS];
1108 	u_int i, j, mode, ctos;
1109 	int r;
1110 
1111 	/* save initial hash as session id */
1112 	if ((kex->flags & KEX_INITIAL) != 0) {
1113 		if (sshbuf_len(kex->session_id) != 0) {
1114 			error_f("already have session ID at kex");
1115 			return SSH_ERR_INTERNAL_ERROR;
1116 		}
1117 		if ((r = sshbuf_put(kex->session_id, hash, hashlen)) != 0)
1118 			return r;
1119 	} else if (sshbuf_len(kex->session_id) == 0) {
1120 		error_f("no session ID in rekex");
1121 		return SSH_ERR_INTERNAL_ERROR;
1122 	}
1123 	for (i = 0; i < NKEYS; i++) {
1124 		if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
1125 		    shared_secret, &keys[i])) != 0) {
1126 			for (j = 0; j < i; j++)
1127 				free(keys[j]);
1128 			return r;
1129 		}
1130 	}
1131 	for (mode = 0; mode < MODE_MAX; mode++) {
1132 		ctos = (!kex->server && mode == MODE_OUT) ||
1133 		    (kex->server && mode == MODE_IN);
1134 		kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
1135 		kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
1136 		kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
1137 	}
1138 	return 0;
1139 }
1140 
1141 int
1142 kex_load_hostkey(struct ssh *ssh, struct sshkey **prvp, struct sshkey **pubp)
1143 {
1144 	struct kex *kex = ssh->kex;
1145 
1146 	*pubp = NULL;
1147 	*prvp = NULL;
1148 	if (kex->load_host_public_key == NULL ||
1149 	    kex->load_host_private_key == NULL) {
1150 		error_f("missing hostkey loader");
1151 		return SSH_ERR_INVALID_ARGUMENT;
1152 	}
1153 	*pubp = kex->load_host_public_key(kex->hostkey_type,
1154 	    kex->hostkey_nid, ssh);
1155 	*prvp = kex->load_host_private_key(kex->hostkey_type,
1156 	    kex->hostkey_nid, ssh);
1157 	if (*pubp == NULL)
1158 		return SSH_ERR_NO_HOSTKEY_LOADED;
1159 	return 0;
1160 }
1161 
1162 int
1163 kex_verify_host_key(struct ssh *ssh, struct sshkey *server_host_key)
1164 {
1165 	struct kex *kex = ssh->kex;
1166 
1167 	if (kex->verify_host_key == NULL) {
1168 		error_f("missing hostkey verifier");
1169 		return SSH_ERR_INVALID_ARGUMENT;
1170 	}
1171 	if (server_host_key->type != kex->hostkey_type ||
1172 	    (kex->hostkey_type == KEY_ECDSA &&
1173 	    server_host_key->ecdsa_nid != kex->hostkey_nid))
1174 		return SSH_ERR_KEY_TYPE_MISMATCH;
1175 	if (kex->verify_host_key(server_host_key, ssh) == -1)
1176 		return  SSH_ERR_SIGNATURE_INVALID;
1177 	return 0;
1178 }
1179 
1180 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1181 void
1182 dump_digest(const char *msg, const u_char *digest, int len)
1183 {
1184 	fprintf(stderr, "%s\n", msg);
1185 	sshbuf_dump_data(digest, len, stderr);
1186 }
1187 #endif
1188 
1189 /*
1190  * Send a plaintext error message to the peer, suffixed by \r\n.
1191  * Only used during banner exchange, and there only for the server.
1192  */
1193 static void
1194 send_error(struct ssh *ssh, char *msg)
1195 {
1196 	char *crnl = "\r\n";
1197 
1198 	if (!ssh->kex->server)
1199 		return;
1200 
1201 	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1202 	    msg, strlen(msg)) != strlen(msg) ||
1203 	    atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1204 	    crnl, strlen(crnl)) != strlen(crnl))
1205 		error_f("write: %.100s", strerror(errno));
1206 }
1207 
1208 /*
1209  * Sends our identification string and waits for the peer's. Will block for
1210  * up to timeout_ms (or indefinitely if timeout_ms <= 0).
1211  * Returns on 0 success or a ssherr.h code on failure.
1212  */
1213 int
1214 kex_exchange_identification(struct ssh *ssh, int timeout_ms,
1215     const char *version_addendum)
1216 {
1217 	int remote_major, remote_minor, mismatch, oerrno = 0;
1218 	size_t len, i, n;
1219 	int r, expect_nl;
1220 	u_char c;
1221 	struct sshbuf *our_version = ssh->kex->server ?
1222 	    ssh->kex->server_version : ssh->kex->client_version;
1223 	struct sshbuf *peer_version = ssh->kex->server ?
1224 	    ssh->kex->client_version : ssh->kex->server_version;
1225 	char *our_version_string = NULL, *peer_version_string = NULL;
1226 	char *cp, *remote_version = NULL;
1227 
1228 	/* Prepare and send our banner */
1229 	sshbuf_reset(our_version);
1230 	if (version_addendum != NULL && *version_addendum == '\0')
1231 		version_addendum = NULL;
1232 	if ((r = sshbuf_putf(our_version, "SSH-%d.%d-%.100s%s%s\r\n",
1233 	    PROTOCOL_MAJOR_2, PROTOCOL_MINOR_2, SSH_VERSION,
1234 	    version_addendum == NULL ? "" : " ",
1235 	    version_addendum == NULL ? "" : version_addendum)) != 0) {
1236 		oerrno = errno;
1237 		error_fr(r, "sshbuf_putf");
1238 		goto out;
1239 	}
1240 
1241 	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1242 	    sshbuf_mutable_ptr(our_version),
1243 	    sshbuf_len(our_version)) != sshbuf_len(our_version)) {
1244 		oerrno = errno;
1245 		debug_f("write: %.100s", strerror(errno));
1246 		r = SSH_ERR_SYSTEM_ERROR;
1247 		goto out;
1248 	}
1249 	if ((r = sshbuf_consume_end(our_version, 2)) != 0) { /* trim \r\n */
1250 		oerrno = errno;
1251 		error_fr(r, "sshbuf_consume_end");
1252 		goto out;
1253 	}
1254 	our_version_string = sshbuf_dup_string(our_version);
1255 	if (our_version_string == NULL) {
1256 		error_f("sshbuf_dup_string failed");
1257 		r = SSH_ERR_ALLOC_FAIL;
1258 		goto out;
1259 	}
1260 	debug("Local version string %.100s", our_version_string);
1261 
1262 	/* Read other side's version identification. */
1263 	for (n = 0; ; n++) {
1264 		if (n >= SSH_MAX_PRE_BANNER_LINES) {
1265 			send_error(ssh, "No SSH identification string "
1266 			    "received.");
1267 			error_f("No SSH version received in first %u lines "
1268 			    "from server", SSH_MAX_PRE_BANNER_LINES);
1269 			r = SSH_ERR_INVALID_FORMAT;
1270 			goto out;
1271 		}
1272 		sshbuf_reset(peer_version);
1273 		expect_nl = 0;
1274 		for (i = 0; ; i++) {
1275 			if (timeout_ms > 0) {
1276 				r = waitrfd(ssh_packet_get_connection_in(ssh),
1277 				    &timeout_ms);
1278 				if (r == -1 && errno == ETIMEDOUT) {
1279 					send_error(ssh, "Timed out waiting "
1280 					    "for SSH identification string.");
1281 					error("Connection timed out during "
1282 					    "banner exchange");
1283 					r = SSH_ERR_CONN_TIMEOUT;
1284 					goto out;
1285 				} else if (r == -1) {
1286 					oerrno = errno;
1287 					error_f("%s", strerror(errno));
1288 					r = SSH_ERR_SYSTEM_ERROR;
1289 					goto out;
1290 				}
1291 			}
1292 
1293 			len = atomicio(read, ssh_packet_get_connection_in(ssh),
1294 			    &c, 1);
1295 			if (len != 1 && errno == EPIPE) {
1296 				error_f("Connection closed by remote host");
1297 				r = SSH_ERR_CONN_CLOSED;
1298 				goto out;
1299 			} else if (len != 1) {
1300 				oerrno = errno;
1301 				error_f("read: %.100s", strerror(errno));
1302 				r = SSH_ERR_SYSTEM_ERROR;
1303 				goto out;
1304 			}
1305 			if (c == '\r') {
1306 				expect_nl = 1;
1307 				continue;
1308 			}
1309 			if (c == '\n')
1310 				break;
1311 			if (c == '\0' || expect_nl) {
1312 				error_f("banner line contains invalid "
1313 				    "characters");
1314 				goto invalid;
1315 			}
1316 			if ((r = sshbuf_put_u8(peer_version, c)) != 0) {
1317 				oerrno = errno;
1318 				error_fr(r, "sshbuf_put");
1319 				goto out;
1320 			}
1321 			if (sshbuf_len(peer_version) > SSH_MAX_BANNER_LEN) {
1322 				error_f("banner line too long");
1323 				goto invalid;
1324 			}
1325 		}
1326 		/* Is this an actual protocol banner? */
1327 		if (sshbuf_len(peer_version) > 4 &&
1328 		    memcmp(sshbuf_ptr(peer_version), "SSH-", 4) == 0)
1329 			break;
1330 		/* If not, then just log the line and continue */
1331 		if ((cp = sshbuf_dup_string(peer_version)) == NULL) {
1332 			error_f("sshbuf_dup_string failed");
1333 			r = SSH_ERR_ALLOC_FAIL;
1334 			goto out;
1335 		}
1336 		/* Do not accept lines before the SSH ident from a client */
1337 		if (ssh->kex->server) {
1338 			error_f("client sent invalid protocol identifier "
1339 			    "\"%.256s\"", cp);
1340 			free(cp);
1341 			goto invalid;
1342 		}
1343 		debug_f("banner line %zu: %s", n, cp);
1344 		free(cp);
1345 	}
1346 	peer_version_string = sshbuf_dup_string(peer_version);
1347 	if (peer_version_string == NULL)
1348 		error_f("sshbuf_dup_string failed");
1349 	/* XXX must be same size for sscanf */
1350 	if ((remote_version = calloc(1, sshbuf_len(peer_version))) == NULL) {
1351 		error_f("calloc failed");
1352 		r = SSH_ERR_ALLOC_FAIL;
1353 		goto out;
1354 	}
1355 
1356 	/*
1357 	 * Check that the versions match.  In future this might accept
1358 	 * several versions and set appropriate flags to handle them.
1359 	 */
1360 	if (sscanf(peer_version_string, "SSH-%d.%d-%[^\n]\n",
1361 	    &remote_major, &remote_minor, remote_version) != 3) {
1362 		error("Bad remote protocol version identification: '%.100s'",
1363 		    peer_version_string);
1364  invalid:
1365 		send_error(ssh, "Invalid SSH identification string.");
1366 		r = SSH_ERR_INVALID_FORMAT;
1367 		goto out;
1368 	}
1369 	debug("Remote protocol version %d.%d, remote software version %.100s",
1370 	    remote_major, remote_minor, remote_version);
1371 	compat_banner(ssh, remote_version);
1372 
1373 	mismatch = 0;
1374 	switch (remote_major) {
1375 	case 2:
1376 		break;
1377 	case 1:
1378 		if (remote_minor != 99)
1379 			mismatch = 1;
1380 		break;
1381 	default:
1382 		mismatch = 1;
1383 		break;
1384 	}
1385 	if (mismatch) {
1386 		error("Protocol major versions differ: %d vs. %d",
1387 		    PROTOCOL_MAJOR_2, remote_major);
1388 		send_error(ssh, "Protocol major versions differ.");
1389 		r = SSH_ERR_NO_PROTOCOL_VERSION;
1390 		goto out;
1391 	}
1392 
1393 	if (ssh->kex->server && (ssh->compat & SSH_BUG_PROBE) != 0) {
1394 		logit("probed from %s port %d with %s.  Don't panic.",
1395 		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1396 		    peer_version_string);
1397 		r = SSH_ERR_CONN_CLOSED; /* XXX */
1398 		goto out;
1399 	}
1400 	if (ssh->kex->server && (ssh->compat & SSH_BUG_SCANNER) != 0) {
1401 		logit("scanned from %s port %d with %s.  Don't panic.",
1402 		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1403 		    peer_version_string);
1404 		r = SSH_ERR_CONN_CLOSED; /* XXX */
1405 		goto out;
1406 	}
1407 	if ((ssh->compat & SSH_BUG_RSASIGMD5) != 0) {
1408 		logit("Remote version \"%.100s\" uses unsafe RSA signature "
1409 		    "scheme; disabling use of RSA keys", remote_version);
1410 	}
1411 	/* success */
1412 	r = 0;
1413  out:
1414 	free(our_version_string);
1415 	free(peer_version_string);
1416 	free(remote_version);
1417 	if (r == SSH_ERR_SYSTEM_ERROR)
1418 		errno = oerrno;
1419 	return r;
1420 }
1421 
1422