xref: /freebsd/crypto/openssh/kex.c (revision 38c63bdc46252d4d8cd313dff4183ec4546d26d9)
1 /* $OpenBSD: kex.c,v 1.181 2023/08/28 03:28:43 djm Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25 
26 #include "includes.h"
27 
28 #include <sys/types.h>
29 #include <errno.h>
30 #include <signal.h>
31 #include <stdarg.h>
32 #include <stdio.h>
33 #include <stdlib.h>
34 #include <string.h>
35 #include <unistd.h>
36 #ifdef HAVE_POLL_H
37 #include <poll.h>
38 #endif
39 
40 #ifdef WITH_OPENSSL
41 #include <openssl/crypto.h>
42 #include <openssl/dh.h>
43 #endif
44 
45 #include "ssh.h"
46 #include "ssh2.h"
47 #include "atomicio.h"
48 #include "version.h"
49 #include "packet.h"
50 #include "compat.h"
51 #include "cipher.h"
52 #include "sshkey.h"
53 #include "kex.h"
54 #include "log.h"
55 #include "mac.h"
56 #include "match.h"
57 #include "misc.h"
58 #include "dispatch.h"
59 #include "monitor.h"
60 #include "myproposal.h"
61 
62 #include "ssherr.h"
63 #include "sshbuf.h"
64 #include "digest.h"
65 #include "xmalloc.h"
66 
67 /* prototype */
68 static int kex_choose_conf(struct ssh *, uint32_t seq);
69 static int kex_input_newkeys(int, u_int32_t, struct ssh *);
70 
71 static const char * const proposal_names[PROPOSAL_MAX] = {
72 	"KEX algorithms",
73 	"host key algorithms",
74 	"ciphers ctos",
75 	"ciphers stoc",
76 	"MACs ctos",
77 	"MACs stoc",
78 	"compression ctos",
79 	"compression stoc",
80 	"languages ctos",
81 	"languages stoc",
82 };
83 
84 struct kexalg {
85 	char *name;
86 	u_int type;
87 	int ec_nid;
88 	int hash_alg;
89 };
90 static const struct kexalg kexalgs[] = {
91 #ifdef WITH_OPENSSL
92 	{ KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
93 	{ KEX_DH14_SHA1, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
94 	{ KEX_DH14_SHA256, KEX_DH_GRP14_SHA256, 0, SSH_DIGEST_SHA256 },
95 	{ KEX_DH16_SHA512, KEX_DH_GRP16_SHA512, 0, SSH_DIGEST_SHA512 },
96 	{ KEX_DH18_SHA512, KEX_DH_GRP18_SHA512, 0, SSH_DIGEST_SHA512 },
97 	{ KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
98 #ifdef HAVE_EVP_SHA256
99 	{ KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
100 #endif /* HAVE_EVP_SHA256 */
101 #ifdef OPENSSL_HAS_ECC
102 	{ KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
103 	    NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
104 	{ KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
105 	    SSH_DIGEST_SHA384 },
106 # ifdef OPENSSL_HAS_NISTP521
107 	{ KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
108 	    SSH_DIGEST_SHA512 },
109 # endif /* OPENSSL_HAS_NISTP521 */
110 #endif /* OPENSSL_HAS_ECC */
111 #endif /* WITH_OPENSSL */
112 #if defined(HAVE_EVP_SHA256) || !defined(WITH_OPENSSL)
113 	{ KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
114 	{ KEX_CURVE25519_SHA256_OLD, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
115 #ifdef USE_SNTRUP761X25519
116 	{ KEX_SNTRUP761X25519_SHA512, KEX_KEM_SNTRUP761X25519_SHA512, 0,
117 	    SSH_DIGEST_SHA512 },
118 #endif
119 #endif /* HAVE_EVP_SHA256 || !WITH_OPENSSL */
120 	{ NULL, 0, -1, -1},
121 };
122 
123 char *
124 kex_alg_list(char sep)
125 {
126 	char *ret = NULL, *tmp;
127 	size_t nlen, rlen = 0;
128 	const struct kexalg *k;
129 
130 	for (k = kexalgs; k->name != NULL; k++) {
131 		if (ret != NULL)
132 			ret[rlen++] = sep;
133 		nlen = strlen(k->name);
134 		if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
135 			free(ret);
136 			return NULL;
137 		}
138 		ret = tmp;
139 		memcpy(ret + rlen, k->name, nlen + 1);
140 		rlen += nlen;
141 	}
142 	return ret;
143 }
144 
145 static const struct kexalg *
146 kex_alg_by_name(const char *name)
147 {
148 	const struct kexalg *k;
149 
150 	for (k = kexalgs; k->name != NULL; k++) {
151 		if (strcmp(k->name, name) == 0)
152 			return k;
153 	}
154 	return NULL;
155 }
156 
157 /* Validate KEX method name list */
158 int
159 kex_names_valid(const char *names)
160 {
161 	char *s, *cp, *p;
162 
163 	if (names == NULL || strcmp(names, "") == 0)
164 		return 0;
165 	if ((s = cp = strdup(names)) == NULL)
166 		return 0;
167 	for ((p = strsep(&cp, ",")); p && *p != '\0';
168 	    (p = strsep(&cp, ","))) {
169 		if (kex_alg_by_name(p) == NULL) {
170 			error("Unsupported KEX algorithm \"%.100s\"", p);
171 			free(s);
172 			return 0;
173 		}
174 	}
175 	debug3("kex names ok: [%s]", names);
176 	free(s);
177 	return 1;
178 }
179 
180 /* returns non-zero if proposal contains any algorithm from algs */
181 static int
182 has_any_alg(const char *proposal, const char *algs)
183 {
184 	char *cp;
185 
186 	if ((cp = match_list(proposal, algs, NULL)) == NULL)
187 		return 0;
188 	free(cp);
189 	return 1;
190 }
191 
192 /*
193  * Concatenate algorithm names, avoiding duplicates in the process.
194  * Caller must free returned string.
195  */
196 char *
197 kex_names_cat(const char *a, const char *b)
198 {
199 	char *ret = NULL, *tmp = NULL, *cp, *p;
200 	size_t len;
201 
202 	if (a == NULL || *a == '\0')
203 		return strdup(b);
204 	if (b == NULL || *b == '\0')
205 		return strdup(a);
206 	if (strlen(b) > 1024*1024)
207 		return NULL;
208 	len = strlen(a) + strlen(b) + 2;
209 	if ((tmp = cp = strdup(b)) == NULL ||
210 	    (ret = calloc(1, len)) == NULL) {
211 		free(tmp);
212 		return NULL;
213 	}
214 	strlcpy(ret, a, len);
215 	for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
216 		if (has_any_alg(ret, p))
217 			continue; /* Algorithm already present */
218 		if (strlcat(ret, ",", len) >= len ||
219 		    strlcat(ret, p, len) >= len) {
220 			free(tmp);
221 			free(ret);
222 			return NULL; /* Shouldn't happen */
223 		}
224 	}
225 	free(tmp);
226 	return ret;
227 }
228 
229 /*
230  * Assemble a list of algorithms from a default list and a string from a
231  * configuration file. The user-provided string may begin with '+' to
232  * indicate that it should be appended to the default, '-' that the
233  * specified names should be removed, or '^' that they should be placed
234  * at the head.
235  */
236 int
237 kex_assemble_names(char **listp, const char *def, const char *all)
238 {
239 	char *cp, *tmp, *patterns;
240 	char *list = NULL, *ret = NULL, *matching = NULL, *opatterns = NULL;
241 	int r = SSH_ERR_INTERNAL_ERROR;
242 
243 	if (listp == NULL || def == NULL || all == NULL)
244 		return SSH_ERR_INVALID_ARGUMENT;
245 
246 	if (*listp == NULL || **listp == '\0') {
247 		if ((*listp = strdup(def)) == NULL)
248 			return SSH_ERR_ALLOC_FAIL;
249 		return 0;
250 	}
251 
252 	list = *listp;
253 	*listp = NULL;
254 	if (*list == '+') {
255 		/* Append names to default list */
256 		if ((tmp = kex_names_cat(def, list + 1)) == NULL) {
257 			r = SSH_ERR_ALLOC_FAIL;
258 			goto fail;
259 		}
260 		free(list);
261 		list = tmp;
262 	} else if (*list == '-') {
263 		/* Remove names from default list */
264 		if ((*listp = match_filter_denylist(def, list + 1)) == NULL) {
265 			r = SSH_ERR_ALLOC_FAIL;
266 			goto fail;
267 		}
268 		free(list);
269 		/* filtering has already been done */
270 		return 0;
271 	} else if (*list == '^') {
272 		/* Place names at head of default list */
273 		if ((tmp = kex_names_cat(list + 1, def)) == NULL) {
274 			r = SSH_ERR_ALLOC_FAIL;
275 			goto fail;
276 		}
277 		free(list);
278 		list = tmp;
279 	} else {
280 		/* Explicit list, overrides default - just use "list" as is */
281 	}
282 
283 	/*
284 	 * The supplied names may be a pattern-list. For the -list case,
285 	 * the patterns are applied above. For the +list and explicit list
286 	 * cases we need to do it now.
287 	 */
288 	ret = NULL;
289 	if ((patterns = opatterns = strdup(list)) == NULL) {
290 		r = SSH_ERR_ALLOC_FAIL;
291 		goto fail;
292 	}
293 	/* Apply positive (i.e. non-negated) patterns from the list */
294 	while ((cp = strsep(&patterns, ",")) != NULL) {
295 		if (*cp == '!') {
296 			/* negated matches are not supported here */
297 			r = SSH_ERR_INVALID_ARGUMENT;
298 			goto fail;
299 		}
300 		free(matching);
301 		if ((matching = match_filter_allowlist(all, cp)) == NULL) {
302 			r = SSH_ERR_ALLOC_FAIL;
303 			goto fail;
304 		}
305 		if ((tmp = kex_names_cat(ret, matching)) == NULL) {
306 			r = SSH_ERR_ALLOC_FAIL;
307 			goto fail;
308 		}
309 		free(ret);
310 		ret = tmp;
311 	}
312 	if (ret == NULL || *ret == '\0') {
313 		/* An empty name-list is an error */
314 		/* XXX better error code? */
315 		r = SSH_ERR_INVALID_ARGUMENT;
316 		goto fail;
317 	}
318 
319 	/* success */
320 	*listp = ret;
321 	ret = NULL;
322 	r = 0;
323 
324  fail:
325 	free(matching);
326 	free(opatterns);
327 	free(list);
328 	free(ret);
329 	return r;
330 }
331 
332 /*
333  * Fill out a proposal array with dynamically allocated values, which may
334  * be modified as required for compatibility reasons.
335  * Any of the options may be NULL, in which case the default is used.
336  * Array contents must be freed by calling kex_proposal_free_entries.
337  */
338 void
339 kex_proposal_populate_entries(struct ssh *ssh, char *prop[PROPOSAL_MAX],
340     const char *kexalgos, const char *ciphers, const char *macs,
341     const char *comp, const char *hkalgs)
342 {
343 	const char *defpropserver[PROPOSAL_MAX] = { KEX_SERVER };
344 	const char *defpropclient[PROPOSAL_MAX] = { KEX_CLIENT };
345 	const char **defprop = ssh->kex->server ? defpropserver : defpropclient;
346 	u_int i;
347 	char *cp;
348 
349 	if (prop == NULL)
350 		fatal_f("proposal missing");
351 
352 	/* Append EXT_INFO signalling to KexAlgorithms */
353 	if (kexalgos == NULL)
354 		kexalgos = defprop[PROPOSAL_KEX_ALGS];
355 	if ((cp = kex_names_cat(kexalgos, ssh->kex->server ?
356 	    "kex-strict-s-v00@openssh.com" :
357 	    "ext-info-c,kex-strict-c-v00@openssh.com")) == NULL)
358 		fatal_f("kex_names_cat");
359 
360 	for (i = 0; i < PROPOSAL_MAX; i++) {
361 		switch(i) {
362 		case PROPOSAL_KEX_ALGS:
363 			prop[i] = compat_kex_proposal(ssh, cp);
364 			break;
365 		case PROPOSAL_ENC_ALGS_CTOS:
366 		case PROPOSAL_ENC_ALGS_STOC:
367 			prop[i] = xstrdup(ciphers ? ciphers : defprop[i]);
368 			break;
369 		case PROPOSAL_MAC_ALGS_CTOS:
370 		case PROPOSAL_MAC_ALGS_STOC:
371 			prop[i]  = xstrdup(macs ? macs : defprop[i]);
372 			break;
373 		case PROPOSAL_COMP_ALGS_CTOS:
374 		case PROPOSAL_COMP_ALGS_STOC:
375 			prop[i] = xstrdup(comp ? comp : defprop[i]);
376 			break;
377 		case PROPOSAL_SERVER_HOST_KEY_ALGS:
378 			prop[i] = xstrdup(hkalgs ? hkalgs : defprop[i]);
379 			break;
380 		default:
381 			prop[i] = xstrdup(defprop[i]);
382 		}
383 	}
384 	free(cp);
385 }
386 
387 void
388 kex_proposal_free_entries(char *prop[PROPOSAL_MAX])
389 {
390 	u_int i;
391 
392 	for (i = 0; i < PROPOSAL_MAX; i++)
393 		free(prop[i]);
394 }
395 
396 /* put algorithm proposal into buffer */
397 int
398 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
399 {
400 	u_int i;
401 	int r;
402 
403 	sshbuf_reset(b);
404 
405 	/*
406 	 * add a dummy cookie, the cookie will be overwritten by
407 	 * kex_send_kexinit(), each time a kexinit is set
408 	 */
409 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
410 		if ((r = sshbuf_put_u8(b, 0)) != 0)
411 			return r;
412 	}
413 	for (i = 0; i < PROPOSAL_MAX; i++) {
414 		if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
415 			return r;
416 	}
417 	if ((r = sshbuf_put_u8(b, 0)) != 0 ||	/* first_kex_packet_follows */
418 	    (r = sshbuf_put_u32(b, 0)) != 0)	/* uint32 reserved */
419 		return r;
420 	return 0;
421 }
422 
423 /* parse buffer and return algorithm proposal */
424 int
425 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
426 {
427 	struct sshbuf *b = NULL;
428 	u_char v;
429 	u_int i;
430 	char **proposal = NULL;
431 	int r;
432 
433 	*propp = NULL;
434 	if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
435 		return SSH_ERR_ALLOC_FAIL;
436 	if ((b = sshbuf_fromb(raw)) == NULL) {
437 		r = SSH_ERR_ALLOC_FAIL;
438 		goto out;
439 	}
440 	if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) { /* skip cookie */
441 		error_fr(r, "consume cookie");
442 		goto out;
443 	}
444 	/* extract kex init proposal strings */
445 	for (i = 0; i < PROPOSAL_MAX; i++) {
446 		if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0) {
447 			error_fr(r, "parse proposal %u", i);
448 			goto out;
449 		}
450 		debug2("%s: %s", proposal_names[i], proposal[i]);
451 	}
452 	/* first kex follows / reserved */
453 	if ((r = sshbuf_get_u8(b, &v)) != 0 ||	/* first_kex_follows */
454 	    (r = sshbuf_get_u32(b, &i)) != 0) {	/* reserved */
455 		error_fr(r, "parse");
456 		goto out;
457 	}
458 	if (first_kex_follows != NULL)
459 		*first_kex_follows = v;
460 	debug2("first_kex_follows %d ", v);
461 	debug2("reserved %u ", i);
462 	r = 0;
463 	*propp = proposal;
464  out:
465 	if (r != 0 && proposal != NULL)
466 		kex_prop_free(proposal);
467 	sshbuf_free(b);
468 	return r;
469 }
470 
471 void
472 kex_prop_free(char **proposal)
473 {
474 	u_int i;
475 
476 	if (proposal == NULL)
477 		return;
478 	for (i = 0; i < PROPOSAL_MAX; i++)
479 		free(proposal[i]);
480 	free(proposal);
481 }
482 
483 int
484 kex_protocol_error(int type, u_int32_t seq, struct ssh *ssh)
485 {
486 	int r;
487 
488 	/* If in strict mode, any unexpected message is an error */
489 	if ((ssh->kex->flags & KEX_INITIAL) && ssh->kex->kex_strict) {
490 		ssh_packet_disconnect(ssh, "strict KEX violation: "
491 		    "unexpected packet type %u (seqnr %u)", type, seq);
492 	}
493 	error_f("type %u seq %u", type, seq);
494 	if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
495 	    (r = sshpkt_put_u32(ssh, seq)) != 0 ||
496 	    (r = sshpkt_send(ssh)) != 0)
497 		return r;
498 	return 0;
499 }
500 
501 static void
502 kex_reset_dispatch(struct ssh *ssh)
503 {
504 	ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
505 	    SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
506 }
507 
508 static int
509 kex_send_ext_info(struct ssh *ssh)
510 {
511 	int r;
512 	char *algs;
513 
514 	debug("Sending SSH2_MSG_EXT_INFO");
515 	if ((algs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
516 		return SSH_ERR_ALLOC_FAIL;
517 	/* XXX filter algs list by allowed pubkey/hostbased types */
518 	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
519 	    (r = sshpkt_put_u32(ssh, 3)) != 0 ||
520 	    (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
521 	    (r = sshpkt_put_cstring(ssh, algs)) != 0 ||
522 	    (r = sshpkt_put_cstring(ssh,
523 	    "publickey-hostbound@openssh.com")) != 0 ||
524 	    (r = sshpkt_put_cstring(ssh, "0")) != 0 ||
525 	    (r = sshpkt_put_cstring(ssh, "ping@openssh.com")) != 0 ||
526 	    (r = sshpkt_put_cstring(ssh, "0")) != 0 ||
527 	    (r = sshpkt_send(ssh)) != 0) {
528 		error_fr(r, "compose");
529 		goto out;
530 	}
531 	/* success */
532 	r = 0;
533  out:
534 	free(algs);
535 	return r;
536 }
537 
538 int
539 kex_send_newkeys(struct ssh *ssh)
540 {
541 	int r;
542 
543 	kex_reset_dispatch(ssh);
544 	if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
545 	    (r = sshpkt_send(ssh)) != 0)
546 		return r;
547 	debug("SSH2_MSG_NEWKEYS sent");
548 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
549 	if (ssh->kex->ext_info_c && (ssh->kex->flags & KEX_INITIAL) != 0)
550 		if ((r = kex_send_ext_info(ssh)) != 0)
551 			return r;
552 	debug("expecting SSH2_MSG_NEWKEYS");
553 	return 0;
554 }
555 
556 /* Check whether an ext_info value contains the expected version string */
557 static int
558 kex_ext_info_check_ver(struct kex *kex, const char *name,
559     const u_char *val, size_t len, const char *want_ver, u_int flag)
560 {
561 	if (memchr(val, '\0', len) != NULL) {
562 		error("SSH2_MSG_EXT_INFO: %s value contains nul byte", name);
563 		return SSH_ERR_INVALID_FORMAT;
564 	}
565 	debug_f("%s=<%s>", name, val);
566 	if (strcmp(val, want_ver) == 0)
567 		kex->flags |= flag;
568 	else
569 		debug_f("unsupported version of %s extension", name);
570 	return 0;
571 }
572 
573 int
574 kex_input_ext_info(int type, u_int32_t seq, struct ssh *ssh)
575 {
576 	struct kex *kex = ssh->kex;
577 	u_int32_t i, ninfo;
578 	char *name;
579 	u_char *val;
580 	size_t vlen;
581 	int r;
582 
583 	debug("SSH2_MSG_EXT_INFO received");
584 	ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
585 	if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
586 		return r;
587 	if (ninfo >= 1024) {
588 		error("SSH2_MSG_EXT_INFO with too many entries, expected "
589 		    "<=1024, received %u", ninfo);
590 		return dispatch_protocol_error(type, seq, ssh);
591 	}
592 	for (i = 0; i < ninfo; i++) {
593 		if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
594 			return r;
595 		if ((r = sshpkt_get_string(ssh, &val, &vlen)) != 0) {
596 			free(name);
597 			return r;
598 		}
599 		if (strcmp(name, "server-sig-algs") == 0) {
600 			/* Ensure no \0 lurking in value */
601 			if (memchr(val, '\0', vlen) != NULL) {
602 				error_f("nul byte in %s", name);
603 				free(name);
604 				free(val);
605 				return SSH_ERR_INVALID_FORMAT;
606 			}
607 			debug_f("%s=<%s>", name, val);
608 			kex->server_sig_algs = val;
609 			val = NULL;
610 		} else if (strcmp(name,
611 		    "publickey-hostbound@openssh.com") == 0) {
612 			if ((r = kex_ext_info_check_ver(kex, name, val, vlen,
613 			    "0", KEX_HAS_PUBKEY_HOSTBOUND)) != 0) {
614 				free(name);
615 				free(val);
616 				return r;
617 			}
618 		} else if (strcmp(name, "ping@openssh.com") == 0) {
619 			if ((r = kex_ext_info_check_ver(kex, name, val, vlen,
620 			    "0", KEX_HAS_PING)) != 0) {
621 				free(name);
622 				free(val);
623 				return r;
624 			}
625 		} else
626 			debug_f("%s (unrecognised)", name);
627 		free(name);
628 		free(val);
629 	}
630 	return sshpkt_get_end(ssh);
631 }
632 
633 static int
634 kex_input_newkeys(int type, u_int32_t seq, struct ssh *ssh)
635 {
636 	struct kex *kex = ssh->kex;
637 	int r;
638 
639 	debug("SSH2_MSG_NEWKEYS received");
640 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
641 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
642 	if ((r = sshpkt_get_end(ssh)) != 0)
643 		return r;
644 	if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0)
645 		return r;
646 	kex->done = 1;
647 	kex->flags &= ~KEX_INITIAL;
648 	sshbuf_reset(kex->peer);
649 	/* sshbuf_reset(kex->my); */
650 	kex->flags &= ~KEX_INIT_SENT;
651 	free(kex->name);
652 	kex->name = NULL;
653 	return 0;
654 }
655 
656 int
657 kex_send_kexinit(struct ssh *ssh)
658 {
659 	u_char *cookie;
660 	struct kex *kex = ssh->kex;
661 	int r;
662 
663 	if (kex == NULL) {
664 		error_f("no kex");
665 		return SSH_ERR_INTERNAL_ERROR;
666 	}
667 	if (kex->flags & KEX_INIT_SENT)
668 		return 0;
669 	kex->done = 0;
670 
671 	/* generate a random cookie */
672 	if (sshbuf_len(kex->my) < KEX_COOKIE_LEN) {
673 		error_f("bad kex length: %zu < %d",
674 		    sshbuf_len(kex->my), KEX_COOKIE_LEN);
675 		return SSH_ERR_INVALID_FORMAT;
676 	}
677 	if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL) {
678 		error_f("buffer error");
679 		return SSH_ERR_INTERNAL_ERROR;
680 	}
681 	arc4random_buf(cookie, KEX_COOKIE_LEN);
682 
683 	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
684 	    (r = sshpkt_putb(ssh, kex->my)) != 0 ||
685 	    (r = sshpkt_send(ssh)) != 0) {
686 		error_fr(r, "compose reply");
687 		return r;
688 	}
689 	debug("SSH2_MSG_KEXINIT sent");
690 	kex->flags |= KEX_INIT_SENT;
691 	return 0;
692 }
693 
694 int
695 kex_input_kexinit(int type, u_int32_t seq, struct ssh *ssh)
696 {
697 	struct kex *kex = ssh->kex;
698 	const u_char *ptr;
699 	u_int i;
700 	size_t dlen;
701 	int r;
702 
703 	debug("SSH2_MSG_KEXINIT received");
704 	if (kex == NULL) {
705 		error_f("no kex");
706 		return SSH_ERR_INTERNAL_ERROR;
707 	}
708 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_protocol_error);
709 	ptr = sshpkt_ptr(ssh, &dlen);
710 	if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
711 		return r;
712 
713 	/* discard packet */
714 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
715 		if ((r = sshpkt_get_u8(ssh, NULL)) != 0) {
716 			error_fr(r, "discard cookie");
717 			return r;
718 		}
719 	}
720 	for (i = 0; i < PROPOSAL_MAX; i++) {
721 		if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0) {
722 			error_fr(r, "discard proposal");
723 			return r;
724 		}
725 	}
726 	/*
727 	 * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
728 	 * KEX method has the server move first, but a server might be using
729 	 * a custom method or one that we otherwise don't support. We should
730 	 * be prepared to remember first_kex_follows here so we can eat a
731 	 * packet later.
732 	 * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
733 	 * for cases where the server *doesn't* go first. I guess we should
734 	 * ignore it when it is set for these cases, which is what we do now.
735 	 */
736 	if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||	/* first_kex_follows */
737 	    (r = sshpkt_get_u32(ssh, NULL)) != 0 ||	/* reserved */
738 	    (r = sshpkt_get_end(ssh)) != 0)
739 			return r;
740 
741 	if (!(kex->flags & KEX_INIT_SENT))
742 		if ((r = kex_send_kexinit(ssh)) != 0)
743 			return r;
744 	if ((r = kex_choose_conf(ssh, seq)) != 0)
745 		return r;
746 
747 	if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
748 		return (kex->kex[kex->kex_type])(ssh);
749 
750 	error_f("unknown kex type %u", kex->kex_type);
751 	return SSH_ERR_INTERNAL_ERROR;
752 }
753 
754 struct kex *
755 kex_new(void)
756 {
757 	struct kex *kex;
758 
759 	if ((kex = calloc(1, sizeof(*kex))) == NULL ||
760 	    (kex->peer = sshbuf_new()) == NULL ||
761 	    (kex->my = sshbuf_new()) == NULL ||
762 	    (kex->client_version = sshbuf_new()) == NULL ||
763 	    (kex->server_version = sshbuf_new()) == NULL ||
764 	    (kex->session_id = sshbuf_new()) == NULL) {
765 		kex_free(kex);
766 		return NULL;
767 	}
768 	return kex;
769 }
770 
771 void
772 kex_free_newkeys(struct newkeys *newkeys)
773 {
774 	if (newkeys == NULL)
775 		return;
776 	if (newkeys->enc.key) {
777 		explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
778 		free(newkeys->enc.key);
779 		newkeys->enc.key = NULL;
780 	}
781 	if (newkeys->enc.iv) {
782 		explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
783 		free(newkeys->enc.iv);
784 		newkeys->enc.iv = NULL;
785 	}
786 	free(newkeys->enc.name);
787 	explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
788 	free(newkeys->comp.name);
789 	explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
790 	mac_clear(&newkeys->mac);
791 	if (newkeys->mac.key) {
792 		explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
793 		free(newkeys->mac.key);
794 		newkeys->mac.key = NULL;
795 	}
796 	free(newkeys->mac.name);
797 	explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
798 	freezero(newkeys, sizeof(*newkeys));
799 }
800 
801 void
802 kex_free(struct kex *kex)
803 {
804 	u_int mode;
805 
806 	if (kex == NULL)
807 		return;
808 
809 #ifdef WITH_OPENSSL
810 	DH_free(kex->dh);
811 #ifdef OPENSSL_HAS_ECC
812 	EC_KEY_free(kex->ec_client_key);
813 #endif /* OPENSSL_HAS_ECC */
814 #endif /* WITH_OPENSSL */
815 	for (mode = 0; mode < MODE_MAX; mode++) {
816 		kex_free_newkeys(kex->newkeys[mode]);
817 		kex->newkeys[mode] = NULL;
818 	}
819 	sshbuf_free(kex->peer);
820 	sshbuf_free(kex->my);
821 	sshbuf_free(kex->client_version);
822 	sshbuf_free(kex->server_version);
823 	sshbuf_free(kex->client_pub);
824 	sshbuf_free(kex->session_id);
825 	sshbuf_free(kex->initial_sig);
826 	sshkey_free(kex->initial_hostkey);
827 	free(kex->failed_choice);
828 	free(kex->hostkey_alg);
829 	free(kex->name);
830 	free(kex);
831 }
832 
833 int
834 kex_ready(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
835 {
836 	int r;
837 
838 	if ((r = kex_prop2buf(ssh->kex->my, proposal)) != 0)
839 		return r;
840 	ssh->kex->flags = KEX_INITIAL;
841 	kex_reset_dispatch(ssh);
842 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
843 	return 0;
844 }
845 
846 int
847 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
848 {
849 	int r;
850 
851 	if ((r = kex_ready(ssh, proposal)) != 0)
852 		return r;
853 	if ((r = kex_send_kexinit(ssh)) != 0) {		/* we start */
854 		kex_free(ssh->kex);
855 		ssh->kex = NULL;
856 		return r;
857 	}
858 	return 0;
859 }
860 
861 /*
862  * Request key re-exchange, returns 0 on success or a ssherr.h error
863  * code otherwise. Must not be called if KEX is incomplete or in-progress.
864  */
865 int
866 kex_start_rekex(struct ssh *ssh)
867 {
868 	if (ssh->kex == NULL) {
869 		error_f("no kex");
870 		return SSH_ERR_INTERNAL_ERROR;
871 	}
872 	if (ssh->kex->done == 0) {
873 		error_f("requested twice");
874 		return SSH_ERR_INTERNAL_ERROR;
875 	}
876 	ssh->kex->done = 0;
877 	return kex_send_kexinit(ssh);
878 }
879 
880 static int
881 choose_enc(struct sshenc *enc, char *client, char *server)
882 {
883 	char *name = match_list(client, server, NULL);
884 
885 	if (name == NULL)
886 		return SSH_ERR_NO_CIPHER_ALG_MATCH;
887 	if ((enc->cipher = cipher_by_name(name)) == NULL) {
888 		error_f("unsupported cipher %s", name);
889 		free(name);
890 		return SSH_ERR_INTERNAL_ERROR;
891 	}
892 	enc->name = name;
893 	enc->enabled = 0;
894 	enc->iv = NULL;
895 	enc->iv_len = cipher_ivlen(enc->cipher);
896 	enc->key = NULL;
897 	enc->key_len = cipher_keylen(enc->cipher);
898 	enc->block_size = cipher_blocksize(enc->cipher);
899 	return 0;
900 }
901 
902 static int
903 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
904 {
905 	char *name = match_list(client, server, NULL);
906 
907 	if (name == NULL)
908 		return SSH_ERR_NO_MAC_ALG_MATCH;
909 	if (mac_setup(mac, name) < 0) {
910 		error_f("unsupported MAC %s", name);
911 		free(name);
912 		return SSH_ERR_INTERNAL_ERROR;
913 	}
914 	mac->name = name;
915 	mac->key = NULL;
916 	mac->enabled = 0;
917 	return 0;
918 }
919 
920 static int
921 choose_comp(struct sshcomp *comp, char *client, char *server)
922 {
923 	char *name = match_list(client, server, NULL);
924 
925 	if (name == NULL)
926 		return SSH_ERR_NO_COMPRESS_ALG_MATCH;
927 #ifdef WITH_ZLIB
928 	if (strcmp(name, "zlib@openssh.com") == 0) {
929 		comp->type = COMP_DELAYED;
930 	} else if (strcmp(name, "zlib") == 0) {
931 		comp->type = COMP_ZLIB;
932 	} else
933 #endif	/* WITH_ZLIB */
934 	if (strcmp(name, "none") == 0) {
935 		comp->type = COMP_NONE;
936 	} else {
937 		error_f("unsupported compression scheme %s", name);
938 		free(name);
939 		return SSH_ERR_INTERNAL_ERROR;
940 	}
941 	comp->name = name;
942 	return 0;
943 }
944 
945 static int
946 choose_kex(struct kex *k, char *client, char *server)
947 {
948 	const struct kexalg *kexalg;
949 
950 	k->name = match_list(client, server, NULL);
951 
952 	debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
953 	if (k->name == NULL)
954 		return SSH_ERR_NO_KEX_ALG_MATCH;
955 	if ((kexalg = kex_alg_by_name(k->name)) == NULL) {
956 		error_f("unsupported KEX method %s", k->name);
957 		return SSH_ERR_INTERNAL_ERROR;
958 	}
959 	k->kex_type = kexalg->type;
960 	k->hash_alg = kexalg->hash_alg;
961 	k->ec_nid = kexalg->ec_nid;
962 	return 0;
963 }
964 
965 static int
966 choose_hostkeyalg(struct kex *k, char *client, char *server)
967 {
968 	free(k->hostkey_alg);
969 	k->hostkey_alg = match_list(client, server, NULL);
970 
971 	debug("kex: host key algorithm: %s",
972 	    k->hostkey_alg ? k->hostkey_alg : "(no match)");
973 	if (k->hostkey_alg == NULL)
974 		return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
975 	k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
976 	if (k->hostkey_type == KEY_UNSPEC) {
977 		error_f("unsupported hostkey algorithm %s", k->hostkey_alg);
978 		return SSH_ERR_INTERNAL_ERROR;
979 	}
980 	k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
981 	return 0;
982 }
983 
984 static int
985 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
986 {
987 	static int check[] = {
988 		PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
989 	};
990 	int *idx;
991 	char *p;
992 
993 	for (idx = &check[0]; *idx != -1; idx++) {
994 		if ((p = strchr(my[*idx], ',')) != NULL)
995 			*p = '\0';
996 		if ((p = strchr(peer[*idx], ',')) != NULL)
997 			*p = '\0';
998 		if (strcmp(my[*idx], peer[*idx]) != 0) {
999 			debug2("proposal mismatch: my %s peer %s",
1000 			    my[*idx], peer[*idx]);
1001 			return (0);
1002 		}
1003 	}
1004 	debug2("proposals match");
1005 	return (1);
1006 }
1007 
1008 static int
1009 kexalgs_contains(char **peer, const char *ext)
1010 {
1011 	return has_any_alg(peer[PROPOSAL_KEX_ALGS], ext);
1012 }
1013 
1014 static int
1015 kex_choose_conf(struct ssh *ssh, uint32_t seq)
1016 {
1017 	struct kex *kex = ssh->kex;
1018 	struct newkeys *newkeys;
1019 	char **my = NULL, **peer = NULL;
1020 	char **cprop, **sprop;
1021 	int nenc, nmac, ncomp;
1022 	u_int mode, ctos, need, dh_need, authlen;
1023 	int r, first_kex_follows;
1024 
1025 	debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
1026 	if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
1027 		goto out;
1028 	debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
1029 	if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
1030 		goto out;
1031 
1032 	if (kex->server) {
1033 		cprop=peer;
1034 		sprop=my;
1035 	} else {
1036 		cprop=my;
1037 		sprop=peer;
1038 	}
1039 
1040 	/* Check whether peer supports ext_info/kex_strict */
1041 	if ((kex->flags & KEX_INITIAL) != 0) {
1042 		if (kex->server) {
1043 			kex->ext_info_c = kexalgs_contains(peer, "ext-info-c");
1044 			kex->kex_strict = kexalgs_contains(peer,
1045 			    "kex-strict-c-v00@openssh.com");
1046 		} else {
1047 			kex->kex_strict = kexalgs_contains(peer,
1048 			    "kex-strict-s-v00@openssh.com");
1049 		}
1050 		if (kex->kex_strict) {
1051 			debug3_f("will use strict KEX ordering");
1052 			if (seq != 0)
1053 				ssh_packet_disconnect(ssh,
1054 				    "strict KEX violation: "
1055 				    "KEXINIT was not the first packet");
1056 		}
1057 	}
1058 
1059 	/* Check whether client supports rsa-sha2 algorithms */
1060 	if (kex->server && (kex->flags & KEX_INITIAL)) {
1061 		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1062 		    "rsa-sha2-256,rsa-sha2-256-cert-v01@openssh.com"))
1063 			kex->flags |= KEX_RSA_SHA2_256_SUPPORTED;
1064 		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1065 		    "rsa-sha2-512,rsa-sha2-512-cert-v01@openssh.com"))
1066 			kex->flags |= KEX_RSA_SHA2_512_SUPPORTED;
1067 	}
1068 
1069 	/* Algorithm Negotiation */
1070 	if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
1071 	    sprop[PROPOSAL_KEX_ALGS])) != 0) {
1072 		kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
1073 		peer[PROPOSAL_KEX_ALGS] = NULL;
1074 		goto out;
1075 	}
1076 	if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
1077 	    sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
1078 		kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
1079 		peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
1080 		goto out;
1081 	}
1082 	for (mode = 0; mode < MODE_MAX; mode++) {
1083 		if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
1084 			r = SSH_ERR_ALLOC_FAIL;
1085 			goto out;
1086 		}
1087 		kex->newkeys[mode] = newkeys;
1088 		ctos = (!kex->server && mode == MODE_OUT) ||
1089 		    (kex->server && mode == MODE_IN);
1090 		nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
1091 		nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
1092 		ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
1093 		if ((r = choose_enc(&newkeys->enc, cprop[nenc],
1094 		    sprop[nenc])) != 0) {
1095 			kex->failed_choice = peer[nenc];
1096 			peer[nenc] = NULL;
1097 			goto out;
1098 		}
1099 		authlen = cipher_authlen(newkeys->enc.cipher);
1100 		/* ignore mac for authenticated encryption */
1101 		if (authlen == 0 &&
1102 		    (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
1103 		    sprop[nmac])) != 0) {
1104 			kex->failed_choice = peer[nmac];
1105 			peer[nmac] = NULL;
1106 			goto out;
1107 		}
1108 		if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
1109 		    sprop[ncomp])) != 0) {
1110 			kex->failed_choice = peer[ncomp];
1111 			peer[ncomp] = NULL;
1112 			goto out;
1113 		}
1114 		debug("kex: %s cipher: %s MAC: %s compression: %s",
1115 		    ctos ? "client->server" : "server->client",
1116 		    newkeys->enc.name,
1117 		    authlen == 0 ? newkeys->mac.name : "<implicit>",
1118 		    newkeys->comp.name);
1119 	}
1120 	need = dh_need = 0;
1121 	for (mode = 0; mode < MODE_MAX; mode++) {
1122 		newkeys = kex->newkeys[mode];
1123 		need = MAXIMUM(need, newkeys->enc.key_len);
1124 		need = MAXIMUM(need, newkeys->enc.block_size);
1125 		need = MAXIMUM(need, newkeys->enc.iv_len);
1126 		need = MAXIMUM(need, newkeys->mac.key_len);
1127 		dh_need = MAXIMUM(dh_need, cipher_seclen(newkeys->enc.cipher));
1128 		dh_need = MAXIMUM(dh_need, newkeys->enc.block_size);
1129 		dh_need = MAXIMUM(dh_need, newkeys->enc.iv_len);
1130 		dh_need = MAXIMUM(dh_need, newkeys->mac.key_len);
1131 	}
1132 	/* XXX need runden? */
1133 	kex->we_need = need;
1134 	kex->dh_need = dh_need;
1135 
1136 	/* ignore the next message if the proposals do not match */
1137 	if (first_kex_follows && !proposals_match(my, peer))
1138 		ssh->dispatch_skip_packets = 1;
1139 	r = 0;
1140  out:
1141 	kex_prop_free(my);
1142 	kex_prop_free(peer);
1143 	return r;
1144 }
1145 
1146 static int
1147 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
1148     const struct sshbuf *shared_secret, u_char **keyp)
1149 {
1150 	struct kex *kex = ssh->kex;
1151 	struct ssh_digest_ctx *hashctx = NULL;
1152 	char c = id;
1153 	u_int have;
1154 	size_t mdsz;
1155 	u_char *digest;
1156 	int r;
1157 
1158 	if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
1159 		return SSH_ERR_INVALID_ARGUMENT;
1160 	if ((digest = calloc(1, ROUNDUP(need, mdsz))) == NULL) {
1161 		r = SSH_ERR_ALLOC_FAIL;
1162 		goto out;
1163 	}
1164 
1165 	/* K1 = HASH(K || H || "A" || session_id) */
1166 	if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1167 	    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1168 	    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1169 	    ssh_digest_update(hashctx, &c, 1) != 0 ||
1170 	    ssh_digest_update_buffer(hashctx, kex->session_id) != 0 ||
1171 	    ssh_digest_final(hashctx, digest, mdsz) != 0) {
1172 		r = SSH_ERR_LIBCRYPTO_ERROR;
1173 		error_f("KEX hash failed");
1174 		goto out;
1175 	}
1176 	ssh_digest_free(hashctx);
1177 	hashctx = NULL;
1178 
1179 	/*
1180 	 * expand key:
1181 	 * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
1182 	 * Key = K1 || K2 || ... || Kn
1183 	 */
1184 	for (have = mdsz; need > have; have += mdsz) {
1185 		if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1186 		    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1187 		    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1188 		    ssh_digest_update(hashctx, digest, have) != 0 ||
1189 		    ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
1190 			error_f("KDF failed");
1191 			r = SSH_ERR_LIBCRYPTO_ERROR;
1192 			goto out;
1193 		}
1194 		ssh_digest_free(hashctx);
1195 		hashctx = NULL;
1196 	}
1197 #ifdef DEBUG_KEX
1198 	fprintf(stderr, "key '%c'== ", c);
1199 	dump_digest("key", digest, need);
1200 #endif
1201 	*keyp = digest;
1202 	digest = NULL;
1203 	r = 0;
1204  out:
1205 	free(digest);
1206 	ssh_digest_free(hashctx);
1207 	return r;
1208 }
1209 
1210 #define NKEYS	6
1211 int
1212 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
1213     const struct sshbuf *shared_secret)
1214 {
1215 	struct kex *kex = ssh->kex;
1216 	u_char *keys[NKEYS];
1217 	u_int i, j, mode, ctos;
1218 	int r;
1219 
1220 	/* save initial hash as session id */
1221 	if ((kex->flags & KEX_INITIAL) != 0) {
1222 		if (sshbuf_len(kex->session_id) != 0) {
1223 			error_f("already have session ID at kex");
1224 			return SSH_ERR_INTERNAL_ERROR;
1225 		}
1226 		if ((r = sshbuf_put(kex->session_id, hash, hashlen)) != 0)
1227 			return r;
1228 	} else if (sshbuf_len(kex->session_id) == 0) {
1229 		error_f("no session ID in rekex");
1230 		return SSH_ERR_INTERNAL_ERROR;
1231 	}
1232 	for (i = 0; i < NKEYS; i++) {
1233 		if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
1234 		    shared_secret, &keys[i])) != 0) {
1235 			for (j = 0; j < i; j++)
1236 				free(keys[j]);
1237 			return r;
1238 		}
1239 	}
1240 	for (mode = 0; mode < MODE_MAX; mode++) {
1241 		ctos = (!kex->server && mode == MODE_OUT) ||
1242 		    (kex->server && mode == MODE_IN);
1243 		kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
1244 		kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
1245 		kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
1246 	}
1247 	return 0;
1248 }
1249 
1250 int
1251 kex_load_hostkey(struct ssh *ssh, struct sshkey **prvp, struct sshkey **pubp)
1252 {
1253 	struct kex *kex = ssh->kex;
1254 
1255 	*pubp = NULL;
1256 	*prvp = NULL;
1257 	if (kex->load_host_public_key == NULL ||
1258 	    kex->load_host_private_key == NULL) {
1259 		error_f("missing hostkey loader");
1260 		return SSH_ERR_INVALID_ARGUMENT;
1261 	}
1262 	*pubp = kex->load_host_public_key(kex->hostkey_type,
1263 	    kex->hostkey_nid, ssh);
1264 	*prvp = kex->load_host_private_key(kex->hostkey_type,
1265 	    kex->hostkey_nid, ssh);
1266 	if (*pubp == NULL)
1267 		return SSH_ERR_NO_HOSTKEY_LOADED;
1268 	return 0;
1269 }
1270 
1271 int
1272 kex_verify_host_key(struct ssh *ssh, struct sshkey *server_host_key)
1273 {
1274 	struct kex *kex = ssh->kex;
1275 
1276 	if (kex->verify_host_key == NULL) {
1277 		error_f("missing hostkey verifier");
1278 		return SSH_ERR_INVALID_ARGUMENT;
1279 	}
1280 	if (server_host_key->type != kex->hostkey_type ||
1281 	    (kex->hostkey_type == KEY_ECDSA &&
1282 	    server_host_key->ecdsa_nid != kex->hostkey_nid))
1283 		return SSH_ERR_KEY_TYPE_MISMATCH;
1284 	if (kex->verify_host_key(server_host_key, ssh) == -1)
1285 		return  SSH_ERR_SIGNATURE_INVALID;
1286 	return 0;
1287 }
1288 
1289 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1290 void
1291 dump_digest(const char *msg, const u_char *digest, int len)
1292 {
1293 	fprintf(stderr, "%s\n", msg);
1294 	sshbuf_dump_data(digest, len, stderr);
1295 }
1296 #endif
1297 
1298 /*
1299  * Send a plaintext error message to the peer, suffixed by \r\n.
1300  * Only used during banner exchange, and there only for the server.
1301  */
1302 static void
1303 send_error(struct ssh *ssh, char *msg)
1304 {
1305 	char *crnl = "\r\n";
1306 
1307 	if (!ssh->kex->server)
1308 		return;
1309 
1310 	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1311 	    msg, strlen(msg)) != strlen(msg) ||
1312 	    atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1313 	    crnl, strlen(crnl)) != strlen(crnl))
1314 		error_f("write: %.100s", strerror(errno));
1315 }
1316 
1317 /*
1318  * Sends our identification string and waits for the peer's. Will block for
1319  * up to timeout_ms (or indefinitely if timeout_ms <= 0).
1320  * Returns on 0 success or a ssherr.h code on failure.
1321  */
1322 int
1323 kex_exchange_identification(struct ssh *ssh, int timeout_ms,
1324     const char *version_addendum)
1325 {
1326 	int remote_major, remote_minor, mismatch, oerrno = 0;
1327 	size_t len, n;
1328 	int r, expect_nl;
1329 	u_char c;
1330 	struct sshbuf *our_version = ssh->kex->server ?
1331 	    ssh->kex->server_version : ssh->kex->client_version;
1332 	struct sshbuf *peer_version = ssh->kex->server ?
1333 	    ssh->kex->client_version : ssh->kex->server_version;
1334 	char *our_version_string = NULL, *peer_version_string = NULL;
1335 	char *cp, *remote_version = NULL;
1336 
1337 	/* Prepare and send our banner */
1338 	sshbuf_reset(our_version);
1339 	if (version_addendum != NULL && *version_addendum == '\0')
1340 		version_addendum = NULL;
1341 	if ((r = sshbuf_putf(our_version, "SSH-%d.%d-%.100s%s%s\r\n",
1342 	    PROTOCOL_MAJOR_2, PROTOCOL_MINOR_2, SSH_VERSION,
1343 	    version_addendum == NULL ? "" : " ",
1344 	    version_addendum == NULL ? "" : version_addendum)) != 0) {
1345 		oerrno = errno;
1346 		error_fr(r, "sshbuf_putf");
1347 		goto out;
1348 	}
1349 
1350 	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1351 	    sshbuf_mutable_ptr(our_version),
1352 	    sshbuf_len(our_version)) != sshbuf_len(our_version)) {
1353 		oerrno = errno;
1354 		debug_f("write: %.100s", strerror(errno));
1355 		r = SSH_ERR_SYSTEM_ERROR;
1356 		goto out;
1357 	}
1358 	if ((r = sshbuf_consume_end(our_version, 2)) != 0) { /* trim \r\n */
1359 		oerrno = errno;
1360 		error_fr(r, "sshbuf_consume_end");
1361 		goto out;
1362 	}
1363 	our_version_string = sshbuf_dup_string(our_version);
1364 	if (our_version_string == NULL) {
1365 		error_f("sshbuf_dup_string failed");
1366 		r = SSH_ERR_ALLOC_FAIL;
1367 		goto out;
1368 	}
1369 	debug("Local version string %.100s", our_version_string);
1370 
1371 	/* Read other side's version identification. */
1372 	for (n = 0; ; n++) {
1373 		if (n >= SSH_MAX_PRE_BANNER_LINES) {
1374 			send_error(ssh, "No SSH identification string "
1375 			    "received.");
1376 			error_f("No SSH version received in first %u lines "
1377 			    "from server", SSH_MAX_PRE_BANNER_LINES);
1378 			r = SSH_ERR_INVALID_FORMAT;
1379 			goto out;
1380 		}
1381 		sshbuf_reset(peer_version);
1382 		expect_nl = 0;
1383 		for (;;) {
1384 			if (timeout_ms > 0) {
1385 				r = waitrfd(ssh_packet_get_connection_in(ssh),
1386 				    &timeout_ms, NULL);
1387 				if (r == -1 && errno == ETIMEDOUT) {
1388 					send_error(ssh, "Timed out waiting "
1389 					    "for SSH identification string.");
1390 					error("Connection timed out during "
1391 					    "banner exchange");
1392 					r = SSH_ERR_CONN_TIMEOUT;
1393 					goto out;
1394 				} else if (r == -1) {
1395 					oerrno = errno;
1396 					error_f("%s", strerror(errno));
1397 					r = SSH_ERR_SYSTEM_ERROR;
1398 					goto out;
1399 				}
1400 			}
1401 
1402 			len = atomicio(read, ssh_packet_get_connection_in(ssh),
1403 			    &c, 1);
1404 			if (len != 1 && errno == EPIPE) {
1405 				verbose_f("Connection closed by remote host");
1406 				r = SSH_ERR_CONN_CLOSED;
1407 				goto out;
1408 			} else if (len != 1) {
1409 				oerrno = errno;
1410 				error_f("read: %.100s", strerror(errno));
1411 				r = SSH_ERR_SYSTEM_ERROR;
1412 				goto out;
1413 			}
1414 			if (c == '\r') {
1415 				expect_nl = 1;
1416 				continue;
1417 			}
1418 			if (c == '\n')
1419 				break;
1420 			if (c == '\0' || expect_nl) {
1421 				verbose_f("banner line contains invalid "
1422 				    "characters");
1423 				goto invalid;
1424 			}
1425 			if ((r = sshbuf_put_u8(peer_version, c)) != 0) {
1426 				oerrno = errno;
1427 				error_fr(r, "sshbuf_put");
1428 				goto out;
1429 			}
1430 			if (sshbuf_len(peer_version) > SSH_MAX_BANNER_LEN) {
1431 				verbose_f("banner line too long");
1432 				goto invalid;
1433 			}
1434 		}
1435 		/* Is this an actual protocol banner? */
1436 		if (sshbuf_len(peer_version) > 4 &&
1437 		    memcmp(sshbuf_ptr(peer_version), "SSH-", 4) == 0)
1438 			break;
1439 		/* If not, then just log the line and continue */
1440 		if ((cp = sshbuf_dup_string(peer_version)) == NULL) {
1441 			error_f("sshbuf_dup_string failed");
1442 			r = SSH_ERR_ALLOC_FAIL;
1443 			goto out;
1444 		}
1445 		/* Do not accept lines before the SSH ident from a client */
1446 		if (ssh->kex->server) {
1447 			verbose_f("client sent invalid protocol identifier "
1448 			    "\"%.256s\"", cp);
1449 			free(cp);
1450 			goto invalid;
1451 		}
1452 		debug_f("banner line %zu: %s", n, cp);
1453 		free(cp);
1454 	}
1455 	peer_version_string = sshbuf_dup_string(peer_version);
1456 	if (peer_version_string == NULL)
1457 		fatal_f("sshbuf_dup_string failed");
1458 	/* XXX must be same size for sscanf */
1459 	if ((remote_version = calloc(1, sshbuf_len(peer_version))) == NULL) {
1460 		error_f("calloc failed");
1461 		r = SSH_ERR_ALLOC_FAIL;
1462 		goto out;
1463 	}
1464 
1465 	/*
1466 	 * Check that the versions match.  In future this might accept
1467 	 * several versions and set appropriate flags to handle them.
1468 	 */
1469 	if (sscanf(peer_version_string, "SSH-%d.%d-%[^\n]\n",
1470 	    &remote_major, &remote_minor, remote_version) != 3) {
1471 		error("Bad remote protocol version identification: '%.100s'",
1472 		    peer_version_string);
1473  invalid:
1474 		send_error(ssh, "Invalid SSH identification string.");
1475 		r = SSH_ERR_INVALID_FORMAT;
1476 		goto out;
1477 	}
1478 	debug("Remote protocol version %d.%d, remote software version %.100s",
1479 	    remote_major, remote_minor, remote_version);
1480 	compat_banner(ssh, remote_version);
1481 
1482 	mismatch = 0;
1483 	switch (remote_major) {
1484 	case 2:
1485 		break;
1486 	case 1:
1487 		if (remote_minor != 99)
1488 			mismatch = 1;
1489 		break;
1490 	default:
1491 		mismatch = 1;
1492 		break;
1493 	}
1494 	if (mismatch) {
1495 		error("Protocol major versions differ: %d vs. %d",
1496 		    PROTOCOL_MAJOR_2, remote_major);
1497 		send_error(ssh, "Protocol major versions differ.");
1498 		r = SSH_ERR_NO_PROTOCOL_VERSION;
1499 		goto out;
1500 	}
1501 
1502 	if (ssh->kex->server && (ssh->compat & SSH_BUG_PROBE) != 0) {
1503 		logit("probed from %s port %d with %s.  Don't panic.",
1504 		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1505 		    peer_version_string);
1506 		r = SSH_ERR_CONN_CLOSED; /* XXX */
1507 		goto out;
1508 	}
1509 	if (ssh->kex->server && (ssh->compat & SSH_BUG_SCANNER) != 0) {
1510 		logit("scanned from %s port %d with %s.  Don't panic.",
1511 		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1512 		    peer_version_string);
1513 		r = SSH_ERR_CONN_CLOSED; /* XXX */
1514 		goto out;
1515 	}
1516 	/* success */
1517 	r = 0;
1518  out:
1519 	free(our_version_string);
1520 	free(peer_version_string);
1521 	free(remote_version);
1522 	if (r == SSH_ERR_SYSTEM_ERROR)
1523 		errno = oerrno;
1524 	return r;
1525 }
1526 
1527