xref: /freebsd/crypto/openssh/kex.c (revision bdd1243df58e60e85101c09001d9812a789b6bc4)
1 /* $OpenBSD: kex.c,v 1.178 2023/03/12 10:40:39 dtucker Exp $ */
2 /*
3  * Copyright (c) 2000, 2001 Markus Friedl.  All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
15  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
16  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
17  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
18  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
19  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
21  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
23  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24  */
25 
26 #include "includes.h"
27 
28 #include <sys/types.h>
29 #include <errno.h>
30 #include <signal.h>
31 #include <stdarg.h>
32 #include <stdio.h>
33 #include <stdlib.h>
34 #include <string.h>
35 #include <unistd.h>
36 #ifdef HAVE_POLL_H
37 #include <poll.h>
38 #endif
39 
40 #ifdef WITH_OPENSSL
41 #include <openssl/crypto.h>
42 #include <openssl/dh.h>
43 #endif
44 
45 #include "ssh.h"
46 #include "ssh2.h"
47 #include "atomicio.h"
48 #include "version.h"
49 #include "packet.h"
50 #include "compat.h"
51 #include "cipher.h"
52 #include "sshkey.h"
53 #include "kex.h"
54 #include "log.h"
55 #include "mac.h"
56 #include "match.h"
57 #include "misc.h"
58 #include "dispatch.h"
59 #include "monitor.h"
60 #include "myproposal.h"
61 
62 #include "ssherr.h"
63 #include "sshbuf.h"
64 #include "digest.h"
65 #include "xmalloc.h"
66 
67 /* prototype */
68 static int kex_choose_conf(struct ssh *);
69 static int kex_input_newkeys(int, u_int32_t, struct ssh *);
70 
71 static const char * const proposal_names[PROPOSAL_MAX] = {
72 	"KEX algorithms",
73 	"host key algorithms",
74 	"ciphers ctos",
75 	"ciphers stoc",
76 	"MACs ctos",
77 	"MACs stoc",
78 	"compression ctos",
79 	"compression stoc",
80 	"languages ctos",
81 	"languages stoc",
82 };
83 
84 struct kexalg {
85 	char *name;
86 	u_int type;
87 	int ec_nid;
88 	int hash_alg;
89 };
90 static const struct kexalg kexalgs[] = {
91 #ifdef WITH_OPENSSL
92 	{ KEX_DH1, KEX_DH_GRP1_SHA1, 0, SSH_DIGEST_SHA1 },
93 	{ KEX_DH14_SHA1, KEX_DH_GRP14_SHA1, 0, SSH_DIGEST_SHA1 },
94 	{ KEX_DH14_SHA256, KEX_DH_GRP14_SHA256, 0, SSH_DIGEST_SHA256 },
95 	{ KEX_DH16_SHA512, KEX_DH_GRP16_SHA512, 0, SSH_DIGEST_SHA512 },
96 	{ KEX_DH18_SHA512, KEX_DH_GRP18_SHA512, 0, SSH_DIGEST_SHA512 },
97 	{ KEX_DHGEX_SHA1, KEX_DH_GEX_SHA1, 0, SSH_DIGEST_SHA1 },
98 #ifdef HAVE_EVP_SHA256
99 	{ KEX_DHGEX_SHA256, KEX_DH_GEX_SHA256, 0, SSH_DIGEST_SHA256 },
100 #endif /* HAVE_EVP_SHA256 */
101 #ifdef OPENSSL_HAS_ECC
102 	{ KEX_ECDH_SHA2_NISTP256, KEX_ECDH_SHA2,
103 	    NID_X9_62_prime256v1, SSH_DIGEST_SHA256 },
104 	{ KEX_ECDH_SHA2_NISTP384, KEX_ECDH_SHA2, NID_secp384r1,
105 	    SSH_DIGEST_SHA384 },
106 # ifdef OPENSSL_HAS_NISTP521
107 	{ KEX_ECDH_SHA2_NISTP521, KEX_ECDH_SHA2, NID_secp521r1,
108 	    SSH_DIGEST_SHA512 },
109 # endif /* OPENSSL_HAS_NISTP521 */
110 #endif /* OPENSSL_HAS_ECC */
111 #endif /* WITH_OPENSSL */
112 #if defined(HAVE_EVP_SHA256) || !defined(WITH_OPENSSL)
113 	{ KEX_CURVE25519_SHA256, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
114 	{ KEX_CURVE25519_SHA256_OLD, KEX_C25519_SHA256, 0, SSH_DIGEST_SHA256 },
115 #ifdef USE_SNTRUP761X25519
116 	{ KEX_SNTRUP761X25519_SHA512, KEX_KEM_SNTRUP761X25519_SHA512, 0,
117 	    SSH_DIGEST_SHA512 },
118 #endif
119 #endif /* HAVE_EVP_SHA256 || !WITH_OPENSSL */
120 	{ NULL, 0, -1, -1},
121 };
122 
123 char *
124 kex_alg_list(char sep)
125 {
126 	char *ret = NULL, *tmp;
127 	size_t nlen, rlen = 0;
128 	const struct kexalg *k;
129 
130 	for (k = kexalgs; k->name != NULL; k++) {
131 		if (ret != NULL)
132 			ret[rlen++] = sep;
133 		nlen = strlen(k->name);
134 		if ((tmp = realloc(ret, rlen + nlen + 2)) == NULL) {
135 			free(ret);
136 			return NULL;
137 		}
138 		ret = tmp;
139 		memcpy(ret + rlen, k->name, nlen + 1);
140 		rlen += nlen;
141 	}
142 	return ret;
143 }
144 
145 static const struct kexalg *
146 kex_alg_by_name(const char *name)
147 {
148 	const struct kexalg *k;
149 
150 	for (k = kexalgs; k->name != NULL; k++) {
151 		if (strcmp(k->name, name) == 0)
152 			return k;
153 	}
154 	return NULL;
155 }
156 
157 /* Validate KEX method name list */
158 int
159 kex_names_valid(const char *names)
160 {
161 	char *s, *cp, *p;
162 
163 	if (names == NULL || strcmp(names, "") == 0)
164 		return 0;
165 	if ((s = cp = strdup(names)) == NULL)
166 		return 0;
167 	for ((p = strsep(&cp, ",")); p && *p != '\0';
168 	    (p = strsep(&cp, ","))) {
169 		if (kex_alg_by_name(p) == NULL) {
170 			error("Unsupported KEX algorithm \"%.100s\"", p);
171 			free(s);
172 			return 0;
173 		}
174 	}
175 	debug3("kex names ok: [%s]", names);
176 	free(s);
177 	return 1;
178 }
179 
180 /*
181  * Concatenate algorithm names, avoiding duplicates in the process.
182  * Caller must free returned string.
183  */
184 char *
185 kex_names_cat(const char *a, const char *b)
186 {
187 	char *ret = NULL, *tmp = NULL, *cp, *p, *m;
188 	size_t len;
189 
190 	if (a == NULL || *a == '\0')
191 		return strdup(b);
192 	if (b == NULL || *b == '\0')
193 		return strdup(a);
194 	if (strlen(b) > 1024*1024)
195 		return NULL;
196 	len = strlen(a) + strlen(b) + 2;
197 	if ((tmp = cp = strdup(b)) == NULL ||
198 	    (ret = calloc(1, len)) == NULL) {
199 		free(tmp);
200 		return NULL;
201 	}
202 	strlcpy(ret, a, len);
203 	for ((p = strsep(&cp, ",")); p && *p != '\0'; (p = strsep(&cp, ","))) {
204 		if ((m = match_list(ret, p, NULL)) != NULL) {
205 			free(m);
206 			continue; /* Algorithm already present */
207 		}
208 		if (strlcat(ret, ",", len) >= len ||
209 		    strlcat(ret, p, len) >= len) {
210 			free(tmp);
211 			free(ret);
212 			return NULL; /* Shouldn't happen */
213 		}
214 	}
215 	free(tmp);
216 	return ret;
217 }
218 
219 /*
220  * Assemble a list of algorithms from a default list and a string from a
221  * configuration file. The user-provided string may begin with '+' to
222  * indicate that it should be appended to the default, '-' that the
223  * specified names should be removed, or '^' that they should be placed
224  * at the head.
225  */
226 int
227 kex_assemble_names(char **listp, const char *def, const char *all)
228 {
229 	char *cp, *tmp, *patterns;
230 	char *list = NULL, *ret = NULL, *matching = NULL, *opatterns = NULL;
231 	int r = SSH_ERR_INTERNAL_ERROR;
232 
233 	if (listp == NULL || def == NULL || all == NULL)
234 		return SSH_ERR_INVALID_ARGUMENT;
235 
236 	if (*listp == NULL || **listp == '\0') {
237 		if ((*listp = strdup(def)) == NULL)
238 			return SSH_ERR_ALLOC_FAIL;
239 		return 0;
240 	}
241 
242 	list = *listp;
243 	*listp = NULL;
244 	if (*list == '+') {
245 		/* Append names to default list */
246 		if ((tmp = kex_names_cat(def, list + 1)) == NULL) {
247 			r = SSH_ERR_ALLOC_FAIL;
248 			goto fail;
249 		}
250 		free(list);
251 		list = tmp;
252 	} else if (*list == '-') {
253 		/* Remove names from default list */
254 		if ((*listp = match_filter_denylist(def, list + 1)) == NULL) {
255 			r = SSH_ERR_ALLOC_FAIL;
256 			goto fail;
257 		}
258 		free(list);
259 		/* filtering has already been done */
260 		return 0;
261 	} else if (*list == '^') {
262 		/* Place names at head of default list */
263 		if ((tmp = kex_names_cat(list + 1, def)) == NULL) {
264 			r = SSH_ERR_ALLOC_FAIL;
265 			goto fail;
266 		}
267 		free(list);
268 		list = tmp;
269 	} else {
270 		/* Explicit list, overrides default - just use "list" as is */
271 	}
272 
273 	/*
274 	 * The supplied names may be a pattern-list. For the -list case,
275 	 * the patterns are applied above. For the +list and explicit list
276 	 * cases we need to do it now.
277 	 */
278 	ret = NULL;
279 	if ((patterns = opatterns = strdup(list)) == NULL) {
280 		r = SSH_ERR_ALLOC_FAIL;
281 		goto fail;
282 	}
283 	/* Apply positive (i.e. non-negated) patterns from the list */
284 	while ((cp = strsep(&patterns, ",")) != NULL) {
285 		if (*cp == '!') {
286 			/* negated matches are not supported here */
287 			r = SSH_ERR_INVALID_ARGUMENT;
288 			goto fail;
289 		}
290 		free(matching);
291 		if ((matching = match_filter_allowlist(all, cp)) == NULL) {
292 			r = SSH_ERR_ALLOC_FAIL;
293 			goto fail;
294 		}
295 		if ((tmp = kex_names_cat(ret, matching)) == NULL) {
296 			r = SSH_ERR_ALLOC_FAIL;
297 			goto fail;
298 		}
299 		free(ret);
300 		ret = tmp;
301 	}
302 	if (ret == NULL || *ret == '\0') {
303 		/* An empty name-list is an error */
304 		/* XXX better error code? */
305 		r = SSH_ERR_INVALID_ARGUMENT;
306 		goto fail;
307 	}
308 
309 	/* success */
310 	*listp = ret;
311 	ret = NULL;
312 	r = 0;
313 
314  fail:
315 	free(matching);
316 	free(opatterns);
317 	free(list);
318 	free(ret);
319 	return r;
320 }
321 
322 /*
323  * Fill out a proposal array with dynamically allocated values, which may
324  * be modified as required for compatibility reasons.
325  * Any of the options may be NULL, in which case the default is used.
326  * Array contents must be freed by calling kex_proposal_free_entries.
327  */
328 void
329 kex_proposal_populate_entries(struct ssh *ssh, char *prop[PROPOSAL_MAX],
330     const char *kexalgos, const char *ciphers, const char *macs,
331     const char *comp, const char *hkalgs)
332 {
333 	const char *defpropserver[PROPOSAL_MAX] = { KEX_SERVER };
334 	const char *defpropclient[PROPOSAL_MAX] = { KEX_CLIENT };
335 	const char **defprop = ssh->kex->server ? defpropserver : defpropclient;
336 	u_int i;
337 
338 	if (prop == NULL)
339 		fatal_f("proposal missing");
340 
341 	for (i = 0; i < PROPOSAL_MAX; i++) {
342 		switch(i) {
343 		case PROPOSAL_KEX_ALGS:
344 			prop[i] = compat_kex_proposal(ssh,
345 			    kexalgos ? kexalgos : defprop[i]);
346 			break;
347 		case PROPOSAL_ENC_ALGS_CTOS:
348 		case PROPOSAL_ENC_ALGS_STOC:
349 			prop[i] = xstrdup(ciphers ? ciphers : defprop[i]);
350 			break;
351 		case PROPOSAL_MAC_ALGS_CTOS:
352 		case PROPOSAL_MAC_ALGS_STOC:
353 			prop[i]  = xstrdup(macs ? macs : defprop[i]);
354 			break;
355 		case PROPOSAL_COMP_ALGS_CTOS:
356 		case PROPOSAL_COMP_ALGS_STOC:
357 			prop[i] = xstrdup(comp ? comp : defprop[i]);
358 			break;
359 		case PROPOSAL_SERVER_HOST_KEY_ALGS:
360 			prop[i] = xstrdup(hkalgs ? hkalgs : defprop[i]);
361 			break;
362 		default:
363 			prop[i] = xstrdup(defprop[i]);
364 		}
365 	}
366 }
367 
368 void
369 kex_proposal_free_entries(char *prop[PROPOSAL_MAX])
370 {
371 	u_int i;
372 
373 	for (i = 0; i < PROPOSAL_MAX; i++)
374 		free(prop[i]);
375 }
376 
377 /* put algorithm proposal into buffer */
378 int
379 kex_prop2buf(struct sshbuf *b, char *proposal[PROPOSAL_MAX])
380 {
381 	u_int i;
382 	int r;
383 
384 	sshbuf_reset(b);
385 
386 	/*
387 	 * add a dummy cookie, the cookie will be overwritten by
388 	 * kex_send_kexinit(), each time a kexinit is set
389 	 */
390 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
391 		if ((r = sshbuf_put_u8(b, 0)) != 0)
392 			return r;
393 	}
394 	for (i = 0; i < PROPOSAL_MAX; i++) {
395 		if ((r = sshbuf_put_cstring(b, proposal[i])) != 0)
396 			return r;
397 	}
398 	if ((r = sshbuf_put_u8(b, 0)) != 0 ||	/* first_kex_packet_follows */
399 	    (r = sshbuf_put_u32(b, 0)) != 0)	/* uint32 reserved */
400 		return r;
401 	return 0;
402 }
403 
404 /* parse buffer and return algorithm proposal */
405 int
406 kex_buf2prop(struct sshbuf *raw, int *first_kex_follows, char ***propp)
407 {
408 	struct sshbuf *b = NULL;
409 	u_char v;
410 	u_int i;
411 	char **proposal = NULL;
412 	int r;
413 
414 	*propp = NULL;
415 	if ((proposal = calloc(PROPOSAL_MAX, sizeof(char *))) == NULL)
416 		return SSH_ERR_ALLOC_FAIL;
417 	if ((b = sshbuf_fromb(raw)) == NULL) {
418 		r = SSH_ERR_ALLOC_FAIL;
419 		goto out;
420 	}
421 	if ((r = sshbuf_consume(b, KEX_COOKIE_LEN)) != 0) { /* skip cookie */
422 		error_fr(r, "consume cookie");
423 		goto out;
424 	}
425 	/* extract kex init proposal strings */
426 	for (i = 0; i < PROPOSAL_MAX; i++) {
427 		if ((r = sshbuf_get_cstring(b, &(proposal[i]), NULL)) != 0) {
428 			error_fr(r, "parse proposal %u", i);
429 			goto out;
430 		}
431 		debug2("%s: %s", proposal_names[i], proposal[i]);
432 	}
433 	/* first kex follows / reserved */
434 	if ((r = sshbuf_get_u8(b, &v)) != 0 ||	/* first_kex_follows */
435 	    (r = sshbuf_get_u32(b, &i)) != 0) {	/* reserved */
436 		error_fr(r, "parse");
437 		goto out;
438 	}
439 	if (first_kex_follows != NULL)
440 		*first_kex_follows = v;
441 	debug2("first_kex_follows %d ", v);
442 	debug2("reserved %u ", i);
443 	r = 0;
444 	*propp = proposal;
445  out:
446 	if (r != 0 && proposal != NULL)
447 		kex_prop_free(proposal);
448 	sshbuf_free(b);
449 	return r;
450 }
451 
452 void
453 kex_prop_free(char **proposal)
454 {
455 	u_int i;
456 
457 	if (proposal == NULL)
458 		return;
459 	for (i = 0; i < PROPOSAL_MAX; i++)
460 		free(proposal[i]);
461 	free(proposal);
462 }
463 
464 int
465 kex_protocol_error(int type, u_int32_t seq, struct ssh *ssh)
466 {
467 	int r;
468 
469 	error("kex protocol error: type %d seq %u", type, seq);
470 	if ((r = sshpkt_start(ssh, SSH2_MSG_UNIMPLEMENTED)) != 0 ||
471 	    (r = sshpkt_put_u32(ssh, seq)) != 0 ||
472 	    (r = sshpkt_send(ssh)) != 0)
473 		return r;
474 	return 0;
475 }
476 
477 static void
478 kex_reset_dispatch(struct ssh *ssh)
479 {
480 	ssh_dispatch_range(ssh, SSH2_MSG_TRANSPORT_MIN,
481 	    SSH2_MSG_TRANSPORT_MAX, &kex_protocol_error);
482 }
483 
484 static int
485 kex_send_ext_info(struct ssh *ssh)
486 {
487 	int r;
488 	char *algs;
489 
490 	debug("Sending SSH2_MSG_EXT_INFO");
491 	if ((algs = sshkey_alg_list(0, 1, 1, ',')) == NULL)
492 		return SSH_ERR_ALLOC_FAIL;
493 	/* XXX filter algs list by allowed pubkey/hostbased types */
494 	if ((r = sshpkt_start(ssh, SSH2_MSG_EXT_INFO)) != 0 ||
495 	    (r = sshpkt_put_u32(ssh, 2)) != 0 ||
496 	    (r = sshpkt_put_cstring(ssh, "server-sig-algs")) != 0 ||
497 	    (r = sshpkt_put_cstring(ssh, algs)) != 0 ||
498 	    (r = sshpkt_put_cstring(ssh,
499 	    "publickey-hostbound@openssh.com")) != 0 ||
500 	    (r = sshpkt_put_cstring(ssh, "0")) != 0 ||
501 	    (r = sshpkt_send(ssh)) != 0) {
502 		error_fr(r, "compose");
503 		goto out;
504 	}
505 	/* success */
506 	r = 0;
507  out:
508 	free(algs);
509 	return r;
510 }
511 
512 int
513 kex_send_newkeys(struct ssh *ssh)
514 {
515 	int r;
516 
517 	kex_reset_dispatch(ssh);
518 	if ((r = sshpkt_start(ssh, SSH2_MSG_NEWKEYS)) != 0 ||
519 	    (r = sshpkt_send(ssh)) != 0)
520 		return r;
521 	debug("SSH2_MSG_NEWKEYS sent");
522 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_input_newkeys);
523 	if (ssh->kex->ext_info_c && (ssh->kex->flags & KEX_INITIAL) != 0)
524 		if ((r = kex_send_ext_info(ssh)) != 0)
525 			return r;
526 	debug("expecting SSH2_MSG_NEWKEYS");
527 	return 0;
528 }
529 
530 int
531 kex_input_ext_info(int type, u_int32_t seq, struct ssh *ssh)
532 {
533 	struct kex *kex = ssh->kex;
534 	u_int32_t i, ninfo;
535 	char *name;
536 	u_char *val;
537 	size_t vlen;
538 	int r;
539 
540 	debug("SSH2_MSG_EXT_INFO received");
541 	ssh_dispatch_set(ssh, SSH2_MSG_EXT_INFO, &kex_protocol_error);
542 	if ((r = sshpkt_get_u32(ssh, &ninfo)) != 0)
543 		return r;
544 	if (ninfo >= 1024) {
545 		error("SSH2_MSG_EXT_INFO with too many entries, expected "
546 		    "<=1024, received %u", ninfo);
547 		return SSH_ERR_INVALID_FORMAT;
548 	}
549 	for (i = 0; i < ninfo; i++) {
550 		if ((r = sshpkt_get_cstring(ssh, &name, NULL)) != 0)
551 			return r;
552 		if ((r = sshpkt_get_string(ssh, &val, &vlen)) != 0) {
553 			free(name);
554 			return r;
555 		}
556 		if (strcmp(name, "server-sig-algs") == 0) {
557 			/* Ensure no \0 lurking in value */
558 			if (memchr(val, '\0', vlen) != NULL) {
559 				error_f("nul byte in %s", name);
560 				return SSH_ERR_INVALID_FORMAT;
561 			}
562 			debug_f("%s=<%s>", name, val);
563 			kex->server_sig_algs = val;
564 			val = NULL;
565 		} else if (strcmp(name,
566 		    "publickey-hostbound@openssh.com") == 0) {
567 			/* XXX refactor */
568 			/* Ensure no \0 lurking in value */
569 			if (memchr(val, '\0', vlen) != NULL) {
570 				error_f("nul byte in %s", name);
571 				return SSH_ERR_INVALID_FORMAT;
572 			}
573 			debug_f("%s=<%s>", name, val);
574 			if (strcmp(val, "0") == 0)
575 				kex->flags |= KEX_HAS_PUBKEY_HOSTBOUND;
576 			else {
577 				debug_f("unsupported version of %s extension",
578 				    name);
579 			}
580 		} else
581 			debug_f("%s (unrecognised)", name);
582 		free(name);
583 		free(val);
584 	}
585 	return sshpkt_get_end(ssh);
586 }
587 
588 static int
589 kex_input_newkeys(int type, u_int32_t seq, struct ssh *ssh)
590 {
591 	struct kex *kex = ssh->kex;
592 	int r;
593 
594 	debug("SSH2_MSG_NEWKEYS received");
595 	ssh_dispatch_set(ssh, SSH2_MSG_NEWKEYS, &kex_protocol_error);
596 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
597 	if ((r = sshpkt_get_end(ssh)) != 0)
598 		return r;
599 	if ((r = ssh_set_newkeys(ssh, MODE_IN)) != 0)
600 		return r;
601 	kex->done = 1;
602 	kex->flags &= ~KEX_INITIAL;
603 	sshbuf_reset(kex->peer);
604 	/* sshbuf_reset(kex->my); */
605 	kex->flags &= ~KEX_INIT_SENT;
606 	free(kex->name);
607 	kex->name = NULL;
608 	return 0;
609 }
610 
611 int
612 kex_send_kexinit(struct ssh *ssh)
613 {
614 	u_char *cookie;
615 	struct kex *kex = ssh->kex;
616 	int r;
617 
618 	if (kex == NULL) {
619 		error_f("no kex");
620 		return SSH_ERR_INTERNAL_ERROR;
621 	}
622 	if (kex->flags & KEX_INIT_SENT)
623 		return 0;
624 	kex->done = 0;
625 
626 	/* generate a random cookie */
627 	if (sshbuf_len(kex->my) < KEX_COOKIE_LEN) {
628 		error_f("bad kex length: %zu < %d",
629 		    sshbuf_len(kex->my), KEX_COOKIE_LEN);
630 		return SSH_ERR_INVALID_FORMAT;
631 	}
632 	if ((cookie = sshbuf_mutable_ptr(kex->my)) == NULL) {
633 		error_f("buffer error");
634 		return SSH_ERR_INTERNAL_ERROR;
635 	}
636 	arc4random_buf(cookie, KEX_COOKIE_LEN);
637 
638 	if ((r = sshpkt_start(ssh, SSH2_MSG_KEXINIT)) != 0 ||
639 	    (r = sshpkt_putb(ssh, kex->my)) != 0 ||
640 	    (r = sshpkt_send(ssh)) != 0) {
641 		error_fr(r, "compose reply");
642 		return r;
643 	}
644 	debug("SSH2_MSG_KEXINIT sent");
645 	kex->flags |= KEX_INIT_SENT;
646 	return 0;
647 }
648 
649 int
650 kex_input_kexinit(int type, u_int32_t seq, struct ssh *ssh)
651 {
652 	struct kex *kex = ssh->kex;
653 	const u_char *ptr;
654 	u_int i;
655 	size_t dlen;
656 	int r;
657 
658 	debug("SSH2_MSG_KEXINIT received");
659 	if (kex == NULL) {
660 		error_f("no kex");
661 		return SSH_ERR_INTERNAL_ERROR;
662 	}
663 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, NULL);
664 	ptr = sshpkt_ptr(ssh, &dlen);
665 	if ((r = sshbuf_put(kex->peer, ptr, dlen)) != 0)
666 		return r;
667 
668 	/* discard packet */
669 	for (i = 0; i < KEX_COOKIE_LEN; i++) {
670 		if ((r = sshpkt_get_u8(ssh, NULL)) != 0) {
671 			error_fr(r, "discard cookie");
672 			return r;
673 		}
674 	}
675 	for (i = 0; i < PROPOSAL_MAX; i++) {
676 		if ((r = sshpkt_get_string(ssh, NULL, NULL)) != 0) {
677 			error_fr(r, "discard proposal");
678 			return r;
679 		}
680 	}
681 	/*
682 	 * XXX RFC4253 sec 7: "each side MAY guess" - currently no supported
683 	 * KEX method has the server move first, but a server might be using
684 	 * a custom method or one that we otherwise don't support. We should
685 	 * be prepared to remember first_kex_follows here so we can eat a
686 	 * packet later.
687 	 * XXX2 - RFC4253 is kind of ambiguous on what first_kex_follows means
688 	 * for cases where the server *doesn't* go first. I guess we should
689 	 * ignore it when it is set for these cases, which is what we do now.
690 	 */
691 	if ((r = sshpkt_get_u8(ssh, NULL)) != 0 ||	/* first_kex_follows */
692 	    (r = sshpkt_get_u32(ssh, NULL)) != 0 ||	/* reserved */
693 	    (r = sshpkt_get_end(ssh)) != 0)
694 			return r;
695 
696 	if (!(kex->flags & KEX_INIT_SENT))
697 		if ((r = kex_send_kexinit(ssh)) != 0)
698 			return r;
699 	if ((r = kex_choose_conf(ssh)) != 0)
700 		return r;
701 
702 	if (kex->kex_type < KEX_MAX && kex->kex[kex->kex_type] != NULL)
703 		return (kex->kex[kex->kex_type])(ssh);
704 
705 	error_f("unknown kex type %u", kex->kex_type);
706 	return SSH_ERR_INTERNAL_ERROR;
707 }
708 
709 struct kex *
710 kex_new(void)
711 {
712 	struct kex *kex;
713 
714 	if ((kex = calloc(1, sizeof(*kex))) == NULL ||
715 	    (kex->peer = sshbuf_new()) == NULL ||
716 	    (kex->my = sshbuf_new()) == NULL ||
717 	    (kex->client_version = sshbuf_new()) == NULL ||
718 	    (kex->server_version = sshbuf_new()) == NULL ||
719 	    (kex->session_id = sshbuf_new()) == NULL) {
720 		kex_free(kex);
721 		return NULL;
722 	}
723 	return kex;
724 }
725 
726 void
727 kex_free_newkeys(struct newkeys *newkeys)
728 {
729 	if (newkeys == NULL)
730 		return;
731 	if (newkeys->enc.key) {
732 		explicit_bzero(newkeys->enc.key, newkeys->enc.key_len);
733 		free(newkeys->enc.key);
734 		newkeys->enc.key = NULL;
735 	}
736 	if (newkeys->enc.iv) {
737 		explicit_bzero(newkeys->enc.iv, newkeys->enc.iv_len);
738 		free(newkeys->enc.iv);
739 		newkeys->enc.iv = NULL;
740 	}
741 	free(newkeys->enc.name);
742 	explicit_bzero(&newkeys->enc, sizeof(newkeys->enc));
743 	free(newkeys->comp.name);
744 	explicit_bzero(&newkeys->comp, sizeof(newkeys->comp));
745 	mac_clear(&newkeys->mac);
746 	if (newkeys->mac.key) {
747 		explicit_bzero(newkeys->mac.key, newkeys->mac.key_len);
748 		free(newkeys->mac.key);
749 		newkeys->mac.key = NULL;
750 	}
751 	free(newkeys->mac.name);
752 	explicit_bzero(&newkeys->mac, sizeof(newkeys->mac));
753 	freezero(newkeys, sizeof(*newkeys));
754 }
755 
756 void
757 kex_free(struct kex *kex)
758 {
759 	u_int mode;
760 
761 	if (kex == NULL)
762 		return;
763 
764 #ifdef WITH_OPENSSL
765 	DH_free(kex->dh);
766 #ifdef OPENSSL_HAS_ECC
767 	EC_KEY_free(kex->ec_client_key);
768 #endif /* OPENSSL_HAS_ECC */
769 #endif /* WITH_OPENSSL */
770 	for (mode = 0; mode < MODE_MAX; mode++) {
771 		kex_free_newkeys(kex->newkeys[mode]);
772 		kex->newkeys[mode] = NULL;
773 	}
774 	sshbuf_free(kex->peer);
775 	sshbuf_free(kex->my);
776 	sshbuf_free(kex->client_version);
777 	sshbuf_free(kex->server_version);
778 	sshbuf_free(kex->client_pub);
779 	sshbuf_free(kex->session_id);
780 	sshbuf_free(kex->initial_sig);
781 	sshkey_free(kex->initial_hostkey);
782 	free(kex->failed_choice);
783 	free(kex->hostkey_alg);
784 	free(kex->name);
785 	free(kex);
786 }
787 
788 int
789 kex_ready(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
790 {
791 	int r;
792 
793 	if ((r = kex_prop2buf(ssh->kex->my, proposal)) != 0)
794 		return r;
795 	ssh->kex->flags = KEX_INITIAL;
796 	kex_reset_dispatch(ssh);
797 	ssh_dispatch_set(ssh, SSH2_MSG_KEXINIT, &kex_input_kexinit);
798 	return 0;
799 }
800 
801 int
802 kex_setup(struct ssh *ssh, char *proposal[PROPOSAL_MAX])
803 {
804 	int r;
805 
806 	if ((r = kex_ready(ssh, proposal)) != 0)
807 		return r;
808 	if ((r = kex_send_kexinit(ssh)) != 0) {		/* we start */
809 		kex_free(ssh->kex);
810 		ssh->kex = NULL;
811 		return r;
812 	}
813 	return 0;
814 }
815 
816 /*
817  * Request key re-exchange, returns 0 on success or a ssherr.h error
818  * code otherwise. Must not be called if KEX is incomplete or in-progress.
819  */
820 int
821 kex_start_rekex(struct ssh *ssh)
822 {
823 	if (ssh->kex == NULL) {
824 		error_f("no kex");
825 		return SSH_ERR_INTERNAL_ERROR;
826 	}
827 	if (ssh->kex->done == 0) {
828 		error_f("requested twice");
829 		return SSH_ERR_INTERNAL_ERROR;
830 	}
831 	ssh->kex->done = 0;
832 	return kex_send_kexinit(ssh);
833 }
834 
835 static int
836 choose_enc(struct sshenc *enc, char *client, char *server)
837 {
838 	char *name = match_list(client, server, NULL);
839 
840 	if (name == NULL)
841 		return SSH_ERR_NO_CIPHER_ALG_MATCH;
842 	if ((enc->cipher = cipher_by_name(name)) == NULL) {
843 		error_f("unsupported cipher %s", name);
844 		free(name);
845 		return SSH_ERR_INTERNAL_ERROR;
846 	}
847 	enc->name = name;
848 	enc->enabled = 0;
849 	enc->iv = NULL;
850 	enc->iv_len = cipher_ivlen(enc->cipher);
851 	enc->key = NULL;
852 	enc->key_len = cipher_keylen(enc->cipher);
853 	enc->block_size = cipher_blocksize(enc->cipher);
854 	return 0;
855 }
856 
857 static int
858 choose_mac(struct ssh *ssh, struct sshmac *mac, char *client, char *server)
859 {
860 	char *name = match_list(client, server, NULL);
861 
862 	if (name == NULL)
863 		return SSH_ERR_NO_MAC_ALG_MATCH;
864 	if (mac_setup(mac, name) < 0) {
865 		error_f("unsupported MAC %s", name);
866 		free(name);
867 		return SSH_ERR_INTERNAL_ERROR;
868 	}
869 	mac->name = name;
870 	mac->key = NULL;
871 	mac->enabled = 0;
872 	return 0;
873 }
874 
875 static int
876 choose_comp(struct sshcomp *comp, char *client, char *server)
877 {
878 	char *name = match_list(client, server, NULL);
879 
880 	if (name == NULL)
881 		return SSH_ERR_NO_COMPRESS_ALG_MATCH;
882 #ifdef WITH_ZLIB
883 	if (strcmp(name, "zlib@openssh.com") == 0) {
884 		comp->type = COMP_DELAYED;
885 	} else if (strcmp(name, "zlib") == 0) {
886 		comp->type = COMP_ZLIB;
887 	} else
888 #endif	/* WITH_ZLIB */
889 	if (strcmp(name, "none") == 0) {
890 		comp->type = COMP_NONE;
891 	} else {
892 		error_f("unsupported compression scheme %s", name);
893 		free(name);
894 		return SSH_ERR_INTERNAL_ERROR;
895 	}
896 	comp->name = name;
897 	return 0;
898 }
899 
900 static int
901 choose_kex(struct kex *k, char *client, char *server)
902 {
903 	const struct kexalg *kexalg;
904 
905 	k->name = match_list(client, server, NULL);
906 
907 	debug("kex: algorithm: %s", k->name ? k->name : "(no match)");
908 	if (k->name == NULL)
909 		return SSH_ERR_NO_KEX_ALG_MATCH;
910 	if ((kexalg = kex_alg_by_name(k->name)) == NULL) {
911 		error_f("unsupported KEX method %s", k->name);
912 		return SSH_ERR_INTERNAL_ERROR;
913 	}
914 	k->kex_type = kexalg->type;
915 	k->hash_alg = kexalg->hash_alg;
916 	k->ec_nid = kexalg->ec_nid;
917 	return 0;
918 }
919 
920 static int
921 choose_hostkeyalg(struct kex *k, char *client, char *server)
922 {
923 	free(k->hostkey_alg);
924 	k->hostkey_alg = match_list(client, server, NULL);
925 
926 	debug("kex: host key algorithm: %s",
927 	    k->hostkey_alg ? k->hostkey_alg : "(no match)");
928 	if (k->hostkey_alg == NULL)
929 		return SSH_ERR_NO_HOSTKEY_ALG_MATCH;
930 	k->hostkey_type = sshkey_type_from_name(k->hostkey_alg);
931 	if (k->hostkey_type == KEY_UNSPEC) {
932 		error_f("unsupported hostkey algorithm %s", k->hostkey_alg);
933 		return SSH_ERR_INTERNAL_ERROR;
934 	}
935 	k->hostkey_nid = sshkey_ecdsa_nid_from_name(k->hostkey_alg);
936 	return 0;
937 }
938 
939 static int
940 proposals_match(char *my[PROPOSAL_MAX], char *peer[PROPOSAL_MAX])
941 {
942 	static int check[] = {
943 		PROPOSAL_KEX_ALGS, PROPOSAL_SERVER_HOST_KEY_ALGS, -1
944 	};
945 	int *idx;
946 	char *p;
947 
948 	for (idx = &check[0]; *idx != -1; idx++) {
949 		if ((p = strchr(my[*idx], ',')) != NULL)
950 			*p = '\0';
951 		if ((p = strchr(peer[*idx], ',')) != NULL)
952 			*p = '\0';
953 		if (strcmp(my[*idx], peer[*idx]) != 0) {
954 			debug2("proposal mismatch: my %s peer %s",
955 			    my[*idx], peer[*idx]);
956 			return (0);
957 		}
958 	}
959 	debug2("proposals match");
960 	return (1);
961 }
962 
963 /* returns non-zero if proposal contains any algorithm from algs */
964 static int
965 has_any_alg(const char *proposal, const char *algs)
966 {
967 	char *cp;
968 
969 	if ((cp = match_list(proposal, algs, NULL)) == NULL)
970 		return 0;
971 	free(cp);
972 	return 1;
973 }
974 
975 static int
976 kex_choose_conf(struct ssh *ssh)
977 {
978 	struct kex *kex = ssh->kex;
979 	struct newkeys *newkeys;
980 	char **my = NULL, **peer = NULL;
981 	char **cprop, **sprop;
982 	int nenc, nmac, ncomp;
983 	u_int mode, ctos, need, dh_need, authlen;
984 	int r, first_kex_follows;
985 
986 	debug2("local %s KEXINIT proposal", kex->server ? "server" : "client");
987 	if ((r = kex_buf2prop(kex->my, NULL, &my)) != 0)
988 		goto out;
989 	debug2("peer %s KEXINIT proposal", kex->server ? "client" : "server");
990 	if ((r = kex_buf2prop(kex->peer, &first_kex_follows, &peer)) != 0)
991 		goto out;
992 
993 	if (kex->server) {
994 		cprop=peer;
995 		sprop=my;
996 	} else {
997 		cprop=my;
998 		sprop=peer;
999 	}
1000 
1001 	/* Check whether client supports ext_info_c */
1002 	if (kex->server && (kex->flags & KEX_INITIAL)) {
1003 		char *ext;
1004 
1005 		ext = match_list("ext-info-c", peer[PROPOSAL_KEX_ALGS], NULL);
1006 		kex->ext_info_c = (ext != NULL);
1007 		free(ext);
1008 	}
1009 
1010 	/* Check whether client supports rsa-sha2 algorithms */
1011 	if (kex->server && (kex->flags & KEX_INITIAL)) {
1012 		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1013 		    "rsa-sha2-256,rsa-sha2-256-cert-v01@openssh.com"))
1014 			kex->flags |= KEX_RSA_SHA2_256_SUPPORTED;
1015 		if (has_any_alg(peer[PROPOSAL_SERVER_HOST_KEY_ALGS],
1016 		    "rsa-sha2-512,rsa-sha2-512-cert-v01@openssh.com"))
1017 			kex->flags |= KEX_RSA_SHA2_512_SUPPORTED;
1018 	}
1019 
1020 	/* Algorithm Negotiation */
1021 	if ((r = choose_kex(kex, cprop[PROPOSAL_KEX_ALGS],
1022 	    sprop[PROPOSAL_KEX_ALGS])) != 0) {
1023 		kex->failed_choice = peer[PROPOSAL_KEX_ALGS];
1024 		peer[PROPOSAL_KEX_ALGS] = NULL;
1025 		goto out;
1026 	}
1027 	if ((r = choose_hostkeyalg(kex, cprop[PROPOSAL_SERVER_HOST_KEY_ALGS],
1028 	    sprop[PROPOSAL_SERVER_HOST_KEY_ALGS])) != 0) {
1029 		kex->failed_choice = peer[PROPOSAL_SERVER_HOST_KEY_ALGS];
1030 		peer[PROPOSAL_SERVER_HOST_KEY_ALGS] = NULL;
1031 		goto out;
1032 	}
1033 	for (mode = 0; mode < MODE_MAX; mode++) {
1034 		if ((newkeys = calloc(1, sizeof(*newkeys))) == NULL) {
1035 			r = SSH_ERR_ALLOC_FAIL;
1036 			goto out;
1037 		}
1038 		kex->newkeys[mode] = newkeys;
1039 		ctos = (!kex->server && mode == MODE_OUT) ||
1040 		    (kex->server && mode == MODE_IN);
1041 		nenc  = ctos ? PROPOSAL_ENC_ALGS_CTOS  : PROPOSAL_ENC_ALGS_STOC;
1042 		nmac  = ctos ? PROPOSAL_MAC_ALGS_CTOS  : PROPOSAL_MAC_ALGS_STOC;
1043 		ncomp = ctos ? PROPOSAL_COMP_ALGS_CTOS : PROPOSAL_COMP_ALGS_STOC;
1044 		if ((r = choose_enc(&newkeys->enc, cprop[nenc],
1045 		    sprop[nenc])) != 0) {
1046 			kex->failed_choice = peer[nenc];
1047 			peer[nenc] = NULL;
1048 			goto out;
1049 		}
1050 		authlen = cipher_authlen(newkeys->enc.cipher);
1051 		/* ignore mac for authenticated encryption */
1052 		if (authlen == 0 &&
1053 		    (r = choose_mac(ssh, &newkeys->mac, cprop[nmac],
1054 		    sprop[nmac])) != 0) {
1055 			kex->failed_choice = peer[nmac];
1056 			peer[nmac] = NULL;
1057 			goto out;
1058 		}
1059 		if ((r = choose_comp(&newkeys->comp, cprop[ncomp],
1060 		    sprop[ncomp])) != 0) {
1061 			kex->failed_choice = peer[ncomp];
1062 			peer[ncomp] = NULL;
1063 			goto out;
1064 		}
1065 		debug("kex: %s cipher: %s MAC: %s compression: %s",
1066 		    ctos ? "client->server" : "server->client",
1067 		    newkeys->enc.name,
1068 		    authlen == 0 ? newkeys->mac.name : "<implicit>",
1069 		    newkeys->comp.name);
1070 	}
1071 	need = dh_need = 0;
1072 	for (mode = 0; mode < MODE_MAX; mode++) {
1073 		newkeys = kex->newkeys[mode];
1074 		need = MAXIMUM(need, newkeys->enc.key_len);
1075 		need = MAXIMUM(need, newkeys->enc.block_size);
1076 		need = MAXIMUM(need, newkeys->enc.iv_len);
1077 		need = MAXIMUM(need, newkeys->mac.key_len);
1078 		dh_need = MAXIMUM(dh_need, cipher_seclen(newkeys->enc.cipher));
1079 		dh_need = MAXIMUM(dh_need, newkeys->enc.block_size);
1080 		dh_need = MAXIMUM(dh_need, newkeys->enc.iv_len);
1081 		dh_need = MAXIMUM(dh_need, newkeys->mac.key_len);
1082 	}
1083 	/* XXX need runden? */
1084 	kex->we_need = need;
1085 	kex->dh_need = dh_need;
1086 
1087 	/* ignore the next message if the proposals do not match */
1088 	if (first_kex_follows && !proposals_match(my, peer))
1089 		ssh->dispatch_skip_packets = 1;
1090 	r = 0;
1091  out:
1092 	kex_prop_free(my);
1093 	kex_prop_free(peer);
1094 	return r;
1095 }
1096 
1097 static int
1098 derive_key(struct ssh *ssh, int id, u_int need, u_char *hash, u_int hashlen,
1099     const struct sshbuf *shared_secret, u_char **keyp)
1100 {
1101 	struct kex *kex = ssh->kex;
1102 	struct ssh_digest_ctx *hashctx = NULL;
1103 	char c = id;
1104 	u_int have;
1105 	size_t mdsz;
1106 	u_char *digest;
1107 	int r;
1108 
1109 	if ((mdsz = ssh_digest_bytes(kex->hash_alg)) == 0)
1110 		return SSH_ERR_INVALID_ARGUMENT;
1111 	if ((digest = calloc(1, ROUNDUP(need, mdsz))) == NULL) {
1112 		r = SSH_ERR_ALLOC_FAIL;
1113 		goto out;
1114 	}
1115 
1116 	/* K1 = HASH(K || H || "A" || session_id) */
1117 	if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1118 	    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1119 	    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1120 	    ssh_digest_update(hashctx, &c, 1) != 0 ||
1121 	    ssh_digest_update_buffer(hashctx, kex->session_id) != 0 ||
1122 	    ssh_digest_final(hashctx, digest, mdsz) != 0) {
1123 		r = SSH_ERR_LIBCRYPTO_ERROR;
1124 		error_f("KEX hash failed");
1125 		goto out;
1126 	}
1127 	ssh_digest_free(hashctx);
1128 	hashctx = NULL;
1129 
1130 	/*
1131 	 * expand key:
1132 	 * Kn = HASH(K || H || K1 || K2 || ... || Kn-1)
1133 	 * Key = K1 || K2 || ... || Kn
1134 	 */
1135 	for (have = mdsz; need > have; have += mdsz) {
1136 		if ((hashctx = ssh_digest_start(kex->hash_alg)) == NULL ||
1137 		    ssh_digest_update_buffer(hashctx, shared_secret) != 0 ||
1138 		    ssh_digest_update(hashctx, hash, hashlen) != 0 ||
1139 		    ssh_digest_update(hashctx, digest, have) != 0 ||
1140 		    ssh_digest_final(hashctx, digest + have, mdsz) != 0) {
1141 			error_f("KDF failed");
1142 			r = SSH_ERR_LIBCRYPTO_ERROR;
1143 			goto out;
1144 		}
1145 		ssh_digest_free(hashctx);
1146 		hashctx = NULL;
1147 	}
1148 #ifdef DEBUG_KEX
1149 	fprintf(stderr, "key '%c'== ", c);
1150 	dump_digest("key", digest, need);
1151 #endif
1152 	*keyp = digest;
1153 	digest = NULL;
1154 	r = 0;
1155  out:
1156 	free(digest);
1157 	ssh_digest_free(hashctx);
1158 	return r;
1159 }
1160 
1161 #define NKEYS	6
1162 int
1163 kex_derive_keys(struct ssh *ssh, u_char *hash, u_int hashlen,
1164     const struct sshbuf *shared_secret)
1165 {
1166 	struct kex *kex = ssh->kex;
1167 	u_char *keys[NKEYS];
1168 	u_int i, j, mode, ctos;
1169 	int r;
1170 
1171 	/* save initial hash as session id */
1172 	if ((kex->flags & KEX_INITIAL) != 0) {
1173 		if (sshbuf_len(kex->session_id) != 0) {
1174 			error_f("already have session ID at kex");
1175 			return SSH_ERR_INTERNAL_ERROR;
1176 		}
1177 		if ((r = sshbuf_put(kex->session_id, hash, hashlen)) != 0)
1178 			return r;
1179 	} else if (sshbuf_len(kex->session_id) == 0) {
1180 		error_f("no session ID in rekex");
1181 		return SSH_ERR_INTERNAL_ERROR;
1182 	}
1183 	for (i = 0; i < NKEYS; i++) {
1184 		if ((r = derive_key(ssh, 'A'+i, kex->we_need, hash, hashlen,
1185 		    shared_secret, &keys[i])) != 0) {
1186 			for (j = 0; j < i; j++)
1187 				free(keys[j]);
1188 			return r;
1189 		}
1190 	}
1191 	for (mode = 0; mode < MODE_MAX; mode++) {
1192 		ctos = (!kex->server && mode == MODE_OUT) ||
1193 		    (kex->server && mode == MODE_IN);
1194 		kex->newkeys[mode]->enc.iv  = keys[ctos ? 0 : 1];
1195 		kex->newkeys[mode]->enc.key = keys[ctos ? 2 : 3];
1196 		kex->newkeys[mode]->mac.key = keys[ctos ? 4 : 5];
1197 	}
1198 	return 0;
1199 }
1200 
1201 int
1202 kex_load_hostkey(struct ssh *ssh, struct sshkey **prvp, struct sshkey **pubp)
1203 {
1204 	struct kex *kex = ssh->kex;
1205 
1206 	*pubp = NULL;
1207 	*prvp = NULL;
1208 	if (kex->load_host_public_key == NULL ||
1209 	    kex->load_host_private_key == NULL) {
1210 		error_f("missing hostkey loader");
1211 		return SSH_ERR_INVALID_ARGUMENT;
1212 	}
1213 	*pubp = kex->load_host_public_key(kex->hostkey_type,
1214 	    kex->hostkey_nid, ssh);
1215 	*prvp = kex->load_host_private_key(kex->hostkey_type,
1216 	    kex->hostkey_nid, ssh);
1217 	if (*pubp == NULL)
1218 		return SSH_ERR_NO_HOSTKEY_LOADED;
1219 	return 0;
1220 }
1221 
1222 int
1223 kex_verify_host_key(struct ssh *ssh, struct sshkey *server_host_key)
1224 {
1225 	struct kex *kex = ssh->kex;
1226 
1227 	if (kex->verify_host_key == NULL) {
1228 		error_f("missing hostkey verifier");
1229 		return SSH_ERR_INVALID_ARGUMENT;
1230 	}
1231 	if (server_host_key->type != kex->hostkey_type ||
1232 	    (kex->hostkey_type == KEY_ECDSA &&
1233 	    server_host_key->ecdsa_nid != kex->hostkey_nid))
1234 		return SSH_ERR_KEY_TYPE_MISMATCH;
1235 	if (kex->verify_host_key(server_host_key, ssh) == -1)
1236 		return  SSH_ERR_SIGNATURE_INVALID;
1237 	return 0;
1238 }
1239 
1240 #if defined(DEBUG_KEX) || defined(DEBUG_KEXDH) || defined(DEBUG_KEXECDH)
1241 void
1242 dump_digest(const char *msg, const u_char *digest, int len)
1243 {
1244 	fprintf(stderr, "%s\n", msg);
1245 	sshbuf_dump_data(digest, len, stderr);
1246 }
1247 #endif
1248 
1249 /*
1250  * Send a plaintext error message to the peer, suffixed by \r\n.
1251  * Only used during banner exchange, and there only for the server.
1252  */
1253 static void
1254 send_error(struct ssh *ssh, char *msg)
1255 {
1256 	char *crnl = "\r\n";
1257 
1258 	if (!ssh->kex->server)
1259 		return;
1260 
1261 	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1262 	    msg, strlen(msg)) != strlen(msg) ||
1263 	    atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1264 	    crnl, strlen(crnl)) != strlen(crnl))
1265 		error_f("write: %.100s", strerror(errno));
1266 }
1267 
1268 /*
1269  * Sends our identification string and waits for the peer's. Will block for
1270  * up to timeout_ms (or indefinitely if timeout_ms <= 0).
1271  * Returns on 0 success or a ssherr.h code on failure.
1272  */
1273 int
1274 kex_exchange_identification(struct ssh *ssh, int timeout_ms,
1275     const char *version_addendum)
1276 {
1277 	int remote_major, remote_minor, mismatch, oerrno = 0;
1278 	size_t len, n;
1279 	int r, expect_nl;
1280 	u_char c;
1281 	struct sshbuf *our_version = ssh->kex->server ?
1282 	    ssh->kex->server_version : ssh->kex->client_version;
1283 	struct sshbuf *peer_version = ssh->kex->server ?
1284 	    ssh->kex->client_version : ssh->kex->server_version;
1285 	char *our_version_string = NULL, *peer_version_string = NULL;
1286 	char *cp, *remote_version = NULL;
1287 
1288 	/* Prepare and send our banner */
1289 	sshbuf_reset(our_version);
1290 	if (version_addendum != NULL && *version_addendum == '\0')
1291 		version_addendum = NULL;
1292 	if ((r = sshbuf_putf(our_version, "SSH-%d.%d-%.100s%s%s\r\n",
1293 	    PROTOCOL_MAJOR_2, PROTOCOL_MINOR_2, SSH_VERSION,
1294 	    version_addendum == NULL ? "" : " ",
1295 	    version_addendum == NULL ? "" : version_addendum)) != 0) {
1296 		oerrno = errno;
1297 		error_fr(r, "sshbuf_putf");
1298 		goto out;
1299 	}
1300 
1301 	if (atomicio(vwrite, ssh_packet_get_connection_out(ssh),
1302 	    sshbuf_mutable_ptr(our_version),
1303 	    sshbuf_len(our_version)) != sshbuf_len(our_version)) {
1304 		oerrno = errno;
1305 		debug_f("write: %.100s", strerror(errno));
1306 		r = SSH_ERR_SYSTEM_ERROR;
1307 		goto out;
1308 	}
1309 	if ((r = sshbuf_consume_end(our_version, 2)) != 0) { /* trim \r\n */
1310 		oerrno = errno;
1311 		error_fr(r, "sshbuf_consume_end");
1312 		goto out;
1313 	}
1314 	our_version_string = sshbuf_dup_string(our_version);
1315 	if (our_version_string == NULL) {
1316 		error_f("sshbuf_dup_string failed");
1317 		r = SSH_ERR_ALLOC_FAIL;
1318 		goto out;
1319 	}
1320 	debug("Local version string %.100s", our_version_string);
1321 
1322 	/* Read other side's version identification. */
1323 	for (n = 0; ; n++) {
1324 		if (n >= SSH_MAX_PRE_BANNER_LINES) {
1325 			send_error(ssh, "No SSH identification string "
1326 			    "received.");
1327 			error_f("No SSH version received in first %u lines "
1328 			    "from server", SSH_MAX_PRE_BANNER_LINES);
1329 			r = SSH_ERR_INVALID_FORMAT;
1330 			goto out;
1331 		}
1332 		sshbuf_reset(peer_version);
1333 		expect_nl = 0;
1334 		for (;;) {
1335 			if (timeout_ms > 0) {
1336 				r = waitrfd(ssh_packet_get_connection_in(ssh),
1337 				    &timeout_ms);
1338 				if (r == -1 && errno == ETIMEDOUT) {
1339 					send_error(ssh, "Timed out waiting "
1340 					    "for SSH identification string.");
1341 					error("Connection timed out during "
1342 					    "banner exchange");
1343 					r = SSH_ERR_CONN_TIMEOUT;
1344 					goto out;
1345 				} else if (r == -1) {
1346 					oerrno = errno;
1347 					error_f("%s", strerror(errno));
1348 					r = SSH_ERR_SYSTEM_ERROR;
1349 					goto out;
1350 				}
1351 			}
1352 
1353 			len = atomicio(read, ssh_packet_get_connection_in(ssh),
1354 			    &c, 1);
1355 			if (len != 1 && errno == EPIPE) {
1356 				error_f("Connection closed by remote host");
1357 				r = SSH_ERR_CONN_CLOSED;
1358 				goto out;
1359 			} else if (len != 1) {
1360 				oerrno = errno;
1361 				error_f("read: %.100s", strerror(errno));
1362 				r = SSH_ERR_SYSTEM_ERROR;
1363 				goto out;
1364 			}
1365 			if (c == '\r') {
1366 				expect_nl = 1;
1367 				continue;
1368 			}
1369 			if (c == '\n')
1370 				break;
1371 			if (c == '\0' || expect_nl) {
1372 				error_f("banner line contains invalid "
1373 				    "characters");
1374 				goto invalid;
1375 			}
1376 			if ((r = sshbuf_put_u8(peer_version, c)) != 0) {
1377 				oerrno = errno;
1378 				error_fr(r, "sshbuf_put");
1379 				goto out;
1380 			}
1381 			if (sshbuf_len(peer_version) > SSH_MAX_BANNER_LEN) {
1382 				error_f("banner line too long");
1383 				goto invalid;
1384 			}
1385 		}
1386 		/* Is this an actual protocol banner? */
1387 		if (sshbuf_len(peer_version) > 4 &&
1388 		    memcmp(sshbuf_ptr(peer_version), "SSH-", 4) == 0)
1389 			break;
1390 		/* If not, then just log the line and continue */
1391 		if ((cp = sshbuf_dup_string(peer_version)) == NULL) {
1392 			error_f("sshbuf_dup_string failed");
1393 			r = SSH_ERR_ALLOC_FAIL;
1394 			goto out;
1395 		}
1396 		/* Do not accept lines before the SSH ident from a client */
1397 		if (ssh->kex->server) {
1398 			error_f("client sent invalid protocol identifier "
1399 			    "\"%.256s\"", cp);
1400 			free(cp);
1401 			goto invalid;
1402 		}
1403 		debug_f("banner line %zu: %s", n, cp);
1404 		free(cp);
1405 	}
1406 	peer_version_string = sshbuf_dup_string(peer_version);
1407 	if (peer_version_string == NULL)
1408 		fatal_f("sshbuf_dup_string failed");
1409 	/* XXX must be same size for sscanf */
1410 	if ((remote_version = calloc(1, sshbuf_len(peer_version))) == NULL) {
1411 		error_f("calloc failed");
1412 		r = SSH_ERR_ALLOC_FAIL;
1413 		goto out;
1414 	}
1415 
1416 	/*
1417 	 * Check that the versions match.  In future this might accept
1418 	 * several versions and set appropriate flags to handle them.
1419 	 */
1420 	if (sscanf(peer_version_string, "SSH-%d.%d-%[^\n]\n",
1421 	    &remote_major, &remote_minor, remote_version) != 3) {
1422 		error("Bad remote protocol version identification: '%.100s'",
1423 		    peer_version_string);
1424  invalid:
1425 		send_error(ssh, "Invalid SSH identification string.");
1426 		r = SSH_ERR_INVALID_FORMAT;
1427 		goto out;
1428 	}
1429 	debug("Remote protocol version %d.%d, remote software version %.100s",
1430 	    remote_major, remote_minor, remote_version);
1431 	compat_banner(ssh, remote_version);
1432 
1433 	mismatch = 0;
1434 	switch (remote_major) {
1435 	case 2:
1436 		break;
1437 	case 1:
1438 		if (remote_minor != 99)
1439 			mismatch = 1;
1440 		break;
1441 	default:
1442 		mismatch = 1;
1443 		break;
1444 	}
1445 	if (mismatch) {
1446 		error("Protocol major versions differ: %d vs. %d",
1447 		    PROTOCOL_MAJOR_2, remote_major);
1448 		send_error(ssh, "Protocol major versions differ.");
1449 		r = SSH_ERR_NO_PROTOCOL_VERSION;
1450 		goto out;
1451 	}
1452 
1453 	if (ssh->kex->server && (ssh->compat & SSH_BUG_PROBE) != 0) {
1454 		logit("probed from %s port %d with %s.  Don't panic.",
1455 		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1456 		    peer_version_string);
1457 		r = SSH_ERR_CONN_CLOSED; /* XXX */
1458 		goto out;
1459 	}
1460 	if (ssh->kex->server && (ssh->compat & SSH_BUG_SCANNER) != 0) {
1461 		logit("scanned from %s port %d with %s.  Don't panic.",
1462 		    ssh_remote_ipaddr(ssh), ssh_remote_port(ssh),
1463 		    peer_version_string);
1464 		r = SSH_ERR_CONN_CLOSED; /* XXX */
1465 		goto out;
1466 	}
1467 	/* success */
1468 	r = 0;
1469  out:
1470 	free(our_version_string);
1471 	free(peer_version_string);
1472 	free(remote_version);
1473 	if (r == SSH_ERR_SYSTEM_ERROR)
1474 		errno = oerrno;
1475 	return r;
1476 }
1477 
1478