xref: /freebsd/sys/dev/wg/if_wg.c (revision 7121e9414f294d116caeadd07ebd969136d3a631)
1 /* SPDX-License-Identifier: ISC
2  *
3  * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4  * Copyright (C) 2019-2021 Matt Dunwoodie <ncon@noconroy.net>
5  * Copyright (c) 2019-2020 Rubicon Communications, LLC (Netgate)
6  * Copyright (c) 2021 Kyle Evans <kevans@FreeBSD.org>
7  * Copyright (c) 2022 The FreeBSD Foundation
8  */
9 
10 #include "opt_inet.h"
11 #include "opt_inet6.h"
12 
13 #include <sys/param.h>
14 #include <sys/systm.h>
15 #include <sys/counter.h>
16 #include <sys/gtaskqueue.h>
17 #include <sys/jail.h>
18 #include <sys/kernel.h>
19 #include <sys/lock.h>
20 #include <sys/mbuf.h>
21 #include <sys/module.h>
22 #include <sys/nv.h>
23 #include <sys/priv.h>
24 #include <sys/protosw.h>
25 #include <sys/rmlock.h>
26 #include <sys/rwlock.h>
27 #include <sys/smp.h>
28 #include <sys/socket.h>
29 #include <sys/socketvar.h>
30 #include <sys/sockio.h>
31 #include <sys/sysctl.h>
32 #include <sys/sx.h>
33 #include <machine/_inttypes.h>
34 #include <net/bpf.h>
35 #include <net/ethernet.h>
36 #include <net/if.h>
37 #include <net/if_clone.h>
38 #include <net/if_types.h>
39 #include <net/if_var.h>
40 #include <net/netisr.h>
41 #include <net/radix.h>
42 #include <netinet/in.h>
43 #include <netinet6/in6_var.h>
44 #include <netinet/ip.h>
45 #include <netinet/ip6.h>
46 #include <netinet/ip_icmp.h>
47 #include <netinet/icmp6.h>
48 #include <netinet/udp_var.h>
49 #include <netinet6/nd6.h>
50 
51 #include "wg_noise.h"
52 #include "wg_cookie.h"
53 #include "version.h"
54 #include "if_wg.h"
55 
56 #define DEFAULT_MTU		(ETHERMTU - 80)
57 #define MAX_MTU			(IF_MAXMTU - 80)
58 
59 #define MAX_STAGED_PKT		128
60 #define MAX_QUEUED_PKT		1024
61 #define MAX_QUEUED_PKT_MASK	(MAX_QUEUED_PKT - 1)
62 
63 #define MAX_QUEUED_HANDSHAKES	4096
64 
65 #define REKEY_TIMEOUT_JITTER	334 /* 1/3 sec, round for arc4random_uniform */
66 #define MAX_TIMER_HANDSHAKES	(90 / REKEY_TIMEOUT)
67 #define NEW_HANDSHAKE_TIMEOUT	(REKEY_TIMEOUT + KEEPALIVE_TIMEOUT)
68 #define UNDERLOAD_TIMEOUT	1
69 
70 #define DPRINTF(sc, ...) if (if_getflags(sc->sc_ifp) & IFF_DEBUG) if_printf(sc->sc_ifp, ##__VA_ARGS__)
71 
72 /* First byte indicating packet type on the wire */
73 #define WG_PKT_INITIATION htole32(1)
74 #define WG_PKT_RESPONSE htole32(2)
75 #define WG_PKT_COOKIE htole32(3)
76 #define WG_PKT_DATA htole32(4)
77 
78 #define WG_PKT_PADDING		16
79 #define WG_KEY_SIZE		32
80 
81 struct wg_pkt_initiation {
82 	uint32_t		t;
83 	uint32_t		s_idx;
84 	uint8_t			ue[NOISE_PUBLIC_KEY_LEN];
85 	uint8_t			es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN];
86 	uint8_t			ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN];
87 	struct cookie_macs	m;
88 };
89 
90 struct wg_pkt_response {
91 	uint32_t		t;
92 	uint32_t		s_idx;
93 	uint32_t		r_idx;
94 	uint8_t			ue[NOISE_PUBLIC_KEY_LEN];
95 	uint8_t			en[0 + NOISE_AUTHTAG_LEN];
96 	struct cookie_macs	m;
97 };
98 
99 struct wg_pkt_cookie {
100 	uint32_t		t;
101 	uint32_t		r_idx;
102 	uint8_t			nonce[COOKIE_NONCE_SIZE];
103 	uint8_t			ec[COOKIE_ENCRYPTED_SIZE];
104 };
105 
106 struct wg_pkt_data {
107 	uint32_t		t;
108 	uint32_t		r_idx;
109 	uint64_t		nonce;
110 	uint8_t			buf[];
111 };
112 
113 struct wg_endpoint {
114 	union {
115 		struct sockaddr		r_sa;
116 		struct sockaddr_in	r_sin;
117 #ifdef INET6
118 		struct sockaddr_in6	r_sin6;
119 #endif
120 	} e_remote;
121 	union {
122 		struct in_addr		l_in;
123 #ifdef INET6
124 		struct in6_pktinfo	l_pktinfo6;
125 #define l_in6 l_pktinfo6.ipi6_addr
126 #endif
127 	} e_local;
128 };
129 
130 struct aip_addr {
131 	uint8_t		length;
132 	union {
133 		uint8_t		bytes[16];
134 		uint32_t	ip;
135 		uint32_t	ip6[4];
136 		struct in_addr	in;
137 		struct in6_addr	in6;
138 	};
139 };
140 
141 struct wg_aip {
142 	struct radix_node	 a_nodes[2];
143 	LIST_ENTRY(wg_aip)	 a_entry;
144 	struct aip_addr		 a_addr;
145 	struct aip_addr		 a_mask;
146 	struct wg_peer		*a_peer;
147 	sa_family_t		 a_af;
148 };
149 
150 struct wg_packet {
151 	STAILQ_ENTRY(wg_packet)	 p_serial;
152 	STAILQ_ENTRY(wg_packet)	 p_parallel;
153 	struct wg_endpoint	 p_endpoint;
154 	struct noise_keypair	*p_keypair;
155 	uint64_t		 p_nonce;
156 	struct mbuf		*p_mbuf;
157 	int			 p_mtu;
158 	sa_family_t		 p_af;
159 	enum wg_ring_state {
160 		WG_PACKET_UNCRYPTED,
161 		WG_PACKET_CRYPTED,
162 		WG_PACKET_DEAD,
163 	}			 p_state;
164 };
165 
166 STAILQ_HEAD(wg_packet_list, wg_packet);
167 
168 struct wg_queue {
169 	struct mtx		 q_mtx;
170 	struct wg_packet_list	 q_queue;
171 	size_t			 q_len;
172 };
173 
174 struct wg_peer {
175 	TAILQ_ENTRY(wg_peer)		 p_entry;
176 	uint64_t			 p_id;
177 	struct wg_softc			*p_sc;
178 
179 	struct noise_remote		*p_remote;
180 	struct cookie_maker		 p_cookie;
181 
182 	struct rwlock			 p_endpoint_lock;
183 	struct wg_endpoint		 p_endpoint;
184 
185 	struct wg_queue	 		 p_stage_queue;
186 	struct wg_queue	 		 p_encrypt_serial;
187 	struct wg_queue	 		 p_decrypt_serial;
188 
189 	bool				 p_enabled;
190 	bool				 p_need_another_keepalive;
191 	uint16_t			 p_persistent_keepalive_interval;
192 	struct callout			 p_new_handshake;
193 	struct callout			 p_send_keepalive;
194 	struct callout			 p_retry_handshake;
195 	struct callout			 p_zero_key_material;
196 	struct callout			 p_persistent_keepalive;
197 
198 	struct mtx			 p_handshake_mtx;
199 	struct timespec			 p_handshake_complete;	/* nanotime */
200 	int				 p_handshake_retries;
201 
202 	struct grouptask		 p_send;
203 	struct grouptask		 p_recv;
204 
205 	counter_u64_t			 p_tx_bytes;
206 	counter_u64_t			 p_rx_bytes;
207 
208 	LIST_HEAD(, wg_aip)		 p_aips;
209 	size_t				 p_aips_num;
210 };
211 
212 struct wg_socket {
213 	struct socket	*so_so4;
214 	struct socket	*so_so6;
215 	uint32_t	 so_user_cookie;
216 	int		 so_fibnum;
217 	in_port_t	 so_port;
218 };
219 
220 struct wg_softc {
221 	LIST_ENTRY(wg_softc)	 sc_entry;
222 	if_t			 sc_ifp;
223 	int			 sc_flags;
224 
225 	struct ucred		*sc_ucred;
226 	struct wg_socket	 sc_socket;
227 
228 	TAILQ_HEAD(,wg_peer)	 sc_peers;
229 	size_t			 sc_peers_num;
230 
231 	struct noise_local	*sc_local;
232 	struct cookie_checker	 sc_cookie;
233 
234 	struct radix_node_head	*sc_aip4;
235 	struct radix_node_head	*sc_aip6;
236 
237 	struct grouptask	 sc_handshake;
238 	struct wg_queue		 sc_handshake_queue;
239 
240 	struct grouptask	*sc_encrypt;
241 	struct grouptask	*sc_decrypt;
242 	struct wg_queue		 sc_encrypt_parallel;
243 	struct wg_queue		 sc_decrypt_parallel;
244 	u_int			 sc_encrypt_last_cpu;
245 	u_int			 sc_decrypt_last_cpu;
246 
247 	struct sx		 sc_lock;
248 };
249 
250 #define	WGF_DYING	0x0001
251 
252 #define MAX_LOOPS	8
253 #define MTAG_WGLOOP	0x77676c70 /* wglp */
254 
255 #define	GROUPTASK_DRAIN(gtask)			\
256 	gtaskqueue_drain((gtask)->gt_taskqueue, &(gtask)->gt_task)
257 
258 #define BPF_MTAP2_AF(ifp, m, af) do { \
259 		uint32_t __bpf_tap_af = (af); \
260 		BPF_MTAP2(ifp, &__bpf_tap_af, sizeof(__bpf_tap_af), m); \
261 	} while (0)
262 
263 static int clone_count;
264 static uma_zone_t wg_packet_zone;
265 static volatile unsigned long peer_counter = 0;
266 static const char wgname[] = "wg";
267 static unsigned wg_osd_jail_slot;
268 
269 static struct sx wg_sx;
270 SX_SYSINIT(wg_sx, &wg_sx, "wg_sx");
271 
272 static LIST_HEAD(, wg_softc) wg_list = LIST_HEAD_INITIALIZER(wg_list);
273 
274 static TASKQGROUP_DEFINE(wg_tqg, mp_ncpus, 1);
275 
276 MALLOC_DEFINE(M_WG, "WG", "wireguard");
277 
278 VNET_DEFINE_STATIC(struct if_clone *, wg_cloner);
279 
280 #define	V_wg_cloner	VNET(wg_cloner)
281 #define	WG_CAPS		IFCAP_LINKSTATE
282 
283 struct wg_timespec64 {
284 	uint64_t	tv_sec;
285 	uint64_t	tv_nsec;
286 };
287 
288 static int wg_socket_init(struct wg_softc *, in_port_t);
289 static int wg_socket_bind(struct socket **, struct socket **, in_port_t *);
290 static void wg_socket_set(struct wg_softc *, struct socket *, struct socket *);
291 static void wg_socket_uninit(struct wg_softc *);
292 static int wg_socket_set_sockopt(struct socket *, struct socket *, int, void *, size_t);
293 static int wg_socket_set_cookie(struct wg_softc *, uint32_t);
294 static int wg_socket_set_fibnum(struct wg_softc *, int);
295 static int wg_send(struct wg_softc *, struct wg_endpoint *, struct mbuf *);
296 static void wg_timers_enable(struct wg_peer *);
297 static void wg_timers_disable(struct wg_peer *);
298 static void wg_timers_set_persistent_keepalive(struct wg_peer *, uint16_t);
299 static void wg_timers_get_last_handshake(struct wg_peer *, struct wg_timespec64 *);
300 static void wg_timers_event_data_sent(struct wg_peer *);
301 static void wg_timers_event_data_received(struct wg_peer *);
302 static void wg_timers_event_any_authenticated_packet_sent(struct wg_peer *);
303 static void wg_timers_event_any_authenticated_packet_received(struct wg_peer *);
304 static void wg_timers_event_any_authenticated_packet_traversal(struct wg_peer *);
305 static void wg_timers_event_handshake_initiated(struct wg_peer *);
306 static void wg_timers_event_handshake_complete(struct wg_peer *);
307 static void wg_timers_event_session_derived(struct wg_peer *);
308 static void wg_timers_event_want_initiation(struct wg_peer *);
309 static void wg_timers_run_send_initiation(struct wg_peer *, bool);
310 static void wg_timers_run_retry_handshake(void *);
311 static void wg_timers_run_send_keepalive(void *);
312 static void wg_timers_run_new_handshake(void *);
313 static void wg_timers_run_zero_key_material(void *);
314 static void wg_timers_run_persistent_keepalive(void *);
315 static int wg_aip_add(struct wg_softc *, struct wg_peer *, sa_family_t, const void *, uint8_t);
316 static struct wg_peer *wg_aip_lookup(struct wg_softc *, sa_family_t, void *);
317 static void wg_aip_remove_all(struct wg_softc *, struct wg_peer *);
318 static struct wg_peer *wg_peer_create(struct wg_softc *,
319     const uint8_t [WG_KEY_SIZE], int *);
320 static void wg_peer_free_deferred(struct noise_remote *);
321 static void wg_peer_destroy(struct wg_peer *);
322 static void wg_peer_destroy_all(struct wg_softc *);
323 static void wg_peer_send_buf(struct wg_peer *, uint8_t *, size_t);
324 static void wg_send_initiation(struct wg_peer *);
325 static void wg_send_response(struct wg_peer *);
326 static void wg_send_cookie(struct wg_softc *, struct cookie_macs *, uint32_t, struct wg_endpoint *);
327 static void wg_peer_set_endpoint(struct wg_peer *, struct wg_endpoint *);
328 static void wg_peer_clear_src(struct wg_peer *);
329 static void wg_peer_get_endpoint(struct wg_peer *, struct wg_endpoint *);
330 static void wg_send_buf(struct wg_softc *, struct wg_endpoint *, uint8_t *, size_t);
331 static void wg_send_keepalive(struct wg_peer *);
332 static void wg_handshake(struct wg_softc *, struct wg_packet *);
333 static void wg_encrypt(struct wg_softc *, struct wg_packet *);
334 static void wg_decrypt(struct wg_softc *, struct wg_packet *);
335 static void wg_softc_handshake_receive(struct wg_softc *);
336 static void wg_softc_decrypt(struct wg_softc *);
337 static void wg_softc_encrypt(struct wg_softc *);
338 static void wg_encrypt_dispatch(struct wg_softc *);
339 static void wg_decrypt_dispatch(struct wg_softc *);
340 static void wg_deliver_out(struct wg_peer *);
341 static void wg_deliver_in(struct wg_peer *);
342 static struct wg_packet *wg_packet_alloc(struct mbuf *);
343 static void wg_packet_free(struct wg_packet *);
344 static void wg_queue_init(struct wg_queue *, const char *);
345 static void wg_queue_deinit(struct wg_queue *);
346 static size_t wg_queue_len(struct wg_queue *);
347 static int wg_queue_enqueue_handshake(struct wg_queue *, struct wg_packet *);
348 static struct wg_packet *wg_queue_dequeue_handshake(struct wg_queue *);
349 static void wg_queue_push_staged(struct wg_queue *, struct wg_packet *);
350 static void wg_queue_enlist_staged(struct wg_queue *, struct wg_packet_list *);
351 static void wg_queue_delist_staged(struct wg_queue *, struct wg_packet_list *);
352 static void wg_queue_purge(struct wg_queue *);
353 static int wg_queue_both(struct wg_queue *, struct wg_queue *, struct wg_packet *);
354 static struct wg_packet *wg_queue_dequeue_serial(struct wg_queue *);
355 static struct wg_packet *wg_queue_dequeue_parallel(struct wg_queue *);
356 static bool wg_input(struct mbuf *, int, struct inpcb *, const struct sockaddr *, void *);
357 static void wg_peer_send_staged(struct wg_peer *);
358 static int wg_clone_create(struct if_clone *ifc, char *name, size_t len,
359 	struct ifc_data *ifd, if_t *ifpp);
360 static void wg_qflush(if_t);
361 static inline int determine_af_and_pullup(struct mbuf **m, sa_family_t *af);
362 static int wg_xmit(if_t, struct mbuf *, sa_family_t, uint32_t);
363 static int wg_transmit(if_t, struct mbuf *);
364 static int wg_output(if_t, struct mbuf *, const struct sockaddr *, struct route *);
365 static int wg_clone_destroy(struct if_clone *ifc, if_t ifp,
366 	uint32_t flags);
367 static bool wgc_privileged(struct wg_softc *);
368 static int wgc_get(struct wg_softc *, struct wg_data_io *);
369 static int wgc_set(struct wg_softc *, struct wg_data_io *);
370 static int wg_up(struct wg_softc *);
371 static void wg_down(struct wg_softc *);
372 static void wg_reassign(if_t, struct vnet *, char *unused);
373 static void wg_init(void *);
374 static int wg_ioctl(if_t, u_long, caddr_t);
375 static void vnet_wg_init(const void *);
376 static void vnet_wg_uninit(const void *);
377 static int wg_module_init(void);
378 static void wg_module_deinit(void);
379 
380 /* TODO Peer */
381 static struct wg_peer *
wg_peer_create(struct wg_softc * sc,const uint8_t pub_key[WG_KEY_SIZE],int * errp)382 wg_peer_create(struct wg_softc *sc, const uint8_t pub_key[WG_KEY_SIZE],
383     int *errp)
384 {
385 	struct wg_peer *peer;
386 
387 	sx_assert(&sc->sc_lock, SX_XLOCKED);
388 
389 	peer = malloc(sizeof(*peer), M_WG, M_WAITOK | M_ZERO);
390 
391 	peer->p_remote = noise_remote_alloc(sc->sc_local, peer, pub_key);
392 	if ((*errp = noise_remote_enable(peer->p_remote)) != 0) {
393 		noise_remote_free(peer->p_remote, NULL);
394 		free(peer, M_WG);
395 		return (NULL);
396 	}
397 
398 	peer->p_id = peer_counter++;
399 	peer->p_sc = sc;
400 	peer->p_tx_bytes = counter_u64_alloc(M_WAITOK);
401 	peer->p_rx_bytes = counter_u64_alloc(M_WAITOK);
402 
403 	cookie_maker_init(&peer->p_cookie, pub_key);
404 
405 	rw_init(&peer->p_endpoint_lock, "wg_peer_endpoint");
406 
407 	wg_queue_init(&peer->p_stage_queue, "stageq");
408 	wg_queue_init(&peer->p_encrypt_serial, "txq");
409 	wg_queue_init(&peer->p_decrypt_serial, "rxq");
410 
411 	peer->p_enabled = false;
412 	peer->p_need_another_keepalive = false;
413 	peer->p_persistent_keepalive_interval = 0;
414 	callout_init(&peer->p_new_handshake, true);
415 	callout_init(&peer->p_send_keepalive, true);
416 	callout_init(&peer->p_retry_handshake, true);
417 	callout_init(&peer->p_persistent_keepalive, true);
418 	callout_init(&peer->p_zero_key_material, true);
419 
420 	mtx_init(&peer->p_handshake_mtx, "peer handshake", NULL, MTX_DEF);
421 	bzero(&peer->p_handshake_complete, sizeof(peer->p_handshake_complete));
422 	peer->p_handshake_retries = 0;
423 
424 	GROUPTASK_INIT(&peer->p_send, 0, (gtask_fn_t *)wg_deliver_out, peer);
425 	taskqgroup_attach(qgroup_wg_tqg, &peer->p_send, peer, NULL, NULL, "wg send");
426 	GROUPTASK_INIT(&peer->p_recv, 0, (gtask_fn_t *)wg_deliver_in, peer);
427 	taskqgroup_attach(qgroup_wg_tqg, &peer->p_recv, peer, NULL, NULL, "wg recv");
428 
429 	LIST_INIT(&peer->p_aips);
430 	peer->p_aips_num = 0;
431 
432 	TAILQ_INSERT_TAIL(&sc->sc_peers, peer, p_entry);
433 	sc->sc_peers_num++;
434 
435 	if (if_getlinkstate(sc->sc_ifp) == LINK_STATE_UP)
436 		wg_timers_enable(peer);
437 
438 	DPRINTF(sc, "Peer %" PRIu64 " created\n", peer->p_id);
439 	return (peer);
440 }
441 
442 static void
wg_peer_free_deferred(struct noise_remote * r)443 wg_peer_free_deferred(struct noise_remote *r)
444 {
445 	struct wg_peer *peer = noise_remote_arg(r);
446 
447 	/* While there are no references remaining, we may still have
448 	 * p_{send,recv} executing (think empty queue, but wg_deliver_{in,out}
449 	 * needs to check the queue. We should wait for them and then free. */
450 	GROUPTASK_DRAIN(&peer->p_recv);
451 	GROUPTASK_DRAIN(&peer->p_send);
452 	taskqgroup_detach(qgroup_wg_tqg, &peer->p_recv);
453 	taskqgroup_detach(qgroup_wg_tqg, &peer->p_send);
454 
455 	wg_queue_deinit(&peer->p_decrypt_serial);
456 	wg_queue_deinit(&peer->p_encrypt_serial);
457 	wg_queue_deinit(&peer->p_stage_queue);
458 
459 	counter_u64_free(peer->p_tx_bytes);
460 	counter_u64_free(peer->p_rx_bytes);
461 	rw_destroy(&peer->p_endpoint_lock);
462 	mtx_destroy(&peer->p_handshake_mtx);
463 
464 	cookie_maker_free(&peer->p_cookie);
465 
466 	free(peer, M_WG);
467 }
468 
469 static void
wg_peer_destroy(struct wg_peer * peer)470 wg_peer_destroy(struct wg_peer *peer)
471 {
472 	struct wg_softc *sc = peer->p_sc;
473 	sx_assert(&sc->sc_lock, SX_XLOCKED);
474 
475 	/* Disable remote and timers. This will prevent any new handshakes
476 	 * occuring. */
477 	noise_remote_disable(peer->p_remote);
478 	wg_timers_disable(peer);
479 
480 	/* Now we can remove all allowed IPs so no more packets will be routed
481 	 * to the peer. */
482 	wg_aip_remove_all(sc, peer);
483 
484 	/* Remove peer from the interface, then free. Some references may still
485 	 * exist to p_remote, so noise_remote_free will wait until they're all
486 	 * put to call wg_peer_free_deferred. */
487 	sc->sc_peers_num--;
488 	TAILQ_REMOVE(&sc->sc_peers, peer, p_entry);
489 	DPRINTF(sc, "Peer %" PRIu64 " destroyed\n", peer->p_id);
490 	noise_remote_free(peer->p_remote, wg_peer_free_deferred);
491 }
492 
493 static void
wg_peer_destroy_all(struct wg_softc * sc)494 wg_peer_destroy_all(struct wg_softc *sc)
495 {
496 	struct wg_peer *peer, *tpeer;
497 	TAILQ_FOREACH_SAFE(peer, &sc->sc_peers, p_entry, tpeer)
498 		wg_peer_destroy(peer);
499 }
500 
501 static void
wg_peer_set_endpoint(struct wg_peer * peer,struct wg_endpoint * e)502 wg_peer_set_endpoint(struct wg_peer *peer, struct wg_endpoint *e)
503 {
504 	MPASS(e->e_remote.r_sa.sa_family != 0);
505 	if (memcmp(e, &peer->p_endpoint, sizeof(*e)) == 0)
506 		return;
507 
508 	rw_wlock(&peer->p_endpoint_lock);
509 	peer->p_endpoint = *e;
510 	rw_wunlock(&peer->p_endpoint_lock);
511 }
512 
513 static void
wg_peer_clear_src(struct wg_peer * peer)514 wg_peer_clear_src(struct wg_peer *peer)
515 {
516 	rw_wlock(&peer->p_endpoint_lock);
517 	bzero(&peer->p_endpoint.e_local, sizeof(peer->p_endpoint.e_local));
518 	rw_wunlock(&peer->p_endpoint_lock);
519 }
520 
521 static void
wg_peer_get_endpoint(struct wg_peer * peer,struct wg_endpoint * e)522 wg_peer_get_endpoint(struct wg_peer *peer, struct wg_endpoint *e)
523 {
524 	rw_rlock(&peer->p_endpoint_lock);
525 	*e = peer->p_endpoint;
526 	rw_runlock(&peer->p_endpoint_lock);
527 }
528 
529 /* Allowed IP */
530 static int
wg_aip_add(struct wg_softc * sc,struct wg_peer * peer,sa_family_t af,const void * addr,uint8_t cidr)531 wg_aip_add(struct wg_softc *sc, struct wg_peer *peer, sa_family_t af, const void *addr, uint8_t cidr)
532 {
533 	struct radix_node_head	*root;
534 	struct radix_node	*node;
535 	struct wg_aip		*aip;
536 	int			 ret = 0;
537 
538 	aip = malloc(sizeof(*aip), M_WG, M_WAITOK | M_ZERO);
539 	aip->a_peer = peer;
540 	aip->a_af = af;
541 
542 	switch (af) {
543 #ifdef INET
544 	case AF_INET:
545 		if (cidr > 32) cidr = 32;
546 		root = sc->sc_aip4;
547 		aip->a_addr.in = *(const struct in_addr *)addr;
548 		aip->a_mask.ip = htonl(~((1LL << (32 - cidr)) - 1) & 0xffffffff);
549 		aip->a_addr.ip &= aip->a_mask.ip;
550 		aip->a_addr.length = aip->a_mask.length = offsetof(struct aip_addr, in) + sizeof(struct in_addr);
551 		break;
552 #endif
553 #ifdef INET6
554 	case AF_INET6:
555 		if (cidr > 128) cidr = 128;
556 		root = sc->sc_aip6;
557 		aip->a_addr.in6 = *(const struct in6_addr *)addr;
558 		in6_prefixlen2mask(&aip->a_mask.in6, cidr);
559 		for (int i = 0; i < 4; i++)
560 			aip->a_addr.ip6[i] &= aip->a_mask.ip6[i];
561 		aip->a_addr.length = aip->a_mask.length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr);
562 		break;
563 #endif
564 	default:
565 		free(aip, M_WG);
566 		return (EAFNOSUPPORT);
567 	}
568 
569 	RADIX_NODE_HEAD_LOCK(root);
570 	node = root->rnh_addaddr(&aip->a_addr, &aip->a_mask, &root->rh, aip->a_nodes);
571 	if (node == aip->a_nodes) {
572 		LIST_INSERT_HEAD(&peer->p_aips, aip, a_entry);
573 		peer->p_aips_num++;
574 	} else if (!node)
575 		node = root->rnh_lookup(&aip->a_addr, &aip->a_mask, &root->rh);
576 	if (!node) {
577 		free(aip, M_WG);
578 		ret = ENOMEM;
579 	} else if (node != aip->a_nodes) {
580 		free(aip, M_WG);
581 		aip = (struct wg_aip *)node;
582 		if (aip->a_peer != peer) {
583 			LIST_REMOVE(aip, a_entry);
584 			aip->a_peer->p_aips_num--;
585 			aip->a_peer = peer;
586 			LIST_INSERT_HEAD(&peer->p_aips, aip, a_entry);
587 			aip->a_peer->p_aips_num++;
588 		}
589 	}
590 	RADIX_NODE_HEAD_UNLOCK(root);
591 	return (ret);
592 }
593 
594 static struct wg_peer *
wg_aip_lookup(struct wg_softc * sc,sa_family_t af,void * a)595 wg_aip_lookup(struct wg_softc *sc, sa_family_t af, void *a)
596 {
597 	struct radix_node_head	*root;
598 	struct radix_node	*node;
599 	struct wg_peer		*peer;
600 	struct aip_addr		 addr;
601 	RADIX_NODE_HEAD_RLOCK_TRACKER;
602 
603 	switch (af) {
604 	case AF_INET:
605 		root = sc->sc_aip4;
606 		memcpy(&addr.in, a, sizeof(addr.in));
607 		addr.length = offsetof(struct aip_addr, in) + sizeof(struct in_addr);
608 		break;
609 	case AF_INET6:
610 		root = sc->sc_aip6;
611 		memcpy(&addr.in6, a, sizeof(addr.in6));
612 		addr.length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr);
613 		break;
614 	default:
615 		return NULL;
616 	}
617 
618 	RADIX_NODE_HEAD_RLOCK(root);
619 	node = root->rnh_matchaddr(&addr, &root->rh);
620 	if (node != NULL) {
621 		peer = ((struct wg_aip *)node)->a_peer;
622 		noise_remote_ref(peer->p_remote);
623 	} else {
624 		peer = NULL;
625 	}
626 	RADIX_NODE_HEAD_RUNLOCK(root);
627 
628 	return (peer);
629 }
630 
631 static void
wg_aip_remove_all(struct wg_softc * sc,struct wg_peer * peer)632 wg_aip_remove_all(struct wg_softc *sc, struct wg_peer *peer)
633 {
634 	struct wg_aip		*aip, *taip;
635 
636 	RADIX_NODE_HEAD_LOCK(sc->sc_aip4);
637 	LIST_FOREACH_SAFE(aip, &peer->p_aips, a_entry, taip) {
638 		if (aip->a_af == AF_INET) {
639 			if (sc->sc_aip4->rnh_deladdr(&aip->a_addr, &aip->a_mask, &sc->sc_aip4->rh) == NULL)
640 				panic("failed to delete aip %p", aip);
641 			LIST_REMOVE(aip, a_entry);
642 			peer->p_aips_num--;
643 			free(aip, M_WG);
644 		}
645 	}
646 	RADIX_NODE_HEAD_UNLOCK(sc->sc_aip4);
647 
648 	RADIX_NODE_HEAD_LOCK(sc->sc_aip6);
649 	LIST_FOREACH_SAFE(aip, &peer->p_aips, a_entry, taip) {
650 		if (aip->a_af == AF_INET6) {
651 			if (sc->sc_aip6->rnh_deladdr(&aip->a_addr, &aip->a_mask, &sc->sc_aip6->rh) == NULL)
652 				panic("failed to delete aip %p", aip);
653 			LIST_REMOVE(aip, a_entry);
654 			peer->p_aips_num--;
655 			free(aip, M_WG);
656 		}
657 	}
658 	RADIX_NODE_HEAD_UNLOCK(sc->sc_aip6);
659 
660 	if (!LIST_EMPTY(&peer->p_aips) || peer->p_aips_num != 0)
661 		panic("wg_aip_remove_all could not delete all %p", peer);
662 }
663 
664 static int
wg_socket_init(struct wg_softc * sc,in_port_t port)665 wg_socket_init(struct wg_softc *sc, in_port_t port)
666 {
667 	struct ucred *cred = sc->sc_ucred;
668 	struct socket *so4 = NULL, *so6 = NULL;
669 	int rc;
670 
671 	sx_assert(&sc->sc_lock, SX_XLOCKED);
672 
673 	if (!cred)
674 		return (EBUSY);
675 
676 	/*
677 	 * For socket creation, we use the creds of the thread that created the
678 	 * tunnel rather than the current thread to maintain the semantics that
679 	 * WireGuard has on Linux with network namespaces -- that the sockets
680 	 * are created in their home vnet so that they can be configured and
681 	 * functionally attached to a foreign vnet as the jail's only interface
682 	 * to the network.
683 	 */
684 #ifdef INET
685 	rc = socreate(AF_INET, &so4, SOCK_DGRAM, IPPROTO_UDP, cred, curthread);
686 	if (rc)
687 		goto out;
688 
689 	rc = udp_set_kernel_tunneling(so4, wg_input, NULL, sc);
690 	/*
691 	 * udp_set_kernel_tunneling can only fail if there is already a tunneling function set.
692 	 * This should never happen with a new socket.
693 	 */
694 	MPASS(rc == 0);
695 #endif
696 
697 #ifdef INET6
698 	rc = socreate(AF_INET6, &so6, SOCK_DGRAM, IPPROTO_UDP, cred, curthread);
699 	if (rc)
700 		goto out;
701 	rc = udp_set_kernel_tunneling(so6, wg_input, NULL, sc);
702 	MPASS(rc == 0);
703 #endif
704 
705 	if (sc->sc_socket.so_user_cookie) {
706 		rc = wg_socket_set_sockopt(so4, so6, SO_USER_COOKIE, &sc->sc_socket.so_user_cookie, sizeof(sc->sc_socket.so_user_cookie));
707 		if (rc)
708 			goto out;
709 	}
710 	rc = wg_socket_set_sockopt(so4, so6, SO_SETFIB, &sc->sc_socket.so_fibnum, sizeof(sc->sc_socket.so_fibnum));
711 	if (rc)
712 		goto out;
713 
714 	rc = wg_socket_bind(&so4, &so6, &port);
715 	if (!rc) {
716 		sc->sc_socket.so_port = port;
717 		wg_socket_set(sc, so4, so6);
718 	}
719 out:
720 	if (rc) {
721 		if (so4 != NULL)
722 			soclose(so4);
723 		if (so6 != NULL)
724 			soclose(so6);
725 	}
726 	return (rc);
727 }
728 
wg_socket_set_sockopt(struct socket * so4,struct socket * so6,int name,void * val,size_t len)729 static int wg_socket_set_sockopt(struct socket *so4, struct socket *so6, int name, void *val, size_t len)
730 {
731 	int ret4 = 0, ret6 = 0;
732 	struct sockopt sopt = {
733 		.sopt_dir = SOPT_SET,
734 		.sopt_level = SOL_SOCKET,
735 		.sopt_name = name,
736 		.sopt_val = val,
737 		.sopt_valsize = len
738 	};
739 
740 	if (so4)
741 		ret4 = sosetopt(so4, &sopt);
742 	if (so6)
743 		ret6 = sosetopt(so6, &sopt);
744 	return (ret4 ?: ret6);
745 }
746 
wg_socket_set_cookie(struct wg_softc * sc,uint32_t user_cookie)747 static int wg_socket_set_cookie(struct wg_softc *sc, uint32_t user_cookie)
748 {
749 	struct wg_socket *so = &sc->sc_socket;
750 	int ret;
751 
752 	sx_assert(&sc->sc_lock, SX_XLOCKED);
753 	ret = wg_socket_set_sockopt(so->so_so4, so->so_so6, SO_USER_COOKIE, &user_cookie, sizeof(user_cookie));
754 	if (!ret)
755 		so->so_user_cookie = user_cookie;
756 	return (ret);
757 }
758 
wg_socket_set_fibnum(struct wg_softc * sc,int fibnum)759 static int wg_socket_set_fibnum(struct wg_softc *sc, int fibnum)
760 {
761 	struct wg_socket *so = &sc->sc_socket;
762 	int ret;
763 
764 	sx_assert(&sc->sc_lock, SX_XLOCKED);
765 
766 	ret = wg_socket_set_sockopt(so->so_so4, so->so_so6, SO_SETFIB, &fibnum, sizeof(fibnum));
767 	if (!ret)
768 		so->so_fibnum = fibnum;
769 	return (ret);
770 }
771 
772 static void
wg_socket_uninit(struct wg_softc * sc)773 wg_socket_uninit(struct wg_softc *sc)
774 {
775 	wg_socket_set(sc, NULL, NULL);
776 }
777 
778 static void
wg_socket_set(struct wg_softc * sc,struct socket * new_so4,struct socket * new_so6)779 wg_socket_set(struct wg_softc *sc, struct socket *new_so4, struct socket *new_so6)
780 {
781 	struct wg_socket *so = &sc->sc_socket;
782 	struct socket *so4, *so6;
783 
784 	sx_assert(&sc->sc_lock, SX_XLOCKED);
785 
786 	so4 = atomic_load_ptr(&so->so_so4);
787 	so6 = atomic_load_ptr(&so->so_so6);
788 	atomic_store_ptr(&so->so_so4, new_so4);
789 	atomic_store_ptr(&so->so_so6, new_so6);
790 
791 	if (!so4 && !so6)
792 		return;
793 	NET_EPOCH_WAIT();
794 	if (so4)
795 		soclose(so4);
796 	if (so6)
797 		soclose(so6);
798 }
799 
800 static int
wg_socket_bind(struct socket ** in_so4,struct socket ** in_so6,in_port_t * requested_port)801 wg_socket_bind(struct socket **in_so4, struct socket **in_so6, in_port_t *requested_port)
802 {
803 	struct socket *so4 = *in_so4, *so6 = *in_so6;
804 	int ret4 = 0, ret6 = 0;
805 	in_port_t port = *requested_port;
806 	struct sockaddr_in sin = {
807 		.sin_len = sizeof(struct sockaddr_in),
808 		.sin_family = AF_INET,
809 		.sin_port = htons(port)
810 	};
811 	struct sockaddr_in6 sin6 = {
812 		.sin6_len = sizeof(struct sockaddr_in6),
813 		.sin6_family = AF_INET6,
814 		.sin6_port = htons(port)
815 	};
816 
817 	if (so4) {
818 		ret4 = sobind(so4, (struct sockaddr *)&sin, curthread);
819 		if (ret4 && ret4 != EADDRNOTAVAIL)
820 			return (ret4);
821 		if (!ret4 && !sin.sin_port) {
822 			struct sockaddr_in bound_sin =
823 			    { .sin_len = sizeof(bound_sin) };
824 			int ret;
825 
826 			ret = sosockaddr(so4, (struct sockaddr *)&bound_sin);
827 			if (ret)
828 				return (ret);
829 			port = ntohs(bound_sin.sin_port);
830 			sin6.sin6_port = bound_sin.sin_port;
831 		}
832 	}
833 
834 	if (so6) {
835 		ret6 = sobind(so6, (struct sockaddr *)&sin6, curthread);
836 		if (ret6 && ret6 != EADDRNOTAVAIL)
837 			return (ret6);
838 		if (!ret6 && !sin6.sin6_port) {
839 			struct sockaddr_in6 bound_sin6 =
840 			    { .sin6_len = sizeof(bound_sin6) };
841 			int ret;
842 
843 			ret = sosockaddr(so6, (struct sockaddr *)&bound_sin6);
844 			if (ret)
845 				return (ret);
846 			port = ntohs(bound_sin6.sin6_port);
847 		}
848 	}
849 
850 	if (ret4 && ret6)
851 		return (ret4);
852 	*requested_port = port;
853 	if (ret4 && !ret6 && so4) {
854 		soclose(so4);
855 		*in_so4 = NULL;
856 	} else if (ret6 && !ret4 && so6) {
857 		soclose(so6);
858 		*in_so6 = NULL;
859 	}
860 	return (0);
861 }
862 
863 static int
wg_send(struct wg_softc * sc,struct wg_endpoint * e,struct mbuf * m)864 wg_send(struct wg_softc *sc, struct wg_endpoint *e, struct mbuf *m)
865 {
866 	struct epoch_tracker et;
867 	struct sockaddr *sa;
868 	struct wg_socket *so = &sc->sc_socket;
869 	struct socket *so4, *so6;
870 	struct mbuf *control = NULL;
871 	int ret = 0;
872 	size_t len = m->m_pkthdr.len;
873 
874 	/* Get local control address before locking */
875 	if (e->e_remote.r_sa.sa_family == AF_INET) {
876 		if (e->e_local.l_in.s_addr != INADDR_ANY)
877 			control = sbcreatecontrol((caddr_t)&e->e_local.l_in,
878 			    sizeof(struct in_addr), IP_SENDSRCADDR,
879 			    IPPROTO_IP, M_NOWAIT);
880 #ifdef INET6
881 	} else if (e->e_remote.r_sa.sa_family == AF_INET6) {
882 		if (!IN6_IS_ADDR_UNSPECIFIED(&e->e_local.l_in6))
883 			control = sbcreatecontrol((caddr_t)&e->e_local.l_pktinfo6,
884 			    sizeof(struct in6_pktinfo), IPV6_PKTINFO,
885 			    IPPROTO_IPV6, M_NOWAIT);
886 #endif
887 	} else {
888 		m_freem(m);
889 		return (EAFNOSUPPORT);
890 	}
891 
892 	/* Get remote address */
893 	sa = &e->e_remote.r_sa;
894 
895 	NET_EPOCH_ENTER(et);
896 	so4 = atomic_load_ptr(&so->so_so4);
897 	so6 = atomic_load_ptr(&so->so_so6);
898 	if (e->e_remote.r_sa.sa_family == AF_INET && so4 != NULL)
899 		ret = sosend(so4, sa, NULL, m, control, 0, curthread);
900 	else if (e->e_remote.r_sa.sa_family == AF_INET6 && so6 != NULL)
901 		ret = sosend(so6, sa, NULL, m, control, 0, curthread);
902 	else {
903 		ret = ENOTCONN;
904 		m_freem(control);
905 		m_freem(m);
906 	}
907 	NET_EPOCH_EXIT(et);
908 	if (ret == 0) {
909 		if_inc_counter(sc->sc_ifp, IFCOUNTER_OPACKETS, 1);
910 		if_inc_counter(sc->sc_ifp, IFCOUNTER_OBYTES, len);
911 	}
912 	return (ret);
913 }
914 
915 static void
wg_send_buf(struct wg_softc * sc,struct wg_endpoint * e,uint8_t * buf,size_t len)916 wg_send_buf(struct wg_softc *sc, struct wg_endpoint *e, uint8_t *buf, size_t len)
917 {
918 	struct mbuf	*m;
919 	int		 ret = 0;
920 	bool		 retried = false;
921 
922 retry:
923 	m = m_get2(len, M_NOWAIT, MT_DATA, M_PKTHDR);
924 	if (!m) {
925 		ret = ENOMEM;
926 		goto out;
927 	}
928 	m_copyback(m, 0, len, buf);
929 
930 	if (ret == 0) {
931 		ret = wg_send(sc, e, m);
932 		/* Retry if we couldn't bind to e->e_local */
933 		if (ret == EADDRNOTAVAIL && !retried) {
934 			bzero(&e->e_local, sizeof(e->e_local));
935 			retried = true;
936 			goto retry;
937 		}
938 	} else {
939 		ret = wg_send(sc, e, m);
940 	}
941 out:
942 	if (ret)
943 		DPRINTF(sc, "Unable to send packet: %d\n", ret);
944 }
945 
946 /* Timers */
947 static void
wg_timers_enable(struct wg_peer * peer)948 wg_timers_enable(struct wg_peer *peer)
949 {
950 	atomic_store_bool(&peer->p_enabled, true);
951 	wg_timers_run_persistent_keepalive(peer);
952 }
953 
954 static void
wg_timers_disable(struct wg_peer * peer)955 wg_timers_disable(struct wg_peer *peer)
956 {
957 	/* By setting p_enabled = false, then calling NET_EPOCH_WAIT, we can be
958 	 * sure no new handshakes are created after the wait. This is because
959 	 * all callout_resets (scheduling the callout) are guarded by
960 	 * p_enabled. We can be sure all sections that read p_enabled and then
961 	 * optionally call callout_reset are finished as they are surrounded by
962 	 * NET_EPOCH_{ENTER,EXIT}.
963 	 *
964 	 * However, as new callouts may be scheduled during NET_EPOCH_WAIT (but
965 	 * not after), we stop all callouts leaving no callouts active.
966 	 *
967 	 * We should also pull NET_EPOCH_WAIT out of the FOREACH(peer) loops, but the
968 	 * performance impact is acceptable for the time being. */
969 	atomic_store_bool(&peer->p_enabled, false);
970 	NET_EPOCH_WAIT();
971 	atomic_store_bool(&peer->p_need_another_keepalive, false);
972 
973 	callout_stop(&peer->p_new_handshake);
974 	callout_stop(&peer->p_send_keepalive);
975 	callout_stop(&peer->p_retry_handshake);
976 	callout_stop(&peer->p_persistent_keepalive);
977 	callout_stop(&peer->p_zero_key_material);
978 }
979 
980 static void
wg_timers_set_persistent_keepalive(struct wg_peer * peer,uint16_t interval)981 wg_timers_set_persistent_keepalive(struct wg_peer *peer, uint16_t interval)
982 {
983 	struct epoch_tracker et;
984 	if (interval != peer->p_persistent_keepalive_interval) {
985 		atomic_store_16(&peer->p_persistent_keepalive_interval, interval);
986 		NET_EPOCH_ENTER(et);
987 		if (atomic_load_bool(&peer->p_enabled))
988 			wg_timers_run_persistent_keepalive(peer);
989 		NET_EPOCH_EXIT(et);
990 	}
991 }
992 
993 static void
wg_timers_get_last_handshake(struct wg_peer * peer,struct wg_timespec64 * time)994 wg_timers_get_last_handshake(struct wg_peer *peer, struct wg_timespec64 *time)
995 {
996 	mtx_lock(&peer->p_handshake_mtx);
997 	time->tv_sec = peer->p_handshake_complete.tv_sec;
998 	time->tv_nsec = peer->p_handshake_complete.tv_nsec;
999 	mtx_unlock(&peer->p_handshake_mtx);
1000 }
1001 
1002 static void
wg_timers_event_data_sent(struct wg_peer * peer)1003 wg_timers_event_data_sent(struct wg_peer *peer)
1004 {
1005 	struct epoch_tracker et;
1006 	NET_EPOCH_ENTER(et);
1007 	if (atomic_load_bool(&peer->p_enabled) &&
1008 	    !callout_pending(&peer->p_new_handshake))
1009 		callout_reset(&peer->p_new_handshake, MSEC_2_TICKS(
1010 		    NEW_HANDSHAKE_TIMEOUT * 1000 +
1011 		    arc4random_uniform(REKEY_TIMEOUT_JITTER)),
1012 		    wg_timers_run_new_handshake, peer);
1013 	NET_EPOCH_EXIT(et);
1014 }
1015 
1016 static void
wg_timers_event_data_received(struct wg_peer * peer)1017 wg_timers_event_data_received(struct wg_peer *peer)
1018 {
1019 	struct epoch_tracker et;
1020 	NET_EPOCH_ENTER(et);
1021 	if (atomic_load_bool(&peer->p_enabled)) {
1022 		if (!callout_pending(&peer->p_send_keepalive))
1023 			callout_reset(&peer->p_send_keepalive,
1024 			    MSEC_2_TICKS(KEEPALIVE_TIMEOUT * 1000),
1025 			    wg_timers_run_send_keepalive, peer);
1026 		else
1027 			atomic_store_bool(&peer->p_need_another_keepalive,
1028 			    true);
1029 	}
1030 	NET_EPOCH_EXIT(et);
1031 }
1032 
1033 static void
wg_timers_event_any_authenticated_packet_sent(struct wg_peer * peer)1034 wg_timers_event_any_authenticated_packet_sent(struct wg_peer *peer)
1035 {
1036 	callout_stop(&peer->p_send_keepalive);
1037 }
1038 
1039 static void
wg_timers_event_any_authenticated_packet_received(struct wg_peer * peer)1040 wg_timers_event_any_authenticated_packet_received(struct wg_peer *peer)
1041 {
1042 	callout_stop(&peer->p_new_handshake);
1043 }
1044 
1045 static void
wg_timers_event_any_authenticated_packet_traversal(struct wg_peer * peer)1046 wg_timers_event_any_authenticated_packet_traversal(struct wg_peer *peer)
1047 {
1048 	struct epoch_tracker et;
1049 	uint16_t interval;
1050 	NET_EPOCH_ENTER(et);
1051 	interval = atomic_load_16(&peer->p_persistent_keepalive_interval);
1052 	if (atomic_load_bool(&peer->p_enabled) && interval > 0)
1053 		callout_reset(&peer->p_persistent_keepalive,
1054 		     MSEC_2_TICKS(interval * 1000),
1055 		     wg_timers_run_persistent_keepalive, peer);
1056 	NET_EPOCH_EXIT(et);
1057 }
1058 
1059 static void
wg_timers_event_handshake_initiated(struct wg_peer * peer)1060 wg_timers_event_handshake_initiated(struct wg_peer *peer)
1061 {
1062 	struct epoch_tracker et;
1063 	NET_EPOCH_ENTER(et);
1064 	if (atomic_load_bool(&peer->p_enabled))
1065 		callout_reset(&peer->p_retry_handshake, MSEC_2_TICKS(
1066 		    REKEY_TIMEOUT * 1000 +
1067 		    arc4random_uniform(REKEY_TIMEOUT_JITTER)),
1068 		    wg_timers_run_retry_handshake, peer);
1069 	NET_EPOCH_EXIT(et);
1070 }
1071 
1072 static void
wg_timers_event_handshake_complete(struct wg_peer * peer)1073 wg_timers_event_handshake_complete(struct wg_peer *peer)
1074 {
1075 	struct epoch_tracker et;
1076 	NET_EPOCH_ENTER(et);
1077 	if (atomic_load_bool(&peer->p_enabled)) {
1078 		mtx_lock(&peer->p_handshake_mtx);
1079 		callout_stop(&peer->p_retry_handshake);
1080 		peer->p_handshake_retries = 0;
1081 		getnanotime(&peer->p_handshake_complete);
1082 		mtx_unlock(&peer->p_handshake_mtx);
1083 		wg_timers_run_send_keepalive(peer);
1084 	}
1085 	NET_EPOCH_EXIT(et);
1086 }
1087 
1088 static void
wg_timers_event_session_derived(struct wg_peer * peer)1089 wg_timers_event_session_derived(struct wg_peer *peer)
1090 {
1091 	struct epoch_tracker et;
1092 	NET_EPOCH_ENTER(et);
1093 	if (atomic_load_bool(&peer->p_enabled))
1094 		callout_reset(&peer->p_zero_key_material,
1095 		    MSEC_2_TICKS(REJECT_AFTER_TIME * 3 * 1000),
1096 		    wg_timers_run_zero_key_material, peer);
1097 	NET_EPOCH_EXIT(et);
1098 }
1099 
1100 static void
wg_timers_event_want_initiation(struct wg_peer * peer)1101 wg_timers_event_want_initiation(struct wg_peer *peer)
1102 {
1103 	struct epoch_tracker et;
1104 	NET_EPOCH_ENTER(et);
1105 	if (atomic_load_bool(&peer->p_enabled))
1106 		wg_timers_run_send_initiation(peer, false);
1107 	NET_EPOCH_EXIT(et);
1108 }
1109 
1110 static void
wg_timers_run_send_initiation(struct wg_peer * peer,bool is_retry)1111 wg_timers_run_send_initiation(struct wg_peer *peer, bool is_retry)
1112 {
1113 	if (!is_retry)
1114 		peer->p_handshake_retries = 0;
1115 	if (noise_remote_initiation_expired(peer->p_remote) == ETIMEDOUT)
1116 		wg_send_initiation(peer);
1117 }
1118 
1119 static void
wg_timers_run_retry_handshake(void * _peer)1120 wg_timers_run_retry_handshake(void *_peer)
1121 {
1122 	struct epoch_tracker et;
1123 	struct wg_peer *peer = _peer;
1124 
1125 	mtx_lock(&peer->p_handshake_mtx);
1126 	if (peer->p_handshake_retries <= MAX_TIMER_HANDSHAKES) {
1127 		peer->p_handshake_retries++;
1128 		mtx_unlock(&peer->p_handshake_mtx);
1129 
1130 		DPRINTF(peer->p_sc, "Handshake for peer %" PRIu64 " did not complete "
1131 		    "after %d seconds, retrying (try %d)\n", peer->p_id,
1132 		    REKEY_TIMEOUT, peer->p_handshake_retries + 1);
1133 		wg_peer_clear_src(peer);
1134 		wg_timers_run_send_initiation(peer, true);
1135 	} else {
1136 		mtx_unlock(&peer->p_handshake_mtx);
1137 
1138 		DPRINTF(peer->p_sc, "Handshake for peer %" PRIu64 " did not complete "
1139 		    "after %d retries, giving up\n", peer->p_id,
1140 		    MAX_TIMER_HANDSHAKES + 2);
1141 
1142 		callout_stop(&peer->p_send_keepalive);
1143 		wg_queue_purge(&peer->p_stage_queue);
1144 		NET_EPOCH_ENTER(et);
1145 		if (atomic_load_bool(&peer->p_enabled) &&
1146 		    !callout_pending(&peer->p_zero_key_material))
1147 			callout_reset(&peer->p_zero_key_material,
1148 			    MSEC_2_TICKS(REJECT_AFTER_TIME * 3 * 1000),
1149 			    wg_timers_run_zero_key_material, peer);
1150 		NET_EPOCH_EXIT(et);
1151 	}
1152 }
1153 
1154 static void
wg_timers_run_send_keepalive(void * _peer)1155 wg_timers_run_send_keepalive(void *_peer)
1156 {
1157 	struct epoch_tracker et;
1158 	struct wg_peer *peer = _peer;
1159 
1160 	wg_send_keepalive(peer);
1161 	NET_EPOCH_ENTER(et);
1162 	if (atomic_load_bool(&peer->p_enabled) &&
1163 	    atomic_load_bool(&peer->p_need_another_keepalive)) {
1164 		atomic_store_bool(&peer->p_need_another_keepalive, false);
1165 		callout_reset(&peer->p_send_keepalive,
1166 		    MSEC_2_TICKS(KEEPALIVE_TIMEOUT * 1000),
1167 		    wg_timers_run_send_keepalive, peer);
1168 	}
1169 	NET_EPOCH_EXIT(et);
1170 }
1171 
1172 static void
wg_timers_run_new_handshake(void * _peer)1173 wg_timers_run_new_handshake(void *_peer)
1174 {
1175 	struct wg_peer *peer = _peer;
1176 
1177 	DPRINTF(peer->p_sc, "Retrying handshake with peer %" PRIu64 " because we "
1178 	    "stopped hearing back after %d seconds\n",
1179 	    peer->p_id, NEW_HANDSHAKE_TIMEOUT);
1180 
1181 	wg_peer_clear_src(peer);
1182 	wg_timers_run_send_initiation(peer, false);
1183 }
1184 
1185 static void
wg_timers_run_zero_key_material(void * _peer)1186 wg_timers_run_zero_key_material(void *_peer)
1187 {
1188 	struct wg_peer *peer = _peer;
1189 
1190 	DPRINTF(peer->p_sc, "Zeroing out keys for peer %" PRIu64 ", since we "
1191 	    "haven't received a new one in %d seconds\n",
1192 	    peer->p_id, REJECT_AFTER_TIME * 3);
1193 	noise_remote_keypairs_clear(peer->p_remote);
1194 }
1195 
1196 static void
wg_timers_run_persistent_keepalive(void * _peer)1197 wg_timers_run_persistent_keepalive(void *_peer)
1198 {
1199 	struct wg_peer *peer = _peer;
1200 
1201 	if (atomic_load_16(&peer->p_persistent_keepalive_interval) > 0)
1202 		wg_send_keepalive(peer);
1203 }
1204 
1205 /* TODO Handshake */
1206 static void
wg_peer_send_buf(struct wg_peer * peer,uint8_t * buf,size_t len)1207 wg_peer_send_buf(struct wg_peer *peer, uint8_t *buf, size_t len)
1208 {
1209 	struct wg_endpoint endpoint;
1210 
1211 	counter_u64_add(peer->p_tx_bytes, len);
1212 	wg_timers_event_any_authenticated_packet_traversal(peer);
1213 	wg_timers_event_any_authenticated_packet_sent(peer);
1214 	wg_peer_get_endpoint(peer, &endpoint);
1215 	wg_send_buf(peer->p_sc, &endpoint, buf, len);
1216 }
1217 
1218 static void
wg_send_initiation(struct wg_peer * peer)1219 wg_send_initiation(struct wg_peer *peer)
1220 {
1221 	struct wg_pkt_initiation pkt;
1222 
1223 	if (noise_create_initiation(peer->p_remote, &pkt.s_idx, pkt.ue,
1224 	    pkt.es, pkt.ets) != 0)
1225 		return;
1226 
1227 	DPRINTF(peer->p_sc, "Sending handshake initiation to peer %" PRIu64 "\n", peer->p_id);
1228 
1229 	pkt.t = WG_PKT_INITIATION;
1230 	cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt,
1231 	    sizeof(pkt) - sizeof(pkt.m));
1232 	wg_peer_send_buf(peer, (uint8_t *)&pkt, sizeof(pkt));
1233 	wg_timers_event_handshake_initiated(peer);
1234 }
1235 
1236 static void
wg_send_response(struct wg_peer * peer)1237 wg_send_response(struct wg_peer *peer)
1238 {
1239 	struct wg_pkt_response pkt;
1240 
1241 	if (noise_create_response(peer->p_remote, &pkt.s_idx, &pkt.r_idx,
1242 	    pkt.ue, pkt.en) != 0)
1243 		return;
1244 
1245 	DPRINTF(peer->p_sc, "Sending handshake response to peer %" PRIu64 "\n", peer->p_id);
1246 
1247 	wg_timers_event_session_derived(peer);
1248 	pkt.t = WG_PKT_RESPONSE;
1249 	cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt,
1250 	     sizeof(pkt)-sizeof(pkt.m));
1251 	wg_peer_send_buf(peer, (uint8_t*)&pkt, sizeof(pkt));
1252 }
1253 
1254 static void
wg_send_cookie(struct wg_softc * sc,struct cookie_macs * cm,uint32_t idx,struct wg_endpoint * e)1255 wg_send_cookie(struct wg_softc *sc, struct cookie_macs *cm, uint32_t idx,
1256     struct wg_endpoint *e)
1257 {
1258 	struct wg_pkt_cookie	pkt;
1259 
1260 	DPRINTF(sc, "Sending cookie response for denied handshake message\n");
1261 
1262 	pkt.t = WG_PKT_COOKIE;
1263 	pkt.r_idx = idx;
1264 
1265 	cookie_checker_create_payload(&sc->sc_cookie, cm, pkt.nonce,
1266 	    pkt.ec, &e->e_remote.r_sa);
1267 	wg_send_buf(sc, e, (uint8_t *)&pkt, sizeof(pkt));
1268 }
1269 
1270 static void
wg_send_keepalive(struct wg_peer * peer)1271 wg_send_keepalive(struct wg_peer *peer)
1272 {
1273 	struct wg_packet *pkt;
1274 	struct mbuf *m;
1275 
1276 	if (wg_queue_len(&peer->p_stage_queue) > 0)
1277 		goto send;
1278 	if ((m = m_gethdr(M_NOWAIT, MT_DATA)) == NULL)
1279 		return;
1280 	if ((pkt = wg_packet_alloc(m)) == NULL) {
1281 		m_freem(m);
1282 		return;
1283 	}
1284 	wg_queue_push_staged(&peer->p_stage_queue, pkt);
1285 	DPRINTF(peer->p_sc, "Sending keepalive packet to peer %" PRIu64 "\n", peer->p_id);
1286 send:
1287 	wg_peer_send_staged(peer);
1288 }
1289 
1290 static void
wg_handshake(struct wg_softc * sc,struct wg_packet * pkt)1291 wg_handshake(struct wg_softc *sc, struct wg_packet *pkt)
1292 {
1293 	struct wg_pkt_initiation	*init;
1294 	struct wg_pkt_response		*resp;
1295 	struct wg_pkt_cookie		*cook;
1296 	struct wg_endpoint		*e;
1297 	struct wg_peer			*peer;
1298 	struct mbuf			*m;
1299 	struct noise_remote		*remote = NULL;
1300 	int				 res;
1301 	bool				 underload = false;
1302 	static sbintime_t		 wg_last_underload; /* sbinuptime */
1303 
1304 	underload = wg_queue_len(&sc->sc_handshake_queue) >= MAX_QUEUED_HANDSHAKES / 8;
1305 	if (underload) {
1306 		wg_last_underload = getsbinuptime();
1307 	} else if (wg_last_underload) {
1308 		underload = wg_last_underload + UNDERLOAD_TIMEOUT * SBT_1S > getsbinuptime();
1309 		if (!underload)
1310 			wg_last_underload = 0;
1311 	}
1312 
1313 	m = pkt->p_mbuf;
1314 	e = &pkt->p_endpoint;
1315 
1316 	if ((pkt->p_mbuf = m = m_pullup(m, m->m_pkthdr.len)) == NULL)
1317 		goto error;
1318 
1319 	switch (*mtod(m, uint32_t *)) {
1320 	case WG_PKT_INITIATION:
1321 		init = mtod(m, struct wg_pkt_initiation *);
1322 
1323 		res = cookie_checker_validate_macs(&sc->sc_cookie, &init->m,
1324 				init, sizeof(*init) - sizeof(init->m),
1325 				underload, &e->e_remote.r_sa,
1326 				if_getvnet(sc->sc_ifp));
1327 
1328 		if (res == EINVAL) {
1329 			DPRINTF(sc, "Invalid initiation MAC\n");
1330 			goto error;
1331 		} else if (res == ECONNREFUSED) {
1332 			DPRINTF(sc, "Handshake ratelimited\n");
1333 			goto error;
1334 		} else if (res == EAGAIN) {
1335 			wg_send_cookie(sc, &init->m, init->s_idx, e);
1336 			goto error;
1337 		} else if (res != 0) {
1338 			panic("unexpected response: %d\n", res);
1339 		}
1340 
1341 		if (noise_consume_initiation(sc->sc_local, &remote,
1342 		    init->s_idx, init->ue, init->es, init->ets) != 0) {
1343 			DPRINTF(sc, "Invalid handshake initiation\n");
1344 			goto error;
1345 		}
1346 
1347 		peer = noise_remote_arg(remote);
1348 
1349 		DPRINTF(sc, "Receiving handshake initiation from peer %" PRIu64 "\n", peer->p_id);
1350 
1351 		wg_peer_set_endpoint(peer, e);
1352 		wg_send_response(peer);
1353 		break;
1354 	case WG_PKT_RESPONSE:
1355 		resp = mtod(m, struct wg_pkt_response *);
1356 
1357 		res = cookie_checker_validate_macs(&sc->sc_cookie, &resp->m,
1358 				resp, sizeof(*resp) - sizeof(resp->m),
1359 				underload, &e->e_remote.r_sa,
1360 				if_getvnet(sc->sc_ifp));
1361 
1362 		if (res == EINVAL) {
1363 			DPRINTF(sc, "Invalid response MAC\n");
1364 			goto error;
1365 		} else if (res == ECONNREFUSED) {
1366 			DPRINTF(sc, "Handshake ratelimited\n");
1367 			goto error;
1368 		} else if (res == EAGAIN) {
1369 			wg_send_cookie(sc, &resp->m, resp->s_idx, e);
1370 			goto error;
1371 		} else if (res != 0) {
1372 			panic("unexpected response: %d\n", res);
1373 		}
1374 
1375 		if (noise_consume_response(sc->sc_local, &remote,
1376 		    resp->s_idx, resp->r_idx, resp->ue, resp->en) != 0) {
1377 			DPRINTF(sc, "Invalid handshake response\n");
1378 			goto error;
1379 		}
1380 
1381 		peer = noise_remote_arg(remote);
1382 		DPRINTF(sc, "Receiving handshake response from peer %" PRIu64 "\n", peer->p_id);
1383 
1384 		wg_peer_set_endpoint(peer, e);
1385 		wg_timers_event_session_derived(peer);
1386 		wg_timers_event_handshake_complete(peer);
1387 		break;
1388 	case WG_PKT_COOKIE:
1389 		cook = mtod(m, struct wg_pkt_cookie *);
1390 
1391 		if ((remote = noise_remote_index(sc->sc_local, cook->r_idx)) == NULL) {
1392 			DPRINTF(sc, "Unknown cookie index\n");
1393 			goto error;
1394 		}
1395 
1396 		peer = noise_remote_arg(remote);
1397 
1398 		if (cookie_maker_consume_payload(&peer->p_cookie,
1399 		    cook->nonce, cook->ec) == 0) {
1400 			DPRINTF(sc, "Receiving cookie response\n");
1401 		} else {
1402 			DPRINTF(sc, "Could not decrypt cookie response\n");
1403 			goto error;
1404 		}
1405 
1406 		goto not_authenticated;
1407 	default:
1408 		panic("invalid packet in handshake queue");
1409 	}
1410 
1411 	wg_timers_event_any_authenticated_packet_received(peer);
1412 	wg_timers_event_any_authenticated_packet_traversal(peer);
1413 
1414 not_authenticated:
1415 	counter_u64_add(peer->p_rx_bytes, m->m_pkthdr.len);
1416 	if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1);
1417 	if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, m->m_pkthdr.len);
1418 error:
1419 	if (remote != NULL)
1420 		noise_remote_put(remote);
1421 	wg_packet_free(pkt);
1422 }
1423 
1424 static void
wg_softc_handshake_receive(struct wg_softc * sc)1425 wg_softc_handshake_receive(struct wg_softc *sc)
1426 {
1427 	struct wg_packet *pkt;
1428 	while ((pkt = wg_queue_dequeue_handshake(&sc->sc_handshake_queue)) != NULL)
1429 		wg_handshake(sc, pkt);
1430 }
1431 
1432 static void
wg_mbuf_reset(struct mbuf * m)1433 wg_mbuf_reset(struct mbuf *m)
1434 {
1435 
1436 	struct m_tag *t, *tmp;
1437 
1438 	/*
1439 	 * We want to reset the mbuf to a newly allocated state, containing
1440 	 * just the packet contents. Unfortunately FreeBSD doesn't seem to
1441 	 * offer this anywhere, so we have to make it up as we go. If we can
1442 	 * get this in kern/kern_mbuf.c, that would be best.
1443 	 *
1444 	 * Notice: this may break things unexpectedly but it is better to fail
1445 	 *         closed in the extreme case than leak informtion in every
1446 	 *         case.
1447 	 *
1448 	 * With that said, all this attempts to do is remove any extraneous
1449 	 * information that could be present.
1450 	 */
1451 
1452 	M_ASSERTPKTHDR(m);
1453 
1454 	m->m_flags &= ~(M_BCAST|M_MCAST|M_VLANTAG|M_PROMISC|M_PROTOFLAGS);
1455 
1456 	M_HASHTYPE_CLEAR(m);
1457 #ifdef NUMA
1458         m->m_pkthdr.numa_domain = M_NODOM;
1459 #endif
1460 	SLIST_FOREACH_SAFE(t, &m->m_pkthdr.tags, m_tag_link, tmp) {
1461 		if ((t->m_tag_id != 0 || t->m_tag_cookie != MTAG_WGLOOP) &&
1462 		    t->m_tag_id != PACKET_TAG_MACLABEL)
1463 			m_tag_delete(m, t);
1464 	}
1465 
1466 	KASSERT((m->m_pkthdr.csum_flags & CSUM_SND_TAG) == 0,
1467 	    ("%s: mbuf %p has a send tag", __func__, m));
1468 
1469 	m->m_pkthdr.csum_flags = 0;
1470 	m->m_pkthdr.PH_per.sixtyfour[0] = 0;
1471 	m->m_pkthdr.PH_loc.sixtyfour[0] = 0;
1472 }
1473 
1474 static inline unsigned int
calculate_padding(struct wg_packet * pkt)1475 calculate_padding(struct wg_packet *pkt)
1476 {
1477 	unsigned int padded_size, last_unit = pkt->p_mbuf->m_pkthdr.len;
1478 
1479 	/* Keepalive packets don't set p_mtu, but also have a length of zero. */
1480 	if (__predict_false(pkt->p_mtu == 0)) {
1481 		padded_size = (last_unit + (WG_PKT_PADDING - 1)) &
1482 		    ~(WG_PKT_PADDING - 1);
1483 		return (padded_size - last_unit);
1484 	}
1485 
1486 	if (__predict_false(last_unit > pkt->p_mtu))
1487 		last_unit %= pkt->p_mtu;
1488 
1489 	padded_size = (last_unit + (WG_PKT_PADDING - 1)) & ~(WG_PKT_PADDING - 1);
1490 	if (pkt->p_mtu < padded_size)
1491 		padded_size = pkt->p_mtu;
1492 	return (padded_size - last_unit);
1493 }
1494 
1495 static void
wg_encrypt(struct wg_softc * sc,struct wg_packet * pkt)1496 wg_encrypt(struct wg_softc *sc, struct wg_packet *pkt)
1497 {
1498 	static const uint8_t	 padding[WG_PKT_PADDING] = { 0 };
1499 	struct wg_pkt_data	*data;
1500 	struct wg_peer		*peer;
1501 	struct noise_remote	*remote;
1502 	struct mbuf		*m;
1503 	uint32_t		 idx;
1504 	unsigned int		 padlen;
1505 	enum wg_ring_state	 state = WG_PACKET_DEAD;
1506 
1507 	remote = noise_keypair_remote(pkt->p_keypair);
1508 	peer = noise_remote_arg(remote);
1509 	m = pkt->p_mbuf;
1510 
1511 	/* Pad the packet */
1512 	padlen = calculate_padding(pkt);
1513 	if (padlen != 0 && !m_append(m, padlen, padding))
1514 		goto out;
1515 
1516 	/* Do encryption */
1517 	if (noise_keypair_encrypt(pkt->p_keypair, &idx, pkt->p_nonce, m) != 0)
1518 		goto out;
1519 
1520 	/* Put header into packet */
1521 	M_PREPEND(m, sizeof(struct wg_pkt_data), M_NOWAIT);
1522 	if (m == NULL)
1523 		goto out;
1524 	data = mtod(m, struct wg_pkt_data *);
1525 	data->t = WG_PKT_DATA;
1526 	data->r_idx = idx;
1527 	data->nonce = htole64(pkt->p_nonce);
1528 
1529 	wg_mbuf_reset(m);
1530 	state = WG_PACKET_CRYPTED;
1531 out:
1532 	pkt->p_mbuf = m;
1533 	atomic_store_rel_int(&pkt->p_state, state);
1534 	GROUPTASK_ENQUEUE(&peer->p_send);
1535 	noise_remote_put(remote);
1536 }
1537 
1538 static void
wg_decrypt(struct wg_softc * sc,struct wg_packet * pkt)1539 wg_decrypt(struct wg_softc *sc, struct wg_packet *pkt)
1540 {
1541 	struct wg_peer		*peer, *allowed_peer;
1542 	struct noise_remote	*remote;
1543 	struct mbuf		*m;
1544 	int			 len;
1545 	enum wg_ring_state	 state = WG_PACKET_DEAD;
1546 
1547 	remote = noise_keypair_remote(pkt->p_keypair);
1548 	peer = noise_remote_arg(remote);
1549 	m = pkt->p_mbuf;
1550 
1551 	/* Read nonce and then adjust to remove the header. */
1552 	pkt->p_nonce = le64toh(mtod(m, struct wg_pkt_data *)->nonce);
1553 	m_adj(m, sizeof(struct wg_pkt_data));
1554 
1555 	if (noise_keypair_decrypt(pkt->p_keypair, pkt->p_nonce, m) != 0)
1556 		goto out;
1557 
1558 	/* A packet with length 0 is a keepalive packet */
1559 	if (__predict_false(m->m_pkthdr.len == 0)) {
1560 		DPRINTF(sc, "Receiving keepalive packet from peer "
1561 		    "%" PRIu64 "\n", peer->p_id);
1562 		state = WG_PACKET_CRYPTED;
1563 		goto out;
1564 	}
1565 
1566 	/*
1567 	 * We can let the network stack handle the intricate validation of the
1568 	 * IP header, we just worry about the sizeof and the version, so we can
1569 	 * read the source address in wg_aip_lookup.
1570 	 */
1571 
1572 	if (determine_af_and_pullup(&m, &pkt->p_af) == 0) {
1573 		if (pkt->p_af == AF_INET) {
1574 			struct ip *ip = mtod(m, struct ip *);
1575 			allowed_peer = wg_aip_lookup(sc, AF_INET, &ip->ip_src);
1576 			len = ntohs(ip->ip_len);
1577 			if (len >= sizeof(struct ip) && len < m->m_pkthdr.len)
1578 				m_adj(m, len - m->m_pkthdr.len);
1579 		} else if (pkt->p_af == AF_INET6) {
1580 			struct ip6_hdr *ip6 = mtod(m, struct ip6_hdr *);
1581 			allowed_peer = wg_aip_lookup(sc, AF_INET6, &ip6->ip6_src);
1582 			len = ntohs(ip6->ip6_plen) + sizeof(struct ip6_hdr);
1583 			if (len < m->m_pkthdr.len)
1584 				m_adj(m, len - m->m_pkthdr.len);
1585 		} else
1586 			panic("determine_af_and_pullup returned unexpected value");
1587 	} else {
1588 		DPRINTF(sc, "Packet is neither ipv4 nor ipv6 from peer %" PRIu64 "\n", peer->p_id);
1589 		goto out;
1590 	}
1591 
1592 	/* We only want to compare the address, not dereference, so drop the ref. */
1593 	if (allowed_peer != NULL)
1594 		noise_remote_put(allowed_peer->p_remote);
1595 
1596 	if (__predict_false(peer != allowed_peer)) {
1597 		DPRINTF(sc, "Packet has unallowed src IP from peer %" PRIu64 "\n", peer->p_id);
1598 		goto out;
1599 	}
1600 
1601 	wg_mbuf_reset(m);
1602 	state = WG_PACKET_CRYPTED;
1603 out:
1604 	pkt->p_mbuf = m;
1605 	atomic_store_rel_int(&pkt->p_state, state);
1606 	GROUPTASK_ENQUEUE(&peer->p_recv);
1607 	noise_remote_put(remote);
1608 }
1609 
1610 static void
wg_softc_decrypt(struct wg_softc * sc)1611 wg_softc_decrypt(struct wg_softc *sc)
1612 {
1613 	struct wg_packet *pkt;
1614 
1615 	while ((pkt = wg_queue_dequeue_parallel(&sc->sc_decrypt_parallel)) != NULL)
1616 		wg_decrypt(sc, pkt);
1617 }
1618 
1619 static void
wg_softc_encrypt(struct wg_softc * sc)1620 wg_softc_encrypt(struct wg_softc *sc)
1621 {
1622 	struct wg_packet *pkt;
1623 
1624 	while ((pkt = wg_queue_dequeue_parallel(&sc->sc_encrypt_parallel)) != NULL)
1625 		wg_encrypt(sc, pkt);
1626 }
1627 
1628 static void
wg_encrypt_dispatch(struct wg_softc * sc)1629 wg_encrypt_dispatch(struct wg_softc *sc)
1630 {
1631 	/*
1632 	 * The update to encrypt_last_cpu is racey such that we may
1633 	 * reschedule the task for the same CPU multiple times, but
1634 	 * the race doesn't really matter.
1635 	 */
1636 	u_int cpu = (sc->sc_encrypt_last_cpu + 1) % mp_ncpus;
1637 	sc->sc_encrypt_last_cpu = cpu;
1638 	GROUPTASK_ENQUEUE(&sc->sc_encrypt[cpu]);
1639 }
1640 
1641 static void
wg_decrypt_dispatch(struct wg_softc * sc)1642 wg_decrypt_dispatch(struct wg_softc *sc)
1643 {
1644 	u_int cpu = (sc->sc_decrypt_last_cpu + 1) % mp_ncpus;
1645 	sc->sc_decrypt_last_cpu = cpu;
1646 	GROUPTASK_ENQUEUE(&sc->sc_decrypt[cpu]);
1647 }
1648 
1649 static void
wg_deliver_out(struct wg_peer * peer)1650 wg_deliver_out(struct wg_peer *peer)
1651 {
1652 	struct wg_endpoint	 endpoint;
1653 	struct wg_softc		*sc = peer->p_sc;
1654 	struct wg_packet	*pkt;
1655 	struct mbuf		*m;
1656 	int			 rc, len;
1657 
1658 	wg_peer_get_endpoint(peer, &endpoint);
1659 
1660 	while ((pkt = wg_queue_dequeue_serial(&peer->p_encrypt_serial)) != NULL) {
1661 		if (atomic_load_acq_int(&pkt->p_state) != WG_PACKET_CRYPTED)
1662 			goto error;
1663 
1664 		m = pkt->p_mbuf;
1665 		pkt->p_mbuf = NULL;
1666 
1667 		len = m->m_pkthdr.len;
1668 
1669 		wg_timers_event_any_authenticated_packet_traversal(peer);
1670 		wg_timers_event_any_authenticated_packet_sent(peer);
1671 		rc = wg_send(sc, &endpoint, m);
1672 		if (rc == 0) {
1673 			if (len > (sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN))
1674 				wg_timers_event_data_sent(peer);
1675 			counter_u64_add(peer->p_tx_bytes, len);
1676 		} else if (rc == EADDRNOTAVAIL) {
1677 			wg_peer_clear_src(peer);
1678 			wg_peer_get_endpoint(peer, &endpoint);
1679 			goto error;
1680 		} else {
1681 			goto error;
1682 		}
1683 		wg_packet_free(pkt);
1684 		if (noise_keep_key_fresh_send(peer->p_remote))
1685 			wg_timers_event_want_initiation(peer);
1686 		continue;
1687 error:
1688 		if_inc_counter(sc->sc_ifp, IFCOUNTER_OERRORS, 1);
1689 		wg_packet_free(pkt);
1690 	}
1691 }
1692 
1693 #ifdef DEV_NETMAP
1694 /*
1695  * Hand a packet to the netmap RX ring, via netmap's
1696  * freebsd_generic_rx_handler().
1697  */
1698 static void
wg_deliver_netmap(if_t ifp,struct mbuf * m,int af)1699 wg_deliver_netmap(if_t ifp, struct mbuf *m, int af)
1700 {
1701 	struct ether_header *eh;
1702 
1703 	M_PREPEND(m, ETHER_HDR_LEN, M_NOWAIT);
1704 	if (__predict_false(m == NULL)) {
1705 		if_inc_counter(ifp, IFCOUNTER_IQDROPS, 1);
1706 		return;
1707 	}
1708 
1709 	eh = mtod(m, struct ether_header *);
1710 	eh->ether_type = af == AF_INET ?
1711 	    htons(ETHERTYPE_IP) : htons(ETHERTYPE_IPV6);
1712 	memcpy(eh->ether_shost, "\x02\x02\x02\x02\x02\x02", ETHER_ADDR_LEN);
1713 	memcpy(eh->ether_dhost, "\xff\xff\xff\xff\xff\xff", ETHER_ADDR_LEN);
1714 	if_input(ifp, m);
1715 }
1716 #endif
1717 
1718 static void
wg_deliver_in(struct wg_peer * peer)1719 wg_deliver_in(struct wg_peer *peer)
1720 {
1721 	struct wg_softc		*sc = peer->p_sc;
1722 	if_t			 ifp = sc->sc_ifp;
1723 	struct wg_packet	*pkt;
1724 	struct mbuf		*m;
1725 	struct epoch_tracker	 et;
1726 	int			 af;
1727 
1728 	while ((pkt = wg_queue_dequeue_serial(&peer->p_decrypt_serial)) != NULL) {
1729 		if (atomic_load_acq_int(&pkt->p_state) != WG_PACKET_CRYPTED)
1730 			goto error;
1731 
1732 		m = pkt->p_mbuf;
1733 		if (noise_keypair_nonce_check(pkt->p_keypair, pkt->p_nonce) != 0)
1734 			goto error;
1735 
1736 		if (noise_keypair_received_with(pkt->p_keypair) == ECONNRESET)
1737 			wg_timers_event_handshake_complete(peer);
1738 
1739 		wg_timers_event_any_authenticated_packet_received(peer);
1740 		wg_timers_event_any_authenticated_packet_traversal(peer);
1741 		wg_peer_set_endpoint(peer, &pkt->p_endpoint);
1742 
1743 		counter_u64_add(peer->p_rx_bytes, m->m_pkthdr.len +
1744 		    sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN);
1745 		if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1);
1746 		if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, m->m_pkthdr.len +
1747 		    sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN);
1748 
1749 		if (m->m_pkthdr.len == 0)
1750 			goto done;
1751 
1752 		af = pkt->p_af;
1753 		MPASS(af == AF_INET || af == AF_INET6);
1754 		pkt->p_mbuf = NULL;
1755 
1756 		m->m_pkthdr.rcvif = ifp;
1757 
1758 		NET_EPOCH_ENTER(et);
1759 		BPF_MTAP2_AF(ifp, m, af);
1760 
1761 		CURVNET_SET(if_getvnet(ifp));
1762 		M_SETFIB(m, if_getfib(ifp));
1763 #ifdef DEV_NETMAP
1764 		if ((if_getcapenable(ifp) & IFCAP_NETMAP) != 0)
1765 			wg_deliver_netmap(ifp, m, af);
1766 		else
1767 #endif
1768 		if (af == AF_INET)
1769 			netisr_dispatch(NETISR_IP, m);
1770 		else if (af == AF_INET6)
1771 			netisr_dispatch(NETISR_IPV6, m);
1772 		CURVNET_RESTORE();
1773 		NET_EPOCH_EXIT(et);
1774 
1775 		wg_timers_event_data_received(peer);
1776 
1777 done:
1778 		if (noise_keep_key_fresh_recv(peer->p_remote))
1779 			wg_timers_event_want_initiation(peer);
1780 		wg_packet_free(pkt);
1781 		continue;
1782 error:
1783 		if_inc_counter(ifp, IFCOUNTER_IERRORS, 1);
1784 		wg_packet_free(pkt);
1785 	}
1786 }
1787 
1788 static struct wg_packet *
wg_packet_alloc(struct mbuf * m)1789 wg_packet_alloc(struct mbuf *m)
1790 {
1791 	struct wg_packet *pkt;
1792 
1793 	if ((pkt = uma_zalloc(wg_packet_zone, M_NOWAIT | M_ZERO)) == NULL)
1794 		return (NULL);
1795 	pkt->p_mbuf = m;
1796 	return (pkt);
1797 }
1798 
1799 static void
wg_packet_free(struct wg_packet * pkt)1800 wg_packet_free(struct wg_packet *pkt)
1801 {
1802 	if (pkt->p_keypair != NULL)
1803 		noise_keypair_put(pkt->p_keypair);
1804 	if (pkt->p_mbuf != NULL)
1805 		m_freem(pkt->p_mbuf);
1806 	uma_zfree(wg_packet_zone, pkt);
1807 }
1808 
1809 static void
wg_queue_init(struct wg_queue * queue,const char * name)1810 wg_queue_init(struct wg_queue *queue, const char *name)
1811 {
1812 	mtx_init(&queue->q_mtx, name, NULL, MTX_DEF);
1813 	STAILQ_INIT(&queue->q_queue);
1814 	queue->q_len = 0;
1815 }
1816 
1817 static void
wg_queue_deinit(struct wg_queue * queue)1818 wg_queue_deinit(struct wg_queue *queue)
1819 {
1820 	wg_queue_purge(queue);
1821 	mtx_destroy(&queue->q_mtx);
1822 }
1823 
1824 static size_t
wg_queue_len(struct wg_queue * queue)1825 wg_queue_len(struct wg_queue *queue)
1826 {
1827 	return (queue->q_len);
1828 }
1829 
1830 static int
wg_queue_enqueue_handshake(struct wg_queue * hs,struct wg_packet * pkt)1831 wg_queue_enqueue_handshake(struct wg_queue *hs, struct wg_packet *pkt)
1832 {
1833 	int ret = 0;
1834 	mtx_lock(&hs->q_mtx);
1835 	if (hs->q_len < MAX_QUEUED_HANDSHAKES) {
1836 		STAILQ_INSERT_TAIL(&hs->q_queue, pkt, p_parallel);
1837 		hs->q_len++;
1838 	} else {
1839 		ret = ENOBUFS;
1840 	}
1841 	mtx_unlock(&hs->q_mtx);
1842 	if (ret != 0)
1843 		wg_packet_free(pkt);
1844 	return (ret);
1845 }
1846 
1847 static struct wg_packet *
wg_queue_dequeue_handshake(struct wg_queue * hs)1848 wg_queue_dequeue_handshake(struct wg_queue *hs)
1849 {
1850 	struct wg_packet *pkt;
1851 	mtx_lock(&hs->q_mtx);
1852 	if ((pkt = STAILQ_FIRST(&hs->q_queue)) != NULL) {
1853 		STAILQ_REMOVE_HEAD(&hs->q_queue, p_parallel);
1854 		hs->q_len--;
1855 	}
1856 	mtx_unlock(&hs->q_mtx);
1857 	return (pkt);
1858 }
1859 
1860 static void
wg_queue_push_staged(struct wg_queue * staged,struct wg_packet * pkt)1861 wg_queue_push_staged(struct wg_queue *staged, struct wg_packet *pkt)
1862 {
1863 	struct wg_packet *old = NULL;
1864 
1865 	mtx_lock(&staged->q_mtx);
1866 	if (staged->q_len >= MAX_STAGED_PKT) {
1867 		old = STAILQ_FIRST(&staged->q_queue);
1868 		STAILQ_REMOVE_HEAD(&staged->q_queue, p_parallel);
1869 		staged->q_len--;
1870 	}
1871 	STAILQ_INSERT_TAIL(&staged->q_queue, pkt, p_parallel);
1872 	staged->q_len++;
1873 	mtx_unlock(&staged->q_mtx);
1874 
1875 	if (old != NULL)
1876 		wg_packet_free(old);
1877 }
1878 
1879 static void
wg_queue_enlist_staged(struct wg_queue * staged,struct wg_packet_list * list)1880 wg_queue_enlist_staged(struct wg_queue *staged, struct wg_packet_list *list)
1881 {
1882 	struct wg_packet *pkt, *tpkt;
1883 	STAILQ_FOREACH_SAFE(pkt, list, p_parallel, tpkt)
1884 		wg_queue_push_staged(staged, pkt);
1885 }
1886 
1887 static void
wg_queue_delist_staged(struct wg_queue * staged,struct wg_packet_list * list)1888 wg_queue_delist_staged(struct wg_queue *staged, struct wg_packet_list *list)
1889 {
1890 	STAILQ_INIT(list);
1891 	mtx_lock(&staged->q_mtx);
1892 	STAILQ_CONCAT(list, &staged->q_queue);
1893 	staged->q_len = 0;
1894 	mtx_unlock(&staged->q_mtx);
1895 }
1896 
1897 static void
wg_queue_purge(struct wg_queue * staged)1898 wg_queue_purge(struct wg_queue *staged)
1899 {
1900 	struct wg_packet_list list;
1901 	struct wg_packet *pkt, *tpkt;
1902 	wg_queue_delist_staged(staged, &list);
1903 	STAILQ_FOREACH_SAFE(pkt, &list, p_parallel, tpkt)
1904 		wg_packet_free(pkt);
1905 }
1906 
1907 static int
wg_queue_both(struct wg_queue * parallel,struct wg_queue * serial,struct wg_packet * pkt)1908 wg_queue_both(struct wg_queue *parallel, struct wg_queue *serial, struct wg_packet *pkt)
1909 {
1910 	pkt->p_state = WG_PACKET_UNCRYPTED;
1911 
1912 	mtx_lock(&serial->q_mtx);
1913 	if (serial->q_len < MAX_QUEUED_PKT) {
1914 		serial->q_len++;
1915 		STAILQ_INSERT_TAIL(&serial->q_queue, pkt, p_serial);
1916 	} else {
1917 		mtx_unlock(&serial->q_mtx);
1918 		wg_packet_free(pkt);
1919 		return (ENOBUFS);
1920 	}
1921 	mtx_unlock(&serial->q_mtx);
1922 
1923 	mtx_lock(&parallel->q_mtx);
1924 	if (parallel->q_len < MAX_QUEUED_PKT) {
1925 		parallel->q_len++;
1926 		STAILQ_INSERT_TAIL(&parallel->q_queue, pkt, p_parallel);
1927 	} else {
1928 		mtx_unlock(&parallel->q_mtx);
1929 		pkt->p_state = WG_PACKET_DEAD;
1930 		return (ENOBUFS);
1931 	}
1932 	mtx_unlock(&parallel->q_mtx);
1933 
1934 	return (0);
1935 }
1936 
1937 static struct wg_packet *
wg_queue_dequeue_serial(struct wg_queue * serial)1938 wg_queue_dequeue_serial(struct wg_queue *serial)
1939 {
1940 	struct wg_packet *pkt = NULL;
1941 	mtx_lock(&serial->q_mtx);
1942 	if (serial->q_len > 0 && STAILQ_FIRST(&serial->q_queue)->p_state != WG_PACKET_UNCRYPTED) {
1943 		serial->q_len--;
1944 		pkt = STAILQ_FIRST(&serial->q_queue);
1945 		STAILQ_REMOVE_HEAD(&serial->q_queue, p_serial);
1946 	}
1947 	mtx_unlock(&serial->q_mtx);
1948 	return (pkt);
1949 }
1950 
1951 static struct wg_packet *
wg_queue_dequeue_parallel(struct wg_queue * parallel)1952 wg_queue_dequeue_parallel(struct wg_queue *parallel)
1953 {
1954 	struct wg_packet *pkt = NULL;
1955 	mtx_lock(&parallel->q_mtx);
1956 	if (parallel->q_len > 0) {
1957 		parallel->q_len--;
1958 		pkt = STAILQ_FIRST(&parallel->q_queue);
1959 		STAILQ_REMOVE_HEAD(&parallel->q_queue, p_parallel);
1960 	}
1961 	mtx_unlock(&parallel->q_mtx);
1962 	return (pkt);
1963 }
1964 
1965 static bool
wg_input(struct mbuf * m,int offset,struct inpcb * inpcb,const struct sockaddr * sa,void * _sc)1966 wg_input(struct mbuf *m, int offset, struct inpcb *inpcb,
1967     const struct sockaddr *sa, void *_sc)
1968 {
1969 #ifdef INET
1970 	const struct sockaddr_in	*sin;
1971 #endif
1972 #ifdef INET6
1973 	const struct sockaddr_in6	*sin6;
1974 #endif
1975 	struct noise_remote		*remote;
1976 	struct wg_pkt_data		*data;
1977 	struct wg_packet		*pkt;
1978 	struct wg_peer			*peer;
1979 	struct wg_softc			*sc = _sc;
1980 	struct mbuf			*defragged;
1981 
1982 	defragged = m_defrag(m, M_NOWAIT);
1983 	if (defragged)
1984 		m = defragged;
1985 	m = m_unshare(m, M_NOWAIT);
1986 	if (!m) {
1987 		if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
1988 		return true;
1989 	}
1990 
1991 	/* Caller provided us with `sa`, no need for this header. */
1992 	m_adj(m, offset + sizeof(struct udphdr));
1993 
1994 	/* Pullup enough to read packet type */
1995 	if ((m = m_pullup(m, sizeof(uint32_t))) == NULL) {
1996 		if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
1997 		return true;
1998 	}
1999 
2000 	if ((pkt = wg_packet_alloc(m)) == NULL) {
2001 		if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
2002 		m_freem(m);
2003 		return true;
2004 	}
2005 
2006 	/* Save send/recv address and port for later. */
2007 	switch (sa->sa_family) {
2008 #ifdef INET
2009 	case AF_INET:
2010 		sin = (const struct sockaddr_in *)sa;
2011 		pkt->p_endpoint.e_remote.r_sin = sin[0];
2012 		pkt->p_endpoint.e_local.l_in = sin[1].sin_addr;
2013 		break;
2014 #endif
2015 #ifdef INET6
2016 	case AF_INET6:
2017 		sin6 = (const struct sockaddr_in6 *)sa;
2018 		pkt->p_endpoint.e_remote.r_sin6 = sin6[0];
2019 		pkt->p_endpoint.e_local.l_in6 = sin6[1].sin6_addr;
2020 		break;
2021 #endif
2022 	default:
2023 		goto error;
2024 	}
2025 
2026 	if ((m->m_pkthdr.len == sizeof(struct wg_pkt_initiation) &&
2027 		*mtod(m, uint32_t *) == WG_PKT_INITIATION) ||
2028 	    (m->m_pkthdr.len == sizeof(struct wg_pkt_response) &&
2029 		*mtod(m, uint32_t *) == WG_PKT_RESPONSE) ||
2030 	    (m->m_pkthdr.len == sizeof(struct wg_pkt_cookie) &&
2031 		*mtod(m, uint32_t *) == WG_PKT_COOKIE)) {
2032 
2033 		if (wg_queue_enqueue_handshake(&sc->sc_handshake_queue, pkt) != 0) {
2034 			if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
2035 			DPRINTF(sc, "Dropping handshake packet\n");
2036 		}
2037 		GROUPTASK_ENQUEUE(&sc->sc_handshake);
2038 	} else if (m->m_pkthdr.len >= sizeof(struct wg_pkt_data) +
2039 	    NOISE_AUTHTAG_LEN && *mtod(m, uint32_t *) == WG_PKT_DATA) {
2040 
2041 		/* Pullup whole header to read r_idx below. */
2042 		if ((pkt->p_mbuf = m_pullup(m, sizeof(struct wg_pkt_data))) == NULL)
2043 			goto error;
2044 
2045 		data = mtod(pkt->p_mbuf, struct wg_pkt_data *);
2046 		if ((pkt->p_keypair = noise_keypair_lookup(sc->sc_local, data->r_idx)) == NULL)
2047 			goto error;
2048 
2049 		remote = noise_keypair_remote(pkt->p_keypair);
2050 		peer = noise_remote_arg(remote);
2051 		if (wg_queue_both(&sc->sc_decrypt_parallel, &peer->p_decrypt_serial, pkt) != 0)
2052 			if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
2053 		wg_decrypt_dispatch(sc);
2054 		noise_remote_put(remote);
2055 	} else {
2056 		goto error;
2057 	}
2058 	return true;
2059 error:
2060 	if_inc_counter(sc->sc_ifp, IFCOUNTER_IERRORS, 1);
2061 	wg_packet_free(pkt);
2062 	return true;
2063 }
2064 
2065 static void
wg_peer_send_staged(struct wg_peer * peer)2066 wg_peer_send_staged(struct wg_peer *peer)
2067 {
2068 	struct wg_packet_list	 list;
2069 	struct noise_keypair	*keypair;
2070 	struct wg_packet	*pkt, *tpkt;
2071 	struct wg_softc		*sc = peer->p_sc;
2072 
2073 	wg_queue_delist_staged(&peer->p_stage_queue, &list);
2074 
2075 	if (STAILQ_EMPTY(&list))
2076 		return;
2077 
2078 	if ((keypair = noise_keypair_current(peer->p_remote)) == NULL)
2079 		goto error;
2080 
2081 	STAILQ_FOREACH(pkt, &list, p_parallel) {
2082 		if (noise_keypair_nonce_next(keypair, &pkt->p_nonce) != 0)
2083 			goto error_keypair;
2084 	}
2085 	STAILQ_FOREACH_SAFE(pkt, &list, p_parallel, tpkt) {
2086 		pkt->p_keypair = noise_keypair_ref(keypair);
2087 		if (wg_queue_both(&sc->sc_encrypt_parallel, &peer->p_encrypt_serial, pkt) != 0)
2088 			if_inc_counter(sc->sc_ifp, IFCOUNTER_OQDROPS, 1);
2089 	}
2090 	wg_encrypt_dispatch(sc);
2091 	noise_keypair_put(keypair);
2092 	return;
2093 
2094 error_keypair:
2095 	noise_keypair_put(keypair);
2096 error:
2097 	wg_queue_enlist_staged(&peer->p_stage_queue, &list);
2098 	wg_timers_event_want_initiation(peer);
2099 }
2100 
2101 static inline void
xmit_err(if_t ifp,struct mbuf * m,struct wg_packet * pkt,sa_family_t af)2102 xmit_err(if_t ifp, struct mbuf *m, struct wg_packet *pkt, sa_family_t af)
2103 {
2104 	if_inc_counter(ifp, IFCOUNTER_OERRORS, 1);
2105 	switch (af) {
2106 #ifdef INET
2107 	case AF_INET:
2108 		icmp_error(m, ICMP_UNREACH, ICMP_UNREACH_HOST, 0, 0);
2109 		if (pkt)
2110 			pkt->p_mbuf = NULL;
2111 		m = NULL;
2112 		break;
2113 #endif
2114 #ifdef INET6
2115 	case AF_INET6:
2116 		icmp6_error(m, ICMP6_DST_UNREACH, 0, 0);
2117 		if (pkt)
2118 			pkt->p_mbuf = NULL;
2119 		m = NULL;
2120 		break;
2121 #endif
2122 	}
2123 	if (pkt)
2124 		wg_packet_free(pkt);
2125 	else if (m)
2126 		m_freem(m);
2127 }
2128 
2129 static int
wg_xmit(if_t ifp,struct mbuf * m,sa_family_t af,uint32_t mtu)2130 wg_xmit(if_t ifp, struct mbuf *m, sa_family_t af, uint32_t mtu)
2131 {
2132 	struct wg_packet	*pkt = NULL;
2133 	struct wg_softc		*sc = if_getsoftc(ifp);
2134 	struct wg_peer		*peer;
2135 	int			 rc = 0;
2136 	sa_family_t		 peer_af;
2137 
2138 	/* Work around lifetime issue in the ipv6 mld code. */
2139 	if (__predict_false((if_getflags(ifp) & IFF_DYING) || !sc)) {
2140 		rc = ENXIO;
2141 		goto err_xmit;
2142 	}
2143 
2144 	if ((pkt = wg_packet_alloc(m)) == NULL) {
2145 		rc = ENOBUFS;
2146 		goto err_xmit;
2147 	}
2148 	pkt->p_mtu = mtu;
2149 	pkt->p_af = af;
2150 
2151 	if (af == AF_INET) {
2152 		peer = wg_aip_lookup(sc, AF_INET, &mtod(m, struct ip *)->ip_dst);
2153 	} else if (af == AF_INET6) {
2154 		peer = wg_aip_lookup(sc, AF_INET6, &mtod(m, struct ip6_hdr *)->ip6_dst);
2155 	} else {
2156 		rc = EAFNOSUPPORT;
2157 		goto err_xmit;
2158 	}
2159 
2160 	BPF_MTAP2_AF(ifp, m, pkt->p_af);
2161 
2162 	if (__predict_false(peer == NULL)) {
2163 		rc = ENETUNREACH;
2164 		goto err_xmit;
2165 	}
2166 
2167 	if (__predict_false(if_tunnel_check_nesting(ifp, m, MTAG_WGLOOP, MAX_LOOPS))) {
2168 		DPRINTF(sc, "Packet looped");
2169 		rc = ELOOP;
2170 		goto err_peer;
2171 	}
2172 
2173 	peer_af = peer->p_endpoint.e_remote.r_sa.sa_family;
2174 	if (__predict_false(peer_af != AF_INET && peer_af != AF_INET6)) {
2175 		DPRINTF(sc, "No valid endpoint has been configured or "
2176 			    "discovered for peer %" PRIu64 "\n", peer->p_id);
2177 		rc = EHOSTUNREACH;
2178 		goto err_peer;
2179 	}
2180 
2181 	wg_queue_push_staged(&peer->p_stage_queue, pkt);
2182 	wg_peer_send_staged(peer);
2183 	noise_remote_put(peer->p_remote);
2184 	return (0);
2185 
2186 err_peer:
2187 	noise_remote_put(peer->p_remote);
2188 err_xmit:
2189 	xmit_err(ifp, m, pkt, af);
2190 	return (rc);
2191 }
2192 
2193 static inline int
determine_af_and_pullup(struct mbuf ** m,sa_family_t * af)2194 determine_af_and_pullup(struct mbuf **m, sa_family_t *af)
2195 {
2196 	u_char ipv;
2197 	if ((*m)->m_pkthdr.len >= sizeof(struct ip6_hdr))
2198 		*m = m_pullup(*m, sizeof(struct ip6_hdr));
2199 	else if ((*m)->m_pkthdr.len >= sizeof(struct ip))
2200 		*m = m_pullup(*m, sizeof(struct ip));
2201 	else
2202 		return (EAFNOSUPPORT);
2203 	if (*m == NULL)
2204 		return (ENOBUFS);
2205 	ipv = mtod(*m, struct ip *)->ip_v;
2206 	if (ipv == 4)
2207 		*af = AF_INET;
2208 	else if (ipv == 6 && (*m)->m_pkthdr.len >= sizeof(struct ip6_hdr))
2209 		*af = AF_INET6;
2210 	else
2211 		return (EAFNOSUPPORT);
2212 	return (0);
2213 }
2214 
2215 static int
determine_ethertype_and_pullup(struct mbuf ** m,int * etp)2216 determine_ethertype_and_pullup(struct mbuf **m, int *etp)
2217 {
2218 	struct ether_header *eh;
2219 
2220 	*m = m_pullup(*m, sizeof(struct ether_header));
2221 	if (__predict_false(*m == NULL))
2222 		return (ENOBUFS);
2223 	eh = mtod(*m, struct ether_header *);
2224 	*etp = ntohs(eh->ether_type);
2225 	if (*etp != ETHERTYPE_IP && *etp != ETHERTYPE_IPV6)
2226 		return (EAFNOSUPPORT);
2227 	return (0);
2228 }
2229 
2230 /*
2231  * This should only be invoked by netmap, via nm_os_generic_xmit_frame(), to
2232  * transmit packets from the netmap TX ring.
2233  */
2234 static int
wg_transmit(if_t ifp,struct mbuf * m)2235 wg_transmit(if_t ifp, struct mbuf *m)
2236 {
2237 	sa_family_t af;
2238 	int et, ret;
2239 	struct mbuf *defragged;
2240 
2241 	KASSERT((if_getcapenable(ifp) & IFCAP_NETMAP) != 0,
2242 	    ("%s: ifp %p is not in netmap mode", __func__, ifp));
2243 
2244 	defragged = m_defrag(m, M_NOWAIT);
2245 	if (defragged)
2246 		m = defragged;
2247 	m = m_unshare(m, M_NOWAIT);
2248 	if (!m) {
2249 		xmit_err(ifp, m, NULL, AF_UNSPEC);
2250 		return (ENOBUFS);
2251 	}
2252 
2253 	ret = determine_ethertype_and_pullup(&m, &et);
2254 	if (ret) {
2255 		xmit_err(ifp, m, NULL, AF_UNSPEC);
2256 		return (ret);
2257 	}
2258 	m_adj(m, sizeof(struct ether_header));
2259 
2260 	ret = determine_af_and_pullup(&m, &af);
2261 	if (ret) {
2262 		xmit_err(ifp, m, NULL, AF_UNSPEC);
2263 		return (ret);
2264 	}
2265 
2266 	/*
2267 	 * netmap only gets to see transient errors, since it handles errors by
2268 	 * refusing to advance the transmit ring and retrying later.
2269 	 */
2270 	ret = wg_xmit(ifp, m, af, if_getmtu(ifp));
2271 	if (ret == ENOBUFS)
2272 		return (ret);
2273 	return (0);
2274 }
2275 
2276 #ifdef DEV_NETMAP
2277 /*
2278  * This should only be invoked by netmap, via nm_os_send_up(), to process
2279  * packets from the host TX ring.
2280  */
2281 static void
wg_if_input(if_t ifp,struct mbuf * m)2282 wg_if_input(if_t ifp, struct mbuf *m)
2283 {
2284 	int et;
2285 
2286 	KASSERT((if_getcapenable(ifp) & IFCAP_NETMAP) != 0,
2287 	    ("%s: ifp %p is not in netmap mode", __func__, ifp));
2288 
2289 	if (determine_ethertype_and_pullup(&m, &et) != 0) {
2290 		if_inc_counter(ifp, IFCOUNTER_IERRORS, 1);
2291 		m_freem(m);
2292 		return;
2293 	}
2294 	CURVNET_SET(if_getvnet(ifp));
2295 	switch (et) {
2296 	case ETHERTYPE_IP:
2297 		m_adj(m, sizeof(struct ether_header));
2298 		netisr_dispatch(NETISR_IP, m);
2299 		break;
2300 	case ETHERTYPE_IPV6:
2301 		m_adj(m, sizeof(struct ether_header));
2302 		netisr_dispatch(NETISR_IPV6, m);
2303 		break;
2304 	default:
2305 		__assert_unreachable();
2306 	}
2307 	CURVNET_RESTORE();
2308 }
2309 
2310 /*
2311  * Deliver a packet to the host RX ring.  Because the interface is in netmap
2312  * mode, the if_transmit() call should pass the packet to netmap_transmit().
2313  */
2314 static int
wg_xmit_netmap(if_t ifp,struct mbuf * m,int af)2315 wg_xmit_netmap(if_t ifp, struct mbuf *m, int af)
2316 {
2317 	struct ether_header *eh;
2318 
2319 	if (__predict_false(if_tunnel_check_nesting(ifp, m, MTAG_WGLOOP,
2320 	    MAX_LOOPS))) {
2321 		printf("%s:%d\n", __func__, __LINE__);
2322 		if_inc_counter(ifp, IFCOUNTER_IERRORS, 1);
2323 		m_freem(m);
2324 		return (ELOOP);
2325 	}
2326 
2327 	M_PREPEND(m, ETHER_HDR_LEN, M_NOWAIT);
2328 	if (__predict_false(m == NULL)) {
2329 		if_inc_counter(ifp, IFCOUNTER_IQDROPS, 1);
2330 		return (ENOBUFS);
2331 	}
2332 
2333 	eh = mtod(m, struct ether_header *);
2334 	eh->ether_type = af == AF_INET ?
2335 	    htons(ETHERTYPE_IP) : htons(ETHERTYPE_IPV6);
2336 	memcpy(eh->ether_shost, "\x06\x06\x06\x06\x06\x06", ETHER_ADDR_LEN);
2337 	memcpy(eh->ether_dhost, "\xff\xff\xff\xff\xff\xff", ETHER_ADDR_LEN);
2338 	return (if_transmit(ifp, m));
2339 }
2340 #endif /* DEV_NETMAP */
2341 
2342 static int
wg_output(if_t ifp,struct mbuf * m,const struct sockaddr * dst,struct route * ro)2343 wg_output(if_t ifp, struct mbuf *m, const struct sockaddr *dst, struct route *ro)
2344 {
2345 	sa_family_t parsed_af;
2346 	uint32_t af, mtu;
2347 	int ret;
2348 	struct mbuf *defragged;
2349 
2350 	/* BPF writes need to be handled specially. */
2351 	if (dst->sa_family == AF_UNSPEC || dst->sa_family == pseudo_AF_HDRCMPLT)
2352 		memcpy(&af, dst->sa_data, sizeof(af));
2353 	else
2354 		af = RO_GET_FAMILY(ro, dst);
2355 	if (af == AF_UNSPEC) {
2356 		xmit_err(ifp, m, NULL, af);
2357 		return (EAFNOSUPPORT);
2358 	}
2359 
2360 #ifdef DEV_NETMAP
2361 	if ((if_getcapenable(ifp) & IFCAP_NETMAP) != 0)
2362 		return (wg_xmit_netmap(ifp, m, af));
2363 #endif
2364 
2365 	defragged = m_defrag(m, M_NOWAIT);
2366 	if (defragged)
2367 		m = defragged;
2368 	m = m_unshare(m, M_NOWAIT);
2369 	if (!m) {
2370 		xmit_err(ifp, m, NULL, AF_UNSPEC);
2371 		return (ENOBUFS);
2372 	}
2373 
2374 	ret = determine_af_and_pullup(&m, &parsed_af);
2375 	if (ret) {
2376 		xmit_err(ifp, m, NULL, AF_UNSPEC);
2377 		return (ret);
2378 	}
2379 
2380 	MPASS(parsed_af == af);
2381 	mtu = (ro != NULL && ro->ro_mtu > 0) ? ro->ro_mtu : if_getmtu(ifp);
2382 	return (wg_xmit(ifp, m, parsed_af, mtu));
2383 }
2384 
2385 static int
wg_peer_add(struct wg_softc * sc,const nvlist_t * nvl)2386 wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl)
2387 {
2388 	uint8_t			 public[WG_KEY_SIZE];
2389 	const void *pub_key, *preshared_key = NULL;
2390 	const struct sockaddr *endpoint;
2391 	int err;
2392 	size_t size;
2393 	struct noise_remote *remote;
2394 	struct wg_peer *peer = NULL;
2395 	bool need_cleanup = false;
2396 
2397 	sx_assert(&sc->sc_lock, SX_XLOCKED);
2398 
2399 	if (!nvlist_exists_binary(nvl, "public-key")) {
2400 		return (EINVAL);
2401 	}
2402 	pub_key = nvlist_get_binary(nvl, "public-key", &size);
2403 	if (size != WG_KEY_SIZE) {
2404 		return (EINVAL);
2405 	}
2406 	if (noise_local_keys(sc->sc_local, public, NULL) == 0 &&
2407 	    bcmp(public, pub_key, WG_KEY_SIZE) == 0) {
2408 		return (0); // Silently ignored; not actually a failure.
2409 	}
2410 	if ((remote = noise_remote_lookup(sc->sc_local, pub_key)) != NULL)
2411 		peer = noise_remote_arg(remote);
2412 	if (nvlist_exists_bool(nvl, "remove") &&
2413 		nvlist_get_bool(nvl, "remove")) {
2414 		if (remote != NULL) {
2415 			wg_peer_destroy(peer);
2416 			noise_remote_put(remote);
2417 		}
2418 		return (0);
2419 	}
2420 	if (nvlist_exists_bool(nvl, "replace-allowedips") &&
2421 		nvlist_get_bool(nvl, "replace-allowedips") &&
2422 	    peer != NULL) {
2423 
2424 		wg_aip_remove_all(sc, peer);
2425 	}
2426 	if (peer == NULL) {
2427 		peer = wg_peer_create(sc, pub_key, &err);
2428 		if (peer == NULL)
2429 			goto out;
2430 		need_cleanup = true;
2431 	}
2432 	if (nvlist_exists_binary(nvl, "endpoint")) {
2433 		endpoint = nvlist_get_binary(nvl, "endpoint", &size);
2434 		if (size > sizeof(peer->p_endpoint.e_remote)) {
2435 			err = EINVAL;
2436 			goto out;
2437 		}
2438 		memcpy(&peer->p_endpoint.e_remote, endpoint, size);
2439 	}
2440 	if (nvlist_exists_binary(nvl, "preshared-key")) {
2441 		preshared_key = nvlist_get_binary(nvl, "preshared-key", &size);
2442 		if (size != WG_KEY_SIZE) {
2443 			err = EINVAL;
2444 			goto out;
2445 		}
2446 		noise_remote_set_psk(peer->p_remote, preshared_key);
2447 	}
2448 	if (nvlist_exists_number(nvl, "persistent-keepalive-interval")) {
2449 		uint64_t pki = nvlist_get_number(nvl, "persistent-keepalive-interval");
2450 		if (pki > UINT16_MAX) {
2451 			err = EINVAL;
2452 			goto out;
2453 		}
2454 		wg_timers_set_persistent_keepalive(peer, pki);
2455 	}
2456 	if (nvlist_exists_nvlist_array(nvl, "allowed-ips")) {
2457 		const void *addr;
2458 		uint64_t cidr;
2459 		const nvlist_t * const * aipl;
2460 		size_t allowedip_count;
2461 
2462 		aipl = nvlist_get_nvlist_array(nvl, "allowed-ips", &allowedip_count);
2463 		for (size_t idx = 0; idx < allowedip_count; idx++) {
2464 			if (!nvlist_exists_number(aipl[idx], "cidr"))
2465 				continue;
2466 			cidr = nvlist_get_number(aipl[idx], "cidr");
2467 			if (nvlist_exists_binary(aipl[idx], "ipv4")) {
2468 				addr = nvlist_get_binary(aipl[idx], "ipv4", &size);
2469 				if (addr == NULL || cidr > 32 || size != sizeof(struct in_addr)) {
2470 					err = EINVAL;
2471 					goto out;
2472 				}
2473 				if ((err = wg_aip_add(sc, peer, AF_INET, addr, cidr)) != 0)
2474 					goto out;
2475 			} else if (nvlist_exists_binary(aipl[idx], "ipv6")) {
2476 				addr = nvlist_get_binary(aipl[idx], "ipv6", &size);
2477 				if (addr == NULL || cidr > 128 || size != sizeof(struct in6_addr)) {
2478 					err = EINVAL;
2479 					goto out;
2480 				}
2481 				if ((err = wg_aip_add(sc, peer, AF_INET6, addr, cidr)) != 0)
2482 					goto out;
2483 			} else {
2484 				continue;
2485 			}
2486 		}
2487 	}
2488 	if (remote != NULL)
2489 		noise_remote_put(remote);
2490 	return (0);
2491 out:
2492 	if (need_cleanup) /* If we fail, only destroy if it was new. */
2493 		wg_peer_destroy(peer);
2494 	if (remote != NULL)
2495 		noise_remote_put(remote);
2496 	return (err);
2497 }
2498 
2499 static int
wgc_set(struct wg_softc * sc,struct wg_data_io * wgd)2500 wgc_set(struct wg_softc *sc, struct wg_data_io *wgd)
2501 {
2502 	uint8_t public[WG_KEY_SIZE], private[WG_KEY_SIZE];
2503 	if_t ifp;
2504 	void *nvlpacked;
2505 	nvlist_t *nvl;
2506 	ssize_t size;
2507 	int err;
2508 
2509 	ifp = sc->sc_ifp;
2510 	if (wgd->wgd_size == 0 || wgd->wgd_data == NULL)
2511 		return (EFAULT);
2512 
2513 	/* Can nvlists be streamed in? It's not nice to impose arbitrary limits like that but
2514 	 * there needs to be _some_ limitation. */
2515 	if (wgd->wgd_size >= UINT32_MAX / 2)
2516 		return (E2BIG);
2517 
2518 	nvlpacked = malloc(wgd->wgd_size, M_TEMP, M_WAITOK | M_ZERO);
2519 
2520 	err = copyin(wgd->wgd_data, nvlpacked, wgd->wgd_size);
2521 	if (err)
2522 		goto out;
2523 	nvl = nvlist_unpack(nvlpacked, wgd->wgd_size, 0);
2524 	if (nvl == NULL) {
2525 		err = EBADMSG;
2526 		goto out;
2527 	}
2528 	sx_xlock(&sc->sc_lock);
2529 	if (nvlist_exists_bool(nvl, "replace-peers") &&
2530 		nvlist_get_bool(nvl, "replace-peers"))
2531 		wg_peer_destroy_all(sc);
2532 	if (nvlist_exists_number(nvl, "listen-port")) {
2533 		uint64_t new_port = nvlist_get_number(nvl, "listen-port");
2534 		if (new_port > UINT16_MAX) {
2535 			err = EINVAL;
2536 			goto out_locked;
2537 		}
2538 		if (new_port != sc->sc_socket.so_port) {
2539 			if ((if_getdrvflags(ifp) & IFF_DRV_RUNNING) != 0) {
2540 				if ((err = wg_socket_init(sc, new_port)) != 0)
2541 					goto out_locked;
2542 			} else
2543 				sc->sc_socket.so_port = new_port;
2544 		}
2545 	}
2546 	if (nvlist_exists_binary(nvl, "private-key")) {
2547 		const void *key = nvlist_get_binary(nvl, "private-key", &size);
2548 		if (size != WG_KEY_SIZE) {
2549 			err = EINVAL;
2550 			goto out_locked;
2551 		}
2552 
2553 		if (noise_local_keys(sc->sc_local, NULL, private) != 0 ||
2554 		    timingsafe_bcmp(private, key, WG_KEY_SIZE) != 0) {
2555 			struct wg_peer *peer;
2556 
2557 			if (curve25519_generate_public(public, key)) {
2558 				/* Peer conflict: remove conflicting peer. */
2559 				struct noise_remote *remote;
2560 				if ((remote = noise_remote_lookup(sc->sc_local,
2561 				    public)) != NULL) {
2562 					peer = noise_remote_arg(remote);
2563 					wg_peer_destroy(peer);
2564 					noise_remote_put(remote);
2565 				}
2566 			}
2567 
2568 			/*
2569 			 * Set the private key and invalidate all existing
2570 			 * handshakes.
2571 			 */
2572 			/* Note: we might be removing the private key. */
2573 			noise_local_private(sc->sc_local, key);
2574 			if (noise_local_keys(sc->sc_local, NULL, NULL) == 0)
2575 				cookie_checker_update(&sc->sc_cookie, public);
2576 			else
2577 				cookie_checker_update(&sc->sc_cookie, NULL);
2578 		}
2579 	}
2580 	if (nvlist_exists_number(nvl, "user-cookie")) {
2581 		uint64_t user_cookie = nvlist_get_number(nvl, "user-cookie");
2582 		if (user_cookie > UINT32_MAX) {
2583 			err = EINVAL;
2584 			goto out_locked;
2585 		}
2586 		err = wg_socket_set_cookie(sc, user_cookie);
2587 		if (err)
2588 			goto out_locked;
2589 	}
2590 	if (nvlist_exists_nvlist_array(nvl, "peers")) {
2591 		size_t peercount;
2592 		const nvlist_t * const*nvl_peers;
2593 
2594 		nvl_peers = nvlist_get_nvlist_array(nvl, "peers", &peercount);
2595 		for (int i = 0; i < peercount; i++) {
2596 			err = wg_peer_add(sc, nvl_peers[i]);
2597 			if (err != 0)
2598 				goto out_locked;
2599 		}
2600 	}
2601 
2602 out_locked:
2603 	sx_xunlock(&sc->sc_lock);
2604 	nvlist_destroy(nvl);
2605 out:
2606 	zfree(nvlpacked, M_TEMP);
2607 	return (err);
2608 }
2609 
2610 static int
wgc_get(struct wg_softc * sc,struct wg_data_io * wgd)2611 wgc_get(struct wg_softc *sc, struct wg_data_io *wgd)
2612 {
2613 	uint8_t public_key[WG_KEY_SIZE] = { 0 };
2614 	uint8_t private_key[WG_KEY_SIZE] = { 0 };
2615 	uint8_t preshared_key[NOISE_SYMMETRIC_KEY_LEN] = { 0 };
2616 	nvlist_t *nvl, *nvl_peer, *nvl_aip, **nvl_peers, **nvl_aips;
2617 	size_t size, peer_count, aip_count, i, j;
2618 	struct wg_timespec64 ts64;
2619 	struct wg_peer *peer;
2620 	struct wg_aip *aip;
2621 	void *packed;
2622 	int err = 0;
2623 
2624 	nvl = nvlist_create(0);
2625 	if (!nvl)
2626 		return (ENOMEM);
2627 
2628 	sx_slock(&sc->sc_lock);
2629 
2630 	if (sc->sc_socket.so_port != 0)
2631 		nvlist_add_number(nvl, "listen-port", sc->sc_socket.so_port);
2632 	if (sc->sc_socket.so_user_cookie != 0)
2633 		nvlist_add_number(nvl, "user-cookie", sc->sc_socket.so_user_cookie);
2634 	if (noise_local_keys(sc->sc_local, public_key, private_key) == 0) {
2635 		nvlist_add_binary(nvl, "public-key", public_key, WG_KEY_SIZE);
2636 		if (wgc_privileged(sc))
2637 			nvlist_add_binary(nvl, "private-key", private_key, WG_KEY_SIZE);
2638 		explicit_bzero(private_key, sizeof(private_key));
2639 	}
2640 	peer_count = sc->sc_peers_num;
2641 	if (peer_count) {
2642 		nvl_peers = mallocarray(peer_count, sizeof(void *), M_NVLIST, M_WAITOK | M_ZERO);
2643 		i = 0;
2644 		TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) {
2645 			if (i >= peer_count)
2646 				panic("peers changed from under us");
2647 
2648 			nvl_peers[i++] = nvl_peer = nvlist_create(0);
2649 			if (!nvl_peer) {
2650 				err = ENOMEM;
2651 				goto err_peer;
2652 			}
2653 
2654 			(void)noise_remote_keys(peer->p_remote, public_key, preshared_key);
2655 			nvlist_add_binary(nvl_peer, "public-key", public_key, sizeof(public_key));
2656 			if (wgc_privileged(sc))
2657 				nvlist_add_binary(nvl_peer, "preshared-key", preshared_key, sizeof(preshared_key));
2658 			explicit_bzero(preshared_key, sizeof(preshared_key));
2659 			if (peer->p_endpoint.e_remote.r_sa.sa_family == AF_INET)
2660 				nvlist_add_binary(nvl_peer, "endpoint", &peer->p_endpoint.e_remote, sizeof(struct sockaddr_in));
2661 			else if (peer->p_endpoint.e_remote.r_sa.sa_family == AF_INET6)
2662 				nvlist_add_binary(nvl_peer, "endpoint", &peer->p_endpoint.e_remote, sizeof(struct sockaddr_in6));
2663 			wg_timers_get_last_handshake(peer, &ts64);
2664 			nvlist_add_binary(nvl_peer, "last-handshake-time", &ts64, sizeof(ts64));
2665 			nvlist_add_number(nvl_peer, "persistent-keepalive-interval", peer->p_persistent_keepalive_interval);
2666 			nvlist_add_number(nvl_peer, "rx-bytes", counter_u64_fetch(peer->p_rx_bytes));
2667 			nvlist_add_number(nvl_peer, "tx-bytes", counter_u64_fetch(peer->p_tx_bytes));
2668 
2669 			aip_count = peer->p_aips_num;
2670 			if (aip_count) {
2671 				nvl_aips = mallocarray(aip_count, sizeof(void *), M_NVLIST, M_WAITOK | M_ZERO);
2672 				j = 0;
2673 				LIST_FOREACH(aip, &peer->p_aips, a_entry) {
2674 					if (j >= aip_count)
2675 						panic("aips changed from under us");
2676 
2677 					nvl_aips[j++] = nvl_aip = nvlist_create(0);
2678 					if (!nvl_aip) {
2679 						err = ENOMEM;
2680 						goto err_aip;
2681 					}
2682 					if (aip->a_af == AF_INET) {
2683 						nvlist_add_binary(nvl_aip, "ipv4", &aip->a_addr.in, sizeof(aip->a_addr.in));
2684 						nvlist_add_number(nvl_aip, "cidr", bitcount32(aip->a_mask.ip));
2685 					}
2686 #ifdef INET6
2687 					else if (aip->a_af == AF_INET6) {
2688 						nvlist_add_binary(nvl_aip, "ipv6", &aip->a_addr.in6, sizeof(aip->a_addr.in6));
2689 						nvlist_add_number(nvl_aip, "cidr", in6_mask2len(&aip->a_mask.in6, NULL));
2690 					}
2691 #endif
2692 				}
2693 				nvlist_add_nvlist_array(nvl_peer, "allowed-ips", (const nvlist_t *const *)nvl_aips, aip_count);
2694 			err_aip:
2695 				for (j = 0; j < aip_count; ++j)
2696 					nvlist_destroy(nvl_aips[j]);
2697 				free(nvl_aips, M_NVLIST);
2698 				if (err)
2699 					goto err_peer;
2700 			}
2701 		}
2702 		nvlist_add_nvlist_array(nvl, "peers", (const nvlist_t * const *)nvl_peers, peer_count);
2703 	err_peer:
2704 		for (i = 0; i < peer_count; ++i)
2705 			nvlist_destroy(nvl_peers[i]);
2706 		free(nvl_peers, M_NVLIST);
2707 		if (err) {
2708 			sx_sunlock(&sc->sc_lock);
2709 			goto err;
2710 		}
2711 	}
2712 	sx_sunlock(&sc->sc_lock);
2713 	packed = nvlist_pack(nvl, &size);
2714 	if (!packed) {
2715 		err = ENOMEM;
2716 		goto err;
2717 	}
2718 	if (!wgd->wgd_size) {
2719 		wgd->wgd_size = size;
2720 		goto out;
2721 	}
2722 	if (wgd->wgd_size < size) {
2723 		err = ENOSPC;
2724 		goto out;
2725 	}
2726 	err = copyout(packed, wgd->wgd_data, size);
2727 	wgd->wgd_size = size;
2728 
2729 out:
2730 	zfree(packed, M_NVLIST);
2731 err:
2732 	nvlist_destroy(nvl);
2733 	return (err);
2734 }
2735 
2736 static int
wg_ioctl(if_t ifp,u_long cmd,caddr_t data)2737 wg_ioctl(if_t ifp, u_long cmd, caddr_t data)
2738 {
2739 	struct wg_data_io *wgd = (struct wg_data_io *)data;
2740 	struct ifreq *ifr = (struct ifreq *)data;
2741 	struct wg_softc *sc;
2742 	int ret = 0;
2743 
2744 	sx_slock(&wg_sx);
2745 	sc = if_getsoftc(ifp);
2746 	if (!sc) {
2747 		ret = ENXIO;
2748 		goto out;
2749 	}
2750 
2751 	switch (cmd) {
2752 	case SIOCSWG:
2753 		ret = priv_check(curthread, PRIV_NET_WG);
2754 		if (ret == 0)
2755 			ret = wgc_set(sc, wgd);
2756 		break;
2757 	case SIOCGWG:
2758 		ret = wgc_get(sc, wgd);
2759 		break;
2760 	/* Interface IOCTLs */
2761 	case SIOCSIFADDR:
2762 		/*
2763 		 * This differs from *BSD norms, but is more uniform with how
2764 		 * WireGuard behaves elsewhere.
2765 		 */
2766 		break;
2767 	case SIOCSIFFLAGS:
2768 		if (if_getflags(ifp) & IFF_UP)
2769 			ret = wg_up(sc);
2770 		else
2771 			wg_down(sc);
2772 		break;
2773 	case SIOCSIFMTU:
2774 		if (ifr->ifr_mtu <= 0 || ifr->ifr_mtu > MAX_MTU)
2775 			ret = EINVAL;
2776 		else
2777 			if_setmtu(ifp, ifr->ifr_mtu);
2778 		break;
2779 	case SIOCADDMULTI:
2780 	case SIOCDELMULTI:
2781 		break;
2782 	case SIOCGTUNFIB:
2783 		ifr->ifr_fib = sc->sc_socket.so_fibnum;
2784 		break;
2785 	case SIOCSTUNFIB:
2786 		ret = priv_check(curthread, PRIV_NET_WG);
2787 		if (ret)
2788 			break;
2789 		ret = priv_check(curthread, PRIV_NET_SETIFFIB);
2790 		if (ret)
2791 			break;
2792 		sx_xlock(&sc->sc_lock);
2793 		ret = wg_socket_set_fibnum(sc, ifr->ifr_fib);
2794 		sx_xunlock(&sc->sc_lock);
2795 		break;
2796 	default:
2797 		ret = ENOTTY;
2798 	}
2799 
2800 out:
2801 	sx_sunlock(&wg_sx);
2802 	return (ret);
2803 }
2804 
2805 static int
wg_up(struct wg_softc * sc)2806 wg_up(struct wg_softc *sc)
2807 {
2808 	if_t ifp = sc->sc_ifp;
2809 	struct wg_peer *peer;
2810 	int rc = EBUSY;
2811 
2812 	sx_xlock(&sc->sc_lock);
2813 	/* Jail's being removed, no more wg_up(). */
2814 	if ((sc->sc_flags & WGF_DYING) != 0)
2815 		goto out;
2816 
2817 	/* Silent success if we're already running. */
2818 	rc = 0;
2819 	if (if_getdrvflags(ifp) & IFF_DRV_RUNNING)
2820 		goto out;
2821 	if_setdrvflagbits(ifp, IFF_DRV_RUNNING, 0);
2822 
2823 	rc = wg_socket_init(sc, sc->sc_socket.so_port);
2824 	if (rc == 0) {
2825 		TAILQ_FOREACH(peer, &sc->sc_peers, p_entry)
2826 			wg_timers_enable(peer);
2827 		if_link_state_change(sc->sc_ifp, LINK_STATE_UP);
2828 	} else {
2829 		if_setdrvflagbits(ifp, 0, IFF_DRV_RUNNING);
2830 		DPRINTF(sc, "Unable to initialize sockets: %d\n", rc);
2831 	}
2832 out:
2833 	sx_xunlock(&sc->sc_lock);
2834 	return (rc);
2835 }
2836 
2837 static void
wg_down(struct wg_softc * sc)2838 wg_down(struct wg_softc *sc)
2839 {
2840 	if_t ifp = sc->sc_ifp;
2841 	struct wg_peer *peer;
2842 
2843 	sx_xlock(&sc->sc_lock);
2844 	if (!(if_getdrvflags(ifp) & IFF_DRV_RUNNING)) {
2845 		sx_xunlock(&sc->sc_lock);
2846 		return;
2847 	}
2848 	if_setdrvflagbits(ifp, 0, IFF_DRV_RUNNING);
2849 
2850 	TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) {
2851 		wg_queue_purge(&peer->p_stage_queue);
2852 		wg_timers_disable(peer);
2853 	}
2854 
2855 	wg_queue_purge(&sc->sc_handshake_queue);
2856 
2857 	TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) {
2858 		noise_remote_handshake_clear(peer->p_remote);
2859 		noise_remote_keypairs_clear(peer->p_remote);
2860 	}
2861 
2862 	if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
2863 	wg_socket_uninit(sc);
2864 
2865 	sx_xunlock(&sc->sc_lock);
2866 }
2867 
2868 static int
wg_clone_create(struct if_clone * ifc,char * name,size_t len,struct ifc_data * ifd,struct ifnet ** ifpp)2869 wg_clone_create(struct if_clone *ifc, char *name, size_t len,
2870     struct ifc_data *ifd, struct ifnet **ifpp)
2871 {
2872 	struct wg_softc *sc;
2873 	if_t ifp;
2874 
2875 	sc = malloc(sizeof(*sc), M_WG, M_WAITOK | M_ZERO);
2876 
2877 	sc->sc_local = noise_local_alloc(sc);
2878 
2879 	sc->sc_encrypt = mallocarray(sizeof(struct grouptask), mp_ncpus, M_WG, M_WAITOK | M_ZERO);
2880 
2881 	sc->sc_decrypt = mallocarray(sizeof(struct grouptask), mp_ncpus, M_WG, M_WAITOK | M_ZERO);
2882 
2883 	if (!rn_inithead((void **)&sc->sc_aip4, offsetof(struct aip_addr, in) * NBBY))
2884 		goto free_decrypt;
2885 
2886 	if (!rn_inithead((void **)&sc->sc_aip6, offsetof(struct aip_addr, in6) * NBBY))
2887 		goto free_aip4;
2888 
2889 	atomic_add_int(&clone_count, 1);
2890 	ifp = sc->sc_ifp = if_alloc(IFT_WIREGUARD);
2891 
2892 	sc->sc_ucred = crhold(curthread->td_ucred);
2893 	sc->sc_socket.so_fibnum = curthread->td_proc->p_fibnum;
2894 	sc->sc_socket.so_port = 0;
2895 
2896 	TAILQ_INIT(&sc->sc_peers);
2897 	sc->sc_peers_num = 0;
2898 
2899 	cookie_checker_init(&sc->sc_cookie);
2900 
2901 	RADIX_NODE_HEAD_LOCK_INIT(sc->sc_aip4);
2902 	RADIX_NODE_HEAD_LOCK_INIT(sc->sc_aip6);
2903 
2904 	GROUPTASK_INIT(&sc->sc_handshake, 0, (gtask_fn_t *)wg_softc_handshake_receive, sc);
2905 	taskqgroup_attach(qgroup_wg_tqg, &sc->sc_handshake, sc, NULL, NULL, "wg tx initiation");
2906 	wg_queue_init(&sc->sc_handshake_queue, "hsq");
2907 
2908 	for (int i = 0; i < mp_ncpus; i++) {
2909 		GROUPTASK_INIT(&sc->sc_encrypt[i], 0,
2910 		     (gtask_fn_t *)wg_softc_encrypt, sc);
2911 		taskqgroup_attach_cpu(qgroup_wg_tqg, &sc->sc_encrypt[i], sc, i, NULL, NULL, "wg encrypt");
2912 		GROUPTASK_INIT(&sc->sc_decrypt[i], 0,
2913 		    (gtask_fn_t *)wg_softc_decrypt, sc);
2914 		taskqgroup_attach_cpu(qgroup_wg_tqg, &sc->sc_decrypt[i], sc, i, NULL, NULL, "wg decrypt");
2915 	}
2916 
2917 	wg_queue_init(&sc->sc_encrypt_parallel, "encp");
2918 	wg_queue_init(&sc->sc_decrypt_parallel, "decp");
2919 
2920 	sx_init(&sc->sc_lock, "wg softc lock");
2921 
2922 	if_setsoftc(ifp, sc);
2923 	if_setcapabilities(ifp, WG_CAPS);
2924 	if_setcapenable(ifp, WG_CAPS);
2925 	if_initname(ifp, wgname, ifd->unit);
2926 
2927 	if_setmtu(ifp, DEFAULT_MTU);
2928 	if_setflags(ifp, IFF_NOARP | IFF_MULTICAST);
2929 	if_setinitfn(ifp, wg_init);
2930 	if_setreassignfn(ifp, wg_reassign);
2931 	if_setqflushfn(ifp, wg_qflush);
2932 	if_settransmitfn(ifp, wg_transmit);
2933 #ifdef DEV_NETMAP
2934 	if_setinputfn(ifp, wg_if_input);
2935 #endif
2936 	if_setoutputfn(ifp, wg_output);
2937 	if_setioctlfn(ifp, wg_ioctl);
2938 	if_attach(ifp);
2939 	bpfattach(ifp, DLT_NULL, sizeof(uint32_t));
2940 #ifdef INET6
2941 	ND_IFINFO(ifp)->flags &= ~ND6_IFF_AUTO_LINKLOCAL;
2942 	ND_IFINFO(ifp)->flags |= ND6_IFF_NO_DAD;
2943 #endif
2944 	sx_xlock(&wg_sx);
2945 	LIST_INSERT_HEAD(&wg_list, sc, sc_entry);
2946 	sx_xunlock(&wg_sx);
2947 	*ifpp = ifp;
2948 	return (0);
2949 free_aip4:
2950 	RADIX_NODE_HEAD_DESTROY(sc->sc_aip4);
2951 	free(sc->sc_aip4, M_RTABLE);
2952 free_decrypt:
2953 	free(sc->sc_decrypt, M_WG);
2954 	free(sc->sc_encrypt, M_WG);
2955 	noise_local_free(sc->sc_local, NULL);
2956 	free(sc, M_WG);
2957 	return (ENOMEM);
2958 }
2959 
2960 static void
wg_clone_deferred_free(struct noise_local * l)2961 wg_clone_deferred_free(struct noise_local *l)
2962 {
2963 	struct wg_softc *sc = noise_local_arg(l);
2964 
2965 	free(sc, M_WG);
2966 	atomic_add_int(&clone_count, -1);
2967 }
2968 
2969 static int
wg_clone_destroy(struct if_clone * ifc,if_t ifp,uint32_t flags)2970 wg_clone_destroy(struct if_clone *ifc, if_t ifp, uint32_t flags)
2971 {
2972 	struct wg_softc *sc = if_getsoftc(ifp);
2973 	struct ucred *cred;
2974 
2975 	sx_xlock(&wg_sx);
2976 	if_setsoftc(ifp, NULL);
2977 	sx_xlock(&sc->sc_lock);
2978 	sc->sc_flags |= WGF_DYING;
2979 	cred = sc->sc_ucred;
2980 	sc->sc_ucred = NULL;
2981 	sx_xunlock(&sc->sc_lock);
2982 	LIST_REMOVE(sc, sc_entry);
2983 	sx_xunlock(&wg_sx);
2984 
2985 	if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
2986 	CURVNET_SET(if_getvnet(sc->sc_ifp));
2987 	if_purgeaddrs(sc->sc_ifp);
2988 	CURVNET_RESTORE();
2989 
2990 	sx_xlock(&sc->sc_lock);
2991 	wg_socket_uninit(sc);
2992 	sx_xunlock(&sc->sc_lock);
2993 
2994 	/*
2995 	 * No guarantees that all traffic have passed until the epoch has
2996 	 * elapsed with the socket closed.
2997 	 */
2998 	NET_EPOCH_WAIT();
2999 
3000 	taskqgroup_drain_all(qgroup_wg_tqg);
3001 	sx_xlock(&sc->sc_lock);
3002 	wg_peer_destroy_all(sc);
3003 	NET_EPOCH_DRAIN_CALLBACKS();
3004 	sx_xunlock(&sc->sc_lock);
3005 	sx_destroy(&sc->sc_lock);
3006 	taskqgroup_detach(qgroup_wg_tqg, &sc->sc_handshake);
3007 	for (int i = 0; i < mp_ncpus; i++) {
3008 		taskqgroup_detach(qgroup_wg_tqg, &sc->sc_encrypt[i]);
3009 		taskqgroup_detach(qgroup_wg_tqg, &sc->sc_decrypt[i]);
3010 	}
3011 	free(sc->sc_encrypt, M_WG);
3012 	free(sc->sc_decrypt, M_WG);
3013 	wg_queue_deinit(&sc->sc_handshake_queue);
3014 	wg_queue_deinit(&sc->sc_encrypt_parallel);
3015 	wg_queue_deinit(&sc->sc_decrypt_parallel);
3016 
3017 	RADIX_NODE_HEAD_DESTROY(sc->sc_aip4);
3018 	RADIX_NODE_HEAD_DESTROY(sc->sc_aip6);
3019 	rn_detachhead((void **)&sc->sc_aip4);
3020 	rn_detachhead((void **)&sc->sc_aip6);
3021 
3022 	cookie_checker_free(&sc->sc_cookie);
3023 
3024 	if (cred != NULL)
3025 		crfree(cred);
3026 	bpfdetach(sc->sc_ifp);
3027 	if_detach(sc->sc_ifp);
3028 	if_free(sc->sc_ifp);
3029 
3030 	noise_local_free(sc->sc_local, wg_clone_deferred_free);
3031 
3032 	return (0);
3033 }
3034 
3035 static void
wg_qflush(if_t ifp __unused)3036 wg_qflush(if_t ifp __unused)
3037 {
3038 }
3039 
3040 /*
3041  * Privileged information (private-key, preshared-key) are only exported for
3042  * root and jailed root by default.
3043  */
3044 static bool
wgc_privileged(struct wg_softc * sc)3045 wgc_privileged(struct wg_softc *sc)
3046 {
3047 	struct thread *td;
3048 
3049 	td = curthread;
3050 	return (priv_check(td, PRIV_NET_WG) == 0);
3051 }
3052 
3053 static void
wg_reassign(if_t ifp,struct vnet * new_vnet __unused,char * unused __unused)3054 wg_reassign(if_t ifp, struct vnet *new_vnet __unused,
3055     char *unused __unused)
3056 {
3057 	struct wg_softc *sc;
3058 
3059 	sc = if_getsoftc(ifp);
3060 	wg_down(sc);
3061 }
3062 
3063 static void
wg_init(void * xsc)3064 wg_init(void *xsc)
3065 {
3066 	struct wg_softc *sc;
3067 
3068 	sc = xsc;
3069 	wg_up(sc);
3070 }
3071 
3072 static void
vnet_wg_init(const void * unused __unused)3073 vnet_wg_init(const void *unused __unused)
3074 {
3075 	struct if_clone_addreq req = {
3076 		.create_f = wg_clone_create,
3077 		.destroy_f = wg_clone_destroy,
3078 		.flags = IFC_F_AUTOUNIT,
3079 	};
3080 	V_wg_cloner = ifc_attach_cloner(wgname, &req);
3081 }
3082 VNET_SYSINIT(vnet_wg_init, SI_SUB_PROTO_IFATTACHDOMAIN, SI_ORDER_ANY,
3083 	     vnet_wg_init, NULL);
3084 
3085 static void
vnet_wg_uninit(const void * unused __unused)3086 vnet_wg_uninit(const void *unused __unused)
3087 {
3088 	if (V_wg_cloner)
3089 		ifc_detach_cloner(V_wg_cloner);
3090 }
3091 VNET_SYSUNINIT(vnet_wg_uninit, SI_SUB_PROTO_IFATTACHDOMAIN, SI_ORDER_ANY,
3092 	       vnet_wg_uninit, NULL);
3093 
3094 static int
wg_prison_remove(void * obj,void * data __unused)3095 wg_prison_remove(void *obj, void *data __unused)
3096 {
3097 	const struct prison *pr = obj;
3098 	struct wg_softc *sc;
3099 
3100 	/*
3101 	 * Do a pass through all if_wg interfaces and release creds on any from
3102 	 * the jail that are supposed to be going away.  This will, in turn, let
3103 	 * the jail die so that we don't end up with Schrödinger's jail.
3104 	 */
3105 	sx_slock(&wg_sx);
3106 	LIST_FOREACH(sc, &wg_list, sc_entry) {
3107 		sx_xlock(&sc->sc_lock);
3108 		if (!(sc->sc_flags & WGF_DYING) && sc->sc_ucred && sc->sc_ucred->cr_prison == pr) {
3109 			struct ucred *cred = sc->sc_ucred;
3110 			DPRINTF(sc, "Creating jail exiting\n");
3111 			if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
3112 			wg_socket_uninit(sc);
3113 			sc->sc_ucred = NULL;
3114 			crfree(cred);
3115 			sc->sc_flags |= WGF_DYING;
3116 		}
3117 		sx_xunlock(&sc->sc_lock);
3118 	}
3119 	sx_sunlock(&wg_sx);
3120 
3121 	return (0);
3122 }
3123 
3124 #ifdef SELFTESTS
3125 #include "selftest/allowedips.c"
wg_run_selftests(void)3126 static bool wg_run_selftests(void)
3127 {
3128 	bool ret = true;
3129 	ret &= wg_allowedips_selftest();
3130 	ret &= noise_counter_selftest();
3131 	ret &= cookie_selftest();
3132 	return ret;
3133 }
3134 #else
wg_run_selftests(void)3135 static inline bool wg_run_selftests(void) { return true; }
3136 #endif
3137 
3138 static int
wg_module_init(void)3139 wg_module_init(void)
3140 {
3141 	int ret;
3142 	osd_method_t methods[PR_MAXMETHOD] = {
3143 		[PR_METHOD_REMOVE] = wg_prison_remove,
3144 	};
3145 
3146 	wg_packet_zone = uma_zcreate("wg packet", sizeof(struct wg_packet),
3147 	     NULL, NULL, NULL, NULL, 0, 0);
3148 
3149 	ret = crypto_init();
3150 	if (ret != 0)
3151 		return (ret);
3152 	ret = cookie_init();
3153 	if (ret != 0)
3154 		return (ret);
3155 
3156 	wg_osd_jail_slot = osd_jail_register(NULL, methods);
3157 
3158 	if (!wg_run_selftests())
3159 		return (ENOTRECOVERABLE);
3160 
3161 	return (0);
3162 }
3163 
3164 static void
wg_module_deinit(void)3165 wg_module_deinit(void)
3166 {
3167 	VNET_ITERATOR_DECL(vnet_iter);
3168 	VNET_LIST_RLOCK();
3169 	VNET_FOREACH(vnet_iter) {
3170 		struct if_clone *clone = VNET_VNET(vnet_iter, wg_cloner);
3171 		if (clone) {
3172 			ifc_detach_cloner(clone);
3173 			VNET_VNET(vnet_iter, wg_cloner) = NULL;
3174 		}
3175 	}
3176 	VNET_LIST_RUNLOCK();
3177 	NET_EPOCH_WAIT();
3178 	MPASS(LIST_EMPTY(&wg_list));
3179 	if (wg_osd_jail_slot != 0)
3180 		osd_jail_deregister(wg_osd_jail_slot);
3181 	cookie_deinit();
3182 	crypto_deinit();
3183 	if (wg_packet_zone != NULL)
3184 		uma_zdestroy(wg_packet_zone);
3185 }
3186 
3187 static int
wg_module_event_handler(module_t mod,int what,void * arg)3188 wg_module_event_handler(module_t mod, int what, void *arg)
3189 {
3190 	switch (what) {
3191 		case MOD_LOAD:
3192 			return wg_module_init();
3193 		case MOD_UNLOAD:
3194 			wg_module_deinit();
3195 			break;
3196 		default:
3197 			return (EOPNOTSUPP);
3198 	}
3199 	return (0);
3200 }
3201 
3202 static moduledata_t wg_moduledata = {
3203 	"if_wg",
3204 	wg_module_event_handler,
3205 	NULL
3206 };
3207 
3208 DECLARE_MODULE(if_wg, wg_moduledata, SI_SUB_PSEUDO, SI_ORDER_ANY);
3209 MODULE_VERSION(if_wg, WIREGUARD_VERSION);
3210 MODULE_DEPEND(if_wg, crypto, 1, 1, 1);
3211