1 // SPDX-License-Identifier: GPL-2.0 2 /* 3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. 4 */ 5 6 #include "allowedips.h" 7 #include "peer.h" 8 9 enum { MAX_ALLOWEDIPS_DEPTH = 129 }; 10 11 static struct kmem_cache *node_cache; 12 13 static void swap_endian(u8 *dst, const u8 *src, u8 bits) 14 { 15 if (bits == 32) { 16 *(u32 *)dst = be32_to_cpu(*(const __be32 *)src); 17 } else if (bits == 128) { 18 ((u64 *)dst)[0] = get_unaligned_be64(src); 19 ((u64 *)dst)[1] = get_unaligned_be64(src + 8); 20 } 21 } 22 23 static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, 24 u8 cidr, u8 bits) 25 { 26 node->cidr = cidr; 27 node->bit_at_a = cidr / 8U; 28 #ifdef __LITTLE_ENDIAN 29 node->bit_at_a ^= (bits / 8U - 1U) % 8U; 30 #endif 31 node->bit_at_b = 7U - (cidr % 8U); 32 node->bitlen = bits; 33 memcpy(node->bits, src, bits / 8U); 34 } 35 36 static inline u8 choose(struct allowedips_node *node, const u8 *key) 37 { 38 return (key[node->bit_at_a] >> node->bit_at_b) & 1; 39 } 40 41 static void push_rcu(struct allowedips_node **stack, 42 struct allowedips_node __rcu *p, unsigned int *len) 43 { 44 if (rcu_access_pointer(p)) { 45 if (WARN_ON(IS_ENABLED(DEBUG) && *len >= MAX_ALLOWEDIPS_DEPTH)) 46 return; 47 stack[(*len)++] = rcu_dereference_raw(p); 48 } 49 } 50 51 static void root_free_rcu(struct rcu_head *rcu) 52 { 53 struct allowedips_node *node, *stack[MAX_ALLOWEDIPS_DEPTH] = { 54 container_of(rcu, struct allowedips_node, rcu) }; 55 unsigned int len = 1; 56 57 while (len > 0 && (node = stack[--len])) { 58 push_rcu(stack, node->bit[0], &len); 59 push_rcu(stack, node->bit[1], &len); 60 kmem_cache_free(node_cache, node); 61 } 62 } 63 64 static void root_remove_peer_lists(struct allowedips_node *root) 65 { 66 struct allowedips_node *node, *stack[MAX_ALLOWEDIPS_DEPTH] = { root }; 67 unsigned int len = 1; 68 69 while (len > 0 && (node = stack[--len])) { 70 push_rcu(stack, node->bit[0], &len); 71 push_rcu(stack, node->bit[1], &len); 72 if (rcu_access_pointer(node->peer)) 73 list_del(&node->peer_list); 74 } 75 } 76 77 static unsigned int fls128(u64 a, u64 b) 78 { 79 return a ? fls64(a) + 64U : fls64(b); 80 } 81 82 static u8 common_bits(const struct allowedips_node *node, const u8 *key, 83 u8 bits) 84 { 85 if (bits == 32) 86 return 32U - fls(*(const u32 *)node->bits ^ *(const u32 *)key); 87 else if (bits == 128) 88 return 128U - fls128( 89 *(const u64 *)&node->bits[0] ^ *(const u64 *)&key[0], 90 *(const u64 *)&node->bits[8] ^ *(const u64 *)&key[8]); 91 return 0; 92 } 93 94 static bool prefix_matches(const struct allowedips_node *node, const u8 *key, 95 u8 bits) 96 { 97 /* This could be much faster if it actually just compared the common 98 * bits properly, by precomputing a mask bswap(~0 << (32 - cidr)), and 99 * the rest, but it turns out that common_bits is already super fast on 100 * modern processors, even taking into account the unfortunate bswap. 101 * So, we just inline it like this instead. 102 */ 103 return common_bits(node, key, bits) >= node->cidr; 104 } 105 106 static struct allowedips_node *find_node(struct allowedips_node *trie, u8 bits, 107 const u8 *key) 108 { 109 struct allowedips_node *node = trie, *found = NULL; 110 111 while (node && prefix_matches(node, key, bits)) { 112 if (rcu_access_pointer(node->peer)) 113 found = node; 114 if (node->cidr == bits) 115 break; 116 node = rcu_dereference_bh(node->bit[choose(node, key)]); 117 } 118 return found; 119 } 120 121 /* Returns a strong reference to a peer */ 122 static struct wg_peer *lookup(struct allowedips_node __rcu *root, u8 bits, 123 const void *be_ip) 124 { 125 /* Aligned so it can be passed to fls/fls64 */ 126 u8 ip[16] __aligned(__alignof(u64)); 127 struct allowedips_node *node; 128 struct wg_peer *peer = NULL; 129 130 swap_endian(ip, be_ip, bits); 131 132 rcu_read_lock_bh(); 133 retry: 134 node = find_node(rcu_dereference_bh(root), bits, ip); 135 if (node) { 136 peer = wg_peer_get_maybe_zero(rcu_dereference_bh(node->peer)); 137 if (!peer) 138 goto retry; 139 } 140 rcu_read_unlock_bh(); 141 return peer; 142 } 143 144 static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key, 145 u8 cidr, u8 bits, struct allowedips_node **rnode, 146 struct mutex *lock) 147 { 148 struct allowedips_node *node = rcu_dereference_protected(trie, lockdep_is_held(lock)); 149 struct allowedips_node *parent = NULL; 150 bool exact = false; 151 152 while (node && node->cidr <= cidr && prefix_matches(node, key, bits)) { 153 parent = node; 154 if (parent->cidr == cidr) { 155 exact = true; 156 break; 157 } 158 node = rcu_dereference_protected(parent->bit[choose(parent, key)], lockdep_is_held(lock)); 159 } 160 *rnode = parent; 161 return exact; 162 } 163 164 static inline void connect_node(struct allowedips_node __rcu **parent, u8 bit, struct allowedips_node *node) 165 { 166 node->parent_bit_packed = (unsigned long)parent | bit; 167 rcu_assign_pointer(*parent, node); 168 } 169 170 static inline void choose_and_connect_node(struct allowedips_node *parent, struct allowedips_node *node) 171 { 172 u8 bit = choose(parent, node->bits); 173 connect_node(&parent->bit[bit], bit, node); 174 } 175 176 static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, 177 u8 cidr, struct wg_peer *peer, struct mutex *lock) 178 { 179 struct allowedips_node *node, *parent, *down, *newnode; 180 181 if (unlikely(cidr > bits || !peer)) 182 return -EINVAL; 183 184 if (!rcu_access_pointer(*trie)) { 185 node = kmem_cache_zalloc(node_cache, GFP_KERNEL); 186 if (unlikely(!node)) 187 return -ENOMEM; 188 RCU_INIT_POINTER(node->peer, peer); 189 list_add_tail(&node->peer_list, &peer->allowedips_list); 190 copy_and_assign_cidr(node, key, cidr, bits); 191 connect_node(trie, 2, node); 192 return 0; 193 } 194 if (node_placement(*trie, key, cidr, bits, &node, lock)) { 195 rcu_assign_pointer(node->peer, peer); 196 list_move_tail(&node->peer_list, &peer->allowedips_list); 197 return 0; 198 } 199 200 newnode = kmem_cache_zalloc(node_cache, GFP_KERNEL); 201 if (unlikely(!newnode)) 202 return -ENOMEM; 203 RCU_INIT_POINTER(newnode->peer, peer); 204 list_add_tail(&newnode->peer_list, &peer->allowedips_list); 205 copy_and_assign_cidr(newnode, key, cidr, bits); 206 207 if (!node) { 208 down = rcu_dereference_protected(*trie, lockdep_is_held(lock)); 209 } else { 210 const u8 bit = choose(node, key); 211 down = rcu_dereference_protected(node->bit[bit], lockdep_is_held(lock)); 212 if (!down) { 213 connect_node(&node->bit[bit], bit, newnode); 214 return 0; 215 } 216 } 217 cidr = min(cidr, common_bits(down, key, bits)); 218 parent = node; 219 220 if (newnode->cidr == cidr) { 221 choose_and_connect_node(newnode, down); 222 if (!parent) 223 connect_node(trie, 2, newnode); 224 else 225 choose_and_connect_node(parent, newnode); 226 return 0; 227 } 228 229 node = kmem_cache_zalloc(node_cache, GFP_KERNEL); 230 if (unlikely(!node)) { 231 list_del(&newnode->peer_list); 232 kmem_cache_free(node_cache, newnode); 233 return -ENOMEM; 234 } 235 INIT_LIST_HEAD(&node->peer_list); 236 copy_and_assign_cidr(node, newnode->bits, cidr, bits); 237 238 choose_and_connect_node(node, down); 239 choose_and_connect_node(node, newnode); 240 if (!parent) 241 connect_node(trie, 2, node); 242 else 243 choose_and_connect_node(parent, node); 244 return 0; 245 } 246 247 static void remove_node(struct allowedips_node *node, struct mutex *lock) 248 { 249 struct allowedips_node *child, **parent_bit, *parent; 250 bool free_parent; 251 252 list_del_init(&node->peer_list); 253 RCU_INIT_POINTER(node->peer, NULL); 254 if (node->bit[0] && node->bit[1]) 255 return; 256 child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])], 257 lockdep_is_held(lock)); 258 if (child) 259 child->parent_bit_packed = node->parent_bit_packed; 260 parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL); 261 *parent_bit = child; 262 parent = (void *)parent_bit - 263 offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]); 264 free_parent = !rcu_access_pointer(node->bit[0]) && !rcu_access_pointer(node->bit[1]) && 265 (node->parent_bit_packed & 3) <= 1 && !rcu_access_pointer(parent->peer); 266 if (free_parent) 267 child = rcu_dereference_protected(parent->bit[!(node->parent_bit_packed & 1)], 268 lockdep_is_held(lock)); 269 kfree_rcu(node, rcu); 270 if (!free_parent) 271 return; 272 if (child) 273 child->parent_bit_packed = parent->parent_bit_packed; 274 *(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child; 275 kfree_rcu(parent, rcu); 276 } 277 278 static int remove(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, 279 u8 cidr, struct wg_peer *peer, struct mutex *lock) 280 { 281 struct allowedips_node *node; 282 283 if (unlikely(cidr > bits)) 284 return -EINVAL; 285 if (!rcu_access_pointer(*trie) || !node_placement(*trie, key, cidr, bits, &node, lock) || 286 peer != rcu_access_pointer(node->peer)) 287 return 0; 288 289 remove_node(node, lock); 290 return 0; 291 } 292 293 void wg_allowedips_init(struct allowedips *table) 294 { 295 table->root4 = table->root6 = NULL; 296 table->seq = 1; 297 } 298 299 void wg_allowedips_free(struct allowedips *table, struct mutex *lock) 300 { 301 struct allowedips_node __rcu *old4 = table->root4, *old6 = table->root6; 302 303 ++table->seq; 304 RCU_INIT_POINTER(table->root4, NULL); 305 RCU_INIT_POINTER(table->root6, NULL); 306 if (rcu_access_pointer(old4)) { 307 struct allowedips_node *node = rcu_dereference_protected(old4, 308 lockdep_is_held(lock)); 309 310 root_remove_peer_lists(node); 311 call_rcu(&node->rcu, root_free_rcu); 312 } 313 if (rcu_access_pointer(old6)) { 314 struct allowedips_node *node = rcu_dereference_protected(old6, 315 lockdep_is_held(lock)); 316 317 root_remove_peer_lists(node); 318 call_rcu(&node->rcu, root_free_rcu); 319 } 320 } 321 322 int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, 323 u8 cidr, struct wg_peer *peer, struct mutex *lock) 324 { 325 /* Aligned so it can be passed to fls */ 326 u8 key[4] __aligned(__alignof(u32)); 327 328 ++table->seq; 329 swap_endian(key, (const u8 *)ip, 32); 330 return add(&table->root4, 32, key, cidr, peer, lock); 331 } 332 333 int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, 334 u8 cidr, struct wg_peer *peer, struct mutex *lock) 335 { 336 /* Aligned so it can be passed to fls64 */ 337 u8 key[16] __aligned(__alignof(u64)); 338 339 ++table->seq; 340 swap_endian(key, (const u8 *)ip, 128); 341 return add(&table->root6, 128, key, cidr, peer, lock); 342 } 343 344 int wg_allowedips_remove_v4(struct allowedips *table, const struct in_addr *ip, 345 u8 cidr, struct wg_peer *peer, struct mutex *lock) 346 { 347 /* Aligned so it can be passed to fls */ 348 u8 key[4] __aligned(__alignof(u32)); 349 350 ++table->seq; 351 swap_endian(key, (const u8 *)ip, 32); 352 return remove(&table->root4, 32, key, cidr, peer, lock); 353 } 354 355 int wg_allowedips_remove_v6(struct allowedips *table, const struct in6_addr *ip, 356 u8 cidr, struct wg_peer *peer, struct mutex *lock) 357 { 358 /* Aligned so it can be passed to fls64 */ 359 u8 key[16] __aligned(__alignof(u64)); 360 361 ++table->seq; 362 swap_endian(key, (const u8 *)ip, 128); 363 return remove(&table->root6, 128, key, cidr, peer, lock); 364 } 365 366 void wg_allowedips_remove_by_peer(struct allowedips *table, 367 struct wg_peer *peer, struct mutex *lock) 368 { 369 struct allowedips_node *node, *tmp; 370 371 if (list_empty(&peer->allowedips_list)) 372 return; 373 ++table->seq; 374 list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) 375 remove_node(node, lock); 376 } 377 378 int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr) 379 { 380 const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U); 381 swap_endian(ip, node->bits, node->bitlen); 382 memset(ip + cidr_bytes, 0, node->bitlen / 8U - cidr_bytes); 383 if (node->cidr) 384 ip[cidr_bytes - 1U] &= ~0U << (-node->cidr % 8U); 385 386 *cidr = node->cidr; 387 return node->bitlen == 32 ? AF_INET : AF_INET6; 388 } 389 390 /* Returns a strong reference to a peer */ 391 struct wg_peer *wg_allowedips_lookup_dst(struct allowedips *table, 392 struct sk_buff *skb) 393 { 394 if (skb->protocol == htons(ETH_P_IP)) 395 return lookup(table->root4, 32, &ip_hdr(skb)->daddr); 396 else if (skb->protocol == htons(ETH_P_IPV6)) 397 return lookup(table->root6, 128, &ipv6_hdr(skb)->daddr); 398 return NULL; 399 } 400 401 /* Returns a strong reference to a peer */ 402 struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table, 403 struct sk_buff *skb) 404 { 405 if (skb->protocol == htons(ETH_P_IP)) 406 return lookup(table->root4, 32, &ip_hdr(skb)->saddr); 407 else if (skb->protocol == htons(ETH_P_IPV6)) 408 return lookup(table->root6, 128, &ipv6_hdr(skb)->saddr); 409 return NULL; 410 } 411 412 int __init wg_allowedips_slab_init(void) 413 { 414 node_cache = KMEM_CACHE(allowedips_node, 0); 415 return node_cache ? 0 : -ENOMEM; 416 } 417 418 void wg_allowedips_slab_uninit(void) 419 { 420 rcu_barrier(); 421 kmem_cache_destroy(node_cache); 422 } 423 424 #include "selftest/allowedips.c" 425