xref: /linux/net/ipv6/inet6_hashtables.c (revision 2ed4b46b4fc77749cb0f8dd31a01441b82c8dbaa)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * INET		An implementation of the TCP/IP protocol suite for the LINUX
4  *		operating system.  INET is implemented using the BSD Socket
5  *		interface as the means of communication with the user level.
6  *
7  *		Generic INET6 transport hashtables
8  *
9  * Authors:	Lotsa people, from code originally in tcp, generalised here
10  *		by Arnaldo Carvalho de Melo <acme@mandriva.com>
11  */
12 
13 #include <linux/module.h>
14 #include <linux/random.h>
15 
16 #include <net/addrconf.h>
17 #include <net/hotdata.h>
18 #include <net/inet_connection_sock.h>
19 #include <net/inet_hashtables.h>
20 #include <net/inet6_hashtables.h>
21 #include <net/secure_seq.h>
22 #include <net/ip.h>
23 #include <net/sock_reuseport.h>
24 #include <net/tcp.h>
25 
26 void inet6_init_ehash_secret(void)
27 {
28 	net_get_random_sleepable_once(&inet6_ehash_secret,
29 				      sizeof(inet6_ehash_secret));
30 	net_get_random_sleepable_once(&tcp_ipv6_hash_secret,
31 				      sizeof(tcp_ipv6_hash_secret));
32 }
33 
34 u32 inet6_ehashfn(const struct net *net,
35 		  const struct in6_addr *laddr, const u16 lport,
36 		  const struct in6_addr *faddr, const __be16 fport)
37 {
38 	u32 lhash, fhash;
39 
40 	lhash = (__force u32)laddr->s6_addr32[3];
41 	fhash = __ipv6_addr_jhash(faddr, tcp_ipv6_hash_secret);
42 
43 	return lport + __inet6_ehashfn(lhash, 0, fhash, fport,
44 				       inet6_ehash_secret + net_hash_mix(net));
45 }
46 EXPORT_SYMBOL_GPL(inet6_ehashfn);
47 
48 /*
49  * Sockets in TCP_CLOSE state are _always_ taken out of the hash, so
50  * we need not check it for TCP lookups anymore, thanks Alexey. -DaveM
51  *
52  * The sockhash lock must be held as a reader here.
53  */
54 struct sock *__inet6_lookup_established(const struct net *net,
55 					const struct in6_addr *saddr,
56 					const __be16 sport,
57 					const struct in6_addr *daddr,
58 					const u16 hnum,
59 					const int dif, const int sdif)
60 {
61 	const __portpair ports = INET_COMBINED_PORTS(sport, hnum);
62 	const struct hlist_nulls_node *node;
63 	struct inet_ehash_bucket *head;
64 	struct inet_hashinfo *hashinfo;
65 	unsigned int hash, slot;
66 	struct sock *sk;
67 
68 	hashinfo = net->ipv4.tcp_death_row.hashinfo;
69 	hash = inet6_ehashfn(net, daddr, hnum, saddr, sport);
70 	slot = hash & hashinfo->ehash_mask;
71 	head = &hashinfo->ehash[slot];
72 begin:
73 	sk_nulls_for_each_rcu(sk, node, &head->chain) {
74 		if (sk->sk_hash != hash)
75 			continue;
76 		if (!inet6_match(net, sk, saddr, daddr, ports, dif, sdif))
77 			continue;
78 		if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
79 			goto out;
80 
81 		if (unlikely(!inet6_match(net, sk, saddr, daddr, ports, dif, sdif))) {
82 			sock_gen_put(sk);
83 			goto begin;
84 		}
85 		goto found;
86 	}
87 	if (get_nulls_value(node) != slot)
88 		goto begin;
89 out:
90 	sk = NULL;
91 found:
92 	return sk;
93 }
94 EXPORT_SYMBOL(__inet6_lookup_established);
95 
96 static inline int compute_score(struct sock *sk, const struct net *net,
97 				const unsigned short hnum,
98 				const struct in6_addr *daddr,
99 				const int dif, const int sdif)
100 {
101 	int score = -1;
102 
103 	if (net_eq(sock_net(sk), net) &&
104 	    READ_ONCE(inet_sk(sk)->inet_num) == hnum &&
105 	    sk->sk_family == PF_INET6) {
106 		if (!ipv6_addr_equal(&sk->sk_v6_rcv_saddr, daddr))
107 			return -1;
108 
109 		if (!inet_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif))
110 			return -1;
111 
112 		score =  sk->sk_bound_dev_if ? 2 : 1;
113 		if (READ_ONCE(sk->sk_incoming_cpu) == raw_smp_processor_id())
114 			score++;
115 	}
116 	return score;
117 }
118 
119 /**
120  * inet6_lookup_reuseport() - execute reuseport logic on AF_INET6 socket if necessary.
121  * @net: network namespace.
122  * @sk: AF_INET6 socket, must be in TCP_LISTEN state for TCP or TCP_CLOSE for UDP.
123  * @skb: context for a potential SK_REUSEPORT program.
124  * @doff: header offset.
125  * @saddr: source address.
126  * @sport: source port.
127  * @daddr: destination address.
128  * @hnum: destination port in host byte order.
129  * @ehashfn: hash function used to generate the fallback hash.
130  *
131  * Return: NULL if sk doesn't have SO_REUSEPORT set, otherwise a pointer to
132  *         the selected sock or an error.
133  */
134 struct sock *inet6_lookup_reuseport(const struct net *net, struct sock *sk,
135 				    struct sk_buff *skb, int doff,
136 				    const struct in6_addr *saddr,
137 				    __be16 sport,
138 				    const struct in6_addr *daddr,
139 				    unsigned short hnum,
140 				    inet6_ehashfn_t *ehashfn)
141 {
142 	struct sock *reuse_sk = NULL;
143 	u32 phash;
144 
145 	if (sk->sk_reuseport) {
146 		phash = INDIRECT_CALL_INET(ehashfn, udp6_ehashfn, inet6_ehashfn,
147 					   net, daddr, hnum, saddr, sport);
148 		reuse_sk = reuseport_select_sock(sk, phash, skb, doff);
149 	}
150 	return reuse_sk;
151 }
152 EXPORT_SYMBOL_GPL(inet6_lookup_reuseport);
153 
154 /* called with rcu_read_lock() */
155 static struct sock *inet6_lhash2_lookup(const struct net *net,
156 		struct inet_listen_hashbucket *ilb2,
157 		struct sk_buff *skb, int doff,
158 		const struct in6_addr *saddr,
159 		const __be16 sport, const struct in6_addr *daddr,
160 		const unsigned short hnum, const int dif, const int sdif)
161 {
162 	struct sock *sk, *result = NULL;
163 	struct hlist_nulls_node *node;
164 	int score, hiscore = 0;
165 
166 	sk_nulls_for_each_rcu(sk, node, &ilb2->nulls_head) {
167 		score = compute_score(sk, net, hnum, daddr, dif, sdif);
168 		if (score > hiscore) {
169 			result = inet6_lookup_reuseport(net, sk, skb, doff,
170 							saddr, sport, daddr, hnum, inet6_ehashfn);
171 			if (result)
172 				return result;
173 
174 			result = sk;
175 			hiscore = score;
176 		}
177 	}
178 
179 	return result;
180 }
181 
182 struct sock *inet6_lookup_run_sk_lookup(const struct net *net,
183 					int protocol,
184 					struct sk_buff *skb, int doff,
185 					const struct in6_addr *saddr,
186 					const __be16 sport,
187 					const struct in6_addr *daddr,
188 					const u16 hnum, const int dif,
189 					inet6_ehashfn_t *ehashfn)
190 {
191 	struct sock *sk, *reuse_sk;
192 	bool no_reuseport;
193 
194 	no_reuseport = bpf_sk_lookup_run_v6(net, protocol, saddr, sport,
195 					    daddr, hnum, dif, &sk);
196 	if (no_reuseport || IS_ERR_OR_NULL(sk))
197 		return sk;
198 
199 	reuse_sk = inet6_lookup_reuseport(net, sk, skb, doff,
200 					  saddr, sport, daddr, hnum, ehashfn);
201 	if (reuse_sk)
202 		sk = reuse_sk;
203 	return sk;
204 }
205 EXPORT_SYMBOL_GPL(inet6_lookup_run_sk_lookup);
206 
207 struct sock *inet6_lookup_listener(const struct net *net,
208 				   struct sk_buff *skb, int doff,
209 				   const struct in6_addr *saddr,
210 				   const __be16 sport,
211 				   const struct in6_addr *daddr,
212 				   const unsigned short hnum,
213 				   const int dif, const int sdif)
214 {
215 	struct inet_listen_hashbucket *ilb2;
216 	struct inet_hashinfo *hashinfo;
217 	struct sock *result = NULL;
218 	unsigned int hash2;
219 
220 	/* Lookup redirect from BPF */
221 	if (static_branch_unlikely(&bpf_sk_lookup_enabled)) {
222 		result = inet6_lookup_run_sk_lookup(net, IPPROTO_TCP, skb, doff,
223 						    saddr, sport, daddr, hnum, dif,
224 						    inet6_ehashfn);
225 		if (result)
226 			goto done;
227 	}
228 
229 	hashinfo = net->ipv4.tcp_death_row.hashinfo;
230 	hash2 = ipv6_portaddr_hash(net, daddr, hnum);
231 	ilb2 = inet_lhash2_bucket(hashinfo, hash2);
232 
233 	result = inet6_lhash2_lookup(net, ilb2, skb, doff,
234 				     saddr, sport, daddr, hnum,
235 				     dif, sdif);
236 	if (result)
237 		goto done;
238 
239 	/* Lookup lhash2 with in6addr_any */
240 	hash2 = ipv6_portaddr_hash(net, &in6addr_any, hnum);
241 	ilb2 = inet_lhash2_bucket(hashinfo, hash2);
242 
243 	result = inet6_lhash2_lookup(net, ilb2, skb, doff,
244 				     saddr, sport, &in6addr_any, hnum,
245 				     dif, sdif);
246 done:
247 	if (IS_ERR(result))
248 		return NULL;
249 	return result;
250 }
251 EXPORT_SYMBOL_GPL(inet6_lookup_listener);
252 
253 struct sock *inet6_lookup(const struct net *net,
254 			  struct sk_buff *skb, int doff,
255 			  const struct in6_addr *saddr, const __be16 sport,
256 			  const struct in6_addr *daddr, const __be16 dport,
257 			  const int dif)
258 {
259 	struct sock *sk;
260 	bool refcounted;
261 
262 	sk = __inet6_lookup(net, skb, doff, saddr, sport, daddr,
263 			    ntohs(dport), dif, 0, &refcounted);
264 	if (sk && !refcounted && !refcount_inc_not_zero(&sk->sk_refcnt))
265 		sk = NULL;
266 	return sk;
267 }
268 EXPORT_SYMBOL_GPL(inet6_lookup);
269 
270 static int __inet6_check_established(struct inet_timewait_death_row *death_row,
271 				     struct sock *sk, const __u16 lport,
272 				     struct inet_timewait_sock **twp,
273 				     bool rcu_lookup,
274 				     u32 hash)
275 {
276 	struct inet_hashinfo *hinfo = death_row->hashinfo;
277 	struct inet_sock *inet = inet_sk(sk);
278 	const struct in6_addr *daddr = &sk->sk_v6_rcv_saddr;
279 	const struct in6_addr *saddr = &sk->sk_v6_daddr;
280 	const int dif = sk->sk_bound_dev_if;
281 	struct net *net = sock_net(sk);
282 	const int sdif = l3mdev_master_ifindex_by_index(net, dif);
283 	const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport);
284 	struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash);
285 	struct inet_timewait_sock *tw = NULL;
286 	const struct hlist_nulls_node *node;
287 	struct sock *sk2;
288 	spinlock_t *lock;
289 
290 	if (rcu_lookup) {
291 		sk_nulls_for_each(sk2, node, &head->chain) {
292 			if (sk2->sk_hash != hash ||
293 			    !inet6_match(net, sk2, saddr, daddr,
294 					 ports, dif, sdif))
295 				continue;
296 			if (sk2->sk_state == TCP_TIME_WAIT)
297 				break;
298 			return -EADDRNOTAVAIL;
299 		}
300 		return 0;
301 	}
302 
303 	lock = inet_ehash_lockp(hinfo, hash);
304 	spin_lock(lock);
305 
306 	sk_nulls_for_each(sk2, node, &head->chain) {
307 		if (sk2->sk_hash != hash)
308 			continue;
309 
310 		if (likely(inet6_match(net, sk2, saddr, daddr, ports,
311 				       dif, sdif))) {
312 			if (sk2->sk_state == TCP_TIME_WAIT) {
313 				tw = inet_twsk(sk2);
314 				if (tcp_twsk_unique(sk, sk2, twp))
315 					break;
316 			}
317 			goto not_unique;
318 		}
319 	}
320 
321 	/* Must record num and sport now. Otherwise we will see
322 	 * in hash table socket with a funny identity.
323 	 */
324 	inet->inet_num = lport;
325 	inet->inet_sport = htons(lport);
326 	sk->sk_hash = hash;
327 	WARN_ON(!sk_unhashed(sk));
328 	__sk_nulls_add_node_rcu(sk, &head->chain);
329 	if (tw) {
330 		sk_nulls_del_node_init_rcu((struct sock *)tw);
331 		__NET_INC_STATS(net, LINUX_MIB_TIMEWAITRECYCLED);
332 	}
333 	spin_unlock(lock);
334 	sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
335 
336 	if (twp) {
337 		*twp = tw;
338 	} else if (tw) {
339 		/* Silly. Should hash-dance instead... */
340 		inet_twsk_deschedule_put(tw);
341 	}
342 	return 0;
343 
344 not_unique:
345 	spin_unlock(lock);
346 	return -EADDRNOTAVAIL;
347 }
348 
349 static u64 inet6_sk_port_offset(const struct sock *sk)
350 {
351 	const struct inet_sock *inet = inet_sk(sk);
352 
353 	return secure_ipv6_port_ephemeral(sk->sk_v6_rcv_saddr.s6_addr32,
354 					  sk->sk_v6_daddr.s6_addr32,
355 					  inet->inet_dport);
356 }
357 
358 int inet6_hash_connect(struct inet_timewait_death_row *death_row,
359 		       struct sock *sk)
360 {
361 	const struct in6_addr *daddr = &sk->sk_v6_rcv_saddr;
362 	const struct in6_addr *saddr = &sk->sk_v6_daddr;
363 	const struct inet_sock *inet = inet_sk(sk);
364 	const struct net *net = sock_net(sk);
365 	u64 port_offset = 0;
366 	u32 hash_port0;
367 
368 	if (!inet_sk(sk)->inet_num)
369 		port_offset = inet6_sk_port_offset(sk);
370 
371 	inet6_init_ehash_secret();
372 
373 	hash_port0 = inet6_ehashfn(net, daddr, 0, saddr, inet->inet_dport);
374 
375 	return __inet_hash_connect(death_row, sk, port_offset, hash_port0,
376 				   __inet6_check_established);
377 }
378 EXPORT_SYMBOL_GPL(inet6_hash_connect);
379