xref: /linux/drivers/net/wireguard/allowedips.c (revision 40286d6379aacfcc053253ef78dc78b09addffda)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4  */
5 
6 #include "allowedips.h"
7 #include "peer.h"
8 
9 enum { MAX_ALLOWEDIPS_DEPTH = 129 };
10 
11 static struct kmem_cache *node_cache;
12 
13 static void swap_endian(u8 *dst, const u8 *src, u8 bits)
14 {
15 	if (bits == 32) {
16 		*(u32 *)dst = be32_to_cpu(*(const __be32 *)src);
17 	} else if (bits == 128) {
18 		((u64 *)dst)[0] = get_unaligned_be64(src);
19 		((u64 *)dst)[1] = get_unaligned_be64(src + 8);
20 	}
21 }
22 
23 static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src,
24 				 u8 cidr, u8 bits)
25 {
26 	node->cidr = cidr;
27 	node->bit_at_a = cidr / 8U;
28 #ifdef __LITTLE_ENDIAN
29 	node->bit_at_a ^= (bits / 8U - 1U) % 8U;
30 #endif
31 	node->bit_at_b = 7U - (cidr % 8U);
32 	node->bitlen = bits;
33 	memcpy(node->bits, src, bits / 8U);
34 }
35 
36 static inline u8 choose(struct allowedips_node *node, const u8 *key)
37 {
38 	return (key[node->bit_at_a] >> node->bit_at_b) & 1;
39 }
40 
41 static void push_rcu(struct allowedips_node **stack,
42 		     struct allowedips_node __rcu *p, unsigned int *len)
43 {
44 	if (rcu_access_pointer(p)) {
45 		if (WARN_ON(IS_ENABLED(DEBUG) && *len >= MAX_ALLOWEDIPS_DEPTH))
46 			return;
47 		stack[(*len)++] = rcu_dereference_raw(p);
48 	}
49 }
50 
51 static void root_free_rcu(struct rcu_head *rcu)
52 {
53 	struct allowedips_node *node, *stack[MAX_ALLOWEDIPS_DEPTH] = {
54 		container_of(rcu, struct allowedips_node, rcu) };
55 	unsigned int len = 1;
56 
57 	while (len > 0 && (node = stack[--len])) {
58 		push_rcu(stack, node->bit[0], &len);
59 		push_rcu(stack, node->bit[1], &len);
60 		kmem_cache_free(node_cache, node);
61 	}
62 }
63 
64 static void root_remove_peer_lists(struct allowedips_node *root)
65 {
66 	struct allowedips_node *node, *stack[MAX_ALLOWEDIPS_DEPTH] = { root };
67 	unsigned int len = 1;
68 
69 	while (len > 0 && (node = stack[--len])) {
70 		push_rcu(stack, node->bit[0], &len);
71 		push_rcu(stack, node->bit[1], &len);
72 		if (rcu_access_pointer(node->peer))
73 			list_del(&node->peer_list);
74 	}
75 }
76 
77 static unsigned int fls128(u64 a, u64 b)
78 {
79 	return a ? fls64(a) + 64U : fls64(b);
80 }
81 
82 static u8 common_bits(const struct allowedips_node *node, const u8 *key,
83 		      u8 bits)
84 {
85 	if (bits == 32)
86 		return 32U - fls(*(const u32 *)node->bits ^ *(const u32 *)key);
87 	else if (bits == 128)
88 		return 128U - fls128(
89 			*(const u64 *)&node->bits[0] ^ *(const u64 *)&key[0],
90 			*(const u64 *)&node->bits[8] ^ *(const u64 *)&key[8]);
91 	return 0;
92 }
93 
94 static bool prefix_matches(const struct allowedips_node *node, const u8 *key,
95 			   u8 bits)
96 {
97 	/* This could be much faster if it actually just compared the common
98 	 * bits properly, by precomputing a mask bswap(~0 << (32 - cidr)), and
99 	 * the rest, but it turns out that common_bits is already super fast on
100 	 * modern processors, even taking into account the unfortunate bswap.
101 	 * So, we just inline it like this instead.
102 	 */
103 	return common_bits(node, key, bits) >= node->cidr;
104 }
105 
106 static struct allowedips_node *find_node(struct allowedips_node *trie, u8 bits,
107 					 const u8 *key)
108 {
109 	struct allowedips_node *node = trie, *found = NULL;
110 
111 	while (node && prefix_matches(node, key, bits)) {
112 		if (rcu_access_pointer(node->peer))
113 			found = node;
114 		if (node->cidr == bits)
115 			break;
116 		node = rcu_dereference_bh(node->bit[choose(node, key)]);
117 	}
118 	return found;
119 }
120 
121 /* Returns a strong reference to a peer */
122 static struct wg_peer *lookup(struct allowedips_node __rcu *root, u8 bits,
123 			      const void *be_ip)
124 {
125 	/* Aligned so it can be passed to fls/fls64 */
126 	u8 ip[16] __aligned(__alignof(u64));
127 	struct allowedips_node *node;
128 	struct wg_peer *peer = NULL;
129 
130 	swap_endian(ip, be_ip, bits);
131 
132 	rcu_read_lock_bh();
133 retry:
134 	node = find_node(rcu_dereference_bh(root), bits, ip);
135 	if (node) {
136 		peer = wg_peer_get_maybe_zero(rcu_dereference_bh(node->peer));
137 		if (!peer)
138 			goto retry;
139 	}
140 	rcu_read_unlock_bh();
141 	return peer;
142 }
143 
144 static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key,
145 			   u8 cidr, u8 bits, struct allowedips_node **rnode,
146 			   struct mutex *lock)
147 {
148 	struct allowedips_node *node = rcu_dereference_protected(trie, lockdep_is_held(lock));
149 	struct allowedips_node *parent = NULL;
150 	bool exact = false;
151 
152 	while (node && node->cidr <= cidr && prefix_matches(node, key, bits)) {
153 		parent = node;
154 		if (parent->cidr == cidr) {
155 			exact = true;
156 			break;
157 		}
158 		node = rcu_dereference_protected(parent->bit[choose(parent, key)], lockdep_is_held(lock));
159 	}
160 	*rnode = parent;
161 	return exact;
162 }
163 
164 static inline void connect_node(struct allowedips_node __rcu **parent, u8 bit, struct allowedips_node *node)
165 {
166 	node->parent_bit_packed = (unsigned long)parent | bit;
167 	rcu_assign_pointer(*parent, node);
168 }
169 
170 static inline void choose_and_connect_node(struct allowedips_node *parent, struct allowedips_node *node)
171 {
172 	u8 bit = choose(parent, node->bits);
173 	connect_node(&parent->bit[bit], bit, node);
174 }
175 
176 static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
177 	       u8 cidr, struct wg_peer *peer, struct mutex *lock)
178 {
179 	struct allowedips_node *node, *parent, *down, *newnode;
180 
181 	if (unlikely(cidr > bits || !peer))
182 		return -EINVAL;
183 
184 	if (!rcu_access_pointer(*trie)) {
185 		node = kmem_cache_zalloc(node_cache, GFP_KERNEL);
186 		if (unlikely(!node))
187 			return -ENOMEM;
188 		RCU_INIT_POINTER(node->peer, peer);
189 		list_add_tail(&node->peer_list, &peer->allowedips_list);
190 		copy_and_assign_cidr(node, key, cidr, bits);
191 		connect_node(trie, 2, node);
192 		return 0;
193 	}
194 	if (node_placement(*trie, key, cidr, bits, &node, lock)) {
195 		rcu_assign_pointer(node->peer, peer);
196 		list_move_tail(&node->peer_list, &peer->allowedips_list);
197 		return 0;
198 	}
199 
200 	newnode = kmem_cache_zalloc(node_cache, GFP_KERNEL);
201 	if (unlikely(!newnode))
202 		return -ENOMEM;
203 	RCU_INIT_POINTER(newnode->peer, peer);
204 	list_add_tail(&newnode->peer_list, &peer->allowedips_list);
205 	copy_and_assign_cidr(newnode, key, cidr, bits);
206 
207 	if (!node) {
208 		down = rcu_dereference_protected(*trie, lockdep_is_held(lock));
209 	} else {
210 		const u8 bit = choose(node, key);
211 		down = rcu_dereference_protected(node->bit[bit], lockdep_is_held(lock));
212 		if (!down) {
213 			connect_node(&node->bit[bit], bit, newnode);
214 			return 0;
215 		}
216 	}
217 	cidr = min(cidr, common_bits(down, key, bits));
218 	parent = node;
219 
220 	if (newnode->cidr == cidr) {
221 		choose_and_connect_node(newnode, down);
222 		if (!parent)
223 			connect_node(trie, 2, newnode);
224 		else
225 			choose_and_connect_node(parent, newnode);
226 		return 0;
227 	}
228 
229 	node = kmem_cache_zalloc(node_cache, GFP_KERNEL);
230 	if (unlikely(!node)) {
231 		list_del(&newnode->peer_list);
232 		kmem_cache_free(node_cache, newnode);
233 		return -ENOMEM;
234 	}
235 	INIT_LIST_HEAD(&node->peer_list);
236 	copy_and_assign_cidr(node, newnode->bits, cidr, bits);
237 
238 	choose_and_connect_node(node, down);
239 	choose_and_connect_node(node, newnode);
240 	if (!parent)
241 		connect_node(trie, 2, node);
242 	else
243 		choose_and_connect_node(parent, node);
244 	return 0;
245 }
246 
247 static void remove_node(struct allowedips_node *node, struct mutex *lock)
248 {
249 	struct allowedips_node *child, **parent_bit, *parent;
250 	bool free_parent;
251 
252 	list_del_init(&node->peer_list);
253 	RCU_INIT_POINTER(node->peer, NULL);
254 	if (node->bit[0] && node->bit[1])
255 		return;
256 	child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])],
257 					  lockdep_is_held(lock));
258 	if (child)
259 		child->parent_bit_packed = node->parent_bit_packed;
260 	parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL);
261 	*parent_bit = child;
262 	parent = (void *)parent_bit -
263 			offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]);
264 	free_parent = !rcu_access_pointer(node->bit[0]) && !rcu_access_pointer(node->bit[1]) &&
265 			(node->parent_bit_packed & 3) <= 1 && !rcu_access_pointer(parent->peer);
266 	if (free_parent)
267 		child = rcu_dereference_protected(parent->bit[!(node->parent_bit_packed & 1)],
268 						  lockdep_is_held(lock));
269 	kfree_rcu(node, rcu);
270 	if (!free_parent)
271 		return;
272 	if (child)
273 		child->parent_bit_packed = parent->parent_bit_packed;
274 	*(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child;
275 	kfree_rcu(parent, rcu);
276 }
277 
278 static int remove(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
279 		  u8 cidr, struct wg_peer *peer, struct mutex *lock)
280 {
281 	struct allowedips_node *node;
282 
283 	if (unlikely(cidr > bits))
284 		return -EINVAL;
285 	if (!rcu_access_pointer(*trie) || !node_placement(*trie, key, cidr, bits, &node, lock) ||
286 	    peer != rcu_access_pointer(node->peer))
287 		return 0;
288 
289 	remove_node(node, lock);
290 	return 0;
291 }
292 
293 void wg_allowedips_init(struct allowedips *table)
294 {
295 	table->root4 = table->root6 = NULL;
296 	table->seq = 1;
297 }
298 
299 void wg_allowedips_free(struct allowedips *table, struct mutex *lock)
300 {
301 	struct allowedips_node __rcu *old4 = table->root4, *old6 = table->root6;
302 
303 	++table->seq;
304 	RCU_INIT_POINTER(table->root4, NULL);
305 	RCU_INIT_POINTER(table->root6, NULL);
306 	if (rcu_access_pointer(old4)) {
307 		struct allowedips_node *node = rcu_dereference_protected(old4,
308 							lockdep_is_held(lock));
309 
310 		root_remove_peer_lists(node);
311 		call_rcu(&node->rcu, root_free_rcu);
312 	}
313 	if (rcu_access_pointer(old6)) {
314 		struct allowedips_node *node = rcu_dereference_protected(old6,
315 							lockdep_is_held(lock));
316 
317 		root_remove_peer_lists(node);
318 		call_rcu(&node->rcu, root_free_rcu);
319 	}
320 }
321 
322 int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip,
323 			    u8 cidr, struct wg_peer *peer, struct mutex *lock)
324 {
325 	/* Aligned so it can be passed to fls */
326 	u8 key[4] __aligned(__alignof(u32));
327 
328 	++table->seq;
329 	swap_endian(key, (const u8 *)ip, 32);
330 	return add(&table->root4, 32, key, cidr, peer, lock);
331 }
332 
333 int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
334 			    u8 cidr, struct wg_peer *peer, struct mutex *lock)
335 {
336 	/* Aligned so it can be passed to fls64 */
337 	u8 key[16] __aligned(__alignof(u64));
338 
339 	++table->seq;
340 	swap_endian(key, (const u8 *)ip, 128);
341 	return add(&table->root6, 128, key, cidr, peer, lock);
342 }
343 
344 int wg_allowedips_remove_v4(struct allowedips *table, const struct in_addr *ip,
345 			    u8 cidr, struct wg_peer *peer, struct mutex *lock)
346 {
347 	/* Aligned so it can be passed to fls */
348 	u8 key[4] __aligned(__alignof(u32));
349 
350 	++table->seq;
351 	swap_endian(key, (const u8 *)ip, 32);
352 	return remove(&table->root4, 32, key, cidr, peer, lock);
353 }
354 
355 int wg_allowedips_remove_v6(struct allowedips *table, const struct in6_addr *ip,
356 			    u8 cidr, struct wg_peer *peer, struct mutex *lock)
357 {
358 	/* Aligned so it can be passed to fls64 */
359 	u8 key[16] __aligned(__alignof(u64));
360 
361 	++table->seq;
362 	swap_endian(key, (const u8 *)ip, 128);
363 	return remove(&table->root6, 128, key, cidr, peer, lock);
364 }
365 
366 void wg_allowedips_remove_by_peer(struct allowedips *table,
367 				  struct wg_peer *peer, struct mutex *lock)
368 {
369 	struct allowedips_node *node, *tmp;
370 
371 	if (list_empty(&peer->allowedips_list))
372 		return;
373 	++table->seq;
374 	list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list)
375 		remove_node(node, lock);
376 }
377 
378 int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr)
379 {
380 	const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U);
381 	swap_endian(ip, node->bits, node->bitlen);
382 	memset(ip + cidr_bytes, 0, node->bitlen / 8U - cidr_bytes);
383 	if (node->cidr)
384 		ip[cidr_bytes - 1U] &= ~0U << (-node->cidr % 8U);
385 
386 	*cidr = node->cidr;
387 	return node->bitlen == 32 ? AF_INET : AF_INET6;
388 }
389 
390 /* Returns a strong reference to a peer */
391 struct wg_peer *wg_allowedips_lookup_dst(struct allowedips *table,
392 					 struct sk_buff *skb)
393 {
394 	if (skb->protocol == htons(ETH_P_IP))
395 		return lookup(table->root4, 32, &ip_hdr(skb)->daddr);
396 	else if (skb->protocol == htons(ETH_P_IPV6))
397 		return lookup(table->root6, 128, &ipv6_hdr(skb)->daddr);
398 	return NULL;
399 }
400 
401 /* Returns a strong reference to a peer */
402 struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table,
403 					 struct sk_buff *skb)
404 {
405 	if (skb->protocol == htons(ETH_P_IP))
406 		return lookup(table->root4, 32, &ip_hdr(skb)->saddr);
407 	else if (skb->protocol == htons(ETH_P_IPV6))
408 		return lookup(table->root6, 128, &ipv6_hdr(skb)->saddr);
409 	return NULL;
410 }
411 
412 int __init wg_allowedips_slab_init(void)
413 {
414 	node_cache = KMEM_CACHE(allowedips_node, 0);
415 	return node_cache ? 0 : -ENOMEM;
416 }
417 
418 void wg_allowedips_slab_uninit(void)
419 {
420 	rcu_barrier();
421 	kmem_cache_destroy(node_cache);
422 }
423 
424 #include "selftest/allowedips.c"
425