xref: /freebsd/sys/dev/wg/wg_noise.c (revision 357378bbdedf24ce2b90e9bd831af4a9db3ec70a)
1 /* SPDX-License-Identifier: ISC
2  *
3  * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4  * Copyright (C) 2019-2021 Matt Dunwoodie <ncon@noconroy.net>
5  * Copyright (c) 2022 The FreeBSD Foundation
6  */
7 
8 #include <sys/param.h>
9 #include <sys/systm.h>
10 #include <sys/ck.h>
11 #include <sys/endian.h>
12 #include <sys/epoch.h>
13 #include <sys/kernel.h>
14 #include <sys/lock.h>
15 #include <sys/malloc.h>
16 #include <sys/mutex.h>
17 #include <sys/refcount.h>
18 #include <sys/rwlock.h>
19 #include <crypto/siphash/siphash.h>
20 
21 #include "crypto.h"
22 #include "wg_noise.h"
23 
24 /* Protocol string constants */
25 #define NOISE_HANDSHAKE_NAME	"Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
26 #define NOISE_IDENTIFIER_NAME	"WireGuard v1 zx2c4 Jason@zx2c4.com"
27 
28 /* Constants for the counter */
29 #define COUNTER_BITS_TOTAL	8192
30 #ifdef __LP64__
31 #define COUNTER_ORDER		6
32 #define COUNTER_BITS		64
33 #else
34 #define COUNTER_ORDER		5
35 #define COUNTER_BITS		32
36 #endif
37 #define COUNTER_REDUNDANT_BITS	COUNTER_BITS
38 #define COUNTER_WINDOW_SIZE	(COUNTER_BITS_TOTAL - COUNTER_REDUNDANT_BITS)
39 
40 /* Constants for the keypair */
41 #define REKEY_AFTER_MESSAGES	(1ull << 60)
42 #define REJECT_AFTER_MESSAGES	(UINT64_MAX - COUNTER_WINDOW_SIZE - 1)
43 #define REKEY_AFTER_TIME	120
44 #define REKEY_AFTER_TIME_RECV	165
45 #define REJECT_INTERVAL		(1000000000 / 50) /* fifty times per sec */
46 /* 24 = floor(log2(REJECT_INTERVAL)) */
47 #define REJECT_INTERVAL_MASK	(~((1ull<<24)-1))
48 #define TIMER_RESET		(SBT_1S * -(REKEY_TIMEOUT+1))
49 
50 #define HT_INDEX_SIZE		(1 << 13)
51 #define HT_INDEX_MASK		(HT_INDEX_SIZE - 1)
52 #define HT_REMOTE_SIZE		(1 << 11)
53 #define HT_REMOTE_MASK		(HT_REMOTE_SIZE - 1)
54 #define MAX_REMOTE_PER_LOCAL	(1 << 20)
55 
56 struct noise_index {
57 	CK_LIST_ENTRY(noise_index)	 i_entry;
58 	uint32_t			 i_local_index;
59 	uint32_t			 i_remote_index;
60 	int				 i_is_keypair;
61 };
62 
63 struct noise_keypair {
64 	struct noise_index		 kp_index;
65 	u_int				 kp_refcnt;
66 	bool				 kp_can_send;
67 	bool				 kp_is_initiator;
68 	sbintime_t			 kp_birthdate; /* sbinuptime */
69 	struct noise_remote		*kp_remote;
70 
71 	uint8_t				 kp_send[NOISE_SYMMETRIC_KEY_LEN];
72 	uint8_t				 kp_recv[NOISE_SYMMETRIC_KEY_LEN];
73 
74 	/* Counter elements */
75 	struct rwlock			 kp_nonce_lock;
76 	uint64_t			 kp_nonce_send;
77 	uint64_t			 kp_nonce_recv;
78 	unsigned long			 kp_backtrack[COUNTER_BITS_TOTAL / COUNTER_BITS];
79 
80 	struct epoch_context		 kp_smr;
81 };
82 
83 struct noise_handshake {
84 	uint8_t	 			 hs_e[NOISE_PUBLIC_KEY_LEN];
85 	uint8_t	 			 hs_hash[NOISE_HASH_LEN];
86 	uint8_t	 			 hs_ck[NOISE_HASH_LEN];
87 };
88 
89 enum noise_handshake_state {
90 	HANDSHAKE_DEAD,
91 	HANDSHAKE_INITIATOR,
92 	HANDSHAKE_RESPONDER,
93 };
94 
95 struct noise_remote {
96 	struct noise_index		 r_index;
97 
98 	CK_LIST_ENTRY(noise_remote) 	 r_entry;
99 	bool				 r_entry_inserted;
100 	uint8_t				 r_public[NOISE_PUBLIC_KEY_LEN];
101 
102 	struct rwlock			 r_handshake_lock;
103 	struct noise_handshake		 r_handshake;
104 	enum noise_handshake_state	 r_handshake_state;
105 	sbintime_t			 r_last_sent; /* sbinuptime */
106 	sbintime_t			 r_last_init_recv; /* sbinuptime */
107 	uint8_t				 r_timestamp[NOISE_TIMESTAMP_LEN];
108 	uint8_t				 r_psk[NOISE_SYMMETRIC_KEY_LEN];
109 	uint8_t		 		 r_ss[NOISE_PUBLIC_KEY_LEN];
110 
111 	u_int				 r_refcnt;
112 	struct noise_local		*r_local;
113 	void				*r_arg;
114 
115 	struct mtx			 r_keypair_mtx;
116 	struct noise_keypair		*r_next, *r_current, *r_previous;
117 
118 	struct epoch_context		 r_smr;
119 	void				(*r_cleanup)(struct noise_remote *);
120 };
121 
122 struct noise_local {
123 	struct rwlock			 l_identity_lock;
124 	bool				 l_has_identity;
125 	uint8_t				 l_public[NOISE_PUBLIC_KEY_LEN];
126 	uint8_t				 l_private[NOISE_PUBLIC_KEY_LEN];
127 
128 	u_int				 l_refcnt;
129 	uint8_t				 l_hash_key[SIPHASH_KEY_LENGTH];
130 	void				*l_arg;
131 	void				(*l_cleanup)(struct noise_local *);
132 
133 	struct mtx			 l_remote_mtx;
134 	size_t				 l_remote_num;
135 	CK_LIST_HEAD(,noise_remote)	 l_remote_hash[HT_REMOTE_SIZE];
136 
137 	struct mtx			 l_index_mtx;
138 	CK_LIST_HEAD(,noise_index)	 l_index_hash[HT_INDEX_SIZE];
139 };
140 
141 static void	noise_precompute_ss(struct noise_local *, struct noise_remote *);
142 
143 static void	noise_remote_index_insert(struct noise_local *, struct noise_remote *);
144 static struct noise_remote *
145 		noise_remote_index_lookup(struct noise_local *, uint32_t, bool);
146 static int	noise_remote_index_remove(struct noise_local *, struct noise_remote *);
147 static void	noise_remote_expire_current(struct noise_remote *);
148 
149 static void	noise_add_new_keypair(struct noise_local *, struct noise_remote *, struct noise_keypair *);
150 static int	noise_begin_session(struct noise_remote *);
151 static void	noise_keypair_drop(struct noise_keypair *);
152 
153 static void	noise_kdf(uint8_t *, uint8_t *, uint8_t *, const uint8_t *,
154 		    size_t, size_t, size_t, size_t,
155 		    const uint8_t [NOISE_HASH_LEN]);
156 static int	noise_mix_dh(uint8_t [NOISE_HASH_LEN], uint8_t [NOISE_SYMMETRIC_KEY_LEN],
157 		    const uint8_t [NOISE_PUBLIC_KEY_LEN],
158 		    const uint8_t [NOISE_PUBLIC_KEY_LEN]);
159 static int	noise_mix_ss(uint8_t ck[NOISE_HASH_LEN], uint8_t [NOISE_SYMMETRIC_KEY_LEN],
160 		    const uint8_t [NOISE_PUBLIC_KEY_LEN]);
161 static void	noise_mix_hash(uint8_t [NOISE_HASH_LEN], const uint8_t *, size_t);
162 static void	noise_mix_psk(uint8_t [NOISE_HASH_LEN], uint8_t [NOISE_HASH_LEN],
163 		    uint8_t [NOISE_SYMMETRIC_KEY_LEN], const uint8_t [NOISE_SYMMETRIC_KEY_LEN]);
164 static void	noise_param_init(uint8_t [NOISE_HASH_LEN], uint8_t [NOISE_HASH_LEN],
165 		    const uint8_t [NOISE_PUBLIC_KEY_LEN]);
166 static void	noise_msg_encrypt(uint8_t *, const uint8_t *, size_t,
167 		    uint8_t [NOISE_SYMMETRIC_KEY_LEN], uint8_t [NOISE_HASH_LEN]);
168 static int	noise_msg_decrypt(uint8_t *, const uint8_t *, size_t,
169 		    uint8_t [NOISE_SYMMETRIC_KEY_LEN], uint8_t [NOISE_HASH_LEN]);
170 static void	noise_msg_ephemeral(uint8_t [NOISE_HASH_LEN], uint8_t [NOISE_HASH_LEN],
171 		    const uint8_t [NOISE_PUBLIC_KEY_LEN]);
172 static void	noise_tai64n_now(uint8_t [NOISE_TIMESTAMP_LEN]);
173 static int	noise_timer_expired(sbintime_t, uint32_t, uint32_t);
174 static uint64_t siphash24(const uint8_t [SIPHASH_KEY_LENGTH], const void *, size_t);
175 
176 MALLOC_DEFINE(M_NOISE, "NOISE", "wgnoise");
177 
178 /* Local configuration */
179 struct noise_local *
180 noise_local_alloc(void *arg)
181 {
182 	struct noise_local *l;
183 	size_t i;
184 
185 	l = malloc(sizeof(*l), M_NOISE, M_WAITOK | M_ZERO);
186 
187 	rw_init(&l->l_identity_lock, "noise_identity");
188 	l->l_has_identity = false;
189 	bzero(l->l_public, NOISE_PUBLIC_KEY_LEN);
190 	bzero(l->l_private, NOISE_PUBLIC_KEY_LEN);
191 
192 	refcount_init(&l->l_refcnt, 1);
193 	arc4random_buf(l->l_hash_key, sizeof(l->l_hash_key));
194 	l->l_arg = arg;
195 	l->l_cleanup = NULL;
196 
197 	mtx_init(&l->l_remote_mtx, "noise_remote", NULL, MTX_DEF);
198 	l->l_remote_num = 0;
199 	for (i = 0; i < HT_REMOTE_SIZE; i++)
200 		CK_LIST_INIT(&l->l_remote_hash[i]);
201 
202 	mtx_init(&l->l_index_mtx, "noise_index", NULL, MTX_DEF);
203 	for (i = 0; i < HT_INDEX_SIZE; i++)
204 		CK_LIST_INIT(&l->l_index_hash[i]);
205 
206 	return (l);
207 }
208 
209 struct noise_local *
210 noise_local_ref(struct noise_local *l)
211 {
212 	refcount_acquire(&l->l_refcnt);
213 	return (l);
214 }
215 
216 void
217 noise_local_put(struct noise_local *l)
218 {
219 	if (refcount_release(&l->l_refcnt)) {
220 		if (l->l_cleanup != NULL)
221 			l->l_cleanup(l);
222 		rw_destroy(&l->l_identity_lock);
223 		mtx_destroy(&l->l_remote_mtx);
224 		mtx_destroy(&l->l_index_mtx);
225 		zfree(l, M_NOISE);
226 	}
227 }
228 
229 void
230 noise_local_free(struct noise_local *l, void (*cleanup)(struct noise_local *))
231 {
232 	l->l_cleanup = cleanup;
233 	noise_local_put(l);
234 }
235 
236 void *
237 noise_local_arg(struct noise_local *l)
238 {
239 	return (l->l_arg);
240 }
241 
242 void
243 noise_local_private(struct noise_local *l, const uint8_t private[NOISE_PUBLIC_KEY_LEN])
244 {
245 	struct epoch_tracker et;
246 	struct noise_remote *r;
247 	size_t i;
248 
249 	rw_wlock(&l->l_identity_lock);
250 	memcpy(l->l_private, private, NOISE_PUBLIC_KEY_LEN);
251 	curve25519_clamp_secret(l->l_private);
252 	l->l_has_identity = curve25519_generate_public(l->l_public, l->l_private);
253 
254 	NET_EPOCH_ENTER(et);
255 	for (i = 0; i < HT_REMOTE_SIZE; i++) {
256 		CK_LIST_FOREACH(r, &l->l_remote_hash[i], r_entry) {
257 			noise_precompute_ss(l, r);
258 			noise_remote_expire_current(r);
259 		}
260 	}
261 	NET_EPOCH_EXIT(et);
262 	rw_wunlock(&l->l_identity_lock);
263 }
264 
265 int
266 noise_local_keys(struct noise_local *l, uint8_t public[NOISE_PUBLIC_KEY_LEN],
267     uint8_t private[NOISE_PUBLIC_KEY_LEN])
268 {
269 	int has_identity;
270 	rw_rlock(&l->l_identity_lock);
271 	if ((has_identity = l->l_has_identity)) {
272 		if (public != NULL)
273 			memcpy(public, l->l_public, NOISE_PUBLIC_KEY_LEN);
274 		if (private != NULL)
275 			memcpy(private, l->l_private, NOISE_PUBLIC_KEY_LEN);
276 	}
277 	rw_runlock(&l->l_identity_lock);
278 	return (has_identity ? 0 : ENXIO);
279 }
280 
281 static void
282 noise_precompute_ss(struct noise_local *l, struct noise_remote *r)
283 {
284 	rw_assert(&l->l_identity_lock, RA_LOCKED);
285 	rw_wlock(&r->r_handshake_lock);
286 	if (!l->l_has_identity ||
287 	    !curve25519(r->r_ss, l->l_private, r->r_public))
288 		bzero(r->r_ss, NOISE_PUBLIC_KEY_LEN);
289 	rw_wunlock(&r->r_handshake_lock);
290 }
291 
292 /* Remote configuration */
293 struct noise_remote *
294 noise_remote_alloc(struct noise_local *l, void *arg,
295     const uint8_t public[NOISE_PUBLIC_KEY_LEN])
296 {
297 	struct noise_remote *r;
298 
299 	r = malloc(sizeof(*r), M_NOISE, M_WAITOK | M_ZERO);
300 	memcpy(r->r_public, public, NOISE_PUBLIC_KEY_LEN);
301 
302 	rw_init(&r->r_handshake_lock, "noise_handshake");
303 	r->r_handshake_state = HANDSHAKE_DEAD;
304 	r->r_last_sent = TIMER_RESET;
305 	r->r_last_init_recv = TIMER_RESET;
306 
307 	rw_rlock(&l->l_identity_lock);
308 	noise_precompute_ss(l, r);
309 	rw_runlock(&l->l_identity_lock);
310 
311 	refcount_init(&r->r_refcnt, 1);
312 	r->r_local = noise_local_ref(l);
313 	r->r_arg = arg;
314 
315 	mtx_init(&r->r_keypair_mtx, "noise_keypair", NULL, MTX_DEF);
316 
317 	return (r);
318 }
319 
320 int
321 noise_remote_enable(struct noise_remote *r)
322 {
323 	struct noise_local *l = r->r_local;
324 	uint64_t idx;
325 	int ret = 0;
326 
327 	/* Insert to hashtable */
328 	idx = siphash24(l->l_hash_key, r->r_public, NOISE_PUBLIC_KEY_LEN) & HT_REMOTE_MASK;
329 
330 	mtx_lock(&l->l_remote_mtx);
331 	if (!r->r_entry_inserted) {
332 		if (l->l_remote_num < MAX_REMOTE_PER_LOCAL) {
333 			r->r_entry_inserted = true;
334 			l->l_remote_num++;
335 			CK_LIST_INSERT_HEAD(&l->l_remote_hash[idx], r, r_entry);
336 		} else {
337 			ret = ENOSPC;
338 		}
339 	}
340 	mtx_unlock(&l->l_remote_mtx);
341 
342 	return ret;
343 }
344 
345 void
346 noise_remote_disable(struct noise_remote *r)
347 {
348 	struct noise_local *l = r->r_local;
349 	/* remove from hashtable */
350 	mtx_lock(&l->l_remote_mtx);
351 	if (r->r_entry_inserted) {
352 		r->r_entry_inserted = false;
353 		CK_LIST_REMOVE(r, r_entry);
354 		l->l_remote_num--;
355 	};
356 	mtx_unlock(&l->l_remote_mtx);
357 }
358 
359 struct noise_remote *
360 noise_remote_lookup(struct noise_local *l, const uint8_t public[NOISE_PUBLIC_KEY_LEN])
361 {
362 	struct epoch_tracker et;
363 	struct noise_remote *r, *ret = NULL;
364 	uint64_t idx;
365 
366 	idx = siphash24(l->l_hash_key, public, NOISE_PUBLIC_KEY_LEN) & HT_REMOTE_MASK;
367 
368 	NET_EPOCH_ENTER(et);
369 	CK_LIST_FOREACH(r, &l->l_remote_hash[idx], r_entry) {
370 		if (timingsafe_bcmp(r->r_public, public, NOISE_PUBLIC_KEY_LEN) == 0) {
371 			if (refcount_acquire_if_not_zero(&r->r_refcnt))
372 				ret = r;
373 			break;
374 		}
375 	}
376 	NET_EPOCH_EXIT(et);
377 	return (ret);
378 }
379 
380 static void
381 noise_remote_index_insert(struct noise_local *l, struct noise_remote *r)
382 {
383 	struct noise_index *i, *r_i = &r->r_index;
384 	struct epoch_tracker et;
385 	uint32_t idx;
386 
387 	noise_remote_index_remove(l, r);
388 
389 	NET_EPOCH_ENTER(et);
390 assign_id:
391 	r_i->i_local_index = arc4random();
392 	idx = r_i->i_local_index & HT_INDEX_MASK;
393 	CK_LIST_FOREACH(i, &l->l_index_hash[idx], i_entry) {
394 		if (i->i_local_index == r_i->i_local_index)
395 			goto assign_id;
396 	}
397 
398 	mtx_lock(&l->l_index_mtx);
399 	CK_LIST_FOREACH(i, &l->l_index_hash[idx], i_entry) {
400 		if (i->i_local_index == r_i->i_local_index) {
401 			mtx_unlock(&l->l_index_mtx);
402 			goto assign_id;
403 		}
404 	}
405 	CK_LIST_INSERT_HEAD(&l->l_index_hash[idx], r_i, i_entry);
406 	mtx_unlock(&l->l_index_mtx);
407 
408 	NET_EPOCH_EXIT(et);
409 }
410 
411 static struct noise_remote *
412 noise_remote_index_lookup(struct noise_local *l, uint32_t idx0, bool lookup_keypair)
413 {
414 	struct epoch_tracker et;
415 	struct noise_index *i;
416 	struct noise_keypair *kp;
417 	struct noise_remote *r, *ret = NULL;
418 	uint32_t idx = idx0 & HT_INDEX_MASK;
419 
420 	NET_EPOCH_ENTER(et);
421 	CK_LIST_FOREACH(i, &l->l_index_hash[idx], i_entry) {
422 		if (i->i_local_index == idx0) {
423 			if (!i->i_is_keypair) {
424 				r = (struct noise_remote *) i;
425 			} else if (lookup_keypair) {
426 				kp = (struct noise_keypair *) i;
427 				r = kp->kp_remote;
428 			} else {
429 				break;
430 			}
431 			if (refcount_acquire_if_not_zero(&r->r_refcnt))
432 				ret = r;
433 			break;
434 		}
435 	}
436 	NET_EPOCH_EXIT(et);
437 	return (ret);
438 }
439 
440 struct noise_remote *
441 noise_remote_index(struct noise_local *l, uint32_t idx)
442 {
443 	return noise_remote_index_lookup(l, idx, true);
444 }
445 
446 static int
447 noise_remote_index_remove(struct noise_local *l, struct noise_remote *r)
448 {
449 	rw_assert(&r->r_handshake_lock, RA_WLOCKED);
450 	if (r->r_handshake_state != HANDSHAKE_DEAD) {
451 		mtx_lock(&l->l_index_mtx);
452 		r->r_handshake_state = HANDSHAKE_DEAD;
453 		CK_LIST_REMOVE(&r->r_index, i_entry);
454 		mtx_unlock(&l->l_index_mtx);
455 		return (1);
456 	}
457 	return (0);
458 }
459 
460 struct noise_remote *
461 noise_remote_ref(struct noise_remote *r)
462 {
463 	refcount_acquire(&r->r_refcnt);
464 	return (r);
465 }
466 
467 static void
468 noise_remote_smr_free(struct epoch_context *smr)
469 {
470 	struct noise_remote *r;
471 	r = __containerof(smr, struct noise_remote, r_smr);
472 	if (r->r_cleanup != NULL)
473 		r->r_cleanup(r);
474 	noise_local_put(r->r_local);
475 	rw_destroy(&r->r_handshake_lock);
476 	mtx_destroy(&r->r_keypair_mtx);
477 	zfree(r, M_NOISE);
478 }
479 
480 void
481 noise_remote_put(struct noise_remote *r)
482 {
483 	if (refcount_release(&r->r_refcnt))
484 		NET_EPOCH_CALL(noise_remote_smr_free, &r->r_smr);
485 }
486 
487 void
488 noise_remote_free(struct noise_remote *r, void (*cleanup)(struct noise_remote *))
489 {
490 	r->r_cleanup = cleanup;
491 	noise_remote_disable(r);
492 
493 	/* now clear all keypairs and handshakes, then put this reference */
494 	noise_remote_handshake_clear(r);
495 	noise_remote_keypairs_clear(r);
496 	noise_remote_put(r);
497 }
498 
499 struct noise_local *
500 noise_remote_local(struct noise_remote *r)
501 {
502 	return (noise_local_ref(r->r_local));
503 }
504 
505 void *
506 noise_remote_arg(struct noise_remote *r)
507 {
508 	return (r->r_arg);
509 }
510 
511 void
512 noise_remote_set_psk(struct noise_remote *r,
513     const uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
514 {
515 	rw_wlock(&r->r_handshake_lock);
516 	if (psk == NULL)
517 		bzero(r->r_psk, NOISE_SYMMETRIC_KEY_LEN);
518 	else
519 		memcpy(r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN);
520 	rw_wunlock(&r->r_handshake_lock);
521 }
522 
523 int
524 noise_remote_keys(struct noise_remote *r, uint8_t public[NOISE_PUBLIC_KEY_LEN],
525     uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
526 {
527 	static uint8_t null_psk[NOISE_SYMMETRIC_KEY_LEN];
528 	int ret;
529 
530 	if (public != NULL)
531 		memcpy(public, r->r_public, NOISE_PUBLIC_KEY_LEN);
532 
533 	rw_rlock(&r->r_handshake_lock);
534 	if (psk != NULL)
535 		memcpy(psk, r->r_psk, NOISE_SYMMETRIC_KEY_LEN);
536 	ret = timingsafe_bcmp(r->r_psk, null_psk, NOISE_SYMMETRIC_KEY_LEN);
537 	rw_runlock(&r->r_handshake_lock);
538 
539 	return (ret ? 0 : ENOENT);
540 }
541 
542 int
543 noise_remote_initiation_expired(struct noise_remote *r)
544 {
545 	int expired;
546 	rw_rlock(&r->r_handshake_lock);
547 	expired = noise_timer_expired(r->r_last_sent, REKEY_TIMEOUT, 0);
548 	rw_runlock(&r->r_handshake_lock);
549 	return (expired);
550 }
551 
552 void
553 noise_remote_handshake_clear(struct noise_remote *r)
554 {
555 	rw_wlock(&r->r_handshake_lock);
556 	if (noise_remote_index_remove(r->r_local, r))
557 		bzero(&r->r_handshake, sizeof(r->r_handshake));
558 	r->r_last_sent = TIMER_RESET;
559 	rw_wunlock(&r->r_handshake_lock);
560 }
561 
562 void
563 noise_remote_keypairs_clear(struct noise_remote *r)
564 {
565 	struct noise_keypair *kp;
566 
567 	mtx_lock(&r->r_keypair_mtx);
568 	kp = atomic_load_ptr(&r->r_next);
569 	atomic_store_ptr(&r->r_next, NULL);
570 	noise_keypair_drop(kp);
571 
572 	kp = atomic_load_ptr(&r->r_current);
573 	atomic_store_ptr(&r->r_current, NULL);
574 	noise_keypair_drop(kp);
575 
576 	kp = atomic_load_ptr(&r->r_previous);
577 	atomic_store_ptr(&r->r_previous, NULL);
578 	noise_keypair_drop(kp);
579 	mtx_unlock(&r->r_keypair_mtx);
580 }
581 
582 static void
583 noise_remote_expire_current(struct noise_remote *r)
584 {
585 	struct epoch_tracker et;
586 	struct noise_keypair *kp;
587 
588 	noise_remote_handshake_clear(r);
589 
590 	NET_EPOCH_ENTER(et);
591 	kp = atomic_load_ptr(&r->r_next);
592 	if (kp != NULL)
593 		atomic_store_bool(&kp->kp_can_send, false);
594 	kp = atomic_load_ptr(&r->r_current);
595 	if (kp != NULL)
596 		atomic_store_bool(&kp->kp_can_send, false);
597 	NET_EPOCH_EXIT(et);
598 }
599 
600 /* Keypair functions */
601 static void
602 noise_add_new_keypair(struct noise_local *l, struct noise_remote *r,
603     struct noise_keypair *kp)
604 {
605 	struct noise_keypair *next, *current, *previous;
606 	struct noise_index *r_i = &r->r_index;
607 
608 	/* Insert into the keypair table */
609 	mtx_lock(&r->r_keypair_mtx);
610 	next = atomic_load_ptr(&r->r_next);
611 	current = atomic_load_ptr(&r->r_current);
612 	previous = atomic_load_ptr(&r->r_previous);
613 
614 	if (kp->kp_is_initiator) {
615 		if (next != NULL) {
616 			atomic_store_ptr(&r->r_next, NULL);
617 			atomic_store_ptr(&r->r_previous, next);
618 			noise_keypair_drop(current);
619 		} else {
620 			atomic_store_ptr(&r->r_previous, current);
621 		}
622 		noise_keypair_drop(previous);
623 		atomic_store_ptr(&r->r_current, kp);
624 	} else {
625 		atomic_store_ptr(&r->r_next, kp);
626 		noise_keypair_drop(next);
627 		atomic_store_ptr(&r->r_previous, NULL);
628 		noise_keypair_drop(previous);
629 
630 	}
631 	mtx_unlock(&r->r_keypair_mtx);
632 
633 	/* Insert into index table */
634 	rw_assert(&r->r_handshake_lock, RA_WLOCKED);
635 
636 	kp->kp_index.i_is_keypair = true;
637 	kp->kp_index.i_local_index = r_i->i_local_index;
638 	kp->kp_index.i_remote_index = r_i->i_remote_index;
639 
640 	mtx_lock(&l->l_index_mtx);
641 	CK_LIST_INSERT_BEFORE(r_i, &kp->kp_index, i_entry);
642 	r->r_handshake_state = HANDSHAKE_DEAD;
643 	CK_LIST_REMOVE(r_i, i_entry);
644 	mtx_unlock(&l->l_index_mtx);
645 
646 	explicit_bzero(&r->r_handshake, sizeof(r->r_handshake));
647 }
648 
649 static int
650 noise_begin_session(struct noise_remote *r)
651 {
652 	struct noise_keypair *kp;
653 
654 	rw_assert(&r->r_handshake_lock, RA_WLOCKED);
655 
656 	if ((kp = malloc(sizeof(*kp), M_NOISE, M_NOWAIT | M_ZERO)) == NULL)
657 		return (ENOSPC);
658 
659 	refcount_init(&kp->kp_refcnt, 1);
660 	kp->kp_can_send = true;
661 	kp->kp_is_initiator = r->r_handshake_state == HANDSHAKE_INITIATOR;
662 	kp->kp_birthdate = getsbinuptime();
663 	kp->kp_remote = noise_remote_ref(r);
664 
665 	if (kp->kp_is_initiator)
666 		noise_kdf(kp->kp_send, kp->kp_recv, NULL, NULL,
667 		    NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
668 		    r->r_handshake.hs_ck);
669 	else
670 		noise_kdf(kp->kp_recv, kp->kp_send, NULL, NULL,
671 		    NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
672 		    r->r_handshake.hs_ck);
673 
674 	rw_init(&kp->kp_nonce_lock, "noise_nonce");
675 
676 	noise_add_new_keypair(r->r_local, r, kp);
677 	return (0);
678 }
679 
680 struct noise_keypair *
681 noise_keypair_lookup(struct noise_local *l, uint32_t idx0)
682 {
683 	struct epoch_tracker et;
684 	struct noise_index *i;
685 	struct noise_keypair *kp, *ret = NULL;
686 	uint32_t idx = idx0 & HT_INDEX_MASK;
687 
688 	NET_EPOCH_ENTER(et);
689 	CK_LIST_FOREACH(i, &l->l_index_hash[idx], i_entry) {
690 		if (i->i_local_index == idx0 && i->i_is_keypair) {
691 			kp = (struct noise_keypair *) i;
692 			if (refcount_acquire_if_not_zero(&kp->kp_refcnt))
693 				ret = kp;
694 			break;
695 		}
696 	}
697 	NET_EPOCH_EXIT(et);
698 	return (ret);
699 }
700 
701 struct noise_keypair *
702 noise_keypair_current(struct noise_remote *r)
703 {
704 	struct epoch_tracker et;
705 	struct noise_keypair *kp, *ret = NULL;
706 
707 	NET_EPOCH_ENTER(et);
708 	kp = atomic_load_ptr(&r->r_current);
709 	if (kp != NULL && atomic_load_bool(&kp->kp_can_send)) {
710 		if (noise_timer_expired(kp->kp_birthdate, REJECT_AFTER_TIME, 0))
711 			atomic_store_bool(&kp->kp_can_send, false);
712 		else if (refcount_acquire_if_not_zero(&kp->kp_refcnt))
713 			ret = kp;
714 	}
715 	NET_EPOCH_EXIT(et);
716 	return (ret);
717 }
718 
719 struct noise_keypair *
720 noise_keypair_ref(struct noise_keypair *kp)
721 {
722 	refcount_acquire(&kp->kp_refcnt);
723 	return (kp);
724 }
725 
726 int
727 noise_keypair_received_with(struct noise_keypair *kp)
728 {
729 	struct noise_keypair *old;
730 	struct noise_remote *r = kp->kp_remote;
731 
732 	if (kp != atomic_load_ptr(&r->r_next))
733 		return (0);
734 
735 	mtx_lock(&r->r_keypair_mtx);
736 	if (kp != atomic_load_ptr(&r->r_next)) {
737 		mtx_unlock(&r->r_keypair_mtx);
738 		return (0);
739 	}
740 
741 	old = atomic_load_ptr(&r->r_previous);
742 	atomic_store_ptr(&r->r_previous, atomic_load_ptr(&r->r_current));
743 	noise_keypair_drop(old);
744 	atomic_store_ptr(&r->r_current, kp);
745 	atomic_store_ptr(&r->r_next, NULL);
746 	mtx_unlock(&r->r_keypair_mtx);
747 
748 	return (ECONNRESET);
749 }
750 
751 static void
752 noise_keypair_smr_free(struct epoch_context *smr)
753 {
754 	struct noise_keypair *kp;
755 	kp = __containerof(smr, struct noise_keypair, kp_smr);
756 	noise_remote_put(kp->kp_remote);
757 	rw_destroy(&kp->kp_nonce_lock);
758 	zfree(kp, M_NOISE);
759 }
760 
761 void
762 noise_keypair_put(struct noise_keypair *kp)
763 {
764 	if (refcount_release(&kp->kp_refcnt))
765 		NET_EPOCH_CALL(noise_keypair_smr_free, &kp->kp_smr);
766 }
767 
768 static void
769 noise_keypair_drop(struct noise_keypair *kp)
770 {
771 	struct noise_remote *r;
772 	struct noise_local *l;
773 
774 	if (kp == NULL)
775 		return;
776 
777 	r = kp->kp_remote;
778 	l = r->r_local;
779 
780 	mtx_lock(&l->l_index_mtx);
781 	CK_LIST_REMOVE(&kp->kp_index, i_entry);
782 	mtx_unlock(&l->l_index_mtx);
783 
784 	noise_keypair_put(kp);
785 }
786 
787 struct noise_remote *
788 noise_keypair_remote(struct noise_keypair *kp)
789 {
790 	return (noise_remote_ref(kp->kp_remote));
791 }
792 
793 int
794 noise_keypair_nonce_next(struct noise_keypair *kp, uint64_t *send)
795 {
796 	if (!atomic_load_bool(&kp->kp_can_send))
797 		return (EINVAL);
798 
799 #ifdef __LP64__
800 	*send = atomic_fetchadd_64(&kp->kp_nonce_send, 1);
801 #else
802 	rw_wlock(&kp->kp_nonce_lock);
803 	*send = kp->kp_nonce_send++;
804 	rw_wunlock(&kp->kp_nonce_lock);
805 #endif
806 	if (*send < REJECT_AFTER_MESSAGES)
807 		return (0);
808 	atomic_store_bool(&kp->kp_can_send, false);
809 	return (EINVAL);
810 }
811 
812 int
813 noise_keypair_nonce_check(struct noise_keypair *kp, uint64_t recv)
814 {
815 	unsigned long index, index_current, top, i, bit;
816 	int ret = EEXIST;
817 
818 	rw_wlock(&kp->kp_nonce_lock);
819 
820 	if (__predict_false(kp->kp_nonce_recv >= REJECT_AFTER_MESSAGES + 1 ||
821 			    recv >= REJECT_AFTER_MESSAGES))
822 		goto error;
823 
824 	++recv;
825 
826 	if (__predict_false(recv + COUNTER_WINDOW_SIZE < kp->kp_nonce_recv))
827 		goto error;
828 
829 	index = recv >> COUNTER_ORDER;
830 
831 	if (__predict_true(recv > kp->kp_nonce_recv)) {
832 		index_current = kp->kp_nonce_recv >> COUNTER_ORDER;
833 		top = MIN(index - index_current, COUNTER_BITS_TOTAL / COUNTER_BITS);
834 		for (i = 1; i <= top; i++)
835 			kp->kp_backtrack[
836 			    (i + index_current) &
837 				((COUNTER_BITS_TOTAL / COUNTER_BITS) - 1)] = 0;
838 #ifdef __LP64__
839 		atomic_store_64(&kp->kp_nonce_recv, recv);
840 #else
841 		kp->kp_nonce_recv = recv;
842 #endif
843 	}
844 
845 	index &= (COUNTER_BITS_TOTAL / COUNTER_BITS) - 1;
846 	bit = 1ul << (recv & (COUNTER_BITS - 1));
847 	if (kp->kp_backtrack[index] & bit)
848 		goto error;
849 
850 	kp->kp_backtrack[index] |= bit;
851 	ret = 0;
852 error:
853 	rw_wunlock(&kp->kp_nonce_lock);
854 	return (ret);
855 }
856 
857 int
858 noise_keep_key_fresh_send(struct noise_remote *r)
859 {
860 	struct epoch_tracker et;
861 	struct noise_keypair *current;
862 	int keep_key_fresh;
863 	uint64_t nonce;
864 
865 	NET_EPOCH_ENTER(et);
866 	current = atomic_load_ptr(&r->r_current);
867 	keep_key_fresh = current != NULL && atomic_load_bool(&current->kp_can_send);
868 	if (!keep_key_fresh)
869 		goto out;
870 #ifdef __LP64__
871 	nonce = atomic_load_64(&current->kp_nonce_send);
872 #else
873 	rw_rlock(&current->kp_nonce_lock);
874 	nonce = current->kp_nonce_send;
875 	rw_runlock(&current->kp_nonce_lock);
876 #endif
877 	keep_key_fresh = nonce > REKEY_AFTER_MESSAGES;
878 	if (keep_key_fresh)
879 		goto out;
880 	keep_key_fresh = current->kp_is_initiator && noise_timer_expired(current->kp_birthdate, REKEY_AFTER_TIME, 0);
881 
882 out:
883 	NET_EPOCH_EXIT(et);
884 	return (keep_key_fresh ? ESTALE : 0);
885 }
886 
887 int
888 noise_keep_key_fresh_recv(struct noise_remote *r)
889 {
890 	struct epoch_tracker et;
891 	struct noise_keypair *current;
892 	int keep_key_fresh;
893 
894 	NET_EPOCH_ENTER(et);
895 	current = atomic_load_ptr(&r->r_current);
896 	keep_key_fresh = current != NULL && atomic_load_bool(&current->kp_can_send) &&
897 	    current->kp_is_initiator && noise_timer_expired(current->kp_birthdate,
898 	    REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT, 0);
899 	NET_EPOCH_EXIT(et);
900 
901 	return (keep_key_fresh ? ESTALE : 0);
902 }
903 
904 int
905 noise_keypair_encrypt(struct noise_keypair *kp, uint32_t *r_idx, uint64_t nonce, struct mbuf *m)
906 {
907 	int ret;
908 
909 	ret = chacha20poly1305_encrypt_mbuf(m, nonce, kp->kp_send);
910 	if (ret)
911 		return (ret);
912 
913 	*r_idx = kp->kp_index.i_remote_index;
914 	return (0);
915 }
916 
917 int
918 noise_keypair_decrypt(struct noise_keypair *kp, uint64_t nonce, struct mbuf *m)
919 {
920 	uint64_t cur_nonce;
921 	int ret;
922 
923 #ifdef __LP64__
924 	cur_nonce = atomic_load_64(&kp->kp_nonce_recv);
925 #else
926 	rw_rlock(&kp->kp_nonce_lock);
927 	cur_nonce = kp->kp_nonce_recv;
928 	rw_runlock(&kp->kp_nonce_lock);
929 #endif
930 
931 	if (cur_nonce >= REJECT_AFTER_MESSAGES ||
932 	    noise_timer_expired(kp->kp_birthdate, REJECT_AFTER_TIME, 0))
933 		return (EINVAL);
934 
935 	ret = chacha20poly1305_decrypt_mbuf(m, nonce, kp->kp_recv);
936 	if (ret)
937 		return (ret);
938 
939 	return (0);
940 }
941 
942 /* Handshake functions */
943 int
944 noise_create_initiation(struct noise_remote *r,
945     uint32_t *s_idx,
946     uint8_t ue[NOISE_PUBLIC_KEY_LEN],
947     uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN],
948     uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN])
949 {
950 	struct noise_handshake *hs = &r->r_handshake;
951 	struct noise_local *l = r->r_local;
952 	uint8_t key[NOISE_SYMMETRIC_KEY_LEN];
953 	int ret = EINVAL;
954 
955 	rw_rlock(&l->l_identity_lock);
956 	rw_wlock(&r->r_handshake_lock);
957 	if (!l->l_has_identity)
958 		goto error;
959 	if (!noise_timer_expired(r->r_last_sent, REKEY_TIMEOUT, 0))
960 		goto error;
961 	noise_param_init(hs->hs_ck, hs->hs_hash, r->r_public);
962 
963 	/* e */
964 	curve25519_generate_secret(hs->hs_e);
965 	if (curve25519_generate_public(ue, hs->hs_e) == 0)
966 		goto error;
967 	noise_msg_ephemeral(hs->hs_ck, hs->hs_hash, ue);
968 
969 	/* es */
970 	if (noise_mix_dh(hs->hs_ck, key, hs->hs_e, r->r_public) != 0)
971 		goto error;
972 
973 	/* s */
974 	noise_msg_encrypt(es, l->l_public,
975 	    NOISE_PUBLIC_KEY_LEN, key, hs->hs_hash);
976 
977 	/* ss */
978 	if (noise_mix_ss(hs->hs_ck, key, r->r_ss) != 0)
979 		goto error;
980 
981 	/* {t} */
982 	noise_tai64n_now(ets);
983 	noise_msg_encrypt(ets, ets,
984 	    NOISE_TIMESTAMP_LEN, key, hs->hs_hash);
985 
986 	noise_remote_index_insert(l, r);
987 	r->r_handshake_state = HANDSHAKE_INITIATOR;
988 	r->r_last_sent = getsbinuptime();
989 	*s_idx = r->r_index.i_local_index;
990 	ret = 0;
991 error:
992 	rw_wunlock(&r->r_handshake_lock);
993 	rw_runlock(&l->l_identity_lock);
994 	explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN);
995 	return (ret);
996 }
997 
998 int
999 noise_consume_initiation(struct noise_local *l, struct noise_remote **rp,
1000     uint32_t s_idx,
1001     uint8_t ue[NOISE_PUBLIC_KEY_LEN],
1002     uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN],
1003     uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN])
1004 {
1005 	struct noise_remote *r;
1006 	struct noise_handshake hs;
1007 	uint8_t key[NOISE_SYMMETRIC_KEY_LEN];
1008 	uint8_t r_public[NOISE_PUBLIC_KEY_LEN];
1009 	uint8_t	timestamp[NOISE_TIMESTAMP_LEN];
1010 	int ret = EINVAL;
1011 
1012 	rw_rlock(&l->l_identity_lock);
1013 	if (!l->l_has_identity)
1014 		goto error;
1015 	noise_param_init(hs.hs_ck, hs.hs_hash, l->l_public);
1016 
1017 	/* e */
1018 	noise_msg_ephemeral(hs.hs_ck, hs.hs_hash, ue);
1019 
1020 	/* es */
1021 	if (noise_mix_dh(hs.hs_ck, key, l->l_private, ue) != 0)
1022 		goto error;
1023 
1024 	/* s */
1025 	if (noise_msg_decrypt(r_public, es,
1026 	    NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN, key, hs.hs_hash) != 0)
1027 		goto error;
1028 
1029 	/* Lookup the remote we received from */
1030 	if ((r = noise_remote_lookup(l, r_public)) == NULL)
1031 		goto error;
1032 
1033 	/* ss */
1034 	if (noise_mix_ss(hs.hs_ck, key, r->r_ss) != 0)
1035 		goto error_put;
1036 
1037 	/* {t} */
1038 	if (noise_msg_decrypt(timestamp, ets,
1039 	    NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN, key, hs.hs_hash) != 0)
1040 		goto error_put;
1041 
1042 	memcpy(hs.hs_e, ue, NOISE_PUBLIC_KEY_LEN);
1043 
1044 	/* We have successfully computed the same results, now we ensure that
1045 	 * this is not an initiation replay, or a flood attack */
1046 	rw_wlock(&r->r_handshake_lock);
1047 
1048 	/* Replay */
1049 	if (memcmp(timestamp, r->r_timestamp, NOISE_TIMESTAMP_LEN) > 0)
1050 		memcpy(r->r_timestamp, timestamp, NOISE_TIMESTAMP_LEN);
1051 	else
1052 		goto error_set;
1053 	/* Flood attack */
1054 	if (noise_timer_expired(r->r_last_init_recv, 0, REJECT_INTERVAL))
1055 		r->r_last_init_recv = getsbinuptime();
1056 	else
1057 		goto error_set;
1058 
1059 	/* Ok, we're happy to accept this initiation now */
1060 	noise_remote_index_insert(l, r);
1061 	r->r_index.i_remote_index = s_idx;
1062 	r->r_handshake_state = HANDSHAKE_RESPONDER;
1063 	r->r_handshake = hs;
1064 	*rp = noise_remote_ref(r);
1065 	ret = 0;
1066 error_set:
1067 	rw_wunlock(&r->r_handshake_lock);
1068 error_put:
1069 	noise_remote_put(r);
1070 error:
1071 	rw_runlock(&l->l_identity_lock);
1072 	explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN);
1073 	explicit_bzero(&hs, sizeof(hs));
1074 	return (ret);
1075 }
1076 
1077 int
1078 noise_create_response(struct noise_remote *r,
1079     uint32_t *s_idx, uint32_t *r_idx,
1080     uint8_t ue[NOISE_PUBLIC_KEY_LEN],
1081     uint8_t en[0 + NOISE_AUTHTAG_LEN])
1082 {
1083 	struct noise_handshake *hs = &r->r_handshake;
1084 	struct noise_local *l = r->r_local;
1085 	uint8_t key[NOISE_SYMMETRIC_KEY_LEN];
1086 	uint8_t e[NOISE_PUBLIC_KEY_LEN];
1087 	int ret = EINVAL;
1088 
1089 	rw_rlock(&l->l_identity_lock);
1090 	rw_wlock(&r->r_handshake_lock);
1091 
1092 	if (r->r_handshake_state != HANDSHAKE_RESPONDER)
1093 		goto error;
1094 
1095 	/* e */
1096 	curve25519_generate_secret(e);
1097 	if (curve25519_generate_public(ue, e) == 0)
1098 		goto error;
1099 	noise_msg_ephemeral(hs->hs_ck, hs->hs_hash, ue);
1100 
1101 	/* ee */
1102 	if (noise_mix_dh(hs->hs_ck, NULL, e, hs->hs_e) != 0)
1103 		goto error;
1104 
1105 	/* se */
1106 	if (noise_mix_dh(hs->hs_ck, NULL, e, r->r_public) != 0)
1107 		goto error;
1108 
1109 	/* psk */
1110 	noise_mix_psk(hs->hs_ck, hs->hs_hash, key, r->r_psk);
1111 
1112 	/* {} */
1113 	noise_msg_encrypt(en, NULL, 0, key, hs->hs_hash);
1114 
1115 	if ((ret = noise_begin_session(r)) == 0) {
1116 		r->r_last_sent = getsbinuptime();
1117 		*s_idx = r->r_index.i_local_index;
1118 		*r_idx = r->r_index.i_remote_index;
1119 	}
1120 error:
1121 	rw_wunlock(&r->r_handshake_lock);
1122 	rw_runlock(&l->l_identity_lock);
1123 	explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN);
1124 	explicit_bzero(e, NOISE_PUBLIC_KEY_LEN);
1125 	return (ret);
1126 }
1127 
1128 int
1129 noise_consume_response(struct noise_local *l, struct noise_remote **rp,
1130     uint32_t s_idx, uint32_t r_idx,
1131     uint8_t ue[NOISE_PUBLIC_KEY_LEN],
1132     uint8_t en[0 + NOISE_AUTHTAG_LEN])
1133 {
1134 	uint8_t preshared_key[NOISE_SYMMETRIC_KEY_LEN];
1135 	uint8_t key[NOISE_SYMMETRIC_KEY_LEN];
1136 	struct noise_handshake hs;
1137 	struct noise_remote *r = NULL;
1138 	int ret = EINVAL;
1139 
1140 	if ((r = noise_remote_index_lookup(l, r_idx, false)) == NULL)
1141 		return (ret);
1142 
1143 	rw_rlock(&l->l_identity_lock);
1144 	if (!l->l_has_identity)
1145 		goto error;
1146 
1147 	rw_rlock(&r->r_handshake_lock);
1148 	if (r->r_handshake_state != HANDSHAKE_INITIATOR) {
1149 		rw_runlock(&r->r_handshake_lock);
1150 		goto error;
1151 	}
1152 	memcpy(preshared_key, r->r_psk, NOISE_SYMMETRIC_KEY_LEN);
1153 	hs = r->r_handshake;
1154 	rw_runlock(&r->r_handshake_lock);
1155 
1156 	/* e */
1157 	noise_msg_ephemeral(hs.hs_ck, hs.hs_hash, ue);
1158 
1159 	/* ee */
1160 	if (noise_mix_dh(hs.hs_ck, NULL, hs.hs_e, ue) != 0)
1161 		goto error_zero;
1162 
1163 	/* se */
1164 	if (noise_mix_dh(hs.hs_ck, NULL, l->l_private, ue) != 0)
1165 		goto error_zero;
1166 
1167 	/* psk */
1168 	noise_mix_psk(hs.hs_ck, hs.hs_hash, key, preshared_key);
1169 
1170 	/* {} */
1171 	if (noise_msg_decrypt(NULL, en,
1172 	    0 + NOISE_AUTHTAG_LEN, key, hs.hs_hash) != 0)
1173 		goto error_zero;
1174 
1175 	rw_wlock(&r->r_handshake_lock);
1176 	if (r->r_handshake_state == HANDSHAKE_INITIATOR &&
1177 	    r->r_index.i_local_index == r_idx) {
1178 		r->r_handshake = hs;
1179 		r->r_index.i_remote_index = s_idx;
1180 		if ((ret = noise_begin_session(r)) == 0)
1181 			*rp = noise_remote_ref(r);
1182 	}
1183 	rw_wunlock(&r->r_handshake_lock);
1184 error_zero:
1185 	explicit_bzero(preshared_key, NOISE_SYMMETRIC_KEY_LEN);
1186 	explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN);
1187 	explicit_bzero(&hs, sizeof(hs));
1188 error:
1189 	rw_runlock(&l->l_identity_lock);
1190 	noise_remote_put(r);
1191 	return (ret);
1192 }
1193 
1194 static void
1195 hmac(uint8_t *out, const uint8_t *in, const uint8_t *key, const size_t outlen,
1196      const size_t inlen, const size_t keylen)
1197 {
1198 	struct blake2s_state state;
1199 	uint8_t x_key[BLAKE2S_BLOCK_SIZE] __aligned(sizeof(uint32_t)) = { 0 };
1200 	uint8_t i_hash[BLAKE2S_HASH_SIZE] __aligned(sizeof(uint32_t));
1201 	int i;
1202 
1203 	if (keylen > BLAKE2S_BLOCK_SIZE) {
1204 		blake2s_init(&state, BLAKE2S_HASH_SIZE);
1205 		blake2s_update(&state, key, keylen);
1206 		blake2s_final(&state, x_key);
1207 	} else
1208 		memcpy(x_key, key, keylen);
1209 
1210 	for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
1211 		x_key[i] ^= 0x36;
1212 
1213 	blake2s_init(&state, BLAKE2S_HASH_SIZE);
1214 	blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE);
1215 	blake2s_update(&state, in, inlen);
1216 	blake2s_final(&state, i_hash);
1217 
1218 	for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
1219 		x_key[i] ^= 0x5c ^ 0x36;
1220 
1221 	blake2s_init(&state, BLAKE2S_HASH_SIZE);
1222 	blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE);
1223 	blake2s_update(&state, i_hash, BLAKE2S_HASH_SIZE);
1224 	blake2s_final(&state, i_hash);
1225 
1226 	memcpy(out, i_hash, outlen);
1227 	explicit_bzero(x_key, BLAKE2S_BLOCK_SIZE);
1228 	explicit_bzero(i_hash, BLAKE2S_HASH_SIZE);
1229 }
1230 
1231 /* Handshake helper functions */
1232 static void
1233 noise_kdf(uint8_t *a, uint8_t *b, uint8_t *c, const uint8_t *x,
1234     size_t a_len, size_t b_len, size_t c_len, size_t x_len,
1235     const uint8_t ck[NOISE_HASH_LEN])
1236 {
1237 	uint8_t out[BLAKE2S_HASH_SIZE + 1];
1238 	uint8_t sec[BLAKE2S_HASH_SIZE];
1239 
1240 	/* Extract entropy from "x" into sec */
1241 	hmac(sec, x, ck, BLAKE2S_HASH_SIZE, x_len, NOISE_HASH_LEN);
1242 
1243 	if (a == NULL || a_len == 0)
1244 		goto out;
1245 
1246 	/* Expand first key: key = sec, data = 0x1 */
1247 	out[0] = 1;
1248 	hmac(out, out, sec, BLAKE2S_HASH_SIZE, 1, BLAKE2S_HASH_SIZE);
1249 	memcpy(a, out, a_len);
1250 
1251 	if (b == NULL || b_len == 0)
1252 		goto out;
1253 
1254 	/* Expand second key: key = sec, data = "a" || 0x2 */
1255 	out[BLAKE2S_HASH_SIZE] = 2;
1256 	hmac(out, out, sec, BLAKE2S_HASH_SIZE, BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE);
1257 	memcpy(b, out, b_len);
1258 
1259 	if (c == NULL || c_len == 0)
1260 		goto out;
1261 
1262 	/* Expand third key: key = sec, data = "b" || 0x3 */
1263 	out[BLAKE2S_HASH_SIZE] = 3;
1264 	hmac(out, out, sec, BLAKE2S_HASH_SIZE, BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE);
1265 	memcpy(c, out, c_len);
1266 
1267 out:
1268 	/* Clear sensitive data from stack */
1269 	explicit_bzero(sec, BLAKE2S_HASH_SIZE);
1270 	explicit_bzero(out, BLAKE2S_HASH_SIZE + 1);
1271 }
1272 
1273 static int
1274 noise_mix_dh(uint8_t ck[NOISE_HASH_LEN], uint8_t key[NOISE_SYMMETRIC_KEY_LEN],
1275     const uint8_t private[NOISE_PUBLIC_KEY_LEN],
1276     const uint8_t public[NOISE_PUBLIC_KEY_LEN])
1277 {
1278 	uint8_t dh[NOISE_PUBLIC_KEY_LEN];
1279 
1280 	if (!curve25519(dh, private, public))
1281 		return (EINVAL);
1282 	noise_kdf(ck, key, NULL, dh,
1283 	    NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, ck);
1284 	explicit_bzero(dh, NOISE_PUBLIC_KEY_LEN);
1285 	return (0);
1286 }
1287 
1288 static int
1289 noise_mix_ss(uint8_t ck[NOISE_HASH_LEN], uint8_t key[NOISE_SYMMETRIC_KEY_LEN],
1290     const uint8_t ss[NOISE_PUBLIC_KEY_LEN])
1291 {
1292 	static uint8_t null_point[NOISE_PUBLIC_KEY_LEN];
1293 	if (timingsafe_bcmp(ss, null_point, NOISE_PUBLIC_KEY_LEN) == 0)
1294 		return (ENOENT);
1295 	noise_kdf(ck, key, NULL, ss,
1296 	    NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, ck);
1297 	return (0);
1298 }
1299 
1300 static void
1301 noise_mix_hash(uint8_t hash[NOISE_HASH_LEN], const uint8_t *src,
1302     size_t src_len)
1303 {
1304 	struct blake2s_state blake;
1305 
1306 	blake2s_init(&blake, NOISE_HASH_LEN);
1307 	blake2s_update(&blake, hash, NOISE_HASH_LEN);
1308 	blake2s_update(&blake, src, src_len);
1309 	blake2s_final(&blake, hash);
1310 }
1311 
1312 static void
1313 noise_mix_psk(uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN],
1314     uint8_t key[NOISE_SYMMETRIC_KEY_LEN],
1315     const uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
1316 {
1317 	uint8_t tmp[NOISE_HASH_LEN];
1318 
1319 	noise_kdf(ck, tmp, key, psk,
1320 	    NOISE_HASH_LEN, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN,
1321 	    NOISE_SYMMETRIC_KEY_LEN, ck);
1322 	noise_mix_hash(hash, tmp, NOISE_HASH_LEN);
1323 	explicit_bzero(tmp, NOISE_HASH_LEN);
1324 }
1325 
1326 static void
1327 noise_param_init(uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN],
1328     const uint8_t s[NOISE_PUBLIC_KEY_LEN])
1329 {
1330 	struct blake2s_state blake;
1331 
1332 	blake2s(ck, (uint8_t *)NOISE_HANDSHAKE_NAME, NULL,
1333 	    NOISE_HASH_LEN, strlen(NOISE_HANDSHAKE_NAME), 0);
1334 	blake2s_init(&blake, NOISE_HASH_LEN);
1335 	blake2s_update(&blake, ck, NOISE_HASH_LEN);
1336 	blake2s_update(&blake, (uint8_t *)NOISE_IDENTIFIER_NAME,
1337 	    strlen(NOISE_IDENTIFIER_NAME));
1338 	blake2s_final(&blake, hash);
1339 
1340 	noise_mix_hash(hash, s, NOISE_PUBLIC_KEY_LEN);
1341 }
1342 
1343 static void
1344 noise_msg_encrypt(uint8_t *dst, const uint8_t *src, size_t src_len,
1345     uint8_t key[NOISE_SYMMETRIC_KEY_LEN], uint8_t hash[NOISE_HASH_LEN])
1346 {
1347 	/* Nonce always zero for Noise_IK */
1348 	chacha20poly1305_encrypt(dst, src, src_len,
1349 	    hash, NOISE_HASH_LEN, 0, key);
1350 	noise_mix_hash(hash, dst, src_len + NOISE_AUTHTAG_LEN);
1351 }
1352 
1353 static int
1354 noise_msg_decrypt(uint8_t *dst, const uint8_t *src, size_t src_len,
1355     uint8_t key[NOISE_SYMMETRIC_KEY_LEN], uint8_t hash[NOISE_HASH_LEN])
1356 {
1357 	/* Nonce always zero for Noise_IK */
1358 	if (!chacha20poly1305_decrypt(dst, src, src_len,
1359 	    hash, NOISE_HASH_LEN, 0, key))
1360 		return (EINVAL);
1361 	noise_mix_hash(hash, src, src_len);
1362 	return (0);
1363 }
1364 
1365 static void
1366 noise_msg_ephemeral(uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN],
1367     const uint8_t src[NOISE_PUBLIC_KEY_LEN])
1368 {
1369 	noise_mix_hash(hash, src, NOISE_PUBLIC_KEY_LEN);
1370 	noise_kdf(ck, NULL, NULL, src, NOISE_HASH_LEN, 0, 0,
1371 		  NOISE_PUBLIC_KEY_LEN, ck);
1372 }
1373 
1374 static void
1375 noise_tai64n_now(uint8_t output[NOISE_TIMESTAMP_LEN])
1376 {
1377 	struct timespec time;
1378 	uint64_t sec;
1379 	uint32_t nsec;
1380 
1381 	getnanotime(&time);
1382 
1383 	/* Round down the nsec counter to limit precise timing leak. */
1384 	time.tv_nsec &= REJECT_INTERVAL_MASK;
1385 
1386 	/* https://cr.yp.to/libtai/tai64.html */
1387 	sec = htobe64(0x400000000000000aULL + time.tv_sec);
1388 	nsec = htobe32(time.tv_nsec);
1389 
1390 	/* memcpy to output buffer, assuming output could be unaligned. */
1391 	memcpy(output, &sec, sizeof(sec));
1392 	memcpy(output + sizeof(sec), &nsec, sizeof(nsec));
1393 }
1394 
1395 static inline int
1396 noise_timer_expired(sbintime_t timer, uint32_t sec, uint32_t nsec)
1397 {
1398 	sbintime_t now = getsbinuptime();
1399 	return (now > (timer + sec * SBT_1S + nstosbt(nsec))) ? ETIMEDOUT : 0;
1400 }
1401 
1402 static uint64_t siphash24(const uint8_t key[SIPHASH_KEY_LENGTH], const void *src, size_t len)
1403 {
1404 	SIPHASH_CTX ctx;
1405 	return (SipHashX(&ctx, 2, 4, key, src, len));
1406 }
1407 
1408 #ifdef SELFTESTS
1409 #include "selftest/counter.c"
1410 #endif /* SELFTESTS */
1411