xref: /linux/net/ipv4/inet_diag.c (revision 3ce095c16263630dde46d6051854073edaacf3d7)
1 /*
2  * inet_diag.c	Module for monitoring INET transport protocols sockets.
3  *
4  * Authors:	Alexey Kuznetsov, <kuznet@ms2.inr.ac.ru>
5  *
6  *	This program is free software; you can redistribute it and/or
7  *      modify it under the terms of the GNU General Public License
8  *      as published by the Free Software Foundation; either version
9  *      2 of the License, or (at your option) any later version.
10  */
11 
12 #include <linux/kernel.h>
13 #include <linux/module.h>
14 #include <linux/types.h>
15 #include <linux/fcntl.h>
16 #include <linux/random.h>
17 #include <linux/slab.h>
18 #include <linux/cache.h>
19 #include <linux/init.h>
20 #include <linux/time.h>
21 
22 #include <net/icmp.h>
23 #include <net/tcp.h>
24 #include <net/ipv6.h>
25 #include <net/inet_common.h>
26 #include <net/inet_connection_sock.h>
27 #include <net/inet_hashtables.h>
28 #include <net/inet_timewait_sock.h>
29 #include <net/inet6_hashtables.h>
30 #include <net/netlink.h>
31 
32 #include <linux/inet.h>
33 #include <linux/stddef.h>
34 
35 #include <linux/inet_diag.h>
36 #include <linux/sock_diag.h>
37 
38 static const struct inet_diag_handler **inet_diag_table;
39 
40 struct inet_diag_entry {
41 	const __be32 *saddr;
42 	const __be32 *daddr;
43 	u16 sport;
44 	u16 dport;
45 	u16 family;
46 	u16 userlocks;
47 };
48 
49 static DEFINE_MUTEX(inet_diag_table_mutex);
50 
51 static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
52 {
53 	if (!inet_diag_table[proto])
54 		request_module("net-pf-%d-proto-%d-type-%d-%d", PF_NETLINK,
55 			       NETLINK_SOCK_DIAG, AF_INET, proto);
56 
57 	mutex_lock(&inet_diag_table_mutex);
58 	if (!inet_diag_table[proto])
59 		return ERR_PTR(-ENOENT);
60 
61 	return inet_diag_table[proto];
62 }
63 
64 static void inet_diag_unlock_handler(const struct inet_diag_handler *handler)
65 {
66 	mutex_unlock(&inet_diag_table_mutex);
67 }
68 
69 static void inet_diag_msg_common_fill(struct inet_diag_msg *r, struct sock *sk)
70 {
71 	r->idiag_family = sk->sk_family;
72 
73 	r->id.idiag_sport = htons(sk->sk_num);
74 	r->id.idiag_dport = sk->sk_dport;
75 	r->id.idiag_if = sk->sk_bound_dev_if;
76 	sock_diag_save_cookie(sk, r->id.idiag_cookie);
77 
78 #if IS_ENABLED(CONFIG_IPV6)
79 	if (sk->sk_family == AF_INET6) {
80 		*(struct in6_addr *)r->id.idiag_src = sk->sk_v6_rcv_saddr;
81 		*(struct in6_addr *)r->id.idiag_dst = sk->sk_v6_daddr;
82 	} else
83 #endif
84 	{
85 	memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src));
86 	memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst));
87 
88 	r->id.idiag_src[0] = sk->sk_rcv_saddr;
89 	r->id.idiag_dst[0] = sk->sk_daddr;
90 	}
91 }
92 
93 static size_t inet_sk_attr_size(void)
94 {
95 	return	  nla_total_size(sizeof(struct tcp_info))
96 		+ nla_total_size(1) /* INET_DIAG_SHUTDOWN */
97 		+ nla_total_size(1) /* INET_DIAG_TOS */
98 		+ nla_total_size(1) /* INET_DIAG_TCLASS */
99 		+ nla_total_size(sizeof(struct inet_diag_meminfo))
100 		+ nla_total_size(sizeof(struct inet_diag_msg))
101 		+ nla_total_size(SK_MEMINFO_VARS * sizeof(u32))
102 		+ nla_total_size(TCP_CA_NAME_MAX)
103 		+ nla_total_size(sizeof(struct tcpvegas_info))
104 		+ 64;
105 }
106 
107 int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
108 		      struct sk_buff *skb, const struct inet_diag_req_v2 *req,
109 		      struct user_namespace *user_ns,
110 		      u32 portid, u32 seq, u16 nlmsg_flags,
111 		      const struct nlmsghdr *unlh)
112 {
113 	const struct inet_sock *inet = inet_sk(sk);
114 	const struct tcp_congestion_ops *ca_ops;
115 	const struct inet_diag_handler *handler;
116 	int ext = req->idiag_ext;
117 	struct inet_diag_msg *r;
118 	struct nlmsghdr  *nlh;
119 	struct nlattr *attr;
120 	void *info = NULL;
121 
122 	handler = inet_diag_table[req->sdiag_protocol];
123 	BUG_ON(!handler);
124 
125 	nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
126 			nlmsg_flags);
127 	if (!nlh)
128 		return -EMSGSIZE;
129 
130 	r = nlmsg_data(nlh);
131 	BUG_ON(!sk_fullsock(sk));
132 
133 	inet_diag_msg_common_fill(r, sk);
134 	r->idiag_state = sk->sk_state;
135 	r->idiag_timer = 0;
136 	r->idiag_retrans = 0;
137 
138 	if (nla_put_u8(skb, INET_DIAG_SHUTDOWN, sk->sk_shutdown))
139 		goto errout;
140 
141 	/* IPv6 dual-stack sockets use inet->tos for IPv4 connections,
142 	 * hence this needs to be included regardless of socket family.
143 	 */
144 	if (ext & (1 << (INET_DIAG_TOS - 1)))
145 		if (nla_put_u8(skb, INET_DIAG_TOS, inet->tos) < 0)
146 			goto errout;
147 
148 #if IS_ENABLED(CONFIG_IPV6)
149 	if (r->idiag_family == AF_INET6) {
150 		if (ext & (1 << (INET_DIAG_TCLASS - 1)))
151 			if (nla_put_u8(skb, INET_DIAG_TCLASS,
152 				       inet6_sk(sk)->tclass) < 0)
153 				goto errout;
154 	}
155 #endif
156 
157 	r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk));
158 	r->idiag_inode = sock_i_ino(sk);
159 
160 	if (ext & (1 << (INET_DIAG_MEMINFO - 1))) {
161 		struct inet_diag_meminfo minfo = {
162 			.idiag_rmem = sk_rmem_alloc_get(sk),
163 			.idiag_wmem = sk->sk_wmem_queued,
164 			.idiag_fmem = sk->sk_forward_alloc,
165 			.idiag_tmem = sk_wmem_alloc_get(sk),
166 		};
167 
168 		if (nla_put(skb, INET_DIAG_MEMINFO, sizeof(minfo), &minfo) < 0)
169 			goto errout;
170 	}
171 
172 	if (ext & (1 << (INET_DIAG_SKMEMINFO - 1)))
173 		if (sock_diag_put_meminfo(sk, skb, INET_DIAG_SKMEMINFO))
174 			goto errout;
175 
176 	if (!icsk) {
177 		handler->idiag_get_info(sk, r, NULL);
178 		goto out;
179 	}
180 
181 #define EXPIRES_IN_MS(tmo)  DIV_ROUND_UP((tmo - jiffies) * 1000, HZ)
182 
183 	if (icsk->icsk_pending == ICSK_TIME_RETRANS ||
184 	    icsk->icsk_pending == ICSK_TIME_EARLY_RETRANS ||
185 	    icsk->icsk_pending == ICSK_TIME_LOSS_PROBE) {
186 		r->idiag_timer = 1;
187 		r->idiag_retrans = icsk->icsk_retransmits;
188 		r->idiag_expires = EXPIRES_IN_MS(icsk->icsk_timeout);
189 	} else if (icsk->icsk_pending == ICSK_TIME_PROBE0) {
190 		r->idiag_timer = 4;
191 		r->idiag_retrans = icsk->icsk_probes_out;
192 		r->idiag_expires = EXPIRES_IN_MS(icsk->icsk_timeout);
193 	} else if (timer_pending(&sk->sk_timer)) {
194 		r->idiag_timer = 2;
195 		r->idiag_retrans = icsk->icsk_probes_out;
196 		r->idiag_expires = EXPIRES_IN_MS(sk->sk_timer.expires);
197 	} else {
198 		r->idiag_timer = 0;
199 		r->idiag_expires = 0;
200 	}
201 #undef EXPIRES_IN_MS
202 
203 	if (ext & (1 << (INET_DIAG_INFO - 1))) {
204 		attr = nla_reserve(skb, INET_DIAG_INFO,
205 				   sizeof(struct tcp_info));
206 		if (!attr)
207 			goto errout;
208 
209 		info = nla_data(attr);
210 	}
211 
212 	if (ext & (1 << (INET_DIAG_CONG - 1))) {
213 		int err = 0;
214 
215 		rcu_read_lock();
216 		ca_ops = READ_ONCE(icsk->icsk_ca_ops);
217 		if (ca_ops)
218 			err = nla_put_string(skb, INET_DIAG_CONG, ca_ops->name);
219 		rcu_read_unlock();
220 		if (err < 0)
221 			goto errout;
222 	}
223 
224 	handler->idiag_get_info(sk, r, info);
225 
226 	if (sk->sk_state < TCP_TIME_WAIT) {
227 		union tcp_cc_info info;
228 		size_t sz = 0;
229 		int attr;
230 
231 		rcu_read_lock();
232 		ca_ops = READ_ONCE(icsk->icsk_ca_ops);
233 		if (ca_ops && ca_ops->get_info)
234 			sz = ca_ops->get_info(sk, ext, &attr, &info);
235 		rcu_read_unlock();
236 		if (sz && nla_put(skb, attr, sz, &info) < 0)
237 			goto errout;
238 	}
239 
240 out:
241 	nlmsg_end(skb, nlh);
242 	return 0;
243 
244 errout:
245 	nlmsg_cancel(skb, nlh);
246 	return -EMSGSIZE;
247 }
248 EXPORT_SYMBOL_GPL(inet_sk_diag_fill);
249 
250 static int inet_csk_diag_fill(struct sock *sk,
251 			      struct sk_buff *skb,
252 			      const struct inet_diag_req_v2 *req,
253 			      struct user_namespace *user_ns,
254 			      u32 portid, u32 seq, u16 nlmsg_flags,
255 			      const struct nlmsghdr *unlh)
256 {
257 	return inet_sk_diag_fill(sk, inet_csk(sk), skb, req,
258 				 user_ns, portid, seq, nlmsg_flags, unlh);
259 }
260 
261 static int inet_twsk_diag_fill(struct sock *sk,
262 			       struct sk_buff *skb,
263 			       u32 portid, u32 seq, u16 nlmsg_flags,
264 			       const struct nlmsghdr *unlh)
265 {
266 	struct inet_timewait_sock *tw = inet_twsk(sk);
267 	struct inet_diag_msg *r;
268 	struct nlmsghdr *nlh;
269 	long tmo;
270 
271 	nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
272 			nlmsg_flags);
273 	if (!nlh)
274 		return -EMSGSIZE;
275 
276 	r = nlmsg_data(nlh);
277 	BUG_ON(tw->tw_state != TCP_TIME_WAIT);
278 
279 	tmo = tw->tw_timer.expires - jiffies;
280 	if (tmo < 0)
281 		tmo = 0;
282 
283 	inet_diag_msg_common_fill(r, sk);
284 	r->idiag_retrans      = 0;
285 
286 	r->idiag_state	      = tw->tw_substate;
287 	r->idiag_timer	      = 3;
288 	r->idiag_expires      = jiffies_to_msecs(tmo);
289 	r->idiag_rqueue	      = 0;
290 	r->idiag_wqueue	      = 0;
291 	r->idiag_uid	      = 0;
292 	r->idiag_inode	      = 0;
293 
294 	nlmsg_end(skb, nlh);
295 	return 0;
296 }
297 
298 static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
299 			      u32 portid, u32 seq, u16 nlmsg_flags,
300 			      const struct nlmsghdr *unlh)
301 {
302 	struct inet_diag_msg *r;
303 	struct nlmsghdr *nlh;
304 	long tmo;
305 
306 	nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
307 			nlmsg_flags);
308 	if (!nlh)
309 		return -EMSGSIZE;
310 
311 	r = nlmsg_data(nlh);
312 	inet_diag_msg_common_fill(r, sk);
313 	r->idiag_state = TCP_SYN_RECV;
314 	r->idiag_timer = 1;
315 	r->idiag_retrans = inet_reqsk(sk)->num_retrans;
316 
317 	BUILD_BUG_ON(offsetof(struct inet_request_sock, ir_cookie) !=
318 		     offsetof(struct sock, sk_cookie));
319 
320 	tmo = inet_reqsk(sk)->rsk_timer.expires - jiffies;
321 	r->idiag_expires = (tmo >= 0) ? jiffies_to_msecs(tmo) : 0;
322 	r->idiag_rqueue	= 0;
323 	r->idiag_wqueue	= 0;
324 	r->idiag_uid	= 0;
325 	r->idiag_inode	= 0;
326 
327 	nlmsg_end(skb, nlh);
328 	return 0;
329 }
330 
331 static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
332 			const struct inet_diag_req_v2 *r,
333 			struct user_namespace *user_ns,
334 			u32 portid, u32 seq, u16 nlmsg_flags,
335 			const struct nlmsghdr *unlh)
336 {
337 	if (sk->sk_state == TCP_TIME_WAIT)
338 		return inet_twsk_diag_fill(sk, skb, portid, seq,
339 					   nlmsg_flags, unlh);
340 
341 	if (sk->sk_state == TCP_NEW_SYN_RECV)
342 		return inet_req_diag_fill(sk, skb, portid, seq,
343 					  nlmsg_flags, unlh);
344 
345 	return inet_csk_diag_fill(sk, skb, r, user_ns, portid, seq,
346 				  nlmsg_flags, unlh);
347 }
348 
349 int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo,
350 			    struct sk_buff *in_skb,
351 			    const struct nlmsghdr *nlh,
352 			    const struct inet_diag_req_v2 *req)
353 {
354 	struct net *net = sock_net(in_skb->sk);
355 	struct sk_buff *rep;
356 	struct sock *sk;
357 	int err;
358 
359 	err = -EINVAL;
360 	if (req->sdiag_family == AF_INET)
361 		sk = inet_lookup(net, hashinfo, req->id.idiag_dst[0],
362 				 req->id.idiag_dport, req->id.idiag_src[0],
363 				 req->id.idiag_sport, req->id.idiag_if);
364 #if IS_ENABLED(CONFIG_IPV6)
365 	else if (req->sdiag_family == AF_INET6)
366 		sk = inet6_lookup(net, hashinfo,
367 				  (struct in6_addr *)req->id.idiag_dst,
368 				  req->id.idiag_dport,
369 				  (struct in6_addr *)req->id.idiag_src,
370 				  req->id.idiag_sport,
371 				  req->id.idiag_if);
372 #endif
373 	else
374 		goto out_nosk;
375 
376 	err = -ENOENT;
377 	if (!sk)
378 		goto out_nosk;
379 
380 	err = sock_diag_check_cookie(sk, req->id.idiag_cookie);
381 	if (err)
382 		goto out;
383 
384 	rep = nlmsg_new(inet_sk_attr_size(), GFP_KERNEL);
385 	if (!rep) {
386 		err = -ENOMEM;
387 		goto out;
388 	}
389 
390 	err = sk_diag_fill(sk, rep, req,
391 			   sk_user_ns(NETLINK_CB(in_skb).sk),
392 			   NETLINK_CB(in_skb).portid,
393 			   nlh->nlmsg_seq, 0, nlh);
394 	if (err < 0) {
395 		WARN_ON(err == -EMSGSIZE);
396 		nlmsg_free(rep);
397 		goto out;
398 	}
399 	err = netlink_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid,
400 			      MSG_DONTWAIT);
401 	if (err > 0)
402 		err = 0;
403 
404 out:
405 	if (sk)
406 		sock_gen_put(sk);
407 
408 out_nosk:
409 	return err;
410 }
411 EXPORT_SYMBOL_GPL(inet_diag_dump_one_icsk);
412 
413 static int inet_diag_get_exact(struct sk_buff *in_skb,
414 			       const struct nlmsghdr *nlh,
415 			       const struct inet_diag_req_v2 *req)
416 {
417 	const struct inet_diag_handler *handler;
418 	int err;
419 
420 	handler = inet_diag_lock_handler(req->sdiag_protocol);
421 	if (IS_ERR(handler))
422 		err = PTR_ERR(handler);
423 	else
424 		err = handler->dump_one(in_skb, nlh, req);
425 	inet_diag_unlock_handler(handler);
426 
427 	return err;
428 }
429 
430 static int bitstring_match(const __be32 *a1, const __be32 *a2, int bits)
431 {
432 	int words = bits >> 5;
433 
434 	bits &= 0x1f;
435 
436 	if (words) {
437 		if (memcmp(a1, a2, words << 2))
438 			return 0;
439 	}
440 	if (bits) {
441 		__be32 w1, w2;
442 		__be32 mask;
443 
444 		w1 = a1[words];
445 		w2 = a2[words];
446 
447 		mask = htonl((0xffffffff) << (32 - bits));
448 
449 		if ((w1 ^ w2) & mask)
450 			return 0;
451 	}
452 
453 	return 1;
454 }
455 
456 static int inet_diag_bc_run(const struct nlattr *_bc,
457 			    const struct inet_diag_entry *entry)
458 {
459 	const void *bc = nla_data(_bc);
460 	int len = nla_len(_bc);
461 
462 	while (len > 0) {
463 		int yes = 1;
464 		const struct inet_diag_bc_op *op = bc;
465 
466 		switch (op->code) {
467 		case INET_DIAG_BC_NOP:
468 			break;
469 		case INET_DIAG_BC_JMP:
470 			yes = 0;
471 			break;
472 		case INET_DIAG_BC_S_GE:
473 			yes = entry->sport >= op[1].no;
474 			break;
475 		case INET_DIAG_BC_S_LE:
476 			yes = entry->sport <= op[1].no;
477 			break;
478 		case INET_DIAG_BC_D_GE:
479 			yes = entry->dport >= op[1].no;
480 			break;
481 		case INET_DIAG_BC_D_LE:
482 			yes = entry->dport <= op[1].no;
483 			break;
484 		case INET_DIAG_BC_AUTO:
485 			yes = !(entry->userlocks & SOCK_BINDPORT_LOCK);
486 			break;
487 		case INET_DIAG_BC_S_COND:
488 		case INET_DIAG_BC_D_COND: {
489 			const struct inet_diag_hostcond *cond;
490 			const __be32 *addr;
491 
492 			cond = (const struct inet_diag_hostcond *)(op + 1);
493 			if (cond->port != -1 &&
494 			    cond->port != (op->code == INET_DIAG_BC_S_COND ?
495 					     entry->sport : entry->dport)) {
496 				yes = 0;
497 				break;
498 			}
499 
500 			if (op->code == INET_DIAG_BC_S_COND)
501 				addr = entry->saddr;
502 			else
503 				addr = entry->daddr;
504 
505 			if (cond->family != AF_UNSPEC &&
506 			    cond->family != entry->family) {
507 				if (entry->family == AF_INET6 &&
508 				    cond->family == AF_INET) {
509 					if (addr[0] == 0 && addr[1] == 0 &&
510 					    addr[2] == htonl(0xffff) &&
511 					    bitstring_match(addr + 3,
512 							    cond->addr,
513 							    cond->prefix_len))
514 						break;
515 				}
516 				yes = 0;
517 				break;
518 			}
519 
520 			if (cond->prefix_len == 0)
521 				break;
522 			if (bitstring_match(addr, cond->addr,
523 					    cond->prefix_len))
524 				break;
525 			yes = 0;
526 			break;
527 		}
528 		}
529 
530 		if (yes) {
531 			len -= op->yes;
532 			bc += op->yes;
533 		} else {
534 			len -= op->no;
535 			bc += op->no;
536 		}
537 	}
538 	return len == 0;
539 }
540 
541 /* This helper is available for all sockets (ESTABLISH, TIMEWAIT, SYN_RECV)
542  */
543 static void entry_fill_addrs(struct inet_diag_entry *entry,
544 			     const struct sock *sk)
545 {
546 #if IS_ENABLED(CONFIG_IPV6)
547 	if (sk->sk_family == AF_INET6) {
548 		entry->saddr = sk->sk_v6_rcv_saddr.s6_addr32;
549 		entry->daddr = sk->sk_v6_daddr.s6_addr32;
550 	} else
551 #endif
552 	{
553 		entry->saddr = &sk->sk_rcv_saddr;
554 		entry->daddr = &sk->sk_daddr;
555 	}
556 }
557 
558 int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk)
559 {
560 	struct inet_sock *inet = inet_sk(sk);
561 	struct inet_diag_entry entry;
562 
563 	if (!bc)
564 		return 1;
565 
566 	entry.family = sk->sk_family;
567 	entry_fill_addrs(&entry, sk);
568 	entry.sport = inet->inet_num;
569 	entry.dport = ntohs(inet->inet_dport);
570 	entry.userlocks = sk_fullsock(sk) ? sk->sk_userlocks : 0;
571 
572 	return inet_diag_bc_run(bc, &entry);
573 }
574 EXPORT_SYMBOL_GPL(inet_diag_bc_sk);
575 
576 static int valid_cc(const void *bc, int len, int cc)
577 {
578 	while (len >= 0) {
579 		const struct inet_diag_bc_op *op = bc;
580 
581 		if (cc > len)
582 			return 0;
583 		if (cc == len)
584 			return 1;
585 		if (op->yes < 4 || op->yes & 3)
586 			return 0;
587 		len -= op->yes;
588 		bc  += op->yes;
589 	}
590 	return 0;
591 }
592 
593 /* Validate an inet_diag_hostcond. */
594 static bool valid_hostcond(const struct inet_diag_bc_op *op, int len,
595 			   int *min_len)
596 {
597 	struct inet_diag_hostcond *cond;
598 	int addr_len;
599 
600 	/* Check hostcond space. */
601 	*min_len += sizeof(struct inet_diag_hostcond);
602 	if (len < *min_len)
603 		return false;
604 	cond = (struct inet_diag_hostcond *)(op + 1);
605 
606 	/* Check address family and address length. */
607 	switch (cond->family) {
608 	case AF_UNSPEC:
609 		addr_len = 0;
610 		break;
611 	case AF_INET:
612 		addr_len = sizeof(struct in_addr);
613 		break;
614 	case AF_INET6:
615 		addr_len = sizeof(struct in6_addr);
616 		break;
617 	default:
618 		return false;
619 	}
620 	*min_len += addr_len;
621 	if (len < *min_len)
622 		return false;
623 
624 	/* Check prefix length (in bits) vs address length (in bytes). */
625 	if (cond->prefix_len > 8 * addr_len)
626 		return false;
627 
628 	return true;
629 }
630 
631 /* Validate a port comparison operator. */
632 static bool valid_port_comparison(const struct inet_diag_bc_op *op,
633 				  int len, int *min_len)
634 {
635 	/* Port comparisons put the port in a follow-on inet_diag_bc_op. */
636 	*min_len += sizeof(struct inet_diag_bc_op);
637 	if (len < *min_len)
638 		return false;
639 	return true;
640 }
641 
642 static int inet_diag_bc_audit(const void *bytecode, int bytecode_len)
643 {
644 	const void *bc = bytecode;
645 	int  len = bytecode_len;
646 
647 	while (len > 0) {
648 		int min_len = sizeof(struct inet_diag_bc_op);
649 		const struct inet_diag_bc_op *op = bc;
650 
651 		switch (op->code) {
652 		case INET_DIAG_BC_S_COND:
653 		case INET_DIAG_BC_D_COND:
654 			if (!valid_hostcond(bc, len, &min_len))
655 				return -EINVAL;
656 			break;
657 		case INET_DIAG_BC_S_GE:
658 		case INET_DIAG_BC_S_LE:
659 		case INET_DIAG_BC_D_GE:
660 		case INET_DIAG_BC_D_LE:
661 			if (!valid_port_comparison(bc, len, &min_len))
662 				return -EINVAL;
663 			break;
664 		case INET_DIAG_BC_AUTO:
665 		case INET_DIAG_BC_JMP:
666 		case INET_DIAG_BC_NOP:
667 			break;
668 		default:
669 			return -EINVAL;
670 		}
671 
672 		if (op->code != INET_DIAG_BC_NOP) {
673 			if (op->no < min_len || op->no > len + 4 || op->no & 3)
674 				return -EINVAL;
675 			if (op->no < len &&
676 			    !valid_cc(bytecode, bytecode_len, len - op->no))
677 				return -EINVAL;
678 		}
679 
680 		if (op->yes < min_len || op->yes > len + 4 || op->yes & 3)
681 			return -EINVAL;
682 		bc  += op->yes;
683 		len -= op->yes;
684 	}
685 	return len == 0 ? 0 : -EINVAL;
686 }
687 
688 static int inet_csk_diag_dump(struct sock *sk,
689 			      struct sk_buff *skb,
690 			      struct netlink_callback *cb,
691 			      const struct inet_diag_req_v2 *r,
692 			      const struct nlattr *bc)
693 {
694 	if (!inet_diag_bc_sk(bc, sk))
695 		return 0;
696 
697 	return inet_csk_diag_fill(sk, skb, r,
698 				  sk_user_ns(NETLINK_CB(cb->skb).sk),
699 				  NETLINK_CB(cb->skb).portid,
700 				  cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh);
701 }
702 
703 static void twsk_build_assert(void)
704 {
705 	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_family) !=
706 		     offsetof(struct sock, sk_family));
707 
708 	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_num) !=
709 		     offsetof(struct inet_sock, inet_num));
710 
711 	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_dport) !=
712 		     offsetof(struct inet_sock, inet_dport));
713 
714 	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_rcv_saddr) !=
715 		     offsetof(struct inet_sock, inet_rcv_saddr));
716 
717 	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_daddr) !=
718 		     offsetof(struct inet_sock, inet_daddr));
719 
720 #if IS_ENABLED(CONFIG_IPV6)
721 	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_rcv_saddr) !=
722 		     offsetof(struct sock, sk_v6_rcv_saddr));
723 
724 	BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_daddr) !=
725 		     offsetof(struct sock, sk_v6_daddr));
726 #endif
727 }
728 
729 static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk,
730 			       struct netlink_callback *cb,
731 			       const struct inet_diag_req_v2 *r,
732 			       const struct nlattr *bc)
733 {
734 	struct inet_connection_sock *icsk = inet_csk(sk);
735 	struct inet_sock *inet = inet_sk(sk);
736 	struct inet_diag_entry entry;
737 	int j, s_j, reqnum, s_reqnum;
738 	struct listen_sock *lopt;
739 	int err = 0;
740 
741 	s_j = cb->args[3];
742 	s_reqnum = cb->args[4];
743 
744 	if (s_j > 0)
745 		s_j--;
746 
747 	entry.family = sk->sk_family;
748 
749 	spin_lock_bh(&icsk->icsk_accept_queue.syn_wait_lock);
750 
751 	lopt = icsk->icsk_accept_queue.listen_opt;
752 	if (!lopt || !listen_sock_qlen(lopt))
753 		goto out;
754 
755 	if (bc) {
756 		entry.sport = inet->inet_num;
757 		entry.userlocks = sk->sk_userlocks;
758 	}
759 
760 	for (j = s_j; j < lopt->nr_table_entries; j++) {
761 		struct request_sock *req, *head = lopt->syn_table[j];
762 
763 		reqnum = 0;
764 		for (req = head; req; reqnum++, req = req->dl_next) {
765 			struct inet_request_sock *ireq = inet_rsk(req);
766 
767 			if (reqnum < s_reqnum)
768 				continue;
769 			if (r->id.idiag_dport != ireq->ir_rmt_port &&
770 			    r->id.idiag_dport)
771 				continue;
772 
773 			if (bc) {
774 				/* Note: entry.sport and entry.userlocks are already set */
775 				entry_fill_addrs(&entry, req_to_sk(req));
776 				entry.dport = ntohs(ireq->ir_rmt_port);
777 
778 				if (!inet_diag_bc_run(bc, &entry))
779 					continue;
780 			}
781 
782 			err = inet_req_diag_fill(req_to_sk(req), skb,
783 						 NETLINK_CB(cb->skb).portid,
784 						 cb->nlh->nlmsg_seq,
785 						 NLM_F_MULTI, cb->nlh);
786 			if (err < 0) {
787 				cb->args[3] = j + 1;
788 				cb->args[4] = reqnum;
789 				goto out;
790 			}
791 		}
792 
793 		s_reqnum = 0;
794 	}
795 
796 out:
797 	spin_unlock_bh(&icsk->icsk_accept_queue.syn_wait_lock);
798 
799 	return err;
800 }
801 
802 void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
803 			 struct netlink_callback *cb,
804 			 const struct inet_diag_req_v2 *r, struct nlattr *bc)
805 {
806 	struct net *net = sock_net(skb->sk);
807 	int i, num, s_i, s_num;
808 
809 	s_i = cb->args[1];
810 	s_num = num = cb->args[2];
811 
812 	if (cb->args[0] == 0) {
813 		if (!(r->idiag_states & (TCPF_LISTEN | TCPF_SYN_RECV)))
814 			goto skip_listen_ht;
815 
816 		for (i = s_i; i < INET_LHTABLE_SIZE; i++) {
817 			struct inet_listen_hashbucket *ilb;
818 			struct hlist_nulls_node *node;
819 			struct sock *sk;
820 
821 			num = 0;
822 			ilb = &hashinfo->listening_hash[i];
823 			spin_lock_bh(&ilb->lock);
824 			sk_nulls_for_each(sk, node, &ilb->head) {
825 				struct inet_sock *inet = inet_sk(sk);
826 
827 				if (!net_eq(sock_net(sk), net))
828 					continue;
829 
830 				if (num < s_num) {
831 					num++;
832 					continue;
833 				}
834 
835 				if (r->sdiag_family != AF_UNSPEC &&
836 				    sk->sk_family != r->sdiag_family)
837 					goto next_listen;
838 
839 				if (r->id.idiag_sport != inet->inet_sport &&
840 				    r->id.idiag_sport)
841 					goto next_listen;
842 
843 				if (!(r->idiag_states & TCPF_LISTEN) ||
844 				    r->id.idiag_dport ||
845 				    cb->args[3] > 0)
846 					goto syn_recv;
847 
848 				if (inet_csk_diag_dump(sk, skb, cb, r, bc) < 0) {
849 					spin_unlock_bh(&ilb->lock);
850 					goto done;
851 				}
852 
853 syn_recv:
854 				if (!(r->idiag_states & TCPF_SYN_RECV))
855 					goto next_listen;
856 
857 				if (inet_diag_dump_reqs(skb, sk, cb, r, bc) < 0) {
858 					spin_unlock_bh(&ilb->lock);
859 					goto done;
860 				}
861 
862 next_listen:
863 				cb->args[3] = 0;
864 				cb->args[4] = 0;
865 				++num;
866 			}
867 			spin_unlock_bh(&ilb->lock);
868 
869 			s_num = 0;
870 			cb->args[3] = 0;
871 			cb->args[4] = 0;
872 		}
873 skip_listen_ht:
874 		cb->args[0] = 1;
875 		s_i = num = s_num = 0;
876 	}
877 
878 	if (!(r->idiag_states & ~(TCPF_LISTEN | TCPF_SYN_RECV)))
879 		goto out;
880 
881 	for (i = s_i; i <= hashinfo->ehash_mask; i++) {
882 		struct inet_ehash_bucket *head = &hashinfo->ehash[i];
883 		spinlock_t *lock = inet_ehash_lockp(hashinfo, i);
884 		struct hlist_nulls_node *node;
885 		struct sock *sk;
886 
887 		num = 0;
888 
889 		if (hlist_nulls_empty(&head->chain))
890 			continue;
891 
892 		if (i > s_i)
893 			s_num = 0;
894 
895 		spin_lock_bh(lock);
896 		sk_nulls_for_each(sk, node, &head->chain) {
897 			int state, res;
898 
899 			if (!net_eq(sock_net(sk), net))
900 				continue;
901 			if (num < s_num)
902 				goto next_normal;
903 			state = (sk->sk_state == TCP_TIME_WAIT) ?
904 				inet_twsk(sk)->tw_substate : sk->sk_state;
905 			if (!(r->idiag_states & (1 << state)))
906 				goto next_normal;
907 			if (r->sdiag_family != AF_UNSPEC &&
908 			    sk->sk_family != r->sdiag_family)
909 				goto next_normal;
910 			if (r->id.idiag_sport != htons(sk->sk_num) &&
911 			    r->id.idiag_sport)
912 				goto next_normal;
913 			if (r->id.idiag_dport != sk->sk_dport &&
914 			    r->id.idiag_dport)
915 				goto next_normal;
916 			twsk_build_assert();
917 
918 			if (!inet_diag_bc_sk(bc, sk))
919 				goto next_normal;
920 
921 			res = sk_diag_fill(sk, skb, r,
922 					   sk_user_ns(NETLINK_CB(cb->skb).sk),
923 					   NETLINK_CB(cb->skb).portid,
924 					   cb->nlh->nlmsg_seq, NLM_F_MULTI,
925 					   cb->nlh);
926 			if (res < 0) {
927 				spin_unlock_bh(lock);
928 				goto done;
929 			}
930 next_normal:
931 			++num;
932 		}
933 
934 		spin_unlock_bh(lock);
935 	}
936 
937 done:
938 	cb->args[1] = i;
939 	cb->args[2] = num;
940 out:
941 	;
942 }
943 EXPORT_SYMBOL_GPL(inet_diag_dump_icsk);
944 
945 static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
946 			    const struct inet_diag_req_v2 *r,
947 			    struct nlattr *bc)
948 {
949 	const struct inet_diag_handler *handler;
950 	int err = 0;
951 
952 	handler = inet_diag_lock_handler(r->sdiag_protocol);
953 	if (!IS_ERR(handler))
954 		handler->dump(skb, cb, r, bc);
955 	else
956 		err = PTR_ERR(handler);
957 	inet_diag_unlock_handler(handler);
958 
959 	return err ? : skb->len;
960 }
961 
962 static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
963 {
964 	int hdrlen = sizeof(struct inet_diag_req_v2);
965 	struct nlattr *bc = NULL;
966 
967 	if (nlmsg_attrlen(cb->nlh, hdrlen))
968 		bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
969 
970 	return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh), bc);
971 }
972 
973 static int inet_diag_type2proto(int type)
974 {
975 	switch (type) {
976 	case TCPDIAG_GETSOCK:
977 		return IPPROTO_TCP;
978 	case DCCPDIAG_GETSOCK:
979 		return IPPROTO_DCCP;
980 	default:
981 		return 0;
982 	}
983 }
984 
985 static int inet_diag_dump_compat(struct sk_buff *skb,
986 				 struct netlink_callback *cb)
987 {
988 	struct inet_diag_req *rc = nlmsg_data(cb->nlh);
989 	int hdrlen = sizeof(struct inet_diag_req);
990 	struct inet_diag_req_v2 req;
991 	struct nlattr *bc = NULL;
992 
993 	req.sdiag_family = AF_UNSPEC; /* compatibility */
994 	req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type);
995 	req.idiag_ext = rc->idiag_ext;
996 	req.idiag_states = rc->idiag_states;
997 	req.id = rc->id;
998 
999 	if (nlmsg_attrlen(cb->nlh, hdrlen))
1000 		bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
1001 
1002 	return __inet_diag_dump(skb, cb, &req, bc);
1003 }
1004 
1005 static int inet_diag_get_exact_compat(struct sk_buff *in_skb,
1006 				      const struct nlmsghdr *nlh)
1007 {
1008 	struct inet_diag_req *rc = nlmsg_data(nlh);
1009 	struct inet_diag_req_v2 req;
1010 
1011 	req.sdiag_family = rc->idiag_family;
1012 	req.sdiag_protocol = inet_diag_type2proto(nlh->nlmsg_type);
1013 	req.idiag_ext = rc->idiag_ext;
1014 	req.idiag_states = rc->idiag_states;
1015 	req.id = rc->id;
1016 
1017 	return inet_diag_get_exact(in_skb, nlh, &req);
1018 }
1019 
1020 static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
1021 {
1022 	int hdrlen = sizeof(struct inet_diag_req);
1023 	struct net *net = sock_net(skb->sk);
1024 
1025 	if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX ||
1026 	    nlmsg_len(nlh) < hdrlen)
1027 		return -EINVAL;
1028 
1029 	if (nlh->nlmsg_flags & NLM_F_DUMP) {
1030 		if (nlmsg_attrlen(nlh, hdrlen)) {
1031 			struct nlattr *attr;
1032 
1033 			attr = nlmsg_find_attr(nlh, hdrlen,
1034 					       INET_DIAG_REQ_BYTECODE);
1035 			if (!attr ||
1036 			    nla_len(attr) < sizeof(struct inet_diag_bc_op) ||
1037 			    inet_diag_bc_audit(nla_data(attr), nla_len(attr)))
1038 				return -EINVAL;
1039 		}
1040 		{
1041 			struct netlink_dump_control c = {
1042 				.dump = inet_diag_dump_compat,
1043 			};
1044 			return netlink_dump_start(net->diag_nlsk, skb, nlh, &c);
1045 		}
1046 	}
1047 
1048 	return inet_diag_get_exact_compat(skb, nlh);
1049 }
1050 
1051 static int inet_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h)
1052 {
1053 	int hdrlen = sizeof(struct inet_diag_req_v2);
1054 	struct net *net = sock_net(skb->sk);
1055 
1056 	if (nlmsg_len(h) < hdrlen)
1057 		return -EINVAL;
1058 
1059 	if (h->nlmsg_flags & NLM_F_DUMP) {
1060 		if (nlmsg_attrlen(h, hdrlen)) {
1061 			struct nlattr *attr;
1062 
1063 			attr = nlmsg_find_attr(h, hdrlen,
1064 					       INET_DIAG_REQ_BYTECODE);
1065 			if (!attr ||
1066 			    nla_len(attr) < sizeof(struct inet_diag_bc_op) ||
1067 			    inet_diag_bc_audit(nla_data(attr), nla_len(attr)))
1068 				return -EINVAL;
1069 		}
1070 		{
1071 			struct netlink_dump_control c = {
1072 				.dump = inet_diag_dump,
1073 			};
1074 			return netlink_dump_start(net->diag_nlsk, skb, h, &c);
1075 		}
1076 	}
1077 
1078 	return inet_diag_get_exact(skb, h, nlmsg_data(h));
1079 }
1080 
1081 static const struct sock_diag_handler inet_diag_handler = {
1082 	.family = AF_INET,
1083 	.dump = inet_diag_handler_dump,
1084 };
1085 
1086 static const struct sock_diag_handler inet6_diag_handler = {
1087 	.family = AF_INET6,
1088 	.dump = inet_diag_handler_dump,
1089 };
1090 
1091 int inet_diag_register(const struct inet_diag_handler *h)
1092 {
1093 	const __u16 type = h->idiag_type;
1094 	int err = -EINVAL;
1095 
1096 	if (type >= IPPROTO_MAX)
1097 		goto out;
1098 
1099 	mutex_lock(&inet_diag_table_mutex);
1100 	err = -EEXIST;
1101 	if (!inet_diag_table[type]) {
1102 		inet_diag_table[type] = h;
1103 		err = 0;
1104 	}
1105 	mutex_unlock(&inet_diag_table_mutex);
1106 out:
1107 	return err;
1108 }
1109 EXPORT_SYMBOL_GPL(inet_diag_register);
1110 
1111 void inet_diag_unregister(const struct inet_diag_handler *h)
1112 {
1113 	const __u16 type = h->idiag_type;
1114 
1115 	if (type >= IPPROTO_MAX)
1116 		return;
1117 
1118 	mutex_lock(&inet_diag_table_mutex);
1119 	inet_diag_table[type] = NULL;
1120 	mutex_unlock(&inet_diag_table_mutex);
1121 }
1122 EXPORT_SYMBOL_GPL(inet_diag_unregister);
1123 
1124 static int __init inet_diag_init(void)
1125 {
1126 	const int inet_diag_table_size = (IPPROTO_MAX *
1127 					  sizeof(struct inet_diag_handler *));
1128 	int err = -ENOMEM;
1129 
1130 	inet_diag_table = kzalloc(inet_diag_table_size, GFP_KERNEL);
1131 	if (!inet_diag_table)
1132 		goto out;
1133 
1134 	err = sock_diag_register(&inet_diag_handler);
1135 	if (err)
1136 		goto out_free_nl;
1137 
1138 	err = sock_diag_register(&inet6_diag_handler);
1139 	if (err)
1140 		goto out_free_inet;
1141 
1142 	sock_diag_register_inet_compat(inet_diag_rcv_msg_compat);
1143 out:
1144 	return err;
1145 
1146 out_free_inet:
1147 	sock_diag_unregister(&inet_diag_handler);
1148 out_free_nl:
1149 	kfree(inet_diag_table);
1150 	goto out;
1151 }
1152 
1153 static void __exit inet_diag_exit(void)
1154 {
1155 	sock_diag_unregister(&inet6_diag_handler);
1156 	sock_diag_unregister(&inet_diag_handler);
1157 	sock_diag_unregister_inet_compat(inet_diag_rcv_msg_compat);
1158 	kfree(inet_diag_table);
1159 }
1160 
1161 module_init(inet_diag_init);
1162 module_exit(inet_diag_exit);
1163 MODULE_LICENSE("GPL");
1164 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2 /* AF_INET */);
1165 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10 /* AF_INET6 */);
1166