xref: /freebsd/sys/dev/wg/wg_noise.c (revision 744bfb213144c63cbaf38d91a1c4f7aebb9b9fbc)
1 /* SPDX-License-Identifier: ISC
2  *
3  * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4  * Copyright (C) 2019-2021 Matt Dunwoodie <ncon@noconroy.net>
5  * Copyright (c) 2022 The FreeBSD Foundation
6  */
7 
8 #include <sys/param.h>
9 #include <sys/systm.h>
10 #include <sys/ck.h>
11 #include <sys/endian.h>
12 #include <sys/epoch.h>
13 #include <sys/kernel.h>
14 #include <sys/lock.h>
15 #include <sys/malloc.h>
16 #include <sys/mutex.h>
17 #include <sys/refcount.h>
18 #include <sys/rwlock.h>
19 #include <crypto/siphash/siphash.h>
20 
21 #include "crypto.h"
22 #include "wg_noise.h"
23 #include "support.h"
24 
25 /* Protocol string constants */
26 #define NOISE_HANDSHAKE_NAME	"Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
27 #define NOISE_IDENTIFIER_NAME	"WireGuard v1 zx2c4 Jason@zx2c4.com"
28 
29 /* Constants for the counter */
30 #define COUNTER_BITS_TOTAL	8192
31 #ifdef __LP64__
32 #define COUNTER_ORDER		6
33 #define COUNTER_BITS		64
34 #else
35 #define COUNTER_ORDER		5
36 #define COUNTER_BITS		32
37 #endif
38 #define COUNTER_REDUNDANT_BITS	COUNTER_BITS
39 #define COUNTER_WINDOW_SIZE	(COUNTER_BITS_TOTAL - COUNTER_REDUNDANT_BITS)
40 
41 /* Constants for the keypair */
42 #define REKEY_AFTER_MESSAGES	(1ull << 60)
43 #define REJECT_AFTER_MESSAGES	(UINT64_MAX - COUNTER_WINDOW_SIZE - 1)
44 #define REKEY_AFTER_TIME	120
45 #define REKEY_AFTER_TIME_RECV	165
46 #define REJECT_INTERVAL		(1000000000 / 50) /* fifty times per sec */
47 /* 24 = floor(log2(REJECT_INTERVAL)) */
48 #define REJECT_INTERVAL_MASK	(~((1ull<<24)-1))
49 #define TIMER_RESET		(SBT_1S * -(REKEY_TIMEOUT+1))
50 
51 #define HT_INDEX_SIZE		(1 << 13)
52 #define HT_INDEX_MASK		(HT_INDEX_SIZE - 1)
53 #define HT_REMOTE_SIZE		(1 << 11)
54 #define HT_REMOTE_MASK		(HT_REMOTE_SIZE - 1)
55 #define MAX_REMOTE_PER_LOCAL	(1 << 20)
56 
57 struct noise_index {
58 	CK_LIST_ENTRY(noise_index)	 i_entry;
59 	uint32_t			 i_local_index;
60 	uint32_t			 i_remote_index;
61 	int				 i_is_keypair;
62 };
63 
64 struct noise_keypair {
65 	struct noise_index		 kp_index;
66 	u_int				 kp_refcnt;
67 	bool				 kp_can_send;
68 	bool				 kp_is_initiator;
69 	sbintime_t			 kp_birthdate; /* sbinuptime */
70 	struct noise_remote		*kp_remote;
71 
72 	uint8_t				 kp_send[NOISE_SYMMETRIC_KEY_LEN];
73 	uint8_t				 kp_recv[NOISE_SYMMETRIC_KEY_LEN];
74 
75 	/* Counter elements */
76 	struct rwlock			 kp_nonce_lock;
77 	uint64_t			 kp_nonce_send;
78 	uint64_t			 kp_nonce_recv;
79 	unsigned long			 kp_backtrack[COUNTER_BITS_TOTAL / COUNTER_BITS];
80 
81 	struct epoch_context		 kp_smr;
82 };
83 
84 struct noise_handshake {
85 	uint8_t	 			 hs_e[NOISE_PUBLIC_KEY_LEN];
86 	uint8_t	 			 hs_hash[NOISE_HASH_LEN];
87 	uint8_t	 			 hs_ck[NOISE_HASH_LEN];
88 };
89 
90 enum noise_handshake_state {
91 	HANDSHAKE_DEAD,
92 	HANDSHAKE_INITIATOR,
93 	HANDSHAKE_RESPONDER,
94 };
95 
96 struct noise_remote {
97 	struct noise_index		 r_index;
98 
99 	CK_LIST_ENTRY(noise_remote) 	 r_entry;
100 	bool				 r_entry_inserted;
101 	uint8_t				 r_public[NOISE_PUBLIC_KEY_LEN];
102 
103 	struct rwlock			 r_handshake_lock;
104 	struct noise_handshake		 r_handshake;
105 	enum noise_handshake_state	 r_handshake_state;
106 	sbintime_t			 r_last_sent; /* sbinuptime */
107 	sbintime_t			 r_last_init_recv; /* sbinuptime */
108 	uint8_t				 r_timestamp[NOISE_TIMESTAMP_LEN];
109 	uint8_t				 r_psk[NOISE_SYMMETRIC_KEY_LEN];
110 	uint8_t		 		 r_ss[NOISE_PUBLIC_KEY_LEN];
111 
112 	u_int				 r_refcnt;
113 	struct noise_local		*r_local;
114 	void				*r_arg;
115 
116 	struct mtx			 r_keypair_mtx;
117 	struct noise_keypair		*r_next, *r_current, *r_previous;
118 
119 	struct epoch_context		 r_smr;
120 	void				(*r_cleanup)(struct noise_remote *);
121 };
122 
123 struct noise_local {
124 	struct rwlock			 l_identity_lock;
125 	bool				 l_has_identity;
126 	uint8_t				 l_public[NOISE_PUBLIC_KEY_LEN];
127 	uint8_t				 l_private[NOISE_PUBLIC_KEY_LEN];
128 
129 	u_int				 l_refcnt;
130 	uint8_t				 l_hash_key[SIPHASH_KEY_LENGTH];
131 	void				*l_arg;
132 	void				(*l_cleanup)(struct noise_local *);
133 
134 	struct mtx			 l_remote_mtx;
135 	size_t				 l_remote_num;
136 	CK_LIST_HEAD(,noise_remote)	 l_remote_hash[HT_REMOTE_SIZE];
137 
138 	struct mtx			 l_index_mtx;
139 	CK_LIST_HEAD(,noise_index)	 l_index_hash[HT_INDEX_SIZE];
140 };
141 
142 static void	noise_precompute_ss(struct noise_local *, struct noise_remote *);
143 
144 static void	noise_remote_index_insert(struct noise_local *, struct noise_remote *);
145 static struct noise_remote *
146 		noise_remote_index_lookup(struct noise_local *, uint32_t, bool);
147 static int	noise_remote_index_remove(struct noise_local *, struct noise_remote *);
148 static void	noise_remote_expire_current(struct noise_remote *);
149 
150 static void	noise_add_new_keypair(struct noise_local *, struct noise_remote *, struct noise_keypair *);
151 static int	noise_begin_session(struct noise_remote *);
152 static void	noise_keypair_drop(struct noise_keypair *);
153 
154 static void	noise_kdf(uint8_t *, uint8_t *, uint8_t *, const uint8_t *,
155 		    size_t, size_t, size_t, size_t,
156 		    const uint8_t [NOISE_HASH_LEN]);
157 static int	noise_mix_dh(uint8_t [NOISE_HASH_LEN], uint8_t [NOISE_SYMMETRIC_KEY_LEN],
158 		    const uint8_t [NOISE_PUBLIC_KEY_LEN],
159 		    const uint8_t [NOISE_PUBLIC_KEY_LEN]);
160 static int	noise_mix_ss(uint8_t ck[NOISE_HASH_LEN], uint8_t [NOISE_SYMMETRIC_KEY_LEN],
161 		    const uint8_t [NOISE_PUBLIC_KEY_LEN]);
162 static void	noise_mix_hash(uint8_t [NOISE_HASH_LEN], const uint8_t *, size_t);
163 static void	noise_mix_psk(uint8_t [NOISE_HASH_LEN], uint8_t [NOISE_HASH_LEN],
164 		    uint8_t [NOISE_SYMMETRIC_KEY_LEN], const uint8_t [NOISE_SYMMETRIC_KEY_LEN]);
165 static void	noise_param_init(uint8_t [NOISE_HASH_LEN], uint8_t [NOISE_HASH_LEN],
166 		    const uint8_t [NOISE_PUBLIC_KEY_LEN]);
167 static void	noise_msg_encrypt(uint8_t *, const uint8_t *, size_t,
168 		    uint8_t [NOISE_SYMMETRIC_KEY_LEN], uint8_t [NOISE_HASH_LEN]);
169 static int	noise_msg_decrypt(uint8_t *, const uint8_t *, size_t,
170 		    uint8_t [NOISE_SYMMETRIC_KEY_LEN], uint8_t [NOISE_HASH_LEN]);
171 static void	noise_msg_ephemeral(uint8_t [NOISE_HASH_LEN], uint8_t [NOISE_HASH_LEN],
172 		    const uint8_t [NOISE_PUBLIC_KEY_LEN]);
173 static void	noise_tai64n_now(uint8_t [NOISE_TIMESTAMP_LEN]);
174 static int	noise_timer_expired(sbintime_t, uint32_t, uint32_t);
175 static uint64_t siphash24(const uint8_t [SIPHASH_KEY_LENGTH], const void *, size_t);
176 
177 MALLOC_DEFINE(M_NOISE, "NOISE", "wgnoise");
178 
179 /* Local configuration */
180 struct noise_local *
181 noise_local_alloc(void *arg)
182 {
183 	struct noise_local *l;
184 	size_t i;
185 
186 	l = malloc(sizeof(*l), M_NOISE, M_WAITOK | M_ZERO);
187 
188 	rw_init(&l->l_identity_lock, "noise_identity");
189 	l->l_has_identity = false;
190 	bzero(l->l_public, NOISE_PUBLIC_KEY_LEN);
191 	bzero(l->l_private, NOISE_PUBLIC_KEY_LEN);
192 
193 	refcount_init(&l->l_refcnt, 1);
194 	arc4random_buf(l->l_hash_key, sizeof(l->l_hash_key));
195 	l->l_arg = arg;
196 	l->l_cleanup = NULL;
197 
198 	mtx_init(&l->l_remote_mtx, "noise_remote", NULL, MTX_DEF);
199 	l->l_remote_num = 0;
200 	for (i = 0; i < HT_REMOTE_SIZE; i++)
201 		CK_LIST_INIT(&l->l_remote_hash[i]);
202 
203 	mtx_init(&l->l_index_mtx, "noise_index", NULL, MTX_DEF);
204 	for (i = 0; i < HT_INDEX_SIZE; i++)
205 		CK_LIST_INIT(&l->l_index_hash[i]);
206 
207 	return (l);
208 }
209 
210 struct noise_local *
211 noise_local_ref(struct noise_local *l)
212 {
213 	refcount_acquire(&l->l_refcnt);
214 	return (l);
215 }
216 
217 void
218 noise_local_put(struct noise_local *l)
219 {
220 	if (refcount_release(&l->l_refcnt)) {
221 		if (l->l_cleanup != NULL)
222 			l->l_cleanup(l);
223 		rw_destroy(&l->l_identity_lock);
224 		mtx_destroy(&l->l_remote_mtx);
225 		mtx_destroy(&l->l_index_mtx);
226 		explicit_bzero(l, sizeof(*l));
227 		free(l, M_NOISE);
228 	}
229 }
230 
231 void
232 noise_local_free(struct noise_local *l, void (*cleanup)(struct noise_local *))
233 {
234 	l->l_cleanup = cleanup;
235 	noise_local_put(l);
236 }
237 
238 void *
239 noise_local_arg(struct noise_local *l)
240 {
241 	return (l->l_arg);
242 }
243 
244 void
245 noise_local_private(struct noise_local *l, const uint8_t private[NOISE_PUBLIC_KEY_LEN])
246 {
247 	struct epoch_tracker et;
248 	struct noise_remote *r;
249 	size_t i;
250 
251 	rw_wlock(&l->l_identity_lock);
252 	memcpy(l->l_private, private, NOISE_PUBLIC_KEY_LEN);
253 	curve25519_clamp_secret(l->l_private);
254 	l->l_has_identity = curve25519_generate_public(l->l_public, l->l_private);
255 
256 	NET_EPOCH_ENTER(et);
257 	for (i = 0; i < HT_REMOTE_SIZE; i++) {
258 		CK_LIST_FOREACH(r, &l->l_remote_hash[i], r_entry) {
259 			noise_precompute_ss(l, r);
260 			noise_remote_expire_current(r);
261 		}
262 	}
263 	NET_EPOCH_EXIT(et);
264 	rw_wunlock(&l->l_identity_lock);
265 }
266 
267 int
268 noise_local_keys(struct noise_local *l, uint8_t public[NOISE_PUBLIC_KEY_LEN],
269     uint8_t private[NOISE_PUBLIC_KEY_LEN])
270 {
271 	int has_identity;
272 	rw_rlock(&l->l_identity_lock);
273 	if ((has_identity = l->l_has_identity)) {
274 		if (public != NULL)
275 			memcpy(public, l->l_public, NOISE_PUBLIC_KEY_LEN);
276 		if (private != NULL)
277 			memcpy(private, l->l_private, NOISE_PUBLIC_KEY_LEN);
278 	}
279 	rw_runlock(&l->l_identity_lock);
280 	return (has_identity ? 0 : ENXIO);
281 }
282 
283 static void
284 noise_precompute_ss(struct noise_local *l, struct noise_remote *r)
285 {
286 	rw_wlock(&r->r_handshake_lock);
287 	if (!l->l_has_identity ||
288 	    !curve25519(r->r_ss, l->l_private, r->r_public))
289 		bzero(r->r_ss, NOISE_PUBLIC_KEY_LEN);
290 	rw_wunlock(&r->r_handshake_lock);
291 }
292 
293 /* Remote configuration */
294 struct noise_remote *
295 noise_remote_alloc(struct noise_local *l, void *arg,
296     const uint8_t public[NOISE_PUBLIC_KEY_LEN])
297 {
298 	struct noise_remote *r;
299 
300 	r = malloc(sizeof(*r), M_NOISE, M_WAITOK | M_ZERO);
301 	memcpy(r->r_public, public, NOISE_PUBLIC_KEY_LEN);
302 
303 	rw_init(&r->r_handshake_lock, "noise_handshake");
304 	r->r_handshake_state = HANDSHAKE_DEAD;
305 	r->r_last_sent = TIMER_RESET;
306 	r->r_last_init_recv = TIMER_RESET;
307 	noise_precompute_ss(l, r);
308 
309 	refcount_init(&r->r_refcnt, 1);
310 	r->r_local = noise_local_ref(l);
311 	r->r_arg = arg;
312 
313 	mtx_init(&r->r_keypair_mtx, "noise_keypair", NULL, MTX_DEF);
314 
315 	return (r);
316 }
317 
318 int
319 noise_remote_enable(struct noise_remote *r)
320 {
321 	struct noise_local *l = r->r_local;
322 	uint64_t idx;
323 	int ret = 0;
324 
325 	/* Insert to hashtable */
326 	idx = siphash24(l->l_hash_key, r->r_public, NOISE_PUBLIC_KEY_LEN) & HT_REMOTE_MASK;
327 
328 	mtx_lock(&l->l_remote_mtx);
329 	if (!r->r_entry_inserted) {
330 		if (l->l_remote_num < MAX_REMOTE_PER_LOCAL) {
331 			r->r_entry_inserted = true;
332 			l->l_remote_num++;
333 			CK_LIST_INSERT_HEAD(&l->l_remote_hash[idx], r, r_entry);
334 		} else {
335 			ret = ENOSPC;
336 		}
337 	}
338 	mtx_unlock(&l->l_remote_mtx);
339 
340 	return ret;
341 }
342 
343 void
344 noise_remote_disable(struct noise_remote *r)
345 {
346 	struct noise_local *l = r->r_local;
347 	/* remove from hashtable */
348 	mtx_lock(&l->l_remote_mtx);
349 	if (r->r_entry_inserted) {
350 		r->r_entry_inserted = false;
351 		CK_LIST_REMOVE(r, r_entry);
352 		l->l_remote_num--;
353 	};
354 	mtx_unlock(&l->l_remote_mtx);
355 }
356 
357 struct noise_remote *
358 noise_remote_lookup(struct noise_local *l, const uint8_t public[NOISE_PUBLIC_KEY_LEN])
359 {
360 	struct epoch_tracker et;
361 	struct noise_remote *r, *ret = NULL;
362 	uint64_t idx;
363 
364 	idx = siphash24(l->l_hash_key, public, NOISE_PUBLIC_KEY_LEN) & HT_REMOTE_MASK;
365 
366 	NET_EPOCH_ENTER(et);
367 	CK_LIST_FOREACH(r, &l->l_remote_hash[idx], r_entry) {
368 		if (timingsafe_bcmp(r->r_public, public, NOISE_PUBLIC_KEY_LEN) == 0) {
369 			if (refcount_acquire_if_not_zero(&r->r_refcnt))
370 				ret = r;
371 			break;
372 		}
373 	}
374 	NET_EPOCH_EXIT(et);
375 	return (ret);
376 }
377 
378 static void
379 noise_remote_index_insert(struct noise_local *l, struct noise_remote *r)
380 {
381 	struct noise_index *i, *r_i = &r->r_index;
382 	struct epoch_tracker et;
383 	uint32_t idx;
384 
385 	noise_remote_index_remove(l, r);
386 
387 	NET_EPOCH_ENTER(et);
388 assign_id:
389 	r_i->i_local_index = arc4random();
390 	idx = r_i->i_local_index & HT_INDEX_MASK;
391 	CK_LIST_FOREACH(i, &l->l_index_hash[idx], i_entry) {
392 		if (i->i_local_index == r_i->i_local_index)
393 			goto assign_id;
394 	}
395 
396 	mtx_lock(&l->l_index_mtx);
397 	CK_LIST_FOREACH(i, &l->l_index_hash[idx], i_entry) {
398 		if (i->i_local_index == r_i->i_local_index) {
399 			mtx_unlock(&l->l_index_mtx);
400 			goto assign_id;
401 		}
402 	}
403 	CK_LIST_INSERT_HEAD(&l->l_index_hash[idx], r_i, i_entry);
404 	mtx_unlock(&l->l_index_mtx);
405 
406 	NET_EPOCH_EXIT(et);
407 }
408 
409 static struct noise_remote *
410 noise_remote_index_lookup(struct noise_local *l, uint32_t idx0, bool lookup_keypair)
411 {
412 	struct epoch_tracker et;
413 	struct noise_index *i;
414 	struct noise_keypair *kp;
415 	struct noise_remote *r, *ret = NULL;
416 	uint32_t idx = idx0 & HT_INDEX_MASK;
417 
418 	NET_EPOCH_ENTER(et);
419 	CK_LIST_FOREACH(i, &l->l_index_hash[idx], i_entry) {
420 		if (i->i_local_index == idx0) {
421 			if (!i->i_is_keypair) {
422 				r = (struct noise_remote *) i;
423 			} else if (lookup_keypair) {
424 				kp = (struct noise_keypair *) i;
425 				r = kp->kp_remote;
426 			} else {
427 				break;
428 			}
429 			if (refcount_acquire_if_not_zero(&r->r_refcnt))
430 				ret = r;
431 			break;
432 		}
433 	}
434 	NET_EPOCH_EXIT(et);
435 	return (ret);
436 }
437 
438 struct noise_remote *
439 noise_remote_index(struct noise_local *l, uint32_t idx)
440 {
441 	return noise_remote_index_lookup(l, idx, true);
442 }
443 
444 static int
445 noise_remote_index_remove(struct noise_local *l, struct noise_remote *r)
446 {
447 	rw_assert(&r->r_handshake_lock, RA_WLOCKED);
448 	if (r->r_handshake_state != HANDSHAKE_DEAD) {
449 		mtx_lock(&l->l_index_mtx);
450 		r->r_handshake_state = HANDSHAKE_DEAD;
451 		CK_LIST_REMOVE(&r->r_index, i_entry);
452 		mtx_unlock(&l->l_index_mtx);
453 		return (1);
454 	}
455 	return (0);
456 }
457 
458 struct noise_remote *
459 noise_remote_ref(struct noise_remote *r)
460 {
461 	refcount_acquire(&r->r_refcnt);
462 	return (r);
463 }
464 
465 static void
466 noise_remote_smr_free(struct epoch_context *smr)
467 {
468 	struct noise_remote *r;
469 	r = __containerof(smr, struct noise_remote, r_smr);
470 	if (r->r_cleanup != NULL)
471 		r->r_cleanup(r);
472 	noise_local_put(r->r_local);
473 	rw_destroy(&r->r_handshake_lock);
474 	mtx_destroy(&r->r_keypair_mtx);
475 	explicit_bzero(r, sizeof(*r));
476 	free(r, M_NOISE);
477 }
478 
479 void
480 noise_remote_put(struct noise_remote *r)
481 {
482 	if (refcount_release(&r->r_refcnt))
483 		NET_EPOCH_CALL(noise_remote_smr_free, &r->r_smr);
484 }
485 
486 void
487 noise_remote_free(struct noise_remote *r, void (*cleanup)(struct noise_remote *))
488 {
489 	r->r_cleanup = cleanup;
490 	noise_remote_disable(r);
491 
492 	/* now clear all keypairs and handshakes, then put this reference */
493 	noise_remote_handshake_clear(r);
494 	noise_remote_keypairs_clear(r);
495 	noise_remote_put(r);
496 }
497 
498 struct noise_local *
499 noise_remote_local(struct noise_remote *r)
500 {
501 	return (noise_local_ref(r->r_local));
502 }
503 
504 void *
505 noise_remote_arg(struct noise_remote *r)
506 {
507 	return (r->r_arg);
508 }
509 
510 void
511 noise_remote_set_psk(struct noise_remote *r,
512     const uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
513 {
514 	rw_wlock(&r->r_handshake_lock);
515 	if (psk == NULL)
516 		bzero(r->r_psk, NOISE_SYMMETRIC_KEY_LEN);
517 	else
518 		memcpy(r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN);
519 	rw_wunlock(&r->r_handshake_lock);
520 }
521 
522 int
523 noise_remote_keys(struct noise_remote *r, uint8_t public[NOISE_PUBLIC_KEY_LEN],
524     uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
525 {
526 	static uint8_t null_psk[NOISE_SYMMETRIC_KEY_LEN];
527 	int ret;
528 
529 	if (public != NULL)
530 		memcpy(public, r->r_public, NOISE_PUBLIC_KEY_LEN);
531 
532 	rw_rlock(&r->r_handshake_lock);
533 	if (psk != NULL)
534 		memcpy(psk, r->r_psk, NOISE_SYMMETRIC_KEY_LEN);
535 	ret = timingsafe_bcmp(r->r_psk, null_psk, NOISE_SYMMETRIC_KEY_LEN);
536 	rw_runlock(&r->r_handshake_lock);
537 
538 	return (ret ? 0 : ENOENT);
539 }
540 
541 int
542 noise_remote_initiation_expired(struct noise_remote *r)
543 {
544 	int expired;
545 	rw_rlock(&r->r_handshake_lock);
546 	expired = noise_timer_expired(r->r_last_sent, REKEY_TIMEOUT, 0);
547 	rw_runlock(&r->r_handshake_lock);
548 	return (expired);
549 }
550 
551 void
552 noise_remote_handshake_clear(struct noise_remote *r)
553 {
554 	rw_wlock(&r->r_handshake_lock);
555 	if (noise_remote_index_remove(r->r_local, r))
556 		bzero(&r->r_handshake, sizeof(r->r_handshake));
557 	r->r_last_sent = TIMER_RESET;
558 	rw_wunlock(&r->r_handshake_lock);
559 }
560 
561 void
562 noise_remote_keypairs_clear(struct noise_remote *r)
563 {
564 	struct noise_keypair *kp;
565 
566 	mtx_lock(&r->r_keypair_mtx);
567 	kp = ck_pr_load_ptr(&r->r_next);
568 	ck_pr_store_ptr(&r->r_next, NULL);
569 	noise_keypair_drop(kp);
570 
571 	kp = ck_pr_load_ptr(&r->r_current);
572 	ck_pr_store_ptr(&r->r_current, NULL);
573 	noise_keypair_drop(kp);
574 
575 	kp = ck_pr_load_ptr(&r->r_previous);
576 	ck_pr_store_ptr(&r->r_previous, NULL);
577 	noise_keypair_drop(kp);
578 	mtx_unlock(&r->r_keypair_mtx);
579 }
580 
581 static void
582 noise_remote_expire_current(struct noise_remote *r)
583 {
584 	struct epoch_tracker et;
585 	struct noise_keypair *kp;
586 
587 	noise_remote_handshake_clear(r);
588 
589 	NET_EPOCH_ENTER(et);
590 	kp = ck_pr_load_ptr(&r->r_next);
591 	if (kp != NULL)
592 		ck_pr_store_bool(&kp->kp_can_send, false);
593 	kp = ck_pr_load_ptr(&r->r_current);
594 	if (kp != NULL)
595 		ck_pr_store_bool(&kp->kp_can_send, false);
596 	NET_EPOCH_EXIT(et);
597 }
598 
599 /* Keypair functions */
600 static void
601 noise_add_new_keypair(struct noise_local *l, struct noise_remote *r,
602     struct noise_keypair *kp)
603 {
604 	struct noise_keypair *next, *current, *previous;
605 	struct noise_index *r_i = &r->r_index;
606 
607 	/* Insert into the keypair table */
608 	mtx_lock(&r->r_keypair_mtx);
609 	next = ck_pr_load_ptr(&r->r_next);
610 	current = ck_pr_load_ptr(&r->r_current);
611 	previous = ck_pr_load_ptr(&r->r_previous);
612 
613 	if (kp->kp_is_initiator) {
614 		if (next != NULL) {
615 			ck_pr_store_ptr(&r->r_next, NULL);
616 			ck_pr_store_ptr(&r->r_previous, next);
617 			noise_keypair_drop(current);
618 		} else {
619 			ck_pr_store_ptr(&r->r_previous, current);
620 		}
621 		noise_keypair_drop(previous);
622 		ck_pr_store_ptr(&r->r_current, kp);
623 	} else {
624 		ck_pr_store_ptr(&r->r_next, kp);
625 		noise_keypair_drop(next);
626 		ck_pr_store_ptr(&r->r_previous, NULL);
627 		noise_keypair_drop(previous);
628 
629 	}
630 	mtx_unlock(&r->r_keypair_mtx);
631 
632 	/* Insert into index table */
633 	rw_assert(&r->r_handshake_lock, RA_WLOCKED);
634 
635 	kp->kp_index.i_is_keypair = true;
636 	kp->kp_index.i_local_index = r_i->i_local_index;
637 	kp->kp_index.i_remote_index = r_i->i_remote_index;
638 
639 	mtx_lock(&l->l_index_mtx);
640 	CK_LIST_INSERT_BEFORE(r_i, &kp->kp_index, i_entry);
641 	r->r_handshake_state = HANDSHAKE_DEAD;
642 	CK_LIST_REMOVE(r_i, i_entry);
643 	mtx_unlock(&l->l_index_mtx);
644 
645 	explicit_bzero(&r->r_handshake, sizeof(r->r_handshake));
646 }
647 
648 static int
649 noise_begin_session(struct noise_remote *r)
650 {
651 	struct noise_keypair *kp;
652 
653 	rw_assert(&r->r_handshake_lock, RA_WLOCKED);
654 
655 	if ((kp = malloc(sizeof(*kp), M_NOISE, M_NOWAIT | M_ZERO)) == NULL)
656 		return (ENOSPC);
657 
658 	refcount_init(&kp->kp_refcnt, 1);
659 	kp->kp_can_send = true;
660 	kp->kp_is_initiator = r->r_handshake_state == HANDSHAKE_INITIATOR;
661 	kp->kp_birthdate = getsbinuptime();
662 	kp->kp_remote = noise_remote_ref(r);
663 
664 	if (kp->kp_is_initiator)
665 		noise_kdf(kp->kp_send, kp->kp_recv, NULL, NULL,
666 		    NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
667 		    r->r_handshake.hs_ck);
668 	else
669 		noise_kdf(kp->kp_recv, kp->kp_send, NULL, NULL,
670 		    NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
671 		    r->r_handshake.hs_ck);
672 
673 	rw_init(&kp->kp_nonce_lock, "noise_nonce");
674 
675 	noise_add_new_keypair(r->r_local, r, kp);
676 	return (0);
677 }
678 
679 struct noise_keypair *
680 noise_keypair_lookup(struct noise_local *l, uint32_t idx0)
681 {
682 	struct epoch_tracker et;
683 	struct noise_index *i;
684 	struct noise_keypair *kp, *ret = NULL;
685 	uint32_t idx = idx0 & HT_INDEX_MASK;
686 
687 	NET_EPOCH_ENTER(et);
688 	CK_LIST_FOREACH(i, &l->l_index_hash[idx], i_entry) {
689 		if (i->i_local_index == idx0 && i->i_is_keypair) {
690 			kp = (struct noise_keypair *) i;
691 			if (refcount_acquire_if_not_zero(&kp->kp_refcnt))
692 				ret = kp;
693 			break;
694 		}
695 	}
696 	NET_EPOCH_EXIT(et);
697 	return (ret);
698 }
699 
700 struct noise_keypair *
701 noise_keypair_current(struct noise_remote *r)
702 {
703 	struct epoch_tracker et;
704 	struct noise_keypair *kp, *ret = NULL;
705 
706 	NET_EPOCH_ENTER(et);
707 	kp = ck_pr_load_ptr(&r->r_current);
708 	if (kp != NULL && ck_pr_load_bool(&kp->kp_can_send)) {
709 		if (noise_timer_expired(kp->kp_birthdate, REJECT_AFTER_TIME, 0))
710 			ck_pr_store_bool(&kp->kp_can_send, false);
711 		else if (refcount_acquire_if_not_zero(&kp->kp_refcnt))
712 			ret = kp;
713 	}
714 	NET_EPOCH_EXIT(et);
715 	return (ret);
716 }
717 
718 struct noise_keypair *
719 noise_keypair_ref(struct noise_keypair *kp)
720 {
721 	refcount_acquire(&kp->kp_refcnt);
722 	return (kp);
723 }
724 
725 int
726 noise_keypair_received_with(struct noise_keypair *kp)
727 {
728 	struct noise_keypair *old;
729 	struct noise_remote *r = kp->kp_remote;
730 
731 	if (kp != ck_pr_load_ptr(&r->r_next))
732 		return (0);
733 
734 	mtx_lock(&r->r_keypair_mtx);
735 	if (kp != ck_pr_load_ptr(&r->r_next)) {
736 		mtx_unlock(&r->r_keypair_mtx);
737 		return (0);
738 	}
739 
740 	old = ck_pr_load_ptr(&r->r_previous);
741 	ck_pr_store_ptr(&r->r_previous, ck_pr_load_ptr(&r->r_current));
742 	noise_keypair_drop(old);
743 	ck_pr_store_ptr(&r->r_current, kp);
744 	ck_pr_store_ptr(&r->r_next, NULL);
745 	mtx_unlock(&r->r_keypair_mtx);
746 
747 	return (ECONNRESET);
748 }
749 
750 static void
751 noise_keypair_smr_free(struct epoch_context *smr)
752 {
753 	struct noise_keypair *kp;
754 	kp = __containerof(smr, struct noise_keypair, kp_smr);
755 	noise_remote_put(kp->kp_remote);
756 	rw_destroy(&kp->kp_nonce_lock);
757 	explicit_bzero(kp, sizeof(*kp));
758 	free(kp, M_NOISE);
759 }
760 
761 void
762 noise_keypair_put(struct noise_keypair *kp)
763 {
764 	if (refcount_release(&kp->kp_refcnt))
765 		NET_EPOCH_CALL(noise_keypair_smr_free, &kp->kp_smr);
766 }
767 
768 static void
769 noise_keypair_drop(struct noise_keypair *kp)
770 {
771 	struct noise_remote *r;
772 	struct noise_local *l;
773 
774 	if (kp == NULL)
775 		return;
776 
777 	r = kp->kp_remote;
778 	l = r->r_local;
779 
780 	mtx_lock(&l->l_index_mtx);
781 	CK_LIST_REMOVE(&kp->kp_index, i_entry);
782 	mtx_unlock(&l->l_index_mtx);
783 
784 	noise_keypair_put(kp);
785 }
786 
787 struct noise_remote *
788 noise_keypair_remote(struct noise_keypair *kp)
789 {
790 	return (noise_remote_ref(kp->kp_remote));
791 }
792 
793 int
794 noise_keypair_nonce_next(struct noise_keypair *kp, uint64_t *send)
795 {
796 	if (!ck_pr_load_bool(&kp->kp_can_send))
797 		return (EINVAL);
798 
799 #ifdef __LP64__
800 	*send = ck_pr_faa_64(&kp->kp_nonce_send, 1);
801 #else
802 	rw_wlock(&kp->kp_nonce_lock);
803 	*send = kp->kp_nonce_send++;
804 	rw_wunlock(&kp->kp_nonce_lock);
805 #endif
806 	if (*send < REJECT_AFTER_MESSAGES)
807 		return (0);
808 	ck_pr_store_bool(&kp->kp_can_send, false);
809 	return (EINVAL);
810 }
811 
812 int
813 noise_keypair_nonce_check(struct noise_keypair *kp, uint64_t recv)
814 {
815 	unsigned long index, index_current, top, i, bit;
816 	int ret = EEXIST;
817 
818 	rw_wlock(&kp->kp_nonce_lock);
819 
820 	if (__predict_false(kp->kp_nonce_recv >= REJECT_AFTER_MESSAGES + 1 ||
821 			    recv >= REJECT_AFTER_MESSAGES))
822 		goto error;
823 
824 	++recv;
825 
826 	if (__predict_false(recv + COUNTER_WINDOW_SIZE < kp->kp_nonce_recv))
827 		goto error;
828 
829 	index = recv >> COUNTER_ORDER;
830 
831 	if (__predict_true(recv > kp->kp_nonce_recv)) {
832 		index_current = kp->kp_nonce_recv >> COUNTER_ORDER;
833 		top = MIN(index - index_current, COUNTER_BITS_TOTAL / COUNTER_BITS);
834 		for (i = 1; i <= top; i++)
835 			kp->kp_backtrack[
836 			    (i + index_current) &
837 				((COUNTER_BITS_TOTAL / COUNTER_BITS) - 1)] = 0;
838 #ifdef __LP64__
839 		ck_pr_store_64(&kp->kp_nonce_recv, recv);
840 #else
841 		kp->kp_nonce_recv = recv;
842 #endif
843 	}
844 
845 	index &= (COUNTER_BITS_TOTAL / COUNTER_BITS) - 1;
846 	bit = 1ul << (recv & (COUNTER_BITS - 1));
847 	if (kp->kp_backtrack[index] & bit)
848 		goto error;
849 
850 	kp->kp_backtrack[index] |= bit;
851 	ret = 0;
852 error:
853 	rw_wunlock(&kp->kp_nonce_lock);
854 	return (ret);
855 }
856 
857 int
858 noise_keep_key_fresh_send(struct noise_remote *r)
859 {
860 	struct epoch_tracker et;
861 	struct noise_keypair *current;
862 	int keep_key_fresh;
863 	uint64_t nonce;
864 
865 	NET_EPOCH_ENTER(et);
866 	current = ck_pr_load_ptr(&r->r_current);
867 	keep_key_fresh = current != NULL && ck_pr_load_bool(&current->kp_can_send);
868 	if (!keep_key_fresh)
869 		goto out;
870 #ifdef __LP64__
871 	nonce = ck_pr_load_64(&current->kp_nonce_send);
872 #else
873 	rw_rlock(&current->kp_nonce_lock);
874 	nonce = current->kp_nonce_send;
875 	rw_runlock(&current->kp_nonce_lock);
876 #endif
877 	keep_key_fresh = nonce > REKEY_AFTER_MESSAGES;
878 	if (keep_key_fresh)
879 		goto out;
880 	keep_key_fresh = current->kp_is_initiator && noise_timer_expired(current->kp_birthdate, REKEY_AFTER_TIME, 0);
881 
882 out:
883 	NET_EPOCH_EXIT(et);
884 	return (keep_key_fresh ? ESTALE : 0);
885 }
886 
887 int
888 noise_keep_key_fresh_recv(struct noise_remote *r)
889 {
890 	struct epoch_tracker et;
891 	struct noise_keypair *current;
892 	int keep_key_fresh;
893 
894 	NET_EPOCH_ENTER(et);
895 	current = ck_pr_load_ptr(&r->r_current);
896 	keep_key_fresh = current != NULL && ck_pr_load_bool(&current->kp_can_send) &&
897 	    current->kp_is_initiator && noise_timer_expired(current->kp_birthdate,
898 	    REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT, 0);
899 	NET_EPOCH_EXIT(et);
900 
901 	return (keep_key_fresh ? ESTALE : 0);
902 }
903 
904 int
905 noise_keypair_encrypt(struct noise_keypair *kp, uint32_t *r_idx, uint64_t nonce, struct mbuf *m)
906 {
907 	int ret;
908 
909 	ret = chacha20poly1305_encrypt_mbuf(m, nonce, kp->kp_send);
910 	if (ret)
911 		return (ret);
912 
913 	*r_idx = kp->kp_index.i_remote_index;
914 	return (0);
915 }
916 
917 int
918 noise_keypair_decrypt(struct noise_keypair *kp, uint64_t nonce, struct mbuf *m)
919 {
920 	uint64_t cur_nonce;
921 	int ret;
922 
923 #ifdef __LP64__
924 	cur_nonce = ck_pr_load_64(&kp->kp_nonce_recv);
925 #else
926 	rw_rlock(&kp->kp_nonce_lock);
927 	cur_nonce = kp->kp_nonce_recv;
928 	rw_runlock(&kp->kp_nonce_lock);
929 #endif
930 
931 	if (cur_nonce >= REJECT_AFTER_MESSAGES ||
932 	    noise_timer_expired(kp->kp_birthdate, REJECT_AFTER_TIME, 0))
933 		return (EINVAL);
934 
935 	ret = chacha20poly1305_decrypt_mbuf(m, nonce, kp->kp_recv);
936 	if (ret)
937 		return (ret);
938 
939 	return (0);
940 }
941 
942 /* Handshake functions */
943 int
944 noise_create_initiation(struct noise_remote *r,
945     uint32_t *s_idx,
946     uint8_t ue[NOISE_PUBLIC_KEY_LEN],
947     uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN],
948     uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN])
949 {
950 	struct noise_handshake *hs = &r->r_handshake;
951 	struct noise_local *l = r->r_local;
952 	uint8_t key[NOISE_SYMMETRIC_KEY_LEN];
953 	int ret = EINVAL;
954 
955 	rw_rlock(&l->l_identity_lock);
956 	rw_wlock(&r->r_handshake_lock);
957 	if (!l->l_has_identity)
958 		goto error;
959 	if (!noise_timer_expired(r->r_last_sent, REKEY_TIMEOUT, 0))
960 		goto error;
961 	noise_param_init(hs->hs_ck, hs->hs_hash, r->r_public);
962 
963 	/* e */
964 	curve25519_generate_secret(hs->hs_e);
965 	if (curve25519_generate_public(ue, hs->hs_e) == 0)
966 		goto error;
967 	noise_msg_ephemeral(hs->hs_ck, hs->hs_hash, ue);
968 
969 	/* es */
970 	if (noise_mix_dh(hs->hs_ck, key, hs->hs_e, r->r_public) != 0)
971 		goto error;
972 
973 	/* s */
974 	noise_msg_encrypt(es, l->l_public,
975 	    NOISE_PUBLIC_KEY_LEN, key, hs->hs_hash);
976 
977 	/* ss */
978 	if (noise_mix_ss(hs->hs_ck, key, r->r_ss) != 0)
979 		goto error;
980 
981 	/* {t} */
982 	noise_tai64n_now(ets);
983 	noise_msg_encrypt(ets, ets,
984 	    NOISE_TIMESTAMP_LEN, key, hs->hs_hash);
985 
986 	noise_remote_index_insert(l, r);
987 	r->r_handshake_state = HANDSHAKE_INITIATOR;
988 	r->r_last_sent = getsbinuptime();
989 	*s_idx = r->r_index.i_local_index;
990 	ret = 0;
991 error:
992 	rw_wunlock(&r->r_handshake_lock);
993 	rw_runlock(&l->l_identity_lock);
994 	explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN);
995 	return (ret);
996 }
997 
998 int
999 noise_consume_initiation(struct noise_local *l, struct noise_remote **rp,
1000     uint32_t s_idx,
1001     uint8_t ue[NOISE_PUBLIC_KEY_LEN],
1002     uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN],
1003     uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN])
1004 {
1005 	struct noise_remote *r;
1006 	struct noise_handshake hs;
1007 	uint8_t key[NOISE_SYMMETRIC_KEY_LEN];
1008 	uint8_t r_public[NOISE_PUBLIC_KEY_LEN];
1009 	uint8_t	timestamp[NOISE_TIMESTAMP_LEN];
1010 	int ret = EINVAL;
1011 
1012 	rw_rlock(&l->l_identity_lock);
1013 	if (!l->l_has_identity)
1014 		goto error;
1015 	noise_param_init(hs.hs_ck, hs.hs_hash, l->l_public);
1016 
1017 	/* e */
1018 	noise_msg_ephemeral(hs.hs_ck, hs.hs_hash, ue);
1019 
1020 	/* es */
1021 	if (noise_mix_dh(hs.hs_ck, key, l->l_private, ue) != 0)
1022 		goto error;
1023 
1024 	/* s */
1025 	if (noise_msg_decrypt(r_public, es,
1026 	    NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN, key, hs.hs_hash) != 0)
1027 		goto error;
1028 
1029 	/* Lookup the remote we received from */
1030 	if ((r = noise_remote_lookup(l, r_public)) == NULL)
1031 		goto error;
1032 
1033 	/* ss */
1034 	if (noise_mix_ss(hs.hs_ck, key, r->r_ss) != 0)
1035 		goto error_put;
1036 
1037 	/* {t} */
1038 	if (noise_msg_decrypt(timestamp, ets,
1039 	    NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN, key, hs.hs_hash) != 0)
1040 		goto error_put;
1041 
1042 	memcpy(hs.hs_e, ue, NOISE_PUBLIC_KEY_LEN);
1043 
1044 	/* We have successfully computed the same results, now we ensure that
1045 	 * this is not an initiation replay, or a flood attack */
1046 	rw_wlock(&r->r_handshake_lock);
1047 
1048 	/* Replay */
1049 	if (memcmp(timestamp, r->r_timestamp, NOISE_TIMESTAMP_LEN) > 0)
1050 		memcpy(r->r_timestamp, timestamp, NOISE_TIMESTAMP_LEN);
1051 	else
1052 		goto error_set;
1053 	/* Flood attack */
1054 	if (noise_timer_expired(r->r_last_init_recv, 0, REJECT_INTERVAL))
1055 		r->r_last_init_recv = getsbinuptime();
1056 	else
1057 		goto error_set;
1058 
1059 	/* Ok, we're happy to accept this initiation now */
1060 	noise_remote_index_insert(l, r);
1061 	r->r_index.i_remote_index = s_idx;
1062 	r->r_handshake_state = HANDSHAKE_RESPONDER;
1063 	r->r_handshake = hs;
1064 	*rp = noise_remote_ref(r);
1065 	ret = 0;
1066 error_set:
1067 	rw_wunlock(&r->r_handshake_lock);
1068 error_put:
1069 	noise_remote_put(r);
1070 error:
1071 	rw_runlock(&l->l_identity_lock);
1072 	explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN);
1073 	explicit_bzero(&hs, sizeof(hs));
1074 	return (ret);
1075 }
1076 
1077 int
1078 noise_create_response(struct noise_remote *r,
1079     uint32_t *s_idx, uint32_t *r_idx,
1080     uint8_t ue[NOISE_PUBLIC_KEY_LEN],
1081     uint8_t en[0 + NOISE_AUTHTAG_LEN])
1082 {
1083 	struct noise_handshake *hs = &r->r_handshake;
1084 	struct noise_local *l = r->r_local;
1085 	uint8_t key[NOISE_SYMMETRIC_KEY_LEN];
1086 	uint8_t e[NOISE_PUBLIC_KEY_LEN];
1087 	int ret = EINVAL;
1088 
1089 	rw_rlock(&l->l_identity_lock);
1090 	rw_wlock(&r->r_handshake_lock);
1091 
1092 	if (r->r_handshake_state != HANDSHAKE_RESPONDER)
1093 		goto error;
1094 
1095 	/* e */
1096 	curve25519_generate_secret(e);
1097 	if (curve25519_generate_public(ue, e) == 0)
1098 		goto error;
1099 	noise_msg_ephemeral(hs->hs_ck, hs->hs_hash, ue);
1100 
1101 	/* ee */
1102 	if (noise_mix_dh(hs->hs_ck, NULL, e, hs->hs_e) != 0)
1103 		goto error;
1104 
1105 	/* se */
1106 	if (noise_mix_dh(hs->hs_ck, NULL, e, r->r_public) != 0)
1107 		goto error;
1108 
1109 	/* psk */
1110 	noise_mix_psk(hs->hs_ck, hs->hs_hash, key, r->r_psk);
1111 
1112 	/* {} */
1113 	noise_msg_encrypt(en, NULL, 0, key, hs->hs_hash);
1114 
1115 	if ((ret = noise_begin_session(r)) == 0) {
1116 		r->r_last_sent = getsbinuptime();
1117 		*s_idx = r->r_index.i_local_index;
1118 		*r_idx = r->r_index.i_remote_index;
1119 	}
1120 error:
1121 	rw_wunlock(&r->r_handshake_lock);
1122 	rw_runlock(&l->l_identity_lock);
1123 	explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN);
1124 	explicit_bzero(e, NOISE_PUBLIC_KEY_LEN);
1125 	return (ret);
1126 }
1127 
1128 int
1129 noise_consume_response(struct noise_local *l, struct noise_remote **rp,
1130     uint32_t s_idx, uint32_t r_idx,
1131     uint8_t ue[NOISE_PUBLIC_KEY_LEN],
1132     uint8_t en[0 + NOISE_AUTHTAG_LEN])
1133 {
1134 	uint8_t preshared_key[NOISE_SYMMETRIC_KEY_LEN];
1135 	uint8_t key[NOISE_SYMMETRIC_KEY_LEN];
1136 	struct noise_handshake hs;
1137 	struct noise_remote *r = NULL;
1138 	int ret = EINVAL;
1139 
1140 	if ((r = noise_remote_index_lookup(l, r_idx, false)) == NULL)
1141 		return (ret);
1142 
1143 	rw_rlock(&l->l_identity_lock);
1144 	if (!l->l_has_identity)
1145 		goto error;
1146 
1147 	rw_rlock(&r->r_handshake_lock);
1148 	if (r->r_handshake_state != HANDSHAKE_INITIATOR) {
1149 		rw_runlock(&r->r_handshake_lock);
1150 		goto error;
1151 	}
1152 	memcpy(preshared_key, r->r_psk, NOISE_SYMMETRIC_KEY_LEN);
1153 	hs = r->r_handshake;
1154 	rw_runlock(&r->r_handshake_lock);
1155 
1156 	/* e */
1157 	noise_msg_ephemeral(hs.hs_ck, hs.hs_hash, ue);
1158 
1159 	/* ee */
1160 	if (noise_mix_dh(hs.hs_ck, NULL, hs.hs_e, ue) != 0)
1161 		goto error_zero;
1162 
1163 	/* se */
1164 	if (noise_mix_dh(hs.hs_ck, NULL, l->l_private, ue) != 0)
1165 		goto error_zero;
1166 
1167 	/* psk */
1168 	noise_mix_psk(hs.hs_ck, hs.hs_hash, key, preshared_key);
1169 
1170 	/* {} */
1171 	if (noise_msg_decrypt(NULL, en,
1172 	    0 + NOISE_AUTHTAG_LEN, key, hs.hs_hash) != 0)
1173 		goto error_zero;
1174 
1175 	rw_wlock(&r->r_handshake_lock);
1176 	if (r->r_handshake_state == HANDSHAKE_INITIATOR &&
1177 	    r->r_index.i_local_index == r_idx) {
1178 		r->r_handshake = hs;
1179 		r->r_index.i_remote_index = s_idx;
1180 		if ((ret = noise_begin_session(r)) == 0)
1181 			*rp = noise_remote_ref(r);
1182 	}
1183 	rw_wunlock(&r->r_handshake_lock);
1184 error_zero:
1185 	explicit_bzero(preshared_key, NOISE_SYMMETRIC_KEY_LEN);
1186 	explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN);
1187 	explicit_bzero(&hs, sizeof(hs));
1188 error:
1189 	rw_runlock(&l->l_identity_lock);
1190 	noise_remote_put(r);
1191 	return (ret);
1192 }
1193 
1194 static void
1195 hmac(uint8_t *out, const uint8_t *in, const uint8_t *key, const size_t outlen,
1196      const size_t inlen, const size_t keylen)
1197 {
1198 	struct blake2s_state state;
1199 	uint8_t x_key[BLAKE2S_BLOCK_SIZE] __aligned(sizeof(uint32_t)) = { 0 };
1200 	uint8_t i_hash[BLAKE2S_HASH_SIZE] __aligned(sizeof(uint32_t));
1201 	int i;
1202 
1203 	if (keylen > BLAKE2S_BLOCK_SIZE) {
1204 		blake2s_init(&state, BLAKE2S_HASH_SIZE);
1205 		blake2s_update(&state, key, keylen);
1206 		blake2s_final(&state, x_key);
1207 	} else
1208 		memcpy(x_key, key, keylen);
1209 
1210 	for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
1211 		x_key[i] ^= 0x36;
1212 
1213 	blake2s_init(&state, BLAKE2S_HASH_SIZE);
1214 	blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE);
1215 	blake2s_update(&state, in, inlen);
1216 	blake2s_final(&state, i_hash);
1217 
1218 	for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
1219 		x_key[i] ^= 0x5c ^ 0x36;
1220 
1221 	blake2s_init(&state, BLAKE2S_HASH_SIZE);
1222 	blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE);
1223 	blake2s_update(&state, i_hash, BLAKE2S_HASH_SIZE);
1224 	blake2s_final(&state, i_hash);
1225 
1226 	memcpy(out, i_hash, outlen);
1227 	explicit_bzero(x_key, BLAKE2S_BLOCK_SIZE);
1228 	explicit_bzero(i_hash, BLAKE2S_HASH_SIZE);
1229 }
1230 
1231 /* Handshake helper functions */
1232 static void
1233 noise_kdf(uint8_t *a, uint8_t *b, uint8_t *c, const uint8_t *x,
1234     size_t a_len, size_t b_len, size_t c_len, size_t x_len,
1235     const uint8_t ck[NOISE_HASH_LEN])
1236 {
1237 	uint8_t out[BLAKE2S_HASH_SIZE + 1];
1238 	uint8_t sec[BLAKE2S_HASH_SIZE];
1239 
1240 	/* Extract entropy from "x" into sec */
1241 	hmac(sec, x, ck, BLAKE2S_HASH_SIZE, x_len, NOISE_HASH_LEN);
1242 
1243 	if (a == NULL || a_len == 0)
1244 		goto out;
1245 
1246 	/* Expand first key: key = sec, data = 0x1 */
1247 	out[0] = 1;
1248 	hmac(out, out, sec, BLAKE2S_HASH_SIZE, 1, BLAKE2S_HASH_SIZE);
1249 	memcpy(a, out, a_len);
1250 
1251 	if (b == NULL || b_len == 0)
1252 		goto out;
1253 
1254 	/* Expand second key: key = sec, data = "a" || 0x2 */
1255 	out[BLAKE2S_HASH_SIZE] = 2;
1256 	hmac(out, out, sec, BLAKE2S_HASH_SIZE, BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE);
1257 	memcpy(b, out, b_len);
1258 
1259 	if (c == NULL || c_len == 0)
1260 		goto out;
1261 
1262 	/* Expand third key: key = sec, data = "b" || 0x3 */
1263 	out[BLAKE2S_HASH_SIZE] = 3;
1264 	hmac(out, out, sec, BLAKE2S_HASH_SIZE, BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE);
1265 	memcpy(c, out, c_len);
1266 
1267 out:
1268 	/* Clear sensitive data from stack */
1269 	explicit_bzero(sec, BLAKE2S_HASH_SIZE);
1270 	explicit_bzero(out, BLAKE2S_HASH_SIZE + 1);
1271 }
1272 
1273 static int
1274 noise_mix_dh(uint8_t ck[NOISE_HASH_LEN], uint8_t key[NOISE_SYMMETRIC_KEY_LEN],
1275     const uint8_t private[NOISE_PUBLIC_KEY_LEN],
1276     const uint8_t public[NOISE_PUBLIC_KEY_LEN])
1277 {
1278 	uint8_t dh[NOISE_PUBLIC_KEY_LEN];
1279 
1280 	if (!curve25519(dh, private, public))
1281 		return (EINVAL);
1282 	noise_kdf(ck, key, NULL, dh,
1283 	    NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, ck);
1284 	explicit_bzero(dh, NOISE_PUBLIC_KEY_LEN);
1285 	return (0);
1286 }
1287 
1288 static int
1289 noise_mix_ss(uint8_t ck[NOISE_HASH_LEN], uint8_t key[NOISE_SYMMETRIC_KEY_LEN],
1290     const uint8_t ss[NOISE_PUBLIC_KEY_LEN])
1291 {
1292 	static uint8_t null_point[NOISE_PUBLIC_KEY_LEN];
1293 	if (timingsafe_bcmp(ss, null_point, NOISE_PUBLIC_KEY_LEN) == 0)
1294 		return (ENOENT);
1295 	noise_kdf(ck, key, NULL, ss,
1296 	    NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, ck);
1297 	return (0);
1298 }
1299 
1300 static void
1301 noise_mix_hash(uint8_t hash[NOISE_HASH_LEN], const uint8_t *src,
1302     size_t src_len)
1303 {
1304 	struct blake2s_state blake;
1305 
1306 	blake2s_init(&blake, NOISE_HASH_LEN);
1307 	blake2s_update(&blake, hash, NOISE_HASH_LEN);
1308 	blake2s_update(&blake, src, src_len);
1309 	blake2s_final(&blake, hash);
1310 }
1311 
1312 static void
1313 noise_mix_psk(uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN],
1314     uint8_t key[NOISE_SYMMETRIC_KEY_LEN],
1315     const uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
1316 {
1317 	uint8_t tmp[NOISE_HASH_LEN];
1318 
1319 	noise_kdf(ck, tmp, key, psk,
1320 	    NOISE_HASH_LEN, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN,
1321 	    NOISE_SYMMETRIC_KEY_LEN, ck);
1322 	noise_mix_hash(hash, tmp, NOISE_HASH_LEN);
1323 	explicit_bzero(tmp, NOISE_HASH_LEN);
1324 }
1325 
1326 static void
1327 noise_param_init(uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN],
1328     const uint8_t s[NOISE_PUBLIC_KEY_LEN])
1329 {
1330 	struct blake2s_state blake;
1331 
1332 	blake2s(ck, (uint8_t *)NOISE_HANDSHAKE_NAME, NULL,
1333 	    NOISE_HASH_LEN, strlen(NOISE_HANDSHAKE_NAME), 0);
1334 	blake2s_init(&blake, NOISE_HASH_LEN);
1335 	blake2s_update(&blake, ck, NOISE_HASH_LEN);
1336 	blake2s_update(&blake, (uint8_t *)NOISE_IDENTIFIER_NAME,
1337 	    strlen(NOISE_IDENTIFIER_NAME));
1338 	blake2s_final(&blake, hash);
1339 
1340 	noise_mix_hash(hash, s, NOISE_PUBLIC_KEY_LEN);
1341 }
1342 
1343 static void
1344 noise_msg_encrypt(uint8_t *dst, const uint8_t *src, size_t src_len,
1345     uint8_t key[NOISE_SYMMETRIC_KEY_LEN], uint8_t hash[NOISE_HASH_LEN])
1346 {
1347 	/* Nonce always zero for Noise_IK */
1348 	chacha20poly1305_encrypt(dst, src, src_len,
1349 	    hash, NOISE_HASH_LEN, 0, key);
1350 	noise_mix_hash(hash, dst, src_len + NOISE_AUTHTAG_LEN);
1351 }
1352 
1353 static int
1354 noise_msg_decrypt(uint8_t *dst, const uint8_t *src, size_t src_len,
1355     uint8_t key[NOISE_SYMMETRIC_KEY_LEN], uint8_t hash[NOISE_HASH_LEN])
1356 {
1357 	/* Nonce always zero for Noise_IK */
1358 	if (!chacha20poly1305_decrypt(dst, src, src_len,
1359 	    hash, NOISE_HASH_LEN, 0, key))
1360 		return (EINVAL);
1361 	noise_mix_hash(hash, src, src_len);
1362 	return (0);
1363 }
1364 
1365 static void
1366 noise_msg_ephemeral(uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN],
1367     const uint8_t src[NOISE_PUBLIC_KEY_LEN])
1368 {
1369 	noise_mix_hash(hash, src, NOISE_PUBLIC_KEY_LEN);
1370 	noise_kdf(ck, NULL, NULL, src, NOISE_HASH_LEN, 0, 0,
1371 		  NOISE_PUBLIC_KEY_LEN, ck);
1372 }
1373 
1374 static void
1375 noise_tai64n_now(uint8_t output[NOISE_TIMESTAMP_LEN])
1376 {
1377 	struct timespec time;
1378 	uint64_t sec;
1379 	uint32_t nsec;
1380 
1381 	getnanotime(&time);
1382 
1383 	/* Round down the nsec counter to limit precise timing leak. */
1384 	time.tv_nsec &= REJECT_INTERVAL_MASK;
1385 
1386 	/* https://cr.yp.to/libtai/tai64.html */
1387 	sec = htobe64(0x400000000000000aULL + time.tv_sec);
1388 	nsec = htobe32(time.tv_nsec);
1389 
1390 	/* memcpy to output buffer, assuming output could be unaligned. */
1391 	memcpy(output, &sec, sizeof(sec));
1392 	memcpy(output + sizeof(sec), &nsec, sizeof(nsec));
1393 }
1394 
1395 static inline int
1396 noise_timer_expired(sbintime_t timer, uint32_t sec, uint32_t nsec)
1397 {
1398 	sbintime_t now = getsbinuptime();
1399 	return (now > (timer + sec * SBT_1S + nstosbt(nsec))) ? ETIMEDOUT : 0;
1400 }
1401 
1402 static uint64_t siphash24(const uint8_t key[SIPHASH_KEY_LENGTH], const void *src, size_t len)
1403 {
1404 	SIPHASH_CTX ctx;
1405 	return (SipHashX(&ctx, 2, 4, key, src, len));
1406 }
1407 
1408 #ifdef SELFTESTS
1409 #include "selftest/counter.c"
1410 #endif /* SELFTESTS */
1411