1 /* SPDX-License-Identifier: ISC
2 *
3 * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4 * Copyright (C) 2019-2021 Matt Dunwoodie <ncon@noconroy.net>
5 * Copyright (c) 2019-2020 Rubicon Communications, LLC (Netgate)
6 * Copyright (c) 2021 Kyle Evans <kevans@FreeBSD.org>
7 * Copyright (c) 2022 The FreeBSD Foundation
8 */
9
10 #include "opt_inet.h"
11 #include "opt_inet6.h"
12
13 #include <sys/param.h>
14 #include <sys/systm.h>
15 #include <sys/counter.h>
16 #include <sys/gtaskqueue.h>
17 #include <sys/jail.h>
18 #include <sys/kernel.h>
19 #include <sys/lock.h>
20 #include <sys/mbuf.h>
21 #include <sys/module.h>
22 #include <sys/nv.h>
23 #include <sys/priv.h>
24 #include <sys/protosw.h>
25 #include <sys/rmlock.h>
26 #include <sys/rwlock.h>
27 #include <sys/smp.h>
28 #include <sys/socket.h>
29 #include <sys/socketvar.h>
30 #include <sys/sockio.h>
31 #include <sys/sysctl.h>
32 #include <sys/sx.h>
33 #include <machine/_inttypes.h>
34 #include <net/bpf.h>
35 #include <net/ethernet.h>
36 #include <net/if.h>
37 #include <net/if_clone.h>
38 #include <net/if_types.h>
39 #include <net/if_var.h>
40 #include <net/netisr.h>
41 #include <net/radix.h>
42 #include <netinet/in.h>
43 #include <netinet6/in6_var.h>
44 #include <netinet/ip.h>
45 #include <netinet/ip6.h>
46 #include <netinet/ip_icmp.h>
47 #include <netinet/icmp6.h>
48 #include <netinet/udp_var.h>
49 #include <netinet6/nd6.h>
50
51 #include "wg_noise.h"
52 #include "wg_cookie.h"
53 #include "version.h"
54 #include "if_wg.h"
55
56 #define DEFAULT_MTU (ETHERMTU - 80)
57 #define MAX_MTU (IF_MAXMTU - 80)
58
59 #define MAX_STAGED_PKT 128
60 #define MAX_QUEUED_PKT 1024
61 #define MAX_QUEUED_PKT_MASK (MAX_QUEUED_PKT - 1)
62
63 #define MAX_QUEUED_HANDSHAKES 4096
64
65 #define REKEY_TIMEOUT_JITTER 334 /* 1/3 sec, round for arc4random_uniform */
66 #define MAX_TIMER_HANDSHAKES (90 / REKEY_TIMEOUT)
67 #define NEW_HANDSHAKE_TIMEOUT (REKEY_TIMEOUT + KEEPALIVE_TIMEOUT)
68 #define UNDERLOAD_TIMEOUT 1
69
70 #define DPRINTF(sc, ...) if (if_getflags(sc->sc_ifp) & IFF_DEBUG) if_printf(sc->sc_ifp, ##__VA_ARGS__)
71
72 /* First byte indicating packet type on the wire */
73 #define WG_PKT_INITIATION htole32(1)
74 #define WG_PKT_RESPONSE htole32(2)
75 #define WG_PKT_COOKIE htole32(3)
76 #define WG_PKT_DATA htole32(4)
77
78 #define WG_PKT_PADDING 16
79 #define WG_KEY_SIZE 32
80
81 struct wg_pkt_initiation {
82 uint32_t t;
83 uint32_t s_idx;
84 uint8_t ue[NOISE_PUBLIC_KEY_LEN];
85 uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN];
86 uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN];
87 struct cookie_macs m;
88 };
89
90 struct wg_pkt_response {
91 uint32_t t;
92 uint32_t s_idx;
93 uint32_t r_idx;
94 uint8_t ue[NOISE_PUBLIC_KEY_LEN];
95 uint8_t en[0 + NOISE_AUTHTAG_LEN];
96 struct cookie_macs m;
97 };
98
99 struct wg_pkt_cookie {
100 uint32_t t;
101 uint32_t r_idx;
102 uint8_t nonce[COOKIE_NONCE_SIZE];
103 uint8_t ec[COOKIE_ENCRYPTED_SIZE];
104 };
105
106 struct wg_pkt_data {
107 uint32_t t;
108 uint32_t r_idx;
109 uint64_t nonce;
110 uint8_t buf[];
111 };
112
113 struct wg_endpoint {
114 union {
115 struct sockaddr r_sa;
116 struct sockaddr_in r_sin;
117 #ifdef INET6
118 struct sockaddr_in6 r_sin6;
119 #endif
120 } e_remote;
121 union {
122 struct in_addr l_in;
123 #ifdef INET6
124 struct in6_pktinfo l_pktinfo6;
125 #define l_in6 l_pktinfo6.ipi6_addr
126 #endif
127 } e_local;
128 };
129
130 struct aip_addr {
131 uint8_t length;
132 union {
133 uint8_t bytes[16];
134 uint32_t ip;
135 uint32_t ip6[4];
136 struct in_addr in;
137 struct in6_addr in6;
138 };
139 };
140
141 struct wg_aip {
142 struct radix_node a_nodes[2];
143 LIST_ENTRY(wg_aip) a_entry;
144 struct aip_addr a_addr;
145 struct aip_addr a_mask;
146 struct wg_peer *a_peer;
147 sa_family_t a_af;
148 };
149
150 struct wg_packet {
151 STAILQ_ENTRY(wg_packet) p_serial;
152 STAILQ_ENTRY(wg_packet) p_parallel;
153 struct wg_endpoint p_endpoint;
154 struct noise_keypair *p_keypair;
155 uint64_t p_nonce;
156 struct mbuf *p_mbuf;
157 int p_mtu;
158 sa_family_t p_af;
159 enum wg_ring_state {
160 WG_PACKET_UNCRYPTED,
161 WG_PACKET_CRYPTED,
162 WG_PACKET_DEAD,
163 } p_state;
164 };
165
166 STAILQ_HEAD(wg_packet_list, wg_packet);
167
168 struct wg_queue {
169 struct mtx q_mtx;
170 struct wg_packet_list q_queue;
171 size_t q_len;
172 };
173
174 struct wg_peer {
175 TAILQ_ENTRY(wg_peer) p_entry;
176 uint64_t p_id;
177 struct wg_softc *p_sc;
178
179 struct noise_remote *p_remote;
180 struct cookie_maker p_cookie;
181
182 struct rwlock p_endpoint_lock;
183 struct wg_endpoint p_endpoint;
184
185 struct wg_queue p_stage_queue;
186 struct wg_queue p_encrypt_serial;
187 struct wg_queue p_decrypt_serial;
188
189 bool p_enabled;
190 bool p_need_another_keepalive;
191 uint16_t p_persistent_keepalive_interval;
192 struct callout p_new_handshake;
193 struct callout p_send_keepalive;
194 struct callout p_retry_handshake;
195 struct callout p_zero_key_material;
196 struct callout p_persistent_keepalive;
197
198 struct mtx p_handshake_mtx;
199 struct timespec p_handshake_complete; /* nanotime */
200 int p_handshake_retries;
201
202 struct grouptask p_send;
203 struct grouptask p_recv;
204
205 counter_u64_t p_tx_bytes;
206 counter_u64_t p_rx_bytes;
207
208 LIST_HEAD(, wg_aip) p_aips;
209 size_t p_aips_num;
210 };
211
212 struct wg_socket {
213 struct socket *so_so4;
214 struct socket *so_so6;
215 uint32_t so_user_cookie;
216 int so_fibnum;
217 in_port_t so_port;
218 };
219
220 struct wg_softc {
221 LIST_ENTRY(wg_softc) sc_entry;
222 if_t sc_ifp;
223 int sc_flags;
224
225 struct ucred *sc_ucred;
226 struct wg_socket sc_socket;
227
228 TAILQ_HEAD(,wg_peer) sc_peers;
229 size_t sc_peers_num;
230
231 struct noise_local *sc_local;
232 struct cookie_checker sc_cookie;
233
234 struct radix_node_head *sc_aip4;
235 struct radix_node_head *sc_aip6;
236
237 struct grouptask sc_handshake;
238 struct wg_queue sc_handshake_queue;
239
240 struct grouptask *sc_encrypt;
241 struct grouptask *sc_decrypt;
242 struct wg_queue sc_encrypt_parallel;
243 struct wg_queue sc_decrypt_parallel;
244 u_int sc_encrypt_last_cpu;
245 u_int sc_decrypt_last_cpu;
246
247 struct sx sc_lock;
248 };
249
250 #define WGF_DYING 0x0001
251
252 #define MAX_LOOPS 8
253 #define MTAG_WGLOOP 0x77676c70 /* wglp */
254
255 #define GROUPTASK_DRAIN(gtask) \
256 gtaskqueue_drain((gtask)->gt_taskqueue, &(gtask)->gt_task)
257
258 #define BPF_MTAP2_AF(ifp, m, af) do { \
259 uint32_t __bpf_tap_af = (af); \
260 BPF_MTAP2(ifp, &__bpf_tap_af, sizeof(__bpf_tap_af), m); \
261 } while (0)
262
263 static int clone_count;
264 static uma_zone_t wg_packet_zone;
265 static volatile unsigned long peer_counter = 0;
266 static const char wgname[] = "wg";
267 static unsigned wg_osd_jail_slot;
268
269 static struct sx wg_sx;
270 SX_SYSINIT(wg_sx, &wg_sx, "wg_sx");
271
272 static LIST_HEAD(, wg_softc) wg_list = LIST_HEAD_INITIALIZER(wg_list);
273
274 static TASKQGROUP_DEFINE(wg_tqg, mp_ncpus, 1);
275
276 MALLOC_DEFINE(M_WG, "WG", "wireguard");
277
278 VNET_DEFINE_STATIC(struct if_clone *, wg_cloner);
279
280 #define V_wg_cloner VNET(wg_cloner)
281 #define WG_CAPS IFCAP_LINKSTATE
282
283 struct wg_timespec64 {
284 uint64_t tv_sec;
285 uint64_t tv_nsec;
286 };
287
288 static int wg_socket_init(struct wg_softc *, in_port_t);
289 static int wg_socket_bind(struct socket **, struct socket **, in_port_t *);
290 static void wg_socket_set(struct wg_softc *, struct socket *, struct socket *);
291 static void wg_socket_uninit(struct wg_softc *);
292 static int wg_socket_set_sockopt(struct socket *, struct socket *, int, void *, size_t);
293 static int wg_socket_set_cookie(struct wg_softc *, uint32_t);
294 static int wg_socket_set_fibnum(struct wg_softc *, int);
295 static int wg_send(struct wg_softc *, struct wg_endpoint *, struct mbuf *);
296 static void wg_timers_enable(struct wg_peer *);
297 static void wg_timers_disable(struct wg_peer *);
298 static void wg_timers_set_persistent_keepalive(struct wg_peer *, uint16_t);
299 static void wg_timers_get_last_handshake(struct wg_peer *, struct wg_timespec64 *);
300 static void wg_timers_event_data_sent(struct wg_peer *);
301 static void wg_timers_event_data_received(struct wg_peer *);
302 static void wg_timers_event_any_authenticated_packet_sent(struct wg_peer *);
303 static void wg_timers_event_any_authenticated_packet_received(struct wg_peer *);
304 static void wg_timers_event_any_authenticated_packet_traversal(struct wg_peer *);
305 static void wg_timers_event_handshake_initiated(struct wg_peer *);
306 static void wg_timers_event_handshake_complete(struct wg_peer *);
307 static void wg_timers_event_session_derived(struct wg_peer *);
308 static void wg_timers_event_want_initiation(struct wg_peer *);
309 static void wg_timers_run_send_initiation(struct wg_peer *, bool);
310 static void wg_timers_run_retry_handshake(void *);
311 static void wg_timers_run_send_keepalive(void *);
312 static void wg_timers_run_new_handshake(void *);
313 static void wg_timers_run_zero_key_material(void *);
314 static void wg_timers_run_persistent_keepalive(void *);
315 static int wg_aip_add(struct wg_softc *, struct wg_peer *, sa_family_t,
316 const void *, uint8_t);
317 static int wg_aip_del(struct wg_softc *, struct wg_peer *, sa_family_t,
318 const void *, uint8_t);
319 static struct wg_peer *wg_aip_lookup(struct wg_softc *, sa_family_t, void *);
320 static void wg_aip_remove_all(struct wg_softc *, struct wg_peer *);
321 static struct wg_peer *wg_peer_create(struct wg_softc *,
322 const uint8_t [WG_KEY_SIZE], int *);
323 static void wg_peer_free_deferred(struct noise_remote *);
324 static void wg_peer_destroy(struct wg_peer *);
325 static void wg_peer_destroy_all(struct wg_softc *);
326 static void wg_peer_send_buf(struct wg_peer *, uint8_t *, size_t);
327 static void wg_send_initiation(struct wg_peer *);
328 static void wg_send_response(struct wg_peer *);
329 static void wg_send_cookie(struct wg_softc *, struct cookie_macs *, uint32_t, struct wg_endpoint *);
330 static void wg_peer_set_endpoint(struct wg_peer *, struct wg_endpoint *);
331 static void wg_peer_clear_src(struct wg_peer *);
332 static void wg_peer_get_endpoint(struct wg_peer *, struct wg_endpoint *);
333 static void wg_send_buf(struct wg_softc *, struct wg_endpoint *, uint8_t *, size_t);
334 static void wg_send_keepalive(struct wg_peer *);
335 static void wg_handshake(struct wg_softc *, struct wg_packet *);
336 static void wg_encrypt(struct wg_softc *, struct wg_packet *);
337 static void wg_decrypt(struct wg_softc *, struct wg_packet *);
338 static void wg_softc_handshake_receive(struct wg_softc *);
339 static void wg_softc_decrypt(struct wg_softc *);
340 static void wg_softc_encrypt(struct wg_softc *);
341 static void wg_encrypt_dispatch(struct wg_softc *);
342 static void wg_decrypt_dispatch(struct wg_softc *);
343 static void wg_deliver_out(struct wg_peer *);
344 static void wg_deliver_in(struct wg_peer *);
345 static struct wg_packet *wg_packet_alloc(struct mbuf *);
346 static void wg_packet_free(struct wg_packet *);
347 static void wg_queue_init(struct wg_queue *, const char *);
348 static void wg_queue_deinit(struct wg_queue *);
349 static size_t wg_queue_len(struct wg_queue *);
350 static int wg_queue_enqueue_handshake(struct wg_queue *, struct wg_packet *);
351 static struct wg_packet *wg_queue_dequeue_handshake(struct wg_queue *);
352 static void wg_queue_push_staged(struct wg_queue *, struct wg_packet *);
353 static void wg_queue_enlist_staged(struct wg_queue *, struct wg_packet_list *);
354 static void wg_queue_delist_staged(struct wg_queue *, struct wg_packet_list *);
355 static void wg_queue_purge(struct wg_queue *);
356 static int wg_queue_both(struct wg_queue *, struct wg_queue *, struct wg_packet *);
357 static struct wg_packet *wg_queue_dequeue_serial(struct wg_queue *);
358 static struct wg_packet *wg_queue_dequeue_parallel(struct wg_queue *);
359 static bool wg_input(struct mbuf *, int, struct inpcb *, const struct sockaddr *, void *);
360 static void wg_peer_send_staged(struct wg_peer *);
361 static int wg_clone_create(struct if_clone *ifc, char *name, size_t len,
362 struct ifc_data *ifd, if_t *ifpp);
363 static void wg_qflush(if_t);
364 static inline int determine_af_and_pullup(struct mbuf **m, sa_family_t *af);
365 static int wg_xmit(if_t, struct mbuf *, sa_family_t, uint32_t);
366 static int wg_transmit(if_t, struct mbuf *);
367 static int wg_output(if_t, struct mbuf *, const struct sockaddr *, struct route *);
368 static int wg_clone_destroy(struct if_clone *ifc, if_t ifp,
369 uint32_t flags);
370 static bool wgc_privileged(struct wg_softc *);
371 static int wgc_get(struct wg_softc *, struct wg_data_io *);
372 static int wgc_set(struct wg_softc *, struct wg_data_io *);
373 static int wg_up(struct wg_softc *);
374 static void wg_down(struct wg_softc *);
375 static void wg_reassign(if_t, struct vnet *, char *unused);
376 static void wg_init(void *);
377 static int wg_ioctl(if_t, u_long, caddr_t);
378 static void vnet_wg_init(const void *);
379 static void vnet_wg_uninit(const void *);
380 static int wg_module_init(void);
381 static void wg_module_deinit(void);
382
383 /* TODO Peer */
384 static struct wg_peer *
wg_peer_create(struct wg_softc * sc,const uint8_t pub_key[WG_KEY_SIZE],int * errp)385 wg_peer_create(struct wg_softc *sc, const uint8_t pub_key[WG_KEY_SIZE],
386 int *errp)
387 {
388 struct wg_peer *peer;
389
390 sx_assert(&sc->sc_lock, SX_XLOCKED);
391
392 peer = malloc(sizeof(*peer), M_WG, M_WAITOK | M_ZERO);
393
394 peer->p_remote = noise_remote_alloc(sc->sc_local, peer, pub_key);
395 if ((*errp = noise_remote_enable(peer->p_remote)) != 0) {
396 noise_remote_free(peer->p_remote, NULL);
397 free(peer, M_WG);
398 return (NULL);
399 }
400
401 peer->p_id = peer_counter++;
402 peer->p_sc = sc;
403 peer->p_tx_bytes = counter_u64_alloc(M_WAITOK);
404 peer->p_rx_bytes = counter_u64_alloc(M_WAITOK);
405
406 cookie_maker_init(&peer->p_cookie, pub_key);
407
408 rw_init(&peer->p_endpoint_lock, "wg_peer_endpoint");
409
410 wg_queue_init(&peer->p_stage_queue, "stageq");
411 wg_queue_init(&peer->p_encrypt_serial, "txq");
412 wg_queue_init(&peer->p_decrypt_serial, "rxq");
413
414 peer->p_enabled = false;
415 peer->p_need_another_keepalive = false;
416 peer->p_persistent_keepalive_interval = 0;
417 callout_init(&peer->p_new_handshake, true);
418 callout_init(&peer->p_send_keepalive, true);
419 callout_init(&peer->p_retry_handshake, true);
420 callout_init(&peer->p_persistent_keepalive, true);
421 callout_init(&peer->p_zero_key_material, true);
422
423 mtx_init(&peer->p_handshake_mtx, "peer handshake", NULL, MTX_DEF);
424 bzero(&peer->p_handshake_complete, sizeof(peer->p_handshake_complete));
425 peer->p_handshake_retries = 0;
426
427 GROUPTASK_INIT(&peer->p_send, 0, (gtask_fn_t *)wg_deliver_out, peer);
428 taskqgroup_attach(qgroup_wg_tqg, &peer->p_send, peer, NULL, NULL, "wg send");
429 GROUPTASK_INIT(&peer->p_recv, 0, (gtask_fn_t *)wg_deliver_in, peer);
430 taskqgroup_attach(qgroup_wg_tqg, &peer->p_recv, peer, NULL, NULL, "wg recv");
431
432 LIST_INIT(&peer->p_aips);
433 peer->p_aips_num = 0;
434
435 TAILQ_INSERT_TAIL(&sc->sc_peers, peer, p_entry);
436 sc->sc_peers_num++;
437
438 if (if_getlinkstate(sc->sc_ifp) == LINK_STATE_UP)
439 wg_timers_enable(peer);
440
441 DPRINTF(sc, "Peer %" PRIu64 " created\n", peer->p_id);
442 return (peer);
443 }
444
445 static void
wg_peer_free_deferred(struct noise_remote * r)446 wg_peer_free_deferred(struct noise_remote *r)
447 {
448 struct wg_peer *peer = noise_remote_arg(r);
449
450 /* While there are no references remaining, we may still have
451 * p_{send,recv} executing (think empty queue, but wg_deliver_{in,out}
452 * needs to check the queue. We should wait for them and then free. */
453 GROUPTASK_DRAIN(&peer->p_recv);
454 GROUPTASK_DRAIN(&peer->p_send);
455 taskqgroup_detach(qgroup_wg_tqg, &peer->p_recv);
456 taskqgroup_detach(qgroup_wg_tqg, &peer->p_send);
457
458 wg_queue_deinit(&peer->p_decrypt_serial);
459 wg_queue_deinit(&peer->p_encrypt_serial);
460 wg_queue_deinit(&peer->p_stage_queue);
461
462 counter_u64_free(peer->p_tx_bytes);
463 counter_u64_free(peer->p_rx_bytes);
464 rw_destroy(&peer->p_endpoint_lock);
465 mtx_destroy(&peer->p_handshake_mtx);
466
467 cookie_maker_free(&peer->p_cookie);
468
469 free(peer, M_WG);
470 }
471
472 static void
wg_peer_destroy(struct wg_peer * peer)473 wg_peer_destroy(struct wg_peer *peer)
474 {
475 struct wg_softc *sc = peer->p_sc;
476 sx_assert(&sc->sc_lock, SX_XLOCKED);
477
478 /* Disable remote and timers. This will prevent any new handshakes
479 * occuring. */
480 noise_remote_disable(peer->p_remote);
481 wg_timers_disable(peer);
482
483 /* Now we can remove all allowed IPs so no more packets will be routed
484 * to the peer. */
485 wg_aip_remove_all(sc, peer);
486
487 /* Remove peer from the interface, then free. Some references may still
488 * exist to p_remote, so noise_remote_free will wait until they're all
489 * put to call wg_peer_free_deferred. */
490 sc->sc_peers_num--;
491 TAILQ_REMOVE(&sc->sc_peers, peer, p_entry);
492 DPRINTF(sc, "Peer %" PRIu64 " destroyed\n", peer->p_id);
493 noise_remote_free(peer->p_remote, wg_peer_free_deferred);
494 }
495
496 static void
wg_peer_destroy_all(struct wg_softc * sc)497 wg_peer_destroy_all(struct wg_softc *sc)
498 {
499 struct wg_peer *peer, *tpeer;
500 TAILQ_FOREACH_SAFE(peer, &sc->sc_peers, p_entry, tpeer)
501 wg_peer_destroy(peer);
502 }
503
504 static void
wg_peer_set_endpoint(struct wg_peer * peer,struct wg_endpoint * e)505 wg_peer_set_endpoint(struct wg_peer *peer, struct wg_endpoint *e)
506 {
507 MPASS(e->e_remote.r_sa.sa_family != 0);
508 if (memcmp(e, &peer->p_endpoint, sizeof(*e)) == 0)
509 return;
510
511 rw_wlock(&peer->p_endpoint_lock);
512 peer->p_endpoint = *e;
513 rw_wunlock(&peer->p_endpoint_lock);
514 }
515
516 static void
wg_peer_clear_src(struct wg_peer * peer)517 wg_peer_clear_src(struct wg_peer *peer)
518 {
519 rw_wlock(&peer->p_endpoint_lock);
520 bzero(&peer->p_endpoint.e_local, sizeof(peer->p_endpoint.e_local));
521 rw_wunlock(&peer->p_endpoint_lock);
522 }
523
524 static void
wg_peer_get_endpoint(struct wg_peer * peer,struct wg_endpoint * e)525 wg_peer_get_endpoint(struct wg_peer *peer, struct wg_endpoint *e)
526 {
527 rw_rlock(&peer->p_endpoint_lock);
528 *e = peer->p_endpoint;
529 rw_runlock(&peer->p_endpoint_lock);
530 }
531
532 static int
wg_aip_addrinfo(struct wg_aip * aip,const void * baddr,uint8_t cidr)533 wg_aip_addrinfo(struct wg_aip *aip, const void *baddr, uint8_t cidr)
534 {
535 struct aip_addr *addr, *mask;
536
537 addr = &aip->a_addr;
538 mask = &aip->a_mask;
539 switch (aip->a_af) {
540 #ifdef INET
541 case AF_INET:
542 if (cidr > 32) cidr = 32;
543 addr->in = *(const struct in_addr *)baddr;
544 mask->ip = htonl(~((1LL << (32 - cidr)) - 1) & 0xffffffff);
545 addr->ip &= mask->ip;
546 addr->length = mask->length = offsetof(struct aip_addr, in) + sizeof(struct in_addr);
547 break;
548 #endif
549 #ifdef INET6
550 case AF_INET6:
551 if (cidr > 128) cidr = 128;
552 addr->in6 = *(const struct in6_addr *)baddr;
553 in6_prefixlen2mask(&mask->in6, cidr);
554 for (int i = 0; i < 4; i++)
555 addr->ip6[i] &= mask->ip6[i];
556 addr->length = mask->length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr);
557 break;
558 #endif
559 default:
560 return (EAFNOSUPPORT);
561 }
562
563 return (0);
564 }
565
566 /* Allowed IP */
567 static int
wg_aip_add(struct wg_softc * sc,struct wg_peer * peer,sa_family_t af,const void * baddr,uint8_t cidr)568 wg_aip_add(struct wg_softc *sc, struct wg_peer *peer, sa_family_t af,
569 const void *baddr, uint8_t cidr)
570 {
571 struct radix_node_head *root = NULL;
572 struct radix_node *node;
573 struct wg_aip *aip;
574 int ret = 0;
575
576 aip = malloc(sizeof(*aip), M_WG, M_WAITOK | M_ZERO);
577 aip->a_peer = peer;
578 aip->a_af = af;
579
580 ret = wg_aip_addrinfo(aip, baddr, cidr);
581 if (ret != 0) {
582 free(aip, M_WG);
583 return (ret);
584 }
585
586 root = af == AF_INET ? sc->sc_aip4 : sc->sc_aip6;
587 MPASS(root != NULL);
588 RADIX_NODE_HEAD_LOCK(root);
589 node = root->rnh_addaddr(&aip->a_addr, &aip->a_mask, &root->rh, aip->a_nodes);
590 if (node == aip->a_nodes) {
591 LIST_INSERT_HEAD(&peer->p_aips, aip, a_entry);
592 peer->p_aips_num++;
593 } else if (!node)
594 node = root->rnh_lookup(&aip->a_addr, &aip->a_mask, &root->rh);
595 if (!node) {
596 free(aip, M_WG);
597 ret = ENOMEM;
598 } else if (node != aip->a_nodes) {
599 free(aip, M_WG);
600 aip = (struct wg_aip *)node;
601 if (aip->a_peer != peer) {
602 LIST_REMOVE(aip, a_entry);
603 aip->a_peer->p_aips_num--;
604 aip->a_peer = peer;
605 LIST_INSERT_HEAD(&peer->p_aips, aip, a_entry);
606 aip->a_peer->p_aips_num++;
607 }
608 }
609 RADIX_NODE_HEAD_UNLOCK(root);
610 return (ret);
611 }
612
613 static int
wg_aip_del(struct wg_softc * sc,struct wg_peer * peer,sa_family_t af,const void * baddr,uint8_t cidr)614 wg_aip_del(struct wg_softc *sc, struct wg_peer *peer, sa_family_t af,
615 const void *baddr, uint8_t cidr)
616 {
617 struct radix_node_head *root = NULL;
618 struct radix_node *dnode __diagused, *node;
619 struct wg_aip *aip, addr;
620 int ret = 0;
621
622 /*
623 * We need to be sure that all padding is cleared, as it is above when
624 * new AllowedIPs are added, since we want to do a direct comparison.
625 */
626 memset(&addr, 0, sizeof(addr));
627 addr.a_af = af;
628
629 ret = wg_aip_addrinfo(&addr, baddr, cidr);
630 if (ret != 0)
631 return (ret);
632
633 root = af == AF_INET ? sc->sc_aip4 : sc->sc_aip6;
634
635 MPASS(root != NULL);
636 RADIX_NODE_HEAD_LOCK(root);
637
638 node = root->rnh_lookup(&addr.a_addr, &addr.a_mask, &root->rh);
639 if (node == NULL) {
640 RADIX_NODE_HEAD_UNLOCK(root);
641 return (0);
642 }
643
644 aip = (struct wg_aip *)node;
645 if (aip->a_peer != peer) {
646 /*
647 * They could have specified an allowed-ip that belonged to a
648 * different peer, in which case our job is done because the
649 * AllowedIP has been removed.
650 */
651 RADIX_NODE_HEAD_UNLOCK(root);
652 return (0);
653 }
654
655 dnode = root->rnh_deladdr(&aip->a_addr, &aip->a_mask, &root->rh);
656 MPASS(dnode == node);
657 RADIX_NODE_HEAD_UNLOCK(root);
658
659 LIST_REMOVE(aip, a_entry);
660 peer->p_aips_num--;
661 free(aip, M_WG);
662 return (0);
663 }
664
665 static struct wg_peer *
wg_aip_lookup(struct wg_softc * sc,sa_family_t af,void * a)666 wg_aip_lookup(struct wg_softc *sc, sa_family_t af, void *a)
667 {
668 struct radix_node_head *root;
669 struct radix_node *node;
670 struct wg_peer *peer;
671 struct aip_addr addr;
672 RADIX_NODE_HEAD_RLOCK_TRACKER;
673
674 switch (af) {
675 case AF_INET:
676 root = sc->sc_aip4;
677 memcpy(&addr.in, a, sizeof(addr.in));
678 addr.length = offsetof(struct aip_addr, in) + sizeof(struct in_addr);
679 break;
680 case AF_INET6:
681 root = sc->sc_aip6;
682 memcpy(&addr.in6, a, sizeof(addr.in6));
683 addr.length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr);
684 break;
685 default:
686 return NULL;
687 }
688
689 RADIX_NODE_HEAD_RLOCK(root);
690 node = root->rnh_matchaddr(&addr, &root->rh);
691 if (node != NULL) {
692 peer = ((struct wg_aip *)node)->a_peer;
693 noise_remote_ref(peer->p_remote);
694 } else {
695 peer = NULL;
696 }
697 RADIX_NODE_HEAD_RUNLOCK(root);
698
699 return (peer);
700 }
701
702 static void
wg_aip_remove_all(struct wg_softc * sc,struct wg_peer * peer)703 wg_aip_remove_all(struct wg_softc *sc, struct wg_peer *peer)
704 {
705 struct wg_aip *aip, *taip;
706
707 RADIX_NODE_HEAD_LOCK(sc->sc_aip4);
708 LIST_FOREACH_SAFE(aip, &peer->p_aips, a_entry, taip) {
709 if (aip->a_af == AF_INET) {
710 if (sc->sc_aip4->rnh_deladdr(&aip->a_addr, &aip->a_mask, &sc->sc_aip4->rh) == NULL)
711 panic("failed to delete aip %p", aip);
712 LIST_REMOVE(aip, a_entry);
713 peer->p_aips_num--;
714 free(aip, M_WG);
715 }
716 }
717 RADIX_NODE_HEAD_UNLOCK(sc->sc_aip4);
718
719 RADIX_NODE_HEAD_LOCK(sc->sc_aip6);
720 LIST_FOREACH_SAFE(aip, &peer->p_aips, a_entry, taip) {
721 if (aip->a_af == AF_INET6) {
722 if (sc->sc_aip6->rnh_deladdr(&aip->a_addr, &aip->a_mask, &sc->sc_aip6->rh) == NULL)
723 panic("failed to delete aip %p", aip);
724 LIST_REMOVE(aip, a_entry);
725 peer->p_aips_num--;
726 free(aip, M_WG);
727 }
728 }
729 RADIX_NODE_HEAD_UNLOCK(sc->sc_aip6);
730
731 if (!LIST_EMPTY(&peer->p_aips) || peer->p_aips_num != 0)
732 panic("wg_aip_remove_all could not delete all %p", peer);
733 }
734
735 static int
wg_socket_init(struct wg_softc * sc,in_port_t port)736 wg_socket_init(struct wg_softc *sc, in_port_t port)
737 {
738 struct ucred *cred = sc->sc_ucred;
739 struct socket *so4 = NULL, *so6 = NULL;
740 int rc;
741
742 sx_assert(&sc->sc_lock, SX_XLOCKED);
743
744 if (!cred)
745 return (EBUSY);
746
747 /*
748 * For socket creation, we use the creds of the thread that created the
749 * tunnel rather than the current thread to maintain the semantics that
750 * WireGuard has on Linux with network namespaces -- that the sockets
751 * are created in their home vnet so that they can be configured and
752 * functionally attached to a foreign vnet as the jail's only interface
753 * to the network.
754 */
755 #ifdef INET
756 rc = socreate(AF_INET, &so4, SOCK_DGRAM, IPPROTO_UDP, cred, curthread);
757 if (rc)
758 goto out;
759
760 rc = udp_set_kernel_tunneling(so4, wg_input, NULL, sc);
761 /*
762 * udp_set_kernel_tunneling can only fail if there is already a tunneling function set.
763 * This should never happen with a new socket.
764 */
765 MPASS(rc == 0);
766 #endif
767
768 #ifdef INET6
769 rc = socreate(AF_INET6, &so6, SOCK_DGRAM, IPPROTO_UDP, cred, curthread);
770 if (rc)
771 goto out;
772 rc = udp_set_kernel_tunneling(so6, wg_input, NULL, sc);
773 MPASS(rc == 0);
774 #endif
775
776 if (sc->sc_socket.so_user_cookie) {
777 rc = wg_socket_set_sockopt(so4, so6, SO_USER_COOKIE, &sc->sc_socket.so_user_cookie, sizeof(sc->sc_socket.so_user_cookie));
778 if (rc)
779 goto out;
780 }
781 rc = wg_socket_set_sockopt(so4, so6, SO_SETFIB, &sc->sc_socket.so_fibnum, sizeof(sc->sc_socket.so_fibnum));
782 if (rc)
783 goto out;
784
785 rc = wg_socket_bind(&so4, &so6, &port);
786 if (!rc) {
787 sc->sc_socket.so_port = port;
788 wg_socket_set(sc, so4, so6);
789 }
790 out:
791 if (rc) {
792 if (so4 != NULL)
793 soclose(so4);
794 if (so6 != NULL)
795 soclose(so6);
796 }
797 return (rc);
798 }
799
wg_socket_set_sockopt(struct socket * so4,struct socket * so6,int name,void * val,size_t len)800 static int wg_socket_set_sockopt(struct socket *so4, struct socket *so6, int name, void *val, size_t len)
801 {
802 int ret4 = 0, ret6 = 0;
803 struct sockopt sopt = {
804 .sopt_dir = SOPT_SET,
805 .sopt_level = SOL_SOCKET,
806 .sopt_name = name,
807 .sopt_val = val,
808 .sopt_valsize = len
809 };
810
811 if (so4)
812 ret4 = sosetopt(so4, &sopt);
813 if (so6)
814 ret6 = sosetopt(so6, &sopt);
815 return (ret4 ?: ret6);
816 }
817
wg_socket_set_cookie(struct wg_softc * sc,uint32_t user_cookie)818 static int wg_socket_set_cookie(struct wg_softc *sc, uint32_t user_cookie)
819 {
820 struct wg_socket *so = &sc->sc_socket;
821 int ret;
822
823 sx_assert(&sc->sc_lock, SX_XLOCKED);
824 ret = wg_socket_set_sockopt(so->so_so4, so->so_so6, SO_USER_COOKIE, &user_cookie, sizeof(user_cookie));
825 if (!ret)
826 so->so_user_cookie = user_cookie;
827 return (ret);
828 }
829
wg_socket_set_fibnum(struct wg_softc * sc,int fibnum)830 static int wg_socket_set_fibnum(struct wg_softc *sc, int fibnum)
831 {
832 struct wg_socket *so = &sc->sc_socket;
833 int ret;
834
835 sx_assert(&sc->sc_lock, SX_XLOCKED);
836
837 ret = wg_socket_set_sockopt(so->so_so4, so->so_so6, SO_SETFIB, &fibnum, sizeof(fibnum));
838 if (!ret)
839 so->so_fibnum = fibnum;
840 return (ret);
841 }
842
843 static void
wg_socket_uninit(struct wg_softc * sc)844 wg_socket_uninit(struct wg_softc *sc)
845 {
846 wg_socket_set(sc, NULL, NULL);
847 }
848
849 static void
wg_socket_set(struct wg_softc * sc,struct socket * new_so4,struct socket * new_so6)850 wg_socket_set(struct wg_softc *sc, struct socket *new_so4, struct socket *new_so6)
851 {
852 struct wg_socket *so = &sc->sc_socket;
853 struct socket *so4, *so6;
854
855 sx_assert(&sc->sc_lock, SX_XLOCKED);
856
857 so4 = atomic_load_ptr(&so->so_so4);
858 so6 = atomic_load_ptr(&so->so_so6);
859 atomic_store_ptr(&so->so_so4, new_so4);
860 atomic_store_ptr(&so->so_so6, new_so6);
861
862 if (!so4 && !so6)
863 return;
864 NET_EPOCH_WAIT();
865 if (so4)
866 soclose(so4);
867 if (so6)
868 soclose(so6);
869 }
870
871 static int
wg_socket_bind(struct socket ** in_so4,struct socket ** in_so6,in_port_t * requested_port)872 wg_socket_bind(struct socket **in_so4, struct socket **in_so6, in_port_t *requested_port)
873 {
874 struct socket *so4 = *in_so4, *so6 = *in_so6;
875 int ret4 = 0, ret6 = 0;
876 in_port_t port = *requested_port;
877 struct sockaddr_in sin = {
878 .sin_len = sizeof(struct sockaddr_in),
879 .sin_family = AF_INET,
880 .sin_port = htons(port)
881 };
882 struct sockaddr_in6 sin6 = {
883 .sin6_len = sizeof(struct sockaddr_in6),
884 .sin6_family = AF_INET6,
885 .sin6_port = htons(port)
886 };
887
888 if (so4) {
889 ret4 = sobind(so4, (struct sockaddr *)&sin, curthread);
890 if (ret4 && ret4 != EADDRNOTAVAIL)
891 return (ret4);
892 if (!ret4 && !sin.sin_port) {
893 struct sockaddr_in bound_sin =
894 { .sin_len = sizeof(bound_sin) };
895 int ret;
896
897 ret = sosockaddr(so4, (struct sockaddr *)&bound_sin);
898 if (ret)
899 return (ret);
900 port = ntohs(bound_sin.sin_port);
901 sin6.sin6_port = bound_sin.sin_port;
902 }
903 }
904
905 if (so6) {
906 ret6 = sobind(so6, (struct sockaddr *)&sin6, curthread);
907 if (ret6 && ret6 != EADDRNOTAVAIL)
908 return (ret6);
909 if (!ret6 && !sin6.sin6_port) {
910 struct sockaddr_in6 bound_sin6 =
911 { .sin6_len = sizeof(bound_sin6) };
912 int ret;
913
914 ret = sosockaddr(so6, (struct sockaddr *)&bound_sin6);
915 if (ret)
916 return (ret);
917 port = ntohs(bound_sin6.sin6_port);
918 }
919 }
920
921 if (ret4 && ret6)
922 return (ret4);
923 *requested_port = port;
924 if (ret4 && !ret6 && so4) {
925 soclose(so4);
926 *in_so4 = NULL;
927 } else if (ret6 && !ret4 && so6) {
928 soclose(so6);
929 *in_so6 = NULL;
930 }
931 return (0);
932 }
933
934 static int
wg_send(struct wg_softc * sc,struct wg_endpoint * e,struct mbuf * m)935 wg_send(struct wg_softc *sc, struct wg_endpoint *e, struct mbuf *m)
936 {
937 struct epoch_tracker et;
938 struct sockaddr *sa;
939 struct wg_socket *so = &sc->sc_socket;
940 struct socket *so4, *so6;
941 struct mbuf *control = NULL;
942 int ret = 0;
943 size_t len = m->m_pkthdr.len;
944
945 /* Get local control address before locking */
946 if (e->e_remote.r_sa.sa_family == AF_INET) {
947 if (e->e_local.l_in.s_addr != INADDR_ANY)
948 control = sbcreatecontrol((caddr_t)&e->e_local.l_in,
949 sizeof(struct in_addr), IP_SENDSRCADDR,
950 IPPROTO_IP, M_NOWAIT);
951 #ifdef INET6
952 } else if (e->e_remote.r_sa.sa_family == AF_INET6) {
953 if (!IN6_IS_ADDR_UNSPECIFIED(&e->e_local.l_in6))
954 control = sbcreatecontrol((caddr_t)&e->e_local.l_pktinfo6,
955 sizeof(struct in6_pktinfo), IPV6_PKTINFO,
956 IPPROTO_IPV6, M_NOWAIT);
957 #endif
958 } else {
959 m_freem(m);
960 return (EAFNOSUPPORT);
961 }
962
963 /* Get remote address */
964 sa = &e->e_remote.r_sa;
965
966 NET_EPOCH_ENTER(et);
967 so4 = atomic_load_ptr(&so->so_so4);
968 so6 = atomic_load_ptr(&so->so_so6);
969 if (e->e_remote.r_sa.sa_family == AF_INET && so4 != NULL)
970 ret = sosend(so4, sa, NULL, m, control, 0, curthread);
971 else if (e->e_remote.r_sa.sa_family == AF_INET6 && so6 != NULL)
972 ret = sosend(so6, sa, NULL, m, control, 0, curthread);
973 else {
974 ret = ENOTCONN;
975 m_freem(control);
976 m_freem(m);
977 }
978 NET_EPOCH_EXIT(et);
979 if (ret == 0) {
980 if_inc_counter(sc->sc_ifp, IFCOUNTER_OPACKETS, 1);
981 if_inc_counter(sc->sc_ifp, IFCOUNTER_OBYTES, len);
982 }
983 return (ret);
984 }
985
986 static void
wg_send_buf(struct wg_softc * sc,struct wg_endpoint * e,uint8_t * buf,size_t len)987 wg_send_buf(struct wg_softc *sc, struct wg_endpoint *e, uint8_t *buf, size_t len)
988 {
989 struct mbuf *m;
990 int ret = 0;
991 bool retried = false;
992
993 retry:
994 m = m_get2(len, M_NOWAIT, MT_DATA, M_PKTHDR);
995 if (!m) {
996 ret = ENOMEM;
997 goto out;
998 }
999 m_copyback(m, 0, len, buf);
1000
1001 if (ret == 0) {
1002 ret = wg_send(sc, e, m);
1003 /* Retry if we couldn't bind to e->e_local */
1004 if (ret == EADDRNOTAVAIL && !retried) {
1005 bzero(&e->e_local, sizeof(e->e_local));
1006 retried = true;
1007 goto retry;
1008 }
1009 } else {
1010 ret = wg_send(sc, e, m);
1011 }
1012 out:
1013 if (ret)
1014 DPRINTF(sc, "Unable to send packet: %d\n", ret);
1015 }
1016
1017 /* Timers */
1018 static void
wg_timers_enable(struct wg_peer * peer)1019 wg_timers_enable(struct wg_peer *peer)
1020 {
1021 atomic_store_bool(&peer->p_enabled, true);
1022 wg_timers_run_persistent_keepalive(peer);
1023 }
1024
1025 static void
wg_timers_disable(struct wg_peer * peer)1026 wg_timers_disable(struct wg_peer *peer)
1027 {
1028 /* By setting p_enabled = false, then calling NET_EPOCH_WAIT, we can be
1029 * sure no new handshakes are created after the wait. This is because
1030 * all callout_resets (scheduling the callout) are guarded by
1031 * p_enabled. We can be sure all sections that read p_enabled and then
1032 * optionally call callout_reset are finished as they are surrounded by
1033 * NET_EPOCH_{ENTER,EXIT}.
1034 *
1035 * However, as new callouts may be scheduled during NET_EPOCH_WAIT (but
1036 * not after), we stop all callouts leaving no callouts active.
1037 *
1038 * We should also pull NET_EPOCH_WAIT out of the FOREACH(peer) loops, but the
1039 * performance impact is acceptable for the time being. */
1040 atomic_store_bool(&peer->p_enabled, false);
1041 NET_EPOCH_WAIT();
1042 atomic_store_bool(&peer->p_need_another_keepalive, false);
1043
1044 callout_stop(&peer->p_new_handshake);
1045 callout_stop(&peer->p_send_keepalive);
1046 callout_stop(&peer->p_retry_handshake);
1047 callout_stop(&peer->p_persistent_keepalive);
1048 callout_stop(&peer->p_zero_key_material);
1049 }
1050
1051 static void
wg_timers_set_persistent_keepalive(struct wg_peer * peer,uint16_t interval)1052 wg_timers_set_persistent_keepalive(struct wg_peer *peer, uint16_t interval)
1053 {
1054 struct epoch_tracker et;
1055 if (interval != peer->p_persistent_keepalive_interval) {
1056 atomic_store_16(&peer->p_persistent_keepalive_interval, interval);
1057 NET_EPOCH_ENTER(et);
1058 if (atomic_load_bool(&peer->p_enabled))
1059 wg_timers_run_persistent_keepalive(peer);
1060 NET_EPOCH_EXIT(et);
1061 }
1062 }
1063
1064 static void
wg_timers_get_last_handshake(struct wg_peer * peer,struct wg_timespec64 * time)1065 wg_timers_get_last_handshake(struct wg_peer *peer, struct wg_timespec64 *time)
1066 {
1067 mtx_lock(&peer->p_handshake_mtx);
1068 time->tv_sec = peer->p_handshake_complete.tv_sec;
1069 time->tv_nsec = peer->p_handshake_complete.tv_nsec;
1070 mtx_unlock(&peer->p_handshake_mtx);
1071 }
1072
1073 static void
wg_timers_event_data_sent(struct wg_peer * peer)1074 wg_timers_event_data_sent(struct wg_peer *peer)
1075 {
1076 struct epoch_tracker et;
1077 NET_EPOCH_ENTER(et);
1078 if (atomic_load_bool(&peer->p_enabled) &&
1079 !callout_pending(&peer->p_new_handshake))
1080 callout_reset(&peer->p_new_handshake, MSEC_2_TICKS(
1081 NEW_HANDSHAKE_TIMEOUT * 1000 +
1082 arc4random_uniform(REKEY_TIMEOUT_JITTER)),
1083 wg_timers_run_new_handshake, peer);
1084 NET_EPOCH_EXIT(et);
1085 }
1086
1087 static void
wg_timers_event_data_received(struct wg_peer * peer)1088 wg_timers_event_data_received(struct wg_peer *peer)
1089 {
1090 struct epoch_tracker et;
1091 NET_EPOCH_ENTER(et);
1092 if (atomic_load_bool(&peer->p_enabled)) {
1093 if (!callout_pending(&peer->p_send_keepalive))
1094 callout_reset(&peer->p_send_keepalive,
1095 MSEC_2_TICKS(KEEPALIVE_TIMEOUT * 1000),
1096 wg_timers_run_send_keepalive, peer);
1097 else
1098 atomic_store_bool(&peer->p_need_another_keepalive,
1099 true);
1100 }
1101 NET_EPOCH_EXIT(et);
1102 }
1103
1104 static void
wg_timers_event_any_authenticated_packet_sent(struct wg_peer * peer)1105 wg_timers_event_any_authenticated_packet_sent(struct wg_peer *peer)
1106 {
1107 callout_stop(&peer->p_send_keepalive);
1108 }
1109
1110 static void
wg_timers_event_any_authenticated_packet_received(struct wg_peer * peer)1111 wg_timers_event_any_authenticated_packet_received(struct wg_peer *peer)
1112 {
1113 callout_stop(&peer->p_new_handshake);
1114 }
1115
1116 static void
wg_timers_event_any_authenticated_packet_traversal(struct wg_peer * peer)1117 wg_timers_event_any_authenticated_packet_traversal(struct wg_peer *peer)
1118 {
1119 struct epoch_tracker et;
1120 uint16_t interval;
1121 NET_EPOCH_ENTER(et);
1122 interval = atomic_load_16(&peer->p_persistent_keepalive_interval);
1123 if (atomic_load_bool(&peer->p_enabled) && interval > 0)
1124 callout_reset(&peer->p_persistent_keepalive,
1125 MSEC_2_TICKS(interval * 1000),
1126 wg_timers_run_persistent_keepalive, peer);
1127 NET_EPOCH_EXIT(et);
1128 }
1129
1130 static void
wg_timers_event_handshake_initiated(struct wg_peer * peer)1131 wg_timers_event_handshake_initiated(struct wg_peer *peer)
1132 {
1133 struct epoch_tracker et;
1134 NET_EPOCH_ENTER(et);
1135 if (atomic_load_bool(&peer->p_enabled))
1136 callout_reset(&peer->p_retry_handshake, MSEC_2_TICKS(
1137 REKEY_TIMEOUT * 1000 +
1138 arc4random_uniform(REKEY_TIMEOUT_JITTER)),
1139 wg_timers_run_retry_handshake, peer);
1140 NET_EPOCH_EXIT(et);
1141 }
1142
1143 static void
wg_timers_event_handshake_complete(struct wg_peer * peer)1144 wg_timers_event_handshake_complete(struct wg_peer *peer)
1145 {
1146 struct epoch_tracker et;
1147 NET_EPOCH_ENTER(et);
1148 if (atomic_load_bool(&peer->p_enabled)) {
1149 mtx_lock(&peer->p_handshake_mtx);
1150 callout_stop(&peer->p_retry_handshake);
1151 peer->p_handshake_retries = 0;
1152 getnanotime(&peer->p_handshake_complete);
1153 mtx_unlock(&peer->p_handshake_mtx);
1154 wg_timers_run_send_keepalive(peer);
1155 }
1156 NET_EPOCH_EXIT(et);
1157 }
1158
1159 static void
wg_timers_event_session_derived(struct wg_peer * peer)1160 wg_timers_event_session_derived(struct wg_peer *peer)
1161 {
1162 struct epoch_tracker et;
1163 NET_EPOCH_ENTER(et);
1164 if (atomic_load_bool(&peer->p_enabled))
1165 callout_reset(&peer->p_zero_key_material,
1166 MSEC_2_TICKS(REJECT_AFTER_TIME * 3 * 1000),
1167 wg_timers_run_zero_key_material, peer);
1168 NET_EPOCH_EXIT(et);
1169 }
1170
1171 static void
wg_timers_event_want_initiation(struct wg_peer * peer)1172 wg_timers_event_want_initiation(struct wg_peer *peer)
1173 {
1174 struct epoch_tracker et;
1175 NET_EPOCH_ENTER(et);
1176 if (atomic_load_bool(&peer->p_enabled))
1177 wg_timers_run_send_initiation(peer, false);
1178 NET_EPOCH_EXIT(et);
1179 }
1180
1181 static void
wg_timers_run_send_initiation(struct wg_peer * peer,bool is_retry)1182 wg_timers_run_send_initiation(struct wg_peer *peer, bool is_retry)
1183 {
1184 if (!is_retry)
1185 peer->p_handshake_retries = 0;
1186 if (noise_remote_initiation_expired(peer->p_remote) == ETIMEDOUT)
1187 wg_send_initiation(peer);
1188 }
1189
1190 static void
wg_timers_run_retry_handshake(void * _peer)1191 wg_timers_run_retry_handshake(void *_peer)
1192 {
1193 struct epoch_tracker et;
1194 struct wg_peer *peer = _peer;
1195
1196 mtx_lock(&peer->p_handshake_mtx);
1197 if (peer->p_handshake_retries <= MAX_TIMER_HANDSHAKES) {
1198 peer->p_handshake_retries++;
1199 mtx_unlock(&peer->p_handshake_mtx);
1200
1201 DPRINTF(peer->p_sc, "Handshake for peer %" PRIu64 " did not complete "
1202 "after %d seconds, retrying (try %d)\n", peer->p_id,
1203 REKEY_TIMEOUT, peer->p_handshake_retries + 1);
1204 wg_peer_clear_src(peer);
1205 wg_timers_run_send_initiation(peer, true);
1206 } else {
1207 mtx_unlock(&peer->p_handshake_mtx);
1208
1209 DPRINTF(peer->p_sc, "Handshake for peer %" PRIu64 " did not complete "
1210 "after %d retries, giving up\n", peer->p_id,
1211 MAX_TIMER_HANDSHAKES + 2);
1212
1213 callout_stop(&peer->p_send_keepalive);
1214 wg_queue_purge(&peer->p_stage_queue);
1215 NET_EPOCH_ENTER(et);
1216 if (atomic_load_bool(&peer->p_enabled) &&
1217 !callout_pending(&peer->p_zero_key_material))
1218 callout_reset(&peer->p_zero_key_material,
1219 MSEC_2_TICKS(REJECT_AFTER_TIME * 3 * 1000),
1220 wg_timers_run_zero_key_material, peer);
1221 NET_EPOCH_EXIT(et);
1222 }
1223 }
1224
1225 static void
wg_timers_run_send_keepalive(void * _peer)1226 wg_timers_run_send_keepalive(void *_peer)
1227 {
1228 struct epoch_tracker et;
1229 struct wg_peer *peer = _peer;
1230
1231 wg_send_keepalive(peer);
1232 NET_EPOCH_ENTER(et);
1233 if (atomic_load_bool(&peer->p_enabled) &&
1234 atomic_load_bool(&peer->p_need_another_keepalive)) {
1235 atomic_store_bool(&peer->p_need_another_keepalive, false);
1236 callout_reset(&peer->p_send_keepalive,
1237 MSEC_2_TICKS(KEEPALIVE_TIMEOUT * 1000),
1238 wg_timers_run_send_keepalive, peer);
1239 }
1240 NET_EPOCH_EXIT(et);
1241 }
1242
1243 static void
wg_timers_run_new_handshake(void * _peer)1244 wg_timers_run_new_handshake(void *_peer)
1245 {
1246 struct wg_peer *peer = _peer;
1247
1248 DPRINTF(peer->p_sc, "Retrying handshake with peer %" PRIu64 " because we "
1249 "stopped hearing back after %d seconds\n",
1250 peer->p_id, NEW_HANDSHAKE_TIMEOUT);
1251
1252 wg_peer_clear_src(peer);
1253 wg_timers_run_send_initiation(peer, false);
1254 }
1255
1256 static void
wg_timers_run_zero_key_material(void * _peer)1257 wg_timers_run_zero_key_material(void *_peer)
1258 {
1259 struct wg_peer *peer = _peer;
1260
1261 DPRINTF(peer->p_sc, "Zeroing out keys for peer %" PRIu64 ", since we "
1262 "haven't received a new one in %d seconds\n",
1263 peer->p_id, REJECT_AFTER_TIME * 3);
1264 noise_remote_keypairs_clear(peer->p_remote);
1265 }
1266
1267 static void
wg_timers_run_persistent_keepalive(void * _peer)1268 wg_timers_run_persistent_keepalive(void *_peer)
1269 {
1270 struct wg_peer *peer = _peer;
1271
1272 if (atomic_load_16(&peer->p_persistent_keepalive_interval) > 0)
1273 wg_send_keepalive(peer);
1274 }
1275
1276 /* TODO Handshake */
1277 static void
wg_peer_send_buf(struct wg_peer * peer,uint8_t * buf,size_t len)1278 wg_peer_send_buf(struct wg_peer *peer, uint8_t *buf, size_t len)
1279 {
1280 struct wg_endpoint endpoint;
1281
1282 counter_u64_add(peer->p_tx_bytes, len);
1283 wg_timers_event_any_authenticated_packet_traversal(peer);
1284 wg_timers_event_any_authenticated_packet_sent(peer);
1285 wg_peer_get_endpoint(peer, &endpoint);
1286 wg_send_buf(peer->p_sc, &endpoint, buf, len);
1287 }
1288
1289 static void
wg_send_initiation(struct wg_peer * peer)1290 wg_send_initiation(struct wg_peer *peer)
1291 {
1292 struct wg_pkt_initiation pkt;
1293
1294 if (noise_create_initiation(peer->p_remote, &pkt.s_idx, pkt.ue,
1295 pkt.es, pkt.ets) != 0)
1296 return;
1297
1298 DPRINTF(peer->p_sc, "Sending handshake initiation to peer %" PRIu64 "\n", peer->p_id);
1299
1300 pkt.t = WG_PKT_INITIATION;
1301 cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt,
1302 sizeof(pkt) - sizeof(pkt.m));
1303 wg_peer_send_buf(peer, (uint8_t *)&pkt, sizeof(pkt));
1304 wg_timers_event_handshake_initiated(peer);
1305 }
1306
1307 static void
wg_send_response(struct wg_peer * peer)1308 wg_send_response(struct wg_peer *peer)
1309 {
1310 struct wg_pkt_response pkt;
1311
1312 if (noise_create_response(peer->p_remote, &pkt.s_idx, &pkt.r_idx,
1313 pkt.ue, pkt.en) != 0)
1314 return;
1315
1316 DPRINTF(peer->p_sc, "Sending handshake response to peer %" PRIu64 "\n", peer->p_id);
1317
1318 wg_timers_event_session_derived(peer);
1319 pkt.t = WG_PKT_RESPONSE;
1320 cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt,
1321 sizeof(pkt)-sizeof(pkt.m));
1322 wg_peer_send_buf(peer, (uint8_t*)&pkt, sizeof(pkt));
1323 }
1324
1325 static void
wg_send_cookie(struct wg_softc * sc,struct cookie_macs * cm,uint32_t idx,struct wg_endpoint * e)1326 wg_send_cookie(struct wg_softc *sc, struct cookie_macs *cm, uint32_t idx,
1327 struct wg_endpoint *e)
1328 {
1329 struct wg_pkt_cookie pkt;
1330
1331 DPRINTF(sc, "Sending cookie response for denied handshake message\n");
1332
1333 pkt.t = WG_PKT_COOKIE;
1334 pkt.r_idx = idx;
1335
1336 cookie_checker_create_payload(&sc->sc_cookie, cm, pkt.nonce,
1337 pkt.ec, &e->e_remote.r_sa);
1338 wg_send_buf(sc, e, (uint8_t *)&pkt, sizeof(pkt));
1339 }
1340
1341 static void
wg_send_keepalive(struct wg_peer * peer)1342 wg_send_keepalive(struct wg_peer *peer)
1343 {
1344 struct wg_packet *pkt;
1345 struct mbuf *m;
1346
1347 if (wg_queue_len(&peer->p_stage_queue) > 0)
1348 goto send;
1349 if ((m = m_gethdr(M_NOWAIT, MT_DATA)) == NULL)
1350 return;
1351 if ((pkt = wg_packet_alloc(m)) == NULL) {
1352 m_freem(m);
1353 return;
1354 }
1355 wg_queue_push_staged(&peer->p_stage_queue, pkt);
1356 DPRINTF(peer->p_sc, "Sending keepalive packet to peer %" PRIu64 "\n", peer->p_id);
1357 send:
1358 wg_peer_send_staged(peer);
1359 }
1360
1361 static void
wg_handshake(struct wg_softc * sc,struct wg_packet * pkt)1362 wg_handshake(struct wg_softc *sc, struct wg_packet *pkt)
1363 {
1364 struct wg_pkt_initiation *init;
1365 struct wg_pkt_response *resp;
1366 struct wg_pkt_cookie *cook;
1367 struct wg_endpoint *e;
1368 struct wg_peer *peer;
1369 struct mbuf *m;
1370 struct noise_remote *remote = NULL;
1371 int res;
1372 bool underload = false;
1373 static sbintime_t wg_last_underload; /* sbinuptime */
1374
1375 underload = wg_queue_len(&sc->sc_handshake_queue) >= MAX_QUEUED_HANDSHAKES / 8;
1376 if (underload) {
1377 wg_last_underload = getsbinuptime();
1378 } else if (wg_last_underload) {
1379 underload = wg_last_underload + UNDERLOAD_TIMEOUT * SBT_1S > getsbinuptime();
1380 if (!underload)
1381 wg_last_underload = 0;
1382 }
1383
1384 m = pkt->p_mbuf;
1385 e = &pkt->p_endpoint;
1386
1387 if ((pkt->p_mbuf = m = m_pullup(m, m->m_pkthdr.len)) == NULL)
1388 goto error;
1389
1390 switch (*mtod(m, uint32_t *)) {
1391 case WG_PKT_INITIATION:
1392 init = mtod(m, struct wg_pkt_initiation *);
1393
1394 res = cookie_checker_validate_macs(&sc->sc_cookie, &init->m,
1395 init, sizeof(*init) - sizeof(init->m),
1396 underload, &e->e_remote.r_sa,
1397 if_getvnet(sc->sc_ifp));
1398
1399 if (res == EINVAL) {
1400 DPRINTF(sc, "Invalid initiation MAC\n");
1401 goto error;
1402 } else if (res == ECONNREFUSED) {
1403 DPRINTF(sc, "Handshake ratelimited\n");
1404 goto error;
1405 } else if (res == EAGAIN) {
1406 wg_send_cookie(sc, &init->m, init->s_idx, e);
1407 goto error;
1408 } else if (res != 0) {
1409 panic("unexpected response: %d\n", res);
1410 }
1411
1412 if (noise_consume_initiation(sc->sc_local, &remote,
1413 init->s_idx, init->ue, init->es, init->ets) != 0) {
1414 DPRINTF(sc, "Invalid handshake initiation\n");
1415 goto error;
1416 }
1417
1418 peer = noise_remote_arg(remote);
1419
1420 DPRINTF(sc, "Receiving handshake initiation from peer %" PRIu64 "\n", peer->p_id);
1421
1422 wg_peer_set_endpoint(peer, e);
1423 wg_send_response(peer);
1424 break;
1425 case WG_PKT_RESPONSE:
1426 resp = mtod(m, struct wg_pkt_response *);
1427
1428 res = cookie_checker_validate_macs(&sc->sc_cookie, &resp->m,
1429 resp, sizeof(*resp) - sizeof(resp->m),
1430 underload, &e->e_remote.r_sa,
1431 if_getvnet(sc->sc_ifp));
1432
1433 if (res == EINVAL) {
1434 DPRINTF(sc, "Invalid response MAC\n");
1435 goto error;
1436 } else if (res == ECONNREFUSED) {
1437 DPRINTF(sc, "Handshake ratelimited\n");
1438 goto error;
1439 } else if (res == EAGAIN) {
1440 wg_send_cookie(sc, &resp->m, resp->s_idx, e);
1441 goto error;
1442 } else if (res != 0) {
1443 panic("unexpected response: %d\n", res);
1444 }
1445
1446 if (noise_consume_response(sc->sc_local, &remote,
1447 resp->s_idx, resp->r_idx, resp->ue, resp->en) != 0) {
1448 DPRINTF(sc, "Invalid handshake response\n");
1449 goto error;
1450 }
1451
1452 peer = noise_remote_arg(remote);
1453 DPRINTF(sc, "Receiving handshake response from peer %" PRIu64 "\n", peer->p_id);
1454
1455 wg_peer_set_endpoint(peer, e);
1456 wg_timers_event_session_derived(peer);
1457 wg_timers_event_handshake_complete(peer);
1458 break;
1459 case WG_PKT_COOKIE:
1460 cook = mtod(m, struct wg_pkt_cookie *);
1461
1462 if ((remote = noise_remote_index(sc->sc_local, cook->r_idx)) == NULL) {
1463 DPRINTF(sc, "Unknown cookie index\n");
1464 goto error;
1465 }
1466
1467 peer = noise_remote_arg(remote);
1468
1469 if (cookie_maker_consume_payload(&peer->p_cookie,
1470 cook->nonce, cook->ec) == 0) {
1471 DPRINTF(sc, "Receiving cookie response\n");
1472 } else {
1473 DPRINTF(sc, "Could not decrypt cookie response\n");
1474 goto error;
1475 }
1476
1477 goto not_authenticated;
1478 default:
1479 panic("invalid packet in handshake queue");
1480 }
1481
1482 wg_timers_event_any_authenticated_packet_received(peer);
1483 wg_timers_event_any_authenticated_packet_traversal(peer);
1484
1485 not_authenticated:
1486 counter_u64_add(peer->p_rx_bytes, m->m_pkthdr.len);
1487 if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1);
1488 if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, m->m_pkthdr.len);
1489 error:
1490 if (remote != NULL)
1491 noise_remote_put(remote);
1492 wg_packet_free(pkt);
1493 }
1494
1495 static void
wg_softc_handshake_receive(struct wg_softc * sc)1496 wg_softc_handshake_receive(struct wg_softc *sc)
1497 {
1498 struct wg_packet *pkt;
1499 while ((pkt = wg_queue_dequeue_handshake(&sc->sc_handshake_queue)) != NULL)
1500 wg_handshake(sc, pkt);
1501 }
1502
1503 static void
wg_mbuf_reset(struct mbuf * m)1504 wg_mbuf_reset(struct mbuf *m)
1505 {
1506
1507 struct m_tag *t, *tmp;
1508
1509 /*
1510 * We want to reset the mbuf to a newly allocated state, containing
1511 * just the packet contents. Unfortunately FreeBSD doesn't seem to
1512 * offer this anywhere, so we have to make it up as we go. If we can
1513 * get this in kern/kern_mbuf.c, that would be best.
1514 *
1515 * Notice: this may break things unexpectedly but it is better to fail
1516 * closed in the extreme case than leak informtion in every
1517 * case.
1518 *
1519 * With that said, all this attempts to do is remove any extraneous
1520 * information that could be present.
1521 */
1522
1523 M_ASSERTPKTHDR(m);
1524
1525 m->m_flags &= ~(M_BCAST|M_MCAST|M_VLANTAG|M_PROMISC|M_PROTOFLAGS);
1526
1527 M_HASHTYPE_CLEAR(m);
1528 #ifdef NUMA
1529 m->m_pkthdr.numa_domain = M_NODOM;
1530 #endif
1531 SLIST_FOREACH_SAFE(t, &m->m_pkthdr.tags, m_tag_link, tmp) {
1532 if ((t->m_tag_id != 0 || t->m_tag_cookie != MTAG_WGLOOP) &&
1533 t->m_tag_id != PACKET_TAG_MACLABEL)
1534 m_tag_delete(m, t);
1535 }
1536
1537 KASSERT((m->m_pkthdr.csum_flags & CSUM_SND_TAG) == 0,
1538 ("%s: mbuf %p has a send tag", __func__, m));
1539
1540 m->m_pkthdr.csum_flags = 0;
1541 m->m_pkthdr.PH_per.sixtyfour[0] = 0;
1542 m->m_pkthdr.PH_loc.sixtyfour[0] = 0;
1543 }
1544
1545 static inline unsigned int
calculate_padding(struct wg_packet * pkt)1546 calculate_padding(struct wg_packet *pkt)
1547 {
1548 unsigned int padded_size, last_unit = pkt->p_mbuf->m_pkthdr.len;
1549
1550 /* Keepalive packets don't set p_mtu, but also have a length of zero. */
1551 if (__predict_false(pkt->p_mtu == 0)) {
1552 padded_size = (last_unit + (WG_PKT_PADDING - 1)) &
1553 ~(WG_PKT_PADDING - 1);
1554 return (padded_size - last_unit);
1555 }
1556
1557 if (__predict_false(last_unit > pkt->p_mtu))
1558 last_unit %= pkt->p_mtu;
1559
1560 padded_size = (last_unit + (WG_PKT_PADDING - 1)) & ~(WG_PKT_PADDING - 1);
1561 if (pkt->p_mtu < padded_size)
1562 padded_size = pkt->p_mtu;
1563 return (padded_size - last_unit);
1564 }
1565
1566 static void
wg_encrypt(struct wg_softc * sc,struct wg_packet * pkt)1567 wg_encrypt(struct wg_softc *sc, struct wg_packet *pkt)
1568 {
1569 static const uint8_t padding[WG_PKT_PADDING] = { 0 };
1570 struct wg_pkt_data *data;
1571 struct wg_peer *peer;
1572 struct noise_remote *remote;
1573 struct mbuf *m;
1574 uint32_t idx;
1575 unsigned int padlen;
1576 enum wg_ring_state state = WG_PACKET_DEAD;
1577
1578 remote = noise_keypair_remote(pkt->p_keypair);
1579 peer = noise_remote_arg(remote);
1580 m = pkt->p_mbuf;
1581
1582 /* Pad the packet */
1583 padlen = calculate_padding(pkt);
1584 if (padlen != 0 && !m_append(m, padlen, padding))
1585 goto out;
1586
1587 /* Do encryption */
1588 if (noise_keypair_encrypt(pkt->p_keypair, &idx, pkt->p_nonce, m) != 0)
1589 goto out;
1590
1591 /* Put header into packet */
1592 M_PREPEND(m, sizeof(struct wg_pkt_data), M_NOWAIT);
1593 if (m == NULL)
1594 goto out;
1595 data = mtod(m, struct wg_pkt_data *);
1596 data->t = WG_PKT_DATA;
1597 data->r_idx = idx;
1598 data->nonce = htole64(pkt->p_nonce);
1599
1600 wg_mbuf_reset(m);
1601 state = WG_PACKET_CRYPTED;
1602 out:
1603 pkt->p_mbuf = m;
1604 atomic_store_rel_int(&pkt->p_state, state);
1605 GROUPTASK_ENQUEUE(&peer->p_send);
1606 noise_remote_put(remote);
1607 }
1608
1609 static void
wg_decrypt(struct wg_softc * sc,struct wg_packet * pkt)1610 wg_decrypt(struct wg_softc *sc, struct wg_packet *pkt)
1611 {
1612 struct wg_peer *peer, *allowed_peer;
1613 struct noise_remote *remote;
1614 struct mbuf *m;
1615 int len;
1616 enum wg_ring_state state = WG_PACKET_DEAD;
1617
1618 remote = noise_keypair_remote(pkt->p_keypair);
1619 peer = noise_remote_arg(remote);
1620 m = pkt->p_mbuf;
1621
1622 /* Read nonce and then adjust to remove the header. */
1623 pkt->p_nonce = le64toh(mtod(m, struct wg_pkt_data *)->nonce);
1624 m_adj(m, sizeof(struct wg_pkt_data));
1625
1626 if (noise_keypair_decrypt(pkt->p_keypair, pkt->p_nonce, m) != 0)
1627 goto out;
1628
1629 /* A packet with length 0 is a keepalive packet */
1630 if (__predict_false(m->m_pkthdr.len == 0)) {
1631 DPRINTF(sc, "Receiving keepalive packet from peer "
1632 "%" PRIu64 "\n", peer->p_id);
1633 state = WG_PACKET_CRYPTED;
1634 goto out;
1635 }
1636
1637 /*
1638 * We can let the network stack handle the intricate validation of the
1639 * IP header, we just worry about the sizeof and the version, so we can
1640 * read the source address in wg_aip_lookup.
1641 */
1642
1643 if (determine_af_and_pullup(&m, &pkt->p_af) == 0) {
1644 if (pkt->p_af == AF_INET) {
1645 struct ip *ip = mtod(m, struct ip *);
1646 allowed_peer = wg_aip_lookup(sc, AF_INET, &ip->ip_src);
1647 len = ntohs(ip->ip_len);
1648 if (len >= sizeof(struct ip) && len < m->m_pkthdr.len)
1649 m_adj(m, len - m->m_pkthdr.len);
1650 } else if (pkt->p_af == AF_INET6) {
1651 struct ip6_hdr *ip6 = mtod(m, struct ip6_hdr *);
1652 allowed_peer = wg_aip_lookup(sc, AF_INET6, &ip6->ip6_src);
1653 len = ntohs(ip6->ip6_plen) + sizeof(struct ip6_hdr);
1654 if (len < m->m_pkthdr.len)
1655 m_adj(m, len - m->m_pkthdr.len);
1656 } else
1657 panic("determine_af_and_pullup returned unexpected value");
1658 } else {
1659 DPRINTF(sc, "Packet is neither ipv4 nor ipv6 from peer %" PRIu64 "\n", peer->p_id);
1660 goto out;
1661 }
1662
1663 /* We only want to compare the address, not dereference, so drop the ref. */
1664 if (allowed_peer != NULL)
1665 noise_remote_put(allowed_peer->p_remote);
1666
1667 if (__predict_false(peer != allowed_peer)) {
1668 DPRINTF(sc, "Packet has unallowed src IP from peer %" PRIu64 "\n", peer->p_id);
1669 goto out;
1670 }
1671
1672 wg_mbuf_reset(m);
1673 state = WG_PACKET_CRYPTED;
1674 out:
1675 pkt->p_mbuf = m;
1676 atomic_store_rel_int(&pkt->p_state, state);
1677 GROUPTASK_ENQUEUE(&peer->p_recv);
1678 noise_remote_put(remote);
1679 }
1680
1681 static void
wg_softc_decrypt(struct wg_softc * sc)1682 wg_softc_decrypt(struct wg_softc *sc)
1683 {
1684 struct wg_packet *pkt;
1685
1686 while ((pkt = wg_queue_dequeue_parallel(&sc->sc_decrypt_parallel)) != NULL)
1687 wg_decrypt(sc, pkt);
1688 }
1689
1690 static void
wg_softc_encrypt(struct wg_softc * sc)1691 wg_softc_encrypt(struct wg_softc *sc)
1692 {
1693 struct wg_packet *pkt;
1694
1695 while ((pkt = wg_queue_dequeue_parallel(&sc->sc_encrypt_parallel)) != NULL)
1696 wg_encrypt(sc, pkt);
1697 }
1698
1699 static void
wg_encrypt_dispatch(struct wg_softc * sc)1700 wg_encrypt_dispatch(struct wg_softc *sc)
1701 {
1702 /*
1703 * The update to encrypt_last_cpu is racey such that we may
1704 * reschedule the task for the same CPU multiple times, but
1705 * the race doesn't really matter.
1706 */
1707 u_int cpu = (sc->sc_encrypt_last_cpu + 1) % mp_ncpus;
1708 sc->sc_encrypt_last_cpu = cpu;
1709 GROUPTASK_ENQUEUE(&sc->sc_encrypt[cpu]);
1710 }
1711
1712 static void
wg_decrypt_dispatch(struct wg_softc * sc)1713 wg_decrypt_dispatch(struct wg_softc *sc)
1714 {
1715 u_int cpu = (sc->sc_decrypt_last_cpu + 1) % mp_ncpus;
1716 sc->sc_decrypt_last_cpu = cpu;
1717 GROUPTASK_ENQUEUE(&sc->sc_decrypt[cpu]);
1718 }
1719
1720 static void
wg_deliver_out(struct wg_peer * peer)1721 wg_deliver_out(struct wg_peer *peer)
1722 {
1723 struct wg_endpoint endpoint;
1724 struct wg_softc *sc = peer->p_sc;
1725 struct wg_packet *pkt;
1726 struct mbuf *m;
1727 int rc, len;
1728
1729 wg_peer_get_endpoint(peer, &endpoint);
1730
1731 while ((pkt = wg_queue_dequeue_serial(&peer->p_encrypt_serial)) != NULL) {
1732 if (atomic_load_acq_int(&pkt->p_state) != WG_PACKET_CRYPTED)
1733 goto error;
1734
1735 m = pkt->p_mbuf;
1736 pkt->p_mbuf = NULL;
1737
1738 len = m->m_pkthdr.len;
1739
1740 wg_timers_event_any_authenticated_packet_traversal(peer);
1741 wg_timers_event_any_authenticated_packet_sent(peer);
1742 rc = wg_send(sc, &endpoint, m);
1743 if (rc == 0) {
1744 if (len > (sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN))
1745 wg_timers_event_data_sent(peer);
1746 counter_u64_add(peer->p_tx_bytes, len);
1747 } else if (rc == EADDRNOTAVAIL) {
1748 wg_peer_clear_src(peer);
1749 wg_peer_get_endpoint(peer, &endpoint);
1750 goto error;
1751 } else {
1752 goto error;
1753 }
1754 wg_packet_free(pkt);
1755 if (noise_keep_key_fresh_send(peer->p_remote))
1756 wg_timers_event_want_initiation(peer);
1757 continue;
1758 error:
1759 if_inc_counter(sc->sc_ifp, IFCOUNTER_OERRORS, 1);
1760 wg_packet_free(pkt);
1761 }
1762 }
1763
1764 #ifdef DEV_NETMAP
1765 /*
1766 * Hand a packet to the netmap RX ring, via netmap's
1767 * freebsd_generic_rx_handler().
1768 */
1769 static void
wg_deliver_netmap(if_t ifp,struct mbuf * m,int af)1770 wg_deliver_netmap(if_t ifp, struct mbuf *m, int af)
1771 {
1772 struct ether_header *eh;
1773
1774 M_PREPEND(m, ETHER_HDR_LEN, M_NOWAIT);
1775 if (__predict_false(m == NULL)) {
1776 if_inc_counter(ifp, IFCOUNTER_IQDROPS, 1);
1777 return;
1778 }
1779
1780 eh = mtod(m, struct ether_header *);
1781 eh->ether_type = af == AF_INET ?
1782 htons(ETHERTYPE_IP) : htons(ETHERTYPE_IPV6);
1783 memcpy(eh->ether_shost, "\x02\x02\x02\x02\x02\x02", ETHER_ADDR_LEN);
1784 memcpy(eh->ether_dhost, "\xff\xff\xff\xff\xff\xff", ETHER_ADDR_LEN);
1785 if_input(ifp, m);
1786 }
1787 #endif
1788
1789 static void
wg_deliver_in(struct wg_peer * peer)1790 wg_deliver_in(struct wg_peer *peer)
1791 {
1792 struct wg_softc *sc = peer->p_sc;
1793 if_t ifp = sc->sc_ifp;
1794 struct wg_packet *pkt;
1795 struct mbuf *m;
1796 struct epoch_tracker et;
1797 int af;
1798
1799 while ((pkt = wg_queue_dequeue_serial(&peer->p_decrypt_serial)) != NULL) {
1800 if (atomic_load_acq_int(&pkt->p_state) != WG_PACKET_CRYPTED)
1801 goto error;
1802
1803 m = pkt->p_mbuf;
1804 if (noise_keypair_nonce_check(pkt->p_keypair, pkt->p_nonce) != 0)
1805 goto error;
1806
1807 if (noise_keypair_received_with(pkt->p_keypair) == ECONNRESET)
1808 wg_timers_event_handshake_complete(peer);
1809
1810 wg_timers_event_any_authenticated_packet_received(peer);
1811 wg_timers_event_any_authenticated_packet_traversal(peer);
1812 wg_peer_set_endpoint(peer, &pkt->p_endpoint);
1813
1814 counter_u64_add(peer->p_rx_bytes, m->m_pkthdr.len +
1815 sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN);
1816 if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1);
1817 if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, m->m_pkthdr.len +
1818 sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN);
1819
1820 if (m->m_pkthdr.len == 0)
1821 goto done;
1822
1823 af = pkt->p_af;
1824 MPASS(af == AF_INET || af == AF_INET6);
1825 pkt->p_mbuf = NULL;
1826
1827 m->m_pkthdr.rcvif = ifp;
1828
1829 NET_EPOCH_ENTER(et);
1830 BPF_MTAP2_AF(ifp, m, af);
1831
1832 CURVNET_SET(if_getvnet(ifp));
1833 M_SETFIB(m, if_getfib(ifp));
1834 #ifdef DEV_NETMAP
1835 if ((if_getcapenable(ifp) & IFCAP_NETMAP) != 0)
1836 wg_deliver_netmap(ifp, m, af);
1837 else
1838 #endif
1839 if (af == AF_INET)
1840 netisr_dispatch(NETISR_IP, m);
1841 else if (af == AF_INET6)
1842 netisr_dispatch(NETISR_IPV6, m);
1843 CURVNET_RESTORE();
1844 NET_EPOCH_EXIT(et);
1845
1846 wg_timers_event_data_received(peer);
1847
1848 done:
1849 if (noise_keep_key_fresh_recv(peer->p_remote))
1850 wg_timers_event_want_initiation(peer);
1851 wg_packet_free(pkt);
1852 continue;
1853 error:
1854 if_inc_counter(ifp, IFCOUNTER_IERRORS, 1);
1855 wg_packet_free(pkt);
1856 }
1857 }
1858
1859 static struct wg_packet *
wg_packet_alloc(struct mbuf * m)1860 wg_packet_alloc(struct mbuf *m)
1861 {
1862 struct wg_packet *pkt;
1863
1864 if ((pkt = uma_zalloc(wg_packet_zone, M_NOWAIT | M_ZERO)) == NULL)
1865 return (NULL);
1866 pkt->p_mbuf = m;
1867 return (pkt);
1868 }
1869
1870 static void
wg_packet_free(struct wg_packet * pkt)1871 wg_packet_free(struct wg_packet *pkt)
1872 {
1873 if (pkt->p_keypair != NULL)
1874 noise_keypair_put(pkt->p_keypair);
1875 if (pkt->p_mbuf != NULL)
1876 m_freem(pkt->p_mbuf);
1877 uma_zfree(wg_packet_zone, pkt);
1878 }
1879
1880 static void
wg_queue_init(struct wg_queue * queue,const char * name)1881 wg_queue_init(struct wg_queue *queue, const char *name)
1882 {
1883 mtx_init(&queue->q_mtx, name, NULL, MTX_DEF);
1884 STAILQ_INIT(&queue->q_queue);
1885 queue->q_len = 0;
1886 }
1887
1888 static void
wg_queue_deinit(struct wg_queue * queue)1889 wg_queue_deinit(struct wg_queue *queue)
1890 {
1891 wg_queue_purge(queue);
1892 mtx_destroy(&queue->q_mtx);
1893 }
1894
1895 static size_t
wg_queue_len(struct wg_queue * queue)1896 wg_queue_len(struct wg_queue *queue)
1897 {
1898 return (queue->q_len);
1899 }
1900
1901 static int
wg_queue_enqueue_handshake(struct wg_queue * hs,struct wg_packet * pkt)1902 wg_queue_enqueue_handshake(struct wg_queue *hs, struct wg_packet *pkt)
1903 {
1904 int ret = 0;
1905 mtx_lock(&hs->q_mtx);
1906 if (hs->q_len < MAX_QUEUED_HANDSHAKES) {
1907 STAILQ_INSERT_TAIL(&hs->q_queue, pkt, p_parallel);
1908 hs->q_len++;
1909 } else {
1910 ret = ENOBUFS;
1911 }
1912 mtx_unlock(&hs->q_mtx);
1913 if (ret != 0)
1914 wg_packet_free(pkt);
1915 return (ret);
1916 }
1917
1918 static struct wg_packet *
wg_queue_dequeue_handshake(struct wg_queue * hs)1919 wg_queue_dequeue_handshake(struct wg_queue *hs)
1920 {
1921 struct wg_packet *pkt;
1922 mtx_lock(&hs->q_mtx);
1923 if ((pkt = STAILQ_FIRST(&hs->q_queue)) != NULL) {
1924 STAILQ_REMOVE_HEAD(&hs->q_queue, p_parallel);
1925 hs->q_len--;
1926 }
1927 mtx_unlock(&hs->q_mtx);
1928 return (pkt);
1929 }
1930
1931 static void
wg_queue_push_staged(struct wg_queue * staged,struct wg_packet * pkt)1932 wg_queue_push_staged(struct wg_queue *staged, struct wg_packet *pkt)
1933 {
1934 struct wg_packet *old = NULL;
1935
1936 mtx_lock(&staged->q_mtx);
1937 if (staged->q_len >= MAX_STAGED_PKT) {
1938 old = STAILQ_FIRST(&staged->q_queue);
1939 STAILQ_REMOVE_HEAD(&staged->q_queue, p_parallel);
1940 staged->q_len--;
1941 }
1942 STAILQ_INSERT_TAIL(&staged->q_queue, pkt, p_parallel);
1943 staged->q_len++;
1944 mtx_unlock(&staged->q_mtx);
1945
1946 if (old != NULL)
1947 wg_packet_free(old);
1948 }
1949
1950 static void
wg_queue_enlist_staged(struct wg_queue * staged,struct wg_packet_list * list)1951 wg_queue_enlist_staged(struct wg_queue *staged, struct wg_packet_list *list)
1952 {
1953 struct wg_packet *pkt, *tpkt;
1954 STAILQ_FOREACH_SAFE(pkt, list, p_parallel, tpkt)
1955 wg_queue_push_staged(staged, pkt);
1956 }
1957
1958 static void
wg_queue_delist_staged(struct wg_queue * staged,struct wg_packet_list * list)1959 wg_queue_delist_staged(struct wg_queue *staged, struct wg_packet_list *list)
1960 {
1961 STAILQ_INIT(list);
1962 mtx_lock(&staged->q_mtx);
1963 STAILQ_CONCAT(list, &staged->q_queue);
1964 staged->q_len = 0;
1965 mtx_unlock(&staged->q_mtx);
1966 }
1967
1968 static void
wg_queue_purge(struct wg_queue * staged)1969 wg_queue_purge(struct wg_queue *staged)
1970 {
1971 struct wg_packet_list list;
1972 struct wg_packet *pkt, *tpkt;
1973 wg_queue_delist_staged(staged, &list);
1974 STAILQ_FOREACH_SAFE(pkt, &list, p_parallel, tpkt)
1975 wg_packet_free(pkt);
1976 }
1977
1978 static int
wg_queue_both(struct wg_queue * parallel,struct wg_queue * serial,struct wg_packet * pkt)1979 wg_queue_both(struct wg_queue *parallel, struct wg_queue *serial, struct wg_packet *pkt)
1980 {
1981 pkt->p_state = WG_PACKET_UNCRYPTED;
1982
1983 mtx_lock(&serial->q_mtx);
1984 if (serial->q_len < MAX_QUEUED_PKT) {
1985 serial->q_len++;
1986 STAILQ_INSERT_TAIL(&serial->q_queue, pkt, p_serial);
1987 } else {
1988 mtx_unlock(&serial->q_mtx);
1989 wg_packet_free(pkt);
1990 return (ENOBUFS);
1991 }
1992 mtx_unlock(&serial->q_mtx);
1993
1994 mtx_lock(¶llel->q_mtx);
1995 if (parallel->q_len < MAX_QUEUED_PKT) {
1996 parallel->q_len++;
1997 STAILQ_INSERT_TAIL(¶llel->q_queue, pkt, p_parallel);
1998 } else {
1999 mtx_unlock(¶llel->q_mtx);
2000 pkt->p_state = WG_PACKET_DEAD;
2001 return (ENOBUFS);
2002 }
2003 mtx_unlock(¶llel->q_mtx);
2004
2005 return (0);
2006 }
2007
2008 static struct wg_packet *
wg_queue_dequeue_serial(struct wg_queue * serial)2009 wg_queue_dequeue_serial(struct wg_queue *serial)
2010 {
2011 struct wg_packet *pkt = NULL;
2012 mtx_lock(&serial->q_mtx);
2013 if (serial->q_len > 0 && STAILQ_FIRST(&serial->q_queue)->p_state != WG_PACKET_UNCRYPTED) {
2014 serial->q_len--;
2015 pkt = STAILQ_FIRST(&serial->q_queue);
2016 STAILQ_REMOVE_HEAD(&serial->q_queue, p_serial);
2017 }
2018 mtx_unlock(&serial->q_mtx);
2019 return (pkt);
2020 }
2021
2022 static struct wg_packet *
wg_queue_dequeue_parallel(struct wg_queue * parallel)2023 wg_queue_dequeue_parallel(struct wg_queue *parallel)
2024 {
2025 struct wg_packet *pkt = NULL;
2026 mtx_lock(¶llel->q_mtx);
2027 if (parallel->q_len > 0) {
2028 parallel->q_len--;
2029 pkt = STAILQ_FIRST(¶llel->q_queue);
2030 STAILQ_REMOVE_HEAD(¶llel->q_queue, p_parallel);
2031 }
2032 mtx_unlock(¶llel->q_mtx);
2033 return (pkt);
2034 }
2035
2036 static bool
wg_input(struct mbuf * m,int offset,struct inpcb * inpcb,const struct sockaddr * sa,void * _sc)2037 wg_input(struct mbuf *m, int offset, struct inpcb *inpcb,
2038 const struct sockaddr *sa, void *_sc)
2039 {
2040 #ifdef INET
2041 const struct sockaddr_in *sin;
2042 #endif
2043 #ifdef INET6
2044 const struct sockaddr_in6 *sin6;
2045 #endif
2046 struct noise_remote *remote;
2047 struct wg_pkt_data *data;
2048 struct wg_packet *pkt;
2049 struct wg_peer *peer;
2050 struct wg_softc *sc = _sc;
2051 struct mbuf *defragged;
2052
2053 defragged = m_defrag(m, M_NOWAIT);
2054 if (defragged)
2055 m = defragged;
2056 m = m_unshare(m, M_NOWAIT);
2057 if (!m) {
2058 if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
2059 return true;
2060 }
2061
2062 /* Caller provided us with `sa`, no need for this header. */
2063 m_adj(m, offset + sizeof(struct udphdr));
2064
2065 /* Pullup enough to read packet type */
2066 if ((m = m_pullup(m, sizeof(uint32_t))) == NULL) {
2067 if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
2068 return true;
2069 }
2070
2071 if ((pkt = wg_packet_alloc(m)) == NULL) {
2072 if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
2073 m_freem(m);
2074 return true;
2075 }
2076
2077 /* Save send/recv address and port for later. */
2078 switch (sa->sa_family) {
2079 #ifdef INET
2080 case AF_INET:
2081 sin = (const struct sockaddr_in *)sa;
2082 pkt->p_endpoint.e_remote.r_sin = sin[0];
2083 pkt->p_endpoint.e_local.l_in = sin[1].sin_addr;
2084 break;
2085 #endif
2086 #ifdef INET6
2087 case AF_INET6:
2088 sin6 = (const struct sockaddr_in6 *)sa;
2089 pkt->p_endpoint.e_remote.r_sin6 = sin6[0];
2090 pkt->p_endpoint.e_local.l_in6 = sin6[1].sin6_addr;
2091 break;
2092 #endif
2093 default:
2094 goto error;
2095 }
2096
2097 if ((m->m_pkthdr.len == sizeof(struct wg_pkt_initiation) &&
2098 *mtod(m, uint32_t *) == WG_PKT_INITIATION) ||
2099 (m->m_pkthdr.len == sizeof(struct wg_pkt_response) &&
2100 *mtod(m, uint32_t *) == WG_PKT_RESPONSE) ||
2101 (m->m_pkthdr.len == sizeof(struct wg_pkt_cookie) &&
2102 *mtod(m, uint32_t *) == WG_PKT_COOKIE)) {
2103
2104 if (wg_queue_enqueue_handshake(&sc->sc_handshake_queue, pkt) != 0) {
2105 if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
2106 DPRINTF(sc, "Dropping handshake packet\n");
2107 }
2108 GROUPTASK_ENQUEUE(&sc->sc_handshake);
2109 } else if (m->m_pkthdr.len >= sizeof(struct wg_pkt_data) +
2110 NOISE_AUTHTAG_LEN && *mtod(m, uint32_t *) == WG_PKT_DATA) {
2111
2112 /* Pullup whole header to read r_idx below. */
2113 if ((pkt->p_mbuf = m_pullup(m, sizeof(struct wg_pkt_data))) == NULL)
2114 goto error;
2115
2116 data = mtod(pkt->p_mbuf, struct wg_pkt_data *);
2117 if ((pkt->p_keypair = noise_keypair_lookup(sc->sc_local, data->r_idx)) == NULL)
2118 goto error;
2119
2120 remote = noise_keypair_remote(pkt->p_keypair);
2121 peer = noise_remote_arg(remote);
2122 if (wg_queue_both(&sc->sc_decrypt_parallel, &peer->p_decrypt_serial, pkt) != 0)
2123 if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
2124 wg_decrypt_dispatch(sc);
2125 noise_remote_put(remote);
2126 } else {
2127 goto error;
2128 }
2129 return true;
2130 error:
2131 if_inc_counter(sc->sc_ifp, IFCOUNTER_IERRORS, 1);
2132 wg_packet_free(pkt);
2133 return true;
2134 }
2135
2136 static void
wg_peer_send_staged(struct wg_peer * peer)2137 wg_peer_send_staged(struct wg_peer *peer)
2138 {
2139 struct wg_packet_list list;
2140 struct noise_keypair *keypair;
2141 struct wg_packet *pkt, *tpkt;
2142 struct wg_softc *sc = peer->p_sc;
2143
2144 wg_queue_delist_staged(&peer->p_stage_queue, &list);
2145
2146 if (STAILQ_EMPTY(&list))
2147 return;
2148
2149 if ((keypair = noise_keypair_current(peer->p_remote)) == NULL)
2150 goto error;
2151
2152 STAILQ_FOREACH(pkt, &list, p_parallel) {
2153 if (noise_keypair_nonce_next(keypair, &pkt->p_nonce) != 0)
2154 goto error_keypair;
2155 }
2156 STAILQ_FOREACH_SAFE(pkt, &list, p_parallel, tpkt) {
2157 pkt->p_keypair = noise_keypair_ref(keypair);
2158 if (wg_queue_both(&sc->sc_encrypt_parallel, &peer->p_encrypt_serial, pkt) != 0)
2159 if_inc_counter(sc->sc_ifp, IFCOUNTER_OQDROPS, 1);
2160 }
2161 wg_encrypt_dispatch(sc);
2162 noise_keypair_put(keypair);
2163 return;
2164
2165 error_keypair:
2166 noise_keypair_put(keypair);
2167 error:
2168 wg_queue_enlist_staged(&peer->p_stage_queue, &list);
2169 wg_timers_event_want_initiation(peer);
2170 }
2171
2172 static inline void
xmit_err(if_t ifp,struct mbuf * m,struct wg_packet * pkt,sa_family_t af)2173 xmit_err(if_t ifp, struct mbuf *m, struct wg_packet *pkt, sa_family_t af)
2174 {
2175 if_inc_counter(ifp, IFCOUNTER_OERRORS, 1);
2176 switch (af) {
2177 #ifdef INET
2178 case AF_INET:
2179 icmp_error(m, ICMP_UNREACH, ICMP_UNREACH_HOST, 0, 0);
2180 if (pkt)
2181 pkt->p_mbuf = NULL;
2182 m = NULL;
2183 break;
2184 #endif
2185 #ifdef INET6
2186 case AF_INET6:
2187 icmp6_error(m, ICMP6_DST_UNREACH, 0, 0);
2188 if (pkt)
2189 pkt->p_mbuf = NULL;
2190 m = NULL;
2191 break;
2192 #endif
2193 }
2194 if (pkt)
2195 wg_packet_free(pkt);
2196 else if (m)
2197 m_freem(m);
2198 }
2199
2200 static int
wg_xmit(if_t ifp,struct mbuf * m,sa_family_t af,uint32_t mtu)2201 wg_xmit(if_t ifp, struct mbuf *m, sa_family_t af, uint32_t mtu)
2202 {
2203 struct wg_packet *pkt = NULL;
2204 struct wg_softc *sc = if_getsoftc(ifp);
2205 struct wg_peer *peer;
2206 int rc = 0;
2207 sa_family_t peer_af;
2208
2209 /* Work around lifetime issue in the ipv6 mld code. */
2210 if (__predict_false((if_getflags(ifp) & IFF_DYING) || !sc)) {
2211 rc = ENXIO;
2212 goto err_xmit;
2213 }
2214
2215 if ((pkt = wg_packet_alloc(m)) == NULL) {
2216 rc = ENOBUFS;
2217 goto err_xmit;
2218 }
2219 pkt->p_mtu = mtu;
2220 pkt->p_af = af;
2221
2222 if (af == AF_INET) {
2223 peer = wg_aip_lookup(sc, AF_INET, &mtod(m, struct ip *)->ip_dst);
2224 } else if (af == AF_INET6) {
2225 peer = wg_aip_lookup(sc, AF_INET6, &mtod(m, struct ip6_hdr *)->ip6_dst);
2226 } else {
2227 rc = EAFNOSUPPORT;
2228 goto err_xmit;
2229 }
2230
2231 BPF_MTAP2_AF(ifp, m, pkt->p_af);
2232
2233 if (__predict_false(peer == NULL)) {
2234 rc = ENETUNREACH;
2235 goto err_xmit;
2236 }
2237
2238 if (__predict_false(if_tunnel_check_nesting(ifp, m, MTAG_WGLOOP, MAX_LOOPS))) {
2239 DPRINTF(sc, "Packet looped");
2240 rc = ELOOP;
2241 goto err_peer;
2242 }
2243
2244 peer_af = peer->p_endpoint.e_remote.r_sa.sa_family;
2245 if (__predict_false(peer_af != AF_INET && peer_af != AF_INET6)) {
2246 DPRINTF(sc, "No valid endpoint has been configured or "
2247 "discovered for peer %" PRIu64 "\n", peer->p_id);
2248 rc = EHOSTUNREACH;
2249 goto err_peer;
2250 }
2251
2252 wg_queue_push_staged(&peer->p_stage_queue, pkt);
2253 wg_peer_send_staged(peer);
2254 noise_remote_put(peer->p_remote);
2255 return (0);
2256
2257 err_peer:
2258 noise_remote_put(peer->p_remote);
2259 err_xmit:
2260 xmit_err(ifp, m, pkt, af);
2261 return (rc);
2262 }
2263
2264 static inline int
determine_af_and_pullup(struct mbuf ** m,sa_family_t * af)2265 determine_af_and_pullup(struct mbuf **m, sa_family_t *af)
2266 {
2267 u_char ipv;
2268 if ((*m)->m_pkthdr.len >= sizeof(struct ip6_hdr))
2269 *m = m_pullup(*m, sizeof(struct ip6_hdr));
2270 else if ((*m)->m_pkthdr.len >= sizeof(struct ip))
2271 *m = m_pullup(*m, sizeof(struct ip));
2272 else
2273 return (EAFNOSUPPORT);
2274 if (*m == NULL)
2275 return (ENOBUFS);
2276 ipv = mtod(*m, struct ip *)->ip_v;
2277 if (ipv == 4)
2278 *af = AF_INET;
2279 else if (ipv == 6 && (*m)->m_pkthdr.len >= sizeof(struct ip6_hdr))
2280 *af = AF_INET6;
2281 else
2282 return (EAFNOSUPPORT);
2283 return (0);
2284 }
2285
2286 static int
determine_ethertype_and_pullup(struct mbuf ** m,int * etp)2287 determine_ethertype_and_pullup(struct mbuf **m, int *etp)
2288 {
2289 struct ether_header *eh;
2290
2291 *m = m_pullup(*m, sizeof(struct ether_header));
2292 if (__predict_false(*m == NULL))
2293 return (ENOBUFS);
2294 eh = mtod(*m, struct ether_header *);
2295 *etp = ntohs(eh->ether_type);
2296 if (*etp != ETHERTYPE_IP && *etp != ETHERTYPE_IPV6)
2297 return (EAFNOSUPPORT);
2298 return (0);
2299 }
2300
2301 /*
2302 * This should only be invoked by netmap, via nm_os_generic_xmit_frame(), to
2303 * transmit packets from the netmap TX ring.
2304 */
2305 static int
wg_transmit(if_t ifp,struct mbuf * m)2306 wg_transmit(if_t ifp, struct mbuf *m)
2307 {
2308 sa_family_t af;
2309 int et, ret;
2310 struct mbuf *defragged;
2311
2312 KASSERT((if_getcapenable(ifp) & IFCAP_NETMAP) != 0,
2313 ("%s: ifp %p is not in netmap mode", __func__, ifp));
2314
2315 defragged = m_defrag(m, M_NOWAIT);
2316 if (defragged)
2317 m = defragged;
2318 m = m_unshare(m, M_NOWAIT);
2319 if (!m) {
2320 xmit_err(ifp, m, NULL, AF_UNSPEC);
2321 return (ENOBUFS);
2322 }
2323
2324 ret = determine_ethertype_and_pullup(&m, &et);
2325 if (ret) {
2326 xmit_err(ifp, m, NULL, AF_UNSPEC);
2327 return (ret);
2328 }
2329 m_adj(m, sizeof(struct ether_header));
2330
2331 ret = determine_af_and_pullup(&m, &af);
2332 if (ret) {
2333 xmit_err(ifp, m, NULL, AF_UNSPEC);
2334 return (ret);
2335 }
2336
2337 /*
2338 * netmap only gets to see transient errors, since it handles errors by
2339 * refusing to advance the transmit ring and retrying later.
2340 */
2341 ret = wg_xmit(ifp, m, af, if_getmtu(ifp));
2342 if (ret == ENOBUFS)
2343 return (ret);
2344 return (0);
2345 }
2346
2347 #ifdef DEV_NETMAP
2348 /*
2349 * This should only be invoked by netmap, via nm_os_send_up(), to process
2350 * packets from the host TX ring.
2351 */
2352 static void
wg_if_input(if_t ifp,struct mbuf * m)2353 wg_if_input(if_t ifp, struct mbuf *m)
2354 {
2355 int et;
2356
2357 KASSERT((if_getcapenable(ifp) & IFCAP_NETMAP) != 0,
2358 ("%s: ifp %p is not in netmap mode", __func__, ifp));
2359
2360 if (determine_ethertype_and_pullup(&m, &et) != 0) {
2361 if_inc_counter(ifp, IFCOUNTER_IERRORS, 1);
2362 m_freem(m);
2363 return;
2364 }
2365 CURVNET_SET(if_getvnet(ifp));
2366 switch (et) {
2367 case ETHERTYPE_IP:
2368 m_adj(m, sizeof(struct ether_header));
2369 netisr_dispatch(NETISR_IP, m);
2370 break;
2371 case ETHERTYPE_IPV6:
2372 m_adj(m, sizeof(struct ether_header));
2373 netisr_dispatch(NETISR_IPV6, m);
2374 break;
2375 default:
2376 __assert_unreachable();
2377 }
2378 CURVNET_RESTORE();
2379 }
2380
2381 /*
2382 * Deliver a packet to the host RX ring. Because the interface is in netmap
2383 * mode, the if_transmit() call should pass the packet to netmap_transmit().
2384 */
2385 static int
wg_xmit_netmap(if_t ifp,struct mbuf * m,int af)2386 wg_xmit_netmap(if_t ifp, struct mbuf *m, int af)
2387 {
2388 struct ether_header *eh;
2389
2390 if (__predict_false(if_tunnel_check_nesting(ifp, m, MTAG_WGLOOP,
2391 MAX_LOOPS))) {
2392 printf("%s:%d\n", __func__, __LINE__);
2393 if_inc_counter(ifp, IFCOUNTER_IERRORS, 1);
2394 m_freem(m);
2395 return (ELOOP);
2396 }
2397
2398 M_PREPEND(m, ETHER_HDR_LEN, M_NOWAIT);
2399 if (__predict_false(m == NULL)) {
2400 if_inc_counter(ifp, IFCOUNTER_IQDROPS, 1);
2401 return (ENOBUFS);
2402 }
2403
2404 eh = mtod(m, struct ether_header *);
2405 eh->ether_type = af == AF_INET ?
2406 htons(ETHERTYPE_IP) : htons(ETHERTYPE_IPV6);
2407 memcpy(eh->ether_shost, "\x06\x06\x06\x06\x06\x06", ETHER_ADDR_LEN);
2408 memcpy(eh->ether_dhost, "\xff\xff\xff\xff\xff\xff", ETHER_ADDR_LEN);
2409 return (if_transmit(ifp, m));
2410 }
2411 #endif /* DEV_NETMAP */
2412
2413 static int
wg_output(if_t ifp,struct mbuf * m,const struct sockaddr * dst,struct route * ro)2414 wg_output(if_t ifp, struct mbuf *m, const struct sockaddr *dst, struct route *ro)
2415 {
2416 sa_family_t parsed_af;
2417 uint32_t af, mtu;
2418 int ret;
2419 struct mbuf *defragged;
2420
2421 /* BPF writes need to be handled specially. */
2422 if (dst->sa_family == AF_UNSPEC || dst->sa_family == pseudo_AF_HDRCMPLT)
2423 memcpy(&af, dst->sa_data, sizeof(af));
2424 else
2425 af = RO_GET_FAMILY(ro, dst);
2426 if (af == AF_UNSPEC) {
2427 xmit_err(ifp, m, NULL, af);
2428 return (EAFNOSUPPORT);
2429 }
2430
2431 #ifdef DEV_NETMAP
2432 if ((if_getcapenable(ifp) & IFCAP_NETMAP) != 0)
2433 return (wg_xmit_netmap(ifp, m, af));
2434 #endif
2435
2436 defragged = m_defrag(m, M_NOWAIT);
2437 if (defragged)
2438 m = defragged;
2439 m = m_unshare(m, M_NOWAIT);
2440 if (!m) {
2441 xmit_err(ifp, m, NULL, AF_UNSPEC);
2442 return (ENOBUFS);
2443 }
2444
2445 ret = determine_af_and_pullup(&m, &parsed_af);
2446 if (ret) {
2447 xmit_err(ifp, m, NULL, AF_UNSPEC);
2448 return (ret);
2449 }
2450
2451 MPASS(parsed_af == af);
2452 mtu = (ro != NULL && ro->ro_mtu > 0) ? ro->ro_mtu : if_getmtu(ifp);
2453 return (wg_xmit(ifp, m, parsed_af, mtu));
2454 }
2455
2456 static int
wg_peer_add(struct wg_softc * sc,const nvlist_t * nvl)2457 wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl)
2458 {
2459 uint8_t public[WG_KEY_SIZE];
2460 const void *pub_key, *preshared_key = NULL;
2461 const struct sockaddr *endpoint;
2462 int err;
2463 size_t size;
2464 struct noise_remote *remote;
2465 struct wg_peer *peer = NULL;
2466 bool need_cleanup = false;
2467
2468 sx_assert(&sc->sc_lock, SX_XLOCKED);
2469
2470 if (!nvlist_exists_binary(nvl, "public-key")) {
2471 return (EINVAL);
2472 }
2473 pub_key = nvlist_get_binary(nvl, "public-key", &size);
2474 if (size != WG_KEY_SIZE) {
2475 return (EINVAL);
2476 }
2477 if (noise_local_keys(sc->sc_local, public, NULL) == 0 &&
2478 bcmp(public, pub_key, WG_KEY_SIZE) == 0) {
2479 return (0); // Silently ignored; not actually a failure.
2480 }
2481 if ((remote = noise_remote_lookup(sc->sc_local, pub_key)) != NULL)
2482 peer = noise_remote_arg(remote);
2483 if (nvlist_exists_bool(nvl, "remove") &&
2484 nvlist_get_bool(nvl, "remove")) {
2485 if (remote != NULL) {
2486 wg_peer_destroy(peer);
2487 noise_remote_put(remote);
2488 }
2489 return (0);
2490 }
2491 if (nvlist_exists_bool(nvl, "replace-allowedips") &&
2492 nvlist_get_bool(nvl, "replace-allowedips") &&
2493 peer != NULL) {
2494
2495 wg_aip_remove_all(sc, peer);
2496 }
2497 if (peer == NULL) {
2498 peer = wg_peer_create(sc, pub_key, &err);
2499 if (peer == NULL)
2500 goto out;
2501 need_cleanup = true;
2502 }
2503 if (nvlist_exists_binary(nvl, "endpoint")) {
2504 endpoint = nvlist_get_binary(nvl, "endpoint", &size);
2505 if (size > sizeof(peer->p_endpoint.e_remote)) {
2506 err = EINVAL;
2507 goto out;
2508 }
2509 memcpy(&peer->p_endpoint.e_remote, endpoint, size);
2510 }
2511 if (nvlist_exists_binary(nvl, "preshared-key")) {
2512 preshared_key = nvlist_get_binary(nvl, "preshared-key", &size);
2513 if (size != WG_KEY_SIZE) {
2514 err = EINVAL;
2515 goto out;
2516 }
2517 noise_remote_set_psk(peer->p_remote, preshared_key);
2518 }
2519 if (nvlist_exists_number(nvl, "persistent-keepalive-interval")) {
2520 uint64_t pki = nvlist_get_number(nvl, "persistent-keepalive-interval");
2521 if (pki > UINT16_MAX) {
2522 err = EINVAL;
2523 goto out;
2524 }
2525 wg_timers_set_persistent_keepalive(peer, pki);
2526 }
2527 if (nvlist_exists_nvlist_array(nvl, "allowed-ips")) {
2528 const void *addr;
2529 uint64_t cidr;
2530 const nvlist_t * const * aipl;
2531 size_t allowedip_count;
2532
2533 aipl = nvlist_get_nvlist_array(nvl, "allowed-ips", &allowedip_count);
2534 for (size_t idx = 0; idx < allowedip_count; idx++) {
2535 sa_family_t ipaf;
2536 int ipflags;
2537
2538 if (!nvlist_exists_number(aipl[idx], "cidr"))
2539 continue;
2540
2541 ipaf = AF_UNSPEC;
2542 ipflags = 0;
2543 if (nvlist_exists_number(aipl[idx], "flags"))
2544 ipflags = nvlist_get_number(aipl[idx], "flags");
2545 if ((ipflags & ~WGALLOWEDIP_VALID_FLAGS) != 0) {
2546 err = EOPNOTSUPP;
2547 goto out;
2548 }
2549 cidr = nvlist_get_number(aipl[idx], "cidr");
2550 if (nvlist_exists_binary(aipl[idx], "ipv4")) {
2551 addr = nvlist_get_binary(aipl[idx], "ipv4", &size);
2552 if (addr == NULL || cidr > 32 || size != sizeof(struct in_addr)) {
2553 err = EINVAL;
2554 goto out;
2555 }
2556
2557 ipaf = AF_INET;
2558 } else if (nvlist_exists_binary(aipl[idx], "ipv6")) {
2559 addr = nvlist_get_binary(aipl[idx], "ipv6", &size);
2560 if (addr == NULL || cidr > 128 || size != sizeof(struct in6_addr)) {
2561 err = EINVAL;
2562 goto out;
2563 }
2564
2565 ipaf = AF_INET6;
2566 } else {
2567 continue;
2568 }
2569
2570 MPASS(ipaf != AF_UNSPEC);
2571 if ((ipflags & WGALLOWEDIP_REMOVE_ME) != 0) {
2572 err = wg_aip_del(sc, peer, ipaf, addr, cidr);
2573 } else {
2574 err = wg_aip_add(sc, peer, ipaf, addr, cidr);
2575 }
2576
2577 if (err != 0)
2578 goto out;
2579 }
2580 }
2581 if (remote != NULL)
2582 noise_remote_put(remote);
2583 return (0);
2584 out:
2585 if (need_cleanup) /* If we fail, only destroy if it was new. */
2586 wg_peer_destroy(peer);
2587 if (remote != NULL)
2588 noise_remote_put(remote);
2589 return (err);
2590 }
2591
2592 static int
wgc_set(struct wg_softc * sc,struct wg_data_io * wgd)2593 wgc_set(struct wg_softc *sc, struct wg_data_io *wgd)
2594 {
2595 uint8_t public[WG_KEY_SIZE], private[WG_KEY_SIZE];
2596 if_t ifp;
2597 void *nvlpacked;
2598 nvlist_t *nvl;
2599 ssize_t size;
2600 int err;
2601
2602 ifp = sc->sc_ifp;
2603 if (wgd->wgd_size == 0 || wgd->wgd_data == NULL)
2604 return (EFAULT);
2605
2606 /* Can nvlists be streamed in? It's not nice to impose arbitrary limits like that but
2607 * there needs to be _some_ limitation. */
2608 if (wgd->wgd_size >= UINT32_MAX / 2)
2609 return (E2BIG);
2610
2611 nvlpacked = malloc(wgd->wgd_size, M_TEMP, M_WAITOK | M_ZERO);
2612
2613 err = copyin(wgd->wgd_data, nvlpacked, wgd->wgd_size);
2614 if (err)
2615 goto out;
2616 nvl = nvlist_unpack(nvlpacked, wgd->wgd_size, 0);
2617 if (nvl == NULL) {
2618 err = EBADMSG;
2619 goto out;
2620 }
2621 sx_xlock(&sc->sc_lock);
2622 if (nvlist_exists_bool(nvl, "replace-peers") &&
2623 nvlist_get_bool(nvl, "replace-peers"))
2624 wg_peer_destroy_all(sc);
2625 if (nvlist_exists_number(nvl, "listen-port")) {
2626 uint64_t new_port = nvlist_get_number(nvl, "listen-port");
2627 if (new_port > UINT16_MAX) {
2628 err = EINVAL;
2629 goto out_locked;
2630 }
2631 if (new_port != sc->sc_socket.so_port) {
2632 if ((if_getdrvflags(ifp) & IFF_DRV_RUNNING) != 0) {
2633 if ((err = wg_socket_init(sc, new_port)) != 0)
2634 goto out_locked;
2635 } else
2636 sc->sc_socket.so_port = new_port;
2637 }
2638 }
2639 if (nvlist_exists_binary(nvl, "private-key")) {
2640 const void *key = nvlist_get_binary(nvl, "private-key", &size);
2641 if (size != WG_KEY_SIZE) {
2642 err = EINVAL;
2643 goto out_locked;
2644 }
2645
2646 if (noise_local_keys(sc->sc_local, NULL, private) != 0 ||
2647 timingsafe_bcmp(private, key, WG_KEY_SIZE) != 0) {
2648 struct wg_peer *peer;
2649
2650 if (curve25519_generate_public(public, key)) {
2651 /* Peer conflict: remove conflicting peer. */
2652 struct noise_remote *remote;
2653 if ((remote = noise_remote_lookup(sc->sc_local,
2654 public)) != NULL) {
2655 peer = noise_remote_arg(remote);
2656 wg_peer_destroy(peer);
2657 noise_remote_put(remote);
2658 }
2659 }
2660
2661 /*
2662 * Set the private key and invalidate all existing
2663 * handshakes.
2664 */
2665 /* Note: we might be removing the private key. */
2666 noise_local_private(sc->sc_local, key);
2667 if (noise_local_keys(sc->sc_local, NULL, NULL) == 0)
2668 cookie_checker_update(&sc->sc_cookie, public);
2669 else
2670 cookie_checker_update(&sc->sc_cookie, NULL);
2671 }
2672 }
2673 if (nvlist_exists_number(nvl, "user-cookie")) {
2674 uint64_t user_cookie = nvlist_get_number(nvl, "user-cookie");
2675 if (user_cookie > UINT32_MAX) {
2676 err = EINVAL;
2677 goto out_locked;
2678 }
2679 err = wg_socket_set_cookie(sc, user_cookie);
2680 if (err)
2681 goto out_locked;
2682 }
2683 if (nvlist_exists_nvlist_array(nvl, "peers")) {
2684 size_t peercount;
2685 const nvlist_t * const*nvl_peers;
2686
2687 nvl_peers = nvlist_get_nvlist_array(nvl, "peers", &peercount);
2688 for (int i = 0; i < peercount; i++) {
2689 err = wg_peer_add(sc, nvl_peers[i]);
2690 if (err != 0)
2691 goto out_locked;
2692 }
2693 }
2694
2695 out_locked:
2696 sx_xunlock(&sc->sc_lock);
2697 nvlist_destroy(nvl);
2698 out:
2699 zfree(nvlpacked, M_TEMP);
2700 return (err);
2701 }
2702
2703 static int
wgc_get(struct wg_softc * sc,struct wg_data_io * wgd)2704 wgc_get(struct wg_softc *sc, struct wg_data_io *wgd)
2705 {
2706 uint8_t public_key[WG_KEY_SIZE] = { 0 };
2707 uint8_t private_key[WG_KEY_SIZE] = { 0 };
2708 uint8_t preshared_key[NOISE_SYMMETRIC_KEY_LEN] = { 0 };
2709 nvlist_t *nvl, *nvl_peer, *nvl_aip, **nvl_peers, **nvl_aips;
2710 size_t size, peer_count, aip_count, i, j;
2711 struct wg_timespec64 ts64;
2712 struct wg_peer *peer;
2713 struct wg_aip *aip;
2714 void *packed;
2715 int err = 0;
2716
2717 nvl = nvlist_create(0);
2718 if (!nvl)
2719 return (ENOMEM);
2720
2721 sx_slock(&sc->sc_lock);
2722
2723 if (sc->sc_socket.so_port != 0)
2724 nvlist_add_number(nvl, "listen-port", sc->sc_socket.so_port);
2725 if (sc->sc_socket.so_user_cookie != 0)
2726 nvlist_add_number(nvl, "user-cookie", sc->sc_socket.so_user_cookie);
2727 if (noise_local_keys(sc->sc_local, public_key, private_key) == 0) {
2728 nvlist_add_binary(nvl, "public-key", public_key, WG_KEY_SIZE);
2729 if (wgc_privileged(sc))
2730 nvlist_add_binary(nvl, "private-key", private_key, WG_KEY_SIZE);
2731 explicit_bzero(private_key, sizeof(private_key));
2732 }
2733 peer_count = sc->sc_peers_num;
2734 if (peer_count) {
2735 nvl_peers = mallocarray(peer_count, sizeof(void *), M_NVLIST, M_WAITOK | M_ZERO);
2736 i = 0;
2737 TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) {
2738 if (i >= peer_count)
2739 panic("peers changed from under us");
2740
2741 nvl_peers[i++] = nvl_peer = nvlist_create(0);
2742 if (!nvl_peer) {
2743 err = ENOMEM;
2744 goto err_peer;
2745 }
2746
2747 (void)noise_remote_keys(peer->p_remote, public_key, preshared_key);
2748 nvlist_add_binary(nvl_peer, "public-key", public_key, sizeof(public_key));
2749 if (wgc_privileged(sc))
2750 nvlist_add_binary(nvl_peer, "preshared-key", preshared_key, sizeof(preshared_key));
2751 explicit_bzero(preshared_key, sizeof(preshared_key));
2752 if (peer->p_endpoint.e_remote.r_sa.sa_family == AF_INET)
2753 nvlist_add_binary(nvl_peer, "endpoint", &peer->p_endpoint.e_remote, sizeof(struct sockaddr_in));
2754 else if (peer->p_endpoint.e_remote.r_sa.sa_family == AF_INET6)
2755 nvlist_add_binary(nvl_peer, "endpoint", &peer->p_endpoint.e_remote, sizeof(struct sockaddr_in6));
2756 wg_timers_get_last_handshake(peer, &ts64);
2757 nvlist_add_binary(nvl_peer, "last-handshake-time", &ts64, sizeof(ts64));
2758 nvlist_add_number(nvl_peer, "persistent-keepalive-interval", peer->p_persistent_keepalive_interval);
2759 nvlist_add_number(nvl_peer, "rx-bytes", counter_u64_fetch(peer->p_rx_bytes));
2760 nvlist_add_number(nvl_peer, "tx-bytes", counter_u64_fetch(peer->p_tx_bytes));
2761
2762 aip_count = peer->p_aips_num;
2763 if (aip_count) {
2764 nvl_aips = mallocarray(aip_count, sizeof(void *), M_NVLIST, M_WAITOK | M_ZERO);
2765 j = 0;
2766 LIST_FOREACH(aip, &peer->p_aips, a_entry) {
2767 if (j >= aip_count)
2768 panic("aips changed from under us");
2769
2770 nvl_aips[j++] = nvl_aip = nvlist_create(0);
2771 if (!nvl_aip) {
2772 err = ENOMEM;
2773 goto err_aip;
2774 }
2775 if (aip->a_af == AF_INET) {
2776 nvlist_add_binary(nvl_aip, "ipv4", &aip->a_addr.in, sizeof(aip->a_addr.in));
2777 nvlist_add_number(nvl_aip, "cidr", bitcount32(aip->a_mask.ip));
2778 }
2779 #ifdef INET6
2780 else if (aip->a_af == AF_INET6) {
2781 nvlist_add_binary(nvl_aip, "ipv6", &aip->a_addr.in6, sizeof(aip->a_addr.in6));
2782 nvlist_add_number(nvl_aip, "cidr", in6_mask2len(&aip->a_mask.in6, NULL));
2783 }
2784 #endif
2785 }
2786 nvlist_add_nvlist_array(nvl_peer, "allowed-ips", (const nvlist_t *const *)nvl_aips, aip_count);
2787 err_aip:
2788 for (j = 0; j < aip_count; ++j)
2789 nvlist_destroy(nvl_aips[j]);
2790 free(nvl_aips, M_NVLIST);
2791 if (err)
2792 goto err_peer;
2793 }
2794 }
2795 nvlist_add_nvlist_array(nvl, "peers", (const nvlist_t * const *)nvl_peers, peer_count);
2796 err_peer:
2797 for (i = 0; i < peer_count; ++i)
2798 nvlist_destroy(nvl_peers[i]);
2799 free(nvl_peers, M_NVLIST);
2800 if (err) {
2801 sx_sunlock(&sc->sc_lock);
2802 goto err;
2803 }
2804 }
2805 sx_sunlock(&sc->sc_lock);
2806 packed = nvlist_pack(nvl, &size);
2807 if (!packed) {
2808 err = ENOMEM;
2809 goto err;
2810 }
2811 if (!wgd->wgd_size) {
2812 wgd->wgd_size = size;
2813 goto out;
2814 }
2815 if (wgd->wgd_size < size) {
2816 err = ENOSPC;
2817 goto out;
2818 }
2819 err = copyout(packed, wgd->wgd_data, size);
2820 wgd->wgd_size = size;
2821
2822 out:
2823 zfree(packed, M_NVLIST);
2824 err:
2825 nvlist_destroy(nvl);
2826 return (err);
2827 }
2828
2829 static int
wg_ioctl(if_t ifp,u_long cmd,caddr_t data)2830 wg_ioctl(if_t ifp, u_long cmd, caddr_t data)
2831 {
2832 struct wg_data_io *wgd = (struct wg_data_io *)data;
2833 struct ifreq *ifr = (struct ifreq *)data;
2834 struct wg_softc *sc;
2835 int ret = 0;
2836
2837 sx_slock(&wg_sx);
2838 sc = if_getsoftc(ifp);
2839 if (!sc) {
2840 ret = ENXIO;
2841 goto out;
2842 }
2843
2844 switch (cmd) {
2845 case SIOCSWG:
2846 ret = priv_check(curthread, PRIV_NET_WG);
2847 if (ret == 0)
2848 ret = wgc_set(sc, wgd);
2849 break;
2850 case SIOCGWG:
2851 ret = wgc_get(sc, wgd);
2852 break;
2853 /* Interface IOCTLs */
2854 case SIOCSIFADDR:
2855 /*
2856 * This differs from *BSD norms, but is more uniform with how
2857 * WireGuard behaves elsewhere.
2858 */
2859 break;
2860 case SIOCSIFFLAGS:
2861 if (if_getflags(ifp) & IFF_UP)
2862 ret = wg_up(sc);
2863 else
2864 wg_down(sc);
2865 break;
2866 case SIOCSIFMTU:
2867 if (ifr->ifr_mtu <= 0 || ifr->ifr_mtu > MAX_MTU)
2868 ret = EINVAL;
2869 else
2870 if_setmtu(ifp, ifr->ifr_mtu);
2871 break;
2872 case SIOCADDMULTI:
2873 case SIOCDELMULTI:
2874 break;
2875 case SIOCGTUNFIB:
2876 ifr->ifr_fib = sc->sc_socket.so_fibnum;
2877 break;
2878 case SIOCSTUNFIB:
2879 ret = priv_check(curthread, PRIV_NET_WG);
2880 if (ret)
2881 break;
2882 ret = priv_check(curthread, PRIV_NET_SETIFFIB);
2883 if (ret)
2884 break;
2885 sx_xlock(&sc->sc_lock);
2886 ret = wg_socket_set_fibnum(sc, ifr->ifr_fib);
2887 sx_xunlock(&sc->sc_lock);
2888 break;
2889 default:
2890 ret = ENOTTY;
2891 }
2892
2893 out:
2894 sx_sunlock(&wg_sx);
2895 return (ret);
2896 }
2897
2898 static int
wg_up(struct wg_softc * sc)2899 wg_up(struct wg_softc *sc)
2900 {
2901 if_t ifp = sc->sc_ifp;
2902 struct wg_peer *peer;
2903 int rc = EBUSY;
2904
2905 sx_xlock(&sc->sc_lock);
2906 /* Jail's being removed, no more wg_up(). */
2907 if ((sc->sc_flags & WGF_DYING) != 0)
2908 goto out;
2909
2910 /* Silent success if we're already running. */
2911 rc = 0;
2912 if (if_getdrvflags(ifp) & IFF_DRV_RUNNING)
2913 goto out;
2914 if_setdrvflagbits(ifp, IFF_DRV_RUNNING, 0);
2915
2916 rc = wg_socket_init(sc, sc->sc_socket.so_port);
2917 if (rc == 0) {
2918 TAILQ_FOREACH(peer, &sc->sc_peers, p_entry)
2919 wg_timers_enable(peer);
2920 if_link_state_change(sc->sc_ifp, LINK_STATE_UP);
2921 } else {
2922 if_setdrvflagbits(ifp, 0, IFF_DRV_RUNNING);
2923 DPRINTF(sc, "Unable to initialize sockets: %d\n", rc);
2924 }
2925 out:
2926 sx_xunlock(&sc->sc_lock);
2927 return (rc);
2928 }
2929
2930 static void
wg_down(struct wg_softc * sc)2931 wg_down(struct wg_softc *sc)
2932 {
2933 if_t ifp = sc->sc_ifp;
2934 struct wg_peer *peer;
2935
2936 sx_xlock(&sc->sc_lock);
2937 if (!(if_getdrvflags(ifp) & IFF_DRV_RUNNING)) {
2938 sx_xunlock(&sc->sc_lock);
2939 return;
2940 }
2941 if_setdrvflagbits(ifp, 0, IFF_DRV_RUNNING);
2942
2943 TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) {
2944 wg_queue_purge(&peer->p_stage_queue);
2945 wg_timers_disable(peer);
2946 }
2947
2948 wg_queue_purge(&sc->sc_handshake_queue);
2949
2950 TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) {
2951 noise_remote_handshake_clear(peer->p_remote);
2952 noise_remote_keypairs_clear(peer->p_remote);
2953 }
2954
2955 if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
2956 wg_socket_uninit(sc);
2957
2958 sx_xunlock(&sc->sc_lock);
2959 }
2960
2961 static int
wg_clone_create(struct if_clone * ifc,char * name,size_t len,struct ifc_data * ifd,struct ifnet ** ifpp)2962 wg_clone_create(struct if_clone *ifc, char *name, size_t len,
2963 struct ifc_data *ifd, struct ifnet **ifpp)
2964 {
2965 struct wg_softc *sc;
2966 if_t ifp;
2967
2968 sc = malloc(sizeof(*sc), M_WG, M_WAITOK | M_ZERO);
2969
2970 sc->sc_local = noise_local_alloc(sc);
2971
2972 sc->sc_encrypt = mallocarray(sizeof(struct grouptask), mp_ncpus, M_WG, M_WAITOK | M_ZERO);
2973
2974 sc->sc_decrypt = mallocarray(sizeof(struct grouptask), mp_ncpus, M_WG, M_WAITOK | M_ZERO);
2975
2976 if (!rn_inithead((void **)&sc->sc_aip4, offsetof(struct aip_addr, in) * NBBY))
2977 goto free_decrypt;
2978
2979 if (!rn_inithead((void **)&sc->sc_aip6, offsetof(struct aip_addr, in6) * NBBY))
2980 goto free_aip4;
2981
2982 atomic_add_int(&clone_count, 1);
2983 ifp = sc->sc_ifp = if_alloc(IFT_WIREGUARD);
2984
2985 sc->sc_ucred = crhold(curthread->td_ucred);
2986 sc->sc_socket.so_fibnum = curthread->td_proc->p_fibnum;
2987 sc->sc_socket.so_port = 0;
2988
2989 TAILQ_INIT(&sc->sc_peers);
2990 sc->sc_peers_num = 0;
2991
2992 cookie_checker_init(&sc->sc_cookie);
2993
2994 RADIX_NODE_HEAD_LOCK_INIT(sc->sc_aip4);
2995 RADIX_NODE_HEAD_LOCK_INIT(sc->sc_aip6);
2996
2997 GROUPTASK_INIT(&sc->sc_handshake, 0, (gtask_fn_t *)wg_softc_handshake_receive, sc);
2998 taskqgroup_attach(qgroup_wg_tqg, &sc->sc_handshake, sc, NULL, NULL, "wg tx initiation");
2999 wg_queue_init(&sc->sc_handshake_queue, "hsq");
3000
3001 for (int i = 0; i < mp_ncpus; i++) {
3002 GROUPTASK_INIT(&sc->sc_encrypt[i], 0,
3003 (gtask_fn_t *)wg_softc_encrypt, sc);
3004 taskqgroup_attach_cpu(qgroup_wg_tqg, &sc->sc_encrypt[i], sc, i, NULL, NULL, "wg encrypt");
3005 GROUPTASK_INIT(&sc->sc_decrypt[i], 0,
3006 (gtask_fn_t *)wg_softc_decrypt, sc);
3007 taskqgroup_attach_cpu(qgroup_wg_tqg, &sc->sc_decrypt[i], sc, i, NULL, NULL, "wg decrypt");
3008 }
3009
3010 wg_queue_init(&sc->sc_encrypt_parallel, "encp");
3011 wg_queue_init(&sc->sc_decrypt_parallel, "decp");
3012
3013 sx_init(&sc->sc_lock, "wg softc lock");
3014
3015 if_setsoftc(ifp, sc);
3016 if_setcapabilities(ifp, WG_CAPS);
3017 if_setcapenable(ifp, WG_CAPS);
3018 if_initname(ifp, wgname, ifd->unit);
3019
3020 if_setmtu(ifp, DEFAULT_MTU);
3021 if_setflags(ifp, IFF_NOARP | IFF_MULTICAST);
3022 if_setinitfn(ifp, wg_init);
3023 if_setreassignfn(ifp, wg_reassign);
3024 if_setqflushfn(ifp, wg_qflush);
3025 if_settransmitfn(ifp, wg_transmit);
3026 #ifdef DEV_NETMAP
3027 if_setinputfn(ifp, wg_if_input);
3028 #endif
3029 if_setoutputfn(ifp, wg_output);
3030 if_setioctlfn(ifp, wg_ioctl);
3031 if_attach(ifp);
3032 bpfattach(ifp, DLT_NULL, sizeof(uint32_t));
3033 #ifdef INET6
3034 ND_IFINFO(ifp)->flags &= ~ND6_IFF_AUTO_LINKLOCAL;
3035 ND_IFINFO(ifp)->flags |= ND6_IFF_NO_DAD;
3036 #endif
3037 sx_xlock(&wg_sx);
3038 LIST_INSERT_HEAD(&wg_list, sc, sc_entry);
3039 sx_xunlock(&wg_sx);
3040 *ifpp = ifp;
3041 return (0);
3042 free_aip4:
3043 RADIX_NODE_HEAD_DESTROY(sc->sc_aip4);
3044 free(sc->sc_aip4, M_RTABLE);
3045 free_decrypt:
3046 free(sc->sc_decrypt, M_WG);
3047 free(sc->sc_encrypt, M_WG);
3048 noise_local_free(sc->sc_local, NULL);
3049 free(sc, M_WG);
3050 return (ENOMEM);
3051 }
3052
3053 static void
wg_clone_deferred_free(struct noise_local * l)3054 wg_clone_deferred_free(struct noise_local *l)
3055 {
3056 struct wg_softc *sc = noise_local_arg(l);
3057
3058 free(sc, M_WG);
3059 atomic_add_int(&clone_count, -1);
3060 }
3061
3062 static int
wg_clone_destroy(struct if_clone * ifc,if_t ifp,uint32_t flags)3063 wg_clone_destroy(struct if_clone *ifc, if_t ifp, uint32_t flags)
3064 {
3065 struct wg_softc *sc = if_getsoftc(ifp);
3066 struct ucred *cred;
3067
3068 sx_xlock(&wg_sx);
3069 if_setsoftc(ifp, NULL);
3070 sx_xlock(&sc->sc_lock);
3071 sc->sc_flags |= WGF_DYING;
3072 cred = sc->sc_ucred;
3073 sc->sc_ucred = NULL;
3074 sx_xunlock(&sc->sc_lock);
3075 LIST_REMOVE(sc, sc_entry);
3076 sx_xunlock(&wg_sx);
3077
3078 if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
3079 CURVNET_SET(if_getvnet(sc->sc_ifp));
3080 if_purgeaddrs(sc->sc_ifp);
3081 CURVNET_RESTORE();
3082
3083 sx_xlock(&sc->sc_lock);
3084 wg_socket_uninit(sc);
3085 sx_xunlock(&sc->sc_lock);
3086
3087 /*
3088 * No guarantees that all traffic have passed until the epoch has
3089 * elapsed with the socket closed.
3090 */
3091 NET_EPOCH_WAIT();
3092
3093 taskqgroup_drain_all(qgroup_wg_tqg);
3094 sx_xlock(&sc->sc_lock);
3095 wg_peer_destroy_all(sc);
3096 NET_EPOCH_DRAIN_CALLBACKS();
3097 sx_xunlock(&sc->sc_lock);
3098 sx_destroy(&sc->sc_lock);
3099 taskqgroup_detach(qgroup_wg_tqg, &sc->sc_handshake);
3100 for (int i = 0; i < mp_ncpus; i++) {
3101 taskqgroup_detach(qgroup_wg_tqg, &sc->sc_encrypt[i]);
3102 taskqgroup_detach(qgroup_wg_tqg, &sc->sc_decrypt[i]);
3103 }
3104 free(sc->sc_encrypt, M_WG);
3105 free(sc->sc_decrypt, M_WG);
3106 wg_queue_deinit(&sc->sc_handshake_queue);
3107 wg_queue_deinit(&sc->sc_encrypt_parallel);
3108 wg_queue_deinit(&sc->sc_decrypt_parallel);
3109
3110 RADIX_NODE_HEAD_DESTROY(sc->sc_aip4);
3111 RADIX_NODE_HEAD_DESTROY(sc->sc_aip6);
3112 rn_detachhead((void **)&sc->sc_aip4);
3113 rn_detachhead((void **)&sc->sc_aip6);
3114
3115 cookie_checker_free(&sc->sc_cookie);
3116
3117 if (cred != NULL)
3118 crfree(cred);
3119 bpfdetach(sc->sc_ifp);
3120 if_detach(sc->sc_ifp);
3121 if_free(sc->sc_ifp);
3122
3123 noise_local_free(sc->sc_local, wg_clone_deferred_free);
3124
3125 return (0);
3126 }
3127
3128 static void
wg_qflush(if_t ifp __unused)3129 wg_qflush(if_t ifp __unused)
3130 {
3131 }
3132
3133 /*
3134 * Privileged information (private-key, preshared-key) are only exported for
3135 * root and jailed root by default.
3136 */
3137 static bool
wgc_privileged(struct wg_softc * sc)3138 wgc_privileged(struct wg_softc *sc)
3139 {
3140 struct thread *td;
3141
3142 td = curthread;
3143 return (priv_check(td, PRIV_NET_WG) == 0);
3144 }
3145
3146 static void
wg_reassign(if_t ifp,struct vnet * new_vnet __unused,char * unused __unused)3147 wg_reassign(if_t ifp, struct vnet *new_vnet __unused,
3148 char *unused __unused)
3149 {
3150 struct wg_softc *sc;
3151
3152 sc = if_getsoftc(ifp);
3153 wg_down(sc);
3154 }
3155
3156 static void
wg_init(void * xsc)3157 wg_init(void *xsc)
3158 {
3159 struct wg_softc *sc;
3160
3161 sc = xsc;
3162 wg_up(sc);
3163 }
3164
3165 static void
vnet_wg_init(const void * unused __unused)3166 vnet_wg_init(const void *unused __unused)
3167 {
3168 struct if_clone_addreq req = {
3169 .create_f = wg_clone_create,
3170 .destroy_f = wg_clone_destroy,
3171 .flags = IFC_F_AUTOUNIT,
3172 };
3173 V_wg_cloner = ifc_attach_cloner(wgname, &req);
3174 }
3175 VNET_SYSINIT(vnet_wg_init, SI_SUB_PROTO_IFATTACHDOMAIN, SI_ORDER_ANY,
3176 vnet_wg_init, NULL);
3177
3178 static void
vnet_wg_uninit(const void * unused __unused)3179 vnet_wg_uninit(const void *unused __unused)
3180 {
3181 if (V_wg_cloner)
3182 ifc_detach_cloner(V_wg_cloner);
3183 }
3184 VNET_SYSUNINIT(vnet_wg_uninit, SI_SUB_PROTO_IFATTACHDOMAIN, SI_ORDER_ANY,
3185 vnet_wg_uninit, NULL);
3186
3187 static int
wg_prison_remove(void * obj,void * data __unused)3188 wg_prison_remove(void *obj, void *data __unused)
3189 {
3190 const struct prison *pr = obj;
3191 struct wg_softc *sc;
3192
3193 /*
3194 * Do a pass through all if_wg interfaces and release creds on any from
3195 * the jail that are supposed to be going away. This will, in turn, let
3196 * the jail die so that we don't end up with Schrödinger's jail.
3197 */
3198 sx_slock(&wg_sx);
3199 LIST_FOREACH(sc, &wg_list, sc_entry) {
3200 sx_xlock(&sc->sc_lock);
3201 if (!(sc->sc_flags & WGF_DYING) && sc->sc_ucred && sc->sc_ucred->cr_prison == pr) {
3202 struct ucred *cred = sc->sc_ucred;
3203 DPRINTF(sc, "Creating jail exiting\n");
3204 if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
3205 wg_socket_uninit(sc);
3206 sc->sc_ucred = NULL;
3207 crfree(cred);
3208 sc->sc_flags |= WGF_DYING;
3209 }
3210 sx_xunlock(&sc->sc_lock);
3211 }
3212 sx_sunlock(&wg_sx);
3213
3214 return (0);
3215 }
3216
3217 #ifdef SELFTESTS
3218 #include "selftest/allowedips.c"
wg_run_selftests(void)3219 static bool wg_run_selftests(void)
3220 {
3221 bool ret = true;
3222 ret &= wg_allowedips_selftest();
3223 ret &= noise_counter_selftest();
3224 ret &= cookie_selftest();
3225 return ret;
3226 }
3227 #else
wg_run_selftests(void)3228 static inline bool wg_run_selftests(void) { return true; }
3229 #endif
3230
3231 static int
wg_module_init(void)3232 wg_module_init(void)
3233 {
3234 int ret;
3235 osd_method_t methods[PR_MAXMETHOD] = {
3236 [PR_METHOD_REMOVE] = wg_prison_remove,
3237 };
3238
3239 wg_packet_zone = uma_zcreate("wg packet", sizeof(struct wg_packet),
3240 NULL, NULL, NULL, NULL, 0, 0);
3241
3242 ret = crypto_init();
3243 if (ret != 0)
3244 return (ret);
3245 ret = cookie_init();
3246 if (ret != 0)
3247 return (ret);
3248
3249 wg_osd_jail_slot = osd_jail_register(NULL, methods);
3250
3251 if (!wg_run_selftests())
3252 return (ENOTRECOVERABLE);
3253
3254 return (0);
3255 }
3256
3257 static void
wg_module_deinit(void)3258 wg_module_deinit(void)
3259 {
3260 VNET_ITERATOR_DECL(vnet_iter);
3261 VNET_LIST_RLOCK();
3262 VNET_FOREACH(vnet_iter) {
3263 struct if_clone *clone = VNET_VNET(vnet_iter, wg_cloner);
3264 if (clone) {
3265 ifc_detach_cloner(clone);
3266 VNET_VNET(vnet_iter, wg_cloner) = NULL;
3267 }
3268 }
3269 VNET_LIST_RUNLOCK();
3270 NET_EPOCH_WAIT();
3271 MPASS(LIST_EMPTY(&wg_list));
3272 if (wg_osd_jail_slot != 0)
3273 osd_jail_deregister(wg_osd_jail_slot);
3274 cookie_deinit();
3275 crypto_deinit();
3276 if (wg_packet_zone != NULL)
3277 uma_zdestroy(wg_packet_zone);
3278 }
3279
3280 static int
wg_module_event_handler(module_t mod,int what,void * arg)3281 wg_module_event_handler(module_t mod, int what, void *arg)
3282 {
3283 switch (what) {
3284 case MOD_LOAD:
3285 return wg_module_init();
3286 case MOD_UNLOAD:
3287 wg_module_deinit();
3288 break;
3289 default:
3290 return (EOPNOTSUPP);
3291 }
3292 return (0);
3293 }
3294
3295 static moduledata_t wg_moduledata = {
3296 "if_wg",
3297 wg_module_event_handler,
3298 NULL
3299 };
3300
3301 DECLARE_MODULE(if_wg, wg_moduledata, SI_SUB_PSEUDO, SI_ORDER_ANY);
3302 MODULE_VERSION(if_wg, WIREGUARD_VERSION);
3303 MODULE_DEPEND(if_wg, crypto, 1, 1, 1);
3304