xref: /linux/drivers/net/wireguard/ratelimiter.c (revision 03ab8e6297acd1bc0eedaa050e2a1635c576fd11)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4  */
5 
6 #include "ratelimiter.h"
7 #include <linux/siphash.h>
8 #include <linux/mm.h>
9 #include <linux/slab.h>
10 #include <net/ip.h>
11 
12 static struct kmem_cache *entry_cache;
13 static hsiphash_key_t key;
14 static spinlock_t table_lock = __SPIN_LOCK_UNLOCKED("ratelimiter_table_lock");
15 static DEFINE_MUTEX(init_lock);
16 static u64 init_refcnt; /* Protected by init_lock, hence not atomic. */
17 static atomic_t total_entries = ATOMIC_INIT(0);
18 static unsigned int max_entries, table_size;
19 static void wg_ratelimiter_gc_entries(struct work_struct *);
20 static DECLARE_DEFERRABLE_WORK(gc_work, wg_ratelimiter_gc_entries);
21 static struct hlist_head *table_v4;
22 #if IS_ENABLED(CONFIG_IPV6)
23 static struct hlist_head *table_v6;
24 #endif
25 
26 struct ratelimiter_entry {
27 	u64 last_time_ns, tokens, ip;
28 	void *net;
29 	spinlock_t lock;
30 	struct hlist_node hash;
31 	struct rcu_head rcu;
32 };
33 
34 enum {
35 	PACKETS_PER_SECOND = 20,
36 	PACKETS_BURSTABLE = 5,
37 	PACKET_COST = NSEC_PER_SEC / PACKETS_PER_SECOND,
38 	TOKEN_MAX = PACKET_COST * PACKETS_BURSTABLE
39 };
40 
entry_free(struct rcu_head * rcu)41 static void entry_free(struct rcu_head *rcu)
42 {
43 	kmem_cache_free(entry_cache,
44 			container_of(rcu, struct ratelimiter_entry, rcu));
45 	atomic_dec(&total_entries);
46 }
47 
entry_uninit(struct ratelimiter_entry * entry)48 static void entry_uninit(struct ratelimiter_entry *entry)
49 {
50 	hlist_del_rcu(&entry->hash);
51 	call_rcu(&entry->rcu, entry_free);
52 }
53 
54 /* Calling this function with a NULL work uninits all entries. */
wg_ratelimiter_gc_entries(struct work_struct * work)55 static void wg_ratelimiter_gc_entries(struct work_struct *work)
56 {
57 	const u64 now = ktime_get_coarse_boottime_ns();
58 	struct ratelimiter_entry *entry;
59 	struct hlist_node *temp;
60 	unsigned int i;
61 
62 	for (i = 0; i < table_size; ++i) {
63 		spin_lock(&table_lock);
64 		hlist_for_each_entry_safe(entry, temp, &table_v4[i], hash) {
65 			if (unlikely(!work) ||
66 			    now - entry->last_time_ns > NSEC_PER_SEC)
67 				entry_uninit(entry);
68 		}
69 #if IS_ENABLED(CONFIG_IPV6)
70 		hlist_for_each_entry_safe(entry, temp, &table_v6[i], hash) {
71 			if (unlikely(!work) ||
72 			    now - entry->last_time_ns > NSEC_PER_SEC)
73 				entry_uninit(entry);
74 		}
75 #endif
76 		spin_unlock(&table_lock);
77 		if (likely(work))
78 			cond_resched();
79 	}
80 	if (likely(work))
81 		queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
82 }
83 
wg_ratelimiter_allow(struct sk_buff * skb,struct net * net)84 bool wg_ratelimiter_allow(struct sk_buff *skb, struct net *net)
85 {
86 	/* We only take the bottom half of the net pointer, so that we can hash
87 	 * 3 words in the end. This way, siphash's len param fits into the final
88 	 * u32, and we don't incur an extra round.
89 	 */
90 	const u32 net_word = (unsigned long)net;
91 	struct ratelimiter_entry *entry;
92 	struct hlist_head *bucket;
93 	u64 ip;
94 
95 	if (skb->protocol == htons(ETH_P_IP)) {
96 		ip = (u64 __force)ip_hdr(skb)->saddr;
97 		bucket = &table_v4[hsiphash_2u32(net_word, ip, &key) &
98 				   (table_size - 1)];
99 	}
100 #if IS_ENABLED(CONFIG_IPV6)
101 	else if (skb->protocol == htons(ETH_P_IPV6)) {
102 		/* Only use 64 bits, so as to ratelimit the whole /64. */
103 		memcpy(&ip, &ipv6_hdr(skb)->saddr, sizeof(ip));
104 		bucket = &table_v6[hsiphash_3u32(net_word, ip >> 32, ip, &key) &
105 				   (table_size - 1)];
106 	}
107 #endif
108 	else
109 		return false;
110 	rcu_read_lock();
111 	hlist_for_each_entry_rcu(entry, bucket, hash) {
112 		if (entry->net == net && entry->ip == ip) {
113 			u64 now, tokens;
114 			bool ret;
115 			/* Quasi-inspired by nft_limit.c, but this is actually a
116 			 * slightly different algorithm. Namely, we incorporate
117 			 * the burst as part of the maximum tokens, rather than
118 			 * as part of the rate.
119 			 */
120 			spin_lock(&entry->lock);
121 			now = ktime_get_coarse_boottime_ns();
122 			tokens = min_t(u64, TOKEN_MAX,
123 				       entry->tokens + now -
124 					       entry->last_time_ns);
125 			entry->last_time_ns = now;
126 			ret = tokens >= PACKET_COST;
127 			entry->tokens = ret ? tokens - PACKET_COST : tokens;
128 			spin_unlock(&entry->lock);
129 			rcu_read_unlock();
130 			return ret;
131 		}
132 	}
133 	rcu_read_unlock();
134 
135 	if (atomic_inc_return(&total_entries) > max_entries)
136 		goto err_oom;
137 
138 	entry = kmem_cache_alloc(entry_cache, GFP_KERNEL);
139 	if (unlikely(!entry))
140 		goto err_oom;
141 
142 	entry->net = net;
143 	entry->ip = ip;
144 	INIT_HLIST_NODE(&entry->hash);
145 	spin_lock_init(&entry->lock);
146 	entry->last_time_ns = ktime_get_coarse_boottime_ns();
147 	entry->tokens = TOKEN_MAX - PACKET_COST;
148 	spin_lock(&table_lock);
149 	hlist_add_head_rcu(&entry->hash, bucket);
150 	spin_unlock(&table_lock);
151 	return true;
152 
153 err_oom:
154 	atomic_dec(&total_entries);
155 	return false;
156 }
157 
wg_ratelimiter_init(void)158 int wg_ratelimiter_init(void)
159 {
160 	mutex_lock(&init_lock);
161 	if (++init_refcnt != 1)
162 		goto out;
163 
164 	entry_cache = KMEM_CACHE(ratelimiter_entry, 0);
165 	if (!entry_cache)
166 		goto err;
167 
168 	/* xt_hashlimit.c uses a slightly different algorithm for ratelimiting,
169 	 * but what it shares in common is that it uses a massive hashtable. So,
170 	 * we borrow their wisdom about good table sizes on different systems
171 	 * dependent on RAM. This calculation here comes from there.
172 	 */
173 	table_size = (totalram_pages() > (1U << 30) / PAGE_SIZE) ? 8192 :
174 		max_t(unsigned long, 16, roundup_pow_of_two(
175 			(totalram_pages() << PAGE_SHIFT) /
176 			(1U << 14) / sizeof(struct hlist_head)));
177 	max_entries = table_size * 8;
178 
179 	table_v4 = kvcalloc(table_size, sizeof(*table_v4), GFP_KERNEL);
180 	if (unlikely(!table_v4))
181 		goto err_kmemcache;
182 
183 #if IS_ENABLED(CONFIG_IPV6)
184 	table_v6 = kvcalloc(table_size, sizeof(*table_v6), GFP_KERNEL);
185 	if (unlikely(!table_v6)) {
186 		kvfree(table_v4);
187 		goto err_kmemcache;
188 	}
189 #endif
190 
191 	queue_delayed_work(system_power_efficient_wq, &gc_work, HZ);
192 	get_random_bytes(&key, sizeof(key));
193 out:
194 	mutex_unlock(&init_lock);
195 	return 0;
196 
197 err_kmemcache:
198 	kmem_cache_destroy(entry_cache);
199 err:
200 	--init_refcnt;
201 	mutex_unlock(&init_lock);
202 	return -ENOMEM;
203 }
204 
wg_ratelimiter_uninit(void)205 void wg_ratelimiter_uninit(void)
206 {
207 	mutex_lock(&init_lock);
208 	if (!init_refcnt || --init_refcnt)
209 		goto out;
210 
211 	cancel_delayed_work_sync(&gc_work);
212 	wg_ratelimiter_gc_entries(NULL);
213 	rcu_barrier();
214 	kvfree(table_v4);
215 #if IS_ENABLED(CONFIG_IPV6)
216 	kvfree(table_v6);
217 #endif
218 	kmem_cache_destroy(entry_cache);
219 out:
220 	mutex_unlock(&init_lock);
221 }
222 
223 #include "selftest/ratelimiter.c"
224