xref: /freebsd/crypto/openssh/kex.c (revision 5e3190f700637fcfc1a52daeaa4a031fdd2557c7)
1 /* $OpenBSD: kex.c,v 1.181 2023/08/28 03:28:43 djm Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25 
26 #include "includes.h"
27 
28 #include <sys/types.h>
29 #include <errno.h>
30 #include <signal.h>
31 #include <stdarg.h>
32 #include <stdio.h>
33 #include <stdlib.h>
34 #include <string.h>
35 #include <unistd.h>
36 #ifdef HAVE_POLL_H
37 #include <poll.h>
38 #endif
39 
40 #ifdef WITH_OPENSSL
41 #include <openssl/crypto.h>
42 #include <openssl/dh.h>
43 #endif
44 
45 #include "ssh.h"
46 #include "ssh2.h"
47 #include "atomicio.h"
48 #include "version.h"
49 #include "packet.h"
50 #include "compat.h"
51 #include "cipher.h"
52 #include "sshkey.h"
53 #include "kex.h"
54 #include "log.h"
55 #include "mac.h"
56 #include "match.h"
57 #include "misc.h"
58 #include "dispatch.h"
59 #include "monitor.h"
60 #include "myproposal.h"
61 
62 #include "ssherr.h"
63 #include "sshbuf.h"
64 #include "digest.h"
65 #include "xmalloc.h"
66 
67 /* prototype */
68 static int kex_choose_conf(struct ssh *);
69 static int kex_input_newkeys(int, u_int32_t, struct ssh *);
70 
71 static const char * const proposal_names[PROPOSAL_MAX] = {
72 	"KEX algorithms",
73 	"host key algorithms",
74 	"ciphers ctos",
75 	"ciphers stoc",
76 	"MACs ctos",
77 	"MACs stoc",
78 	"compression ctos",
79 	"compression stoc",
80 	"languages ctos",
81 	"languages stoc",
82 };
83 
84 struct kexalg {
85 	char *name;
86 	u_int type;
87 	int ec_nid;
88 	int hash_alg;
89 };
90 static const struct kexalg kexalgs[] = {
91 #ifdef WITH_OPENSSL
92 	{ KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
93 	{ KEX_DH14_SHA1, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
94 	{ KEX_DH14_SHA256, KEX_DH_GRP14_SHA256, 0, SSH_DIGEST_SHA256 },
95 	{ KEX_DH16_SHA512, KEX_DH_GRP16_SHA512, 0, SSH_DIGEST_SHA512 },
96 	{ KEX_DH18_SHA512, KEX_DH_GRP18_SHA512, 0, SSH_DIGEST_SHA512 },
97 	{ KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
98 #ifdef HAVE_EVP_SHA256
99 	{ KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
100 #endif /* HAVE_EVP_SHA256 */
101 #ifdef OPENSSL_HAS_ECC
102 	{ KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
103 	    NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
104 	{ KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
105 	    SSH_DIGEST_SHA384 },
106 # ifdef OPENSSL_HAS_NISTP521
107 	{ KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
108 	    SSH_DIGEST_SHA512 },
109 # endif /* OPENSSL_HAS_NISTP521 */
110 #endif /* OPENSSL_HAS_ECC */
111 #endif /* WITH_OPENSSL */
112 #if defined(HAVE_EVP_SHA256) || !defined(WITH_OPENSSL)
113 	{ KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
114 	{ KEX_CURVE25519_SHA256_OLD, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
115 #ifdef USE_SNTRUP761X25519
116 	{ KEX_SNTRUP761X25519_SHA512, KEX_KEM_SNTRUP761X25519_SHA512, 0,
117 	    SSH_DIGEST_SHA512 },
118 #endif
119 #endif /* HAVE_EVP_SHA256 || !WITH_OPENSSL */
120 	{ NULL, 0, -1, -1},
121 };
122 
123 char *
124 kex_alg_list(char sep)
125 {
126 	char *ret = NULL, *tmp;
127 	size_t nlen, rlen = 0;
128 	const struct kexalg *k;
129 
130 	for (k = kexalgs; k->name != NULL; k++) {
131 		if (ret != NULL)
132 			ret[rlen++] = sep;
133 		nlen = strlen(k->name);
134 		if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
135 			free(ret);
136 			return NULL;
137 		}
138 		ret = tmp;
139 		memcpy(ret + rlen, k->name, nlen + 1);
140 		rlen += nlen;
141 	}
142 	return ret;
143 }
144 
145 static const struct kexalg *
146 kex_alg_by_name(const char *name)
147 {
148 	const struct kexalg *k;
149 
150 	for (k = kexalgs; k->name != NULL; k++) {
151 		if (strcmp(k->name, name) == 0)
152 			return k;
153 	}
154 	return NULL;
155 }
156 
157 /* Validate KEX method name list */
158 int
159 kex_names_valid(const char *names)
160 {
161 	char *s, *cp, *p;
162 
163 	if (names == NULL || strcmp(names, "") == 0)
164 		return 0;
165 	if ((s = cp = strdup(names)) == NULL)
166 		return 0;
167 	for ((p = strsep(&cp, ",")); p && *p != '\0';
168 	    (p = strsep(&cp, ","))) {
169 		if (kex_alg_by_name(p) == NULL) {
170 			error("Unsupported KEX algorithm \"%.100s\"", p);
171 			free(s);
172 			return 0;
173 		}
174 	}
175 	debug3("kex names ok: [%s]", names);
176 	free(s);
177 	return 1;
178 }
179 
180 /*
181  * Concatenate algorithm names, avoiding duplicates in the process.
182  * Caller must free returned string.
183  */
184 char *
185 kex_names_cat(const char *a, const char *b)
186 {
187 	char *ret = NULL, *tmp = NULL, *cp, *p, *m;
188 	size_t len;
189 
190 	if (a == NULL || *a == '\0')
191 		return strdup(b);
192 	if (b == NULL || *b == '\0')
193 		return strdup(a);
194 	if (strlen(b) > 1024*1024)
195 		return NULL;
196 	len = strlen(a) + strlen(b) + 2;
197 	if ((tmp = cp = strdup(b)) == NULL ||
198 	    (ret = calloc(1, len)) == NULL) {
199 		free(tmp);
200 		return NULL;
201 	}
202 	strlcpy(ret, a, len);
203 	for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
204 		if ((m = match_list(ret, p, NULL)) != NULL) {
205 			free(m);
206 			continue; /* Algorithm already present */
207 		}
208 		if (strlcat(ret, ",", len) >= len ||
209 		    strlcat(ret, p, len) >= len) {
210 			free(tmp);
211 			free(ret);
212 			return NULL; /* Shouldn't happen */
213 		}
214 	}
215 	free(tmp);
216 	return ret;
217 }
218 
219 /*
220  * Assemble a list of algorithms from a default list and a string from a
221  * configuration file. The user-provided string may begin with '+' to
222  * indicate that it should be appended to the default, '-' that the
223  * specified names should be removed, or '^' that they should be placed
224  * at the head.
225  */
226 int
227 kex_assemble_names(char **listp, const char *def, const char *all)
228 {
229 	char *cp, *tmp, *patterns;
230 	char *list = NULL, *ret = NULL, *matching = NULL, *opatterns = NULL;
231 	int r = SSH_ERR_INTERNAL_ERROR;
232 
233 	if (listp == NULL || def == NULL || all == NULL)
234 		return SSH_ERR_INVALID_ARGUMENT;
235 
236 	if (*listp == NULL || **listp == '\0') {
237 		if ((*listp = strdup(def)) == NULL)
238 			return SSH_ERR_ALLOC_FAIL;
239 		return 0;
240 	}
241 
242 	list = *listp;
243 	*listp = NULL;
244 	if (*list == '+') {
245 		/* Append names to default list */
246 		if ((tmp = kex_names_cat(def, list + 1)) == NULL) {
247 			r = SSH_ERR_ALLOC_FAIL;
248 			goto fail;
249 		}
250 		free(list);
251 		list = tmp;
252 	} else if (*list == '-') {
253 		/* Remove names from default list */
254 		if ((*listp = match_filter_denylist(def, list + 1)) == NULL) {
255 			r = SSH_ERR_ALLOC_FAIL;
256 			goto fail;
257 		}
258 		free(list);
259 		/* filtering has already been done */
260 		return 0;
261 	} else if (*list == '^') {
262 		/* Place names at head of default list */
263 		if ((tmp = kex_names_cat(list + 1, def)) == NULL) {
264 			r = SSH_ERR_ALLOC_FAIL;
265 			goto fail;
266 		}
267 		free(list);
268 		list = tmp;
269 	} else {
270 		/* Explicit list, overrides default - just use "list" as is */
271 	}
272 
273 	/*
274 	 * The supplied names may be a pattern-list. For the -list case,
275 	 * the patterns are applied above. For the +list and explicit list
276 	 * cases we need to do it now.
277 	 */
278 	ret = NULL;
279 	if ((patterns = opatterns = strdup(list)) == NULL) {
280 		r = SSH_ERR_ALLOC_FAIL;
281 		goto fail;
282 	}
283 	/* Apply positive (i.e. non-negated) patterns from the list */
284 	while ((cp = strsep(&patterns, ",")) != NULL) {
285 		if (*cp == '!') {
286 			/* negated matches are not supported here */
287 			r = SSH_ERR_INVALID_ARGUMENT;
288 			goto fail;
289 		}
290 		free(matching);
291 		if ((matching = match_filter_allowlist(all, cp)) == NULL) {
292 			r = SSH_ERR_ALLOC_FAIL;
293 			goto fail;
294 		}
295 		if ((tmp = kex_names_cat(ret, matching)) == NULL) {
296 			r = SSH_ERR_ALLOC_FAIL;
297 			goto fail;
298 		}
299 		free(ret);
300 		ret = tmp;
301 	}
302 	if (ret == NULL || *ret == '\0') {
303 		/* An empty name-list is an error */
304 		/* XXX better error code? */
305 		r = SSH_ERR_INVALID_ARGUMENT;
306 		goto fail;
307 	}
308 
309 	/* success */
310 	*listp = ret;
311 	ret = NULL;
312 	r = 0;
313 
314  fail:
315 	free(matching);
316 	free(opatterns);
317 	free(list);
318 	free(ret);
319 	return r;
320 }
321 
322 /*
323  * Fill out a proposal array with dynamically allocated values, which may
324  * be modified as required for compatibility reasons.
325  * Any of the options may be NULL, in which case the default is used.
326  * Array contents must be freed by calling kex_proposal_free_entries.
327  */
328 void
329 kex_proposal_populate_entries(struct ssh *ssh, char *prop[PROPOSAL_MAX],
330     const char *kexalgos, const char *ciphers, const char *macs,
331     const char *comp, const char *hkalgs)
332 {
333 	const char *defpropserver[PROPOSAL_MAX] = { KEX_SERVER };
334 	const char *defpropclient[PROPOSAL_MAX] = { KEX_CLIENT };
335 	const char **defprop = ssh->kex->server ? defpropserver : defpropclient;
336 	u_int i;
337 
338 	if (prop == NULL)
339 		fatal_f("proposal missing");
340 
341 	for (i = 0; i < PROPOSAL_MAX; i++) {
342 		switch(i) {
343 		case PROPOSAL_KEX_ALGS:
344 			prop[i] = compat_kex_proposal(ssh,
345 			    kexalgos ? kexalgos : defprop[i]);
346 			break;
347 		case PROPOSAL_ENC_ALGS_CTOS:
348 		case PROPOSAL_ENC_ALGS_STOC:
349 			prop[i] = xstrdup(ciphers ? ciphers : defprop[i]);
350 			break;
351 		case PROPOSAL_MAC_ALGS_CTOS:
352 		case PROPOSAL_MAC_ALGS_STOC:
353 			prop[i]  = xstrdup(macs ? macs : defprop[i]);
354 			break;
355 		case PROPOSAL_COMP_ALGS_CTOS:
356 		case PROPOSAL_COMP_ALGS_STOC:
357 			prop[i] = xstrdup(comp ? comp : defprop[i]);
358 			break;
359 		case PROPOSAL_SERVER_HOST_KEY_ALGS:
360 			prop[i] = xstrdup(hkalgs ? hkalgs : defprop[i]);
361 			break;
362 		default:
363 			prop[i] = xstrdup(defprop[i]);
364 		}
365 	}
366 }
367 
368 void
369 kex_proposal_free_entries(char *prop[PROPOSAL_MAX])
370 {
371 	u_int i;
372 
373 	for (i = 0; i < PROPOSAL_MAX; i++)
374 		free(prop[i]);
375 }
376 
377 /* put algorithm proposal into buffer */
378 int
379 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
380 {
381 	u_int i;
382 	int r;
383 
384 	sshbuf_reset(b);
385 
386 	/*
387 	 * add a dummy cookie, the cookie will be overwritten by
388 	 * kex_send_kexinit(), each time a kexinit is set
389 	 */
390 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
391 		if ((r = sshbuf_put_u8(b, 0)) != 0)
392 			return r;
393 	}
394 	for (i = 0; i < PROPOSAL_MAX; i++) {
395 		if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
396 			return r;
397 	}
398 	if ((r = sshbuf_put_u8(b, 0)) != 0 ||	/* first_kex_packet_follows */
399 	    (r = sshbuf_put_u32(b, 0)) != 0)	/* uint32 reserved */
400 		return r;
401 	return 0;
402 }
403 
404 /* parse buffer and return algorithm proposal */
405 int
406 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
407 {
408 	struct sshbuf *b = NULL;
409 	u_char v;
410 	u_int i;
411 	char **proposal = NULL;
412 	int r;
413 
414 	*propp = NULL;
415 	if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
416 		return SSH_ERR_ALLOC_FAIL;
417 	if ((b = sshbuf_fromb(raw)) == NULL) {
418 		r = SSH_ERR_ALLOC_FAIL;
419 		goto out;
420 	}
421 	if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) { /* skip cookie */
422 		error_fr(r, "consume cookie");
423 		goto out;
424 	}
425 	/* extract kex init proposal strings */
426 	for (i = 0; i < PROPOSAL_MAX; i++) {
427 		if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0) {
428 			error_fr(r, "parse proposal %u", i);
429 			goto out;
430 		}
431 		debug2("%s: %s", proposal_names[i], proposal[i]);
432 	}
433 	/* first kex follows / reserved */
434 	if ((r = sshbuf_get_u8(b, &v)) != 0 ||	/* first_kex_follows */
435 	    (r = sshbuf_get_u32(b, &i)) != 0) {	/* reserved */
436 		error_fr(r, "parse");
437 		goto out;
438 	}
439 	if (first_kex_follows != NULL)
440 		*first_kex_follows = v;
441 	debug2("first_kex_follows %d ", v);
442 	debug2("reserved %u ", i);
443 	r = 0;
444 	*propp = proposal;
445  out:
446 	if (r != 0 && proposal != NULL)
447 		kex_prop_free(proposal);
448 	sshbuf_free(b);
449 	return r;
450 }
451 
452 void
453 kex_prop_free(char **proposal)
454 {
455 	u_int i;
456 
457 	if (proposal == NULL)
458 		return;
459 	for (i = 0; i < PROPOSAL_MAX; i++)
460 		free(proposal[i]);
461 	free(proposal);
462 }
463 
464 int
465 kex_protocol_error(int type, u_int32_t seq, struct ssh *ssh)
466 {
467 	int r;
468 
469 	error("kex protocol error: type %d seq %u", type, seq);
470 	if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
471 	    (r = sshpkt_put_u32(ssh, seq)) != 0 ||
472 	    (r = sshpkt_send(ssh)) != 0)
473 		return r;
474 	return 0;
475 }
476 
477 static void
478 kex_reset_dispatch(struct ssh *ssh)
479 {
480 	ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
481 	    SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
482 }
483 
484 static int
485 kex_send_ext_info(struct ssh *ssh)
486 {
487 	int r;
488 	char *algs;
489 
490 	debug("Sending SSH2_MSG_EXT_INFO");
491 	if ((algs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
492 		return SSH_ERR_ALLOC_FAIL;
493 	/* XXX filter algs list by allowed pubkey/hostbased types */
494 	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
495 	    (r = sshpkt_put_u32(ssh, 3)) != 0 ||
496 	    (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
497 	    (r = sshpkt_put_cstring(ssh, algs)) != 0 ||
498 	    (r = sshpkt_put_cstring(ssh,
499 	    "publickey-hostbound@openssh.com")) != 0 ||
500 	    (r = sshpkt_put_cstring(ssh, "0")) != 0 ||
501 	    (r = sshpkt_put_cstring(ssh, "ping@openssh.com")) != 0 ||
502 	    (r = sshpkt_put_cstring(ssh, "0")) != 0 ||
503 	    (r = sshpkt_send(ssh)) != 0) {
504 		error_fr(r, "compose");
505 		goto out;
506 	}
507 	/* success */
508 	r = 0;
509  out:
510 	free(algs);
511 	return r;
512 }
513 
514 int
515 kex_send_newkeys(struct ssh *ssh)
516 {
517 	int r;
518 
519 	kex_reset_dispatch(ssh);
520 	if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
521 	    (r = sshpkt_send(ssh)) != 0)
522 		return r;
523 	debug("SSH2_MSG_NEWKEYS sent");
524 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
525 	if (ssh->kex->ext_info_c && (ssh->kex->flags & KEX_INITIAL) != 0)
526 		if ((r = kex_send_ext_info(ssh)) != 0)
527 			return r;
528 	debug("expecting SSH2_MSG_NEWKEYS");
529 	return 0;
530 }
531 
532 /* Check whether an ext_info value contains the expected version string */
533 static int
534 kex_ext_info_check_ver(struct kex *kex, const char *name,
535     const u_char *val, size_t len, const char *want_ver, u_int flag)
536 {
537 	if (memchr(val, '\0', len) != NULL) {
538 		error("SSH2_MSG_EXT_INFO: %s value contains nul byte", name);
539 		return SSH_ERR_INVALID_FORMAT;
540 	}
541 	debug_f("%s=<%s>", name, val);
542 	if (strcmp(val, want_ver) == 0)
543 		kex->flags |= flag;
544 	else
545 		debug_f("unsupported version of %s extension", name);
546 	return 0;
547 }
548 
549 int
550 kex_input_ext_info(int type, u_int32_t seq, struct ssh *ssh)
551 {
552 	struct kex *kex = ssh->kex;
553 	u_int32_t i, ninfo;
554 	char *name;
555 	u_char *val;
556 	size_t vlen;
557 	int r;
558 
559 	debug("SSH2_MSG_EXT_INFO received");
560 	ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
561 	if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
562 		return r;
563 	if (ninfo >= 1024) {
564 		error("SSH2_MSG_EXT_INFO with too many entries, expected "
565 		    "<=1024, received %u", ninfo);
566 		return SSH_ERR_INVALID_FORMAT;
567 	}
568 	for (i = 0; i < ninfo; i++) {
569 		if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
570 			return r;
571 		if ((r = sshpkt_get_string(ssh, &val, &vlen)) != 0) {
572 			free(name);
573 			return r;
574 		}
575 		if (strcmp(name, "server-sig-algs") == 0) {
576 			/* Ensure no \0 lurking in value */
577 			if (memchr(val, '\0', vlen) != NULL) {
578 				error_f("nul byte in %s", name);
579 				free(name);
580 				free(val);
581 				return SSH_ERR_INVALID_FORMAT;
582 			}
583 			debug_f("%s=<%s>", name, val);
584 			kex->server_sig_algs = val;
585 			val = NULL;
586 		} else if (strcmp(name,
587 		    "publickey-hostbound@openssh.com") == 0) {
588 			if ((r = kex_ext_info_check_ver(kex, name, val, vlen,
589 			    "0", KEX_HAS_PUBKEY_HOSTBOUND)) != 0) {
590 				free(name);
591 				free(val);
592 				return r;
593 			}
594 		} else if (strcmp(name, "ping@openssh.com") == 0) {
595 			if ((r = kex_ext_info_check_ver(kex, name, val, vlen,
596 			    "0", KEX_HAS_PING)) != 0) {
597 				free(name);
598 				free(val);
599 				return r;
600 			}
601 		} else
602 			debug_f("%s (unrecognised)", name);
603 		free(name);
604 		free(val);
605 	}
606 	return sshpkt_get_end(ssh);
607 }
608 
609 static int
610 kex_input_newkeys(int type, u_int32_t seq, struct ssh *ssh)
611 {
612 	struct kex *kex = ssh->kex;
613 	int r;
614 
615 	debug("SSH2_MSG_NEWKEYS received");
616 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
617 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
618 	if ((r = sshpkt_get_end(ssh)) != 0)
619 		return r;
620 	if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0)
621 		return r;
622 	kex->done = 1;
623 	kex->flags &= ~KEX_INITIAL;
624 	sshbuf_reset(kex->peer);
625 	/* sshbuf_reset(kex->my); */
626 	kex->flags &= ~KEX_INIT_SENT;
627 	free(kex->name);
628 	kex->name = NULL;
629 	return 0;
630 }
631 
632 int
633 kex_send_kexinit(struct ssh *ssh)
634 {
635 	u_char *cookie;
636 	struct kex *kex = ssh->kex;
637 	int r;
638 
639 	if (kex == NULL) {
640 		error_f("no kex");
641 		return SSH_ERR_INTERNAL_ERROR;
642 	}
643 	if (kex->flags & KEX_INIT_SENT)
644 		return 0;
645 	kex->done = 0;
646 
647 	/* generate a random cookie */
648 	if (sshbuf_len(kex->my) < KEX_COOKIE_LEN) {
649 		error_f("bad kex length: %zu < %d",
650 		    sshbuf_len(kex->my), KEX_COOKIE_LEN);
651 		return SSH_ERR_INVALID_FORMAT;
652 	}
653 	if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL) {
654 		error_f("buffer error");
655 		return SSH_ERR_INTERNAL_ERROR;
656 	}
657 	arc4random_buf(cookie, KEX_COOKIE_LEN);
658 
659 	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
660 	    (r = sshpkt_putb(ssh, kex->my)) != 0 ||
661 	    (r = sshpkt_send(ssh)) != 0) {
662 		error_fr(r, "compose reply");
663 		return r;
664 	}
665 	debug("SSH2_MSG_KEXINIT sent");
666 	kex->flags |= KEX_INIT_SENT;
667 	return 0;
668 }
669 
670 int
671 kex_input_kexinit(int type, u_int32_t seq, struct ssh *ssh)
672 {
673 	struct kex *kex = ssh->kex;
674 	const u_char *ptr;
675 	u_int i;
676 	size_t dlen;
677 	int r;
678 
679 	debug("SSH2_MSG_KEXINIT received");
680 	if (kex == NULL) {
681 		error_f("no kex");
682 		return SSH_ERR_INTERNAL_ERROR;
683 	}
684 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, NULL);
685 	ptr = sshpkt_ptr(ssh, &dlen);
686 	if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
687 		return r;
688 
689 	/* discard packet */
690 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
691 		if ((r = sshpkt_get_u8(ssh, NULL)) != 0) {
692 			error_fr(r, "discard cookie");
693 			return r;
694 		}
695 	}
696 	for (i = 0; i < PROPOSAL_MAX; i++) {
697 		if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0) {
698 			error_fr(r, "discard proposal");
699 			return r;
700 		}
701 	}
702 	/*
703 	 * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
704 	 * KEX method has the server move first, but a server might be using
705 	 * a custom method or one that we otherwise don't support. We should
706 	 * be prepared to remember first_kex_follows here so we can eat a
707 	 * packet later.
708 	 * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
709 	 * for cases where the server *doesn't* go first. I guess we should
710 	 * ignore it when it is set for these cases, which is what we do now.
711 	 */
712 	if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||	/* first_kex_follows */
713 	    (r = sshpkt_get_u32(ssh, NULL)) != 0 ||	/* reserved */
714 	    (r = sshpkt_get_end(ssh)) != 0)
715 			return r;
716 
717 	if (!(kex->flags & KEX_INIT_SENT))
718 		if ((r = kex_send_kexinit(ssh)) != 0)
719 			return r;
720 	if ((r = kex_choose_conf(ssh)) != 0)
721 		return r;
722 
723 	if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
724 		return (kex->kex[kex->kex_type])(ssh);
725 
726 	error_f("unknown kex type %u", kex->kex_type);
727 	return SSH_ERR_INTERNAL_ERROR;
728 }
729 
730 struct kex *
731 kex_new(void)
732 {
733 	struct kex *kex;
734 
735 	if ((kex = calloc(1, sizeof(*kex))) == NULL ||
736 	    (kex->peer = sshbuf_new()) == NULL ||
737 	    (kex->my = sshbuf_new()) == NULL ||
738 	    (kex->client_version = sshbuf_new()) == NULL ||
739 	    (kex->server_version = sshbuf_new()) == NULL ||
740 	    (kex->session_id = sshbuf_new()) == NULL) {
741 		kex_free(kex);
742 		return NULL;
743 	}
744 	return kex;
745 }
746 
747 void
748 kex_free_newkeys(struct newkeys *newkeys)
749 {
750 	if (newkeys == NULL)
751 		return;
752 	if (newkeys->enc.key) {
753 		explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
754 		free(newkeys->enc.key);
755 		newkeys->enc.key = NULL;
756 	}
757 	if (newkeys->enc.iv) {
758 		explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
759 		free(newkeys->enc.iv);
760 		newkeys->enc.iv = NULL;
761 	}
762 	free(newkeys->enc.name);
763 	explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
764 	free(newkeys->comp.name);
765 	explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
766 	mac_clear(&newkeys->mac);
767 	if (newkeys->mac.key) {
768 		explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
769 		free(newkeys->mac.key);
770 		newkeys->mac.key = NULL;
771 	}
772 	free(newkeys->mac.name);
773 	explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
774 	freezero(newkeys, sizeof(*newkeys));
775 }
776 
777 void
778 kex_free(struct kex *kex)
779 {
780 	u_int mode;
781 
782 	if (kex == NULL)
783 		return;
784 
785 #ifdef WITH_OPENSSL
786 	DH_free(kex->dh);
787 #ifdef OPENSSL_HAS_ECC
788 	EC_KEY_free(kex->ec_client_key);
789 #endif /* OPENSSL_HAS_ECC */
790 #endif /* WITH_OPENSSL */
791 	for (mode = 0; mode < MODE_MAX; mode++) {
792 		kex_free_newkeys(kex->newkeys[mode]);
793 		kex->newkeys[mode] = NULL;
794 	}
795 	sshbuf_free(kex->peer);
796 	sshbuf_free(kex->my);
797 	sshbuf_free(kex->client_version);
798 	sshbuf_free(kex->server_version);
799 	sshbuf_free(kex->client_pub);
800 	sshbuf_free(kex->session_id);
801 	sshbuf_free(kex->initial_sig);
802 	sshkey_free(kex->initial_hostkey);
803 	free(kex->failed_choice);
804 	free(kex->hostkey_alg);
805 	free(kex->name);
806 	free(kex);
807 }
808 
809 int
810 kex_ready(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
811 {
812 	int r;
813 
814 	if ((r = kex_prop2buf(ssh->kex->my, proposal)) != 0)
815 		return r;
816 	ssh->kex->flags = KEX_INITIAL;
817 	kex_reset_dispatch(ssh);
818 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
819 	return 0;
820 }
821 
822 int
823 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
824 {
825 	int r;
826 
827 	if ((r = kex_ready(ssh, proposal)) != 0)
828 		return r;
829 	if ((r = kex_send_kexinit(ssh)) != 0) {		/* we start */
830 		kex_free(ssh->kex);
831 		ssh->kex = NULL;
832 		return r;
833 	}
834 	return 0;
835 }
836 
837 /*
838  * Request key re-exchange, returns 0 on success or a ssherr.h error
839  * code otherwise. Must not be called if KEX is incomplete or in-progress.
840  */
841 int
842 kex_start_rekex(struct ssh *ssh)
843 {
844 	if (ssh->kex == NULL) {
845 		error_f("no kex");
846 		return SSH_ERR_INTERNAL_ERROR;
847 	}
848 	if (ssh->kex->done == 0) {
849 		error_f("requested twice");
850 		return SSH_ERR_INTERNAL_ERROR;
851 	}
852 	ssh->kex->done = 0;
853 	return kex_send_kexinit(ssh);
854 }
855 
856 static int
857 choose_enc(struct sshenc *enc, char *client, char *server)
858 {
859 	char *name = match_list(client, server, NULL);
860 
861 	if (name == NULL)
862 		return SSH_ERR_NO_CIPHER_ALG_MATCH;
863 	if ((enc->cipher = cipher_by_name(name)) == NULL) {
864 		error_f("unsupported cipher %s", name);
865 		free(name);
866 		return SSH_ERR_INTERNAL_ERROR;
867 	}
868 	enc->name = name;
869 	enc->enabled = 0;
870 	enc->iv = NULL;
871 	enc->iv_len = cipher_ivlen(enc->cipher);
872 	enc->key = NULL;
873 	enc->key_len = cipher_keylen(enc->cipher);
874 	enc->block_size = cipher_blocksize(enc->cipher);
875 	return 0;
876 }
877 
878 static int
879 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
880 {
881 	char *name = match_list(client, server, NULL);
882 
883 	if (name == NULL)
884 		return SSH_ERR_NO_MAC_ALG_MATCH;
885 	if (mac_setup(mac, name) < 0) {
886 		error_f("unsupported MAC %s", name);
887 		free(name);
888 		return SSH_ERR_INTERNAL_ERROR;
889 	}
890 	mac->name = name;
891 	mac->key = NULL;
892 	mac->enabled = 0;
893 	return 0;
894 }
895 
896 static int
897 choose_comp(struct sshcomp *comp, char *client, char *server)
898 {
899 	char *name = match_list(client, server, NULL);
900 
901 	if (name == NULL)
902 		return SSH_ERR_NO_COMPRESS_ALG_MATCH;
903 #ifdef WITH_ZLIB
904 	if (strcmp(name, "zlib@openssh.com") == 0) {
905 		comp->type = COMP_DELAYED;
906 	} else if (strcmp(name, "zlib") == 0) {
907 		comp->type = COMP_ZLIB;
908 	} else
909 #endif	/* WITH_ZLIB */
910 	if (strcmp(name, "none") == 0) {
911 		comp->type = COMP_NONE;
912 	} else {
913 		error_f("unsupported compression scheme %s", name);
914 		free(name);
915 		return SSH_ERR_INTERNAL_ERROR;
916 	}
917 	comp->name = name;
918 	return 0;
919 }
920 
921 static int
922 choose_kex(struct kex *k, char *client, char *server)
923 {
924 	const struct kexalg *kexalg;
925 
926 	k->name = match_list(client, server, NULL);
927 
928 	debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
929 	if (k->name == NULL)
930 		return SSH_ERR_NO_KEX_ALG_MATCH;
931 	if ((kexalg = kex_alg_by_name(k->name)) == NULL) {
932 		error_f("unsupported KEX method %s", k->name);
933 		return SSH_ERR_INTERNAL_ERROR;
934 	}
935 	k->kex_type = kexalg->type;
936 	k->hash_alg = kexalg->hash_alg;
937 	k->ec_nid = kexalg->ec_nid;
938 	return 0;
939 }
940 
941 static int
942 choose_hostkeyalg(struct kex *k, char *client, char *server)
943 {
944 	free(k->hostkey_alg);
945 	k->hostkey_alg = match_list(client, server, NULL);
946 
947 	debug("kex: host key algorithm: %s",
948 	    k->hostkey_alg ? k->hostkey_alg : "(no match)");
949 	if (k->hostkey_alg == NULL)
950 		return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
951 	k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
952 	if (k->hostkey_type == KEY_UNSPEC) {
953 		error_f("unsupported hostkey algorithm %s", k->hostkey_alg);
954 		return SSH_ERR_INTERNAL_ERROR;
955 	}
956 	k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
957 	return 0;
958 }
959 
960 static int
961 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
962 {
963 	static int check[] = {
964 		PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
965 	};
966 	int *idx;
967 	char *p;
968 
969 	for (idx = &check[0]; *idx != -1; idx++) {
970 		if ((p = strchr(my[*idx], ',')) != NULL)
971 			*p = '\0';
972 		if ((p = strchr(peer[*idx], ',')) != NULL)
973 			*p = '\0';
974 		if (strcmp(my[*idx], peer[*idx]) != 0) {
975 			debug2("proposal mismatch: my %s peer %s",
976 			    my[*idx], peer[*idx]);
977 			return (0);
978 		}
979 	}
980 	debug2("proposals match");
981 	return (1);
982 }
983 
984 /* returns non-zero if proposal contains any algorithm from algs */
985 static int
986 has_any_alg(const char *proposal, const char *algs)
987 {
988 	char *cp;
989 
990 	if ((cp = match_list(proposal, algs, NULL)) == NULL)
991 		return 0;
992 	free(cp);
993 	return 1;
994 }
995 
996 static int
997 kex_choose_conf(struct ssh *ssh)
998 {
999 	struct kex *kex = ssh->kex;
1000 	struct newkeys *newkeys;
1001 	char **my = NULL, **peer = NULL;
1002 	char **cprop, **sprop;
1003 	int nenc, nmac, ncomp;
1004 	u_int mode, ctos, need, dh_need, authlen;
1005 	int r, first_kex_follows;
1006 
1007 	debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
1008 	if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
1009 		goto out;
1010 	debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
1011 	if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
1012 		goto out;
1013 
1014 	if (kex->server) {
1015 		cprop=peer;
1016 		sprop=my;
1017 	} else {
1018 		cprop=my;
1019 		sprop=peer;
1020 	}
1021 
1022 	/* Check whether client supports ext_info_c */
1023 	if (kex->server && (kex->flags & KEX_INITIAL)) {
1024 		char *ext;
1025 
1026 		ext = match_list("ext-info-c", peer[PROPOSAL_KEX_ALGS], NULL);
1027 		kex->ext_info_c = (ext != NULL);
1028 		free(ext);
1029 	}
1030 
1031 	/* Check whether client supports rsa-sha2 algorithms */
1032 	if (kex->server && (kex->flags & KEX_INITIAL)) {
1033 		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1034 		    "rsa-sha2-256,rsa-sha2-256-cert-v01@openssh.com"))
1035 			kex->flags |= KEX_RSA_SHA2_256_SUPPORTED;
1036 		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1037 		    "rsa-sha2-512,rsa-sha2-512-cert-v01@openssh.com"))
1038 			kex->flags |= KEX_RSA_SHA2_512_SUPPORTED;
1039 	}
1040 
1041 	/* Algorithm Negotiation */
1042 	if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
1043 	    sprop[PROPOSAL_KEX_ALGS])) != 0) {
1044 		kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
1045 		peer[PROPOSAL_KEX_ALGS] = NULL;
1046 		goto out;
1047 	}
1048 	if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
1049 	    sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
1050 		kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
1051 		peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
1052 		goto out;
1053 	}
1054 	for (mode = 0; mode < MODE_MAX; mode++) {
1055 		if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
1056 			r = SSH_ERR_ALLOC_FAIL;
1057 			goto out;
1058 		}
1059 		kex->newkeys[mode] = newkeys;
1060 		ctos = (!kex->server && mode == MODE_OUT) ||
1061 		    (kex->server && mode == MODE_IN);
1062 		nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
1063 		nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
1064 		ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
1065 		if ((r = choose_enc(&newkeys->enc, cprop[nenc],
1066 		    sprop[nenc])) != 0) {
1067 			kex->failed_choice = peer[nenc];
1068 			peer[nenc] = NULL;
1069 			goto out;
1070 		}
1071 		authlen = cipher_authlen(newkeys->enc.cipher);
1072 		/* ignore mac for authenticated encryption */
1073 		if (authlen == 0 &&
1074 		    (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
1075 		    sprop[nmac])) != 0) {
1076 			kex->failed_choice = peer[nmac];
1077 			peer[nmac] = NULL;
1078 			goto out;
1079 		}
1080 		if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
1081 		    sprop[ncomp])) != 0) {
1082 			kex->failed_choice = peer[ncomp];
1083 			peer[ncomp] = NULL;
1084 			goto out;
1085 		}
1086 		debug("kex: %s cipher: %s MAC: %s compression: %s",
1087 		    ctos ? "client->server" : "server->client",
1088 		    newkeys->enc.name,
1089 		    authlen == 0 ? newkeys->mac.name : "<implicit>",
1090 		    newkeys->comp.name);
1091 	}
1092 	need = dh_need = 0;
1093 	for (mode = 0; mode < MODE_MAX; mode++) {
1094 		newkeys = kex->newkeys[mode];
1095 		need = MAXIMUM(need, newkeys->enc.key_len);
1096 		need = MAXIMUM(need, newkeys->enc.block_size);
1097 		need = MAXIMUM(need, newkeys->enc.iv_len);
1098 		need = MAXIMUM(need, newkeys->mac.key_len);
1099 		dh_need = MAXIMUM(dh_need, cipher_seclen(newkeys->enc.cipher));
1100 		dh_need = MAXIMUM(dh_need, newkeys->enc.block_size);
1101 		dh_need = MAXIMUM(dh_need, newkeys->enc.iv_len);
1102 		dh_need = MAXIMUM(dh_need, newkeys->mac.key_len);
1103 	}
1104 	/* XXX need runden? */
1105 	kex->we_need = need;
1106 	kex->dh_need = dh_need;
1107 
1108 	/* ignore the next message if the proposals do not match */
1109 	if (first_kex_follows && !proposals_match(my, peer))
1110 		ssh->dispatch_skip_packets = 1;
1111 	r = 0;
1112  out:
1113 	kex_prop_free(my);
1114 	kex_prop_free(peer);
1115 	return r;
1116 }
1117 
1118 static int
1119 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
1120     const struct sshbuf *shared_secret, u_char **keyp)
1121 {
1122 	struct kex *kex = ssh->kex;
1123 	struct ssh_digest_ctx *hashctx = NULL;
1124 	char c = id;
1125 	u_int have;
1126 	size_t mdsz;
1127 	u_char *digest;
1128 	int r;
1129 
1130 	if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
1131 		return SSH_ERR_INVALID_ARGUMENT;
1132 	if ((digest = calloc(1, ROUNDUP(need, mdsz))) == NULL) {
1133 		r = SSH_ERR_ALLOC_FAIL;
1134 		goto out;
1135 	}
1136 
1137 	/* K1 = HASH(K || H || "A" || session_id) */
1138 	if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1139 	    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1140 	    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1141 	    ssh_digest_update(hashctx, &c, 1) != 0 ||
1142 	    ssh_digest_update_buffer(hashctx, kex->session_id) != 0 ||
1143 	    ssh_digest_final(hashctx, digest, mdsz) != 0) {
1144 		r = SSH_ERR_LIBCRYPTO_ERROR;
1145 		error_f("KEX hash failed");
1146 		goto out;
1147 	}
1148 	ssh_digest_free(hashctx);
1149 	hashctx = NULL;
1150 
1151 	/*
1152 	 * expand key:
1153 	 * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
1154 	 * Key = K1 || K2 || ... || Kn
1155 	 */
1156 	for (have = mdsz; need > have; have += mdsz) {
1157 		if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1158 		    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1159 		    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1160 		    ssh_digest_update(hashctx, digest, have) != 0 ||
1161 		    ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
1162 			error_f("KDF failed");
1163 			r = SSH_ERR_LIBCRYPTO_ERROR;
1164 			goto out;
1165 		}
1166 		ssh_digest_free(hashctx);
1167 		hashctx = NULL;
1168 	}
1169 #ifdef DEBUG_KEX
1170 	fprintf(stderr, "key '%c'== ", c);
1171 	dump_digest("key", digest, need);
1172 #endif
1173 	*keyp = digest;
1174 	digest = NULL;
1175 	r = 0;
1176  out:
1177 	free(digest);
1178 	ssh_digest_free(hashctx);
1179 	return r;
1180 }
1181 
1182 #define NKEYS	6
1183 int
1184 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
1185     const struct sshbuf *shared_secret)
1186 {
1187 	struct kex *kex = ssh->kex;
1188 	u_char *keys[NKEYS];
1189 	u_int i, j, mode, ctos;
1190 	int r;
1191 
1192 	/* save initial hash as session id */
1193 	if ((kex->flags & KEX_INITIAL) != 0) {
1194 		if (sshbuf_len(kex->session_id) != 0) {
1195 			error_f("already have session ID at kex");
1196 			return SSH_ERR_INTERNAL_ERROR;
1197 		}
1198 		if ((r = sshbuf_put(kex->session_id, hash, hashlen)) != 0)
1199 			return r;
1200 	} else if (sshbuf_len(kex->session_id) == 0) {
1201 		error_f("no session ID in rekex");
1202 		return SSH_ERR_INTERNAL_ERROR;
1203 	}
1204 	for (i = 0; i < NKEYS; i++) {
1205 		if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
1206 		    shared_secret, &keys[i])) != 0) {
1207 			for (j = 0; j < i; j++)
1208 				free(keys[j]);
1209 			return r;
1210 		}
1211 	}
1212 	for (mode = 0; mode < MODE_MAX; mode++) {
1213 		ctos = (!kex->server && mode == MODE_OUT) ||
1214 		    (kex->server && mode == MODE_IN);
1215 		kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
1216 		kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
1217 		kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
1218 	}
1219 	return 0;
1220 }
1221 
1222 int
1223 kex_load_hostkey(struct ssh *ssh, struct sshkey **prvp, struct sshkey **pubp)
1224 {
1225 	struct kex *kex = ssh->kex;
1226 
1227 	*pubp = NULL;
1228 	*prvp = NULL;
1229 	if (kex->load_host_public_key == NULL ||
1230 	    kex->load_host_private_key == NULL) {
1231 		error_f("missing hostkey loader");
1232 		return SSH_ERR_INVALID_ARGUMENT;
1233 	}
1234 	*pubp = kex->load_host_public_key(kex->hostkey_type,
1235 	    kex->hostkey_nid, ssh);
1236 	*prvp = kex->load_host_private_key(kex->hostkey_type,
1237 	    kex->hostkey_nid, ssh);
1238 	if (*pubp == NULL)
1239 		return SSH_ERR_NO_HOSTKEY_LOADED;
1240 	return 0;
1241 }
1242 
1243 int
1244 kex_verify_host_key(struct ssh *ssh, struct sshkey *server_host_key)
1245 {
1246 	struct kex *kex = ssh->kex;
1247 
1248 	if (kex->verify_host_key == NULL) {
1249 		error_f("missing hostkey verifier");
1250 		return SSH_ERR_INVALID_ARGUMENT;
1251 	}
1252 	if (server_host_key->type != kex->hostkey_type ||
1253 	    (kex->hostkey_type == KEY_ECDSA &&
1254 	    server_host_key->ecdsa_nid != kex->hostkey_nid))
1255 		return SSH_ERR_KEY_TYPE_MISMATCH;
1256 	if (kex->verify_host_key(server_host_key, ssh) == -1)
1257 		return  SSH_ERR_SIGNATURE_INVALID;
1258 	return 0;
1259 }
1260 
1261 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1262 void
1263 dump_digest(const char *msg, const u_char *digest, int len)
1264 {
1265 	fprintf(stderr, "%s\n", msg);
1266 	sshbuf_dump_data(digest, len, stderr);
1267 }
1268 #endif
1269 
1270 /*
1271  * Send a plaintext error message to the peer, suffixed by \r\n.
1272  * Only used during banner exchange, and there only for the server.
1273  */
1274 static void
1275 send_error(struct ssh *ssh, char *msg)
1276 {
1277 	char *crnl = "\r\n";
1278 
1279 	if (!ssh->kex->server)
1280 		return;
1281 
1282 	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1283 	    msg, strlen(msg)) != strlen(msg) ||
1284 	    atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1285 	    crnl, strlen(crnl)) != strlen(crnl))
1286 		error_f("write: %.100s", strerror(errno));
1287 }
1288 
1289 /*
1290  * Sends our identification string and waits for the peer's. Will block for
1291  * up to timeout_ms (or indefinitely if timeout_ms <= 0).
1292  * Returns on 0 success or a ssherr.h code on failure.
1293  */
1294 int
1295 kex_exchange_identification(struct ssh *ssh, int timeout_ms,
1296     const char *version_addendum)
1297 {
1298 	int remote_major, remote_minor, mismatch, oerrno = 0;
1299 	size_t len, n;
1300 	int r, expect_nl;
1301 	u_char c;
1302 	struct sshbuf *our_version = ssh->kex->server ?
1303 	    ssh->kex->server_version : ssh->kex->client_version;
1304 	struct sshbuf *peer_version = ssh->kex->server ?
1305 	    ssh->kex->client_version : ssh->kex->server_version;
1306 	char *our_version_string = NULL, *peer_version_string = NULL;
1307 	char *cp, *remote_version = NULL;
1308 
1309 	/* Prepare and send our banner */
1310 	sshbuf_reset(our_version);
1311 	if (version_addendum != NULL && *version_addendum == '\0')
1312 		version_addendum = NULL;
1313 	if ((r = sshbuf_putf(our_version, "SSH-%d.%d-%.100s%s%s\r\n",
1314 	    PROTOCOL_MAJOR_2, PROTOCOL_MINOR_2, SSH_VERSION,
1315 	    version_addendum == NULL ? "" : " ",
1316 	    version_addendum == NULL ? "" : version_addendum)) != 0) {
1317 		oerrno = errno;
1318 		error_fr(r, "sshbuf_putf");
1319 		goto out;
1320 	}
1321 
1322 	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1323 	    sshbuf_mutable_ptr(our_version),
1324 	    sshbuf_len(our_version)) != sshbuf_len(our_version)) {
1325 		oerrno = errno;
1326 		debug_f("write: %.100s", strerror(errno));
1327 		r = SSH_ERR_SYSTEM_ERROR;
1328 		goto out;
1329 	}
1330 	if ((r = sshbuf_consume_end(our_version, 2)) != 0) { /* trim \r\n */
1331 		oerrno = errno;
1332 		error_fr(r, "sshbuf_consume_end");
1333 		goto out;
1334 	}
1335 	our_version_string = sshbuf_dup_string(our_version);
1336 	if (our_version_string == NULL) {
1337 		error_f("sshbuf_dup_string failed");
1338 		r = SSH_ERR_ALLOC_FAIL;
1339 		goto out;
1340 	}
1341 	debug("Local version string %.100s", our_version_string);
1342 
1343 	/* Read other side's version identification. */
1344 	for (n = 0; ; n++) {
1345 		if (n >= SSH_MAX_PRE_BANNER_LINES) {
1346 			send_error(ssh, "No SSH identification string "
1347 			    "received.");
1348 			error_f("No SSH version received in first %u lines "
1349 			    "from server", SSH_MAX_PRE_BANNER_LINES);
1350 			r = SSH_ERR_INVALID_FORMAT;
1351 			goto out;
1352 		}
1353 		sshbuf_reset(peer_version);
1354 		expect_nl = 0;
1355 		for (;;) {
1356 			if (timeout_ms > 0) {
1357 				r = waitrfd(ssh_packet_get_connection_in(ssh),
1358 				    &timeout_ms, NULL);
1359 				if (r == -1 && errno == ETIMEDOUT) {
1360 					send_error(ssh, "Timed out waiting "
1361 					    "for SSH identification string.");
1362 					error("Connection timed out during "
1363 					    "banner exchange");
1364 					r = SSH_ERR_CONN_TIMEOUT;
1365 					goto out;
1366 				} else if (r == -1) {
1367 					oerrno = errno;
1368 					error_f("%s", strerror(errno));
1369 					r = SSH_ERR_SYSTEM_ERROR;
1370 					goto out;
1371 				}
1372 			}
1373 
1374 			len = atomicio(read, ssh_packet_get_connection_in(ssh),
1375 			    &c, 1);
1376 			if (len != 1 && errno == EPIPE) {
1377 				verbose_f("Connection closed by remote host");
1378 				r = SSH_ERR_CONN_CLOSED;
1379 				goto out;
1380 			} else if (len != 1) {
1381 				oerrno = errno;
1382 				error_f("read: %.100s", strerror(errno));
1383 				r = SSH_ERR_SYSTEM_ERROR;
1384 				goto out;
1385 			}
1386 			if (c == '\r') {
1387 				expect_nl = 1;
1388 				continue;
1389 			}
1390 			if (c == '\n')
1391 				break;
1392 			if (c == '\0' || expect_nl) {
1393 				verbose_f("banner line contains invalid "
1394 				    "characters");
1395 				goto invalid;
1396 			}
1397 			if ((r = sshbuf_put_u8(peer_version, c)) != 0) {
1398 				oerrno = errno;
1399 				error_fr(r, "sshbuf_put");
1400 				goto out;
1401 			}
1402 			if (sshbuf_len(peer_version) > SSH_MAX_BANNER_LEN) {
1403 				verbose_f("banner line too long");
1404 				goto invalid;
1405 			}
1406 		}
1407 		/* Is this an actual protocol banner? */
1408 		if (sshbuf_len(peer_version) > 4 &&
1409 		    memcmp(sshbuf_ptr(peer_version), "SSH-", 4) == 0)
1410 			break;
1411 		/* If not, then just log the line and continue */
1412 		if ((cp = sshbuf_dup_string(peer_version)) == NULL) {
1413 			error_f("sshbuf_dup_string failed");
1414 			r = SSH_ERR_ALLOC_FAIL;
1415 			goto out;
1416 		}
1417 		/* Do not accept lines before the SSH ident from a client */
1418 		if (ssh->kex->server) {
1419 			verbose_f("client sent invalid protocol identifier "
1420 			    "\"%.256s\"", cp);
1421 			free(cp);
1422 			goto invalid;
1423 		}
1424 		debug_f("banner line %zu: %s", n, cp);
1425 		free(cp);
1426 	}
1427 	peer_version_string = sshbuf_dup_string(peer_version);
1428 	if (peer_version_string == NULL)
1429 		fatal_f("sshbuf_dup_string failed");
1430 	/* XXX must be same size for sscanf */
1431 	if ((remote_version = calloc(1, sshbuf_len(peer_version))) == NULL) {
1432 		error_f("calloc failed");
1433 		r = SSH_ERR_ALLOC_FAIL;
1434 		goto out;
1435 	}
1436 
1437 	/*
1438 	 * Check that the versions match.  In future this might accept
1439 	 * several versions and set appropriate flags to handle them.
1440 	 */
1441 	if (sscanf(peer_version_string, "SSH-%d.%d-%[^\n]\n",
1442 	    &remote_major, &remote_minor, remote_version) != 3) {
1443 		error("Bad remote protocol version identification: '%.100s'",
1444 		    peer_version_string);
1445  invalid:
1446 		send_error(ssh, "Invalid SSH identification string.");
1447 		r = SSH_ERR_INVALID_FORMAT;
1448 		goto out;
1449 	}
1450 	debug("Remote protocol version %d.%d, remote software version %.100s",
1451 	    remote_major, remote_minor, remote_version);
1452 	compat_banner(ssh, remote_version);
1453 
1454 	mismatch = 0;
1455 	switch (remote_major) {
1456 	case 2:
1457 		break;
1458 	case 1:
1459 		if (remote_minor != 99)
1460 			mismatch = 1;
1461 		break;
1462 	default:
1463 		mismatch = 1;
1464 		break;
1465 	}
1466 	if (mismatch) {
1467 		error("Protocol major versions differ: %d vs. %d",
1468 		    PROTOCOL_MAJOR_2, remote_major);
1469 		send_error(ssh, "Protocol major versions differ.");
1470 		r = SSH_ERR_NO_PROTOCOL_VERSION;
1471 		goto out;
1472 	}
1473 
1474 	if (ssh->kex->server && (ssh->compat & SSH_BUG_PROBE) != 0) {
1475 		logit("probed from %s port %d with %s.  Don't panic.",
1476 		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1477 		    peer_version_string);
1478 		r = SSH_ERR_CONN_CLOSED; /* XXX */
1479 		goto out;
1480 	}
1481 	if (ssh->kex->server && (ssh->compat & SSH_BUG_SCANNER) != 0) {
1482 		logit("scanned from %s port %d with %s.  Don't panic.",
1483 		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1484 		    peer_version_string);
1485 		r = SSH_ERR_CONN_CLOSED; /* XXX */
1486 		goto out;
1487 	}
1488 	/* success */
1489 	r = 0;
1490  out:
1491 	free(our_version_string);
1492 	free(peer_version_string);
1493 	free(remote_version);
1494 	if (r == SSH_ERR_SYSTEM_ERROR)
1495 		errno = oerrno;
1496 	return r;
1497 }
1498 
1499