xref: /linux/net/ipv4/inet_diag.c (revision 07fdad3a93756b872da7b53647715c48d0f4a2d0)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * inet_diag.c	Module for monitoring INET transport protocols sockets.
4  *
5  * Authors:	Alexey Kuznetsov, <kuznet@ms2.inr.ac.ru>
6  */
7 
8 #include <linux/kernel.h>
9 #include <linux/module.h>
10 #include <linux/types.h>
11 #include <linux/fcntl.h>
12 #include <linux/random.h>
13 #include <linux/slab.h>
14 #include <linux/cache.h>
15 #include <linux/init.h>
16 #include <linux/time.h>
17 
18 #include <net/icmp.h>
19 #include <net/tcp.h>
20 #include <net/ipv6.h>
21 #include <net/inet_common.h>
22 #include <net/inet_connection_sock.h>
23 #include <net/bpf_sk_storage.h>
24 #include <net/netlink.h>
25 
26 #include <linux/inet.h>
27 #include <linux/stddef.h>
28 
29 #include <linux/inet_diag.h>
30 #include <linux/sock_diag.h>
31 
32 static const struct inet_diag_handler __rcu **inet_diag_table;
33 
34 struct inet_diag_entry {
35 	const __be32 *saddr;
36 	const __be32 *daddr;
37 	u16 sport;
38 	u16 dport;
39 	u16 family;
40 	u16 userlocks;
41 	u32 ifindex;
42 	u32 mark;
43 #ifdef CONFIG_SOCK_CGROUP_DATA
44 	u64 cgroup_id;
45 #endif
46 };
47 
48 static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
49 {
50 	const struct inet_diag_handler *handler;
51 
52 	if (proto < 0 || proto >= IPPROTO_MAX)
53 		return NULL;
54 
55 	if (!READ_ONCE(inet_diag_table[proto]))
56 		sock_load_diag_module(AF_INET, proto);
57 
58 	rcu_read_lock();
59 	handler = rcu_dereference(inet_diag_table[proto]);
60 	if (handler && !try_module_get(handler->owner))
61 		handler = NULL;
62 	rcu_read_unlock();
63 
64 	return handler;
65 }
66 
67 static void inet_diag_unlock_handler(const struct inet_diag_handler *handler)
68 {
69 	module_put(handler->owner);
70 }
71 
72 void inet_diag_msg_common_fill(struct inet_diag_msg *r, struct sock *sk)
73 {
74 	r->idiag_family = READ_ONCE(sk->sk_family);
75 
76 	r->id.idiag_sport = htons(READ_ONCE(sk->sk_num));
77 	r->id.idiag_dport = READ_ONCE(sk->sk_dport);
78 	r->id.idiag_if = READ_ONCE(sk->sk_bound_dev_if);
79 	sock_diag_save_cookie(sk, r->id.idiag_cookie);
80 
81 #if IS_ENABLED(CONFIG_IPV6)
82 	if (r->idiag_family == AF_INET6) {
83 		data_race(*(struct in6_addr *)r->id.idiag_src = sk->sk_v6_rcv_saddr);
84 		data_race(*(struct in6_addr *)r->id.idiag_dst = sk->sk_v6_daddr);
85 	} else
86 #endif
87 	{
88 	memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src));
89 	memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst));
90 
91 	r->id.idiag_src[0] = READ_ONCE(sk->sk_rcv_saddr);
92 	r->id.idiag_dst[0] = READ_ONCE(sk->sk_daddr);
93 	}
94 }
95 EXPORT_SYMBOL_GPL(inet_diag_msg_common_fill);
96 
97 int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb,
98 			     struct inet_diag_msg *r, int ext,
99 			     struct user_namespace *user_ns,
100 			     bool net_admin)
101 {
102 	const struct inet_sock *inet = inet_sk(sk);
103 	struct inet_diag_sockopt inet_sockopt;
104 
105 	if (nla_put_u8(skb, INET_DIAG_SHUTDOWN, sk->sk_shutdown))
106 		goto errout;
107 
108 	/* IPv6 dual-stack sockets use inet->tos for IPv4 connections,
109 	 * hence this needs to be included regardless of socket family.
110 	 */
111 	if (ext & (1 << (INET_DIAG_TOS - 1)))
112 		if (nla_put_u8(skb, INET_DIAG_TOS, READ_ONCE(inet->tos)) < 0)
113 			goto errout;
114 
115 #if IS_ENABLED(CONFIG_IPV6)
116 	if (r->idiag_family == AF_INET6) {
117 		if (ext & (1 << (INET_DIAG_TCLASS - 1)))
118 			if (nla_put_u8(skb, INET_DIAG_TCLASS,
119 				       inet6_sk(sk)->tclass) < 0)
120 				goto errout;
121 
122 		if (((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) &&
123 		    nla_put_u8(skb, INET_DIAG_SKV6ONLY, ipv6_only_sock(sk)))
124 			goto errout;
125 	}
126 #endif
127 
128 	if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, READ_ONCE(sk->sk_mark)))
129 		goto errout;
130 
131 	if (ext & (1 << (INET_DIAG_CLASS_ID - 1)) ||
132 	    ext & (1 << (INET_DIAG_TCLASS - 1))) {
133 		u32 classid = 0;
134 
135 #ifdef CONFIG_CGROUP_NET_CLASSID
136 		classid = sock_cgroup_classid(&sk->sk_cgrp_data);
137 #endif
138 		/* Fallback to socket priority if class id isn't set.
139 		 * Classful qdiscs use it as direct reference to class.
140 		 * For cgroup2 classid is always zero.
141 		 */
142 		if (!classid)
143 			classid = READ_ONCE(sk->sk_priority);
144 
145 		if (nla_put_u32(skb, INET_DIAG_CLASS_ID, classid))
146 			goto errout;
147 	}
148 
149 #ifdef CONFIG_SOCK_CGROUP_DATA
150 	if (nla_put_u64_64bit(skb, INET_DIAG_CGROUP_ID,
151 			      cgroup_id(sock_cgroup_ptr(&sk->sk_cgrp_data)),
152 			      INET_DIAG_PAD))
153 		goto errout;
154 #endif
155 
156 	r->idiag_uid = from_kuid_munged(user_ns, sk_uid(sk));
157 	r->idiag_inode = sock_i_ino(sk);
158 
159 	memset(&inet_sockopt, 0, sizeof(inet_sockopt));
160 	inet_sockopt.recverr	= inet_test_bit(RECVERR, sk);
161 	inet_sockopt.is_icsk	= inet_test_bit(IS_ICSK, sk);
162 	inet_sockopt.freebind	= inet_test_bit(FREEBIND, sk);
163 	inet_sockopt.hdrincl	= inet_test_bit(HDRINCL, sk);
164 	inet_sockopt.mc_loop	= inet_test_bit(MC_LOOP, sk);
165 	inet_sockopt.transparent = inet_test_bit(TRANSPARENT, sk);
166 	inet_sockopt.mc_all	= inet_test_bit(MC_ALL, sk);
167 	inet_sockopt.nodefrag	= inet_test_bit(NODEFRAG, sk);
168 	inet_sockopt.bind_address_no_port = inet_test_bit(BIND_ADDRESS_NO_PORT, sk);
169 	inet_sockopt.recverr_rfc4884 = inet_test_bit(RECVERR_RFC4884, sk);
170 	inet_sockopt.defer_connect = inet_test_bit(DEFER_CONNECT, sk);
171 	if (nla_put(skb, INET_DIAG_SOCKOPT, sizeof(inet_sockopt),
172 		    &inet_sockopt))
173 		goto errout;
174 
175 	return 0;
176 errout:
177 	return 1;
178 }
179 EXPORT_SYMBOL_GPL(inet_diag_msg_attrs_fill);
180 
181 static int inet_diag_parse_attrs(const struct nlmsghdr *nlh, int hdrlen,
182 				 struct nlattr **req_nlas)
183 {
184 	struct nlattr *nla;
185 	int remaining;
186 
187 	nlmsg_for_each_attr(nla, nlh, hdrlen, remaining) {
188 		int type = nla_type(nla);
189 
190 		if (type == INET_DIAG_REQ_PROTOCOL && nla_len(nla) != sizeof(u32))
191 			return -EINVAL;
192 
193 		if (type < __INET_DIAG_REQ_MAX)
194 			req_nlas[type] = nla;
195 	}
196 	return 0;
197 }
198 
199 static int inet_diag_get_protocol(const struct inet_diag_req_v2 *req,
200 				  const struct inet_diag_dump_data *data)
201 {
202 	if (data->req_nlas[INET_DIAG_REQ_PROTOCOL])
203 		return nla_get_u32(data->req_nlas[INET_DIAG_REQ_PROTOCOL]);
204 	return req->sdiag_protocol;
205 }
206 
207 #define MAX_DUMP_ALLOC_SIZE (KMALLOC_MAX_SIZE - SKB_DATA_ALIGN(sizeof(struct skb_shared_info)))
208 
209 int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
210 		      struct sk_buff *skb, struct netlink_callback *cb,
211 		      const struct inet_diag_req_v2 *req,
212 		      u16 nlmsg_flags, bool net_admin)
213 {
214 	const struct tcp_congestion_ops *ca_ops;
215 	const struct inet_diag_handler *handler;
216 	struct inet_diag_dump_data *cb_data;
217 	int ext = req->idiag_ext;
218 	struct inet_diag_msg *r;
219 	struct nlmsghdr  *nlh;
220 	struct nlattr *attr;
221 	void *info = NULL;
222 	u8 icsk_pending;
223 	int protocol;
224 
225 	cb_data = cb->data;
226 	protocol = inet_diag_get_protocol(req, cb_data);
227 
228 	/* inet_diag_lock_handler() made sure inet_diag_table[] is stable. */
229 	handler = rcu_dereference_protected(inet_diag_table[protocol], 1);
230 	DEBUG_NET_WARN_ON_ONCE(!handler);
231 	if (!handler)
232 		return -ENXIO;
233 
234 	nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
235 			cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags);
236 	if (!nlh)
237 		return -EMSGSIZE;
238 
239 	r = nlmsg_data(nlh);
240 	BUG_ON(!sk_fullsock(sk));
241 
242 	inet_diag_msg_common_fill(r, sk);
243 	r->idiag_state = sk->sk_state;
244 	r->idiag_timer = 0;
245 	r->idiag_retrans = 0;
246 	r->idiag_expires = 0;
247 
248 	if (inet_diag_msg_attrs_fill(sk, skb, r, ext,
249 				     sk_user_ns(NETLINK_CB(cb->skb).sk),
250 				     net_admin))
251 		goto errout;
252 
253 	if (ext & (1 << (INET_DIAG_MEMINFO - 1))) {
254 		struct inet_diag_meminfo minfo = {
255 			.idiag_rmem = sk_rmem_alloc_get(sk),
256 			.idiag_wmem = READ_ONCE(sk->sk_wmem_queued),
257 			.idiag_fmem = READ_ONCE(sk->sk_forward_alloc),
258 			.idiag_tmem = sk_wmem_alloc_get(sk),
259 		};
260 
261 		if (nla_put(skb, INET_DIAG_MEMINFO, sizeof(minfo), &minfo) < 0)
262 			goto errout;
263 	}
264 
265 	if (ext & (1 << (INET_DIAG_SKMEMINFO - 1)))
266 		if (sock_diag_put_meminfo(sk, skb, INET_DIAG_SKMEMINFO))
267 			goto errout;
268 
269 	/*
270 	 * RAW sockets might have user-defined protocols assigned,
271 	 * so report the one supplied on socket creation.
272 	 */
273 	if (sk->sk_type == SOCK_RAW) {
274 		if (nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol))
275 			goto errout;
276 	}
277 
278 	if (!icsk) {
279 		handler->idiag_get_info(sk, r, NULL);
280 		goto out;
281 	}
282 
283 	icsk_pending = smp_load_acquire(&icsk->icsk_pending);
284 	if (icsk_pending == ICSK_TIME_RETRANS ||
285 	    icsk_pending == ICSK_TIME_REO_TIMEOUT ||
286 	    icsk_pending == ICSK_TIME_LOSS_PROBE) {
287 		r->idiag_timer = 1;
288 		r->idiag_retrans = READ_ONCE(icsk->icsk_retransmits);
289 		r->idiag_expires =
290 			jiffies_delta_to_msecs(icsk_timeout(icsk) - jiffies);
291 	} else if (icsk_pending == ICSK_TIME_PROBE0) {
292 		r->idiag_timer = 4;
293 		r->idiag_retrans = READ_ONCE(icsk->icsk_probes_out);
294 		r->idiag_expires =
295 			jiffies_delta_to_msecs(icsk_timeout(icsk) - jiffies);
296 	} else if (timer_pending(&sk->sk_timer)) {
297 		r->idiag_timer = 2;
298 		r->idiag_retrans = READ_ONCE(icsk->icsk_probes_out);
299 		r->idiag_expires =
300 			jiffies_delta_to_msecs(sk->sk_timer.expires - jiffies);
301 	}
302 
303 	if ((ext & (1 << (INET_DIAG_INFO - 1))) && handler->idiag_info_size) {
304 		attr = nla_reserve_64bit(skb, INET_DIAG_INFO,
305 					 handler->idiag_info_size,
306 					 INET_DIAG_PAD);
307 		if (!attr)
308 			goto errout;
309 
310 		info = nla_data(attr);
311 	}
312 
313 	if (ext & (1 << (INET_DIAG_CONG - 1))) {
314 		int err = 0;
315 
316 		rcu_read_lock();
317 		ca_ops = READ_ONCE(icsk->icsk_ca_ops);
318 		if (ca_ops)
319 			err = nla_put_string(skb, INET_DIAG_CONG, ca_ops->name);
320 		rcu_read_unlock();
321 		if (err < 0)
322 			goto errout;
323 	}
324 
325 	handler->idiag_get_info(sk, r, info);
326 
327 	if (ext & (1 << (INET_DIAG_INFO - 1)) && handler->idiag_get_aux)
328 		if (handler->idiag_get_aux(sk, net_admin, skb) < 0)
329 			goto errout;
330 
331 	if (sk->sk_state < TCP_TIME_WAIT) {
332 		union tcp_cc_info info;
333 		size_t sz = 0;
334 		int attr;
335 
336 		rcu_read_lock();
337 		ca_ops = READ_ONCE(icsk->icsk_ca_ops);
338 		if (ca_ops && ca_ops->get_info)
339 			sz = ca_ops->get_info(sk, ext, &attr, &info);
340 		rcu_read_unlock();
341 		if (sz && nla_put(skb, attr, sz, &info) < 0)
342 			goto errout;
343 	}
344 
345 	/* Keep it at the end for potential retry with a larger skb,
346 	 * or else do best-effort fitting, which is only done for the
347 	 * first_nlmsg.
348 	 */
349 	if (cb_data->bpf_stg_diag) {
350 		bool first_nlmsg = ((unsigned char *)nlh == skb->data);
351 		unsigned int prev_min_dump_alloc;
352 		unsigned int total_nla_size = 0;
353 		unsigned int msg_len;
354 		int err;
355 
356 		msg_len = skb_tail_pointer(skb) - (unsigned char *)nlh;
357 		err = bpf_sk_storage_diag_put(cb_data->bpf_stg_diag, sk, skb,
358 					      INET_DIAG_SK_BPF_STORAGES,
359 					      &total_nla_size);
360 
361 		if (!err)
362 			goto out;
363 
364 		total_nla_size += msg_len;
365 		prev_min_dump_alloc = cb->min_dump_alloc;
366 		if (total_nla_size > prev_min_dump_alloc)
367 			cb->min_dump_alloc = min_t(u32, total_nla_size,
368 						   MAX_DUMP_ALLOC_SIZE);
369 
370 		if (!first_nlmsg)
371 			goto errout;
372 
373 		if (cb->min_dump_alloc > prev_min_dump_alloc)
374 			/* Retry with pskb_expand_head() with
375 			 * __GFP_DIRECT_RECLAIM
376 			 */
377 			goto errout;
378 
379 		WARN_ON_ONCE(total_nla_size <= prev_min_dump_alloc);
380 
381 		/* Send what we have for this sk
382 		 * and move on to the next sk in the following
383 		 * dump()
384 		 */
385 	}
386 
387 out:
388 	nlmsg_end(skb, nlh);
389 	return 0;
390 
391 errout:
392 	nlmsg_cancel(skb, nlh);
393 	return -EMSGSIZE;
394 }
395 EXPORT_SYMBOL_GPL(inet_sk_diag_fill);
396 
397 static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb,
398 			       const struct nlmsghdr *nlh,
399 			       int hdrlen,
400 			       const struct inet_diag_req_v2 *req)
401 {
402 	const struct inet_diag_handler *handler;
403 	struct inet_diag_dump_data dump_data;
404 	int err, protocol;
405 
406 	memset(&dump_data, 0, sizeof(dump_data));
407 	err = inet_diag_parse_attrs(nlh, hdrlen, dump_data.req_nlas);
408 	if (err)
409 		return err;
410 
411 	protocol = inet_diag_get_protocol(req, &dump_data);
412 
413 	handler = inet_diag_lock_handler(protocol);
414 	if (!handler)
415 		return -ENOENT;
416 
417 	if (cmd == SOCK_DIAG_BY_FAMILY) {
418 		struct netlink_callback cb = {
419 			.nlh = nlh,
420 			.skb = in_skb,
421 			.data = &dump_data,
422 		};
423 		err = handler->dump_one(&cb, req);
424 	} else if (cmd == SOCK_DESTROY && handler->destroy) {
425 		err = handler->destroy(in_skb, req);
426 	} else {
427 		err = -EOPNOTSUPP;
428 	}
429 	inet_diag_unlock_handler(handler);
430 
431 	return err;
432 }
433 
434 static int bitstring_match(const __be32 *a1, const __be32 *a2, int bits)
435 {
436 	int words = bits >> 5;
437 
438 	bits &= 0x1f;
439 
440 	if (words) {
441 		if (memcmp(a1, a2, words << 2))
442 			return 0;
443 	}
444 	if (bits) {
445 		__be32 w1, w2;
446 		__be32 mask;
447 
448 		w1 = a1[words];
449 		w2 = a2[words];
450 
451 		mask = htonl((0xffffffff) << (32 - bits));
452 
453 		if ((w1 ^ w2) & mask)
454 			return 0;
455 	}
456 
457 	return 1;
458 }
459 
460 static int inet_diag_bc_run(const struct nlattr *_bc,
461 			    const struct inet_diag_entry *entry)
462 {
463 	const void *bc = nla_data(_bc);
464 	int len = nla_len(_bc);
465 
466 	while (len > 0) {
467 		int yes = 1;
468 		const struct inet_diag_bc_op *op = bc;
469 
470 		switch (op->code) {
471 		case INET_DIAG_BC_NOP:
472 			break;
473 		case INET_DIAG_BC_JMP:
474 			yes = 0;
475 			break;
476 		case INET_DIAG_BC_S_EQ:
477 			yes = entry->sport == op[1].no;
478 			break;
479 		case INET_DIAG_BC_S_GE:
480 			yes = entry->sport >= op[1].no;
481 			break;
482 		case INET_DIAG_BC_S_LE:
483 			yes = entry->sport <= op[1].no;
484 			break;
485 		case INET_DIAG_BC_D_EQ:
486 			yes = entry->dport == op[1].no;
487 			break;
488 		case INET_DIAG_BC_D_GE:
489 			yes = entry->dport >= op[1].no;
490 			break;
491 		case INET_DIAG_BC_D_LE:
492 			yes = entry->dport <= op[1].no;
493 			break;
494 		case INET_DIAG_BC_AUTO:
495 			yes = !(entry->userlocks & SOCK_BINDPORT_LOCK);
496 			break;
497 		case INET_DIAG_BC_S_COND:
498 		case INET_DIAG_BC_D_COND: {
499 			const struct inet_diag_hostcond *cond;
500 			const __be32 *addr;
501 
502 			cond = (const struct inet_diag_hostcond *)(op + 1);
503 			if (cond->port != -1 &&
504 			    cond->port != (op->code == INET_DIAG_BC_S_COND ?
505 					     entry->sport : entry->dport)) {
506 				yes = 0;
507 				break;
508 			}
509 
510 			if (op->code == INET_DIAG_BC_S_COND)
511 				addr = entry->saddr;
512 			else
513 				addr = entry->daddr;
514 
515 			if (cond->family != AF_UNSPEC &&
516 			    cond->family != entry->family) {
517 				if (entry->family == AF_INET6 &&
518 				    cond->family == AF_INET) {
519 					if (addr[0] == 0 && addr[1] == 0 &&
520 					    addr[2] == htonl(0xffff) &&
521 					    bitstring_match(addr + 3,
522 							    cond->addr,
523 							    cond->prefix_len))
524 						break;
525 				}
526 				yes = 0;
527 				break;
528 			}
529 
530 			if (cond->prefix_len == 0)
531 				break;
532 			if (bitstring_match(addr, cond->addr,
533 					    cond->prefix_len))
534 				break;
535 			yes = 0;
536 			break;
537 		}
538 		case INET_DIAG_BC_DEV_COND: {
539 			u32 ifindex;
540 
541 			ifindex = *((const u32 *)(op + 1));
542 			if (ifindex != entry->ifindex)
543 				yes = 0;
544 			break;
545 		}
546 		case INET_DIAG_BC_MARK_COND: {
547 			struct inet_diag_markcond *cond;
548 
549 			cond = (struct inet_diag_markcond *)(op + 1);
550 			if ((entry->mark & cond->mask) != cond->mark)
551 				yes = 0;
552 			break;
553 		}
554 #ifdef CONFIG_SOCK_CGROUP_DATA
555 		case INET_DIAG_BC_CGROUP_COND: {
556 			u64 cgroup_id;
557 
558 			cgroup_id = get_unaligned((const u64 *)(op + 1));
559 			if (cgroup_id != entry->cgroup_id)
560 				yes = 0;
561 			break;
562 		}
563 #endif
564 		}
565 
566 		if (yes) {
567 			len -= op->yes;
568 			bc += op->yes;
569 		} else {
570 			len -= op->no;
571 			bc += op->no;
572 		}
573 	}
574 	return len == 0;
575 }
576 
577 /* This helper is available for all sockets (ESTABLISH, TIMEWAIT, SYN_RECV)
578  */
579 static void entry_fill_addrs(struct inet_diag_entry *entry,
580 			     const struct sock *sk)
581 {
582 #if IS_ENABLED(CONFIG_IPV6)
583 	if (entry->family == AF_INET6) {
584 		entry->saddr = sk->sk_v6_rcv_saddr.s6_addr32;
585 		entry->daddr = sk->sk_v6_daddr.s6_addr32;
586 	} else
587 #endif
588 	{
589 		entry->saddr = &sk->sk_rcv_saddr;
590 		entry->daddr = &sk->sk_daddr;
591 	}
592 }
593 
594 int inet_diag_bc_sk(const struct inet_diag_dump_data *cb_data, struct sock *sk)
595 {
596 	const struct nlattr *bc = cb_data->inet_diag_nla_bc;
597 	const struct inet_sock *inet = inet_sk(sk);
598 	struct inet_diag_entry entry;
599 
600 	if (!bc)
601 		return 1;
602 
603 	entry.family = READ_ONCE(sk->sk_family);
604 	entry_fill_addrs(&entry, sk);
605 	entry.sport = READ_ONCE(inet->inet_num);
606 	entry.dport = ntohs(READ_ONCE(inet->inet_dport));
607 	entry.ifindex = READ_ONCE(sk->sk_bound_dev_if);
608 	if (cb_data->userlocks_needed)
609 		entry.userlocks = sk_fullsock(sk) ? READ_ONCE(sk->sk_userlocks) : 0;
610 	if (cb_data->mark_needed) {
611 		if (sk_fullsock(sk))
612 			entry.mark = READ_ONCE(sk->sk_mark);
613 		else if (sk->sk_state == TCP_NEW_SYN_RECV)
614 			entry.mark = inet_rsk(inet_reqsk(sk))->ir_mark;
615 		else if (sk->sk_state == TCP_TIME_WAIT)
616 			entry.mark = inet_twsk(sk)->tw_mark;
617 		else
618 			entry.mark = 0;
619 	}
620 #ifdef CONFIG_SOCK_CGROUP_DATA
621 	if (cb_data->cgroup_needed)
622 		entry.cgroup_id = sk_fullsock(sk) ?
623 			cgroup_id(sock_cgroup_ptr(&sk->sk_cgrp_data)) : 0;
624 #endif
625 
626 	return inet_diag_bc_run(bc, &entry);
627 }
628 EXPORT_SYMBOL_GPL(inet_diag_bc_sk);
629 
630 static int valid_cc(const void *bc, int len, int cc)
631 {
632 	while (len >= 0) {
633 		const struct inet_diag_bc_op *op = bc;
634 
635 		if (cc > len)
636 			return 0;
637 		if (cc == len)
638 			return 1;
639 		if (op->yes < 4 || op->yes & 3)
640 			return 0;
641 		len -= op->yes;
642 		bc  += op->yes;
643 	}
644 	return 0;
645 }
646 
647 /* data is u32 ifindex */
648 static bool valid_devcond(const struct inet_diag_bc_op *op, int len,
649 			  int *min_len)
650 {
651 	/* Check ifindex space. */
652 	*min_len += sizeof(u32);
653 	if (len < *min_len)
654 		return false;
655 
656 	return true;
657 }
658 /* Validate an inet_diag_hostcond. */
659 static bool valid_hostcond(const struct inet_diag_bc_op *op, int len,
660 			   int *min_len)
661 {
662 	struct inet_diag_hostcond *cond;
663 	int addr_len;
664 
665 	/* Check hostcond space. */
666 	*min_len += sizeof(struct inet_diag_hostcond);
667 	if (len < *min_len)
668 		return false;
669 	cond = (struct inet_diag_hostcond *)(op + 1);
670 
671 	/* Check address family and address length. */
672 	switch (cond->family) {
673 	case AF_UNSPEC:
674 		addr_len = 0;
675 		break;
676 	case AF_INET:
677 		addr_len = sizeof(struct in_addr);
678 		break;
679 	case AF_INET6:
680 		addr_len = sizeof(struct in6_addr);
681 		break;
682 	default:
683 		return false;
684 	}
685 	*min_len += addr_len;
686 	if (len < *min_len)
687 		return false;
688 
689 	/* Check prefix length (in bits) vs address length (in bytes). */
690 	if (cond->prefix_len > 8 * addr_len)
691 		return false;
692 
693 	return true;
694 }
695 
696 /* Validate a port comparison operator. */
697 static bool valid_port_comparison(const struct inet_diag_bc_op *op,
698 				  int len, int *min_len)
699 {
700 	/* Port comparisons put the port in a follow-on inet_diag_bc_op. */
701 	*min_len += sizeof(struct inet_diag_bc_op);
702 	if (len < *min_len)
703 		return false;
704 	return true;
705 }
706 
707 static bool valid_markcond(const struct inet_diag_bc_op *op, int len,
708 			   int *min_len)
709 {
710 	*min_len += sizeof(struct inet_diag_markcond);
711 	return len >= *min_len;
712 }
713 
714 #ifdef CONFIG_SOCK_CGROUP_DATA
715 static bool valid_cgroupcond(const struct inet_diag_bc_op *op, int len,
716 			     int *min_len)
717 {
718 	*min_len += sizeof(u64);
719 	return len >= *min_len;
720 }
721 #endif
722 
723 static int inet_diag_bc_audit(struct inet_diag_dump_data *cb_data,
724 			      const struct sk_buff *skb)
725 {
726 	const struct nlattr *attr = cb_data->inet_diag_nla_bc;
727 	const void *bytecode, *bc;
728 	int bytecode_len, len;
729 	bool net_admin;
730 
731 	if (!attr)
732 		return 0;
733 
734 	if (nla_len(attr) < sizeof(struct inet_diag_bc_op))
735 		return -EINVAL;
736 
737 	net_admin = netlink_net_capable(skb, CAP_NET_ADMIN);
738 	bytecode = bc = nla_data(attr);
739 	len = bytecode_len = nla_len(attr);
740 
741 	while (len > 0) {
742 		int min_len = sizeof(struct inet_diag_bc_op);
743 		const struct inet_diag_bc_op *op = bc;
744 
745 		switch (op->code) {
746 		case INET_DIAG_BC_S_COND:
747 		case INET_DIAG_BC_D_COND:
748 			if (!valid_hostcond(bc, len, &min_len))
749 				return -EINVAL;
750 			break;
751 		case INET_DIAG_BC_DEV_COND:
752 			if (!valid_devcond(bc, len, &min_len))
753 				return -EINVAL;
754 			break;
755 		case INET_DIAG_BC_S_EQ:
756 		case INET_DIAG_BC_S_GE:
757 		case INET_DIAG_BC_S_LE:
758 		case INET_DIAG_BC_D_EQ:
759 		case INET_DIAG_BC_D_GE:
760 		case INET_DIAG_BC_D_LE:
761 			if (!valid_port_comparison(bc, len, &min_len))
762 				return -EINVAL;
763 			break;
764 		case INET_DIAG_BC_MARK_COND:
765 			if (!net_admin)
766 				return -EPERM;
767 			if (!valid_markcond(bc, len, &min_len))
768 				return -EINVAL;
769 			cb_data->mark_needed = true;
770 			break;
771 #ifdef CONFIG_SOCK_CGROUP_DATA
772 		case INET_DIAG_BC_CGROUP_COND:
773 			if (!valid_cgroupcond(bc, len, &min_len))
774 				return -EINVAL;
775 			cb_data->cgroup_needed = true;
776 			break;
777 #endif
778 		case INET_DIAG_BC_AUTO:
779 			cb_data->userlocks_needed = true;
780 			fallthrough;
781 		case INET_DIAG_BC_JMP:
782 		case INET_DIAG_BC_NOP:
783 			break;
784 		default:
785 			return -EINVAL;
786 		}
787 
788 		if (op->code != INET_DIAG_BC_NOP) {
789 			if (op->no < min_len || op->no > len + 4 || op->no & 3)
790 				return -EINVAL;
791 			if (op->no < len &&
792 			    !valid_cc(bytecode, bytecode_len, len - op->no))
793 				return -EINVAL;
794 		}
795 
796 		if (op->yes < min_len || op->yes > len + 4 || op->yes & 3)
797 			return -EINVAL;
798 		bc  += op->yes;
799 		len -= op->yes;
800 	}
801 	return len == 0 ? 0 : -EINVAL;
802 }
803 
804 static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
805 			    const struct inet_diag_req_v2 *r)
806 {
807 	struct inet_diag_dump_data *cb_data = cb->data;
808 	const struct inet_diag_handler *handler;
809 	u32 prev_min_dump_alloc;
810 	int protocol, err = 0;
811 
812 	protocol = inet_diag_get_protocol(r, cb_data);
813 
814 again:
815 	prev_min_dump_alloc = cb->min_dump_alloc;
816 	handler = inet_diag_lock_handler(protocol);
817 	if (handler) {
818 		handler->dump(skb, cb, r);
819 		inet_diag_unlock_handler(handler);
820 	} else {
821 		err = -ENOENT;
822 	}
823 	/* The skb is not large enough to fit one sk info and
824 	 * inet_sk_diag_fill() has requested for a larger skb.
825 	 */
826 	if (!skb->len && cb->min_dump_alloc > prev_min_dump_alloc) {
827 		err = pskb_expand_head(skb, 0, cb->min_dump_alloc, GFP_KERNEL);
828 		if (!err)
829 			goto again;
830 	}
831 
832 	return err ? : skb->len;
833 }
834 
835 static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
836 {
837 	return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh));
838 }
839 
840 static int __inet_diag_dump_start(struct netlink_callback *cb, int hdrlen)
841 {
842 	const struct nlmsghdr *nlh = cb->nlh;
843 	struct inet_diag_dump_data *cb_data;
844 	struct sk_buff *skb = cb->skb;
845 	struct nlattr *nla;
846 	int err;
847 
848 	cb_data = kzalloc(sizeof(*cb_data), GFP_KERNEL);
849 	if (!cb_data)
850 		return -ENOMEM;
851 
852 	err = inet_diag_parse_attrs(nlh, hdrlen, cb_data->req_nlas);
853 	if (err) {
854 		kfree(cb_data);
855 		return err;
856 	}
857 	err = inet_diag_bc_audit(cb_data, skb);
858 	if (err) {
859 		kfree(cb_data);
860 		return err;
861 	}
862 
863 	nla = cb_data->inet_diag_nla_bpf_stgs;
864 	if (nla) {
865 		struct bpf_sk_storage_diag *bpf_stg_diag;
866 
867 		bpf_stg_diag = bpf_sk_storage_diag_alloc(nla);
868 		if (IS_ERR(bpf_stg_diag)) {
869 			kfree(cb_data);
870 			return PTR_ERR(bpf_stg_diag);
871 		}
872 		cb_data->bpf_stg_diag = bpf_stg_diag;
873 	}
874 
875 	cb->data = cb_data;
876 	return 0;
877 }
878 
879 static int inet_diag_dump_start(struct netlink_callback *cb)
880 {
881 	return __inet_diag_dump_start(cb, sizeof(struct inet_diag_req_v2));
882 }
883 
884 static int inet_diag_dump_start_compat(struct netlink_callback *cb)
885 {
886 	return __inet_diag_dump_start(cb, sizeof(struct inet_diag_req));
887 }
888 
889 static int inet_diag_dump_done(struct netlink_callback *cb)
890 {
891 	struct inet_diag_dump_data *cb_data = cb->data;
892 
893 	bpf_sk_storage_diag_free(cb_data->bpf_stg_diag);
894 	kfree(cb->data);
895 
896 	return 0;
897 }
898 
899 static int inet_diag_type2proto(int type)
900 {
901 	switch (type) {
902 	case TCPDIAG_GETSOCK:
903 		return IPPROTO_TCP;
904 	default:
905 		return 0;
906 	}
907 }
908 
909 static int inet_diag_dump_compat(struct sk_buff *skb,
910 				 struct netlink_callback *cb)
911 {
912 	struct inet_diag_req *rc = nlmsg_data(cb->nlh);
913 	struct inet_diag_req_v2 req;
914 
915 	req.sdiag_family = AF_UNSPEC; /* compatibility */
916 	req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type);
917 	req.idiag_ext = rc->idiag_ext;
918 	req.pad = 0;
919 	req.idiag_states = rc->idiag_states;
920 	req.id = rc->id;
921 
922 	return __inet_diag_dump(skb, cb, &req);
923 }
924 
925 static int inet_diag_get_exact_compat(struct sk_buff *in_skb,
926 				      const struct nlmsghdr *nlh)
927 {
928 	struct inet_diag_req *rc = nlmsg_data(nlh);
929 	struct inet_diag_req_v2 req;
930 
931 	req.sdiag_family = rc->idiag_family;
932 	req.sdiag_protocol = inet_diag_type2proto(nlh->nlmsg_type);
933 	req.idiag_ext = rc->idiag_ext;
934 	req.pad = 0;
935 	req.idiag_states = rc->idiag_states;
936 	req.id = rc->id;
937 
938 	return inet_diag_cmd_exact(SOCK_DIAG_BY_FAMILY, in_skb, nlh,
939 				   sizeof(struct inet_diag_req), &req);
940 }
941 
942 static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
943 {
944 	int hdrlen = sizeof(struct inet_diag_req);
945 	struct net *net = sock_net(skb->sk);
946 
947 	if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX ||
948 	    nlmsg_len(nlh) < hdrlen)
949 		return -EINVAL;
950 
951 	if (nlh->nlmsg_flags & NLM_F_DUMP) {
952 		struct netlink_dump_control c = {
953 			.start = inet_diag_dump_start_compat,
954 			.done = inet_diag_dump_done,
955 			.dump = inet_diag_dump_compat,
956 		};
957 		return netlink_dump_start(net->diag_nlsk, skb, nlh, &c);
958 	}
959 
960 	return inet_diag_get_exact_compat(skb, nlh);
961 }
962 
963 static int inet_diag_handler_cmd(struct sk_buff *skb, struct nlmsghdr *h)
964 {
965 	int hdrlen = sizeof(struct inet_diag_req_v2);
966 	struct net *net = sock_net(skb->sk);
967 
968 	if (nlmsg_len(h) < hdrlen)
969 		return -EINVAL;
970 
971 	if (h->nlmsg_type == SOCK_DIAG_BY_FAMILY &&
972 	    h->nlmsg_flags & NLM_F_DUMP) {
973 		struct netlink_dump_control c = {
974 			.start = inet_diag_dump_start,
975 			.done = inet_diag_dump_done,
976 			.dump = inet_diag_dump,
977 		};
978 		return netlink_dump_start(net->diag_nlsk, skb, h, &c);
979 	}
980 
981 	return inet_diag_cmd_exact(h->nlmsg_type, skb, h, hdrlen,
982 				   nlmsg_data(h));
983 }
984 
985 static
986 int inet_diag_handler_get_info(struct sk_buff *skb, struct sock *sk)
987 {
988 	const struct inet_diag_handler *handler;
989 	struct nlmsghdr *nlh;
990 	struct nlattr *attr;
991 	struct inet_diag_msg *r;
992 	void *info = NULL;
993 	int err = 0;
994 
995 	nlh = nlmsg_put(skb, 0, 0, SOCK_DIAG_BY_FAMILY, sizeof(*r), 0);
996 	if (!nlh)
997 		return -ENOMEM;
998 
999 	r = nlmsg_data(nlh);
1000 	memset(r, 0, sizeof(*r));
1001 	inet_diag_msg_common_fill(r, sk);
1002 	if (sk->sk_type == SOCK_DGRAM || sk->sk_type == SOCK_STREAM)
1003 		r->id.idiag_sport = inet_sk(sk)->inet_sport;
1004 	r->idiag_state = sk->sk_state;
1005 
1006 	if ((err = nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol))) {
1007 		nlmsg_cancel(skb, nlh);
1008 		return err;
1009 	}
1010 
1011 	handler = inet_diag_lock_handler(sk->sk_protocol);
1012 	if (!handler) {
1013 		nlmsg_cancel(skb, nlh);
1014 		return -ENOENT;
1015 	}
1016 
1017 	attr = handler->idiag_info_size
1018 		? nla_reserve_64bit(skb, INET_DIAG_INFO,
1019 				    handler->idiag_info_size,
1020 				    INET_DIAG_PAD)
1021 		: NULL;
1022 	if (attr)
1023 		info = nla_data(attr);
1024 
1025 	handler->idiag_get_info(sk, r, info);
1026 	inet_diag_unlock_handler(handler);
1027 
1028 	nlmsg_end(skb, nlh);
1029 	return 0;
1030 }
1031 
1032 static const struct sock_diag_handler inet_diag_handler = {
1033 	.owner = THIS_MODULE,
1034 	.family = AF_INET,
1035 	.dump = inet_diag_handler_cmd,
1036 	.get_info = inet_diag_handler_get_info,
1037 	.destroy = inet_diag_handler_cmd,
1038 };
1039 
1040 static const struct sock_diag_handler inet6_diag_handler = {
1041 	.owner = THIS_MODULE,
1042 	.family = AF_INET6,
1043 	.dump = inet_diag_handler_cmd,
1044 	.get_info = inet_diag_handler_get_info,
1045 	.destroy = inet_diag_handler_cmd,
1046 };
1047 
1048 int inet_diag_register(const struct inet_diag_handler *h)
1049 {
1050 	const __u16 type = h->idiag_type;
1051 
1052 	if (type >= IPPROTO_MAX)
1053 		return -EINVAL;
1054 
1055 	return !cmpxchg((const struct inet_diag_handler **)&inet_diag_table[type],
1056 			NULL, h) ? 0 : -EEXIST;
1057 }
1058 EXPORT_SYMBOL_GPL(inet_diag_register);
1059 
1060 void inet_diag_unregister(const struct inet_diag_handler *h)
1061 {
1062 	const __u16 type = h->idiag_type;
1063 
1064 	if (type >= IPPROTO_MAX)
1065 		return;
1066 
1067 	xchg((const struct inet_diag_handler **)&inet_diag_table[type],
1068 	     NULL);
1069 }
1070 EXPORT_SYMBOL_GPL(inet_diag_unregister);
1071 
1072 static const struct sock_diag_inet_compat inet_diag_compat = {
1073 	.owner	= THIS_MODULE,
1074 	.fn	= inet_diag_rcv_msg_compat,
1075 };
1076 
1077 static int __init inet_diag_init(void)
1078 {
1079 	const int inet_diag_table_size = (IPPROTO_MAX *
1080 					  sizeof(struct inet_diag_handler *));
1081 	int err = -ENOMEM;
1082 
1083 	inet_diag_table = kzalloc(inet_diag_table_size, GFP_KERNEL);
1084 	if (!inet_diag_table)
1085 		goto out;
1086 
1087 	err = sock_diag_register(&inet_diag_handler);
1088 	if (err)
1089 		goto out_free_nl;
1090 
1091 	err = sock_diag_register(&inet6_diag_handler);
1092 	if (err)
1093 		goto out_free_inet;
1094 
1095 	sock_diag_register_inet_compat(&inet_diag_compat);
1096 out:
1097 	return err;
1098 
1099 out_free_inet:
1100 	sock_diag_unregister(&inet_diag_handler);
1101 out_free_nl:
1102 	kfree(inet_diag_table);
1103 	goto out;
1104 }
1105 
1106 static void __exit inet_diag_exit(void)
1107 {
1108 	sock_diag_unregister(&inet6_diag_handler);
1109 	sock_diag_unregister(&inet_diag_handler);
1110 	sock_diag_unregister_inet_compat(&inet_diag_compat);
1111 	kfree(inet_diag_table);
1112 }
1113 
1114 module_init(inet_diag_init);
1115 module_exit(inet_diag_exit);
1116 MODULE_LICENSE("GPL");
1117 MODULE_DESCRIPTION("INET/INET6: socket monitoring via SOCK_DIAG");
1118 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2 /* AF_INET */);
1119 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10 /* AF_INET6 */);
1120