xref: /linux/net/ipv4/inet_diag.c (revision 7a7c52645ce62314cdd69815e9d8fcb33e0042d5)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * inet_diag.c	Module for monitoring INET transport protocols sockets.
4  *
5  * Authors:	Alexey Kuznetsov, <kuznet@ms2.inr.ac.ru>
6  */
7 
8 #include <linux/kernel.h>
9 #include <linux/module.h>
10 #include <linux/types.h>
11 #include <linux/fcntl.h>
12 #include <linux/random.h>
13 #include <linux/slab.h>
14 #include <linux/cache.h>
15 #include <linux/init.h>
16 #include <linux/time.h>
17 
18 #include <net/icmp.h>
19 #include <net/tcp.h>
20 #include <net/ipv6.h>
21 #include <net/inet_common.h>
22 #include <net/inet_connection_sock.h>
23 #include <net/bpf_sk_storage.h>
24 #include <net/netlink.h>
25 
26 #include <linux/inet.h>
27 #include <linux/stddef.h>
28 
29 #include <linux/inet_diag.h>
30 #include <linux/sock_diag.h>
31 
32 static const struct inet_diag_handler __rcu **inet_diag_table;
33 
34 struct inet_diag_entry {
35 	const __be32 *saddr;
36 	const __be32 *daddr;
37 	u16 sport;
38 	u16 dport;
39 	u16 family;
40 	u16 userlocks;
41 	u32 ifindex;
42 	u32 mark;
43 #ifdef CONFIG_SOCK_CGROUP_DATA
44 	u64 cgroup_id;
45 #endif
46 };
47 
48 static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
49 {
50 	const struct inet_diag_handler *handler;
51 
52 	if (proto < 0 || proto >= IPPROTO_MAX)
53 		return NULL;
54 
55 	if (!READ_ONCE(inet_diag_table[proto]))
56 		sock_load_diag_module(AF_INET, proto);
57 
58 	rcu_read_lock();
59 	handler = rcu_dereference(inet_diag_table[proto]);
60 	if (handler && !try_module_get(handler->owner))
61 		handler = NULL;
62 	rcu_read_unlock();
63 
64 	return handler;
65 }
66 
67 static void inet_diag_unlock_handler(const struct inet_diag_handler *handler)
68 {
69 	module_put(handler->owner);
70 }
71 
72 void inet_diag_msg_common_fill(struct inet_diag_msg *r, struct sock *sk)
73 {
74 	r->idiag_family = sk->sk_family;
75 
76 	r->id.idiag_sport = htons(sk->sk_num);
77 	r->id.idiag_dport = sk->sk_dport;
78 	r->id.idiag_if = sk->sk_bound_dev_if;
79 	sock_diag_save_cookie(sk, r->id.idiag_cookie);
80 
81 #if IS_ENABLED(CONFIG_IPV6)
82 	if (sk->sk_family == AF_INET6) {
83 		*(struct in6_addr *)r->id.idiag_src = sk->sk_v6_rcv_saddr;
84 		*(struct in6_addr *)r->id.idiag_dst = sk->sk_v6_daddr;
85 	} else
86 #endif
87 	{
88 	memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src));
89 	memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst));
90 
91 	r->id.idiag_src[0] = sk->sk_rcv_saddr;
92 	r->id.idiag_dst[0] = sk->sk_daddr;
93 	}
94 }
95 EXPORT_SYMBOL_GPL(inet_diag_msg_common_fill);
96 
97 int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb,
98 			     struct inet_diag_msg *r, int ext,
99 			     struct user_namespace *user_ns,
100 			     bool net_admin)
101 {
102 	const struct inet_sock *inet = inet_sk(sk);
103 	struct inet_diag_sockopt inet_sockopt;
104 
105 	if (nla_put_u8(skb, INET_DIAG_SHUTDOWN, sk->sk_shutdown))
106 		goto errout;
107 
108 	/* IPv6 dual-stack sockets use inet->tos for IPv4 connections,
109 	 * hence this needs to be included regardless of socket family.
110 	 */
111 	if (ext & (1 << (INET_DIAG_TOS - 1)))
112 		if (nla_put_u8(skb, INET_DIAG_TOS, READ_ONCE(inet->tos)) < 0)
113 			goto errout;
114 
115 #if IS_ENABLED(CONFIG_IPV6)
116 	if (r->idiag_family == AF_INET6) {
117 		if (ext & (1 << (INET_DIAG_TCLASS - 1)))
118 			if (nla_put_u8(skb, INET_DIAG_TCLASS,
119 				       inet6_sk(sk)->tclass) < 0)
120 				goto errout;
121 
122 		if (((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) &&
123 		    nla_put_u8(skb, INET_DIAG_SKV6ONLY, ipv6_only_sock(sk)))
124 			goto errout;
125 	}
126 #endif
127 
128 	if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, READ_ONCE(sk->sk_mark)))
129 		goto errout;
130 
131 	if (ext & (1 << (INET_DIAG_CLASS_ID - 1)) ||
132 	    ext & (1 << (INET_DIAG_TCLASS - 1))) {
133 		u32 classid = 0;
134 
135 #ifdef CONFIG_CGROUP_NET_CLASSID
136 		classid = sock_cgroup_classid(&sk->sk_cgrp_data);
137 #endif
138 		/* Fallback to socket priority if class id isn't set.
139 		 * Classful qdiscs use it as direct reference to class.
140 		 * For cgroup2 classid is always zero.
141 		 */
142 		if (!classid)
143 			classid = READ_ONCE(sk->sk_priority);
144 
145 		if (nla_put_u32(skb, INET_DIAG_CLASS_ID, classid))
146 			goto errout;
147 	}
148 
149 #ifdef CONFIG_SOCK_CGROUP_DATA
150 	if (nla_put_u64_64bit(skb, INET_DIAG_CGROUP_ID,
151 			      cgroup_id(sock_cgroup_ptr(&sk->sk_cgrp_data)),
152 			      INET_DIAG_PAD))
153 		goto errout;
154 #endif
155 
156 	r->idiag_uid = from_kuid_munged(user_ns, sk_uid(sk));
157 	r->idiag_inode = sock_i_ino(sk);
158 
159 	memset(&inet_sockopt, 0, sizeof(inet_sockopt));
160 	inet_sockopt.recverr	= inet_test_bit(RECVERR, sk);
161 	inet_sockopt.is_icsk	= inet_test_bit(IS_ICSK, sk);
162 	inet_sockopt.freebind	= inet_test_bit(FREEBIND, sk);
163 	inet_sockopt.hdrincl	= inet_test_bit(HDRINCL, sk);
164 	inet_sockopt.mc_loop	= inet_test_bit(MC_LOOP, sk);
165 	inet_sockopt.transparent = inet_test_bit(TRANSPARENT, sk);
166 	inet_sockopt.mc_all	= inet_test_bit(MC_ALL, sk);
167 	inet_sockopt.nodefrag	= inet_test_bit(NODEFRAG, sk);
168 	inet_sockopt.bind_address_no_port = inet_test_bit(BIND_ADDRESS_NO_PORT, sk);
169 	inet_sockopt.recverr_rfc4884 = inet_test_bit(RECVERR_RFC4884, sk);
170 	inet_sockopt.defer_connect = inet_test_bit(DEFER_CONNECT, sk);
171 	if (nla_put(skb, INET_DIAG_SOCKOPT, sizeof(inet_sockopt),
172 		    &inet_sockopt))
173 		goto errout;
174 
175 	return 0;
176 errout:
177 	return 1;
178 }
179 EXPORT_SYMBOL_GPL(inet_diag_msg_attrs_fill);
180 
181 static int inet_diag_parse_attrs(const struct nlmsghdr *nlh, int hdrlen,
182 				 struct nlattr **req_nlas)
183 {
184 	struct nlattr *nla;
185 	int remaining;
186 
187 	nlmsg_for_each_attr(nla, nlh, hdrlen, remaining) {
188 		int type = nla_type(nla);
189 
190 		if (type == INET_DIAG_REQ_PROTOCOL && nla_len(nla) != sizeof(u32))
191 			return -EINVAL;
192 
193 		if (type < __INET_DIAG_REQ_MAX)
194 			req_nlas[type] = nla;
195 	}
196 	return 0;
197 }
198 
199 static int inet_diag_get_protocol(const struct inet_diag_req_v2 *req,
200 				  const struct inet_diag_dump_data *data)
201 {
202 	if (data->req_nlas[INET_DIAG_REQ_PROTOCOL])
203 		return nla_get_u32(data->req_nlas[INET_DIAG_REQ_PROTOCOL]);
204 	return req->sdiag_protocol;
205 }
206 
207 #define MAX_DUMP_ALLOC_SIZE (KMALLOC_MAX_SIZE - SKB_DATA_ALIGN(sizeof(struct skb_shared_info)))
208 
209 int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
210 		      struct sk_buff *skb, struct netlink_callback *cb,
211 		      const struct inet_diag_req_v2 *req,
212 		      u16 nlmsg_flags, bool net_admin)
213 {
214 	const struct tcp_congestion_ops *ca_ops;
215 	const struct inet_diag_handler *handler;
216 	struct inet_diag_dump_data *cb_data;
217 	int ext = req->idiag_ext;
218 	struct inet_diag_msg *r;
219 	struct nlmsghdr  *nlh;
220 	struct nlattr *attr;
221 	void *info = NULL;
222 	u8 icsk_pending;
223 	int protocol;
224 
225 	cb_data = cb->data;
226 	protocol = inet_diag_get_protocol(req, cb_data);
227 
228 	/* inet_diag_lock_handler() made sure inet_diag_table[] is stable. */
229 	handler = rcu_dereference_protected(inet_diag_table[protocol], 1);
230 	DEBUG_NET_WARN_ON_ONCE(!handler);
231 	if (!handler)
232 		return -ENXIO;
233 
234 	nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
235 			cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags);
236 	if (!nlh)
237 		return -EMSGSIZE;
238 
239 	r = nlmsg_data(nlh);
240 	BUG_ON(!sk_fullsock(sk));
241 
242 	inet_diag_msg_common_fill(r, sk);
243 	r->idiag_state = sk->sk_state;
244 	r->idiag_timer = 0;
245 	r->idiag_retrans = 0;
246 	r->idiag_expires = 0;
247 
248 	if (inet_diag_msg_attrs_fill(sk, skb, r, ext,
249 				     sk_user_ns(NETLINK_CB(cb->skb).sk),
250 				     net_admin))
251 		goto errout;
252 
253 	if (ext & (1 << (INET_DIAG_MEMINFO - 1))) {
254 		struct inet_diag_meminfo minfo = {
255 			.idiag_rmem = sk_rmem_alloc_get(sk),
256 			.idiag_wmem = READ_ONCE(sk->sk_wmem_queued),
257 			.idiag_fmem = READ_ONCE(sk->sk_forward_alloc),
258 			.idiag_tmem = sk_wmem_alloc_get(sk),
259 		};
260 
261 		if (nla_put(skb, INET_DIAG_MEMINFO, sizeof(minfo), &minfo) < 0)
262 			goto errout;
263 	}
264 
265 	if (ext & (1 << (INET_DIAG_SKMEMINFO - 1)))
266 		if (sock_diag_put_meminfo(sk, skb, INET_DIAG_SKMEMINFO))
267 			goto errout;
268 
269 	/*
270 	 * RAW sockets might have user-defined protocols assigned,
271 	 * so report the one supplied on socket creation.
272 	 */
273 	if (sk->sk_type == SOCK_RAW) {
274 		if (nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol))
275 			goto errout;
276 	}
277 
278 	if (!icsk) {
279 		handler->idiag_get_info(sk, r, NULL);
280 		goto out;
281 	}
282 
283 	icsk_pending = smp_load_acquire(&icsk->icsk_pending);
284 	if (icsk_pending == ICSK_TIME_RETRANS ||
285 	    icsk_pending == ICSK_TIME_REO_TIMEOUT ||
286 	    icsk_pending == ICSK_TIME_LOSS_PROBE) {
287 		r->idiag_timer = 1;
288 		r->idiag_retrans = READ_ONCE(icsk->icsk_retransmits);
289 		r->idiag_expires =
290 			jiffies_delta_to_msecs(icsk_timeout(icsk) - jiffies);
291 	} else if (icsk_pending == ICSK_TIME_PROBE0) {
292 		r->idiag_timer = 4;
293 		r->idiag_retrans = READ_ONCE(icsk->icsk_probes_out);
294 		r->idiag_expires =
295 			jiffies_delta_to_msecs(icsk_timeout(icsk) - jiffies);
296 	} else if (timer_pending(&sk->sk_timer)) {
297 		r->idiag_timer = 2;
298 		r->idiag_retrans = READ_ONCE(icsk->icsk_probes_out);
299 		r->idiag_expires =
300 			jiffies_delta_to_msecs(sk->sk_timer.expires - jiffies);
301 	}
302 
303 	if ((ext & (1 << (INET_DIAG_INFO - 1))) && handler->idiag_info_size) {
304 		attr = nla_reserve_64bit(skb, INET_DIAG_INFO,
305 					 handler->idiag_info_size,
306 					 INET_DIAG_PAD);
307 		if (!attr)
308 			goto errout;
309 
310 		info = nla_data(attr);
311 	}
312 
313 	if (ext & (1 << (INET_DIAG_CONG - 1))) {
314 		int err = 0;
315 
316 		rcu_read_lock();
317 		ca_ops = READ_ONCE(icsk->icsk_ca_ops);
318 		if (ca_ops)
319 			err = nla_put_string(skb, INET_DIAG_CONG, ca_ops->name);
320 		rcu_read_unlock();
321 		if (err < 0)
322 			goto errout;
323 	}
324 
325 	handler->idiag_get_info(sk, r, info);
326 
327 	if (ext & (1 << (INET_DIAG_INFO - 1)) && handler->idiag_get_aux)
328 		if (handler->idiag_get_aux(sk, net_admin, skb) < 0)
329 			goto errout;
330 
331 	if (sk->sk_state < TCP_TIME_WAIT) {
332 		union tcp_cc_info info;
333 		size_t sz = 0;
334 		int attr;
335 
336 		rcu_read_lock();
337 		ca_ops = READ_ONCE(icsk->icsk_ca_ops);
338 		if (ca_ops && ca_ops->get_info)
339 			sz = ca_ops->get_info(sk, ext, &attr, &info);
340 		rcu_read_unlock();
341 		if (sz && nla_put(skb, attr, sz, &info) < 0)
342 			goto errout;
343 	}
344 
345 	/* Keep it at the end for potential retry with a larger skb,
346 	 * or else do best-effort fitting, which is only done for the
347 	 * first_nlmsg.
348 	 */
349 	if (cb_data->bpf_stg_diag) {
350 		bool first_nlmsg = ((unsigned char *)nlh == skb->data);
351 		unsigned int prev_min_dump_alloc;
352 		unsigned int total_nla_size = 0;
353 		unsigned int msg_len;
354 		int err;
355 
356 		msg_len = skb_tail_pointer(skb) - (unsigned char *)nlh;
357 		err = bpf_sk_storage_diag_put(cb_data->bpf_stg_diag, sk, skb,
358 					      INET_DIAG_SK_BPF_STORAGES,
359 					      &total_nla_size);
360 
361 		if (!err)
362 			goto out;
363 
364 		total_nla_size += msg_len;
365 		prev_min_dump_alloc = cb->min_dump_alloc;
366 		if (total_nla_size > prev_min_dump_alloc)
367 			cb->min_dump_alloc = min_t(u32, total_nla_size,
368 						   MAX_DUMP_ALLOC_SIZE);
369 
370 		if (!first_nlmsg)
371 			goto errout;
372 
373 		if (cb->min_dump_alloc > prev_min_dump_alloc)
374 			/* Retry with pskb_expand_head() with
375 			 * __GFP_DIRECT_RECLAIM
376 			 */
377 			goto errout;
378 
379 		WARN_ON_ONCE(total_nla_size <= prev_min_dump_alloc);
380 
381 		/* Send what we have for this sk
382 		 * and move on to the next sk in the following
383 		 * dump()
384 		 */
385 	}
386 
387 out:
388 	nlmsg_end(skb, nlh);
389 	return 0;
390 
391 errout:
392 	nlmsg_cancel(skb, nlh);
393 	return -EMSGSIZE;
394 }
395 EXPORT_SYMBOL_GPL(inet_sk_diag_fill);
396 
397 static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb,
398 			       const struct nlmsghdr *nlh,
399 			       int hdrlen,
400 			       const struct inet_diag_req_v2 *req)
401 {
402 	const struct inet_diag_handler *handler;
403 	struct inet_diag_dump_data dump_data;
404 	int err, protocol;
405 
406 	memset(&dump_data, 0, sizeof(dump_data));
407 	err = inet_diag_parse_attrs(nlh, hdrlen, dump_data.req_nlas);
408 	if (err)
409 		return err;
410 
411 	protocol = inet_diag_get_protocol(req, &dump_data);
412 
413 	handler = inet_diag_lock_handler(protocol);
414 	if (!handler)
415 		return -ENOENT;
416 
417 	if (cmd == SOCK_DIAG_BY_FAMILY) {
418 		struct netlink_callback cb = {
419 			.nlh = nlh,
420 			.skb = in_skb,
421 			.data = &dump_data,
422 		};
423 		err = handler->dump_one(&cb, req);
424 	} else if (cmd == SOCK_DESTROY && handler->destroy) {
425 		err = handler->destroy(in_skb, req);
426 	} else {
427 		err = -EOPNOTSUPP;
428 	}
429 	inet_diag_unlock_handler(handler);
430 
431 	return err;
432 }
433 
434 static int bitstring_match(const __be32 *a1, const __be32 *a2, int bits)
435 {
436 	int words = bits >> 5;
437 
438 	bits &= 0x1f;
439 
440 	if (words) {
441 		if (memcmp(a1, a2, words << 2))
442 			return 0;
443 	}
444 	if (bits) {
445 		__be32 w1, w2;
446 		__be32 mask;
447 
448 		w1 = a1[words];
449 		w2 = a2[words];
450 
451 		mask = htonl((0xffffffff) << (32 - bits));
452 
453 		if ((w1 ^ w2) & mask)
454 			return 0;
455 	}
456 
457 	return 1;
458 }
459 
460 static int inet_diag_bc_run(const struct nlattr *_bc,
461 			    const struct inet_diag_entry *entry)
462 {
463 	const void *bc = nla_data(_bc);
464 	int len = nla_len(_bc);
465 
466 	while (len > 0) {
467 		int yes = 1;
468 		const struct inet_diag_bc_op *op = bc;
469 
470 		switch (op->code) {
471 		case INET_DIAG_BC_NOP:
472 			break;
473 		case INET_DIAG_BC_JMP:
474 			yes = 0;
475 			break;
476 		case INET_DIAG_BC_S_EQ:
477 			yes = entry->sport == op[1].no;
478 			break;
479 		case INET_DIAG_BC_S_GE:
480 			yes = entry->sport >= op[1].no;
481 			break;
482 		case INET_DIAG_BC_S_LE:
483 			yes = entry->sport <= op[1].no;
484 			break;
485 		case INET_DIAG_BC_D_EQ:
486 			yes = entry->dport == op[1].no;
487 			break;
488 		case INET_DIAG_BC_D_GE:
489 			yes = entry->dport >= op[1].no;
490 			break;
491 		case INET_DIAG_BC_D_LE:
492 			yes = entry->dport <= op[1].no;
493 			break;
494 		case INET_DIAG_BC_AUTO:
495 			yes = !(entry->userlocks & SOCK_BINDPORT_LOCK);
496 			break;
497 		case INET_DIAG_BC_S_COND:
498 		case INET_DIAG_BC_D_COND: {
499 			const struct inet_diag_hostcond *cond;
500 			const __be32 *addr;
501 
502 			cond = (const struct inet_diag_hostcond *)(op + 1);
503 			if (cond->port != -1 &&
504 			    cond->port != (op->code == INET_DIAG_BC_S_COND ?
505 					     entry->sport : entry->dport)) {
506 				yes = 0;
507 				break;
508 			}
509 
510 			if (op->code == INET_DIAG_BC_S_COND)
511 				addr = entry->saddr;
512 			else
513 				addr = entry->daddr;
514 
515 			if (cond->family != AF_UNSPEC &&
516 			    cond->family != entry->family) {
517 				if (entry->family == AF_INET6 &&
518 				    cond->family == AF_INET) {
519 					if (addr[0] == 0 && addr[1] == 0 &&
520 					    addr[2] == htonl(0xffff) &&
521 					    bitstring_match(addr + 3,
522 							    cond->addr,
523 							    cond->prefix_len))
524 						break;
525 				}
526 				yes = 0;
527 				break;
528 			}
529 
530 			if (cond->prefix_len == 0)
531 				break;
532 			if (bitstring_match(addr, cond->addr,
533 					    cond->prefix_len))
534 				break;
535 			yes = 0;
536 			break;
537 		}
538 		case INET_DIAG_BC_DEV_COND: {
539 			u32 ifindex;
540 
541 			ifindex = *((const u32 *)(op + 1));
542 			if (ifindex != entry->ifindex)
543 				yes = 0;
544 			break;
545 		}
546 		case INET_DIAG_BC_MARK_COND: {
547 			struct inet_diag_markcond *cond;
548 
549 			cond = (struct inet_diag_markcond *)(op + 1);
550 			if ((entry->mark & cond->mask) != cond->mark)
551 				yes = 0;
552 			break;
553 		}
554 #ifdef CONFIG_SOCK_CGROUP_DATA
555 		case INET_DIAG_BC_CGROUP_COND: {
556 			u64 cgroup_id;
557 
558 			cgroup_id = get_unaligned((const u64 *)(op + 1));
559 			if (cgroup_id != entry->cgroup_id)
560 				yes = 0;
561 			break;
562 		}
563 #endif
564 		}
565 
566 		if (yes) {
567 			len -= op->yes;
568 			bc += op->yes;
569 		} else {
570 			len -= op->no;
571 			bc += op->no;
572 		}
573 	}
574 	return len == 0;
575 }
576 
577 /* This helper is available for all sockets (ESTABLISH, TIMEWAIT, SYN_RECV)
578  */
579 static void entry_fill_addrs(struct inet_diag_entry *entry,
580 			     const struct sock *sk)
581 {
582 #if IS_ENABLED(CONFIG_IPV6)
583 	if (sk->sk_family == AF_INET6) {
584 		entry->saddr = sk->sk_v6_rcv_saddr.s6_addr32;
585 		entry->daddr = sk->sk_v6_daddr.s6_addr32;
586 	} else
587 #endif
588 	{
589 		entry->saddr = &sk->sk_rcv_saddr;
590 		entry->daddr = &sk->sk_daddr;
591 	}
592 }
593 
594 int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk)
595 {
596 	struct inet_sock *inet = inet_sk(sk);
597 	struct inet_diag_entry entry;
598 
599 	if (!bc)
600 		return 1;
601 
602 	entry.family = sk->sk_family;
603 	entry_fill_addrs(&entry, sk);
604 	entry.sport = inet->inet_num;
605 	entry.dport = ntohs(inet->inet_dport);
606 	entry.ifindex = sk->sk_bound_dev_if;
607 	entry.userlocks = sk_fullsock(sk) ? sk->sk_userlocks : 0;
608 	if (sk_fullsock(sk))
609 		entry.mark = READ_ONCE(sk->sk_mark);
610 	else if (sk->sk_state == TCP_NEW_SYN_RECV)
611 		entry.mark = inet_rsk(inet_reqsk(sk))->ir_mark;
612 	else if (sk->sk_state == TCP_TIME_WAIT)
613 		entry.mark = inet_twsk(sk)->tw_mark;
614 	else
615 		entry.mark = 0;
616 #ifdef CONFIG_SOCK_CGROUP_DATA
617 	entry.cgroup_id = sk_fullsock(sk) ?
618 		cgroup_id(sock_cgroup_ptr(&sk->sk_cgrp_data)) : 0;
619 #endif
620 
621 	return inet_diag_bc_run(bc, &entry);
622 }
623 EXPORT_SYMBOL_GPL(inet_diag_bc_sk);
624 
625 static int valid_cc(const void *bc, int len, int cc)
626 {
627 	while (len >= 0) {
628 		const struct inet_diag_bc_op *op = bc;
629 
630 		if (cc > len)
631 			return 0;
632 		if (cc == len)
633 			return 1;
634 		if (op->yes < 4 || op->yes & 3)
635 			return 0;
636 		len -= op->yes;
637 		bc  += op->yes;
638 	}
639 	return 0;
640 }
641 
642 /* data is u32 ifindex */
643 static bool valid_devcond(const struct inet_diag_bc_op *op, int len,
644 			  int *min_len)
645 {
646 	/* Check ifindex space. */
647 	*min_len += sizeof(u32);
648 	if (len < *min_len)
649 		return false;
650 
651 	return true;
652 }
653 /* Validate an inet_diag_hostcond. */
654 static bool valid_hostcond(const struct inet_diag_bc_op *op, int len,
655 			   int *min_len)
656 {
657 	struct inet_diag_hostcond *cond;
658 	int addr_len;
659 
660 	/* Check hostcond space. */
661 	*min_len += sizeof(struct inet_diag_hostcond);
662 	if (len < *min_len)
663 		return false;
664 	cond = (struct inet_diag_hostcond *)(op + 1);
665 
666 	/* Check address family and address length. */
667 	switch (cond->family) {
668 	case AF_UNSPEC:
669 		addr_len = 0;
670 		break;
671 	case AF_INET:
672 		addr_len = sizeof(struct in_addr);
673 		break;
674 	case AF_INET6:
675 		addr_len = sizeof(struct in6_addr);
676 		break;
677 	default:
678 		return false;
679 	}
680 	*min_len += addr_len;
681 	if (len < *min_len)
682 		return false;
683 
684 	/* Check prefix length (in bits) vs address length (in bytes). */
685 	if (cond->prefix_len > 8 * addr_len)
686 		return false;
687 
688 	return true;
689 }
690 
691 /* Validate a port comparison operator. */
692 static bool valid_port_comparison(const struct inet_diag_bc_op *op,
693 				  int len, int *min_len)
694 {
695 	/* Port comparisons put the port in a follow-on inet_diag_bc_op. */
696 	*min_len += sizeof(struct inet_diag_bc_op);
697 	if (len < *min_len)
698 		return false;
699 	return true;
700 }
701 
702 static bool valid_markcond(const struct inet_diag_bc_op *op, int len,
703 			   int *min_len)
704 {
705 	*min_len += sizeof(struct inet_diag_markcond);
706 	return len >= *min_len;
707 }
708 
709 #ifdef CONFIG_SOCK_CGROUP_DATA
710 static bool valid_cgroupcond(const struct inet_diag_bc_op *op, int len,
711 			     int *min_len)
712 {
713 	*min_len += sizeof(u64);
714 	return len >= *min_len;
715 }
716 #endif
717 
718 static int inet_diag_bc_audit(const struct nlattr *attr,
719 			      const struct sk_buff *skb)
720 {
721 	bool net_admin = netlink_net_capable(skb, CAP_NET_ADMIN);
722 	const void *bytecode, *bc;
723 	int bytecode_len, len;
724 
725 	if (!attr || nla_len(attr) < sizeof(struct inet_diag_bc_op))
726 		return -EINVAL;
727 
728 	bytecode = bc = nla_data(attr);
729 	len = bytecode_len = nla_len(attr);
730 
731 	while (len > 0) {
732 		int min_len = sizeof(struct inet_diag_bc_op);
733 		const struct inet_diag_bc_op *op = bc;
734 
735 		switch (op->code) {
736 		case INET_DIAG_BC_S_COND:
737 		case INET_DIAG_BC_D_COND:
738 			if (!valid_hostcond(bc, len, &min_len))
739 				return -EINVAL;
740 			break;
741 		case INET_DIAG_BC_DEV_COND:
742 			if (!valid_devcond(bc, len, &min_len))
743 				return -EINVAL;
744 			break;
745 		case INET_DIAG_BC_S_EQ:
746 		case INET_DIAG_BC_S_GE:
747 		case INET_DIAG_BC_S_LE:
748 		case INET_DIAG_BC_D_EQ:
749 		case INET_DIAG_BC_D_GE:
750 		case INET_DIAG_BC_D_LE:
751 			if (!valid_port_comparison(bc, len, &min_len))
752 				return -EINVAL;
753 			break;
754 		case INET_DIAG_BC_MARK_COND:
755 			if (!net_admin)
756 				return -EPERM;
757 			if (!valid_markcond(bc, len, &min_len))
758 				return -EINVAL;
759 			break;
760 #ifdef CONFIG_SOCK_CGROUP_DATA
761 		case INET_DIAG_BC_CGROUP_COND:
762 			if (!valid_cgroupcond(bc, len, &min_len))
763 				return -EINVAL;
764 			break;
765 #endif
766 		case INET_DIAG_BC_AUTO:
767 		case INET_DIAG_BC_JMP:
768 		case INET_DIAG_BC_NOP:
769 			break;
770 		default:
771 			return -EINVAL;
772 		}
773 
774 		if (op->code != INET_DIAG_BC_NOP) {
775 			if (op->no < min_len || op->no > len + 4 || op->no & 3)
776 				return -EINVAL;
777 			if (op->no < len &&
778 			    !valid_cc(bytecode, bytecode_len, len - op->no))
779 				return -EINVAL;
780 		}
781 
782 		if (op->yes < min_len || op->yes > len + 4 || op->yes & 3)
783 			return -EINVAL;
784 		bc  += op->yes;
785 		len -= op->yes;
786 	}
787 	return len == 0 ? 0 : -EINVAL;
788 }
789 
790 static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
791 			    const struct inet_diag_req_v2 *r)
792 {
793 	struct inet_diag_dump_data *cb_data = cb->data;
794 	const struct inet_diag_handler *handler;
795 	u32 prev_min_dump_alloc;
796 	int protocol, err = 0;
797 
798 	protocol = inet_diag_get_protocol(r, cb_data);
799 
800 again:
801 	prev_min_dump_alloc = cb->min_dump_alloc;
802 	handler = inet_diag_lock_handler(protocol);
803 	if (handler) {
804 		handler->dump(skb, cb, r);
805 		inet_diag_unlock_handler(handler);
806 	} else {
807 		err = -ENOENT;
808 	}
809 	/* The skb is not large enough to fit one sk info and
810 	 * inet_sk_diag_fill() has requested for a larger skb.
811 	 */
812 	if (!skb->len && cb->min_dump_alloc > prev_min_dump_alloc) {
813 		err = pskb_expand_head(skb, 0, cb->min_dump_alloc, GFP_KERNEL);
814 		if (!err)
815 			goto again;
816 	}
817 
818 	return err ? : skb->len;
819 }
820 
821 static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
822 {
823 	return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh));
824 }
825 
826 static int __inet_diag_dump_start(struct netlink_callback *cb, int hdrlen)
827 {
828 	const struct nlmsghdr *nlh = cb->nlh;
829 	struct inet_diag_dump_data *cb_data;
830 	struct sk_buff *skb = cb->skb;
831 	struct nlattr *nla;
832 	int err;
833 
834 	cb_data = kzalloc(sizeof(*cb_data), GFP_KERNEL);
835 	if (!cb_data)
836 		return -ENOMEM;
837 
838 	err = inet_diag_parse_attrs(nlh, hdrlen, cb_data->req_nlas);
839 	if (err) {
840 		kfree(cb_data);
841 		return err;
842 	}
843 	nla = cb_data->inet_diag_nla_bc;
844 	if (nla) {
845 		err = inet_diag_bc_audit(nla, skb);
846 		if (err) {
847 			kfree(cb_data);
848 			return err;
849 		}
850 	}
851 
852 	nla = cb_data->inet_diag_nla_bpf_stgs;
853 	if (nla) {
854 		struct bpf_sk_storage_diag *bpf_stg_diag;
855 
856 		bpf_stg_diag = bpf_sk_storage_diag_alloc(nla);
857 		if (IS_ERR(bpf_stg_diag)) {
858 			kfree(cb_data);
859 			return PTR_ERR(bpf_stg_diag);
860 		}
861 		cb_data->bpf_stg_diag = bpf_stg_diag;
862 	}
863 
864 	cb->data = cb_data;
865 	return 0;
866 }
867 
868 static int inet_diag_dump_start(struct netlink_callback *cb)
869 {
870 	return __inet_diag_dump_start(cb, sizeof(struct inet_diag_req_v2));
871 }
872 
873 static int inet_diag_dump_start_compat(struct netlink_callback *cb)
874 {
875 	return __inet_diag_dump_start(cb, sizeof(struct inet_diag_req));
876 }
877 
878 static int inet_diag_dump_done(struct netlink_callback *cb)
879 {
880 	struct inet_diag_dump_data *cb_data = cb->data;
881 
882 	bpf_sk_storage_diag_free(cb_data->bpf_stg_diag);
883 	kfree(cb->data);
884 
885 	return 0;
886 }
887 
888 static int inet_diag_type2proto(int type)
889 {
890 	switch (type) {
891 	case TCPDIAG_GETSOCK:
892 		return IPPROTO_TCP;
893 	default:
894 		return 0;
895 	}
896 }
897 
898 static int inet_diag_dump_compat(struct sk_buff *skb,
899 				 struct netlink_callback *cb)
900 {
901 	struct inet_diag_req *rc = nlmsg_data(cb->nlh);
902 	struct inet_diag_req_v2 req;
903 
904 	req.sdiag_family = AF_UNSPEC; /* compatibility */
905 	req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type);
906 	req.idiag_ext = rc->idiag_ext;
907 	req.pad = 0;
908 	req.idiag_states = rc->idiag_states;
909 	req.id = rc->id;
910 
911 	return __inet_diag_dump(skb, cb, &req);
912 }
913 
914 static int inet_diag_get_exact_compat(struct sk_buff *in_skb,
915 				      const struct nlmsghdr *nlh)
916 {
917 	struct inet_diag_req *rc = nlmsg_data(nlh);
918 	struct inet_diag_req_v2 req;
919 
920 	req.sdiag_family = rc->idiag_family;
921 	req.sdiag_protocol = inet_diag_type2proto(nlh->nlmsg_type);
922 	req.idiag_ext = rc->idiag_ext;
923 	req.pad = 0;
924 	req.idiag_states = rc->idiag_states;
925 	req.id = rc->id;
926 
927 	return inet_diag_cmd_exact(SOCK_DIAG_BY_FAMILY, in_skb, nlh,
928 				   sizeof(struct inet_diag_req), &req);
929 }
930 
931 static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
932 {
933 	int hdrlen = sizeof(struct inet_diag_req);
934 	struct net *net = sock_net(skb->sk);
935 
936 	if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX ||
937 	    nlmsg_len(nlh) < hdrlen)
938 		return -EINVAL;
939 
940 	if (nlh->nlmsg_flags & NLM_F_DUMP) {
941 		struct netlink_dump_control c = {
942 			.start = inet_diag_dump_start_compat,
943 			.done = inet_diag_dump_done,
944 			.dump = inet_diag_dump_compat,
945 		};
946 		return netlink_dump_start(net->diag_nlsk, skb, nlh, &c);
947 	}
948 
949 	return inet_diag_get_exact_compat(skb, nlh);
950 }
951 
952 static int inet_diag_handler_cmd(struct sk_buff *skb, struct nlmsghdr *h)
953 {
954 	int hdrlen = sizeof(struct inet_diag_req_v2);
955 	struct net *net = sock_net(skb->sk);
956 
957 	if (nlmsg_len(h) < hdrlen)
958 		return -EINVAL;
959 
960 	if (h->nlmsg_type == SOCK_DIAG_BY_FAMILY &&
961 	    h->nlmsg_flags & NLM_F_DUMP) {
962 		struct netlink_dump_control c = {
963 			.start = inet_diag_dump_start,
964 			.done = inet_diag_dump_done,
965 			.dump = inet_diag_dump,
966 		};
967 		return netlink_dump_start(net->diag_nlsk, skb, h, &c);
968 	}
969 
970 	return inet_diag_cmd_exact(h->nlmsg_type, skb, h, hdrlen,
971 				   nlmsg_data(h));
972 }
973 
974 static
975 int inet_diag_handler_get_info(struct sk_buff *skb, struct sock *sk)
976 {
977 	const struct inet_diag_handler *handler;
978 	struct nlmsghdr *nlh;
979 	struct nlattr *attr;
980 	struct inet_diag_msg *r;
981 	void *info = NULL;
982 	int err = 0;
983 
984 	nlh = nlmsg_put(skb, 0, 0, SOCK_DIAG_BY_FAMILY, sizeof(*r), 0);
985 	if (!nlh)
986 		return -ENOMEM;
987 
988 	r = nlmsg_data(nlh);
989 	memset(r, 0, sizeof(*r));
990 	inet_diag_msg_common_fill(r, sk);
991 	if (sk->sk_type == SOCK_DGRAM || sk->sk_type == SOCK_STREAM)
992 		r->id.idiag_sport = inet_sk(sk)->inet_sport;
993 	r->idiag_state = sk->sk_state;
994 
995 	if ((err = nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol))) {
996 		nlmsg_cancel(skb, nlh);
997 		return err;
998 	}
999 
1000 	handler = inet_diag_lock_handler(sk->sk_protocol);
1001 	if (!handler) {
1002 		nlmsg_cancel(skb, nlh);
1003 		return -ENOENT;
1004 	}
1005 
1006 	attr = handler->idiag_info_size
1007 		? nla_reserve_64bit(skb, INET_DIAG_INFO,
1008 				    handler->idiag_info_size,
1009 				    INET_DIAG_PAD)
1010 		: NULL;
1011 	if (attr)
1012 		info = nla_data(attr);
1013 
1014 	handler->idiag_get_info(sk, r, info);
1015 	inet_diag_unlock_handler(handler);
1016 
1017 	nlmsg_end(skb, nlh);
1018 	return 0;
1019 }
1020 
1021 static const struct sock_diag_handler inet_diag_handler = {
1022 	.owner = THIS_MODULE,
1023 	.family = AF_INET,
1024 	.dump = inet_diag_handler_cmd,
1025 	.get_info = inet_diag_handler_get_info,
1026 	.destroy = inet_diag_handler_cmd,
1027 };
1028 
1029 static const struct sock_diag_handler inet6_diag_handler = {
1030 	.owner = THIS_MODULE,
1031 	.family = AF_INET6,
1032 	.dump = inet_diag_handler_cmd,
1033 	.get_info = inet_diag_handler_get_info,
1034 	.destroy = inet_diag_handler_cmd,
1035 };
1036 
1037 int inet_diag_register(const struct inet_diag_handler *h)
1038 {
1039 	const __u16 type = h->idiag_type;
1040 
1041 	if (type >= IPPROTO_MAX)
1042 		return -EINVAL;
1043 
1044 	return !cmpxchg((const struct inet_diag_handler **)&inet_diag_table[type],
1045 			NULL, h) ? 0 : -EEXIST;
1046 }
1047 EXPORT_SYMBOL_GPL(inet_diag_register);
1048 
1049 void inet_diag_unregister(const struct inet_diag_handler *h)
1050 {
1051 	const __u16 type = h->idiag_type;
1052 
1053 	if (type >= IPPROTO_MAX)
1054 		return;
1055 
1056 	xchg((const struct inet_diag_handler **)&inet_diag_table[type],
1057 	     NULL);
1058 }
1059 EXPORT_SYMBOL_GPL(inet_diag_unregister);
1060 
1061 static const struct sock_diag_inet_compat inet_diag_compat = {
1062 	.owner	= THIS_MODULE,
1063 	.fn	= inet_diag_rcv_msg_compat,
1064 };
1065 
1066 static int __init inet_diag_init(void)
1067 {
1068 	const int inet_diag_table_size = (IPPROTO_MAX *
1069 					  sizeof(struct inet_diag_handler *));
1070 	int err = -ENOMEM;
1071 
1072 	inet_diag_table = kzalloc(inet_diag_table_size, GFP_KERNEL);
1073 	if (!inet_diag_table)
1074 		goto out;
1075 
1076 	err = sock_diag_register(&inet_diag_handler);
1077 	if (err)
1078 		goto out_free_nl;
1079 
1080 	err = sock_diag_register(&inet6_diag_handler);
1081 	if (err)
1082 		goto out_free_inet;
1083 
1084 	sock_diag_register_inet_compat(&inet_diag_compat);
1085 out:
1086 	return err;
1087 
1088 out_free_inet:
1089 	sock_diag_unregister(&inet_diag_handler);
1090 out_free_nl:
1091 	kfree(inet_diag_table);
1092 	goto out;
1093 }
1094 
1095 static void __exit inet_diag_exit(void)
1096 {
1097 	sock_diag_unregister(&inet6_diag_handler);
1098 	sock_diag_unregister(&inet_diag_handler);
1099 	sock_diag_unregister_inet_compat(&inet_diag_compat);
1100 	kfree(inet_diag_table);
1101 }
1102 
1103 module_init(inet_diag_init);
1104 module_exit(inet_diag_exit);
1105 MODULE_LICENSE("GPL");
1106 MODULE_DESCRIPTION("INET/INET6: socket monitoring via SOCK_DIAG");
1107 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2 /* AF_INET */);
1108 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10 /* AF_INET6 */);
1109