1 /* SPDX-License-Identifier: ISC
2 *
3 * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4 * Copyright (C) 2019-2021 Matt Dunwoodie <ncon@noconroy.net>
5 * Copyright (c) 2019-2020 Rubicon Communications, LLC (Netgate)
6 * Copyright (c) 2021 Kyle Evans <kevans@FreeBSD.org>
7 * Copyright (c) 2022 The FreeBSD Foundation
8 */
9
10 #include "opt_inet.h"
11 #include "opt_inet6.h"
12
13 #include <sys/param.h>
14 #include <sys/systm.h>
15 #include <sys/counter.h>
16 #include <sys/gtaskqueue.h>
17 #include <sys/jail.h>
18 #include <sys/kernel.h>
19 #include <sys/lock.h>
20 #include <sys/mbuf.h>
21 #include <sys/module.h>
22 #include <sys/nv.h>
23 #include <sys/priv.h>
24 #include <sys/protosw.h>
25 #include <sys/rmlock.h>
26 #include <sys/rwlock.h>
27 #include <sys/smp.h>
28 #include <sys/socket.h>
29 #include <sys/socketvar.h>
30 #include <sys/sockio.h>
31 #include <sys/sysctl.h>
32 #include <sys/sx.h>
33 #include <machine/_inttypes.h>
34 #include <net/bpf.h>
35 #include <net/ethernet.h>
36 #include <net/if.h>
37 #include <net/if_clone.h>
38 #include <net/if_types.h>
39 #include <net/if_var.h>
40 #include <net/netisr.h>
41 #include <net/radix.h>
42 #include <netinet/in.h>
43 #include <netinet6/in6_var.h>
44 #include <netinet/ip.h>
45 #include <netinet/ip6.h>
46 #include <netinet/ip_icmp.h>
47 #include <netinet/icmp6.h>
48 #include <netinet/udp_var.h>
49 #include <netinet6/nd6.h>
50
51 #include "wg_noise.h"
52 #include "wg_cookie.h"
53 #include "version.h"
54 #include "if_wg.h"
55
56 #define DEFAULT_MTU (ETHERMTU - 80)
57 #define MAX_MTU (IF_MAXMTU - 80)
58
59 #define MAX_STAGED_PKT 128
60 #define MAX_QUEUED_PKT 1024
61 #define MAX_QUEUED_PKT_MASK (MAX_QUEUED_PKT - 1)
62
63 #define MAX_QUEUED_HANDSHAKES 4096
64
65 #define REKEY_TIMEOUT_JITTER 334 /* 1/3 sec, round for arc4random_uniform */
66 #define MAX_TIMER_HANDSHAKES (90 / REKEY_TIMEOUT)
67 #define NEW_HANDSHAKE_TIMEOUT (REKEY_TIMEOUT + KEEPALIVE_TIMEOUT)
68 #define UNDERLOAD_TIMEOUT 1
69
70 #define DPRINTF(sc, ...) if (if_getflags(sc->sc_ifp) & IFF_DEBUG) if_printf(sc->sc_ifp, ##__VA_ARGS__)
71
72 /* First byte indicating packet type on the wire */
73 #define WG_PKT_INITIATION htole32(1)
74 #define WG_PKT_RESPONSE htole32(2)
75 #define WG_PKT_COOKIE htole32(3)
76 #define WG_PKT_DATA htole32(4)
77
78 #define WG_PKT_PADDING 16
79 #define WG_KEY_SIZE 32
80
81 struct wg_pkt_initiation {
82 uint32_t t;
83 uint32_t s_idx;
84 uint8_t ue[NOISE_PUBLIC_KEY_LEN];
85 uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN];
86 uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN];
87 struct cookie_macs m;
88 };
89
90 struct wg_pkt_response {
91 uint32_t t;
92 uint32_t s_idx;
93 uint32_t r_idx;
94 uint8_t ue[NOISE_PUBLIC_KEY_LEN];
95 uint8_t en[0 + NOISE_AUTHTAG_LEN];
96 struct cookie_macs m;
97 };
98
99 struct wg_pkt_cookie {
100 uint32_t t;
101 uint32_t r_idx;
102 uint8_t nonce[COOKIE_NONCE_SIZE];
103 uint8_t ec[COOKIE_ENCRYPTED_SIZE];
104 };
105
106 struct wg_pkt_data {
107 uint32_t t;
108 uint32_t r_idx;
109 uint64_t nonce;
110 uint8_t buf[];
111 };
112
113 struct wg_endpoint {
114 union {
115 struct sockaddr r_sa;
116 struct sockaddr_in r_sin;
117 #ifdef INET6
118 struct sockaddr_in6 r_sin6;
119 #endif
120 } e_remote;
121 union {
122 struct in_addr l_in;
123 #ifdef INET6
124 struct in6_pktinfo l_pktinfo6;
125 #define l_in6 l_pktinfo6.ipi6_addr
126 #endif
127 } e_local;
128 };
129
130 struct aip_addr {
131 uint8_t length;
132 union {
133 uint8_t bytes[16];
134 uint32_t ip;
135 uint32_t ip6[4];
136 struct in_addr in;
137 struct in6_addr in6;
138 };
139 };
140
141 struct wg_aip {
142 struct radix_node a_nodes[2];
143 LIST_ENTRY(wg_aip) a_entry;
144 struct aip_addr a_addr;
145 struct aip_addr a_mask;
146 struct wg_peer *a_peer;
147 sa_family_t a_af;
148 };
149
150 struct wg_packet {
151 STAILQ_ENTRY(wg_packet) p_serial;
152 STAILQ_ENTRY(wg_packet) p_parallel;
153 struct wg_endpoint p_endpoint;
154 struct noise_keypair *p_keypair;
155 uint64_t p_nonce;
156 struct mbuf *p_mbuf;
157 int p_mtu;
158 sa_family_t p_af;
159 enum wg_ring_state {
160 WG_PACKET_UNCRYPTED,
161 WG_PACKET_CRYPTED,
162 WG_PACKET_DEAD,
163 } p_state;
164 };
165
166 STAILQ_HEAD(wg_packet_list, wg_packet);
167
168 struct wg_queue {
169 struct mtx q_mtx;
170 struct wg_packet_list q_queue;
171 size_t q_len;
172 };
173
174 struct wg_peer {
175 TAILQ_ENTRY(wg_peer) p_entry;
176 uint64_t p_id;
177 struct wg_softc *p_sc;
178
179 struct noise_remote *p_remote;
180 struct cookie_maker p_cookie;
181
182 struct rwlock p_endpoint_lock;
183 struct wg_endpoint p_endpoint;
184
185 struct wg_queue p_stage_queue;
186 struct wg_queue p_encrypt_serial;
187 struct wg_queue p_decrypt_serial;
188
189 bool p_enabled;
190 bool p_need_another_keepalive;
191 uint16_t p_persistent_keepalive_interval;
192 struct callout p_new_handshake;
193 struct callout p_send_keepalive;
194 struct callout p_retry_handshake;
195 struct callout p_zero_key_material;
196 struct callout p_persistent_keepalive;
197
198 struct mtx p_handshake_mtx;
199 struct timespec p_handshake_complete; /* nanotime */
200 int p_handshake_retries;
201
202 struct grouptask p_send;
203 struct grouptask p_recv;
204
205 counter_u64_t p_tx_bytes;
206 counter_u64_t p_rx_bytes;
207
208 LIST_HEAD(, wg_aip) p_aips;
209 size_t p_aips_num;
210 };
211
212 struct wg_socket {
213 struct socket *so_so4;
214 struct socket *so_so6;
215 uint32_t so_user_cookie;
216 int so_fibnum;
217 in_port_t so_port;
218 };
219
220 struct wg_softc {
221 LIST_ENTRY(wg_softc) sc_entry;
222 if_t sc_ifp;
223 int sc_flags;
224
225 struct ucred *sc_ucred;
226 struct wg_socket sc_socket;
227
228 TAILQ_HEAD(,wg_peer) sc_peers;
229 size_t sc_peers_num;
230
231 struct noise_local *sc_local;
232 struct cookie_checker sc_cookie;
233
234 struct radix_node_head *sc_aip4;
235 struct radix_node_head *sc_aip6;
236
237 struct grouptask sc_handshake;
238 struct wg_queue sc_handshake_queue;
239
240 struct grouptask *sc_encrypt;
241 struct grouptask *sc_decrypt;
242 struct wg_queue sc_encrypt_parallel;
243 struct wg_queue sc_decrypt_parallel;
244 u_int sc_encrypt_last_cpu;
245 u_int sc_decrypt_last_cpu;
246
247 struct sx sc_lock;
248 };
249
250 #define WGF_DYING 0x0001
251
252 #define MAX_LOOPS 8
253 #define MTAG_WGLOOP 0x77676c70 /* wglp */
254
255 #define GROUPTASK_DRAIN(gtask) \
256 gtaskqueue_drain((gtask)->gt_taskqueue, &(gtask)->gt_task)
257
258 #define BPF_MTAP2_AF(ifp, m, af) do { \
259 uint32_t __bpf_tap_af = (af); \
260 BPF_MTAP2(ifp, &__bpf_tap_af, sizeof(__bpf_tap_af), m); \
261 } while (0)
262
263 static int clone_count;
264 static uma_zone_t wg_packet_zone;
265 static volatile unsigned long peer_counter = 0;
266 static const char wgname[] = "wg";
267 static unsigned wg_osd_jail_slot;
268
269 static struct sx wg_sx;
270 SX_SYSINIT(wg_sx, &wg_sx, "wg_sx");
271
272 static LIST_HEAD(, wg_softc) wg_list = LIST_HEAD_INITIALIZER(wg_list);
273
274 static TASKQGROUP_DEFINE(wg_tqg, mp_ncpus, 1);
275
276 MALLOC_DEFINE(M_WG, "WG", "wireguard");
277
278 VNET_DEFINE_STATIC(struct if_clone *, wg_cloner);
279
280 #define V_wg_cloner VNET(wg_cloner)
281 #define WG_CAPS IFCAP_LINKSTATE
282
283 struct wg_timespec64 {
284 uint64_t tv_sec;
285 uint64_t tv_nsec;
286 };
287
288 static int wg_socket_init(struct wg_softc *, in_port_t);
289 static int wg_socket_bind(struct socket **, struct socket **, in_port_t *);
290 static void wg_socket_set(struct wg_softc *, struct socket *, struct socket *);
291 static void wg_socket_uninit(struct wg_softc *);
292 static int wg_socket_set_sockopt(struct socket *, struct socket *, int, void *, size_t);
293 static int wg_socket_set_cookie(struct wg_softc *, uint32_t);
294 static int wg_socket_set_fibnum(struct wg_softc *, int);
295 static int wg_send(struct wg_softc *, struct wg_endpoint *, struct mbuf *);
296 static void wg_timers_enable(struct wg_peer *);
297 static void wg_timers_disable(struct wg_peer *);
298 static void wg_timers_set_persistent_keepalive(struct wg_peer *, uint16_t);
299 static void wg_timers_get_last_handshake(struct wg_peer *, struct wg_timespec64 *);
300 static void wg_timers_event_data_sent(struct wg_peer *);
301 static void wg_timers_event_data_received(struct wg_peer *);
302 static void wg_timers_event_any_authenticated_packet_sent(struct wg_peer *);
303 static void wg_timers_event_any_authenticated_packet_received(struct wg_peer *);
304 static void wg_timers_event_any_authenticated_packet_traversal(struct wg_peer *);
305 static void wg_timers_event_handshake_initiated(struct wg_peer *);
306 static void wg_timers_event_handshake_complete(struct wg_peer *);
307 static void wg_timers_event_session_derived(struct wg_peer *);
308 static void wg_timers_event_want_initiation(struct wg_peer *);
309 static void wg_timers_run_send_initiation(struct wg_peer *, bool);
310 static void wg_timers_run_retry_handshake(void *);
311 static void wg_timers_run_send_keepalive(void *);
312 static void wg_timers_run_new_handshake(void *);
313 static void wg_timers_run_zero_key_material(void *);
314 static void wg_timers_run_persistent_keepalive(void *);
315 static int wg_aip_add(struct wg_softc *, struct wg_peer *, sa_family_t, const void *, uint8_t);
316 static struct wg_peer *wg_aip_lookup(struct wg_softc *, sa_family_t, void *);
317 static void wg_aip_remove_all(struct wg_softc *, struct wg_peer *);
318 static struct wg_peer *wg_peer_alloc(struct wg_softc *, const uint8_t [WG_KEY_SIZE]);
319 static void wg_peer_free_deferred(struct noise_remote *);
320 static void wg_peer_destroy(struct wg_peer *);
321 static void wg_peer_destroy_all(struct wg_softc *);
322 static void wg_peer_send_buf(struct wg_peer *, uint8_t *, size_t);
323 static void wg_send_initiation(struct wg_peer *);
324 static void wg_send_response(struct wg_peer *);
325 static void wg_send_cookie(struct wg_softc *, struct cookie_macs *, uint32_t, struct wg_endpoint *);
326 static void wg_peer_set_endpoint(struct wg_peer *, struct wg_endpoint *);
327 static void wg_peer_clear_src(struct wg_peer *);
328 static void wg_peer_get_endpoint(struct wg_peer *, struct wg_endpoint *);
329 static void wg_send_buf(struct wg_softc *, struct wg_endpoint *, uint8_t *, size_t);
330 static void wg_send_keepalive(struct wg_peer *);
331 static void wg_handshake(struct wg_softc *, struct wg_packet *);
332 static void wg_encrypt(struct wg_softc *, struct wg_packet *);
333 static void wg_decrypt(struct wg_softc *, struct wg_packet *);
334 static void wg_softc_handshake_receive(struct wg_softc *);
335 static void wg_softc_decrypt(struct wg_softc *);
336 static void wg_softc_encrypt(struct wg_softc *);
337 static void wg_encrypt_dispatch(struct wg_softc *);
338 static void wg_decrypt_dispatch(struct wg_softc *);
339 static void wg_deliver_out(struct wg_peer *);
340 static void wg_deliver_in(struct wg_peer *);
341 static struct wg_packet *wg_packet_alloc(struct mbuf *);
342 static void wg_packet_free(struct wg_packet *);
343 static void wg_queue_init(struct wg_queue *, const char *);
344 static void wg_queue_deinit(struct wg_queue *);
345 static size_t wg_queue_len(struct wg_queue *);
346 static int wg_queue_enqueue_handshake(struct wg_queue *, struct wg_packet *);
347 static struct wg_packet *wg_queue_dequeue_handshake(struct wg_queue *);
348 static void wg_queue_push_staged(struct wg_queue *, struct wg_packet *);
349 static void wg_queue_enlist_staged(struct wg_queue *, struct wg_packet_list *);
350 static void wg_queue_delist_staged(struct wg_queue *, struct wg_packet_list *);
351 static void wg_queue_purge(struct wg_queue *);
352 static int wg_queue_both(struct wg_queue *, struct wg_queue *, struct wg_packet *);
353 static struct wg_packet *wg_queue_dequeue_serial(struct wg_queue *);
354 static struct wg_packet *wg_queue_dequeue_parallel(struct wg_queue *);
355 static bool wg_input(struct mbuf *, int, struct inpcb *, const struct sockaddr *, void *);
356 static void wg_peer_send_staged(struct wg_peer *);
357 static int wg_clone_create(struct if_clone *ifc, char *name, size_t len,
358 struct ifc_data *ifd, if_t *ifpp);
359 static void wg_qflush(if_t);
360 static inline int determine_af_and_pullup(struct mbuf **m, sa_family_t *af);
361 static int wg_xmit(if_t, struct mbuf *, sa_family_t, uint32_t);
362 static int wg_transmit(if_t, struct mbuf *);
363 static int wg_output(if_t, struct mbuf *, const struct sockaddr *, struct route *);
364 static int wg_clone_destroy(struct if_clone *ifc, if_t ifp,
365 uint32_t flags);
366 static bool wgc_privileged(struct wg_softc *);
367 static int wgc_get(struct wg_softc *, struct wg_data_io *);
368 static int wgc_set(struct wg_softc *, struct wg_data_io *);
369 static int wg_up(struct wg_softc *);
370 static void wg_down(struct wg_softc *);
371 static void wg_reassign(if_t, struct vnet *, char *unused);
372 static void wg_init(void *);
373 static int wg_ioctl(if_t, u_long, caddr_t);
374 static void vnet_wg_init(const void *);
375 static void vnet_wg_uninit(const void *);
376 static int wg_module_init(void);
377 static void wg_module_deinit(void);
378
379 /* TODO Peer */
380 static struct wg_peer *
wg_peer_alloc(struct wg_softc * sc,const uint8_t pub_key[WG_KEY_SIZE])381 wg_peer_alloc(struct wg_softc *sc, const uint8_t pub_key[WG_KEY_SIZE])
382 {
383 struct wg_peer *peer;
384
385 sx_assert(&sc->sc_lock, SX_XLOCKED);
386
387 peer = malloc(sizeof(*peer), M_WG, M_WAITOK | M_ZERO);
388 peer->p_remote = noise_remote_alloc(sc->sc_local, peer, pub_key);
389 peer->p_tx_bytes = counter_u64_alloc(M_WAITOK);
390 peer->p_rx_bytes = counter_u64_alloc(M_WAITOK);
391 peer->p_id = peer_counter++;
392 peer->p_sc = sc;
393
394 cookie_maker_init(&peer->p_cookie, pub_key);
395
396 rw_init(&peer->p_endpoint_lock, "wg_peer_endpoint");
397
398 wg_queue_init(&peer->p_stage_queue, "stageq");
399 wg_queue_init(&peer->p_encrypt_serial, "txq");
400 wg_queue_init(&peer->p_decrypt_serial, "rxq");
401
402 peer->p_enabled = false;
403 peer->p_need_another_keepalive = false;
404 peer->p_persistent_keepalive_interval = 0;
405 callout_init(&peer->p_new_handshake, true);
406 callout_init(&peer->p_send_keepalive, true);
407 callout_init(&peer->p_retry_handshake, true);
408 callout_init(&peer->p_persistent_keepalive, true);
409 callout_init(&peer->p_zero_key_material, true);
410
411 mtx_init(&peer->p_handshake_mtx, "peer handshake", NULL, MTX_DEF);
412 bzero(&peer->p_handshake_complete, sizeof(peer->p_handshake_complete));
413 peer->p_handshake_retries = 0;
414
415 GROUPTASK_INIT(&peer->p_send, 0, (gtask_fn_t *)wg_deliver_out, peer);
416 taskqgroup_attach(qgroup_wg_tqg, &peer->p_send, peer, NULL, NULL, "wg send");
417 GROUPTASK_INIT(&peer->p_recv, 0, (gtask_fn_t *)wg_deliver_in, peer);
418 taskqgroup_attach(qgroup_wg_tqg, &peer->p_recv, peer, NULL, NULL, "wg recv");
419
420 LIST_INIT(&peer->p_aips);
421 peer->p_aips_num = 0;
422
423 return (peer);
424 }
425
426 static void
wg_peer_free_deferred(struct noise_remote * r)427 wg_peer_free_deferred(struct noise_remote *r)
428 {
429 struct wg_peer *peer = noise_remote_arg(r);
430
431 /* While there are no references remaining, we may still have
432 * p_{send,recv} executing (think empty queue, but wg_deliver_{in,out}
433 * needs to check the queue. We should wait for them and then free. */
434 GROUPTASK_DRAIN(&peer->p_recv);
435 GROUPTASK_DRAIN(&peer->p_send);
436 taskqgroup_detach(qgroup_wg_tqg, &peer->p_recv);
437 taskqgroup_detach(qgroup_wg_tqg, &peer->p_send);
438
439 wg_queue_deinit(&peer->p_decrypt_serial);
440 wg_queue_deinit(&peer->p_encrypt_serial);
441 wg_queue_deinit(&peer->p_stage_queue);
442
443 counter_u64_free(peer->p_tx_bytes);
444 counter_u64_free(peer->p_rx_bytes);
445 rw_destroy(&peer->p_endpoint_lock);
446 mtx_destroy(&peer->p_handshake_mtx);
447
448 cookie_maker_free(&peer->p_cookie);
449
450 free(peer, M_WG);
451 }
452
453 static void
wg_peer_destroy(struct wg_peer * peer)454 wg_peer_destroy(struct wg_peer *peer)
455 {
456 struct wg_softc *sc = peer->p_sc;
457 sx_assert(&sc->sc_lock, SX_XLOCKED);
458
459 /* Disable remote and timers. This will prevent any new handshakes
460 * occuring. */
461 noise_remote_disable(peer->p_remote);
462 wg_timers_disable(peer);
463
464 /* Now we can remove all allowed IPs so no more packets will be routed
465 * to the peer. */
466 wg_aip_remove_all(sc, peer);
467
468 /* Remove peer from the interface, then free. Some references may still
469 * exist to p_remote, so noise_remote_free will wait until they're all
470 * put to call wg_peer_free_deferred. */
471 sc->sc_peers_num--;
472 TAILQ_REMOVE(&sc->sc_peers, peer, p_entry);
473 DPRINTF(sc, "Peer %" PRIu64 " destroyed\n", peer->p_id);
474 noise_remote_free(peer->p_remote, wg_peer_free_deferred);
475 }
476
477 static void
wg_peer_destroy_all(struct wg_softc * sc)478 wg_peer_destroy_all(struct wg_softc *sc)
479 {
480 struct wg_peer *peer, *tpeer;
481 TAILQ_FOREACH_SAFE(peer, &sc->sc_peers, p_entry, tpeer)
482 wg_peer_destroy(peer);
483 }
484
485 static void
wg_peer_set_endpoint(struct wg_peer * peer,struct wg_endpoint * e)486 wg_peer_set_endpoint(struct wg_peer *peer, struct wg_endpoint *e)
487 {
488 MPASS(e->e_remote.r_sa.sa_family != 0);
489 if (memcmp(e, &peer->p_endpoint, sizeof(*e)) == 0)
490 return;
491
492 rw_wlock(&peer->p_endpoint_lock);
493 peer->p_endpoint = *e;
494 rw_wunlock(&peer->p_endpoint_lock);
495 }
496
497 static void
wg_peer_clear_src(struct wg_peer * peer)498 wg_peer_clear_src(struct wg_peer *peer)
499 {
500 rw_wlock(&peer->p_endpoint_lock);
501 bzero(&peer->p_endpoint.e_local, sizeof(peer->p_endpoint.e_local));
502 rw_wunlock(&peer->p_endpoint_lock);
503 }
504
505 static void
wg_peer_get_endpoint(struct wg_peer * peer,struct wg_endpoint * e)506 wg_peer_get_endpoint(struct wg_peer *peer, struct wg_endpoint *e)
507 {
508 rw_rlock(&peer->p_endpoint_lock);
509 *e = peer->p_endpoint;
510 rw_runlock(&peer->p_endpoint_lock);
511 }
512
513 /* Allowed IP */
514 static int
wg_aip_add(struct wg_softc * sc,struct wg_peer * peer,sa_family_t af,const void * addr,uint8_t cidr)515 wg_aip_add(struct wg_softc *sc, struct wg_peer *peer, sa_family_t af, const void *addr, uint8_t cidr)
516 {
517 struct radix_node_head *root;
518 struct radix_node *node;
519 struct wg_aip *aip;
520 int ret = 0;
521
522 aip = malloc(sizeof(*aip), M_WG, M_WAITOK | M_ZERO);
523 aip->a_peer = peer;
524 aip->a_af = af;
525
526 switch (af) {
527 #ifdef INET
528 case AF_INET:
529 if (cidr > 32) cidr = 32;
530 root = sc->sc_aip4;
531 aip->a_addr.in = *(const struct in_addr *)addr;
532 aip->a_mask.ip = htonl(~((1LL << (32 - cidr)) - 1) & 0xffffffff);
533 aip->a_addr.ip &= aip->a_mask.ip;
534 aip->a_addr.length = aip->a_mask.length = offsetof(struct aip_addr, in) + sizeof(struct in_addr);
535 break;
536 #endif
537 #ifdef INET6
538 case AF_INET6:
539 if (cidr > 128) cidr = 128;
540 root = sc->sc_aip6;
541 aip->a_addr.in6 = *(const struct in6_addr *)addr;
542 in6_prefixlen2mask(&aip->a_mask.in6, cidr);
543 for (int i = 0; i < 4; i++)
544 aip->a_addr.ip6[i] &= aip->a_mask.ip6[i];
545 aip->a_addr.length = aip->a_mask.length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr);
546 break;
547 #endif
548 default:
549 free(aip, M_WG);
550 return (EAFNOSUPPORT);
551 }
552
553 RADIX_NODE_HEAD_LOCK(root);
554 node = root->rnh_addaddr(&aip->a_addr, &aip->a_mask, &root->rh, aip->a_nodes);
555 if (node == aip->a_nodes) {
556 LIST_INSERT_HEAD(&peer->p_aips, aip, a_entry);
557 peer->p_aips_num++;
558 } else if (!node)
559 node = root->rnh_lookup(&aip->a_addr, &aip->a_mask, &root->rh);
560 if (!node) {
561 free(aip, M_WG);
562 ret = ENOMEM;
563 } else if (node != aip->a_nodes) {
564 free(aip, M_WG);
565 aip = (struct wg_aip *)node;
566 if (aip->a_peer != peer) {
567 LIST_REMOVE(aip, a_entry);
568 aip->a_peer->p_aips_num--;
569 aip->a_peer = peer;
570 LIST_INSERT_HEAD(&peer->p_aips, aip, a_entry);
571 aip->a_peer->p_aips_num++;
572 }
573 }
574 RADIX_NODE_HEAD_UNLOCK(root);
575 return (ret);
576 }
577
578 static struct wg_peer *
wg_aip_lookup(struct wg_softc * sc,sa_family_t af,void * a)579 wg_aip_lookup(struct wg_softc *sc, sa_family_t af, void *a)
580 {
581 struct radix_node_head *root;
582 struct radix_node *node;
583 struct wg_peer *peer;
584 struct aip_addr addr;
585 RADIX_NODE_HEAD_RLOCK_TRACKER;
586
587 switch (af) {
588 case AF_INET:
589 root = sc->sc_aip4;
590 memcpy(&addr.in, a, sizeof(addr.in));
591 addr.length = offsetof(struct aip_addr, in) + sizeof(struct in_addr);
592 break;
593 case AF_INET6:
594 root = sc->sc_aip6;
595 memcpy(&addr.in6, a, sizeof(addr.in6));
596 addr.length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr);
597 break;
598 default:
599 return NULL;
600 }
601
602 RADIX_NODE_HEAD_RLOCK(root);
603 node = root->rnh_matchaddr(&addr, &root->rh);
604 if (node != NULL) {
605 peer = ((struct wg_aip *)node)->a_peer;
606 noise_remote_ref(peer->p_remote);
607 } else {
608 peer = NULL;
609 }
610 RADIX_NODE_HEAD_RUNLOCK(root);
611
612 return (peer);
613 }
614
615 static void
wg_aip_remove_all(struct wg_softc * sc,struct wg_peer * peer)616 wg_aip_remove_all(struct wg_softc *sc, struct wg_peer *peer)
617 {
618 struct wg_aip *aip, *taip;
619
620 RADIX_NODE_HEAD_LOCK(sc->sc_aip4);
621 LIST_FOREACH_SAFE(aip, &peer->p_aips, a_entry, taip) {
622 if (aip->a_af == AF_INET) {
623 if (sc->sc_aip4->rnh_deladdr(&aip->a_addr, &aip->a_mask, &sc->sc_aip4->rh) == NULL)
624 panic("failed to delete aip %p", aip);
625 LIST_REMOVE(aip, a_entry);
626 peer->p_aips_num--;
627 free(aip, M_WG);
628 }
629 }
630 RADIX_NODE_HEAD_UNLOCK(sc->sc_aip4);
631
632 RADIX_NODE_HEAD_LOCK(sc->sc_aip6);
633 LIST_FOREACH_SAFE(aip, &peer->p_aips, a_entry, taip) {
634 if (aip->a_af == AF_INET6) {
635 if (sc->sc_aip6->rnh_deladdr(&aip->a_addr, &aip->a_mask, &sc->sc_aip6->rh) == NULL)
636 panic("failed to delete aip %p", aip);
637 LIST_REMOVE(aip, a_entry);
638 peer->p_aips_num--;
639 free(aip, M_WG);
640 }
641 }
642 RADIX_NODE_HEAD_UNLOCK(sc->sc_aip6);
643
644 if (!LIST_EMPTY(&peer->p_aips) || peer->p_aips_num != 0)
645 panic("wg_aip_remove_all could not delete all %p", peer);
646 }
647
648 static int
wg_socket_init(struct wg_softc * sc,in_port_t port)649 wg_socket_init(struct wg_softc *sc, in_port_t port)
650 {
651 struct ucred *cred = sc->sc_ucred;
652 struct socket *so4 = NULL, *so6 = NULL;
653 int rc;
654
655 sx_assert(&sc->sc_lock, SX_XLOCKED);
656
657 if (!cred)
658 return (EBUSY);
659
660 /*
661 * For socket creation, we use the creds of the thread that created the
662 * tunnel rather than the current thread to maintain the semantics that
663 * WireGuard has on Linux with network namespaces -- that the sockets
664 * are created in their home vnet so that they can be configured and
665 * functionally attached to a foreign vnet as the jail's only interface
666 * to the network.
667 */
668 #ifdef INET
669 rc = socreate(AF_INET, &so4, SOCK_DGRAM, IPPROTO_UDP, cred, curthread);
670 if (rc)
671 goto out;
672
673 rc = udp_set_kernel_tunneling(so4, wg_input, NULL, sc);
674 /*
675 * udp_set_kernel_tunneling can only fail if there is already a tunneling function set.
676 * This should never happen with a new socket.
677 */
678 MPASS(rc == 0);
679 #endif
680
681 #ifdef INET6
682 rc = socreate(AF_INET6, &so6, SOCK_DGRAM, IPPROTO_UDP, cred, curthread);
683 if (rc)
684 goto out;
685 rc = udp_set_kernel_tunneling(so6, wg_input, NULL, sc);
686 MPASS(rc == 0);
687 #endif
688
689 if (sc->sc_socket.so_user_cookie) {
690 rc = wg_socket_set_sockopt(so4, so6, SO_USER_COOKIE, &sc->sc_socket.so_user_cookie, sizeof(sc->sc_socket.so_user_cookie));
691 if (rc)
692 goto out;
693 }
694 rc = wg_socket_set_sockopt(so4, so6, SO_SETFIB, &sc->sc_socket.so_fibnum, sizeof(sc->sc_socket.so_fibnum));
695 if (rc)
696 goto out;
697
698 rc = wg_socket_bind(&so4, &so6, &port);
699 if (!rc) {
700 sc->sc_socket.so_port = port;
701 wg_socket_set(sc, so4, so6);
702 }
703 out:
704 if (rc) {
705 if (so4 != NULL)
706 soclose(so4);
707 if (so6 != NULL)
708 soclose(so6);
709 }
710 return (rc);
711 }
712
wg_socket_set_sockopt(struct socket * so4,struct socket * so6,int name,void * val,size_t len)713 static int wg_socket_set_sockopt(struct socket *so4, struct socket *so6, int name, void *val, size_t len)
714 {
715 int ret4 = 0, ret6 = 0;
716 struct sockopt sopt = {
717 .sopt_dir = SOPT_SET,
718 .sopt_level = SOL_SOCKET,
719 .sopt_name = name,
720 .sopt_val = val,
721 .sopt_valsize = len
722 };
723
724 if (so4)
725 ret4 = sosetopt(so4, &sopt);
726 if (so6)
727 ret6 = sosetopt(so6, &sopt);
728 return (ret4 ?: ret6);
729 }
730
wg_socket_set_cookie(struct wg_softc * sc,uint32_t user_cookie)731 static int wg_socket_set_cookie(struct wg_softc *sc, uint32_t user_cookie)
732 {
733 struct wg_socket *so = &sc->sc_socket;
734 int ret;
735
736 sx_assert(&sc->sc_lock, SX_XLOCKED);
737 ret = wg_socket_set_sockopt(so->so_so4, so->so_so6, SO_USER_COOKIE, &user_cookie, sizeof(user_cookie));
738 if (!ret)
739 so->so_user_cookie = user_cookie;
740 return (ret);
741 }
742
wg_socket_set_fibnum(struct wg_softc * sc,int fibnum)743 static int wg_socket_set_fibnum(struct wg_softc *sc, int fibnum)
744 {
745 struct wg_socket *so = &sc->sc_socket;
746 int ret;
747
748 sx_assert(&sc->sc_lock, SX_XLOCKED);
749
750 ret = wg_socket_set_sockopt(so->so_so4, so->so_so6, SO_SETFIB, &fibnum, sizeof(fibnum));
751 if (!ret)
752 so->so_fibnum = fibnum;
753 return (ret);
754 }
755
756 static void
wg_socket_uninit(struct wg_softc * sc)757 wg_socket_uninit(struct wg_softc *sc)
758 {
759 wg_socket_set(sc, NULL, NULL);
760 }
761
762 static void
wg_socket_set(struct wg_softc * sc,struct socket * new_so4,struct socket * new_so6)763 wg_socket_set(struct wg_softc *sc, struct socket *new_so4, struct socket *new_so6)
764 {
765 struct wg_socket *so = &sc->sc_socket;
766 struct socket *so4, *so6;
767
768 sx_assert(&sc->sc_lock, SX_XLOCKED);
769
770 so4 = atomic_load_ptr(&so->so_so4);
771 so6 = atomic_load_ptr(&so->so_so6);
772 atomic_store_ptr(&so->so_so4, new_so4);
773 atomic_store_ptr(&so->so_so6, new_so6);
774
775 if (!so4 && !so6)
776 return;
777 NET_EPOCH_WAIT();
778 if (so4)
779 soclose(so4);
780 if (so6)
781 soclose(so6);
782 }
783
784 static int
wg_socket_bind(struct socket ** in_so4,struct socket ** in_so6,in_port_t * requested_port)785 wg_socket_bind(struct socket **in_so4, struct socket **in_so6, in_port_t *requested_port)
786 {
787 struct socket *so4 = *in_so4, *so6 = *in_so6;
788 int ret4 = 0, ret6 = 0;
789 in_port_t port = *requested_port;
790 struct sockaddr_in sin = {
791 .sin_len = sizeof(struct sockaddr_in),
792 .sin_family = AF_INET,
793 .sin_port = htons(port)
794 };
795 struct sockaddr_in6 sin6 = {
796 .sin6_len = sizeof(struct sockaddr_in6),
797 .sin6_family = AF_INET6,
798 .sin6_port = htons(port)
799 };
800
801 if (so4) {
802 ret4 = sobind(so4, (struct sockaddr *)&sin, curthread);
803 if (ret4 && ret4 != EADDRNOTAVAIL)
804 return (ret4);
805 if (!ret4 && !sin.sin_port) {
806 struct sockaddr_in bound_sin =
807 { .sin_len = sizeof(bound_sin) };
808 int ret;
809
810 ret = sosockaddr(so4, (struct sockaddr *)&bound_sin);
811 if (ret)
812 return (ret);
813 port = ntohs(bound_sin.sin_port);
814 sin6.sin6_port = bound_sin.sin_port;
815 }
816 }
817
818 if (so6) {
819 ret6 = sobind(so6, (struct sockaddr *)&sin6, curthread);
820 if (ret6 && ret6 != EADDRNOTAVAIL)
821 return (ret6);
822 if (!ret6 && !sin6.sin6_port) {
823 struct sockaddr_in6 bound_sin6 =
824 { .sin6_len = sizeof(bound_sin6) };
825 int ret;
826
827 ret = sosockaddr(so6, (struct sockaddr *)&bound_sin6);
828 if (ret)
829 return (ret);
830 port = ntohs(bound_sin6.sin6_port);
831 }
832 }
833
834 if (ret4 && ret6)
835 return (ret4);
836 *requested_port = port;
837 if (ret4 && !ret6 && so4) {
838 soclose(so4);
839 *in_so4 = NULL;
840 } else if (ret6 && !ret4 && so6) {
841 soclose(so6);
842 *in_so6 = NULL;
843 }
844 return (0);
845 }
846
847 static int
wg_send(struct wg_softc * sc,struct wg_endpoint * e,struct mbuf * m)848 wg_send(struct wg_softc *sc, struct wg_endpoint *e, struct mbuf *m)
849 {
850 struct epoch_tracker et;
851 struct sockaddr *sa;
852 struct wg_socket *so = &sc->sc_socket;
853 struct socket *so4, *so6;
854 struct mbuf *control = NULL;
855 int ret = 0;
856 size_t len = m->m_pkthdr.len;
857
858 /* Get local control address before locking */
859 if (e->e_remote.r_sa.sa_family == AF_INET) {
860 if (e->e_local.l_in.s_addr != INADDR_ANY)
861 control = sbcreatecontrol((caddr_t)&e->e_local.l_in,
862 sizeof(struct in_addr), IP_SENDSRCADDR,
863 IPPROTO_IP, M_NOWAIT);
864 #ifdef INET6
865 } else if (e->e_remote.r_sa.sa_family == AF_INET6) {
866 if (!IN6_IS_ADDR_UNSPECIFIED(&e->e_local.l_in6))
867 control = sbcreatecontrol((caddr_t)&e->e_local.l_pktinfo6,
868 sizeof(struct in6_pktinfo), IPV6_PKTINFO,
869 IPPROTO_IPV6, M_NOWAIT);
870 #endif
871 } else {
872 m_freem(m);
873 return (EAFNOSUPPORT);
874 }
875
876 /* Get remote address */
877 sa = &e->e_remote.r_sa;
878
879 NET_EPOCH_ENTER(et);
880 so4 = atomic_load_ptr(&so->so_so4);
881 so6 = atomic_load_ptr(&so->so_so6);
882 if (e->e_remote.r_sa.sa_family == AF_INET && so4 != NULL)
883 ret = sosend(so4, sa, NULL, m, control, 0, curthread);
884 else if (e->e_remote.r_sa.sa_family == AF_INET6 && so6 != NULL)
885 ret = sosend(so6, sa, NULL, m, control, 0, curthread);
886 else {
887 ret = ENOTCONN;
888 m_freem(control);
889 m_freem(m);
890 }
891 NET_EPOCH_EXIT(et);
892 if (ret == 0) {
893 if_inc_counter(sc->sc_ifp, IFCOUNTER_OPACKETS, 1);
894 if_inc_counter(sc->sc_ifp, IFCOUNTER_OBYTES, len);
895 }
896 return (ret);
897 }
898
899 static void
wg_send_buf(struct wg_softc * sc,struct wg_endpoint * e,uint8_t * buf,size_t len)900 wg_send_buf(struct wg_softc *sc, struct wg_endpoint *e, uint8_t *buf, size_t len)
901 {
902 struct mbuf *m;
903 int ret = 0;
904 bool retried = false;
905
906 retry:
907 m = m_get2(len, M_NOWAIT, MT_DATA, M_PKTHDR);
908 if (!m) {
909 ret = ENOMEM;
910 goto out;
911 }
912 m_copyback(m, 0, len, buf);
913
914 if (ret == 0) {
915 ret = wg_send(sc, e, m);
916 /* Retry if we couldn't bind to e->e_local */
917 if (ret == EADDRNOTAVAIL && !retried) {
918 bzero(&e->e_local, sizeof(e->e_local));
919 retried = true;
920 goto retry;
921 }
922 } else {
923 ret = wg_send(sc, e, m);
924 }
925 out:
926 if (ret)
927 DPRINTF(sc, "Unable to send packet: %d\n", ret);
928 }
929
930 /* Timers */
931 static void
wg_timers_enable(struct wg_peer * peer)932 wg_timers_enable(struct wg_peer *peer)
933 {
934 atomic_store_bool(&peer->p_enabled, true);
935 wg_timers_run_persistent_keepalive(peer);
936 }
937
938 static void
wg_timers_disable(struct wg_peer * peer)939 wg_timers_disable(struct wg_peer *peer)
940 {
941 /* By setting p_enabled = false, then calling NET_EPOCH_WAIT, we can be
942 * sure no new handshakes are created after the wait. This is because
943 * all callout_resets (scheduling the callout) are guarded by
944 * p_enabled. We can be sure all sections that read p_enabled and then
945 * optionally call callout_reset are finished as they are surrounded by
946 * NET_EPOCH_{ENTER,EXIT}.
947 *
948 * However, as new callouts may be scheduled during NET_EPOCH_WAIT (but
949 * not after), we stop all callouts leaving no callouts active.
950 *
951 * We should also pull NET_EPOCH_WAIT out of the FOREACH(peer) loops, but the
952 * performance impact is acceptable for the time being. */
953 atomic_store_bool(&peer->p_enabled, false);
954 NET_EPOCH_WAIT();
955 atomic_store_bool(&peer->p_need_another_keepalive, false);
956
957 callout_stop(&peer->p_new_handshake);
958 callout_stop(&peer->p_send_keepalive);
959 callout_stop(&peer->p_retry_handshake);
960 callout_stop(&peer->p_persistent_keepalive);
961 callout_stop(&peer->p_zero_key_material);
962 }
963
964 static void
wg_timers_set_persistent_keepalive(struct wg_peer * peer,uint16_t interval)965 wg_timers_set_persistent_keepalive(struct wg_peer *peer, uint16_t interval)
966 {
967 struct epoch_tracker et;
968 if (interval != peer->p_persistent_keepalive_interval) {
969 atomic_store_16(&peer->p_persistent_keepalive_interval, interval);
970 NET_EPOCH_ENTER(et);
971 if (atomic_load_bool(&peer->p_enabled))
972 wg_timers_run_persistent_keepalive(peer);
973 NET_EPOCH_EXIT(et);
974 }
975 }
976
977 static void
wg_timers_get_last_handshake(struct wg_peer * peer,struct wg_timespec64 * time)978 wg_timers_get_last_handshake(struct wg_peer *peer, struct wg_timespec64 *time)
979 {
980 mtx_lock(&peer->p_handshake_mtx);
981 time->tv_sec = peer->p_handshake_complete.tv_sec;
982 time->tv_nsec = peer->p_handshake_complete.tv_nsec;
983 mtx_unlock(&peer->p_handshake_mtx);
984 }
985
986 static void
wg_timers_event_data_sent(struct wg_peer * peer)987 wg_timers_event_data_sent(struct wg_peer *peer)
988 {
989 struct epoch_tracker et;
990 NET_EPOCH_ENTER(et);
991 if (atomic_load_bool(&peer->p_enabled) &&
992 !callout_pending(&peer->p_new_handshake))
993 callout_reset(&peer->p_new_handshake, MSEC_2_TICKS(
994 NEW_HANDSHAKE_TIMEOUT * 1000 +
995 arc4random_uniform(REKEY_TIMEOUT_JITTER)),
996 wg_timers_run_new_handshake, peer);
997 NET_EPOCH_EXIT(et);
998 }
999
1000 static void
wg_timers_event_data_received(struct wg_peer * peer)1001 wg_timers_event_data_received(struct wg_peer *peer)
1002 {
1003 struct epoch_tracker et;
1004 NET_EPOCH_ENTER(et);
1005 if (atomic_load_bool(&peer->p_enabled)) {
1006 if (!callout_pending(&peer->p_send_keepalive))
1007 callout_reset(&peer->p_send_keepalive,
1008 MSEC_2_TICKS(KEEPALIVE_TIMEOUT * 1000),
1009 wg_timers_run_send_keepalive, peer);
1010 else
1011 atomic_store_bool(&peer->p_need_another_keepalive,
1012 true);
1013 }
1014 NET_EPOCH_EXIT(et);
1015 }
1016
1017 static void
wg_timers_event_any_authenticated_packet_sent(struct wg_peer * peer)1018 wg_timers_event_any_authenticated_packet_sent(struct wg_peer *peer)
1019 {
1020 callout_stop(&peer->p_send_keepalive);
1021 }
1022
1023 static void
wg_timers_event_any_authenticated_packet_received(struct wg_peer * peer)1024 wg_timers_event_any_authenticated_packet_received(struct wg_peer *peer)
1025 {
1026 callout_stop(&peer->p_new_handshake);
1027 }
1028
1029 static void
wg_timers_event_any_authenticated_packet_traversal(struct wg_peer * peer)1030 wg_timers_event_any_authenticated_packet_traversal(struct wg_peer *peer)
1031 {
1032 struct epoch_tracker et;
1033 uint16_t interval;
1034 NET_EPOCH_ENTER(et);
1035 interval = atomic_load_16(&peer->p_persistent_keepalive_interval);
1036 if (atomic_load_bool(&peer->p_enabled) && interval > 0)
1037 callout_reset(&peer->p_persistent_keepalive,
1038 MSEC_2_TICKS(interval * 1000),
1039 wg_timers_run_persistent_keepalive, peer);
1040 NET_EPOCH_EXIT(et);
1041 }
1042
1043 static void
wg_timers_event_handshake_initiated(struct wg_peer * peer)1044 wg_timers_event_handshake_initiated(struct wg_peer *peer)
1045 {
1046 struct epoch_tracker et;
1047 NET_EPOCH_ENTER(et);
1048 if (atomic_load_bool(&peer->p_enabled))
1049 callout_reset(&peer->p_retry_handshake, MSEC_2_TICKS(
1050 REKEY_TIMEOUT * 1000 +
1051 arc4random_uniform(REKEY_TIMEOUT_JITTER)),
1052 wg_timers_run_retry_handshake, peer);
1053 NET_EPOCH_EXIT(et);
1054 }
1055
1056 static void
wg_timers_event_handshake_complete(struct wg_peer * peer)1057 wg_timers_event_handshake_complete(struct wg_peer *peer)
1058 {
1059 struct epoch_tracker et;
1060 NET_EPOCH_ENTER(et);
1061 if (atomic_load_bool(&peer->p_enabled)) {
1062 mtx_lock(&peer->p_handshake_mtx);
1063 callout_stop(&peer->p_retry_handshake);
1064 peer->p_handshake_retries = 0;
1065 getnanotime(&peer->p_handshake_complete);
1066 mtx_unlock(&peer->p_handshake_mtx);
1067 wg_timers_run_send_keepalive(peer);
1068 }
1069 NET_EPOCH_EXIT(et);
1070 }
1071
1072 static void
wg_timers_event_session_derived(struct wg_peer * peer)1073 wg_timers_event_session_derived(struct wg_peer *peer)
1074 {
1075 struct epoch_tracker et;
1076 NET_EPOCH_ENTER(et);
1077 if (atomic_load_bool(&peer->p_enabled))
1078 callout_reset(&peer->p_zero_key_material,
1079 MSEC_2_TICKS(REJECT_AFTER_TIME * 3 * 1000),
1080 wg_timers_run_zero_key_material, peer);
1081 NET_EPOCH_EXIT(et);
1082 }
1083
1084 static void
wg_timers_event_want_initiation(struct wg_peer * peer)1085 wg_timers_event_want_initiation(struct wg_peer *peer)
1086 {
1087 struct epoch_tracker et;
1088 NET_EPOCH_ENTER(et);
1089 if (atomic_load_bool(&peer->p_enabled))
1090 wg_timers_run_send_initiation(peer, false);
1091 NET_EPOCH_EXIT(et);
1092 }
1093
1094 static void
wg_timers_run_send_initiation(struct wg_peer * peer,bool is_retry)1095 wg_timers_run_send_initiation(struct wg_peer *peer, bool is_retry)
1096 {
1097 if (!is_retry)
1098 peer->p_handshake_retries = 0;
1099 if (noise_remote_initiation_expired(peer->p_remote) == ETIMEDOUT)
1100 wg_send_initiation(peer);
1101 }
1102
1103 static void
wg_timers_run_retry_handshake(void * _peer)1104 wg_timers_run_retry_handshake(void *_peer)
1105 {
1106 struct epoch_tracker et;
1107 struct wg_peer *peer = _peer;
1108
1109 mtx_lock(&peer->p_handshake_mtx);
1110 if (peer->p_handshake_retries <= MAX_TIMER_HANDSHAKES) {
1111 peer->p_handshake_retries++;
1112 mtx_unlock(&peer->p_handshake_mtx);
1113
1114 DPRINTF(peer->p_sc, "Handshake for peer %" PRIu64 " did not complete "
1115 "after %d seconds, retrying (try %d)\n", peer->p_id,
1116 REKEY_TIMEOUT, peer->p_handshake_retries + 1);
1117 wg_peer_clear_src(peer);
1118 wg_timers_run_send_initiation(peer, true);
1119 } else {
1120 mtx_unlock(&peer->p_handshake_mtx);
1121
1122 DPRINTF(peer->p_sc, "Handshake for peer %" PRIu64 " did not complete "
1123 "after %d retries, giving up\n", peer->p_id,
1124 MAX_TIMER_HANDSHAKES + 2);
1125
1126 callout_stop(&peer->p_send_keepalive);
1127 wg_queue_purge(&peer->p_stage_queue);
1128 NET_EPOCH_ENTER(et);
1129 if (atomic_load_bool(&peer->p_enabled) &&
1130 !callout_pending(&peer->p_zero_key_material))
1131 callout_reset(&peer->p_zero_key_material,
1132 MSEC_2_TICKS(REJECT_AFTER_TIME * 3 * 1000),
1133 wg_timers_run_zero_key_material, peer);
1134 NET_EPOCH_EXIT(et);
1135 }
1136 }
1137
1138 static void
wg_timers_run_send_keepalive(void * _peer)1139 wg_timers_run_send_keepalive(void *_peer)
1140 {
1141 struct epoch_tracker et;
1142 struct wg_peer *peer = _peer;
1143
1144 wg_send_keepalive(peer);
1145 NET_EPOCH_ENTER(et);
1146 if (atomic_load_bool(&peer->p_enabled) &&
1147 atomic_load_bool(&peer->p_need_another_keepalive)) {
1148 atomic_store_bool(&peer->p_need_another_keepalive, false);
1149 callout_reset(&peer->p_send_keepalive,
1150 MSEC_2_TICKS(KEEPALIVE_TIMEOUT * 1000),
1151 wg_timers_run_send_keepalive, peer);
1152 }
1153 NET_EPOCH_EXIT(et);
1154 }
1155
1156 static void
wg_timers_run_new_handshake(void * _peer)1157 wg_timers_run_new_handshake(void *_peer)
1158 {
1159 struct wg_peer *peer = _peer;
1160
1161 DPRINTF(peer->p_sc, "Retrying handshake with peer %" PRIu64 " because we "
1162 "stopped hearing back after %d seconds\n",
1163 peer->p_id, NEW_HANDSHAKE_TIMEOUT);
1164
1165 wg_peer_clear_src(peer);
1166 wg_timers_run_send_initiation(peer, false);
1167 }
1168
1169 static void
wg_timers_run_zero_key_material(void * _peer)1170 wg_timers_run_zero_key_material(void *_peer)
1171 {
1172 struct wg_peer *peer = _peer;
1173
1174 DPRINTF(peer->p_sc, "Zeroing out keys for peer %" PRIu64 ", since we "
1175 "haven't received a new one in %d seconds\n",
1176 peer->p_id, REJECT_AFTER_TIME * 3);
1177 noise_remote_keypairs_clear(peer->p_remote);
1178 }
1179
1180 static void
wg_timers_run_persistent_keepalive(void * _peer)1181 wg_timers_run_persistent_keepalive(void *_peer)
1182 {
1183 struct wg_peer *peer = _peer;
1184
1185 if (atomic_load_16(&peer->p_persistent_keepalive_interval) > 0)
1186 wg_send_keepalive(peer);
1187 }
1188
1189 /* TODO Handshake */
1190 static void
wg_peer_send_buf(struct wg_peer * peer,uint8_t * buf,size_t len)1191 wg_peer_send_buf(struct wg_peer *peer, uint8_t *buf, size_t len)
1192 {
1193 struct wg_endpoint endpoint;
1194
1195 counter_u64_add(peer->p_tx_bytes, len);
1196 wg_timers_event_any_authenticated_packet_traversal(peer);
1197 wg_timers_event_any_authenticated_packet_sent(peer);
1198 wg_peer_get_endpoint(peer, &endpoint);
1199 wg_send_buf(peer->p_sc, &endpoint, buf, len);
1200 }
1201
1202 static void
wg_send_initiation(struct wg_peer * peer)1203 wg_send_initiation(struct wg_peer *peer)
1204 {
1205 struct wg_pkt_initiation pkt;
1206
1207 if (noise_create_initiation(peer->p_remote, &pkt.s_idx, pkt.ue,
1208 pkt.es, pkt.ets) != 0)
1209 return;
1210
1211 DPRINTF(peer->p_sc, "Sending handshake initiation to peer %" PRIu64 "\n", peer->p_id);
1212
1213 pkt.t = WG_PKT_INITIATION;
1214 cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt,
1215 sizeof(pkt) - sizeof(pkt.m));
1216 wg_peer_send_buf(peer, (uint8_t *)&pkt, sizeof(pkt));
1217 wg_timers_event_handshake_initiated(peer);
1218 }
1219
1220 static void
wg_send_response(struct wg_peer * peer)1221 wg_send_response(struct wg_peer *peer)
1222 {
1223 struct wg_pkt_response pkt;
1224
1225 if (noise_create_response(peer->p_remote, &pkt.s_idx, &pkt.r_idx,
1226 pkt.ue, pkt.en) != 0)
1227 return;
1228
1229 DPRINTF(peer->p_sc, "Sending handshake response to peer %" PRIu64 "\n", peer->p_id);
1230
1231 wg_timers_event_session_derived(peer);
1232 pkt.t = WG_PKT_RESPONSE;
1233 cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt,
1234 sizeof(pkt)-sizeof(pkt.m));
1235 wg_peer_send_buf(peer, (uint8_t*)&pkt, sizeof(pkt));
1236 }
1237
1238 static void
wg_send_cookie(struct wg_softc * sc,struct cookie_macs * cm,uint32_t idx,struct wg_endpoint * e)1239 wg_send_cookie(struct wg_softc *sc, struct cookie_macs *cm, uint32_t idx,
1240 struct wg_endpoint *e)
1241 {
1242 struct wg_pkt_cookie pkt;
1243
1244 DPRINTF(sc, "Sending cookie response for denied handshake message\n");
1245
1246 pkt.t = WG_PKT_COOKIE;
1247 pkt.r_idx = idx;
1248
1249 cookie_checker_create_payload(&sc->sc_cookie, cm, pkt.nonce,
1250 pkt.ec, &e->e_remote.r_sa);
1251 wg_send_buf(sc, e, (uint8_t *)&pkt, sizeof(pkt));
1252 }
1253
1254 static void
wg_send_keepalive(struct wg_peer * peer)1255 wg_send_keepalive(struct wg_peer *peer)
1256 {
1257 struct wg_packet *pkt;
1258 struct mbuf *m;
1259
1260 if (wg_queue_len(&peer->p_stage_queue) > 0)
1261 goto send;
1262 if ((m = m_gethdr(M_NOWAIT, MT_DATA)) == NULL)
1263 return;
1264 if ((pkt = wg_packet_alloc(m)) == NULL) {
1265 m_freem(m);
1266 return;
1267 }
1268 wg_queue_push_staged(&peer->p_stage_queue, pkt);
1269 DPRINTF(peer->p_sc, "Sending keepalive packet to peer %" PRIu64 "\n", peer->p_id);
1270 send:
1271 wg_peer_send_staged(peer);
1272 }
1273
1274 static void
wg_handshake(struct wg_softc * sc,struct wg_packet * pkt)1275 wg_handshake(struct wg_softc *sc, struct wg_packet *pkt)
1276 {
1277 struct wg_pkt_initiation *init;
1278 struct wg_pkt_response *resp;
1279 struct wg_pkt_cookie *cook;
1280 struct wg_endpoint *e;
1281 struct wg_peer *peer;
1282 struct mbuf *m;
1283 struct noise_remote *remote = NULL;
1284 int res;
1285 bool underload = false;
1286 static sbintime_t wg_last_underload; /* sbinuptime */
1287
1288 underload = wg_queue_len(&sc->sc_handshake_queue) >= MAX_QUEUED_HANDSHAKES / 8;
1289 if (underload) {
1290 wg_last_underload = getsbinuptime();
1291 } else if (wg_last_underload) {
1292 underload = wg_last_underload + UNDERLOAD_TIMEOUT * SBT_1S > getsbinuptime();
1293 if (!underload)
1294 wg_last_underload = 0;
1295 }
1296
1297 m = pkt->p_mbuf;
1298 e = &pkt->p_endpoint;
1299
1300 if ((pkt->p_mbuf = m = m_pullup(m, m->m_pkthdr.len)) == NULL)
1301 goto error;
1302
1303 switch (*mtod(m, uint32_t *)) {
1304 case WG_PKT_INITIATION:
1305 init = mtod(m, struct wg_pkt_initiation *);
1306
1307 res = cookie_checker_validate_macs(&sc->sc_cookie, &init->m,
1308 init, sizeof(*init) - sizeof(init->m),
1309 underload, &e->e_remote.r_sa,
1310 if_getvnet(sc->sc_ifp));
1311
1312 if (res == EINVAL) {
1313 DPRINTF(sc, "Invalid initiation MAC\n");
1314 goto error;
1315 } else if (res == ECONNREFUSED) {
1316 DPRINTF(sc, "Handshake ratelimited\n");
1317 goto error;
1318 } else if (res == EAGAIN) {
1319 wg_send_cookie(sc, &init->m, init->s_idx, e);
1320 goto error;
1321 } else if (res != 0) {
1322 panic("unexpected response: %d\n", res);
1323 }
1324
1325 if (noise_consume_initiation(sc->sc_local, &remote,
1326 init->s_idx, init->ue, init->es, init->ets) != 0) {
1327 DPRINTF(sc, "Invalid handshake initiation\n");
1328 goto error;
1329 }
1330
1331 peer = noise_remote_arg(remote);
1332
1333 DPRINTF(sc, "Receiving handshake initiation from peer %" PRIu64 "\n", peer->p_id);
1334
1335 wg_peer_set_endpoint(peer, e);
1336 wg_send_response(peer);
1337 break;
1338 case WG_PKT_RESPONSE:
1339 resp = mtod(m, struct wg_pkt_response *);
1340
1341 res = cookie_checker_validate_macs(&sc->sc_cookie, &resp->m,
1342 resp, sizeof(*resp) - sizeof(resp->m),
1343 underload, &e->e_remote.r_sa,
1344 if_getvnet(sc->sc_ifp));
1345
1346 if (res == EINVAL) {
1347 DPRINTF(sc, "Invalid response MAC\n");
1348 goto error;
1349 } else if (res == ECONNREFUSED) {
1350 DPRINTF(sc, "Handshake ratelimited\n");
1351 goto error;
1352 } else if (res == EAGAIN) {
1353 wg_send_cookie(sc, &resp->m, resp->s_idx, e);
1354 goto error;
1355 } else if (res != 0) {
1356 panic("unexpected response: %d\n", res);
1357 }
1358
1359 if (noise_consume_response(sc->sc_local, &remote,
1360 resp->s_idx, resp->r_idx, resp->ue, resp->en) != 0) {
1361 DPRINTF(sc, "Invalid handshake response\n");
1362 goto error;
1363 }
1364
1365 peer = noise_remote_arg(remote);
1366 DPRINTF(sc, "Receiving handshake response from peer %" PRIu64 "\n", peer->p_id);
1367
1368 wg_peer_set_endpoint(peer, e);
1369 wg_timers_event_session_derived(peer);
1370 wg_timers_event_handshake_complete(peer);
1371 break;
1372 case WG_PKT_COOKIE:
1373 cook = mtod(m, struct wg_pkt_cookie *);
1374
1375 if ((remote = noise_remote_index(sc->sc_local, cook->r_idx)) == NULL) {
1376 DPRINTF(sc, "Unknown cookie index\n");
1377 goto error;
1378 }
1379
1380 peer = noise_remote_arg(remote);
1381
1382 if (cookie_maker_consume_payload(&peer->p_cookie,
1383 cook->nonce, cook->ec) == 0) {
1384 DPRINTF(sc, "Receiving cookie response\n");
1385 } else {
1386 DPRINTF(sc, "Could not decrypt cookie response\n");
1387 goto error;
1388 }
1389
1390 goto not_authenticated;
1391 default:
1392 panic("invalid packet in handshake queue");
1393 }
1394
1395 wg_timers_event_any_authenticated_packet_received(peer);
1396 wg_timers_event_any_authenticated_packet_traversal(peer);
1397
1398 not_authenticated:
1399 counter_u64_add(peer->p_rx_bytes, m->m_pkthdr.len);
1400 if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1);
1401 if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, m->m_pkthdr.len);
1402 error:
1403 if (remote != NULL)
1404 noise_remote_put(remote);
1405 wg_packet_free(pkt);
1406 }
1407
1408 static void
wg_softc_handshake_receive(struct wg_softc * sc)1409 wg_softc_handshake_receive(struct wg_softc *sc)
1410 {
1411 struct wg_packet *pkt;
1412 while ((pkt = wg_queue_dequeue_handshake(&sc->sc_handshake_queue)) != NULL)
1413 wg_handshake(sc, pkt);
1414 }
1415
1416 static void
wg_mbuf_reset(struct mbuf * m)1417 wg_mbuf_reset(struct mbuf *m)
1418 {
1419
1420 struct m_tag *t, *tmp;
1421
1422 /*
1423 * We want to reset the mbuf to a newly allocated state, containing
1424 * just the packet contents. Unfortunately FreeBSD doesn't seem to
1425 * offer this anywhere, so we have to make it up as we go. If we can
1426 * get this in kern/kern_mbuf.c, that would be best.
1427 *
1428 * Notice: this may break things unexpectedly but it is better to fail
1429 * closed in the extreme case than leak informtion in every
1430 * case.
1431 *
1432 * With that said, all this attempts to do is remove any extraneous
1433 * information that could be present.
1434 */
1435
1436 M_ASSERTPKTHDR(m);
1437
1438 m->m_flags &= ~(M_BCAST|M_MCAST|M_VLANTAG|M_PROMISC|M_PROTOFLAGS);
1439
1440 M_HASHTYPE_CLEAR(m);
1441 #ifdef NUMA
1442 m->m_pkthdr.numa_domain = M_NODOM;
1443 #endif
1444 SLIST_FOREACH_SAFE(t, &m->m_pkthdr.tags, m_tag_link, tmp) {
1445 if ((t->m_tag_id != 0 || t->m_tag_cookie != MTAG_WGLOOP) &&
1446 t->m_tag_id != PACKET_TAG_MACLABEL)
1447 m_tag_delete(m, t);
1448 }
1449
1450 KASSERT((m->m_pkthdr.csum_flags & CSUM_SND_TAG) == 0,
1451 ("%s: mbuf %p has a send tag", __func__, m));
1452
1453 m->m_pkthdr.csum_flags = 0;
1454 m->m_pkthdr.PH_per.sixtyfour[0] = 0;
1455 m->m_pkthdr.PH_loc.sixtyfour[0] = 0;
1456 }
1457
1458 static inline unsigned int
calculate_padding(struct wg_packet * pkt)1459 calculate_padding(struct wg_packet *pkt)
1460 {
1461 unsigned int padded_size, last_unit = pkt->p_mbuf->m_pkthdr.len;
1462
1463 /* Keepalive packets don't set p_mtu, but also have a length of zero. */
1464 if (__predict_false(pkt->p_mtu == 0)) {
1465 padded_size = (last_unit + (WG_PKT_PADDING - 1)) &
1466 ~(WG_PKT_PADDING - 1);
1467 return (padded_size - last_unit);
1468 }
1469
1470 if (__predict_false(last_unit > pkt->p_mtu))
1471 last_unit %= pkt->p_mtu;
1472
1473 padded_size = (last_unit + (WG_PKT_PADDING - 1)) & ~(WG_PKT_PADDING - 1);
1474 if (pkt->p_mtu < padded_size)
1475 padded_size = pkt->p_mtu;
1476 return (padded_size - last_unit);
1477 }
1478
1479 static void
wg_encrypt(struct wg_softc * sc,struct wg_packet * pkt)1480 wg_encrypt(struct wg_softc *sc, struct wg_packet *pkt)
1481 {
1482 static const uint8_t padding[WG_PKT_PADDING] = { 0 };
1483 struct wg_pkt_data *data;
1484 struct wg_peer *peer;
1485 struct noise_remote *remote;
1486 struct mbuf *m;
1487 uint32_t idx;
1488 unsigned int padlen;
1489 enum wg_ring_state state = WG_PACKET_DEAD;
1490
1491 remote = noise_keypair_remote(pkt->p_keypair);
1492 peer = noise_remote_arg(remote);
1493 m = pkt->p_mbuf;
1494
1495 /* Pad the packet */
1496 padlen = calculate_padding(pkt);
1497 if (padlen != 0 && !m_append(m, padlen, padding))
1498 goto out;
1499
1500 /* Do encryption */
1501 if (noise_keypair_encrypt(pkt->p_keypair, &idx, pkt->p_nonce, m) != 0)
1502 goto out;
1503
1504 /* Put header into packet */
1505 M_PREPEND(m, sizeof(struct wg_pkt_data), M_NOWAIT);
1506 if (m == NULL)
1507 goto out;
1508 data = mtod(m, struct wg_pkt_data *);
1509 data->t = WG_PKT_DATA;
1510 data->r_idx = idx;
1511 data->nonce = htole64(pkt->p_nonce);
1512
1513 wg_mbuf_reset(m);
1514 state = WG_PACKET_CRYPTED;
1515 out:
1516 pkt->p_mbuf = m;
1517 atomic_store_rel_int(&pkt->p_state, state);
1518 GROUPTASK_ENQUEUE(&peer->p_send);
1519 noise_remote_put(remote);
1520 }
1521
1522 static void
wg_decrypt(struct wg_softc * sc,struct wg_packet * pkt)1523 wg_decrypt(struct wg_softc *sc, struct wg_packet *pkt)
1524 {
1525 struct wg_peer *peer, *allowed_peer;
1526 struct noise_remote *remote;
1527 struct mbuf *m;
1528 int len;
1529 enum wg_ring_state state = WG_PACKET_DEAD;
1530
1531 remote = noise_keypair_remote(pkt->p_keypair);
1532 peer = noise_remote_arg(remote);
1533 m = pkt->p_mbuf;
1534
1535 /* Read nonce and then adjust to remove the header. */
1536 pkt->p_nonce = le64toh(mtod(m, struct wg_pkt_data *)->nonce);
1537 m_adj(m, sizeof(struct wg_pkt_data));
1538
1539 if (noise_keypair_decrypt(pkt->p_keypair, pkt->p_nonce, m) != 0)
1540 goto out;
1541
1542 /* A packet with length 0 is a keepalive packet */
1543 if (__predict_false(m->m_pkthdr.len == 0)) {
1544 DPRINTF(sc, "Receiving keepalive packet from peer "
1545 "%" PRIu64 "\n", peer->p_id);
1546 state = WG_PACKET_CRYPTED;
1547 goto out;
1548 }
1549
1550 /*
1551 * We can let the network stack handle the intricate validation of the
1552 * IP header, we just worry about the sizeof and the version, so we can
1553 * read the source address in wg_aip_lookup.
1554 */
1555
1556 if (determine_af_and_pullup(&m, &pkt->p_af) == 0) {
1557 if (pkt->p_af == AF_INET) {
1558 struct ip *ip = mtod(m, struct ip *);
1559 allowed_peer = wg_aip_lookup(sc, AF_INET, &ip->ip_src);
1560 len = ntohs(ip->ip_len);
1561 if (len >= sizeof(struct ip) && len < m->m_pkthdr.len)
1562 m_adj(m, len - m->m_pkthdr.len);
1563 } else if (pkt->p_af == AF_INET6) {
1564 struct ip6_hdr *ip6 = mtod(m, struct ip6_hdr *);
1565 allowed_peer = wg_aip_lookup(sc, AF_INET6, &ip6->ip6_src);
1566 len = ntohs(ip6->ip6_plen) + sizeof(struct ip6_hdr);
1567 if (len < m->m_pkthdr.len)
1568 m_adj(m, len - m->m_pkthdr.len);
1569 } else
1570 panic("determine_af_and_pullup returned unexpected value");
1571 } else {
1572 DPRINTF(sc, "Packet is neither ipv4 nor ipv6 from peer %" PRIu64 "\n", peer->p_id);
1573 goto out;
1574 }
1575
1576 /* We only want to compare the address, not dereference, so drop the ref. */
1577 if (allowed_peer != NULL)
1578 noise_remote_put(allowed_peer->p_remote);
1579
1580 if (__predict_false(peer != allowed_peer)) {
1581 DPRINTF(sc, "Packet has unallowed src IP from peer %" PRIu64 "\n", peer->p_id);
1582 goto out;
1583 }
1584
1585 wg_mbuf_reset(m);
1586 state = WG_PACKET_CRYPTED;
1587 out:
1588 pkt->p_mbuf = m;
1589 atomic_store_rel_int(&pkt->p_state, state);
1590 GROUPTASK_ENQUEUE(&peer->p_recv);
1591 noise_remote_put(remote);
1592 }
1593
1594 static void
wg_softc_decrypt(struct wg_softc * sc)1595 wg_softc_decrypt(struct wg_softc *sc)
1596 {
1597 struct wg_packet *pkt;
1598
1599 while ((pkt = wg_queue_dequeue_parallel(&sc->sc_decrypt_parallel)) != NULL)
1600 wg_decrypt(sc, pkt);
1601 }
1602
1603 static void
wg_softc_encrypt(struct wg_softc * sc)1604 wg_softc_encrypt(struct wg_softc *sc)
1605 {
1606 struct wg_packet *pkt;
1607
1608 while ((pkt = wg_queue_dequeue_parallel(&sc->sc_encrypt_parallel)) != NULL)
1609 wg_encrypt(sc, pkt);
1610 }
1611
1612 static void
wg_encrypt_dispatch(struct wg_softc * sc)1613 wg_encrypt_dispatch(struct wg_softc *sc)
1614 {
1615 /*
1616 * The update to encrypt_last_cpu is racey such that we may
1617 * reschedule the task for the same CPU multiple times, but
1618 * the race doesn't really matter.
1619 */
1620 u_int cpu = (sc->sc_encrypt_last_cpu + 1) % mp_ncpus;
1621 sc->sc_encrypt_last_cpu = cpu;
1622 GROUPTASK_ENQUEUE(&sc->sc_encrypt[cpu]);
1623 }
1624
1625 static void
wg_decrypt_dispatch(struct wg_softc * sc)1626 wg_decrypt_dispatch(struct wg_softc *sc)
1627 {
1628 u_int cpu = (sc->sc_decrypt_last_cpu + 1) % mp_ncpus;
1629 sc->sc_decrypt_last_cpu = cpu;
1630 GROUPTASK_ENQUEUE(&sc->sc_decrypt[cpu]);
1631 }
1632
1633 static void
wg_deliver_out(struct wg_peer * peer)1634 wg_deliver_out(struct wg_peer *peer)
1635 {
1636 struct wg_endpoint endpoint;
1637 struct wg_softc *sc = peer->p_sc;
1638 struct wg_packet *pkt;
1639 struct mbuf *m;
1640 int rc, len;
1641
1642 wg_peer_get_endpoint(peer, &endpoint);
1643
1644 while ((pkt = wg_queue_dequeue_serial(&peer->p_encrypt_serial)) != NULL) {
1645 if (atomic_load_acq_int(&pkt->p_state) != WG_PACKET_CRYPTED)
1646 goto error;
1647
1648 m = pkt->p_mbuf;
1649 pkt->p_mbuf = NULL;
1650
1651 len = m->m_pkthdr.len;
1652
1653 wg_timers_event_any_authenticated_packet_traversal(peer);
1654 wg_timers_event_any_authenticated_packet_sent(peer);
1655 rc = wg_send(sc, &endpoint, m);
1656 if (rc == 0) {
1657 if (len > (sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN))
1658 wg_timers_event_data_sent(peer);
1659 counter_u64_add(peer->p_tx_bytes, len);
1660 } else if (rc == EADDRNOTAVAIL) {
1661 wg_peer_clear_src(peer);
1662 wg_peer_get_endpoint(peer, &endpoint);
1663 goto error;
1664 } else {
1665 goto error;
1666 }
1667 wg_packet_free(pkt);
1668 if (noise_keep_key_fresh_send(peer->p_remote))
1669 wg_timers_event_want_initiation(peer);
1670 continue;
1671 error:
1672 if_inc_counter(sc->sc_ifp, IFCOUNTER_OERRORS, 1);
1673 wg_packet_free(pkt);
1674 }
1675 }
1676
1677 #ifdef DEV_NETMAP
1678 /*
1679 * Hand a packet to the netmap RX ring, via netmap's
1680 * freebsd_generic_rx_handler().
1681 */
1682 static void
wg_deliver_netmap(if_t ifp,struct mbuf * m,int af)1683 wg_deliver_netmap(if_t ifp, struct mbuf *m, int af)
1684 {
1685 struct ether_header *eh;
1686
1687 M_PREPEND(m, ETHER_HDR_LEN, M_NOWAIT);
1688 if (__predict_false(m == NULL)) {
1689 if_inc_counter(ifp, IFCOUNTER_IQDROPS, 1);
1690 return;
1691 }
1692
1693 eh = mtod(m, struct ether_header *);
1694 eh->ether_type = af == AF_INET ?
1695 htons(ETHERTYPE_IP) : htons(ETHERTYPE_IPV6);
1696 memcpy(eh->ether_shost, "\x02\x02\x02\x02\x02\x02", ETHER_ADDR_LEN);
1697 memcpy(eh->ether_dhost, "\xff\xff\xff\xff\xff\xff", ETHER_ADDR_LEN);
1698 if_input(ifp, m);
1699 }
1700 #endif
1701
1702 static void
wg_deliver_in(struct wg_peer * peer)1703 wg_deliver_in(struct wg_peer *peer)
1704 {
1705 struct wg_softc *sc = peer->p_sc;
1706 if_t ifp = sc->sc_ifp;
1707 struct wg_packet *pkt;
1708 struct mbuf *m;
1709 struct epoch_tracker et;
1710 int af;
1711
1712 while ((pkt = wg_queue_dequeue_serial(&peer->p_decrypt_serial)) != NULL) {
1713 if (atomic_load_acq_int(&pkt->p_state) != WG_PACKET_CRYPTED)
1714 goto error;
1715
1716 m = pkt->p_mbuf;
1717 if (noise_keypair_nonce_check(pkt->p_keypair, pkt->p_nonce) != 0)
1718 goto error;
1719
1720 if (noise_keypair_received_with(pkt->p_keypair) == ECONNRESET)
1721 wg_timers_event_handshake_complete(peer);
1722
1723 wg_timers_event_any_authenticated_packet_received(peer);
1724 wg_timers_event_any_authenticated_packet_traversal(peer);
1725 wg_peer_set_endpoint(peer, &pkt->p_endpoint);
1726
1727 counter_u64_add(peer->p_rx_bytes, m->m_pkthdr.len +
1728 sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN);
1729 if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1);
1730 if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, m->m_pkthdr.len +
1731 sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN);
1732
1733 if (m->m_pkthdr.len == 0)
1734 goto done;
1735
1736 af = pkt->p_af;
1737 MPASS(af == AF_INET || af == AF_INET6);
1738 pkt->p_mbuf = NULL;
1739
1740 m->m_pkthdr.rcvif = ifp;
1741
1742 NET_EPOCH_ENTER(et);
1743 BPF_MTAP2_AF(ifp, m, af);
1744
1745 CURVNET_SET(if_getvnet(ifp));
1746 M_SETFIB(m, if_getfib(ifp));
1747 #ifdef DEV_NETMAP
1748 if ((if_getcapenable(ifp) & IFCAP_NETMAP) != 0)
1749 wg_deliver_netmap(ifp, m, af);
1750 else
1751 #endif
1752 if (af == AF_INET)
1753 netisr_dispatch(NETISR_IP, m);
1754 else if (af == AF_INET6)
1755 netisr_dispatch(NETISR_IPV6, m);
1756 CURVNET_RESTORE();
1757 NET_EPOCH_EXIT(et);
1758
1759 wg_timers_event_data_received(peer);
1760
1761 done:
1762 if (noise_keep_key_fresh_recv(peer->p_remote))
1763 wg_timers_event_want_initiation(peer);
1764 wg_packet_free(pkt);
1765 continue;
1766 error:
1767 if_inc_counter(ifp, IFCOUNTER_IERRORS, 1);
1768 wg_packet_free(pkt);
1769 }
1770 }
1771
1772 static struct wg_packet *
wg_packet_alloc(struct mbuf * m)1773 wg_packet_alloc(struct mbuf *m)
1774 {
1775 struct wg_packet *pkt;
1776
1777 if ((pkt = uma_zalloc(wg_packet_zone, M_NOWAIT | M_ZERO)) == NULL)
1778 return (NULL);
1779 pkt->p_mbuf = m;
1780 return (pkt);
1781 }
1782
1783 static void
wg_packet_free(struct wg_packet * pkt)1784 wg_packet_free(struct wg_packet *pkt)
1785 {
1786 if (pkt->p_keypair != NULL)
1787 noise_keypair_put(pkt->p_keypair);
1788 if (pkt->p_mbuf != NULL)
1789 m_freem(pkt->p_mbuf);
1790 uma_zfree(wg_packet_zone, pkt);
1791 }
1792
1793 static void
wg_queue_init(struct wg_queue * queue,const char * name)1794 wg_queue_init(struct wg_queue *queue, const char *name)
1795 {
1796 mtx_init(&queue->q_mtx, name, NULL, MTX_DEF);
1797 STAILQ_INIT(&queue->q_queue);
1798 queue->q_len = 0;
1799 }
1800
1801 static void
wg_queue_deinit(struct wg_queue * queue)1802 wg_queue_deinit(struct wg_queue *queue)
1803 {
1804 wg_queue_purge(queue);
1805 mtx_destroy(&queue->q_mtx);
1806 }
1807
1808 static size_t
wg_queue_len(struct wg_queue * queue)1809 wg_queue_len(struct wg_queue *queue)
1810 {
1811 return (queue->q_len);
1812 }
1813
1814 static int
wg_queue_enqueue_handshake(struct wg_queue * hs,struct wg_packet * pkt)1815 wg_queue_enqueue_handshake(struct wg_queue *hs, struct wg_packet *pkt)
1816 {
1817 int ret = 0;
1818 mtx_lock(&hs->q_mtx);
1819 if (hs->q_len < MAX_QUEUED_HANDSHAKES) {
1820 STAILQ_INSERT_TAIL(&hs->q_queue, pkt, p_parallel);
1821 hs->q_len++;
1822 } else {
1823 ret = ENOBUFS;
1824 }
1825 mtx_unlock(&hs->q_mtx);
1826 if (ret != 0)
1827 wg_packet_free(pkt);
1828 return (ret);
1829 }
1830
1831 static struct wg_packet *
wg_queue_dequeue_handshake(struct wg_queue * hs)1832 wg_queue_dequeue_handshake(struct wg_queue *hs)
1833 {
1834 struct wg_packet *pkt;
1835 mtx_lock(&hs->q_mtx);
1836 if ((pkt = STAILQ_FIRST(&hs->q_queue)) != NULL) {
1837 STAILQ_REMOVE_HEAD(&hs->q_queue, p_parallel);
1838 hs->q_len--;
1839 }
1840 mtx_unlock(&hs->q_mtx);
1841 return (pkt);
1842 }
1843
1844 static void
wg_queue_push_staged(struct wg_queue * staged,struct wg_packet * pkt)1845 wg_queue_push_staged(struct wg_queue *staged, struct wg_packet *pkt)
1846 {
1847 struct wg_packet *old = NULL;
1848
1849 mtx_lock(&staged->q_mtx);
1850 if (staged->q_len >= MAX_STAGED_PKT) {
1851 old = STAILQ_FIRST(&staged->q_queue);
1852 STAILQ_REMOVE_HEAD(&staged->q_queue, p_parallel);
1853 staged->q_len--;
1854 }
1855 STAILQ_INSERT_TAIL(&staged->q_queue, pkt, p_parallel);
1856 staged->q_len++;
1857 mtx_unlock(&staged->q_mtx);
1858
1859 if (old != NULL)
1860 wg_packet_free(old);
1861 }
1862
1863 static void
wg_queue_enlist_staged(struct wg_queue * staged,struct wg_packet_list * list)1864 wg_queue_enlist_staged(struct wg_queue *staged, struct wg_packet_list *list)
1865 {
1866 struct wg_packet *pkt, *tpkt;
1867 STAILQ_FOREACH_SAFE(pkt, list, p_parallel, tpkt)
1868 wg_queue_push_staged(staged, pkt);
1869 }
1870
1871 static void
wg_queue_delist_staged(struct wg_queue * staged,struct wg_packet_list * list)1872 wg_queue_delist_staged(struct wg_queue *staged, struct wg_packet_list *list)
1873 {
1874 STAILQ_INIT(list);
1875 mtx_lock(&staged->q_mtx);
1876 STAILQ_CONCAT(list, &staged->q_queue);
1877 staged->q_len = 0;
1878 mtx_unlock(&staged->q_mtx);
1879 }
1880
1881 static void
wg_queue_purge(struct wg_queue * staged)1882 wg_queue_purge(struct wg_queue *staged)
1883 {
1884 struct wg_packet_list list;
1885 struct wg_packet *pkt, *tpkt;
1886 wg_queue_delist_staged(staged, &list);
1887 STAILQ_FOREACH_SAFE(pkt, &list, p_parallel, tpkt)
1888 wg_packet_free(pkt);
1889 }
1890
1891 static int
wg_queue_both(struct wg_queue * parallel,struct wg_queue * serial,struct wg_packet * pkt)1892 wg_queue_both(struct wg_queue *parallel, struct wg_queue *serial, struct wg_packet *pkt)
1893 {
1894 pkt->p_state = WG_PACKET_UNCRYPTED;
1895
1896 mtx_lock(&serial->q_mtx);
1897 if (serial->q_len < MAX_QUEUED_PKT) {
1898 serial->q_len++;
1899 STAILQ_INSERT_TAIL(&serial->q_queue, pkt, p_serial);
1900 } else {
1901 mtx_unlock(&serial->q_mtx);
1902 wg_packet_free(pkt);
1903 return (ENOBUFS);
1904 }
1905 mtx_unlock(&serial->q_mtx);
1906
1907 mtx_lock(¶llel->q_mtx);
1908 if (parallel->q_len < MAX_QUEUED_PKT) {
1909 parallel->q_len++;
1910 STAILQ_INSERT_TAIL(¶llel->q_queue, pkt, p_parallel);
1911 } else {
1912 mtx_unlock(¶llel->q_mtx);
1913 pkt->p_state = WG_PACKET_DEAD;
1914 return (ENOBUFS);
1915 }
1916 mtx_unlock(¶llel->q_mtx);
1917
1918 return (0);
1919 }
1920
1921 static struct wg_packet *
wg_queue_dequeue_serial(struct wg_queue * serial)1922 wg_queue_dequeue_serial(struct wg_queue *serial)
1923 {
1924 struct wg_packet *pkt = NULL;
1925 mtx_lock(&serial->q_mtx);
1926 if (serial->q_len > 0 && STAILQ_FIRST(&serial->q_queue)->p_state != WG_PACKET_UNCRYPTED) {
1927 serial->q_len--;
1928 pkt = STAILQ_FIRST(&serial->q_queue);
1929 STAILQ_REMOVE_HEAD(&serial->q_queue, p_serial);
1930 }
1931 mtx_unlock(&serial->q_mtx);
1932 return (pkt);
1933 }
1934
1935 static struct wg_packet *
wg_queue_dequeue_parallel(struct wg_queue * parallel)1936 wg_queue_dequeue_parallel(struct wg_queue *parallel)
1937 {
1938 struct wg_packet *pkt = NULL;
1939 mtx_lock(¶llel->q_mtx);
1940 if (parallel->q_len > 0) {
1941 parallel->q_len--;
1942 pkt = STAILQ_FIRST(¶llel->q_queue);
1943 STAILQ_REMOVE_HEAD(¶llel->q_queue, p_parallel);
1944 }
1945 mtx_unlock(¶llel->q_mtx);
1946 return (pkt);
1947 }
1948
1949 static bool
wg_input(struct mbuf * m,int offset,struct inpcb * inpcb,const struct sockaddr * sa,void * _sc)1950 wg_input(struct mbuf *m, int offset, struct inpcb *inpcb,
1951 const struct sockaddr *sa, void *_sc)
1952 {
1953 #ifdef INET
1954 const struct sockaddr_in *sin;
1955 #endif
1956 #ifdef INET6
1957 const struct sockaddr_in6 *sin6;
1958 #endif
1959 struct noise_remote *remote;
1960 struct wg_pkt_data *data;
1961 struct wg_packet *pkt;
1962 struct wg_peer *peer;
1963 struct wg_softc *sc = _sc;
1964 struct mbuf *defragged;
1965
1966 defragged = m_defrag(m, M_NOWAIT);
1967 if (defragged)
1968 m = defragged;
1969 m = m_unshare(m, M_NOWAIT);
1970 if (!m) {
1971 if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
1972 return true;
1973 }
1974
1975 /* Caller provided us with `sa`, no need for this header. */
1976 m_adj(m, offset + sizeof(struct udphdr));
1977
1978 /* Pullup enough to read packet type */
1979 if ((m = m_pullup(m, sizeof(uint32_t))) == NULL) {
1980 if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
1981 return true;
1982 }
1983
1984 if ((pkt = wg_packet_alloc(m)) == NULL) {
1985 if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
1986 m_freem(m);
1987 return true;
1988 }
1989
1990 /* Save send/recv address and port for later. */
1991 switch (sa->sa_family) {
1992 #ifdef INET
1993 case AF_INET:
1994 sin = (const struct sockaddr_in *)sa;
1995 pkt->p_endpoint.e_remote.r_sin = sin[0];
1996 pkt->p_endpoint.e_local.l_in = sin[1].sin_addr;
1997 break;
1998 #endif
1999 #ifdef INET6
2000 case AF_INET6:
2001 sin6 = (const struct sockaddr_in6 *)sa;
2002 pkt->p_endpoint.e_remote.r_sin6 = sin6[0];
2003 pkt->p_endpoint.e_local.l_in6 = sin6[1].sin6_addr;
2004 break;
2005 #endif
2006 default:
2007 goto error;
2008 }
2009
2010 if ((m->m_pkthdr.len == sizeof(struct wg_pkt_initiation) &&
2011 *mtod(m, uint32_t *) == WG_PKT_INITIATION) ||
2012 (m->m_pkthdr.len == sizeof(struct wg_pkt_response) &&
2013 *mtod(m, uint32_t *) == WG_PKT_RESPONSE) ||
2014 (m->m_pkthdr.len == sizeof(struct wg_pkt_cookie) &&
2015 *mtod(m, uint32_t *) == WG_PKT_COOKIE)) {
2016
2017 if (wg_queue_enqueue_handshake(&sc->sc_handshake_queue, pkt) != 0) {
2018 if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
2019 DPRINTF(sc, "Dropping handshake packet\n");
2020 }
2021 GROUPTASK_ENQUEUE(&sc->sc_handshake);
2022 } else if (m->m_pkthdr.len >= sizeof(struct wg_pkt_data) +
2023 NOISE_AUTHTAG_LEN && *mtod(m, uint32_t *) == WG_PKT_DATA) {
2024
2025 /* Pullup whole header to read r_idx below. */
2026 if ((pkt->p_mbuf = m_pullup(m, sizeof(struct wg_pkt_data))) == NULL)
2027 goto error;
2028
2029 data = mtod(pkt->p_mbuf, struct wg_pkt_data *);
2030 if ((pkt->p_keypair = noise_keypair_lookup(sc->sc_local, data->r_idx)) == NULL)
2031 goto error;
2032
2033 remote = noise_keypair_remote(pkt->p_keypair);
2034 peer = noise_remote_arg(remote);
2035 if (wg_queue_both(&sc->sc_decrypt_parallel, &peer->p_decrypt_serial, pkt) != 0)
2036 if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
2037 wg_decrypt_dispatch(sc);
2038 noise_remote_put(remote);
2039 } else {
2040 goto error;
2041 }
2042 return true;
2043 error:
2044 if_inc_counter(sc->sc_ifp, IFCOUNTER_IERRORS, 1);
2045 wg_packet_free(pkt);
2046 return true;
2047 }
2048
2049 static void
wg_peer_send_staged(struct wg_peer * peer)2050 wg_peer_send_staged(struct wg_peer *peer)
2051 {
2052 struct wg_packet_list list;
2053 struct noise_keypair *keypair;
2054 struct wg_packet *pkt, *tpkt;
2055 struct wg_softc *sc = peer->p_sc;
2056
2057 wg_queue_delist_staged(&peer->p_stage_queue, &list);
2058
2059 if (STAILQ_EMPTY(&list))
2060 return;
2061
2062 if ((keypair = noise_keypair_current(peer->p_remote)) == NULL)
2063 goto error;
2064
2065 STAILQ_FOREACH(pkt, &list, p_parallel) {
2066 if (noise_keypair_nonce_next(keypair, &pkt->p_nonce) != 0)
2067 goto error_keypair;
2068 }
2069 STAILQ_FOREACH_SAFE(pkt, &list, p_parallel, tpkt) {
2070 pkt->p_keypair = noise_keypair_ref(keypair);
2071 if (wg_queue_both(&sc->sc_encrypt_parallel, &peer->p_encrypt_serial, pkt) != 0)
2072 if_inc_counter(sc->sc_ifp, IFCOUNTER_OQDROPS, 1);
2073 }
2074 wg_encrypt_dispatch(sc);
2075 noise_keypair_put(keypair);
2076 return;
2077
2078 error_keypair:
2079 noise_keypair_put(keypair);
2080 error:
2081 wg_queue_enlist_staged(&peer->p_stage_queue, &list);
2082 wg_timers_event_want_initiation(peer);
2083 }
2084
2085 static inline void
xmit_err(if_t ifp,struct mbuf * m,struct wg_packet * pkt,sa_family_t af)2086 xmit_err(if_t ifp, struct mbuf *m, struct wg_packet *pkt, sa_family_t af)
2087 {
2088 if_inc_counter(ifp, IFCOUNTER_OERRORS, 1);
2089 switch (af) {
2090 #ifdef INET
2091 case AF_INET:
2092 icmp_error(m, ICMP_UNREACH, ICMP_UNREACH_HOST, 0, 0);
2093 if (pkt)
2094 pkt->p_mbuf = NULL;
2095 m = NULL;
2096 break;
2097 #endif
2098 #ifdef INET6
2099 case AF_INET6:
2100 icmp6_error(m, ICMP6_DST_UNREACH, 0, 0);
2101 if (pkt)
2102 pkt->p_mbuf = NULL;
2103 m = NULL;
2104 break;
2105 #endif
2106 }
2107 if (pkt)
2108 wg_packet_free(pkt);
2109 else if (m)
2110 m_freem(m);
2111 }
2112
2113 static int
wg_xmit(if_t ifp,struct mbuf * m,sa_family_t af,uint32_t mtu)2114 wg_xmit(if_t ifp, struct mbuf *m, sa_family_t af, uint32_t mtu)
2115 {
2116 struct wg_packet *pkt = NULL;
2117 struct wg_softc *sc = if_getsoftc(ifp);
2118 struct wg_peer *peer;
2119 int rc = 0;
2120 sa_family_t peer_af;
2121
2122 /* Work around lifetime issue in the ipv6 mld code. */
2123 if (__predict_false((if_getflags(ifp) & IFF_DYING) || !sc)) {
2124 rc = ENXIO;
2125 goto err_xmit;
2126 }
2127
2128 if ((pkt = wg_packet_alloc(m)) == NULL) {
2129 rc = ENOBUFS;
2130 goto err_xmit;
2131 }
2132 pkt->p_mtu = mtu;
2133 pkt->p_af = af;
2134
2135 if (af == AF_INET) {
2136 peer = wg_aip_lookup(sc, AF_INET, &mtod(m, struct ip *)->ip_dst);
2137 } else if (af == AF_INET6) {
2138 peer = wg_aip_lookup(sc, AF_INET6, &mtod(m, struct ip6_hdr *)->ip6_dst);
2139 } else {
2140 rc = EAFNOSUPPORT;
2141 goto err_xmit;
2142 }
2143
2144 BPF_MTAP2_AF(ifp, m, pkt->p_af);
2145
2146 if (__predict_false(peer == NULL)) {
2147 rc = ENETUNREACH;
2148 goto err_xmit;
2149 }
2150
2151 if (__predict_false(if_tunnel_check_nesting(ifp, m, MTAG_WGLOOP, MAX_LOOPS))) {
2152 DPRINTF(sc, "Packet looped");
2153 rc = ELOOP;
2154 goto err_peer;
2155 }
2156
2157 peer_af = peer->p_endpoint.e_remote.r_sa.sa_family;
2158 if (__predict_false(peer_af != AF_INET && peer_af != AF_INET6)) {
2159 DPRINTF(sc, "No valid endpoint has been configured or "
2160 "discovered for peer %" PRIu64 "\n", peer->p_id);
2161 rc = EHOSTUNREACH;
2162 goto err_peer;
2163 }
2164
2165 wg_queue_push_staged(&peer->p_stage_queue, pkt);
2166 wg_peer_send_staged(peer);
2167 noise_remote_put(peer->p_remote);
2168 return (0);
2169
2170 err_peer:
2171 noise_remote_put(peer->p_remote);
2172 err_xmit:
2173 xmit_err(ifp, m, pkt, af);
2174 return (rc);
2175 }
2176
2177 static inline int
determine_af_and_pullup(struct mbuf ** m,sa_family_t * af)2178 determine_af_and_pullup(struct mbuf **m, sa_family_t *af)
2179 {
2180 u_char ipv;
2181 if ((*m)->m_pkthdr.len >= sizeof(struct ip6_hdr))
2182 *m = m_pullup(*m, sizeof(struct ip6_hdr));
2183 else if ((*m)->m_pkthdr.len >= sizeof(struct ip))
2184 *m = m_pullup(*m, sizeof(struct ip));
2185 else
2186 return (EAFNOSUPPORT);
2187 if (*m == NULL)
2188 return (ENOBUFS);
2189 ipv = mtod(*m, struct ip *)->ip_v;
2190 if (ipv == 4)
2191 *af = AF_INET;
2192 else if (ipv == 6 && (*m)->m_pkthdr.len >= sizeof(struct ip6_hdr))
2193 *af = AF_INET6;
2194 else
2195 return (EAFNOSUPPORT);
2196 return (0);
2197 }
2198
2199 static int
determine_ethertype_and_pullup(struct mbuf ** m,int * etp)2200 determine_ethertype_and_pullup(struct mbuf **m, int *etp)
2201 {
2202 struct ether_header *eh;
2203
2204 *m = m_pullup(*m, sizeof(struct ether_header));
2205 if (__predict_false(*m == NULL))
2206 return (ENOBUFS);
2207 eh = mtod(*m, struct ether_header *);
2208 *etp = ntohs(eh->ether_type);
2209 if (*etp != ETHERTYPE_IP && *etp != ETHERTYPE_IPV6)
2210 return (EAFNOSUPPORT);
2211 return (0);
2212 }
2213
2214 /*
2215 * This should only be invoked by netmap, via nm_os_generic_xmit_frame(), to
2216 * transmit packets from the netmap TX ring.
2217 */
2218 static int
wg_transmit(if_t ifp,struct mbuf * m)2219 wg_transmit(if_t ifp, struct mbuf *m)
2220 {
2221 sa_family_t af;
2222 int et, ret;
2223 struct mbuf *defragged;
2224
2225 KASSERT((if_getcapenable(ifp) & IFCAP_NETMAP) != 0,
2226 ("%s: ifp %p is not in netmap mode", __func__, ifp));
2227
2228 defragged = m_defrag(m, M_NOWAIT);
2229 if (defragged)
2230 m = defragged;
2231 m = m_unshare(m, M_NOWAIT);
2232 if (!m) {
2233 xmit_err(ifp, m, NULL, AF_UNSPEC);
2234 return (ENOBUFS);
2235 }
2236
2237 ret = determine_ethertype_and_pullup(&m, &et);
2238 if (ret) {
2239 xmit_err(ifp, m, NULL, AF_UNSPEC);
2240 return (ret);
2241 }
2242 m_adj(m, sizeof(struct ether_header));
2243
2244 ret = determine_af_and_pullup(&m, &af);
2245 if (ret) {
2246 xmit_err(ifp, m, NULL, AF_UNSPEC);
2247 return (ret);
2248 }
2249
2250 /*
2251 * netmap only gets to see transient errors, since it handles errors by
2252 * refusing to advance the transmit ring and retrying later.
2253 */
2254 ret = wg_xmit(ifp, m, af, if_getmtu(ifp));
2255 if (ret == ENOBUFS)
2256 return (ret);
2257 return (0);
2258 }
2259
2260 #ifdef DEV_NETMAP
2261 /*
2262 * This should only be invoked by netmap, via nm_os_send_up(), to process
2263 * packets from the host TX ring.
2264 */
2265 static void
wg_if_input(if_t ifp,struct mbuf * m)2266 wg_if_input(if_t ifp, struct mbuf *m)
2267 {
2268 int et;
2269
2270 KASSERT((if_getcapenable(ifp) & IFCAP_NETMAP) != 0,
2271 ("%s: ifp %p is not in netmap mode", __func__, ifp));
2272
2273 if (determine_ethertype_and_pullup(&m, &et) != 0) {
2274 if_inc_counter(ifp, IFCOUNTER_IERRORS, 1);
2275 m_freem(m);
2276 return;
2277 }
2278 CURVNET_SET(if_getvnet(ifp));
2279 switch (et) {
2280 case ETHERTYPE_IP:
2281 m_adj(m, sizeof(struct ether_header));
2282 netisr_dispatch(NETISR_IP, m);
2283 break;
2284 case ETHERTYPE_IPV6:
2285 m_adj(m, sizeof(struct ether_header));
2286 netisr_dispatch(NETISR_IPV6, m);
2287 break;
2288 default:
2289 __assert_unreachable();
2290 }
2291 CURVNET_RESTORE();
2292 }
2293
2294 /*
2295 * Deliver a packet to the host RX ring. Because the interface is in netmap
2296 * mode, the if_transmit() call should pass the packet to netmap_transmit().
2297 */
2298 static int
wg_xmit_netmap(if_t ifp,struct mbuf * m,int af)2299 wg_xmit_netmap(if_t ifp, struct mbuf *m, int af)
2300 {
2301 struct ether_header *eh;
2302
2303 if (__predict_false(if_tunnel_check_nesting(ifp, m, MTAG_WGLOOP,
2304 MAX_LOOPS))) {
2305 printf("%s:%d\n", __func__, __LINE__);
2306 if_inc_counter(ifp, IFCOUNTER_IERRORS, 1);
2307 m_freem(m);
2308 return (ELOOP);
2309 }
2310
2311 M_PREPEND(m, ETHER_HDR_LEN, M_NOWAIT);
2312 if (__predict_false(m == NULL)) {
2313 if_inc_counter(ifp, IFCOUNTER_IQDROPS, 1);
2314 return (ENOBUFS);
2315 }
2316
2317 eh = mtod(m, struct ether_header *);
2318 eh->ether_type = af == AF_INET ?
2319 htons(ETHERTYPE_IP) : htons(ETHERTYPE_IPV6);
2320 memcpy(eh->ether_shost, "\x06\x06\x06\x06\x06\x06", ETHER_ADDR_LEN);
2321 memcpy(eh->ether_dhost, "\xff\xff\xff\xff\xff\xff", ETHER_ADDR_LEN);
2322 return (if_transmit(ifp, m));
2323 }
2324 #endif /* DEV_NETMAP */
2325
2326 static int
wg_output(if_t ifp,struct mbuf * m,const struct sockaddr * dst,struct route * ro)2327 wg_output(if_t ifp, struct mbuf *m, const struct sockaddr *dst, struct route *ro)
2328 {
2329 sa_family_t parsed_af;
2330 uint32_t af, mtu;
2331 int ret;
2332 struct mbuf *defragged;
2333
2334 /* BPF writes need to be handled specially. */
2335 if (dst->sa_family == AF_UNSPEC || dst->sa_family == pseudo_AF_HDRCMPLT)
2336 memcpy(&af, dst->sa_data, sizeof(af));
2337 else
2338 af = dst->sa_family;
2339 if (af == AF_UNSPEC) {
2340 xmit_err(ifp, m, NULL, af);
2341 return (EAFNOSUPPORT);
2342 }
2343
2344 #ifdef DEV_NETMAP
2345 if ((if_getcapenable(ifp) & IFCAP_NETMAP) != 0)
2346 return (wg_xmit_netmap(ifp, m, af));
2347 #endif
2348
2349 defragged = m_defrag(m, M_NOWAIT);
2350 if (defragged)
2351 m = defragged;
2352 m = m_unshare(m, M_NOWAIT);
2353 if (!m) {
2354 xmit_err(ifp, m, NULL, AF_UNSPEC);
2355 return (ENOBUFS);
2356 }
2357
2358 ret = determine_af_and_pullup(&m, &parsed_af);
2359 if (ret) {
2360 xmit_err(ifp, m, NULL, AF_UNSPEC);
2361 return (ret);
2362 }
2363 if (parsed_af != af) {
2364 xmit_err(ifp, m, NULL, AF_UNSPEC);
2365 return (EAFNOSUPPORT);
2366 }
2367 mtu = (ro != NULL && ro->ro_mtu > 0) ? ro->ro_mtu : if_getmtu(ifp);
2368 return (wg_xmit(ifp, m, parsed_af, mtu));
2369 }
2370
2371 static int
wg_peer_add(struct wg_softc * sc,const nvlist_t * nvl)2372 wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl)
2373 {
2374 uint8_t public[WG_KEY_SIZE];
2375 const void *pub_key, *preshared_key = NULL;
2376 const struct sockaddr *endpoint;
2377 int err;
2378 size_t size;
2379 struct noise_remote *remote;
2380 struct wg_peer *peer = NULL;
2381 bool need_insert = false;
2382
2383 sx_assert(&sc->sc_lock, SX_XLOCKED);
2384
2385 if (!nvlist_exists_binary(nvl, "public-key")) {
2386 return (EINVAL);
2387 }
2388 pub_key = nvlist_get_binary(nvl, "public-key", &size);
2389 if (size != WG_KEY_SIZE) {
2390 return (EINVAL);
2391 }
2392 if (noise_local_keys(sc->sc_local, public, NULL) == 0 &&
2393 bcmp(public, pub_key, WG_KEY_SIZE) == 0) {
2394 return (0); // Silently ignored; not actually a failure.
2395 }
2396 if ((remote = noise_remote_lookup(sc->sc_local, pub_key)) != NULL)
2397 peer = noise_remote_arg(remote);
2398 if (nvlist_exists_bool(nvl, "remove") &&
2399 nvlist_get_bool(nvl, "remove")) {
2400 if (remote != NULL) {
2401 wg_peer_destroy(peer);
2402 noise_remote_put(remote);
2403 }
2404 return (0);
2405 }
2406 if (nvlist_exists_bool(nvl, "replace-allowedips") &&
2407 nvlist_get_bool(nvl, "replace-allowedips") &&
2408 peer != NULL) {
2409
2410 wg_aip_remove_all(sc, peer);
2411 }
2412 if (peer == NULL) {
2413 peer = wg_peer_alloc(sc, pub_key);
2414 need_insert = true;
2415 }
2416 if (nvlist_exists_binary(nvl, "endpoint")) {
2417 endpoint = nvlist_get_binary(nvl, "endpoint", &size);
2418 if (size > sizeof(peer->p_endpoint.e_remote)) {
2419 err = EINVAL;
2420 goto out;
2421 }
2422 memcpy(&peer->p_endpoint.e_remote, endpoint, size);
2423 }
2424 if (nvlist_exists_binary(nvl, "preshared-key")) {
2425 preshared_key = nvlist_get_binary(nvl, "preshared-key", &size);
2426 if (size != WG_KEY_SIZE) {
2427 err = EINVAL;
2428 goto out;
2429 }
2430 noise_remote_set_psk(peer->p_remote, preshared_key);
2431 }
2432 if (nvlist_exists_number(nvl, "persistent-keepalive-interval")) {
2433 uint64_t pki = nvlist_get_number(nvl, "persistent-keepalive-interval");
2434 if (pki > UINT16_MAX) {
2435 err = EINVAL;
2436 goto out;
2437 }
2438 wg_timers_set_persistent_keepalive(peer, pki);
2439 }
2440 if (nvlist_exists_nvlist_array(nvl, "allowed-ips")) {
2441 const void *addr;
2442 uint64_t cidr;
2443 const nvlist_t * const * aipl;
2444 size_t allowedip_count;
2445
2446 aipl = nvlist_get_nvlist_array(nvl, "allowed-ips", &allowedip_count);
2447 for (size_t idx = 0; idx < allowedip_count; idx++) {
2448 if (!nvlist_exists_number(aipl[idx], "cidr"))
2449 continue;
2450 cidr = nvlist_get_number(aipl[idx], "cidr");
2451 if (nvlist_exists_binary(aipl[idx], "ipv4")) {
2452 addr = nvlist_get_binary(aipl[idx], "ipv4", &size);
2453 if (addr == NULL || cidr > 32 || size != sizeof(struct in_addr)) {
2454 err = EINVAL;
2455 goto out;
2456 }
2457 if ((err = wg_aip_add(sc, peer, AF_INET, addr, cidr)) != 0)
2458 goto out;
2459 } else if (nvlist_exists_binary(aipl[idx], "ipv6")) {
2460 addr = nvlist_get_binary(aipl[idx], "ipv6", &size);
2461 if (addr == NULL || cidr > 128 || size != sizeof(struct in6_addr)) {
2462 err = EINVAL;
2463 goto out;
2464 }
2465 if ((err = wg_aip_add(sc, peer, AF_INET6, addr, cidr)) != 0)
2466 goto out;
2467 } else {
2468 continue;
2469 }
2470 }
2471 }
2472 if (need_insert) {
2473 if ((err = noise_remote_enable(peer->p_remote)) != 0)
2474 goto out;
2475 TAILQ_INSERT_TAIL(&sc->sc_peers, peer, p_entry);
2476 sc->sc_peers_num++;
2477 if (if_getlinkstate(sc->sc_ifp) == LINK_STATE_UP)
2478 wg_timers_enable(peer);
2479 }
2480 if (remote != NULL)
2481 noise_remote_put(remote);
2482 return (0);
2483 out:
2484 if (need_insert) /* If we fail, only destroy if it was new. */
2485 wg_peer_destroy(peer);
2486 if (remote != NULL)
2487 noise_remote_put(remote);
2488 return (err);
2489 }
2490
2491 static int
wgc_set(struct wg_softc * sc,struct wg_data_io * wgd)2492 wgc_set(struct wg_softc *sc, struct wg_data_io *wgd)
2493 {
2494 uint8_t public[WG_KEY_SIZE], private[WG_KEY_SIZE];
2495 if_t ifp;
2496 void *nvlpacked;
2497 nvlist_t *nvl;
2498 ssize_t size;
2499 int err;
2500
2501 ifp = sc->sc_ifp;
2502 if (wgd->wgd_size == 0 || wgd->wgd_data == NULL)
2503 return (EFAULT);
2504
2505 /* Can nvlists be streamed in? It's not nice to impose arbitrary limits like that but
2506 * there needs to be _some_ limitation. */
2507 if (wgd->wgd_size >= UINT32_MAX / 2)
2508 return (E2BIG);
2509
2510 nvlpacked = malloc(wgd->wgd_size, M_TEMP, M_WAITOK | M_ZERO);
2511
2512 err = copyin(wgd->wgd_data, nvlpacked, wgd->wgd_size);
2513 if (err)
2514 goto out;
2515 nvl = nvlist_unpack(nvlpacked, wgd->wgd_size, 0);
2516 if (nvl == NULL) {
2517 err = EBADMSG;
2518 goto out;
2519 }
2520 sx_xlock(&sc->sc_lock);
2521 if (nvlist_exists_bool(nvl, "replace-peers") &&
2522 nvlist_get_bool(nvl, "replace-peers"))
2523 wg_peer_destroy_all(sc);
2524 if (nvlist_exists_number(nvl, "listen-port")) {
2525 uint64_t new_port = nvlist_get_number(nvl, "listen-port");
2526 if (new_port > UINT16_MAX) {
2527 err = EINVAL;
2528 goto out_locked;
2529 }
2530 if (new_port != sc->sc_socket.so_port) {
2531 if ((if_getdrvflags(ifp) & IFF_DRV_RUNNING) != 0) {
2532 if ((err = wg_socket_init(sc, new_port)) != 0)
2533 goto out_locked;
2534 } else
2535 sc->sc_socket.so_port = new_port;
2536 }
2537 }
2538 if (nvlist_exists_binary(nvl, "private-key")) {
2539 const void *key = nvlist_get_binary(nvl, "private-key", &size);
2540 if (size != WG_KEY_SIZE) {
2541 err = EINVAL;
2542 goto out_locked;
2543 }
2544
2545 if (noise_local_keys(sc->sc_local, NULL, private) != 0 ||
2546 timingsafe_bcmp(private, key, WG_KEY_SIZE) != 0) {
2547 struct wg_peer *peer;
2548
2549 if (curve25519_generate_public(public, key)) {
2550 /* Peer conflict: remove conflicting peer. */
2551 struct noise_remote *remote;
2552 if ((remote = noise_remote_lookup(sc->sc_local,
2553 public)) != NULL) {
2554 peer = noise_remote_arg(remote);
2555 wg_peer_destroy(peer);
2556 noise_remote_put(remote);
2557 }
2558 }
2559
2560 /*
2561 * Set the private key and invalidate all existing
2562 * handshakes.
2563 */
2564 /* Note: we might be removing the private key. */
2565 noise_local_private(sc->sc_local, key);
2566 if (noise_local_keys(sc->sc_local, NULL, NULL) == 0)
2567 cookie_checker_update(&sc->sc_cookie, public);
2568 else
2569 cookie_checker_update(&sc->sc_cookie, NULL);
2570 }
2571 }
2572 if (nvlist_exists_number(nvl, "user-cookie")) {
2573 uint64_t user_cookie = nvlist_get_number(nvl, "user-cookie");
2574 if (user_cookie > UINT32_MAX) {
2575 err = EINVAL;
2576 goto out_locked;
2577 }
2578 err = wg_socket_set_cookie(sc, user_cookie);
2579 if (err)
2580 goto out_locked;
2581 }
2582 if (nvlist_exists_nvlist_array(nvl, "peers")) {
2583 size_t peercount;
2584 const nvlist_t * const*nvl_peers;
2585
2586 nvl_peers = nvlist_get_nvlist_array(nvl, "peers", &peercount);
2587 for (int i = 0; i < peercount; i++) {
2588 err = wg_peer_add(sc, nvl_peers[i]);
2589 if (err != 0)
2590 goto out_locked;
2591 }
2592 }
2593
2594 out_locked:
2595 sx_xunlock(&sc->sc_lock);
2596 nvlist_destroy(nvl);
2597 out:
2598 zfree(nvlpacked, M_TEMP);
2599 return (err);
2600 }
2601
2602 static int
wgc_get(struct wg_softc * sc,struct wg_data_io * wgd)2603 wgc_get(struct wg_softc *sc, struct wg_data_io *wgd)
2604 {
2605 uint8_t public_key[WG_KEY_SIZE] = { 0 };
2606 uint8_t private_key[WG_KEY_SIZE] = { 0 };
2607 uint8_t preshared_key[NOISE_SYMMETRIC_KEY_LEN] = { 0 };
2608 nvlist_t *nvl, *nvl_peer, *nvl_aip, **nvl_peers, **nvl_aips;
2609 size_t size, peer_count, aip_count, i, j;
2610 struct wg_timespec64 ts64;
2611 struct wg_peer *peer;
2612 struct wg_aip *aip;
2613 void *packed;
2614 int err = 0;
2615
2616 nvl = nvlist_create(0);
2617 if (!nvl)
2618 return (ENOMEM);
2619
2620 sx_slock(&sc->sc_lock);
2621
2622 if (sc->sc_socket.so_port != 0)
2623 nvlist_add_number(nvl, "listen-port", sc->sc_socket.so_port);
2624 if (sc->sc_socket.so_user_cookie != 0)
2625 nvlist_add_number(nvl, "user-cookie", sc->sc_socket.so_user_cookie);
2626 if (noise_local_keys(sc->sc_local, public_key, private_key) == 0) {
2627 nvlist_add_binary(nvl, "public-key", public_key, WG_KEY_SIZE);
2628 if (wgc_privileged(sc))
2629 nvlist_add_binary(nvl, "private-key", private_key, WG_KEY_SIZE);
2630 explicit_bzero(private_key, sizeof(private_key));
2631 }
2632 peer_count = sc->sc_peers_num;
2633 if (peer_count) {
2634 nvl_peers = mallocarray(peer_count, sizeof(void *), M_NVLIST, M_WAITOK | M_ZERO);
2635 i = 0;
2636 TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) {
2637 if (i >= peer_count)
2638 panic("peers changed from under us");
2639
2640 nvl_peers[i++] = nvl_peer = nvlist_create(0);
2641 if (!nvl_peer) {
2642 err = ENOMEM;
2643 goto err_peer;
2644 }
2645
2646 (void)noise_remote_keys(peer->p_remote, public_key, preshared_key);
2647 nvlist_add_binary(nvl_peer, "public-key", public_key, sizeof(public_key));
2648 if (wgc_privileged(sc))
2649 nvlist_add_binary(nvl_peer, "preshared-key", preshared_key, sizeof(preshared_key));
2650 explicit_bzero(preshared_key, sizeof(preshared_key));
2651 if (peer->p_endpoint.e_remote.r_sa.sa_family == AF_INET)
2652 nvlist_add_binary(nvl_peer, "endpoint", &peer->p_endpoint.e_remote, sizeof(struct sockaddr_in));
2653 else if (peer->p_endpoint.e_remote.r_sa.sa_family == AF_INET6)
2654 nvlist_add_binary(nvl_peer, "endpoint", &peer->p_endpoint.e_remote, sizeof(struct sockaddr_in6));
2655 wg_timers_get_last_handshake(peer, &ts64);
2656 nvlist_add_binary(nvl_peer, "last-handshake-time", &ts64, sizeof(ts64));
2657 nvlist_add_number(nvl_peer, "persistent-keepalive-interval", peer->p_persistent_keepalive_interval);
2658 nvlist_add_number(nvl_peer, "rx-bytes", counter_u64_fetch(peer->p_rx_bytes));
2659 nvlist_add_number(nvl_peer, "tx-bytes", counter_u64_fetch(peer->p_tx_bytes));
2660
2661 aip_count = peer->p_aips_num;
2662 if (aip_count) {
2663 nvl_aips = mallocarray(aip_count, sizeof(void *), M_NVLIST, M_WAITOK | M_ZERO);
2664 j = 0;
2665 LIST_FOREACH(aip, &peer->p_aips, a_entry) {
2666 if (j >= aip_count)
2667 panic("aips changed from under us");
2668
2669 nvl_aips[j++] = nvl_aip = nvlist_create(0);
2670 if (!nvl_aip) {
2671 err = ENOMEM;
2672 goto err_aip;
2673 }
2674 if (aip->a_af == AF_INET) {
2675 nvlist_add_binary(nvl_aip, "ipv4", &aip->a_addr.in, sizeof(aip->a_addr.in));
2676 nvlist_add_number(nvl_aip, "cidr", bitcount32(aip->a_mask.ip));
2677 }
2678 #ifdef INET6
2679 else if (aip->a_af == AF_INET6) {
2680 nvlist_add_binary(nvl_aip, "ipv6", &aip->a_addr.in6, sizeof(aip->a_addr.in6));
2681 nvlist_add_number(nvl_aip, "cidr", in6_mask2len(&aip->a_mask.in6, NULL));
2682 }
2683 #endif
2684 }
2685 nvlist_add_nvlist_array(nvl_peer, "allowed-ips", (const nvlist_t *const *)nvl_aips, aip_count);
2686 err_aip:
2687 for (j = 0; j < aip_count; ++j)
2688 nvlist_destroy(nvl_aips[j]);
2689 free(nvl_aips, M_NVLIST);
2690 if (err)
2691 goto err_peer;
2692 }
2693 }
2694 nvlist_add_nvlist_array(nvl, "peers", (const nvlist_t * const *)nvl_peers, peer_count);
2695 err_peer:
2696 for (i = 0; i < peer_count; ++i)
2697 nvlist_destroy(nvl_peers[i]);
2698 free(nvl_peers, M_NVLIST);
2699 if (err) {
2700 sx_sunlock(&sc->sc_lock);
2701 goto err;
2702 }
2703 }
2704 sx_sunlock(&sc->sc_lock);
2705 packed = nvlist_pack(nvl, &size);
2706 if (!packed) {
2707 err = ENOMEM;
2708 goto err;
2709 }
2710 if (!wgd->wgd_size) {
2711 wgd->wgd_size = size;
2712 goto out;
2713 }
2714 if (wgd->wgd_size < size) {
2715 err = ENOSPC;
2716 goto out;
2717 }
2718 err = copyout(packed, wgd->wgd_data, size);
2719 wgd->wgd_size = size;
2720
2721 out:
2722 zfree(packed, M_NVLIST);
2723 err:
2724 nvlist_destroy(nvl);
2725 return (err);
2726 }
2727
2728 static int
wg_ioctl(if_t ifp,u_long cmd,caddr_t data)2729 wg_ioctl(if_t ifp, u_long cmd, caddr_t data)
2730 {
2731 struct wg_data_io *wgd = (struct wg_data_io *)data;
2732 struct ifreq *ifr = (struct ifreq *)data;
2733 struct wg_softc *sc;
2734 int ret = 0;
2735
2736 sx_slock(&wg_sx);
2737 sc = if_getsoftc(ifp);
2738 if (!sc) {
2739 ret = ENXIO;
2740 goto out;
2741 }
2742
2743 switch (cmd) {
2744 case SIOCSWG:
2745 ret = priv_check(curthread, PRIV_NET_WG);
2746 if (ret == 0)
2747 ret = wgc_set(sc, wgd);
2748 break;
2749 case SIOCGWG:
2750 ret = wgc_get(sc, wgd);
2751 break;
2752 /* Interface IOCTLs */
2753 case SIOCSIFADDR:
2754 /*
2755 * This differs from *BSD norms, but is more uniform with how
2756 * WireGuard behaves elsewhere.
2757 */
2758 break;
2759 case SIOCSIFFLAGS:
2760 if (if_getflags(ifp) & IFF_UP)
2761 ret = wg_up(sc);
2762 else
2763 wg_down(sc);
2764 break;
2765 case SIOCSIFMTU:
2766 if (ifr->ifr_mtu <= 0 || ifr->ifr_mtu > MAX_MTU)
2767 ret = EINVAL;
2768 else
2769 if_setmtu(ifp, ifr->ifr_mtu);
2770 break;
2771 case SIOCADDMULTI:
2772 case SIOCDELMULTI:
2773 break;
2774 case SIOCGTUNFIB:
2775 ifr->ifr_fib = sc->sc_socket.so_fibnum;
2776 break;
2777 case SIOCSTUNFIB:
2778 ret = priv_check(curthread, PRIV_NET_WG);
2779 if (ret)
2780 break;
2781 ret = priv_check(curthread, PRIV_NET_SETIFFIB);
2782 if (ret)
2783 break;
2784 sx_xlock(&sc->sc_lock);
2785 ret = wg_socket_set_fibnum(sc, ifr->ifr_fib);
2786 sx_xunlock(&sc->sc_lock);
2787 break;
2788 default:
2789 ret = ENOTTY;
2790 }
2791
2792 out:
2793 sx_sunlock(&wg_sx);
2794 return (ret);
2795 }
2796
2797 static int
wg_up(struct wg_softc * sc)2798 wg_up(struct wg_softc *sc)
2799 {
2800 if_t ifp = sc->sc_ifp;
2801 struct wg_peer *peer;
2802 int rc = EBUSY;
2803
2804 sx_xlock(&sc->sc_lock);
2805 /* Jail's being removed, no more wg_up(). */
2806 if ((sc->sc_flags & WGF_DYING) != 0)
2807 goto out;
2808
2809 /* Silent success if we're already running. */
2810 rc = 0;
2811 if (if_getdrvflags(ifp) & IFF_DRV_RUNNING)
2812 goto out;
2813 if_setdrvflagbits(ifp, IFF_DRV_RUNNING, 0);
2814
2815 rc = wg_socket_init(sc, sc->sc_socket.so_port);
2816 if (rc == 0) {
2817 TAILQ_FOREACH(peer, &sc->sc_peers, p_entry)
2818 wg_timers_enable(peer);
2819 if_link_state_change(sc->sc_ifp, LINK_STATE_UP);
2820 } else {
2821 if_setdrvflagbits(ifp, 0, IFF_DRV_RUNNING);
2822 DPRINTF(sc, "Unable to initialize sockets: %d\n", rc);
2823 }
2824 out:
2825 sx_xunlock(&sc->sc_lock);
2826 return (rc);
2827 }
2828
2829 static void
wg_down(struct wg_softc * sc)2830 wg_down(struct wg_softc *sc)
2831 {
2832 if_t ifp = sc->sc_ifp;
2833 struct wg_peer *peer;
2834
2835 sx_xlock(&sc->sc_lock);
2836 if (!(if_getdrvflags(ifp) & IFF_DRV_RUNNING)) {
2837 sx_xunlock(&sc->sc_lock);
2838 return;
2839 }
2840 if_setdrvflagbits(ifp, 0, IFF_DRV_RUNNING);
2841
2842 TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) {
2843 wg_queue_purge(&peer->p_stage_queue);
2844 wg_timers_disable(peer);
2845 }
2846
2847 wg_queue_purge(&sc->sc_handshake_queue);
2848
2849 TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) {
2850 noise_remote_handshake_clear(peer->p_remote);
2851 noise_remote_keypairs_clear(peer->p_remote);
2852 }
2853
2854 if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
2855 wg_socket_uninit(sc);
2856
2857 sx_xunlock(&sc->sc_lock);
2858 }
2859
2860 static int
wg_clone_create(struct if_clone * ifc,char * name,size_t len,struct ifc_data * ifd,struct ifnet ** ifpp)2861 wg_clone_create(struct if_clone *ifc, char *name, size_t len,
2862 struct ifc_data *ifd, struct ifnet **ifpp)
2863 {
2864 struct wg_softc *sc;
2865 if_t ifp;
2866
2867 sc = malloc(sizeof(*sc), M_WG, M_WAITOK | M_ZERO);
2868
2869 sc->sc_local = noise_local_alloc(sc);
2870
2871 sc->sc_encrypt = mallocarray(sizeof(struct grouptask), mp_ncpus, M_WG, M_WAITOK | M_ZERO);
2872
2873 sc->sc_decrypt = mallocarray(sizeof(struct grouptask), mp_ncpus, M_WG, M_WAITOK | M_ZERO);
2874
2875 if (!rn_inithead((void **)&sc->sc_aip4, offsetof(struct aip_addr, in) * NBBY))
2876 goto free_decrypt;
2877
2878 if (!rn_inithead((void **)&sc->sc_aip6, offsetof(struct aip_addr, in6) * NBBY))
2879 goto free_aip4;
2880
2881 atomic_add_int(&clone_count, 1);
2882 ifp = sc->sc_ifp = if_alloc(IFT_WIREGUARD);
2883
2884 sc->sc_ucred = crhold(curthread->td_ucred);
2885 sc->sc_socket.so_fibnum = curthread->td_proc->p_fibnum;
2886 sc->sc_socket.so_port = 0;
2887
2888 TAILQ_INIT(&sc->sc_peers);
2889 sc->sc_peers_num = 0;
2890
2891 cookie_checker_init(&sc->sc_cookie);
2892
2893 RADIX_NODE_HEAD_LOCK_INIT(sc->sc_aip4);
2894 RADIX_NODE_HEAD_LOCK_INIT(sc->sc_aip6);
2895
2896 GROUPTASK_INIT(&sc->sc_handshake, 0, (gtask_fn_t *)wg_softc_handshake_receive, sc);
2897 taskqgroup_attach(qgroup_wg_tqg, &sc->sc_handshake, sc, NULL, NULL, "wg tx initiation");
2898 wg_queue_init(&sc->sc_handshake_queue, "hsq");
2899
2900 for (int i = 0; i < mp_ncpus; i++) {
2901 GROUPTASK_INIT(&sc->sc_encrypt[i], 0,
2902 (gtask_fn_t *)wg_softc_encrypt, sc);
2903 taskqgroup_attach_cpu(qgroup_wg_tqg, &sc->sc_encrypt[i], sc, i, NULL, NULL, "wg encrypt");
2904 GROUPTASK_INIT(&sc->sc_decrypt[i], 0,
2905 (gtask_fn_t *)wg_softc_decrypt, sc);
2906 taskqgroup_attach_cpu(qgroup_wg_tqg, &sc->sc_decrypt[i], sc, i, NULL, NULL, "wg decrypt");
2907 }
2908
2909 wg_queue_init(&sc->sc_encrypt_parallel, "encp");
2910 wg_queue_init(&sc->sc_decrypt_parallel, "decp");
2911
2912 sx_init(&sc->sc_lock, "wg softc lock");
2913
2914 if_setsoftc(ifp, sc);
2915 if_setcapabilities(ifp, WG_CAPS);
2916 if_setcapenable(ifp, WG_CAPS);
2917 if_initname(ifp, wgname, ifd->unit);
2918
2919 if_setmtu(ifp, DEFAULT_MTU);
2920 if_setflags(ifp, IFF_NOARP | IFF_MULTICAST);
2921 if_setinitfn(ifp, wg_init);
2922 if_setreassignfn(ifp, wg_reassign);
2923 if_setqflushfn(ifp, wg_qflush);
2924 if_settransmitfn(ifp, wg_transmit);
2925 #ifdef DEV_NETMAP
2926 if_setinputfn(ifp, wg_if_input);
2927 #endif
2928 if_setoutputfn(ifp, wg_output);
2929 if_setioctlfn(ifp, wg_ioctl);
2930 if_attach(ifp);
2931 bpfattach(ifp, DLT_NULL, sizeof(uint32_t));
2932 #ifdef INET6
2933 ND_IFINFO(ifp)->flags &= ~ND6_IFF_AUTO_LINKLOCAL;
2934 ND_IFINFO(ifp)->flags |= ND6_IFF_NO_DAD;
2935 #endif
2936 sx_xlock(&wg_sx);
2937 LIST_INSERT_HEAD(&wg_list, sc, sc_entry);
2938 sx_xunlock(&wg_sx);
2939 *ifpp = ifp;
2940 return (0);
2941 free_aip4:
2942 RADIX_NODE_HEAD_DESTROY(sc->sc_aip4);
2943 free(sc->sc_aip4, M_RTABLE);
2944 free_decrypt:
2945 free(sc->sc_decrypt, M_WG);
2946 free(sc->sc_encrypt, M_WG);
2947 noise_local_free(sc->sc_local, NULL);
2948 free(sc, M_WG);
2949 return (ENOMEM);
2950 }
2951
2952 static void
wg_clone_deferred_free(struct noise_local * l)2953 wg_clone_deferred_free(struct noise_local *l)
2954 {
2955 struct wg_softc *sc = noise_local_arg(l);
2956
2957 free(sc, M_WG);
2958 atomic_add_int(&clone_count, -1);
2959 }
2960
2961 static int
wg_clone_destroy(struct if_clone * ifc,if_t ifp,uint32_t flags)2962 wg_clone_destroy(struct if_clone *ifc, if_t ifp, uint32_t flags)
2963 {
2964 struct wg_softc *sc = if_getsoftc(ifp);
2965 struct ucred *cred;
2966
2967 sx_xlock(&wg_sx);
2968 if_setsoftc(ifp, NULL);
2969 sx_xlock(&sc->sc_lock);
2970 sc->sc_flags |= WGF_DYING;
2971 cred = sc->sc_ucred;
2972 sc->sc_ucred = NULL;
2973 sx_xunlock(&sc->sc_lock);
2974 LIST_REMOVE(sc, sc_entry);
2975 sx_xunlock(&wg_sx);
2976
2977 if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
2978 CURVNET_SET(if_getvnet(sc->sc_ifp));
2979 if_purgeaddrs(sc->sc_ifp);
2980 CURVNET_RESTORE();
2981
2982 sx_xlock(&sc->sc_lock);
2983 wg_socket_uninit(sc);
2984 sx_xunlock(&sc->sc_lock);
2985
2986 /*
2987 * No guarantees that all traffic have passed until the epoch has
2988 * elapsed with the socket closed.
2989 */
2990 NET_EPOCH_WAIT();
2991
2992 taskqgroup_drain_all(qgroup_wg_tqg);
2993 sx_xlock(&sc->sc_lock);
2994 wg_peer_destroy_all(sc);
2995 NET_EPOCH_DRAIN_CALLBACKS();
2996 sx_xunlock(&sc->sc_lock);
2997 sx_destroy(&sc->sc_lock);
2998 taskqgroup_detach(qgroup_wg_tqg, &sc->sc_handshake);
2999 for (int i = 0; i < mp_ncpus; i++) {
3000 taskqgroup_detach(qgroup_wg_tqg, &sc->sc_encrypt[i]);
3001 taskqgroup_detach(qgroup_wg_tqg, &sc->sc_decrypt[i]);
3002 }
3003 free(sc->sc_encrypt, M_WG);
3004 free(sc->sc_decrypt, M_WG);
3005 wg_queue_deinit(&sc->sc_handshake_queue);
3006 wg_queue_deinit(&sc->sc_encrypt_parallel);
3007 wg_queue_deinit(&sc->sc_decrypt_parallel);
3008
3009 RADIX_NODE_HEAD_DESTROY(sc->sc_aip4);
3010 RADIX_NODE_HEAD_DESTROY(sc->sc_aip6);
3011 rn_detachhead((void **)&sc->sc_aip4);
3012 rn_detachhead((void **)&sc->sc_aip6);
3013
3014 cookie_checker_free(&sc->sc_cookie);
3015
3016 if (cred != NULL)
3017 crfree(cred);
3018 bpfdetach(sc->sc_ifp);
3019 if_detach(sc->sc_ifp);
3020 if_free(sc->sc_ifp);
3021
3022 noise_local_free(sc->sc_local, wg_clone_deferred_free);
3023
3024 return (0);
3025 }
3026
3027 static void
wg_qflush(if_t ifp __unused)3028 wg_qflush(if_t ifp __unused)
3029 {
3030 }
3031
3032 /*
3033 * Privileged information (private-key, preshared-key) are only exported for
3034 * root and jailed root by default.
3035 */
3036 static bool
wgc_privileged(struct wg_softc * sc)3037 wgc_privileged(struct wg_softc *sc)
3038 {
3039 struct thread *td;
3040
3041 td = curthread;
3042 return (priv_check(td, PRIV_NET_WG) == 0);
3043 }
3044
3045 static void
wg_reassign(if_t ifp,struct vnet * new_vnet __unused,char * unused __unused)3046 wg_reassign(if_t ifp, struct vnet *new_vnet __unused,
3047 char *unused __unused)
3048 {
3049 struct wg_softc *sc;
3050
3051 sc = if_getsoftc(ifp);
3052 wg_down(sc);
3053 }
3054
3055 static void
wg_init(void * xsc)3056 wg_init(void *xsc)
3057 {
3058 struct wg_softc *sc;
3059
3060 sc = xsc;
3061 wg_up(sc);
3062 }
3063
3064 static void
vnet_wg_init(const void * unused __unused)3065 vnet_wg_init(const void *unused __unused)
3066 {
3067 struct if_clone_addreq req = {
3068 .create_f = wg_clone_create,
3069 .destroy_f = wg_clone_destroy,
3070 .flags = IFC_F_AUTOUNIT,
3071 };
3072 V_wg_cloner = ifc_attach_cloner(wgname, &req);
3073 }
3074 VNET_SYSINIT(vnet_wg_init, SI_SUB_PROTO_IFATTACHDOMAIN, SI_ORDER_ANY,
3075 vnet_wg_init, NULL);
3076
3077 static void
vnet_wg_uninit(const void * unused __unused)3078 vnet_wg_uninit(const void *unused __unused)
3079 {
3080 if (V_wg_cloner)
3081 ifc_detach_cloner(V_wg_cloner);
3082 }
3083 VNET_SYSUNINIT(vnet_wg_uninit, SI_SUB_PROTO_IFATTACHDOMAIN, SI_ORDER_ANY,
3084 vnet_wg_uninit, NULL);
3085
3086 static int
wg_prison_remove(void * obj,void * data __unused)3087 wg_prison_remove(void *obj, void *data __unused)
3088 {
3089 const struct prison *pr = obj;
3090 struct wg_softc *sc;
3091
3092 /*
3093 * Do a pass through all if_wg interfaces and release creds on any from
3094 * the jail that are supposed to be going away. This will, in turn, let
3095 * the jail die so that we don't end up with Schrödinger's jail.
3096 */
3097 sx_slock(&wg_sx);
3098 LIST_FOREACH(sc, &wg_list, sc_entry) {
3099 sx_xlock(&sc->sc_lock);
3100 if (!(sc->sc_flags & WGF_DYING) && sc->sc_ucred && sc->sc_ucred->cr_prison == pr) {
3101 struct ucred *cred = sc->sc_ucred;
3102 DPRINTF(sc, "Creating jail exiting\n");
3103 if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
3104 wg_socket_uninit(sc);
3105 sc->sc_ucred = NULL;
3106 crfree(cred);
3107 sc->sc_flags |= WGF_DYING;
3108 }
3109 sx_xunlock(&sc->sc_lock);
3110 }
3111 sx_sunlock(&wg_sx);
3112
3113 return (0);
3114 }
3115
3116 #ifdef SELFTESTS
3117 #include "selftest/allowedips.c"
wg_run_selftests(void)3118 static bool wg_run_selftests(void)
3119 {
3120 bool ret = true;
3121 ret &= wg_allowedips_selftest();
3122 ret &= noise_counter_selftest();
3123 ret &= cookie_selftest();
3124 return ret;
3125 }
3126 #else
wg_run_selftests(void)3127 static inline bool wg_run_selftests(void) { return true; }
3128 #endif
3129
3130 static int
wg_module_init(void)3131 wg_module_init(void)
3132 {
3133 int ret;
3134 osd_method_t methods[PR_MAXMETHOD] = {
3135 [PR_METHOD_REMOVE] = wg_prison_remove,
3136 };
3137
3138 wg_packet_zone = uma_zcreate("wg packet", sizeof(struct wg_packet),
3139 NULL, NULL, NULL, NULL, 0, 0);
3140
3141 ret = crypto_init();
3142 if (ret != 0)
3143 return (ret);
3144 ret = cookie_init();
3145 if (ret != 0)
3146 return (ret);
3147
3148 wg_osd_jail_slot = osd_jail_register(NULL, methods);
3149
3150 if (!wg_run_selftests())
3151 return (ENOTRECOVERABLE);
3152
3153 return (0);
3154 }
3155
3156 static void
wg_module_deinit(void)3157 wg_module_deinit(void)
3158 {
3159 VNET_ITERATOR_DECL(vnet_iter);
3160 VNET_LIST_RLOCK();
3161 VNET_FOREACH(vnet_iter) {
3162 struct if_clone *clone = VNET_VNET(vnet_iter, wg_cloner);
3163 if (clone) {
3164 ifc_detach_cloner(clone);
3165 VNET_VNET(vnet_iter, wg_cloner) = NULL;
3166 }
3167 }
3168 VNET_LIST_RUNLOCK();
3169 NET_EPOCH_WAIT();
3170 MPASS(LIST_EMPTY(&wg_list));
3171 if (wg_osd_jail_slot != 0)
3172 osd_jail_deregister(wg_osd_jail_slot);
3173 cookie_deinit();
3174 crypto_deinit();
3175 if (wg_packet_zone != NULL)
3176 uma_zdestroy(wg_packet_zone);
3177 }
3178
3179 static int
wg_module_event_handler(module_t mod,int what,void * arg)3180 wg_module_event_handler(module_t mod, int what, void *arg)
3181 {
3182 switch (what) {
3183 case MOD_LOAD:
3184 return wg_module_init();
3185 case MOD_UNLOAD:
3186 wg_module_deinit();
3187 break;
3188 default:
3189 return (EOPNOTSUPP);
3190 }
3191 return (0);
3192 }
3193
3194 static moduledata_t wg_moduledata = {
3195 "if_wg",
3196 wg_module_event_handler,
3197 NULL
3198 };
3199
3200 DECLARE_MODULE(if_wg, wg_moduledata, SI_SUB_PSEUDO, SI_ORDER_ANY);
3201 MODULE_VERSION(if_wg, WIREGUARD_VERSION);
3202 MODULE_DEPEND(if_wg, crypto, 1, 1, 1);
3203