1 // SPDX-License-Identifier: GPL-2.0 2 /* 3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. 4 */ 5 6 #include "allowedips.h" 7 #include "peer.h" 8 9 static void swap_endian(u8 *dst, const u8 *src, u8 bits) 10 { 11 if (bits == 32) { 12 *(u32 *)dst = be32_to_cpu(*(const __be32 *)src); 13 } else if (bits == 128) { 14 ((u64 *)dst)[0] = be64_to_cpu(((const __be64 *)src)[0]); 15 ((u64 *)dst)[1] = be64_to_cpu(((const __be64 *)src)[1]); 16 } 17 } 18 19 static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, 20 u8 cidr, u8 bits) 21 { 22 node->cidr = cidr; 23 node->bit_at_a = cidr / 8U; 24 #ifdef __LITTLE_ENDIAN 25 node->bit_at_a ^= (bits / 8U - 1U) % 8U; 26 #endif 27 node->bit_at_b = 7U - (cidr % 8U); 28 node->bitlen = bits; 29 memcpy(node->bits, src, bits / 8U); 30 } 31 #define CHOOSE_NODE(parent, key) \ 32 parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1] 33 34 static void push_rcu(struct allowedips_node **stack, 35 struct allowedips_node __rcu *p, unsigned int *len) 36 { 37 if (rcu_access_pointer(p)) { 38 WARN_ON(IS_ENABLED(DEBUG) && *len >= 128); 39 stack[(*len)++] = rcu_dereference_raw(p); 40 } 41 } 42 43 static void root_free_rcu(struct rcu_head *rcu) 44 { 45 struct allowedips_node *node, *stack[128] = { 46 container_of(rcu, struct allowedips_node, rcu) }; 47 unsigned int len = 1; 48 49 while (len > 0 && (node = stack[--len])) { 50 push_rcu(stack, node->bit[0], &len); 51 push_rcu(stack, node->bit[1], &len); 52 kfree(node); 53 } 54 } 55 56 static void root_remove_peer_lists(struct allowedips_node *root) 57 { 58 struct allowedips_node *node, *stack[128] = { root }; 59 unsigned int len = 1; 60 61 while (len > 0 && (node = stack[--len])) { 62 push_rcu(stack, node->bit[0], &len); 63 push_rcu(stack, node->bit[1], &len); 64 if (rcu_access_pointer(node->peer)) 65 list_del(&node->peer_list); 66 } 67 } 68 69 static void walk_remove_by_peer(struct allowedips_node __rcu **top, 70 struct wg_peer *peer, struct mutex *lock) 71 { 72 #define REF(p) rcu_access_pointer(p) 73 #define DEREF(p) rcu_dereference_protected(*(p), lockdep_is_held(lock)) 74 #define PUSH(p) ({ \ 75 WARN_ON(IS_ENABLED(DEBUG) && len >= 128); \ 76 stack[len++] = p; \ 77 }) 78 79 struct allowedips_node __rcu **stack[128], **nptr; 80 struct allowedips_node *node, *prev; 81 unsigned int len; 82 83 if (unlikely(!peer || !REF(*top))) 84 return; 85 86 for (prev = NULL, len = 0, PUSH(top); len > 0; prev = node) { 87 nptr = stack[len - 1]; 88 node = DEREF(nptr); 89 if (!node) { 90 --len; 91 continue; 92 } 93 if (!prev || REF(prev->bit[0]) == node || 94 REF(prev->bit[1]) == node) { 95 if (REF(node->bit[0])) 96 PUSH(&node->bit[0]); 97 else if (REF(node->bit[1])) 98 PUSH(&node->bit[1]); 99 } else if (REF(node->bit[0]) == prev) { 100 if (REF(node->bit[1])) 101 PUSH(&node->bit[1]); 102 } else { 103 if (rcu_dereference_protected(node->peer, 104 lockdep_is_held(lock)) == peer) { 105 RCU_INIT_POINTER(node->peer, NULL); 106 list_del_init(&node->peer_list); 107 if (!node->bit[0] || !node->bit[1]) { 108 rcu_assign_pointer(*nptr, DEREF( 109 &node->bit[!REF(node->bit[0])])); 110 kfree_rcu(node, rcu); 111 node = DEREF(nptr); 112 } 113 } 114 --len; 115 } 116 } 117 118 #undef REF 119 #undef DEREF 120 #undef PUSH 121 } 122 123 static unsigned int fls128(u64 a, u64 b) 124 { 125 return a ? fls64(a) + 64U : fls64(b); 126 } 127 128 static u8 common_bits(const struct allowedips_node *node, const u8 *key, 129 u8 bits) 130 { 131 if (bits == 32) 132 return 32U - fls(*(const u32 *)node->bits ^ *(const u32 *)key); 133 else if (bits == 128) 134 return 128U - fls128( 135 *(const u64 *)&node->bits[0] ^ *(const u64 *)&key[0], 136 *(const u64 *)&node->bits[8] ^ *(const u64 *)&key[8]); 137 return 0; 138 } 139 140 static bool prefix_matches(const struct allowedips_node *node, const u8 *key, 141 u8 bits) 142 { 143 /* This could be much faster if it actually just compared the common 144 * bits properly, by precomputing a mask bswap(~0 << (32 - cidr)), and 145 * the rest, but it turns out that common_bits is already super fast on 146 * modern processors, even taking into account the unfortunate bswap. 147 * So, we just inline it like this instead. 148 */ 149 return common_bits(node, key, bits) >= node->cidr; 150 } 151 152 static struct allowedips_node *find_node(struct allowedips_node *trie, u8 bits, 153 const u8 *key) 154 { 155 struct allowedips_node *node = trie, *found = NULL; 156 157 while (node && prefix_matches(node, key, bits)) { 158 if (rcu_access_pointer(node->peer)) 159 found = node; 160 if (node->cidr == bits) 161 break; 162 node = rcu_dereference_bh(CHOOSE_NODE(node, key)); 163 } 164 return found; 165 } 166 167 /* Returns a strong reference to a peer */ 168 static struct wg_peer *lookup(struct allowedips_node __rcu *root, u8 bits, 169 const void *be_ip) 170 { 171 /* Aligned so it can be passed to fls/fls64 */ 172 u8 ip[16] __aligned(__alignof(u64)); 173 struct allowedips_node *node; 174 struct wg_peer *peer = NULL; 175 176 swap_endian(ip, be_ip, bits); 177 178 rcu_read_lock_bh(); 179 retry: 180 node = find_node(rcu_dereference_bh(root), bits, ip); 181 if (node) { 182 peer = wg_peer_get_maybe_zero(rcu_dereference_bh(node->peer)); 183 if (!peer) 184 goto retry; 185 } 186 rcu_read_unlock_bh(); 187 return peer; 188 } 189 190 static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key, 191 u8 cidr, u8 bits, struct allowedips_node **rnode, 192 struct mutex *lock) 193 { 194 struct allowedips_node *node = rcu_dereference_protected(trie, 195 lockdep_is_held(lock)); 196 struct allowedips_node *parent = NULL; 197 bool exact = false; 198 199 while (node && node->cidr <= cidr && prefix_matches(node, key, bits)) { 200 parent = node; 201 if (parent->cidr == cidr) { 202 exact = true; 203 break; 204 } 205 node = rcu_dereference_protected(CHOOSE_NODE(parent, key), 206 lockdep_is_held(lock)); 207 } 208 *rnode = parent; 209 return exact; 210 } 211 212 static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, 213 u8 cidr, struct wg_peer *peer, struct mutex *lock) 214 { 215 struct allowedips_node *node, *parent, *down, *newnode; 216 217 if (unlikely(cidr > bits || !peer)) 218 return -EINVAL; 219 220 if (!rcu_access_pointer(*trie)) { 221 node = kzalloc(sizeof(*node), GFP_KERNEL); 222 if (unlikely(!node)) 223 return -ENOMEM; 224 RCU_INIT_POINTER(node->peer, peer); 225 list_add_tail(&node->peer_list, &peer->allowedips_list); 226 copy_and_assign_cidr(node, key, cidr, bits); 227 rcu_assign_pointer(*trie, node); 228 return 0; 229 } 230 if (node_placement(*trie, key, cidr, bits, &node, lock)) { 231 rcu_assign_pointer(node->peer, peer); 232 list_move_tail(&node->peer_list, &peer->allowedips_list); 233 return 0; 234 } 235 236 newnode = kzalloc(sizeof(*newnode), GFP_KERNEL); 237 if (unlikely(!newnode)) 238 return -ENOMEM; 239 RCU_INIT_POINTER(newnode->peer, peer); 240 list_add_tail(&newnode->peer_list, &peer->allowedips_list); 241 copy_and_assign_cidr(newnode, key, cidr, bits); 242 243 if (!node) { 244 down = rcu_dereference_protected(*trie, lockdep_is_held(lock)); 245 } else { 246 down = rcu_dereference_protected(CHOOSE_NODE(node, key), 247 lockdep_is_held(lock)); 248 if (!down) { 249 rcu_assign_pointer(CHOOSE_NODE(node, key), newnode); 250 return 0; 251 } 252 } 253 cidr = min(cidr, common_bits(down, key, bits)); 254 parent = node; 255 256 if (newnode->cidr == cidr) { 257 rcu_assign_pointer(CHOOSE_NODE(newnode, down->bits), down); 258 if (!parent) 259 rcu_assign_pointer(*trie, newnode); 260 else 261 rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits), 262 newnode); 263 } else { 264 node = kzalloc(sizeof(*node), GFP_KERNEL); 265 if (unlikely(!node)) { 266 list_del(&newnode->peer_list); 267 kfree(newnode); 268 return -ENOMEM; 269 } 270 INIT_LIST_HEAD(&node->peer_list); 271 copy_and_assign_cidr(node, newnode->bits, cidr, bits); 272 273 rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down); 274 rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode); 275 if (!parent) 276 rcu_assign_pointer(*trie, node); 277 else 278 rcu_assign_pointer(CHOOSE_NODE(parent, node->bits), 279 node); 280 } 281 return 0; 282 } 283 284 void wg_allowedips_init(struct allowedips *table) 285 { 286 table->root4 = table->root6 = NULL; 287 table->seq = 1; 288 } 289 290 void wg_allowedips_free(struct allowedips *table, struct mutex *lock) 291 { 292 struct allowedips_node __rcu *old4 = table->root4, *old6 = table->root6; 293 294 ++table->seq; 295 RCU_INIT_POINTER(table->root4, NULL); 296 RCU_INIT_POINTER(table->root6, NULL); 297 if (rcu_access_pointer(old4)) { 298 struct allowedips_node *node = rcu_dereference_protected(old4, 299 lockdep_is_held(lock)); 300 301 root_remove_peer_lists(node); 302 call_rcu(&node->rcu, root_free_rcu); 303 } 304 if (rcu_access_pointer(old6)) { 305 struct allowedips_node *node = rcu_dereference_protected(old6, 306 lockdep_is_held(lock)); 307 308 root_remove_peer_lists(node); 309 call_rcu(&node->rcu, root_free_rcu); 310 } 311 } 312 313 int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, 314 u8 cidr, struct wg_peer *peer, struct mutex *lock) 315 { 316 /* Aligned so it can be passed to fls */ 317 u8 key[4] __aligned(__alignof(u32)); 318 319 ++table->seq; 320 swap_endian(key, (const u8 *)ip, 32); 321 return add(&table->root4, 32, key, cidr, peer, lock); 322 } 323 324 int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, 325 u8 cidr, struct wg_peer *peer, struct mutex *lock) 326 { 327 /* Aligned so it can be passed to fls64 */ 328 u8 key[16] __aligned(__alignof(u64)); 329 330 ++table->seq; 331 swap_endian(key, (const u8 *)ip, 128); 332 return add(&table->root6, 128, key, cidr, peer, lock); 333 } 334 335 void wg_allowedips_remove_by_peer(struct allowedips *table, 336 struct wg_peer *peer, struct mutex *lock) 337 { 338 ++table->seq; 339 walk_remove_by_peer(&table->root4, peer, lock); 340 walk_remove_by_peer(&table->root6, peer, lock); 341 } 342 343 int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr) 344 { 345 const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U); 346 swap_endian(ip, node->bits, node->bitlen); 347 memset(ip + cidr_bytes, 0, node->bitlen / 8U - cidr_bytes); 348 if (node->cidr) 349 ip[cidr_bytes - 1U] &= ~0U << (-node->cidr % 8U); 350 351 *cidr = node->cidr; 352 return node->bitlen == 32 ? AF_INET : AF_INET6; 353 } 354 355 /* Returns a strong reference to a peer */ 356 struct wg_peer *wg_allowedips_lookup_dst(struct allowedips *table, 357 struct sk_buff *skb) 358 { 359 if (skb->protocol == htons(ETH_P_IP)) 360 return lookup(table->root4, 32, &ip_hdr(skb)->daddr); 361 else if (skb->protocol == htons(ETH_P_IPV6)) 362 return lookup(table->root6, 128, &ipv6_hdr(skb)->daddr); 363 return NULL; 364 } 365 366 /* Returns a strong reference to a peer */ 367 struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table, 368 struct sk_buff *skb) 369 { 370 if (skb->protocol == htons(ETH_P_IP)) 371 return lookup(table->root4, 32, &ip_hdr(skb)->saddr); 372 else if (skb->protocol == htons(ETH_P_IPV6)) 373 return lookup(table->root6, 128, &ipv6_hdr(skb)->saddr); 374 return NULL; 375 } 376 377 #include "selftest/allowedips.c" 378