1 /* SPDX-License-Identifier: ISC
2 *
3 * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4 * Copyright (C) 2019-2021 Matt Dunwoodie <ncon@noconroy.net>
5 * Copyright (c) 2019-2020 Rubicon Communications, LLC (Netgate)
6 * Copyright (c) 2021 Kyle Evans <kevans@FreeBSD.org>
7 * Copyright (c) 2022 The FreeBSD Foundation
8 */
9
10 #include "opt_inet.h"
11 #include "opt_inet6.h"
12
13 #include <sys/param.h>
14 #include <sys/systm.h>
15 #include <sys/counter.h>
16 #include <sys/gtaskqueue.h>
17 #include <sys/jail.h>
18 #include <sys/kernel.h>
19 #include <sys/lock.h>
20 #include <sys/mbuf.h>
21 #include <sys/module.h>
22 #include <sys/nv.h>
23 #include <sys/priv.h>
24 #include <sys/protosw.h>
25 #include <sys/rmlock.h>
26 #include <sys/rwlock.h>
27 #include <sys/smp.h>
28 #include <sys/socket.h>
29 #include <sys/socketvar.h>
30 #include <sys/sockio.h>
31 #include <sys/sysctl.h>
32 #include <sys/sx.h>
33 #include <machine/_inttypes.h>
34 #include <net/bpf.h>
35 #include <net/ethernet.h>
36 #include <net/if.h>
37 #include <net/if_clone.h>
38 #include <net/if_types.h>
39 #include <net/if_var.h>
40 #include <net/netisr.h>
41 #include <net/radix.h>
42 #include <netinet/in.h>
43 #include <netinet6/in6_var.h>
44 #include <netinet/ip.h>
45 #include <netinet/ip6.h>
46 #include <netinet/ip_icmp.h>
47 #include <netinet/icmp6.h>
48 #include <netinet/udp_var.h>
49 #include <netinet6/nd6.h>
50
51 #include "wg_noise.h"
52 #include "wg_cookie.h"
53 #include "version.h"
54 #include "if_wg.h"
55
56 #define DEFAULT_MTU (ETHERMTU - 80)
57 #define MAX_MTU (IF_MAXMTU - 80)
58
59 #define MAX_STAGED_PKT 128
60 #define MAX_QUEUED_PKT 1024
61 #define MAX_QUEUED_PKT_MASK (MAX_QUEUED_PKT - 1)
62
63 #define MAX_QUEUED_HANDSHAKES 4096
64
65 #define REKEY_TIMEOUT_JITTER 334 /* 1/3 sec, round for arc4random_uniform */
66 #define MAX_TIMER_HANDSHAKES (90 / REKEY_TIMEOUT)
67 #define NEW_HANDSHAKE_TIMEOUT (REKEY_TIMEOUT + KEEPALIVE_TIMEOUT)
68 #define UNDERLOAD_TIMEOUT 1
69
70 #define DPRINTF(sc, ...) if (if_getflags(sc->sc_ifp) & IFF_DEBUG) if_printf(sc->sc_ifp, ##__VA_ARGS__)
71
72 /* First byte indicating packet type on the wire */
73 #define WG_PKT_INITIATION htole32(1)
74 #define WG_PKT_RESPONSE htole32(2)
75 #define WG_PKT_COOKIE htole32(3)
76 #define WG_PKT_DATA htole32(4)
77
78 #define WG_PKT_PADDING 16
79 #define WG_KEY_SIZE 32
80
81 struct wg_pkt_initiation {
82 uint32_t t;
83 uint32_t s_idx;
84 uint8_t ue[NOISE_PUBLIC_KEY_LEN];
85 uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN];
86 uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN];
87 struct cookie_macs m;
88 };
89
90 struct wg_pkt_response {
91 uint32_t t;
92 uint32_t s_idx;
93 uint32_t r_idx;
94 uint8_t ue[NOISE_PUBLIC_KEY_LEN];
95 uint8_t en[0 + NOISE_AUTHTAG_LEN];
96 struct cookie_macs m;
97 };
98
99 struct wg_pkt_cookie {
100 uint32_t t;
101 uint32_t r_idx;
102 uint8_t nonce[COOKIE_NONCE_SIZE];
103 uint8_t ec[COOKIE_ENCRYPTED_SIZE];
104 };
105
106 struct wg_pkt_data {
107 uint32_t t;
108 uint32_t r_idx;
109 uint64_t nonce;
110 uint8_t buf[];
111 };
112
113 struct wg_endpoint {
114 union {
115 struct sockaddr r_sa;
116 struct sockaddr_in r_sin;
117 #ifdef INET6
118 struct sockaddr_in6 r_sin6;
119 #endif
120 } e_remote;
121 union {
122 struct in_addr l_in;
123 #ifdef INET6
124 struct in6_pktinfo l_pktinfo6;
125 #define l_in6 l_pktinfo6.ipi6_addr
126 #endif
127 } e_local;
128 };
129
130 struct aip_addr {
131 uint8_t length;
132 union {
133 uint8_t bytes[16];
134 uint32_t ip;
135 uint32_t ip6[4];
136 struct in_addr in;
137 struct in6_addr in6;
138 };
139 };
140
141 struct wg_aip {
142 struct radix_node a_nodes[2];
143 LIST_ENTRY(wg_aip) a_entry;
144 struct aip_addr a_addr;
145 struct aip_addr a_mask;
146 struct wg_peer *a_peer;
147 sa_family_t a_af;
148 };
149
150 struct wg_packet {
151 STAILQ_ENTRY(wg_packet) p_serial;
152 STAILQ_ENTRY(wg_packet) p_parallel;
153 struct wg_endpoint p_endpoint;
154 struct noise_keypair *p_keypair;
155 uint64_t p_nonce;
156 struct mbuf *p_mbuf;
157 int p_mtu;
158 sa_family_t p_af;
159 enum wg_ring_state {
160 WG_PACKET_UNCRYPTED,
161 WG_PACKET_CRYPTED,
162 WG_PACKET_DEAD,
163 } p_state;
164 };
165
166 STAILQ_HEAD(wg_packet_list, wg_packet);
167
168 struct wg_queue {
169 struct mtx q_mtx;
170 struct wg_packet_list q_queue;
171 size_t q_len;
172 };
173
174 struct wg_peer {
175 TAILQ_ENTRY(wg_peer) p_entry;
176 uint64_t p_id;
177 struct wg_softc *p_sc;
178
179 struct noise_remote *p_remote;
180 struct cookie_maker p_cookie;
181
182 struct rwlock p_endpoint_lock;
183 struct wg_endpoint p_endpoint;
184
185 struct wg_queue p_stage_queue;
186 struct wg_queue p_encrypt_serial;
187 struct wg_queue p_decrypt_serial;
188
189 bool p_enabled;
190 bool p_need_another_keepalive;
191 uint16_t p_persistent_keepalive_interval;
192 struct callout p_new_handshake;
193 struct callout p_send_keepalive;
194 struct callout p_retry_handshake;
195 struct callout p_zero_key_material;
196 struct callout p_persistent_keepalive;
197
198 struct mtx p_handshake_mtx;
199 struct timespec p_handshake_complete; /* nanotime */
200 int p_handshake_retries;
201
202 struct grouptask p_send;
203 struct grouptask p_recv;
204
205 counter_u64_t p_tx_bytes;
206 counter_u64_t p_rx_bytes;
207
208 LIST_HEAD(, wg_aip) p_aips;
209 size_t p_aips_num;
210 };
211
212 struct wg_socket {
213 struct socket *so_so4;
214 struct socket *so_so6;
215 uint32_t so_user_cookie;
216 int so_fibnum;
217 in_port_t so_port;
218 };
219
220 struct wg_softc {
221 LIST_ENTRY(wg_softc) sc_entry;
222 if_t sc_ifp;
223 int sc_flags;
224
225 struct ucred *sc_ucred;
226 struct wg_socket sc_socket;
227
228 TAILQ_HEAD(,wg_peer) sc_peers;
229 size_t sc_peers_num;
230
231 struct noise_local *sc_local;
232 struct cookie_checker sc_cookie;
233
234 struct radix_node_head *sc_aip4;
235 struct radix_node_head *sc_aip6;
236
237 struct grouptask sc_handshake;
238 struct wg_queue sc_handshake_queue;
239
240 struct grouptask *sc_encrypt;
241 struct grouptask *sc_decrypt;
242 struct wg_queue sc_encrypt_parallel;
243 struct wg_queue sc_decrypt_parallel;
244 u_int sc_encrypt_last_cpu;
245 u_int sc_decrypt_last_cpu;
246
247 struct sx sc_lock;
248 };
249
250 #define WGF_DYING 0x0001
251
252 #define MAX_LOOPS 8
253 #define MTAG_WGLOOP 0x77676c70 /* wglp */
254
255 #define GROUPTASK_DRAIN(gtask) \
256 gtaskqueue_drain((gtask)->gt_taskqueue, &(gtask)->gt_task)
257
258 #define BPF_MTAP2_AF(ifp, m, af) do { \
259 uint32_t __bpf_tap_af = (af); \
260 BPF_MTAP2(ifp, &__bpf_tap_af, sizeof(__bpf_tap_af), m); \
261 } while (0)
262
263 static int clone_count;
264 static uma_zone_t wg_packet_zone;
265 static volatile unsigned long peer_counter = 0;
266 static const char wgname[] = "wg";
267 static unsigned wg_osd_jail_slot;
268
269 static struct sx wg_sx;
270 SX_SYSINIT(wg_sx, &wg_sx, "wg_sx");
271
272 static LIST_HEAD(, wg_softc) wg_list = LIST_HEAD_INITIALIZER(wg_list);
273
274 static TASKQGROUP_DEFINE(wg_tqg, mp_ncpus, 1);
275
276 MALLOC_DEFINE(M_WG, "WG", "wireguard");
277
278 VNET_DEFINE_STATIC(struct if_clone *, wg_cloner);
279
280 #define V_wg_cloner VNET(wg_cloner)
281 #define WG_CAPS IFCAP_LINKSTATE
282
283 struct wg_timespec64 {
284 uint64_t tv_sec;
285 uint64_t tv_nsec;
286 };
287
288 static int wg_socket_init(struct wg_softc *, in_port_t);
289 static int wg_socket_bind(struct socket **, struct socket **, in_port_t *);
290 static void wg_socket_set(struct wg_softc *, struct socket *, struct socket *);
291 static void wg_socket_uninit(struct wg_softc *);
292 static int wg_socket_set_sockopt(struct socket *, struct socket *, int, void *, size_t);
293 static int wg_socket_set_cookie(struct wg_softc *, uint32_t);
294 static int wg_socket_set_fibnum(struct wg_softc *, int);
295 static int wg_send(struct wg_softc *, struct wg_endpoint *, struct mbuf *);
296 static void wg_timers_enable(struct wg_peer *);
297 static void wg_timers_disable(struct wg_peer *);
298 static void wg_timers_set_persistent_keepalive(struct wg_peer *, uint16_t);
299 static void wg_timers_get_last_handshake(struct wg_peer *, struct wg_timespec64 *);
300 static void wg_timers_event_data_sent(struct wg_peer *);
301 static void wg_timers_event_data_received(struct wg_peer *);
302 static void wg_timers_event_any_authenticated_packet_sent(struct wg_peer *);
303 static void wg_timers_event_any_authenticated_packet_received(struct wg_peer *);
304 static void wg_timers_event_any_authenticated_packet_traversal(struct wg_peer *);
305 static void wg_timers_event_handshake_initiated(struct wg_peer *);
306 static void wg_timers_event_handshake_complete(struct wg_peer *);
307 static void wg_timers_event_session_derived(struct wg_peer *);
308 static void wg_timers_event_want_initiation(struct wg_peer *);
309 static void wg_timers_run_send_initiation(struct wg_peer *, bool);
310 static void wg_timers_run_retry_handshake(void *);
311 static void wg_timers_run_send_keepalive(void *);
312 static void wg_timers_run_new_handshake(void *);
313 static void wg_timers_run_zero_key_material(void *);
314 static void wg_timers_run_persistent_keepalive(void *);
315 static int wg_aip_add(struct wg_softc *, struct wg_peer *, sa_family_t,
316 const void *, uint8_t);
317 static int wg_aip_del(struct wg_softc *, struct wg_peer *, sa_family_t,
318 const void *, uint8_t);
319 static struct wg_peer *wg_aip_lookup(struct wg_softc *, sa_family_t, void *);
320 static void wg_aip_remove_all(struct wg_softc *, struct wg_peer *);
321 static struct wg_peer *wg_peer_create(struct wg_softc *,
322 const uint8_t [WG_KEY_SIZE], int *);
323 static void wg_peer_free_deferred(struct noise_remote *);
324 static void wg_peer_destroy(struct wg_peer *);
325 static void wg_peer_destroy_all(struct wg_softc *);
326 static void wg_peer_send_buf(struct wg_peer *, uint8_t *, size_t);
327 static void wg_send_initiation(struct wg_peer *);
328 static void wg_send_response(struct wg_peer *);
329 static void wg_send_cookie(struct wg_softc *, struct cookie_macs *, uint32_t, struct wg_endpoint *);
330 static void wg_peer_set_endpoint(struct wg_peer *, struct wg_endpoint *);
331 static void wg_peer_clear_src(struct wg_peer *);
332 static void wg_peer_get_endpoint(struct wg_peer *, struct wg_endpoint *);
333 static void wg_send_buf(struct wg_softc *, struct wg_endpoint *, uint8_t *, size_t);
334 static void wg_send_keepalive(struct wg_peer *);
335 static void wg_handshake(struct wg_softc *, struct wg_packet *);
336 static void wg_encrypt(struct wg_softc *, struct wg_packet *);
337 static void wg_decrypt(struct wg_softc *, struct wg_packet *);
338 static void wg_softc_handshake_receive(struct wg_softc *);
339 static void wg_softc_decrypt(struct wg_softc *);
340 static void wg_softc_encrypt(struct wg_softc *);
341 static void wg_encrypt_dispatch(struct wg_softc *);
342 static void wg_decrypt_dispatch(struct wg_softc *);
343 static void wg_deliver_out(struct wg_peer *);
344 static void wg_deliver_in(struct wg_peer *);
345 static struct wg_packet *wg_packet_alloc(struct mbuf *);
346 static void wg_packet_free(struct wg_packet *);
347 static void wg_queue_init(struct wg_queue *, const char *);
348 static void wg_queue_deinit(struct wg_queue *);
349 static size_t wg_queue_len(struct wg_queue *);
350 static int wg_queue_enqueue_handshake(struct wg_queue *, struct wg_packet *);
351 static struct wg_packet *wg_queue_dequeue_handshake(struct wg_queue *);
352 static void wg_queue_push_staged(struct wg_queue *, struct wg_packet *);
353 static void wg_queue_enlist_staged(struct wg_queue *, struct wg_packet_list *);
354 static void wg_queue_delist_staged(struct wg_queue *, struct wg_packet_list *);
355 static void wg_queue_purge(struct wg_queue *);
356 static int wg_queue_both(struct wg_queue *, struct wg_queue *, struct wg_packet *);
357 static struct wg_packet *wg_queue_dequeue_serial(struct wg_queue *);
358 static struct wg_packet *wg_queue_dequeue_parallel(struct wg_queue *);
359 static bool wg_input(struct mbuf *, int, struct inpcb *, const struct sockaddr *, void *);
360 static void wg_peer_send_staged(struct wg_peer *);
361 static int wg_clone_create(struct if_clone *ifc, char *name, size_t len,
362 struct ifc_data *ifd, if_t *ifpp);
363 static void wg_qflush(if_t);
364 static inline int determine_af_and_pullup(struct mbuf **m, sa_family_t *af);
365 static int wg_xmit(if_t, struct mbuf *, sa_family_t, uint32_t);
366 static int wg_transmit(if_t, struct mbuf *);
367 static int wg_output(if_t, struct mbuf *, const struct sockaddr *, struct route *);
368 static int wg_clone_destroy(struct if_clone *ifc, if_t ifp,
369 uint32_t flags);
370 static bool wgc_privileged(struct wg_softc *);
371 static int wgc_get(struct wg_softc *, struct wg_data_io *);
372 static int wgc_set(struct wg_softc *, struct wg_data_io *);
373 static int wg_up(struct wg_softc *);
374 static void wg_down(struct wg_softc *);
375 static void wg_reassign(if_t, struct vnet *, char *unused);
376 static void wg_init(void *);
377 static int wg_ioctl(if_t, u_long, caddr_t);
378 static void vnet_wg_init(const void *);
379 static void vnet_wg_uninit(const void *);
380 static int wg_module_init(void);
381 static void wg_module_deinit(void);
382
383 /* TODO Peer */
384 static struct wg_peer *
wg_peer_create(struct wg_softc * sc,const uint8_t pub_key[WG_KEY_SIZE],int * errp)385 wg_peer_create(struct wg_softc *sc, const uint8_t pub_key[WG_KEY_SIZE],
386 int *errp)
387 {
388 struct wg_peer *peer;
389
390 sx_assert(&sc->sc_lock, SX_XLOCKED);
391
392 peer = malloc(sizeof(*peer), M_WG, M_WAITOK | M_ZERO);
393
394 peer->p_remote = noise_remote_alloc(sc->sc_local, peer, pub_key);
395 if ((*errp = noise_remote_enable(peer->p_remote)) != 0) {
396 noise_remote_free(peer->p_remote, NULL);
397 free(peer, M_WG);
398 return (NULL);
399 }
400
401 peer->p_id = peer_counter++;
402 peer->p_sc = sc;
403 peer->p_tx_bytes = counter_u64_alloc(M_WAITOK);
404 peer->p_rx_bytes = counter_u64_alloc(M_WAITOK);
405
406 cookie_maker_init(&peer->p_cookie, pub_key);
407
408 rw_init(&peer->p_endpoint_lock, "wg_peer_endpoint");
409
410 wg_queue_init(&peer->p_stage_queue, "stageq");
411 wg_queue_init(&peer->p_encrypt_serial, "txq");
412 wg_queue_init(&peer->p_decrypt_serial, "rxq");
413
414 peer->p_enabled = false;
415 peer->p_need_another_keepalive = false;
416 peer->p_persistent_keepalive_interval = 0;
417 callout_init(&peer->p_new_handshake, true);
418 callout_init(&peer->p_send_keepalive, true);
419 callout_init(&peer->p_retry_handshake, true);
420 callout_init(&peer->p_persistent_keepalive, true);
421 callout_init(&peer->p_zero_key_material, true);
422
423 mtx_init(&peer->p_handshake_mtx, "peer handshake", NULL, MTX_DEF);
424 bzero(&peer->p_handshake_complete, sizeof(peer->p_handshake_complete));
425 peer->p_handshake_retries = 0;
426
427 GROUPTASK_INIT(&peer->p_send, 0, (gtask_fn_t *)wg_deliver_out, peer);
428 taskqgroup_attach(qgroup_wg_tqg, &peer->p_send, peer, NULL, NULL, "wg send");
429 GROUPTASK_INIT(&peer->p_recv, 0, (gtask_fn_t *)wg_deliver_in, peer);
430 taskqgroup_attach(qgroup_wg_tqg, &peer->p_recv, peer, NULL, NULL, "wg recv");
431
432 LIST_INIT(&peer->p_aips);
433 peer->p_aips_num = 0;
434
435 TAILQ_INSERT_TAIL(&sc->sc_peers, peer, p_entry);
436 sc->sc_peers_num++;
437
438 if (if_getlinkstate(sc->sc_ifp) == LINK_STATE_UP)
439 wg_timers_enable(peer);
440
441 DPRINTF(sc, "Peer %" PRIu64 " created\n", peer->p_id);
442 return (peer);
443 }
444
445 static void
wg_peer_free_deferred(struct noise_remote * r)446 wg_peer_free_deferred(struct noise_remote *r)
447 {
448 struct wg_peer *peer = noise_remote_arg(r);
449
450 /* While there are no references remaining, we may still have
451 * p_{send,recv} executing (think empty queue, but wg_deliver_{in,out}
452 * needs to check the queue. We should wait for them and then free. */
453 GROUPTASK_DRAIN(&peer->p_recv);
454 GROUPTASK_DRAIN(&peer->p_send);
455 taskqgroup_detach(qgroup_wg_tqg, &peer->p_recv);
456 taskqgroup_detach(qgroup_wg_tqg, &peer->p_send);
457
458 wg_queue_deinit(&peer->p_decrypt_serial);
459 wg_queue_deinit(&peer->p_encrypt_serial);
460 wg_queue_deinit(&peer->p_stage_queue);
461
462 counter_u64_free(peer->p_tx_bytes);
463 counter_u64_free(peer->p_rx_bytes);
464 rw_destroy(&peer->p_endpoint_lock);
465 mtx_destroy(&peer->p_handshake_mtx);
466
467 cookie_maker_free(&peer->p_cookie);
468
469 free(peer, M_WG);
470 }
471
472 static void
wg_peer_destroy(struct wg_peer * peer)473 wg_peer_destroy(struct wg_peer *peer)
474 {
475 struct wg_softc *sc = peer->p_sc;
476 sx_assert(&sc->sc_lock, SX_XLOCKED);
477
478 /* Disable remote and timers. This will prevent any new handshakes
479 * occuring. */
480 noise_remote_disable(peer->p_remote);
481 wg_timers_disable(peer);
482
483 /* Now we can remove all allowed IPs so no more packets will be routed
484 * to the peer. */
485 wg_aip_remove_all(sc, peer);
486
487 /* Remove peer from the interface, then free. Some references may still
488 * exist to p_remote, so noise_remote_free will wait until they're all
489 * put to call wg_peer_free_deferred. */
490 sc->sc_peers_num--;
491 TAILQ_REMOVE(&sc->sc_peers, peer, p_entry);
492 DPRINTF(sc, "Peer %" PRIu64 " destroyed\n", peer->p_id);
493 noise_remote_free(peer->p_remote, wg_peer_free_deferred);
494 }
495
496 static void
wg_peer_destroy_all(struct wg_softc * sc)497 wg_peer_destroy_all(struct wg_softc *sc)
498 {
499 struct wg_peer *peer, *tpeer;
500 TAILQ_FOREACH_SAFE(peer, &sc->sc_peers, p_entry, tpeer)
501 wg_peer_destroy(peer);
502 }
503
504 static void
wg_peer_set_endpoint(struct wg_peer * peer,struct wg_endpoint * e)505 wg_peer_set_endpoint(struct wg_peer *peer, struct wg_endpoint *e)
506 {
507 MPASS(e->e_remote.r_sa.sa_family != 0);
508 if (memcmp(e, &peer->p_endpoint, sizeof(*e)) == 0)
509 return;
510
511 rw_wlock(&peer->p_endpoint_lock);
512 peer->p_endpoint = *e;
513 rw_wunlock(&peer->p_endpoint_lock);
514 }
515
516 static void
wg_peer_clear_src(struct wg_peer * peer)517 wg_peer_clear_src(struct wg_peer *peer)
518 {
519 rw_wlock(&peer->p_endpoint_lock);
520 bzero(&peer->p_endpoint.e_local, sizeof(peer->p_endpoint.e_local));
521 rw_wunlock(&peer->p_endpoint_lock);
522 }
523
524 static void
wg_peer_get_endpoint(struct wg_peer * peer,struct wg_endpoint * e)525 wg_peer_get_endpoint(struct wg_peer *peer, struct wg_endpoint *e)
526 {
527 rw_rlock(&peer->p_endpoint_lock);
528 *e = peer->p_endpoint;
529 rw_runlock(&peer->p_endpoint_lock);
530 }
531
532 static int
wg_aip_addrinfo(struct wg_aip * aip,const void * baddr,uint8_t cidr)533 wg_aip_addrinfo(struct wg_aip *aip, const void *baddr, uint8_t cidr)
534 {
535 #if defined(INET) || defined(INET6)
536 struct aip_addr *addr, *mask;
537
538 addr = &aip->a_addr;
539 mask = &aip->a_mask;
540 #endif
541 switch (aip->a_af) {
542 #ifdef INET
543 case AF_INET:
544 if (cidr > 32) cidr = 32;
545 addr->in = *(const struct in_addr *)baddr;
546 mask->ip = htonl(~((1LL << (32 - cidr)) - 1) & 0xffffffff);
547 addr->ip &= mask->ip;
548 addr->length = mask->length = offsetof(struct aip_addr, in) + sizeof(struct in_addr);
549 break;
550 #endif
551 #ifdef INET6
552 case AF_INET6:
553 if (cidr > 128) cidr = 128;
554 addr->in6 = *(const struct in6_addr *)baddr;
555 in6_prefixlen2mask(&mask->in6, cidr);
556 for (int i = 0; i < 4; i++)
557 addr->ip6[i] &= mask->ip6[i];
558 addr->length = mask->length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr);
559 break;
560 #endif
561 default:
562 return (EAFNOSUPPORT);
563 }
564
565 return (0);
566 }
567
568 /* Allowed IP */
569 static int
wg_aip_add(struct wg_softc * sc,struct wg_peer * peer,sa_family_t af,const void * baddr,uint8_t cidr)570 wg_aip_add(struct wg_softc *sc, struct wg_peer *peer, sa_family_t af,
571 const void *baddr, uint8_t cidr)
572 {
573 struct radix_node_head *root = NULL;
574 struct radix_node *node;
575 struct wg_aip *aip;
576 int ret = 0;
577
578 aip = malloc(sizeof(*aip), M_WG, M_WAITOK | M_ZERO);
579 aip->a_peer = peer;
580 aip->a_af = af;
581
582 ret = wg_aip_addrinfo(aip, baddr, cidr);
583 if (ret != 0) {
584 free(aip, M_WG);
585 return (ret);
586 }
587
588 root = af == AF_INET ? sc->sc_aip4 : sc->sc_aip6;
589 MPASS(root != NULL);
590 RADIX_NODE_HEAD_LOCK(root);
591 node = root->rnh_addaddr(&aip->a_addr, &aip->a_mask, &root->rh, aip->a_nodes);
592 if (node == aip->a_nodes) {
593 LIST_INSERT_HEAD(&peer->p_aips, aip, a_entry);
594 peer->p_aips_num++;
595 } else if (!node)
596 node = root->rnh_lookup(&aip->a_addr, &aip->a_mask, &root->rh);
597 if (!node) {
598 free(aip, M_WG);
599 ret = ENOMEM;
600 } else if (node != aip->a_nodes) {
601 free(aip, M_WG);
602 aip = (struct wg_aip *)node;
603 if (aip->a_peer != peer) {
604 LIST_REMOVE(aip, a_entry);
605 aip->a_peer->p_aips_num--;
606 aip->a_peer = peer;
607 LIST_INSERT_HEAD(&peer->p_aips, aip, a_entry);
608 aip->a_peer->p_aips_num++;
609 }
610 }
611 RADIX_NODE_HEAD_UNLOCK(root);
612 return (ret);
613 }
614
615 static int
wg_aip_del(struct wg_softc * sc,struct wg_peer * peer,sa_family_t af,const void * baddr,uint8_t cidr)616 wg_aip_del(struct wg_softc *sc, struct wg_peer *peer, sa_family_t af,
617 const void *baddr, uint8_t cidr)
618 {
619 struct radix_node_head *root = NULL;
620 struct radix_node *dnode __diagused, *node;
621 struct wg_aip *aip, addr;
622 int ret = 0;
623
624 /*
625 * We need to be sure that all padding is cleared, as it is above when
626 * new AllowedIPs are added, since we want to do a direct comparison.
627 */
628 memset(&addr, 0, sizeof(addr));
629 addr.a_af = af;
630
631 ret = wg_aip_addrinfo(&addr, baddr, cidr);
632 if (ret != 0)
633 return (ret);
634
635 root = af == AF_INET ? sc->sc_aip4 : sc->sc_aip6;
636
637 MPASS(root != NULL);
638 RADIX_NODE_HEAD_LOCK(root);
639
640 node = root->rnh_lookup(&addr.a_addr, &addr.a_mask, &root->rh);
641 if (node == NULL) {
642 RADIX_NODE_HEAD_UNLOCK(root);
643 return (0);
644 }
645
646 aip = (struct wg_aip *)node;
647 if (aip->a_peer != peer) {
648 /*
649 * They could have specified an allowed-ip that belonged to a
650 * different peer, in which case our job is done because the
651 * AllowedIP has been removed.
652 */
653 RADIX_NODE_HEAD_UNLOCK(root);
654 return (0);
655 }
656
657 dnode = root->rnh_deladdr(&aip->a_addr, &aip->a_mask, &root->rh);
658 MPASS(dnode == node);
659 RADIX_NODE_HEAD_UNLOCK(root);
660
661 LIST_REMOVE(aip, a_entry);
662 peer->p_aips_num--;
663 free(aip, M_WG);
664 return (0);
665 }
666
667 static struct wg_peer *
wg_aip_lookup(struct wg_softc * sc,sa_family_t af,void * a)668 wg_aip_lookup(struct wg_softc *sc, sa_family_t af, void *a)
669 {
670 struct radix_node_head *root;
671 struct radix_node *node;
672 struct wg_peer *peer;
673 struct aip_addr addr;
674 RADIX_NODE_HEAD_RLOCK_TRACKER;
675
676 switch (af) {
677 case AF_INET:
678 root = sc->sc_aip4;
679 memcpy(&addr.in, a, sizeof(addr.in));
680 addr.length = offsetof(struct aip_addr, in) + sizeof(struct in_addr);
681 break;
682 case AF_INET6:
683 root = sc->sc_aip6;
684 memcpy(&addr.in6, a, sizeof(addr.in6));
685 addr.length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr);
686 break;
687 default:
688 return NULL;
689 }
690
691 RADIX_NODE_HEAD_RLOCK(root);
692 node = root->rnh_matchaddr(&addr, &root->rh);
693 if (node != NULL) {
694 peer = ((struct wg_aip *)node)->a_peer;
695 noise_remote_ref(peer->p_remote);
696 } else {
697 peer = NULL;
698 }
699 RADIX_NODE_HEAD_RUNLOCK(root);
700
701 return (peer);
702 }
703
704 static void
wg_aip_remove_all(struct wg_softc * sc,struct wg_peer * peer)705 wg_aip_remove_all(struct wg_softc *sc, struct wg_peer *peer)
706 {
707 struct wg_aip *aip, *taip;
708
709 RADIX_NODE_HEAD_LOCK(sc->sc_aip4);
710 LIST_FOREACH_SAFE(aip, &peer->p_aips, a_entry, taip) {
711 if (aip->a_af == AF_INET) {
712 if (sc->sc_aip4->rnh_deladdr(&aip->a_addr, &aip->a_mask, &sc->sc_aip4->rh) == NULL)
713 panic("failed to delete aip %p", aip);
714 LIST_REMOVE(aip, a_entry);
715 peer->p_aips_num--;
716 free(aip, M_WG);
717 }
718 }
719 RADIX_NODE_HEAD_UNLOCK(sc->sc_aip4);
720
721 RADIX_NODE_HEAD_LOCK(sc->sc_aip6);
722 LIST_FOREACH_SAFE(aip, &peer->p_aips, a_entry, taip) {
723 if (aip->a_af == AF_INET6) {
724 if (sc->sc_aip6->rnh_deladdr(&aip->a_addr, &aip->a_mask, &sc->sc_aip6->rh) == NULL)
725 panic("failed to delete aip %p", aip);
726 LIST_REMOVE(aip, a_entry);
727 peer->p_aips_num--;
728 free(aip, M_WG);
729 }
730 }
731 RADIX_NODE_HEAD_UNLOCK(sc->sc_aip6);
732
733 if (!LIST_EMPTY(&peer->p_aips) || peer->p_aips_num != 0)
734 panic("wg_aip_remove_all could not delete all %p", peer);
735 }
736
737 static int
wg_socket_init(struct wg_softc * sc,in_port_t port)738 wg_socket_init(struct wg_softc *sc, in_port_t port)
739 {
740 struct ucred *cred = sc->sc_ucred;
741 struct socket *so4 = NULL, *so6 = NULL;
742 int rc;
743
744 sx_assert(&sc->sc_lock, SX_XLOCKED);
745
746 if (!cred)
747 return (EBUSY);
748
749 /*
750 * For socket creation, we use the creds of the thread that created the
751 * tunnel rather than the current thread to maintain the semantics that
752 * WireGuard has on Linux with network namespaces -- that the sockets
753 * are created in their home vnet so that they can be configured and
754 * functionally attached to a foreign vnet as the jail's only interface
755 * to the network.
756 */
757 #ifdef INET
758 rc = socreate(AF_INET, &so4, SOCK_DGRAM, IPPROTO_UDP, cred, curthread);
759 if (rc)
760 goto out;
761
762 rc = udp_set_kernel_tunneling(so4, wg_input, NULL, sc);
763 /*
764 * udp_set_kernel_tunneling can only fail if there is already a tunneling function set.
765 * This should never happen with a new socket.
766 */
767 MPASS(rc == 0);
768 #endif
769
770 #ifdef INET6
771 rc = socreate(AF_INET6, &so6, SOCK_DGRAM, IPPROTO_UDP, cred, curthread);
772 if (rc)
773 goto out;
774 rc = udp_set_kernel_tunneling(so6, wg_input, NULL, sc);
775 MPASS(rc == 0);
776 #endif
777
778 if (sc->sc_socket.so_user_cookie) {
779 rc = wg_socket_set_sockopt(so4, so6, SO_USER_COOKIE, &sc->sc_socket.so_user_cookie, sizeof(sc->sc_socket.so_user_cookie));
780 if (rc)
781 goto out;
782 }
783 rc = wg_socket_set_sockopt(so4, so6, SO_SETFIB, &sc->sc_socket.so_fibnum, sizeof(sc->sc_socket.so_fibnum));
784 if (rc)
785 goto out;
786
787 rc = wg_socket_bind(&so4, &so6, &port);
788 if (!rc) {
789 sc->sc_socket.so_port = port;
790 wg_socket_set(sc, so4, so6);
791 }
792 out:
793 if (rc) {
794 if (so4 != NULL)
795 soclose(so4);
796 if (so6 != NULL)
797 soclose(so6);
798 }
799 return (rc);
800 }
801
wg_socket_set_sockopt(struct socket * so4,struct socket * so6,int name,void * val,size_t len)802 static int wg_socket_set_sockopt(struct socket *so4, struct socket *so6, int name, void *val, size_t len)
803 {
804 int ret4 = 0, ret6 = 0;
805 struct sockopt sopt = {
806 .sopt_dir = SOPT_SET,
807 .sopt_level = SOL_SOCKET,
808 .sopt_name = name,
809 .sopt_val = val,
810 .sopt_valsize = len
811 };
812
813 if (so4)
814 ret4 = sosetopt(so4, &sopt);
815 if (so6)
816 ret6 = sosetopt(so6, &sopt);
817 return (ret4 ?: ret6);
818 }
819
wg_socket_set_cookie(struct wg_softc * sc,uint32_t user_cookie)820 static int wg_socket_set_cookie(struct wg_softc *sc, uint32_t user_cookie)
821 {
822 struct wg_socket *so = &sc->sc_socket;
823 int ret;
824
825 sx_assert(&sc->sc_lock, SX_XLOCKED);
826 ret = wg_socket_set_sockopt(so->so_so4, so->so_so6, SO_USER_COOKIE, &user_cookie, sizeof(user_cookie));
827 if (!ret)
828 so->so_user_cookie = user_cookie;
829 return (ret);
830 }
831
wg_socket_set_fibnum(struct wg_softc * sc,int fibnum)832 static int wg_socket_set_fibnum(struct wg_softc *sc, int fibnum)
833 {
834 struct wg_socket *so = &sc->sc_socket;
835 int ret;
836
837 sx_assert(&sc->sc_lock, SX_XLOCKED);
838
839 ret = wg_socket_set_sockopt(so->so_so4, so->so_so6, SO_SETFIB, &fibnum, sizeof(fibnum));
840 if (!ret)
841 so->so_fibnum = fibnum;
842 return (ret);
843 }
844
845 static void
wg_socket_uninit(struct wg_softc * sc)846 wg_socket_uninit(struct wg_softc *sc)
847 {
848 wg_socket_set(sc, NULL, NULL);
849 }
850
851 static void
wg_socket_set(struct wg_softc * sc,struct socket * new_so4,struct socket * new_so6)852 wg_socket_set(struct wg_softc *sc, struct socket *new_so4, struct socket *new_so6)
853 {
854 struct wg_socket *so = &sc->sc_socket;
855 struct socket *so4, *so6;
856
857 sx_assert(&sc->sc_lock, SX_XLOCKED);
858
859 so4 = atomic_load_ptr(&so->so_so4);
860 so6 = atomic_load_ptr(&so->so_so6);
861 atomic_store_ptr(&so->so_so4, new_so4);
862 atomic_store_ptr(&so->so_so6, new_so6);
863
864 if (!so4 && !so6)
865 return;
866 NET_EPOCH_WAIT();
867 if (so4)
868 soclose(so4);
869 if (so6)
870 soclose(so6);
871 }
872
873 static int
wg_socket_bind(struct socket ** in_so4,struct socket ** in_so6,in_port_t * requested_port)874 wg_socket_bind(struct socket **in_so4, struct socket **in_so6, in_port_t *requested_port)
875 {
876 struct socket *so4 = *in_so4, *so6 = *in_so6;
877 int ret4 = 0, ret6 = 0;
878 in_port_t port = *requested_port;
879 struct sockaddr_in sin = {
880 .sin_len = sizeof(struct sockaddr_in),
881 .sin_family = AF_INET,
882 .sin_port = htons(port)
883 };
884 struct sockaddr_in6 sin6 = {
885 .sin6_len = sizeof(struct sockaddr_in6),
886 .sin6_family = AF_INET6,
887 .sin6_port = htons(port)
888 };
889
890 if (so4) {
891 ret4 = sobind(so4, (struct sockaddr *)&sin, curthread);
892 if (ret4 && ret4 != EADDRNOTAVAIL)
893 return (ret4);
894 if (!ret4 && !sin.sin_port) {
895 struct sockaddr_in bound_sin =
896 { .sin_len = sizeof(bound_sin) };
897 int ret;
898
899 ret = sosockaddr(so4, (struct sockaddr *)&bound_sin);
900 if (ret)
901 return (ret);
902 port = ntohs(bound_sin.sin_port);
903 sin6.sin6_port = bound_sin.sin_port;
904 }
905 }
906
907 if (so6) {
908 ret6 = sobind(so6, (struct sockaddr *)&sin6, curthread);
909 if (ret6 && ret6 != EADDRNOTAVAIL)
910 return (ret6);
911 if (!ret6 && !sin6.sin6_port) {
912 struct sockaddr_in6 bound_sin6 =
913 { .sin6_len = sizeof(bound_sin6) };
914 int ret;
915
916 ret = sosockaddr(so6, (struct sockaddr *)&bound_sin6);
917 if (ret)
918 return (ret);
919 port = ntohs(bound_sin6.sin6_port);
920 }
921 }
922
923 if (ret4 && ret6)
924 return (ret4);
925 *requested_port = port;
926 if (ret4 && !ret6 && so4) {
927 soclose(so4);
928 *in_so4 = NULL;
929 } else if (ret6 && !ret4 && so6) {
930 soclose(so6);
931 *in_so6 = NULL;
932 }
933 return (0);
934 }
935
936 static int
wg_send(struct wg_softc * sc,struct wg_endpoint * e,struct mbuf * m)937 wg_send(struct wg_softc *sc, struct wg_endpoint *e, struct mbuf *m)
938 {
939 struct epoch_tracker et;
940 struct sockaddr *sa;
941 struct wg_socket *so = &sc->sc_socket;
942 struct socket *so4, *so6;
943 struct mbuf *control = NULL;
944 int ret = 0;
945 size_t len = m->m_pkthdr.len;
946
947 /* Get local control address before locking */
948 if (e->e_remote.r_sa.sa_family == AF_INET) {
949 if (e->e_local.l_in.s_addr != INADDR_ANY)
950 control = sbcreatecontrol((caddr_t)&e->e_local.l_in,
951 sizeof(struct in_addr), IP_SENDSRCADDR,
952 IPPROTO_IP, M_NOWAIT);
953 #ifdef INET6
954 } else if (e->e_remote.r_sa.sa_family == AF_INET6) {
955 if (!IN6_IS_ADDR_UNSPECIFIED(&e->e_local.l_in6))
956 control = sbcreatecontrol((caddr_t)&e->e_local.l_pktinfo6,
957 sizeof(struct in6_pktinfo), IPV6_PKTINFO,
958 IPPROTO_IPV6, M_NOWAIT);
959 #endif
960 } else {
961 m_freem(m);
962 return (EAFNOSUPPORT);
963 }
964
965 /* Get remote address */
966 sa = &e->e_remote.r_sa;
967
968 NET_EPOCH_ENTER(et);
969 so4 = atomic_load_ptr(&so->so_so4);
970 so6 = atomic_load_ptr(&so->so_so6);
971 if (e->e_remote.r_sa.sa_family == AF_INET && so4 != NULL)
972 ret = sosend(so4, sa, NULL, m, control, 0, curthread);
973 else if (e->e_remote.r_sa.sa_family == AF_INET6 && so6 != NULL)
974 ret = sosend(so6, sa, NULL, m, control, 0, curthread);
975 else {
976 ret = ENOTCONN;
977 m_freem(control);
978 m_freem(m);
979 }
980 NET_EPOCH_EXIT(et);
981 if (ret == 0) {
982 if_inc_counter(sc->sc_ifp, IFCOUNTER_OPACKETS, 1);
983 if_inc_counter(sc->sc_ifp, IFCOUNTER_OBYTES, len);
984 }
985 return (ret);
986 }
987
988 static void
wg_send_buf(struct wg_softc * sc,struct wg_endpoint * e,uint8_t * buf,size_t len)989 wg_send_buf(struct wg_softc *sc, struct wg_endpoint *e, uint8_t *buf, size_t len)
990 {
991 struct mbuf *m;
992 int ret = 0;
993 bool retried = false;
994
995 retry:
996 m = m_get2(len, M_NOWAIT, MT_DATA, M_PKTHDR);
997 if (!m) {
998 ret = ENOMEM;
999 goto out;
1000 }
1001 m_copyback(m, 0, len, buf);
1002
1003 if (ret == 0) {
1004 ret = wg_send(sc, e, m);
1005 /* Retry if we couldn't bind to e->e_local */
1006 if (ret == EADDRNOTAVAIL && !retried) {
1007 bzero(&e->e_local, sizeof(e->e_local));
1008 retried = true;
1009 goto retry;
1010 }
1011 } else {
1012 ret = wg_send(sc, e, m);
1013 }
1014 out:
1015 if (ret)
1016 DPRINTF(sc, "Unable to send packet: %d\n", ret);
1017 }
1018
1019 /* Timers */
1020 static void
wg_timers_enable(struct wg_peer * peer)1021 wg_timers_enable(struct wg_peer *peer)
1022 {
1023 atomic_store_bool(&peer->p_enabled, true);
1024 wg_timers_run_persistent_keepalive(peer);
1025 }
1026
1027 static void
wg_timers_disable(struct wg_peer * peer)1028 wg_timers_disable(struct wg_peer *peer)
1029 {
1030 /* By setting p_enabled = false, then calling NET_EPOCH_WAIT, we can be
1031 * sure no new handshakes are created after the wait. This is because
1032 * all callout_resets (scheduling the callout) are guarded by
1033 * p_enabled. We can be sure all sections that read p_enabled and then
1034 * optionally call callout_reset are finished as they are surrounded by
1035 * NET_EPOCH_{ENTER,EXIT}.
1036 *
1037 * However, as new callouts may be scheduled during NET_EPOCH_WAIT (but
1038 * not after), we stop all callouts leaving no callouts active.
1039 *
1040 * We should also pull NET_EPOCH_WAIT out of the FOREACH(peer) loops, but the
1041 * performance impact is acceptable for the time being. */
1042 atomic_store_bool(&peer->p_enabled, false);
1043 NET_EPOCH_WAIT();
1044 atomic_store_bool(&peer->p_need_another_keepalive, false);
1045
1046 callout_stop(&peer->p_new_handshake);
1047 callout_stop(&peer->p_send_keepalive);
1048 callout_stop(&peer->p_retry_handshake);
1049 callout_stop(&peer->p_persistent_keepalive);
1050 callout_stop(&peer->p_zero_key_material);
1051 }
1052
1053 static void
wg_timers_set_persistent_keepalive(struct wg_peer * peer,uint16_t interval)1054 wg_timers_set_persistent_keepalive(struct wg_peer *peer, uint16_t interval)
1055 {
1056 struct epoch_tracker et;
1057 if (interval != peer->p_persistent_keepalive_interval) {
1058 atomic_store_16(&peer->p_persistent_keepalive_interval, interval);
1059 NET_EPOCH_ENTER(et);
1060 if (atomic_load_bool(&peer->p_enabled))
1061 wg_timers_run_persistent_keepalive(peer);
1062 NET_EPOCH_EXIT(et);
1063 }
1064 }
1065
1066 static void
wg_timers_get_last_handshake(struct wg_peer * peer,struct wg_timespec64 * time)1067 wg_timers_get_last_handshake(struct wg_peer *peer, struct wg_timespec64 *time)
1068 {
1069 mtx_lock(&peer->p_handshake_mtx);
1070 time->tv_sec = peer->p_handshake_complete.tv_sec;
1071 time->tv_nsec = peer->p_handshake_complete.tv_nsec;
1072 mtx_unlock(&peer->p_handshake_mtx);
1073 }
1074
1075 static void
wg_timers_event_data_sent(struct wg_peer * peer)1076 wg_timers_event_data_sent(struct wg_peer *peer)
1077 {
1078 struct epoch_tracker et;
1079 NET_EPOCH_ENTER(et);
1080 if (atomic_load_bool(&peer->p_enabled) &&
1081 !callout_pending(&peer->p_new_handshake))
1082 callout_reset(&peer->p_new_handshake, MSEC_2_TICKS(
1083 NEW_HANDSHAKE_TIMEOUT * 1000 +
1084 arc4random_uniform(REKEY_TIMEOUT_JITTER)),
1085 wg_timers_run_new_handshake, peer);
1086 NET_EPOCH_EXIT(et);
1087 }
1088
1089 static void
wg_timers_event_data_received(struct wg_peer * peer)1090 wg_timers_event_data_received(struct wg_peer *peer)
1091 {
1092 struct epoch_tracker et;
1093 NET_EPOCH_ENTER(et);
1094 if (atomic_load_bool(&peer->p_enabled)) {
1095 if (!callout_pending(&peer->p_send_keepalive))
1096 callout_reset(&peer->p_send_keepalive,
1097 MSEC_2_TICKS(KEEPALIVE_TIMEOUT * 1000),
1098 wg_timers_run_send_keepalive, peer);
1099 else
1100 atomic_store_bool(&peer->p_need_another_keepalive,
1101 true);
1102 }
1103 NET_EPOCH_EXIT(et);
1104 }
1105
1106 static void
wg_timers_event_any_authenticated_packet_sent(struct wg_peer * peer)1107 wg_timers_event_any_authenticated_packet_sent(struct wg_peer *peer)
1108 {
1109 callout_stop(&peer->p_send_keepalive);
1110 }
1111
1112 static void
wg_timers_event_any_authenticated_packet_received(struct wg_peer * peer)1113 wg_timers_event_any_authenticated_packet_received(struct wg_peer *peer)
1114 {
1115 callout_stop(&peer->p_new_handshake);
1116 }
1117
1118 static void
wg_timers_event_any_authenticated_packet_traversal(struct wg_peer * peer)1119 wg_timers_event_any_authenticated_packet_traversal(struct wg_peer *peer)
1120 {
1121 struct epoch_tracker et;
1122 uint16_t interval;
1123 NET_EPOCH_ENTER(et);
1124 interval = atomic_load_16(&peer->p_persistent_keepalive_interval);
1125 if (atomic_load_bool(&peer->p_enabled) && interval > 0)
1126 callout_reset(&peer->p_persistent_keepalive,
1127 MSEC_2_TICKS(interval * 1000),
1128 wg_timers_run_persistent_keepalive, peer);
1129 NET_EPOCH_EXIT(et);
1130 }
1131
1132 static void
wg_timers_event_handshake_initiated(struct wg_peer * peer)1133 wg_timers_event_handshake_initiated(struct wg_peer *peer)
1134 {
1135 struct epoch_tracker et;
1136 NET_EPOCH_ENTER(et);
1137 if (atomic_load_bool(&peer->p_enabled))
1138 callout_reset(&peer->p_retry_handshake, MSEC_2_TICKS(
1139 REKEY_TIMEOUT * 1000 +
1140 arc4random_uniform(REKEY_TIMEOUT_JITTER)),
1141 wg_timers_run_retry_handshake, peer);
1142 NET_EPOCH_EXIT(et);
1143 }
1144
1145 static void
wg_timers_event_handshake_complete(struct wg_peer * peer)1146 wg_timers_event_handshake_complete(struct wg_peer *peer)
1147 {
1148 struct epoch_tracker et;
1149 NET_EPOCH_ENTER(et);
1150 if (atomic_load_bool(&peer->p_enabled)) {
1151 mtx_lock(&peer->p_handshake_mtx);
1152 callout_stop(&peer->p_retry_handshake);
1153 peer->p_handshake_retries = 0;
1154 getnanotime(&peer->p_handshake_complete);
1155 mtx_unlock(&peer->p_handshake_mtx);
1156 wg_timers_run_send_keepalive(peer);
1157 }
1158 NET_EPOCH_EXIT(et);
1159 }
1160
1161 static void
wg_timers_event_session_derived(struct wg_peer * peer)1162 wg_timers_event_session_derived(struct wg_peer *peer)
1163 {
1164 struct epoch_tracker et;
1165 NET_EPOCH_ENTER(et);
1166 if (atomic_load_bool(&peer->p_enabled))
1167 callout_reset(&peer->p_zero_key_material,
1168 MSEC_2_TICKS(REJECT_AFTER_TIME * 3 * 1000),
1169 wg_timers_run_zero_key_material, peer);
1170 NET_EPOCH_EXIT(et);
1171 }
1172
1173 static void
wg_timers_event_want_initiation(struct wg_peer * peer)1174 wg_timers_event_want_initiation(struct wg_peer *peer)
1175 {
1176 struct epoch_tracker et;
1177 NET_EPOCH_ENTER(et);
1178 if (atomic_load_bool(&peer->p_enabled))
1179 wg_timers_run_send_initiation(peer, false);
1180 NET_EPOCH_EXIT(et);
1181 }
1182
1183 static void
wg_timers_run_send_initiation(struct wg_peer * peer,bool is_retry)1184 wg_timers_run_send_initiation(struct wg_peer *peer, bool is_retry)
1185 {
1186 if (!is_retry)
1187 peer->p_handshake_retries = 0;
1188 if (noise_remote_initiation_expired(peer->p_remote) == ETIMEDOUT)
1189 wg_send_initiation(peer);
1190 }
1191
1192 static void
wg_timers_run_retry_handshake(void * _peer)1193 wg_timers_run_retry_handshake(void *_peer)
1194 {
1195 struct epoch_tracker et;
1196 struct wg_peer *peer = _peer;
1197
1198 mtx_lock(&peer->p_handshake_mtx);
1199 if (peer->p_handshake_retries <= MAX_TIMER_HANDSHAKES) {
1200 peer->p_handshake_retries++;
1201 mtx_unlock(&peer->p_handshake_mtx);
1202
1203 DPRINTF(peer->p_sc, "Handshake for peer %" PRIu64 " did not complete "
1204 "after %d seconds, retrying (try %d)\n", peer->p_id,
1205 REKEY_TIMEOUT, peer->p_handshake_retries + 1);
1206 wg_peer_clear_src(peer);
1207 wg_timers_run_send_initiation(peer, true);
1208 } else {
1209 mtx_unlock(&peer->p_handshake_mtx);
1210
1211 DPRINTF(peer->p_sc, "Handshake for peer %" PRIu64 " did not complete "
1212 "after %d retries, giving up\n", peer->p_id,
1213 MAX_TIMER_HANDSHAKES + 2);
1214
1215 callout_stop(&peer->p_send_keepalive);
1216 wg_queue_purge(&peer->p_stage_queue);
1217 NET_EPOCH_ENTER(et);
1218 if (atomic_load_bool(&peer->p_enabled) &&
1219 !callout_pending(&peer->p_zero_key_material))
1220 callout_reset(&peer->p_zero_key_material,
1221 MSEC_2_TICKS(REJECT_AFTER_TIME * 3 * 1000),
1222 wg_timers_run_zero_key_material, peer);
1223 NET_EPOCH_EXIT(et);
1224 }
1225 }
1226
1227 static void
wg_timers_run_send_keepalive(void * _peer)1228 wg_timers_run_send_keepalive(void *_peer)
1229 {
1230 struct epoch_tracker et;
1231 struct wg_peer *peer = _peer;
1232
1233 wg_send_keepalive(peer);
1234 NET_EPOCH_ENTER(et);
1235 if (atomic_load_bool(&peer->p_enabled) &&
1236 atomic_load_bool(&peer->p_need_another_keepalive)) {
1237 atomic_store_bool(&peer->p_need_another_keepalive, false);
1238 callout_reset(&peer->p_send_keepalive,
1239 MSEC_2_TICKS(KEEPALIVE_TIMEOUT * 1000),
1240 wg_timers_run_send_keepalive, peer);
1241 }
1242 NET_EPOCH_EXIT(et);
1243 }
1244
1245 static void
wg_timers_run_new_handshake(void * _peer)1246 wg_timers_run_new_handshake(void *_peer)
1247 {
1248 struct wg_peer *peer = _peer;
1249
1250 DPRINTF(peer->p_sc, "Retrying handshake with peer %" PRIu64 " because we "
1251 "stopped hearing back after %d seconds\n",
1252 peer->p_id, NEW_HANDSHAKE_TIMEOUT);
1253
1254 wg_peer_clear_src(peer);
1255 wg_timers_run_send_initiation(peer, false);
1256 }
1257
1258 static void
wg_timers_run_zero_key_material(void * _peer)1259 wg_timers_run_zero_key_material(void *_peer)
1260 {
1261 struct wg_peer *peer = _peer;
1262
1263 DPRINTF(peer->p_sc, "Zeroing out keys for peer %" PRIu64 ", since we "
1264 "haven't received a new one in %d seconds\n",
1265 peer->p_id, REJECT_AFTER_TIME * 3);
1266 noise_remote_keypairs_clear(peer->p_remote);
1267 }
1268
1269 static void
wg_timers_run_persistent_keepalive(void * _peer)1270 wg_timers_run_persistent_keepalive(void *_peer)
1271 {
1272 struct wg_peer *peer = _peer;
1273
1274 if (atomic_load_16(&peer->p_persistent_keepalive_interval) > 0)
1275 wg_send_keepalive(peer);
1276 }
1277
1278 /* TODO Handshake */
1279 static void
wg_peer_send_buf(struct wg_peer * peer,uint8_t * buf,size_t len)1280 wg_peer_send_buf(struct wg_peer *peer, uint8_t *buf, size_t len)
1281 {
1282 struct wg_endpoint endpoint;
1283
1284 counter_u64_add(peer->p_tx_bytes, len);
1285 wg_timers_event_any_authenticated_packet_traversal(peer);
1286 wg_timers_event_any_authenticated_packet_sent(peer);
1287 wg_peer_get_endpoint(peer, &endpoint);
1288 wg_send_buf(peer->p_sc, &endpoint, buf, len);
1289 }
1290
1291 static void
wg_send_initiation(struct wg_peer * peer)1292 wg_send_initiation(struct wg_peer *peer)
1293 {
1294 struct wg_pkt_initiation pkt;
1295
1296 if (noise_create_initiation(peer->p_remote, &pkt.s_idx, pkt.ue,
1297 pkt.es, pkt.ets) != 0)
1298 return;
1299
1300 DPRINTF(peer->p_sc, "Sending handshake initiation to peer %" PRIu64 "\n", peer->p_id);
1301
1302 pkt.t = WG_PKT_INITIATION;
1303 cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt,
1304 sizeof(pkt) - sizeof(pkt.m));
1305 wg_peer_send_buf(peer, (uint8_t *)&pkt, sizeof(pkt));
1306 wg_timers_event_handshake_initiated(peer);
1307 }
1308
1309 static void
wg_send_response(struct wg_peer * peer)1310 wg_send_response(struct wg_peer *peer)
1311 {
1312 struct wg_pkt_response pkt;
1313
1314 if (noise_create_response(peer->p_remote, &pkt.s_idx, &pkt.r_idx,
1315 pkt.ue, pkt.en) != 0)
1316 return;
1317
1318 DPRINTF(peer->p_sc, "Sending handshake response to peer %" PRIu64 "\n", peer->p_id);
1319
1320 wg_timers_event_session_derived(peer);
1321 pkt.t = WG_PKT_RESPONSE;
1322 cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt,
1323 sizeof(pkt)-sizeof(pkt.m));
1324 wg_peer_send_buf(peer, (uint8_t*)&pkt, sizeof(pkt));
1325 }
1326
1327 static void
wg_send_cookie(struct wg_softc * sc,struct cookie_macs * cm,uint32_t idx,struct wg_endpoint * e)1328 wg_send_cookie(struct wg_softc *sc, struct cookie_macs *cm, uint32_t idx,
1329 struct wg_endpoint *e)
1330 {
1331 struct wg_pkt_cookie pkt;
1332
1333 DPRINTF(sc, "Sending cookie response for denied handshake message\n");
1334
1335 pkt.t = WG_PKT_COOKIE;
1336 pkt.r_idx = idx;
1337
1338 cookie_checker_create_payload(&sc->sc_cookie, cm, pkt.nonce,
1339 pkt.ec, &e->e_remote.r_sa);
1340 wg_send_buf(sc, e, (uint8_t *)&pkt, sizeof(pkt));
1341 }
1342
1343 static void
wg_send_keepalive(struct wg_peer * peer)1344 wg_send_keepalive(struct wg_peer *peer)
1345 {
1346 struct wg_packet *pkt;
1347 struct mbuf *m;
1348
1349 if (wg_queue_len(&peer->p_stage_queue) > 0)
1350 goto send;
1351 if ((m = m_gethdr(M_NOWAIT, MT_DATA)) == NULL)
1352 return;
1353 if ((pkt = wg_packet_alloc(m)) == NULL) {
1354 m_freem(m);
1355 return;
1356 }
1357 wg_queue_push_staged(&peer->p_stage_queue, pkt);
1358 DPRINTF(peer->p_sc, "Sending keepalive packet to peer %" PRIu64 "\n", peer->p_id);
1359 send:
1360 wg_peer_send_staged(peer);
1361 }
1362
1363 static void
wg_handshake(struct wg_softc * sc,struct wg_packet * pkt)1364 wg_handshake(struct wg_softc *sc, struct wg_packet *pkt)
1365 {
1366 struct wg_pkt_initiation *init;
1367 struct wg_pkt_response *resp;
1368 struct wg_pkt_cookie *cook;
1369 struct wg_endpoint *e;
1370 struct wg_peer *peer;
1371 struct mbuf *m;
1372 struct noise_remote *remote = NULL;
1373 int res;
1374 bool underload = false;
1375 static sbintime_t wg_last_underload; /* sbinuptime */
1376
1377 underload = wg_queue_len(&sc->sc_handshake_queue) >= MAX_QUEUED_HANDSHAKES / 8;
1378 if (underload) {
1379 wg_last_underload = getsbinuptime();
1380 } else if (wg_last_underload) {
1381 underload = wg_last_underload + UNDERLOAD_TIMEOUT * SBT_1S > getsbinuptime();
1382 if (!underload)
1383 wg_last_underload = 0;
1384 }
1385
1386 m = pkt->p_mbuf;
1387 e = &pkt->p_endpoint;
1388
1389 if ((pkt->p_mbuf = m = m_pullup(m, m->m_pkthdr.len)) == NULL)
1390 goto error;
1391
1392 switch (*mtod(m, uint32_t *)) {
1393 case WG_PKT_INITIATION:
1394 init = mtod(m, struct wg_pkt_initiation *);
1395
1396 res = cookie_checker_validate_macs(&sc->sc_cookie, &init->m,
1397 init, sizeof(*init) - sizeof(init->m),
1398 underload, &e->e_remote.r_sa,
1399 if_getvnet(sc->sc_ifp));
1400
1401 if (res == EINVAL) {
1402 DPRINTF(sc, "Invalid initiation MAC\n");
1403 goto error;
1404 } else if (res == ECONNREFUSED) {
1405 DPRINTF(sc, "Handshake ratelimited\n");
1406 goto error;
1407 } else if (res == EAGAIN) {
1408 wg_send_cookie(sc, &init->m, init->s_idx, e);
1409 goto error;
1410 } else if (res != 0) {
1411 panic("unexpected response: %d\n", res);
1412 }
1413
1414 if (noise_consume_initiation(sc->sc_local, &remote,
1415 init->s_idx, init->ue, init->es, init->ets) != 0) {
1416 DPRINTF(sc, "Invalid handshake initiation\n");
1417 goto error;
1418 }
1419
1420 peer = noise_remote_arg(remote);
1421
1422 DPRINTF(sc, "Receiving handshake initiation from peer %" PRIu64 "\n", peer->p_id);
1423
1424 wg_peer_set_endpoint(peer, e);
1425 wg_send_response(peer);
1426 break;
1427 case WG_PKT_RESPONSE:
1428 resp = mtod(m, struct wg_pkt_response *);
1429
1430 res = cookie_checker_validate_macs(&sc->sc_cookie, &resp->m,
1431 resp, sizeof(*resp) - sizeof(resp->m),
1432 underload, &e->e_remote.r_sa,
1433 if_getvnet(sc->sc_ifp));
1434
1435 if (res == EINVAL) {
1436 DPRINTF(sc, "Invalid response MAC\n");
1437 goto error;
1438 } else if (res == ECONNREFUSED) {
1439 DPRINTF(sc, "Handshake ratelimited\n");
1440 goto error;
1441 } else if (res == EAGAIN) {
1442 wg_send_cookie(sc, &resp->m, resp->s_idx, e);
1443 goto error;
1444 } else if (res != 0) {
1445 panic("unexpected response: %d\n", res);
1446 }
1447
1448 if (noise_consume_response(sc->sc_local, &remote,
1449 resp->s_idx, resp->r_idx, resp->ue, resp->en) != 0) {
1450 DPRINTF(sc, "Invalid handshake response\n");
1451 goto error;
1452 }
1453
1454 peer = noise_remote_arg(remote);
1455 DPRINTF(sc, "Receiving handshake response from peer %" PRIu64 "\n", peer->p_id);
1456
1457 wg_peer_set_endpoint(peer, e);
1458 wg_timers_event_session_derived(peer);
1459 wg_timers_event_handshake_complete(peer);
1460 break;
1461 case WG_PKT_COOKIE:
1462 cook = mtod(m, struct wg_pkt_cookie *);
1463
1464 if ((remote = noise_remote_index(sc->sc_local, cook->r_idx)) == NULL) {
1465 DPRINTF(sc, "Unknown cookie index\n");
1466 goto error;
1467 }
1468
1469 peer = noise_remote_arg(remote);
1470
1471 if (cookie_maker_consume_payload(&peer->p_cookie,
1472 cook->nonce, cook->ec) == 0) {
1473 DPRINTF(sc, "Receiving cookie response\n");
1474 } else {
1475 DPRINTF(sc, "Could not decrypt cookie response\n");
1476 goto error;
1477 }
1478
1479 goto not_authenticated;
1480 default:
1481 panic("invalid packet in handshake queue");
1482 }
1483
1484 wg_timers_event_any_authenticated_packet_received(peer);
1485 wg_timers_event_any_authenticated_packet_traversal(peer);
1486
1487 not_authenticated:
1488 counter_u64_add(peer->p_rx_bytes, m->m_pkthdr.len);
1489 if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1);
1490 if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, m->m_pkthdr.len);
1491 error:
1492 if (remote != NULL)
1493 noise_remote_put(remote);
1494 wg_packet_free(pkt);
1495 }
1496
1497 static void
wg_softc_handshake_receive(struct wg_softc * sc)1498 wg_softc_handshake_receive(struct wg_softc *sc)
1499 {
1500 struct wg_packet *pkt;
1501 while ((pkt = wg_queue_dequeue_handshake(&sc->sc_handshake_queue)) != NULL)
1502 wg_handshake(sc, pkt);
1503 }
1504
1505 static void
wg_mbuf_reset(struct mbuf * m)1506 wg_mbuf_reset(struct mbuf *m)
1507 {
1508
1509 struct m_tag *t, *tmp;
1510
1511 /*
1512 * We want to reset the mbuf to a newly allocated state, containing
1513 * just the packet contents. Unfortunately FreeBSD doesn't seem to
1514 * offer this anywhere, so we have to make it up as we go. If we can
1515 * get this in kern/kern_mbuf.c, that would be best.
1516 *
1517 * Notice: this may break things unexpectedly but it is better to fail
1518 * closed in the extreme case than leak informtion in every
1519 * case.
1520 *
1521 * With that said, all this attempts to do is remove any extraneous
1522 * information that could be present.
1523 */
1524
1525 M_ASSERTPKTHDR(m);
1526
1527 m->m_flags &= ~(M_BCAST|M_MCAST|M_VLANTAG|M_PROMISC|M_PROTOFLAGS);
1528
1529 M_HASHTYPE_CLEAR(m);
1530 #ifdef NUMA
1531 m->m_pkthdr.numa_domain = M_NODOM;
1532 #endif
1533 SLIST_FOREACH_SAFE(t, &m->m_pkthdr.tags, m_tag_link, tmp) {
1534 if ((t->m_tag_id != 0 || t->m_tag_cookie != MTAG_WGLOOP) &&
1535 t->m_tag_id != PACKET_TAG_MACLABEL)
1536 m_tag_delete(m, t);
1537 }
1538
1539 KASSERT((m->m_pkthdr.csum_flags & CSUM_SND_TAG) == 0,
1540 ("%s: mbuf %p has a send tag", __func__, m));
1541
1542 m->m_pkthdr.csum_flags = 0;
1543 m->m_pkthdr.PH_per.sixtyfour[0] = 0;
1544 m->m_pkthdr.PH_loc.sixtyfour[0] = 0;
1545 }
1546
1547 static inline unsigned int
calculate_padding(struct wg_packet * pkt)1548 calculate_padding(struct wg_packet *pkt)
1549 {
1550 unsigned int padded_size, last_unit = pkt->p_mbuf->m_pkthdr.len;
1551
1552 /* Keepalive packets don't set p_mtu, but also have a length of zero. */
1553 if (__predict_false(pkt->p_mtu == 0)) {
1554 padded_size = (last_unit + (WG_PKT_PADDING - 1)) &
1555 ~(WG_PKT_PADDING - 1);
1556 return (padded_size - last_unit);
1557 }
1558
1559 if (__predict_false(last_unit > pkt->p_mtu))
1560 last_unit %= pkt->p_mtu;
1561
1562 padded_size = (last_unit + (WG_PKT_PADDING - 1)) & ~(WG_PKT_PADDING - 1);
1563 if (pkt->p_mtu < padded_size)
1564 padded_size = pkt->p_mtu;
1565 return (padded_size - last_unit);
1566 }
1567
1568 static void
wg_encrypt(struct wg_softc * sc,struct wg_packet * pkt)1569 wg_encrypt(struct wg_softc *sc, struct wg_packet *pkt)
1570 {
1571 static const uint8_t padding[WG_PKT_PADDING] = { 0 };
1572 struct wg_pkt_data *data;
1573 struct wg_peer *peer;
1574 struct noise_remote *remote;
1575 struct mbuf *m;
1576 uint32_t idx;
1577 unsigned int padlen;
1578 enum wg_ring_state state = WG_PACKET_DEAD;
1579
1580 remote = noise_keypair_remote(pkt->p_keypair);
1581 peer = noise_remote_arg(remote);
1582 m = pkt->p_mbuf;
1583
1584 /* Pad the packet */
1585 padlen = calculate_padding(pkt);
1586 if (padlen != 0 && !m_append(m, padlen, padding))
1587 goto out;
1588
1589 /* Do encryption */
1590 if (noise_keypair_encrypt(pkt->p_keypair, &idx, pkt->p_nonce, m) != 0)
1591 goto out;
1592
1593 /* Put header into packet */
1594 M_PREPEND(m, sizeof(struct wg_pkt_data), M_NOWAIT);
1595 if (m == NULL)
1596 goto out;
1597 data = mtod(m, struct wg_pkt_data *);
1598 data->t = WG_PKT_DATA;
1599 data->r_idx = idx;
1600 data->nonce = htole64(pkt->p_nonce);
1601
1602 wg_mbuf_reset(m);
1603 state = WG_PACKET_CRYPTED;
1604 out:
1605 pkt->p_mbuf = m;
1606 atomic_store_rel_int(&pkt->p_state, state);
1607 GROUPTASK_ENQUEUE(&peer->p_send);
1608 noise_remote_put(remote);
1609 }
1610
1611 static void
wg_decrypt(struct wg_softc * sc,struct wg_packet * pkt)1612 wg_decrypt(struct wg_softc *sc, struct wg_packet *pkt)
1613 {
1614 struct wg_peer *peer, *allowed_peer;
1615 struct noise_remote *remote;
1616 struct mbuf *m;
1617 int len;
1618 enum wg_ring_state state = WG_PACKET_DEAD;
1619
1620 remote = noise_keypair_remote(pkt->p_keypair);
1621 peer = noise_remote_arg(remote);
1622 m = pkt->p_mbuf;
1623
1624 /* Read nonce and then adjust to remove the header. */
1625 pkt->p_nonce = le64toh(mtod(m, struct wg_pkt_data *)->nonce);
1626 m_adj(m, sizeof(struct wg_pkt_data));
1627
1628 if (noise_keypair_decrypt(pkt->p_keypair, pkt->p_nonce, m) != 0)
1629 goto out;
1630
1631 /* A packet with length 0 is a keepalive packet */
1632 if (__predict_false(m->m_pkthdr.len == 0)) {
1633 DPRINTF(sc, "Receiving keepalive packet from peer "
1634 "%" PRIu64 "\n", peer->p_id);
1635 state = WG_PACKET_CRYPTED;
1636 goto out;
1637 }
1638
1639 /*
1640 * We can let the network stack handle the intricate validation of the
1641 * IP header, we just worry about the sizeof and the version, so we can
1642 * read the source address in wg_aip_lookup.
1643 */
1644
1645 if (determine_af_and_pullup(&m, &pkt->p_af) == 0) {
1646 if (pkt->p_af == AF_INET) {
1647 struct ip *ip = mtod(m, struct ip *);
1648 allowed_peer = wg_aip_lookup(sc, AF_INET, &ip->ip_src);
1649 len = ntohs(ip->ip_len);
1650 if (len >= sizeof(struct ip) && len < m->m_pkthdr.len)
1651 m_adj(m, len - m->m_pkthdr.len);
1652 } else if (pkt->p_af == AF_INET6) {
1653 struct ip6_hdr *ip6 = mtod(m, struct ip6_hdr *);
1654 allowed_peer = wg_aip_lookup(sc, AF_INET6, &ip6->ip6_src);
1655 len = ntohs(ip6->ip6_plen) + sizeof(struct ip6_hdr);
1656 if (len < m->m_pkthdr.len)
1657 m_adj(m, len - m->m_pkthdr.len);
1658 } else
1659 panic("determine_af_and_pullup returned unexpected value");
1660 } else {
1661 DPRINTF(sc, "Packet is neither ipv4 nor ipv6 from peer %" PRIu64 "\n", peer->p_id);
1662 goto out;
1663 }
1664
1665 /* We only want to compare the address, not dereference, so drop the ref. */
1666 if (allowed_peer != NULL)
1667 noise_remote_put(allowed_peer->p_remote);
1668
1669 if (__predict_false(peer != allowed_peer)) {
1670 DPRINTF(sc, "Packet has unallowed src IP from peer %" PRIu64 "\n", peer->p_id);
1671 goto out;
1672 }
1673
1674 wg_mbuf_reset(m);
1675 state = WG_PACKET_CRYPTED;
1676 out:
1677 pkt->p_mbuf = m;
1678 atomic_store_rel_int(&pkt->p_state, state);
1679 GROUPTASK_ENQUEUE(&peer->p_recv);
1680 noise_remote_put(remote);
1681 }
1682
1683 static void
wg_softc_decrypt(struct wg_softc * sc)1684 wg_softc_decrypt(struct wg_softc *sc)
1685 {
1686 struct wg_packet *pkt;
1687
1688 while ((pkt = wg_queue_dequeue_parallel(&sc->sc_decrypt_parallel)) != NULL)
1689 wg_decrypt(sc, pkt);
1690 }
1691
1692 static void
wg_softc_encrypt(struct wg_softc * sc)1693 wg_softc_encrypt(struct wg_softc *sc)
1694 {
1695 struct wg_packet *pkt;
1696
1697 while ((pkt = wg_queue_dequeue_parallel(&sc->sc_encrypt_parallel)) != NULL)
1698 wg_encrypt(sc, pkt);
1699 }
1700
1701 static void
wg_encrypt_dispatch(struct wg_softc * sc)1702 wg_encrypt_dispatch(struct wg_softc *sc)
1703 {
1704 /*
1705 * The update to encrypt_last_cpu is racey such that we may
1706 * reschedule the task for the same CPU multiple times, but
1707 * the race doesn't really matter.
1708 */
1709 u_int cpu = (sc->sc_encrypt_last_cpu + 1) % mp_ncpus;
1710 sc->sc_encrypt_last_cpu = cpu;
1711 GROUPTASK_ENQUEUE(&sc->sc_encrypt[cpu]);
1712 }
1713
1714 static void
wg_decrypt_dispatch(struct wg_softc * sc)1715 wg_decrypt_dispatch(struct wg_softc *sc)
1716 {
1717 u_int cpu = (sc->sc_decrypt_last_cpu + 1) % mp_ncpus;
1718 sc->sc_decrypt_last_cpu = cpu;
1719 GROUPTASK_ENQUEUE(&sc->sc_decrypt[cpu]);
1720 }
1721
1722 static void
wg_deliver_out(struct wg_peer * peer)1723 wg_deliver_out(struct wg_peer *peer)
1724 {
1725 struct wg_endpoint endpoint;
1726 struct wg_softc *sc = peer->p_sc;
1727 struct wg_packet *pkt;
1728 struct mbuf *m;
1729 int rc, len;
1730
1731 wg_peer_get_endpoint(peer, &endpoint);
1732
1733 while ((pkt = wg_queue_dequeue_serial(&peer->p_encrypt_serial)) != NULL) {
1734 if (atomic_load_acq_int(&pkt->p_state) != WG_PACKET_CRYPTED)
1735 goto error;
1736
1737 m = pkt->p_mbuf;
1738 pkt->p_mbuf = NULL;
1739
1740 len = m->m_pkthdr.len;
1741
1742 wg_timers_event_any_authenticated_packet_traversal(peer);
1743 wg_timers_event_any_authenticated_packet_sent(peer);
1744 rc = wg_send(sc, &endpoint, m);
1745 if (rc == 0) {
1746 if (len > (sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN))
1747 wg_timers_event_data_sent(peer);
1748 counter_u64_add(peer->p_tx_bytes, len);
1749 } else if (rc == EADDRNOTAVAIL) {
1750 wg_peer_clear_src(peer);
1751 wg_peer_get_endpoint(peer, &endpoint);
1752 goto error;
1753 } else {
1754 goto error;
1755 }
1756 wg_packet_free(pkt);
1757 if (noise_keep_key_fresh_send(peer->p_remote))
1758 wg_timers_event_want_initiation(peer);
1759 continue;
1760 error:
1761 if_inc_counter(sc->sc_ifp, IFCOUNTER_OERRORS, 1);
1762 wg_packet_free(pkt);
1763 }
1764 }
1765
1766 #ifdef DEV_NETMAP
1767 /*
1768 * Hand a packet to the netmap RX ring, via netmap's
1769 * freebsd_generic_rx_handler().
1770 */
1771 static void
wg_deliver_netmap(if_t ifp,struct mbuf * m,int af)1772 wg_deliver_netmap(if_t ifp, struct mbuf *m, int af)
1773 {
1774 struct ether_header *eh;
1775
1776 M_PREPEND(m, ETHER_HDR_LEN, M_NOWAIT);
1777 if (__predict_false(m == NULL)) {
1778 if_inc_counter(ifp, IFCOUNTER_IQDROPS, 1);
1779 return;
1780 }
1781
1782 eh = mtod(m, struct ether_header *);
1783 eh->ether_type = af == AF_INET ?
1784 htons(ETHERTYPE_IP) : htons(ETHERTYPE_IPV6);
1785 memcpy(eh->ether_shost, "\x02\x02\x02\x02\x02\x02", ETHER_ADDR_LEN);
1786 memcpy(eh->ether_dhost, "\xff\xff\xff\xff\xff\xff", ETHER_ADDR_LEN);
1787 if_input(ifp, m);
1788 }
1789 #endif
1790
1791 static void
wg_deliver_in(struct wg_peer * peer)1792 wg_deliver_in(struct wg_peer *peer)
1793 {
1794 struct wg_softc *sc = peer->p_sc;
1795 if_t ifp = sc->sc_ifp;
1796 struct wg_packet *pkt;
1797 struct mbuf *m;
1798 struct epoch_tracker et;
1799 int af;
1800
1801 while ((pkt = wg_queue_dequeue_serial(&peer->p_decrypt_serial)) != NULL) {
1802 if (atomic_load_acq_int(&pkt->p_state) != WG_PACKET_CRYPTED)
1803 goto error;
1804
1805 m = pkt->p_mbuf;
1806 if (noise_keypair_nonce_check(pkt->p_keypair, pkt->p_nonce) != 0)
1807 goto error;
1808
1809 if (noise_keypair_received_with(pkt->p_keypair) == ECONNRESET)
1810 wg_timers_event_handshake_complete(peer);
1811
1812 wg_timers_event_any_authenticated_packet_received(peer);
1813 wg_timers_event_any_authenticated_packet_traversal(peer);
1814 wg_peer_set_endpoint(peer, &pkt->p_endpoint);
1815
1816 counter_u64_add(peer->p_rx_bytes, m->m_pkthdr.len +
1817 sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN);
1818 if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1);
1819 if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, m->m_pkthdr.len +
1820 sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN);
1821
1822 if (m->m_pkthdr.len == 0)
1823 goto done;
1824
1825 af = pkt->p_af;
1826 MPASS(af == AF_INET || af == AF_INET6);
1827 pkt->p_mbuf = NULL;
1828
1829 m->m_pkthdr.rcvif = ifp;
1830
1831 NET_EPOCH_ENTER(et);
1832 BPF_MTAP2_AF(ifp, m, af);
1833
1834 CURVNET_SET(if_getvnet(ifp));
1835 M_SETFIB(m, if_getfib(ifp));
1836 #ifdef DEV_NETMAP
1837 if ((if_getcapenable(ifp) & IFCAP_NETMAP) != 0)
1838 wg_deliver_netmap(ifp, m, af);
1839 else
1840 #endif
1841 if (af == AF_INET)
1842 netisr_dispatch(NETISR_IP, m);
1843 else if (af == AF_INET6)
1844 netisr_dispatch(NETISR_IPV6, m);
1845 CURVNET_RESTORE();
1846 NET_EPOCH_EXIT(et);
1847
1848 wg_timers_event_data_received(peer);
1849
1850 done:
1851 if (noise_keep_key_fresh_recv(peer->p_remote))
1852 wg_timers_event_want_initiation(peer);
1853 wg_packet_free(pkt);
1854 continue;
1855 error:
1856 if_inc_counter(ifp, IFCOUNTER_IERRORS, 1);
1857 wg_packet_free(pkt);
1858 }
1859 }
1860
1861 static struct wg_packet *
wg_packet_alloc(struct mbuf * m)1862 wg_packet_alloc(struct mbuf *m)
1863 {
1864 struct wg_packet *pkt;
1865
1866 if ((pkt = uma_zalloc(wg_packet_zone, M_NOWAIT | M_ZERO)) == NULL)
1867 return (NULL);
1868 pkt->p_mbuf = m;
1869 return (pkt);
1870 }
1871
1872 static void
wg_packet_free(struct wg_packet * pkt)1873 wg_packet_free(struct wg_packet *pkt)
1874 {
1875 if (pkt->p_keypair != NULL)
1876 noise_keypair_put(pkt->p_keypair);
1877 if (pkt->p_mbuf != NULL)
1878 m_freem(pkt->p_mbuf);
1879 uma_zfree(wg_packet_zone, pkt);
1880 }
1881
1882 static void
wg_queue_init(struct wg_queue * queue,const char * name)1883 wg_queue_init(struct wg_queue *queue, const char *name)
1884 {
1885 mtx_init(&queue->q_mtx, name, NULL, MTX_DEF);
1886 STAILQ_INIT(&queue->q_queue);
1887 queue->q_len = 0;
1888 }
1889
1890 static void
wg_queue_deinit(struct wg_queue * queue)1891 wg_queue_deinit(struct wg_queue *queue)
1892 {
1893 wg_queue_purge(queue);
1894 mtx_destroy(&queue->q_mtx);
1895 }
1896
1897 static size_t
wg_queue_len(struct wg_queue * queue)1898 wg_queue_len(struct wg_queue *queue)
1899 {
1900 return (queue->q_len);
1901 }
1902
1903 static int
wg_queue_enqueue_handshake(struct wg_queue * hs,struct wg_packet * pkt)1904 wg_queue_enqueue_handshake(struct wg_queue *hs, struct wg_packet *pkt)
1905 {
1906 int ret = 0;
1907 mtx_lock(&hs->q_mtx);
1908 if (hs->q_len < MAX_QUEUED_HANDSHAKES) {
1909 STAILQ_INSERT_TAIL(&hs->q_queue, pkt, p_parallel);
1910 hs->q_len++;
1911 } else {
1912 ret = ENOBUFS;
1913 }
1914 mtx_unlock(&hs->q_mtx);
1915 if (ret != 0)
1916 wg_packet_free(pkt);
1917 return (ret);
1918 }
1919
1920 static struct wg_packet *
wg_queue_dequeue_handshake(struct wg_queue * hs)1921 wg_queue_dequeue_handshake(struct wg_queue *hs)
1922 {
1923 struct wg_packet *pkt;
1924 mtx_lock(&hs->q_mtx);
1925 if ((pkt = STAILQ_FIRST(&hs->q_queue)) != NULL) {
1926 STAILQ_REMOVE_HEAD(&hs->q_queue, p_parallel);
1927 hs->q_len--;
1928 }
1929 mtx_unlock(&hs->q_mtx);
1930 return (pkt);
1931 }
1932
1933 static void
wg_queue_push_staged(struct wg_queue * staged,struct wg_packet * pkt)1934 wg_queue_push_staged(struct wg_queue *staged, struct wg_packet *pkt)
1935 {
1936 struct wg_packet *old = NULL;
1937
1938 mtx_lock(&staged->q_mtx);
1939 if (staged->q_len >= MAX_STAGED_PKT) {
1940 old = STAILQ_FIRST(&staged->q_queue);
1941 STAILQ_REMOVE_HEAD(&staged->q_queue, p_parallel);
1942 staged->q_len--;
1943 }
1944 STAILQ_INSERT_TAIL(&staged->q_queue, pkt, p_parallel);
1945 staged->q_len++;
1946 mtx_unlock(&staged->q_mtx);
1947
1948 if (old != NULL)
1949 wg_packet_free(old);
1950 }
1951
1952 static void
wg_queue_enlist_staged(struct wg_queue * staged,struct wg_packet_list * list)1953 wg_queue_enlist_staged(struct wg_queue *staged, struct wg_packet_list *list)
1954 {
1955 struct wg_packet *pkt, *tpkt;
1956 STAILQ_FOREACH_SAFE(pkt, list, p_parallel, tpkt)
1957 wg_queue_push_staged(staged, pkt);
1958 }
1959
1960 static void
wg_queue_delist_staged(struct wg_queue * staged,struct wg_packet_list * list)1961 wg_queue_delist_staged(struct wg_queue *staged, struct wg_packet_list *list)
1962 {
1963 STAILQ_INIT(list);
1964 mtx_lock(&staged->q_mtx);
1965 STAILQ_CONCAT(list, &staged->q_queue);
1966 staged->q_len = 0;
1967 mtx_unlock(&staged->q_mtx);
1968 }
1969
1970 static void
wg_queue_purge(struct wg_queue * staged)1971 wg_queue_purge(struct wg_queue *staged)
1972 {
1973 struct wg_packet_list list;
1974 struct wg_packet *pkt, *tpkt;
1975 wg_queue_delist_staged(staged, &list);
1976 STAILQ_FOREACH_SAFE(pkt, &list, p_parallel, tpkt)
1977 wg_packet_free(pkt);
1978 }
1979
1980 static int
wg_queue_both(struct wg_queue * parallel,struct wg_queue * serial,struct wg_packet * pkt)1981 wg_queue_both(struct wg_queue *parallel, struct wg_queue *serial, struct wg_packet *pkt)
1982 {
1983 pkt->p_state = WG_PACKET_UNCRYPTED;
1984
1985 mtx_lock(&serial->q_mtx);
1986 if (serial->q_len < MAX_QUEUED_PKT) {
1987 serial->q_len++;
1988 STAILQ_INSERT_TAIL(&serial->q_queue, pkt, p_serial);
1989 } else {
1990 mtx_unlock(&serial->q_mtx);
1991 wg_packet_free(pkt);
1992 return (ENOBUFS);
1993 }
1994 mtx_unlock(&serial->q_mtx);
1995
1996 mtx_lock(¶llel->q_mtx);
1997 if (parallel->q_len < MAX_QUEUED_PKT) {
1998 parallel->q_len++;
1999 STAILQ_INSERT_TAIL(¶llel->q_queue, pkt, p_parallel);
2000 } else {
2001 mtx_unlock(¶llel->q_mtx);
2002 pkt->p_state = WG_PACKET_DEAD;
2003 return (ENOBUFS);
2004 }
2005 mtx_unlock(¶llel->q_mtx);
2006
2007 return (0);
2008 }
2009
2010 static struct wg_packet *
wg_queue_dequeue_serial(struct wg_queue * serial)2011 wg_queue_dequeue_serial(struct wg_queue *serial)
2012 {
2013 struct wg_packet *pkt = NULL;
2014 mtx_lock(&serial->q_mtx);
2015 if (serial->q_len > 0 && STAILQ_FIRST(&serial->q_queue)->p_state != WG_PACKET_UNCRYPTED) {
2016 serial->q_len--;
2017 pkt = STAILQ_FIRST(&serial->q_queue);
2018 STAILQ_REMOVE_HEAD(&serial->q_queue, p_serial);
2019 }
2020 mtx_unlock(&serial->q_mtx);
2021 return (pkt);
2022 }
2023
2024 static struct wg_packet *
wg_queue_dequeue_parallel(struct wg_queue * parallel)2025 wg_queue_dequeue_parallel(struct wg_queue *parallel)
2026 {
2027 struct wg_packet *pkt = NULL;
2028 mtx_lock(¶llel->q_mtx);
2029 if (parallel->q_len > 0) {
2030 parallel->q_len--;
2031 pkt = STAILQ_FIRST(¶llel->q_queue);
2032 STAILQ_REMOVE_HEAD(¶llel->q_queue, p_parallel);
2033 }
2034 mtx_unlock(¶llel->q_mtx);
2035 return (pkt);
2036 }
2037
2038 static bool
wg_input(struct mbuf * m,int offset,struct inpcb * inpcb,const struct sockaddr * sa,void * _sc)2039 wg_input(struct mbuf *m, int offset, struct inpcb *inpcb,
2040 const struct sockaddr *sa, void *_sc)
2041 {
2042 #ifdef INET
2043 const struct sockaddr_in *sin;
2044 #endif
2045 #ifdef INET6
2046 const struct sockaddr_in6 *sin6;
2047 #endif
2048 struct noise_remote *remote;
2049 struct wg_pkt_data *data;
2050 struct wg_packet *pkt;
2051 struct wg_peer *peer;
2052 struct wg_softc *sc = _sc;
2053 struct mbuf *defragged;
2054
2055 defragged = m_defrag(m, M_NOWAIT);
2056 if (defragged)
2057 m = defragged;
2058 m = m_unshare(m, M_NOWAIT);
2059 if (!m) {
2060 if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
2061 return true;
2062 }
2063
2064 /* Caller provided us with `sa`, no need for this header. */
2065 m_adj(m, offset + sizeof(struct udphdr));
2066
2067 /* Pullup enough to read packet type */
2068 if ((m = m_pullup(m, sizeof(uint32_t))) == NULL) {
2069 if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
2070 return true;
2071 }
2072
2073 if ((pkt = wg_packet_alloc(m)) == NULL) {
2074 if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
2075 m_freem(m);
2076 return true;
2077 }
2078
2079 /* Save send/recv address and port for later. */
2080 switch (sa->sa_family) {
2081 #ifdef INET
2082 case AF_INET:
2083 sin = (const struct sockaddr_in *)sa;
2084 pkt->p_endpoint.e_remote.r_sin = sin[0];
2085 pkt->p_endpoint.e_local.l_in = sin[1].sin_addr;
2086 break;
2087 #endif
2088 #ifdef INET6
2089 case AF_INET6:
2090 sin6 = (const struct sockaddr_in6 *)sa;
2091 pkt->p_endpoint.e_remote.r_sin6 = sin6[0];
2092 pkt->p_endpoint.e_local.l_in6 = sin6[1].sin6_addr;
2093 break;
2094 #endif
2095 default:
2096 goto error;
2097 }
2098
2099 if ((m->m_pkthdr.len == sizeof(struct wg_pkt_initiation) &&
2100 *mtod(m, uint32_t *) == WG_PKT_INITIATION) ||
2101 (m->m_pkthdr.len == sizeof(struct wg_pkt_response) &&
2102 *mtod(m, uint32_t *) == WG_PKT_RESPONSE) ||
2103 (m->m_pkthdr.len == sizeof(struct wg_pkt_cookie) &&
2104 *mtod(m, uint32_t *) == WG_PKT_COOKIE)) {
2105
2106 if (wg_queue_enqueue_handshake(&sc->sc_handshake_queue, pkt) != 0) {
2107 if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
2108 DPRINTF(sc, "Dropping handshake packet\n");
2109 }
2110 GROUPTASK_ENQUEUE(&sc->sc_handshake);
2111 } else if (m->m_pkthdr.len >= sizeof(struct wg_pkt_data) +
2112 NOISE_AUTHTAG_LEN && *mtod(m, uint32_t *) == WG_PKT_DATA) {
2113
2114 /* Pullup whole header to read r_idx below. */
2115 if ((pkt->p_mbuf = m_pullup(m, sizeof(struct wg_pkt_data))) == NULL)
2116 goto error;
2117
2118 data = mtod(pkt->p_mbuf, struct wg_pkt_data *);
2119 if ((pkt->p_keypair = noise_keypair_lookup(sc->sc_local, data->r_idx)) == NULL)
2120 goto error;
2121
2122 remote = noise_keypair_remote(pkt->p_keypair);
2123 peer = noise_remote_arg(remote);
2124 if (wg_queue_both(&sc->sc_decrypt_parallel, &peer->p_decrypt_serial, pkt) != 0)
2125 if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
2126 wg_decrypt_dispatch(sc);
2127 noise_remote_put(remote);
2128 } else {
2129 goto error;
2130 }
2131 return true;
2132 error:
2133 if_inc_counter(sc->sc_ifp, IFCOUNTER_IERRORS, 1);
2134 wg_packet_free(pkt);
2135 return true;
2136 }
2137
2138 static void
wg_peer_send_staged(struct wg_peer * peer)2139 wg_peer_send_staged(struct wg_peer *peer)
2140 {
2141 struct wg_packet_list list;
2142 struct noise_keypair *keypair;
2143 struct wg_packet *pkt, *tpkt;
2144 struct wg_softc *sc = peer->p_sc;
2145
2146 wg_queue_delist_staged(&peer->p_stage_queue, &list);
2147
2148 if (STAILQ_EMPTY(&list))
2149 return;
2150
2151 if ((keypair = noise_keypair_current(peer->p_remote)) == NULL)
2152 goto error;
2153
2154 STAILQ_FOREACH(pkt, &list, p_parallel) {
2155 if (noise_keypair_nonce_next(keypair, &pkt->p_nonce) != 0)
2156 goto error_keypair;
2157 }
2158 STAILQ_FOREACH_SAFE(pkt, &list, p_parallel, tpkt) {
2159 pkt->p_keypair = noise_keypair_ref(keypair);
2160 if (wg_queue_both(&sc->sc_encrypt_parallel, &peer->p_encrypt_serial, pkt) != 0)
2161 if_inc_counter(sc->sc_ifp, IFCOUNTER_OQDROPS, 1);
2162 }
2163 wg_encrypt_dispatch(sc);
2164 noise_keypair_put(keypair);
2165 return;
2166
2167 error_keypair:
2168 noise_keypair_put(keypair);
2169 error:
2170 wg_queue_enlist_staged(&peer->p_stage_queue, &list);
2171 wg_timers_event_want_initiation(peer);
2172 }
2173
2174 static inline void
xmit_err(if_t ifp,struct mbuf * m,struct wg_packet * pkt,sa_family_t af)2175 xmit_err(if_t ifp, struct mbuf *m, struct wg_packet *pkt, sa_family_t af)
2176 {
2177 if_inc_counter(ifp, IFCOUNTER_OERRORS, 1);
2178 switch (af) {
2179 #ifdef INET
2180 case AF_INET:
2181 icmp_error(m, ICMP_UNREACH, ICMP_UNREACH_HOST, 0, 0);
2182 if (pkt)
2183 pkt->p_mbuf = NULL;
2184 m = NULL;
2185 break;
2186 #endif
2187 #ifdef INET6
2188 case AF_INET6:
2189 icmp6_error(m, ICMP6_DST_UNREACH, 0, 0);
2190 if (pkt)
2191 pkt->p_mbuf = NULL;
2192 m = NULL;
2193 break;
2194 #endif
2195 }
2196 if (pkt)
2197 wg_packet_free(pkt);
2198 else if (m)
2199 m_freem(m);
2200 }
2201
2202 static int
wg_xmit(if_t ifp,struct mbuf * m,sa_family_t af,uint32_t mtu)2203 wg_xmit(if_t ifp, struct mbuf *m, sa_family_t af, uint32_t mtu)
2204 {
2205 struct wg_packet *pkt = NULL;
2206 struct wg_softc *sc = if_getsoftc(ifp);
2207 struct wg_peer *peer;
2208 int rc = 0;
2209 sa_family_t peer_af;
2210
2211 /* Work around lifetime issue in the ipv6 mld code. */
2212 if (__predict_false((if_getflags(ifp) & IFF_DYING) || !sc)) {
2213 rc = ENXIO;
2214 goto err_xmit;
2215 }
2216
2217 if ((pkt = wg_packet_alloc(m)) == NULL) {
2218 rc = ENOBUFS;
2219 goto err_xmit;
2220 }
2221 pkt->p_mtu = mtu;
2222 pkt->p_af = af;
2223
2224 if (af == AF_INET) {
2225 peer = wg_aip_lookup(sc, AF_INET, &mtod(m, struct ip *)->ip_dst);
2226 } else if (af == AF_INET6) {
2227 peer = wg_aip_lookup(sc, AF_INET6, &mtod(m, struct ip6_hdr *)->ip6_dst);
2228 } else {
2229 rc = EAFNOSUPPORT;
2230 goto err_xmit;
2231 }
2232
2233 BPF_MTAP2_AF(ifp, m, pkt->p_af);
2234
2235 if (__predict_false(peer == NULL)) {
2236 rc = ENETUNREACH;
2237 goto err_xmit;
2238 }
2239
2240 if (__predict_false(if_tunnel_check_nesting(ifp, m, MTAG_WGLOOP, MAX_LOOPS))) {
2241 DPRINTF(sc, "Packet looped");
2242 rc = ELOOP;
2243 goto err_peer;
2244 }
2245
2246 peer_af = peer->p_endpoint.e_remote.r_sa.sa_family;
2247 if (__predict_false(peer_af != AF_INET && peer_af != AF_INET6)) {
2248 DPRINTF(sc, "No valid endpoint has been configured or "
2249 "discovered for peer %" PRIu64 "\n", peer->p_id);
2250 rc = EHOSTUNREACH;
2251 goto err_peer;
2252 }
2253
2254 wg_queue_push_staged(&peer->p_stage_queue, pkt);
2255 wg_peer_send_staged(peer);
2256 noise_remote_put(peer->p_remote);
2257 return (0);
2258
2259 err_peer:
2260 noise_remote_put(peer->p_remote);
2261 err_xmit:
2262 xmit_err(ifp, m, pkt, af);
2263 return (rc);
2264 }
2265
2266 static inline int
determine_af_and_pullup(struct mbuf ** m,sa_family_t * af)2267 determine_af_and_pullup(struct mbuf **m, sa_family_t *af)
2268 {
2269 u_char ipv;
2270 if ((*m)->m_pkthdr.len >= sizeof(struct ip6_hdr))
2271 *m = m_pullup(*m, sizeof(struct ip6_hdr));
2272 else if ((*m)->m_pkthdr.len >= sizeof(struct ip))
2273 *m = m_pullup(*m, sizeof(struct ip));
2274 else
2275 return (EAFNOSUPPORT);
2276 if (*m == NULL)
2277 return (ENOBUFS);
2278 ipv = mtod(*m, struct ip *)->ip_v;
2279 if (ipv == 4)
2280 *af = AF_INET;
2281 else if (ipv == 6 && (*m)->m_pkthdr.len >= sizeof(struct ip6_hdr))
2282 *af = AF_INET6;
2283 else
2284 return (EAFNOSUPPORT);
2285 return (0);
2286 }
2287
2288 static int
determine_ethertype_and_pullup(struct mbuf ** m,int * etp)2289 determine_ethertype_and_pullup(struct mbuf **m, int *etp)
2290 {
2291 struct ether_header *eh;
2292
2293 *m = m_pullup(*m, sizeof(struct ether_header));
2294 if (__predict_false(*m == NULL))
2295 return (ENOBUFS);
2296 eh = mtod(*m, struct ether_header *);
2297 *etp = ntohs(eh->ether_type);
2298 if (*etp != ETHERTYPE_IP && *etp != ETHERTYPE_IPV6)
2299 return (EAFNOSUPPORT);
2300 return (0);
2301 }
2302
2303 /*
2304 * This should only be invoked by netmap, via nm_os_generic_xmit_frame(), to
2305 * transmit packets from the netmap TX ring.
2306 */
2307 static int
wg_transmit(if_t ifp,struct mbuf * m)2308 wg_transmit(if_t ifp, struct mbuf *m)
2309 {
2310 sa_family_t af;
2311 int et, ret;
2312 struct mbuf *defragged;
2313
2314 KASSERT((if_getcapenable(ifp) & IFCAP_NETMAP) != 0,
2315 ("%s: ifp %p is not in netmap mode", __func__, ifp));
2316
2317 defragged = m_defrag(m, M_NOWAIT);
2318 if (defragged)
2319 m = defragged;
2320 m = m_unshare(m, M_NOWAIT);
2321 if (!m) {
2322 xmit_err(ifp, m, NULL, AF_UNSPEC);
2323 return (ENOBUFS);
2324 }
2325
2326 ret = determine_ethertype_and_pullup(&m, &et);
2327 if (ret) {
2328 xmit_err(ifp, m, NULL, AF_UNSPEC);
2329 return (ret);
2330 }
2331 m_adj(m, sizeof(struct ether_header));
2332
2333 ret = determine_af_and_pullup(&m, &af);
2334 if (ret) {
2335 xmit_err(ifp, m, NULL, AF_UNSPEC);
2336 return (ret);
2337 }
2338
2339 /*
2340 * netmap only gets to see transient errors, since it handles errors by
2341 * refusing to advance the transmit ring and retrying later.
2342 */
2343 ret = wg_xmit(ifp, m, af, if_getmtu(ifp));
2344 if (ret == ENOBUFS)
2345 return (ret);
2346 return (0);
2347 }
2348
2349 #ifdef DEV_NETMAP
2350 /*
2351 * This should only be invoked by netmap, via nm_os_send_up(), to process
2352 * packets from the host TX ring.
2353 */
2354 static void
wg_if_input(if_t ifp,struct mbuf * m)2355 wg_if_input(if_t ifp, struct mbuf *m)
2356 {
2357 int et;
2358
2359 KASSERT((if_getcapenable(ifp) & IFCAP_NETMAP) != 0,
2360 ("%s: ifp %p is not in netmap mode", __func__, ifp));
2361
2362 if (determine_ethertype_and_pullup(&m, &et) != 0) {
2363 if_inc_counter(ifp, IFCOUNTER_IERRORS, 1);
2364 m_freem(m);
2365 return;
2366 }
2367 CURVNET_SET(if_getvnet(ifp));
2368 switch (et) {
2369 case ETHERTYPE_IP:
2370 m_adj(m, sizeof(struct ether_header));
2371 netisr_dispatch(NETISR_IP, m);
2372 break;
2373 case ETHERTYPE_IPV6:
2374 m_adj(m, sizeof(struct ether_header));
2375 netisr_dispatch(NETISR_IPV6, m);
2376 break;
2377 default:
2378 __assert_unreachable();
2379 }
2380 CURVNET_RESTORE();
2381 }
2382
2383 /*
2384 * Deliver a packet to the host RX ring. Because the interface is in netmap
2385 * mode, the if_transmit() call should pass the packet to netmap_transmit().
2386 */
2387 static int
wg_xmit_netmap(if_t ifp,struct mbuf * m,int af)2388 wg_xmit_netmap(if_t ifp, struct mbuf *m, int af)
2389 {
2390 struct ether_header *eh;
2391
2392 if (__predict_false(if_tunnel_check_nesting(ifp, m, MTAG_WGLOOP,
2393 MAX_LOOPS))) {
2394 printf("%s:%d\n", __func__, __LINE__);
2395 if_inc_counter(ifp, IFCOUNTER_IERRORS, 1);
2396 m_freem(m);
2397 return (ELOOP);
2398 }
2399
2400 M_PREPEND(m, ETHER_HDR_LEN, M_NOWAIT);
2401 if (__predict_false(m == NULL)) {
2402 if_inc_counter(ifp, IFCOUNTER_IQDROPS, 1);
2403 return (ENOBUFS);
2404 }
2405
2406 eh = mtod(m, struct ether_header *);
2407 eh->ether_type = af == AF_INET ?
2408 htons(ETHERTYPE_IP) : htons(ETHERTYPE_IPV6);
2409 memcpy(eh->ether_shost, "\x06\x06\x06\x06\x06\x06", ETHER_ADDR_LEN);
2410 memcpy(eh->ether_dhost, "\xff\xff\xff\xff\xff\xff", ETHER_ADDR_LEN);
2411 return (if_transmit(ifp, m));
2412 }
2413 #endif /* DEV_NETMAP */
2414
2415 static int
wg_output(if_t ifp,struct mbuf * m,const struct sockaddr * dst,struct route * ro)2416 wg_output(if_t ifp, struct mbuf *m, const struct sockaddr *dst, struct route *ro)
2417 {
2418 sa_family_t parsed_af;
2419 uint32_t af, mtu;
2420 int ret;
2421 struct mbuf *defragged;
2422
2423 /* BPF writes need to be handled specially. */
2424 if (dst->sa_family == AF_UNSPEC || dst->sa_family == pseudo_AF_HDRCMPLT)
2425 memcpy(&af, dst->sa_data, sizeof(af));
2426 else
2427 af = RO_GET_FAMILY(ro, dst);
2428 if (af == AF_UNSPEC) {
2429 xmit_err(ifp, m, NULL, af);
2430 return (EAFNOSUPPORT);
2431 }
2432
2433 #ifdef DEV_NETMAP
2434 if ((if_getcapenable(ifp) & IFCAP_NETMAP) != 0)
2435 return (wg_xmit_netmap(ifp, m, af));
2436 #endif
2437
2438 defragged = m_defrag(m, M_NOWAIT);
2439 if (defragged)
2440 m = defragged;
2441 m = m_unshare(m, M_NOWAIT);
2442 if (!m) {
2443 xmit_err(ifp, m, NULL, AF_UNSPEC);
2444 return (ENOBUFS);
2445 }
2446
2447 ret = determine_af_and_pullup(&m, &parsed_af);
2448 if (ret) {
2449 xmit_err(ifp, m, NULL, AF_UNSPEC);
2450 return (ret);
2451 }
2452
2453 MPASS(parsed_af == af);
2454 mtu = (ro != NULL && ro->ro_mtu > 0) ? ro->ro_mtu : if_getmtu(ifp);
2455 return (wg_xmit(ifp, m, parsed_af, mtu));
2456 }
2457
2458 static int
wg_peer_add(struct wg_softc * sc,const nvlist_t * nvl)2459 wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl)
2460 {
2461 uint8_t public[WG_KEY_SIZE];
2462 const void *pub_key, *preshared_key = NULL;
2463 const struct sockaddr *endpoint;
2464 int err;
2465 size_t size;
2466 struct noise_remote *remote;
2467 struct wg_peer *peer = NULL;
2468 bool need_cleanup = false;
2469
2470 sx_assert(&sc->sc_lock, SX_XLOCKED);
2471
2472 if (!nvlist_exists_binary(nvl, "public-key")) {
2473 return (EINVAL);
2474 }
2475 pub_key = nvlist_get_binary(nvl, "public-key", &size);
2476 if (size != WG_KEY_SIZE) {
2477 return (EINVAL);
2478 }
2479 if (noise_local_keys(sc->sc_local, public, NULL) == 0 &&
2480 bcmp(public, pub_key, WG_KEY_SIZE) == 0) {
2481 return (0); // Silently ignored; not actually a failure.
2482 }
2483 if ((remote = noise_remote_lookup(sc->sc_local, pub_key)) != NULL)
2484 peer = noise_remote_arg(remote);
2485 if (nvlist_exists_bool(nvl, "remove") &&
2486 nvlist_get_bool(nvl, "remove")) {
2487 if (remote != NULL) {
2488 wg_peer_destroy(peer);
2489 noise_remote_put(remote);
2490 }
2491 return (0);
2492 }
2493 if (nvlist_exists_bool(nvl, "replace-allowedips") &&
2494 nvlist_get_bool(nvl, "replace-allowedips") &&
2495 peer != NULL) {
2496
2497 wg_aip_remove_all(sc, peer);
2498 }
2499 if (peer == NULL) {
2500 peer = wg_peer_create(sc, pub_key, &err);
2501 if (peer == NULL)
2502 goto out;
2503 need_cleanup = true;
2504 }
2505 if (nvlist_exists_binary(nvl, "endpoint")) {
2506 endpoint = nvlist_get_binary(nvl, "endpoint", &size);
2507 if (size > sizeof(peer->p_endpoint.e_remote)) {
2508 err = EINVAL;
2509 goto out;
2510 }
2511 memcpy(&peer->p_endpoint.e_remote, endpoint, size);
2512 }
2513 if (nvlist_exists_binary(nvl, "preshared-key")) {
2514 preshared_key = nvlist_get_binary(nvl, "preshared-key", &size);
2515 if (size != WG_KEY_SIZE) {
2516 err = EINVAL;
2517 goto out;
2518 }
2519 noise_remote_set_psk(peer->p_remote, preshared_key);
2520 }
2521 if (nvlist_exists_number(nvl, "persistent-keepalive-interval")) {
2522 uint64_t pki = nvlist_get_number(nvl, "persistent-keepalive-interval");
2523 if (pki > UINT16_MAX) {
2524 err = EINVAL;
2525 goto out;
2526 }
2527 wg_timers_set_persistent_keepalive(peer, pki);
2528 }
2529 if (nvlist_exists_nvlist_array(nvl, "allowed-ips")) {
2530 const void *addr;
2531 uint64_t cidr;
2532 const nvlist_t * const * aipl;
2533 size_t allowedip_count;
2534
2535 aipl = nvlist_get_nvlist_array(nvl, "allowed-ips", &allowedip_count);
2536 for (size_t idx = 0; idx < allowedip_count; idx++) {
2537 sa_family_t ipaf;
2538 int ipflags;
2539
2540 if (!nvlist_exists_number(aipl[idx], "cidr"))
2541 continue;
2542
2543 ipaf = AF_UNSPEC;
2544 ipflags = 0;
2545 if (nvlist_exists_number(aipl[idx], "flags"))
2546 ipflags = nvlist_get_number(aipl[idx], "flags");
2547 if ((ipflags & ~WGALLOWEDIP_VALID_FLAGS) != 0) {
2548 err = EOPNOTSUPP;
2549 goto out;
2550 }
2551 cidr = nvlist_get_number(aipl[idx], "cidr");
2552 if (nvlist_exists_binary(aipl[idx], "ipv4")) {
2553 addr = nvlist_get_binary(aipl[idx], "ipv4", &size);
2554 if (addr == NULL || cidr > 32 || size != sizeof(struct in_addr)) {
2555 err = EINVAL;
2556 goto out;
2557 }
2558
2559 ipaf = AF_INET;
2560 } else if (nvlist_exists_binary(aipl[idx], "ipv6")) {
2561 addr = nvlist_get_binary(aipl[idx], "ipv6", &size);
2562 if (addr == NULL || cidr > 128 || size != sizeof(struct in6_addr)) {
2563 err = EINVAL;
2564 goto out;
2565 }
2566
2567 ipaf = AF_INET6;
2568 } else {
2569 continue;
2570 }
2571
2572 MPASS(ipaf != AF_UNSPEC);
2573 if ((ipflags & WGALLOWEDIP_REMOVE_ME) != 0) {
2574 err = wg_aip_del(sc, peer, ipaf, addr, cidr);
2575 } else {
2576 err = wg_aip_add(sc, peer, ipaf, addr, cidr);
2577 }
2578
2579 if (err != 0)
2580 goto out;
2581 }
2582 }
2583 if (remote != NULL)
2584 noise_remote_put(remote);
2585 return (0);
2586 out:
2587 if (need_cleanup) /* If we fail, only destroy if it was new. */
2588 wg_peer_destroy(peer);
2589 if (remote != NULL)
2590 noise_remote_put(remote);
2591 return (err);
2592 }
2593
2594 static int
wgc_set(struct wg_softc * sc,struct wg_data_io * wgd)2595 wgc_set(struct wg_softc *sc, struct wg_data_io *wgd)
2596 {
2597 uint8_t public[WG_KEY_SIZE], private[WG_KEY_SIZE];
2598 if_t ifp;
2599 void *nvlpacked;
2600 nvlist_t *nvl;
2601 ssize_t size;
2602 int err;
2603
2604 ifp = sc->sc_ifp;
2605 if (wgd->wgd_size == 0 || wgd->wgd_data == NULL)
2606 return (EFAULT);
2607
2608 /* Can nvlists be streamed in? It's not nice to impose arbitrary limits like that but
2609 * there needs to be _some_ limitation. */
2610 if (wgd->wgd_size >= UINT32_MAX / 2)
2611 return (E2BIG);
2612
2613 nvlpacked = malloc(wgd->wgd_size, M_TEMP, M_WAITOK | M_ZERO);
2614
2615 err = copyin(wgd->wgd_data, nvlpacked, wgd->wgd_size);
2616 if (err)
2617 goto out;
2618 nvl = nvlist_unpack(nvlpacked, wgd->wgd_size, 0);
2619 if (nvl == NULL) {
2620 err = EBADMSG;
2621 goto out;
2622 }
2623 sx_xlock(&sc->sc_lock);
2624 if (nvlist_exists_bool(nvl, "replace-peers") &&
2625 nvlist_get_bool(nvl, "replace-peers"))
2626 wg_peer_destroy_all(sc);
2627 if (nvlist_exists_number(nvl, "listen-port")) {
2628 uint64_t new_port = nvlist_get_number(nvl, "listen-port");
2629 if (new_port > UINT16_MAX) {
2630 err = EINVAL;
2631 goto out_locked;
2632 }
2633 if (new_port != sc->sc_socket.so_port) {
2634 if ((if_getdrvflags(ifp) & IFF_DRV_RUNNING) != 0) {
2635 if ((err = wg_socket_init(sc, new_port)) != 0)
2636 goto out_locked;
2637 } else
2638 sc->sc_socket.so_port = new_port;
2639 }
2640 }
2641 if (nvlist_exists_binary(nvl, "private-key")) {
2642 const void *key = nvlist_get_binary(nvl, "private-key", &size);
2643 if (size != WG_KEY_SIZE) {
2644 err = EINVAL;
2645 goto out_locked;
2646 }
2647
2648 if (noise_local_keys(sc->sc_local, NULL, private) != 0 ||
2649 timingsafe_bcmp(private, key, WG_KEY_SIZE) != 0) {
2650 struct wg_peer *peer;
2651
2652 if (curve25519_generate_public(public, key)) {
2653 /* Peer conflict: remove conflicting peer. */
2654 struct noise_remote *remote;
2655 if ((remote = noise_remote_lookup(sc->sc_local,
2656 public)) != NULL) {
2657 peer = noise_remote_arg(remote);
2658 wg_peer_destroy(peer);
2659 noise_remote_put(remote);
2660 }
2661 }
2662
2663 /*
2664 * Set the private key and invalidate all existing
2665 * handshakes.
2666 */
2667 /* Note: we might be removing the private key. */
2668 noise_local_private(sc->sc_local, key);
2669 if (noise_local_keys(sc->sc_local, NULL, NULL) == 0)
2670 cookie_checker_update(&sc->sc_cookie, public);
2671 else
2672 cookie_checker_update(&sc->sc_cookie, NULL);
2673 }
2674 }
2675 if (nvlist_exists_number(nvl, "user-cookie")) {
2676 uint64_t user_cookie = nvlist_get_number(nvl, "user-cookie");
2677 if (user_cookie > UINT32_MAX) {
2678 err = EINVAL;
2679 goto out_locked;
2680 }
2681 err = wg_socket_set_cookie(sc, user_cookie);
2682 if (err)
2683 goto out_locked;
2684 }
2685 if (nvlist_exists_nvlist_array(nvl, "peers")) {
2686 size_t peercount;
2687 const nvlist_t * const*nvl_peers;
2688
2689 nvl_peers = nvlist_get_nvlist_array(nvl, "peers", &peercount);
2690 for (int i = 0; i < peercount; i++) {
2691 err = wg_peer_add(sc, nvl_peers[i]);
2692 if (err != 0)
2693 goto out_locked;
2694 }
2695 }
2696
2697 out_locked:
2698 sx_xunlock(&sc->sc_lock);
2699 nvlist_destroy(nvl);
2700 out:
2701 zfree(nvlpacked, M_TEMP);
2702 return (err);
2703 }
2704
2705 static int
wgc_get(struct wg_softc * sc,struct wg_data_io * wgd)2706 wgc_get(struct wg_softc *sc, struct wg_data_io *wgd)
2707 {
2708 uint8_t public_key[WG_KEY_SIZE] = { 0 };
2709 uint8_t private_key[WG_KEY_SIZE] = { 0 };
2710 uint8_t preshared_key[NOISE_SYMMETRIC_KEY_LEN] = { 0 };
2711 nvlist_t *nvl, *nvl_peer, *nvl_aip, **nvl_peers, **nvl_aips;
2712 size_t size, peer_count, aip_count, i, j;
2713 struct wg_timespec64 ts64;
2714 struct wg_peer *peer;
2715 struct wg_aip *aip;
2716 void *packed;
2717 int err = 0;
2718
2719 nvl = nvlist_create(0);
2720 if (!nvl)
2721 return (ENOMEM);
2722
2723 sx_slock(&sc->sc_lock);
2724
2725 if (sc->sc_socket.so_port != 0)
2726 nvlist_add_number(nvl, "listen-port", sc->sc_socket.so_port);
2727 if (sc->sc_socket.so_user_cookie != 0)
2728 nvlist_add_number(nvl, "user-cookie", sc->sc_socket.so_user_cookie);
2729 if (noise_local_keys(sc->sc_local, public_key, private_key) == 0) {
2730 nvlist_add_binary(nvl, "public-key", public_key, WG_KEY_SIZE);
2731 if (wgc_privileged(sc))
2732 nvlist_add_binary(nvl, "private-key", private_key, WG_KEY_SIZE);
2733 explicit_bzero(private_key, sizeof(private_key));
2734 }
2735 peer_count = sc->sc_peers_num;
2736 if (peer_count) {
2737 nvl_peers = mallocarray(peer_count, sizeof(void *), M_NVLIST, M_WAITOK | M_ZERO);
2738 i = 0;
2739 TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) {
2740 if (i >= peer_count)
2741 panic("peers changed from under us");
2742
2743 nvl_peers[i++] = nvl_peer = nvlist_create(0);
2744 if (!nvl_peer) {
2745 err = ENOMEM;
2746 goto err_peer;
2747 }
2748
2749 (void)noise_remote_keys(peer->p_remote, public_key, preshared_key);
2750 nvlist_add_binary(nvl_peer, "public-key", public_key, sizeof(public_key));
2751 if (wgc_privileged(sc))
2752 nvlist_add_binary(nvl_peer, "preshared-key", preshared_key, sizeof(preshared_key));
2753 explicit_bzero(preshared_key, sizeof(preshared_key));
2754 if (peer->p_endpoint.e_remote.r_sa.sa_family == AF_INET)
2755 nvlist_add_binary(nvl_peer, "endpoint", &peer->p_endpoint.e_remote, sizeof(struct sockaddr_in));
2756 else if (peer->p_endpoint.e_remote.r_sa.sa_family == AF_INET6)
2757 nvlist_add_binary(nvl_peer, "endpoint", &peer->p_endpoint.e_remote, sizeof(struct sockaddr_in6));
2758 wg_timers_get_last_handshake(peer, &ts64);
2759 nvlist_add_binary(nvl_peer, "last-handshake-time", &ts64, sizeof(ts64));
2760 nvlist_add_number(nvl_peer, "persistent-keepalive-interval", peer->p_persistent_keepalive_interval);
2761 nvlist_add_number(nvl_peer, "rx-bytes", counter_u64_fetch(peer->p_rx_bytes));
2762 nvlist_add_number(nvl_peer, "tx-bytes", counter_u64_fetch(peer->p_tx_bytes));
2763
2764 aip_count = peer->p_aips_num;
2765 if (aip_count) {
2766 nvl_aips = mallocarray(aip_count, sizeof(void *), M_NVLIST, M_WAITOK | M_ZERO);
2767 j = 0;
2768 LIST_FOREACH(aip, &peer->p_aips, a_entry) {
2769 if (j >= aip_count)
2770 panic("aips changed from under us");
2771
2772 nvl_aips[j++] = nvl_aip = nvlist_create(0);
2773 if (!nvl_aip) {
2774 err = ENOMEM;
2775 goto err_aip;
2776 }
2777 if (aip->a_af == AF_INET) {
2778 nvlist_add_binary(nvl_aip, "ipv4", &aip->a_addr.in, sizeof(aip->a_addr.in));
2779 nvlist_add_number(nvl_aip, "cidr", bitcount32(aip->a_mask.ip));
2780 }
2781 #ifdef INET6
2782 else if (aip->a_af == AF_INET6) {
2783 nvlist_add_binary(nvl_aip, "ipv6", &aip->a_addr.in6, sizeof(aip->a_addr.in6));
2784 nvlist_add_number(nvl_aip, "cidr", in6_mask2len(&aip->a_mask.in6, NULL));
2785 }
2786 #endif
2787 }
2788 nvlist_add_nvlist_array(nvl_peer, "allowed-ips", (const nvlist_t *const *)nvl_aips, aip_count);
2789 err_aip:
2790 for (j = 0; j < aip_count; ++j)
2791 nvlist_destroy(nvl_aips[j]);
2792 free(nvl_aips, M_NVLIST);
2793 if (err)
2794 goto err_peer;
2795 }
2796 }
2797 nvlist_add_nvlist_array(nvl, "peers", (const nvlist_t * const *)nvl_peers, peer_count);
2798 err_peer:
2799 for (i = 0; i < peer_count; ++i)
2800 nvlist_destroy(nvl_peers[i]);
2801 free(nvl_peers, M_NVLIST);
2802 if (err) {
2803 sx_sunlock(&sc->sc_lock);
2804 goto err;
2805 }
2806 }
2807 sx_sunlock(&sc->sc_lock);
2808 packed = nvlist_pack(nvl, &size);
2809 if (!packed) {
2810 err = ENOMEM;
2811 goto err;
2812 }
2813 if (!wgd->wgd_size) {
2814 wgd->wgd_size = size;
2815 goto out;
2816 }
2817 if (wgd->wgd_size < size) {
2818 err = ENOSPC;
2819 goto out;
2820 }
2821 err = copyout(packed, wgd->wgd_data, size);
2822 wgd->wgd_size = size;
2823
2824 out:
2825 zfree(packed, M_NVLIST);
2826 err:
2827 nvlist_destroy(nvl);
2828 return (err);
2829 }
2830
2831 static int
wg_ioctl(if_t ifp,u_long cmd,caddr_t data)2832 wg_ioctl(if_t ifp, u_long cmd, caddr_t data)
2833 {
2834 struct wg_data_io *wgd = (struct wg_data_io *)data;
2835 struct ifreq *ifr = (struct ifreq *)data;
2836 struct wg_softc *sc;
2837 int ret = 0;
2838
2839 sx_slock(&wg_sx);
2840 sc = if_getsoftc(ifp);
2841 if (!sc) {
2842 ret = ENXIO;
2843 goto out;
2844 }
2845
2846 switch (cmd) {
2847 case SIOCSWG:
2848 ret = priv_check(curthread, PRIV_NET_WG);
2849 if (ret == 0)
2850 ret = wgc_set(sc, wgd);
2851 break;
2852 case SIOCGWG:
2853 ret = wgc_get(sc, wgd);
2854 break;
2855 /* Interface IOCTLs */
2856 case SIOCSIFADDR:
2857 /*
2858 * This differs from *BSD norms, but is more uniform with how
2859 * WireGuard behaves elsewhere.
2860 */
2861 break;
2862 case SIOCSIFFLAGS:
2863 if (if_getflags(ifp) & IFF_UP)
2864 ret = wg_up(sc);
2865 else
2866 wg_down(sc);
2867 break;
2868 case SIOCSIFMTU:
2869 if (ifr->ifr_mtu <= 0 || ifr->ifr_mtu > MAX_MTU)
2870 ret = EINVAL;
2871 else
2872 if_setmtu(ifp, ifr->ifr_mtu);
2873 break;
2874 case SIOCADDMULTI:
2875 case SIOCDELMULTI:
2876 break;
2877 case SIOCGTUNFIB:
2878 ifr->ifr_fib = sc->sc_socket.so_fibnum;
2879 break;
2880 case SIOCSTUNFIB:
2881 ret = priv_check(curthread, PRIV_NET_WG);
2882 if (ret)
2883 break;
2884 ret = priv_check(curthread, PRIV_NET_SETIFFIB);
2885 if (ret)
2886 break;
2887 sx_xlock(&sc->sc_lock);
2888 ret = wg_socket_set_fibnum(sc, ifr->ifr_fib);
2889 sx_xunlock(&sc->sc_lock);
2890 break;
2891 default:
2892 ret = ENOTTY;
2893 }
2894
2895 out:
2896 sx_sunlock(&wg_sx);
2897 return (ret);
2898 }
2899
2900 static int
wg_up(struct wg_softc * sc)2901 wg_up(struct wg_softc *sc)
2902 {
2903 if_t ifp = sc->sc_ifp;
2904 struct wg_peer *peer;
2905 int rc = EBUSY;
2906
2907 sx_xlock(&sc->sc_lock);
2908 /* Jail's being removed, no more wg_up(). */
2909 if ((sc->sc_flags & WGF_DYING) != 0)
2910 goto out;
2911
2912 /* Silent success if we're already running. */
2913 rc = 0;
2914 if (if_getdrvflags(ifp) & IFF_DRV_RUNNING)
2915 goto out;
2916 if_setdrvflagbits(ifp, IFF_DRV_RUNNING, 0);
2917
2918 rc = wg_socket_init(sc, sc->sc_socket.so_port);
2919 if (rc == 0) {
2920 TAILQ_FOREACH(peer, &sc->sc_peers, p_entry)
2921 wg_timers_enable(peer);
2922 if_link_state_change(sc->sc_ifp, LINK_STATE_UP);
2923 } else {
2924 if_setdrvflagbits(ifp, 0, IFF_DRV_RUNNING);
2925 DPRINTF(sc, "Unable to initialize sockets: %d\n", rc);
2926 }
2927 out:
2928 sx_xunlock(&sc->sc_lock);
2929 return (rc);
2930 }
2931
2932 static void
wg_down(struct wg_softc * sc)2933 wg_down(struct wg_softc *sc)
2934 {
2935 if_t ifp = sc->sc_ifp;
2936 struct wg_peer *peer;
2937
2938 sx_xlock(&sc->sc_lock);
2939 if (!(if_getdrvflags(ifp) & IFF_DRV_RUNNING)) {
2940 sx_xunlock(&sc->sc_lock);
2941 return;
2942 }
2943 if_setdrvflagbits(ifp, 0, IFF_DRV_RUNNING);
2944
2945 TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) {
2946 wg_queue_purge(&peer->p_stage_queue);
2947 wg_timers_disable(peer);
2948 }
2949
2950 wg_queue_purge(&sc->sc_handshake_queue);
2951
2952 TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) {
2953 noise_remote_handshake_clear(peer->p_remote);
2954 noise_remote_keypairs_clear(peer->p_remote);
2955 }
2956
2957 if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
2958 wg_socket_uninit(sc);
2959
2960 sx_xunlock(&sc->sc_lock);
2961 }
2962
2963 static int
wg_clone_create(struct if_clone * ifc,char * name,size_t len,struct ifc_data * ifd,struct ifnet ** ifpp)2964 wg_clone_create(struct if_clone *ifc, char *name, size_t len,
2965 struct ifc_data *ifd, struct ifnet **ifpp)
2966 {
2967 struct wg_softc *sc;
2968 if_t ifp;
2969
2970 sc = malloc(sizeof(*sc), M_WG, M_WAITOK | M_ZERO);
2971
2972 sc->sc_local = noise_local_alloc(sc);
2973
2974 sc->sc_encrypt = mallocarray(sizeof(struct grouptask), mp_ncpus, M_WG, M_WAITOK | M_ZERO);
2975
2976 sc->sc_decrypt = mallocarray(sizeof(struct grouptask), mp_ncpus, M_WG, M_WAITOK | M_ZERO);
2977
2978 if (!rn_inithead((void **)&sc->sc_aip4, offsetof(struct aip_addr, in) * NBBY))
2979 goto free_decrypt;
2980
2981 if (!rn_inithead((void **)&sc->sc_aip6, offsetof(struct aip_addr, in6) * NBBY))
2982 goto free_aip4;
2983
2984 atomic_add_int(&clone_count, 1);
2985 ifp = sc->sc_ifp = if_alloc(IFT_WIREGUARD);
2986
2987 sc->sc_ucred = crhold(curthread->td_ucred);
2988 sc->sc_socket.so_fibnum = curthread->td_proc->p_fibnum;
2989 sc->sc_socket.so_port = 0;
2990
2991 TAILQ_INIT(&sc->sc_peers);
2992 sc->sc_peers_num = 0;
2993
2994 cookie_checker_init(&sc->sc_cookie);
2995
2996 RADIX_NODE_HEAD_LOCK_INIT(sc->sc_aip4);
2997 RADIX_NODE_HEAD_LOCK_INIT(sc->sc_aip6);
2998
2999 GROUPTASK_INIT(&sc->sc_handshake, 0, (gtask_fn_t *)wg_softc_handshake_receive, sc);
3000 taskqgroup_attach(qgroup_wg_tqg, &sc->sc_handshake, sc, NULL, NULL, "wg tx initiation");
3001 wg_queue_init(&sc->sc_handshake_queue, "hsq");
3002
3003 for (int i = 0; i < mp_ncpus; i++) {
3004 GROUPTASK_INIT(&sc->sc_encrypt[i], 0,
3005 (gtask_fn_t *)wg_softc_encrypt, sc);
3006 taskqgroup_attach_cpu(qgroup_wg_tqg, &sc->sc_encrypt[i], sc, i, NULL, NULL, "wg encrypt");
3007 GROUPTASK_INIT(&sc->sc_decrypt[i], 0,
3008 (gtask_fn_t *)wg_softc_decrypt, sc);
3009 taskqgroup_attach_cpu(qgroup_wg_tqg, &sc->sc_decrypt[i], sc, i, NULL, NULL, "wg decrypt");
3010 }
3011
3012 wg_queue_init(&sc->sc_encrypt_parallel, "encp");
3013 wg_queue_init(&sc->sc_decrypt_parallel, "decp");
3014
3015 sx_init(&sc->sc_lock, "wg softc lock");
3016
3017 if_setsoftc(ifp, sc);
3018 if_setcapabilities(ifp, WG_CAPS);
3019 if_setcapenable(ifp, WG_CAPS);
3020 if_initname(ifp, wgname, ifd->unit);
3021
3022 if_setmtu(ifp, DEFAULT_MTU);
3023 if_setflags(ifp, IFF_NOARP | IFF_MULTICAST);
3024 if_setinitfn(ifp, wg_init);
3025 if_setreassignfn(ifp, wg_reassign);
3026 if_setqflushfn(ifp, wg_qflush);
3027 if_settransmitfn(ifp, wg_transmit);
3028 #ifdef DEV_NETMAP
3029 if_setinputfn(ifp, wg_if_input);
3030 #endif
3031 if_setoutputfn(ifp, wg_output);
3032 if_setioctlfn(ifp, wg_ioctl);
3033 if_attach(ifp);
3034 bpfattach(ifp, DLT_NULL, sizeof(uint32_t));
3035 #ifdef INET6
3036 ND_IFINFO(ifp)->flags &= ~ND6_IFF_AUTO_LINKLOCAL;
3037 ND_IFINFO(ifp)->flags |= ND6_IFF_NO_DAD;
3038 #endif
3039 sx_xlock(&wg_sx);
3040 LIST_INSERT_HEAD(&wg_list, sc, sc_entry);
3041 sx_xunlock(&wg_sx);
3042 *ifpp = ifp;
3043 return (0);
3044 free_aip4:
3045 RADIX_NODE_HEAD_DESTROY(sc->sc_aip4);
3046 free(sc->sc_aip4, M_RTABLE);
3047 free_decrypt:
3048 free(sc->sc_decrypt, M_WG);
3049 free(sc->sc_encrypt, M_WG);
3050 noise_local_free(sc->sc_local, NULL);
3051 free(sc, M_WG);
3052 return (ENOMEM);
3053 }
3054
3055 static void
wg_clone_deferred_free(struct noise_local * l)3056 wg_clone_deferred_free(struct noise_local *l)
3057 {
3058 struct wg_softc *sc = noise_local_arg(l);
3059
3060 free(sc, M_WG);
3061 atomic_add_int(&clone_count, -1);
3062 }
3063
3064 static int
wg_clone_destroy(struct if_clone * ifc,if_t ifp,uint32_t flags)3065 wg_clone_destroy(struct if_clone *ifc, if_t ifp, uint32_t flags)
3066 {
3067 struct wg_softc *sc = if_getsoftc(ifp);
3068 struct ucred *cred;
3069
3070 sx_xlock(&wg_sx);
3071 if_setsoftc(ifp, NULL);
3072 sx_xlock(&sc->sc_lock);
3073 sc->sc_flags |= WGF_DYING;
3074 cred = sc->sc_ucred;
3075 sc->sc_ucred = NULL;
3076 sx_xunlock(&sc->sc_lock);
3077 LIST_REMOVE(sc, sc_entry);
3078 sx_xunlock(&wg_sx);
3079
3080 if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
3081 CURVNET_SET(if_getvnet(sc->sc_ifp));
3082 if_purgeaddrs(sc->sc_ifp);
3083 CURVNET_RESTORE();
3084
3085 sx_xlock(&sc->sc_lock);
3086 wg_socket_uninit(sc);
3087 sx_xunlock(&sc->sc_lock);
3088
3089 /*
3090 * No guarantees that all traffic have passed until the epoch has
3091 * elapsed with the socket closed.
3092 */
3093 NET_EPOCH_WAIT();
3094
3095 taskqgroup_drain_all(qgroup_wg_tqg);
3096 sx_xlock(&sc->sc_lock);
3097 wg_peer_destroy_all(sc);
3098 NET_EPOCH_DRAIN_CALLBACKS();
3099 sx_xunlock(&sc->sc_lock);
3100 sx_destroy(&sc->sc_lock);
3101 taskqgroup_detach(qgroup_wg_tqg, &sc->sc_handshake);
3102 for (int i = 0; i < mp_ncpus; i++) {
3103 taskqgroup_detach(qgroup_wg_tqg, &sc->sc_encrypt[i]);
3104 taskqgroup_detach(qgroup_wg_tqg, &sc->sc_decrypt[i]);
3105 }
3106 free(sc->sc_encrypt, M_WG);
3107 free(sc->sc_decrypt, M_WG);
3108 wg_queue_deinit(&sc->sc_handshake_queue);
3109 wg_queue_deinit(&sc->sc_encrypt_parallel);
3110 wg_queue_deinit(&sc->sc_decrypt_parallel);
3111
3112 RADIX_NODE_HEAD_DESTROY(sc->sc_aip4);
3113 RADIX_NODE_HEAD_DESTROY(sc->sc_aip6);
3114 rn_detachhead((void **)&sc->sc_aip4);
3115 rn_detachhead((void **)&sc->sc_aip6);
3116
3117 cookie_checker_free(&sc->sc_cookie);
3118
3119 if (cred != NULL)
3120 crfree(cred);
3121 bpfdetach(sc->sc_ifp);
3122 if_detach(sc->sc_ifp);
3123 if_free(sc->sc_ifp);
3124
3125 noise_local_free(sc->sc_local, wg_clone_deferred_free);
3126
3127 return (0);
3128 }
3129
3130 static void
wg_qflush(if_t ifp __unused)3131 wg_qflush(if_t ifp __unused)
3132 {
3133 }
3134
3135 /*
3136 * Privileged information (private-key, preshared-key) are only exported for
3137 * root and jailed root by default.
3138 */
3139 static bool
wgc_privileged(struct wg_softc * sc)3140 wgc_privileged(struct wg_softc *sc)
3141 {
3142 struct thread *td;
3143
3144 td = curthread;
3145 return (priv_check(td, PRIV_NET_WG) == 0);
3146 }
3147
3148 static void
wg_reassign(if_t ifp,struct vnet * new_vnet __unused,char * unused __unused)3149 wg_reassign(if_t ifp, struct vnet *new_vnet __unused,
3150 char *unused __unused)
3151 {
3152 struct wg_softc *sc;
3153
3154 sc = if_getsoftc(ifp);
3155 wg_down(sc);
3156 }
3157
3158 static void
wg_init(void * xsc)3159 wg_init(void *xsc)
3160 {
3161 struct wg_softc *sc;
3162
3163 sc = xsc;
3164 wg_up(sc);
3165 }
3166
3167 static void
vnet_wg_init(const void * unused __unused)3168 vnet_wg_init(const void *unused __unused)
3169 {
3170 struct if_clone_addreq req = {
3171 .create_f = wg_clone_create,
3172 .destroy_f = wg_clone_destroy,
3173 .flags = IFC_F_AUTOUNIT,
3174 };
3175 V_wg_cloner = ifc_attach_cloner(wgname, &req);
3176 }
3177 VNET_SYSINIT(vnet_wg_init, SI_SUB_PROTO_IFATTACHDOMAIN, SI_ORDER_ANY,
3178 vnet_wg_init, NULL);
3179
3180 static void
vnet_wg_uninit(const void * unused __unused)3181 vnet_wg_uninit(const void *unused __unused)
3182 {
3183 if (V_wg_cloner)
3184 ifc_detach_cloner(V_wg_cloner);
3185 }
3186 VNET_SYSUNINIT(vnet_wg_uninit, SI_SUB_PROTO_IFATTACHDOMAIN, SI_ORDER_ANY,
3187 vnet_wg_uninit, NULL);
3188
3189 static int
wg_prison_remove(void * obj,void * data __unused)3190 wg_prison_remove(void *obj, void *data __unused)
3191 {
3192 const struct prison *pr = obj;
3193 struct wg_softc *sc;
3194
3195 /*
3196 * Do a pass through all if_wg interfaces and release creds on any from
3197 * the jail that are supposed to be going away. This will, in turn, let
3198 * the jail die so that we don't end up with Schrödinger's jail.
3199 */
3200 sx_slock(&wg_sx);
3201 LIST_FOREACH(sc, &wg_list, sc_entry) {
3202 sx_xlock(&sc->sc_lock);
3203 if (!(sc->sc_flags & WGF_DYING) && sc->sc_ucred && sc->sc_ucred->cr_prison == pr) {
3204 struct ucred *cred = sc->sc_ucred;
3205 DPRINTF(sc, "Creating jail exiting\n");
3206 if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
3207 wg_socket_uninit(sc);
3208 sc->sc_ucred = NULL;
3209 crfree(cred);
3210 sc->sc_flags |= WGF_DYING;
3211 }
3212 sx_xunlock(&sc->sc_lock);
3213 }
3214 sx_sunlock(&wg_sx);
3215
3216 return (0);
3217 }
3218
3219 #ifdef SELFTESTS
3220 #include "selftest/allowedips.c"
wg_run_selftests(void)3221 static bool wg_run_selftests(void)
3222 {
3223 bool ret = true;
3224 ret &= wg_allowedips_selftest();
3225 ret &= noise_counter_selftest();
3226 ret &= cookie_selftest();
3227 return ret;
3228 }
3229 #else
wg_run_selftests(void)3230 static inline bool wg_run_selftests(void) { return true; }
3231 #endif
3232
3233 static int
wg_module_init(void)3234 wg_module_init(void)
3235 {
3236 int ret;
3237 osd_method_t methods[PR_MAXMETHOD] = {
3238 [PR_METHOD_REMOVE] = wg_prison_remove,
3239 };
3240
3241 wg_packet_zone = uma_zcreate("wg packet", sizeof(struct wg_packet),
3242 NULL, NULL, NULL, NULL, 0, 0);
3243
3244 ret = crypto_init();
3245 if (ret != 0)
3246 return (ret);
3247 ret = cookie_init();
3248 if (ret != 0)
3249 return (ret);
3250
3251 wg_osd_jail_slot = osd_jail_register(NULL, methods);
3252
3253 if (!wg_run_selftests())
3254 return (ENOTRECOVERABLE);
3255
3256 return (0);
3257 }
3258
3259 static void
wg_module_deinit(void)3260 wg_module_deinit(void)
3261 {
3262 VNET_ITERATOR_DECL(vnet_iter);
3263 VNET_LIST_RLOCK();
3264 VNET_FOREACH(vnet_iter) {
3265 struct if_clone *clone = VNET_VNET(vnet_iter, wg_cloner);
3266 if (clone) {
3267 ifc_detach_cloner(clone);
3268 VNET_VNET(vnet_iter, wg_cloner) = NULL;
3269 }
3270 }
3271 VNET_LIST_RUNLOCK();
3272 NET_EPOCH_WAIT();
3273 MPASS(LIST_EMPTY(&wg_list));
3274 if (wg_osd_jail_slot != 0)
3275 osd_jail_deregister(wg_osd_jail_slot);
3276 cookie_deinit();
3277 crypto_deinit();
3278 if (wg_packet_zone != NULL)
3279 uma_zdestroy(wg_packet_zone);
3280 }
3281
3282 static int
wg_module_event_handler(module_t mod,int what,void * arg)3283 wg_module_event_handler(module_t mod, int what, void *arg)
3284 {
3285 switch (what) {
3286 case MOD_LOAD:
3287 return wg_module_init();
3288 case MOD_UNLOAD:
3289 wg_module_deinit();
3290 break;
3291 default:
3292 return (EOPNOTSUPP);
3293 }
3294 return (0);
3295 }
3296
3297 static moduledata_t wg_moduledata = {
3298 "if_wg",
3299 wg_module_event_handler,
3300 NULL
3301 };
3302
3303 DECLARE_MODULE(if_wg, wg_moduledata, SI_SUB_PSEUDO, SI_ORDER_ANY);
3304 MODULE_VERSION(if_wg, WIREGUARD_VERSION);
3305 MODULE_DEPEND(if_wg, crypto, 1, 1, 1);
3306