1 // SPDX-License-Identifier: GPL-2.0 2 #include <linux/filter.h> 3 #include <linux/mutex.h> 4 #include <linux/socket.h> 5 #include <linux/skbuff.h> 6 #include <net/netlink.h> 7 #include <net/net_namespace.h> 8 #include <linux/module.h> 9 #include <net/sock.h> 10 #include <linux/kernel.h> 11 #include <linux/tcp.h> 12 #include <linux/workqueue.h> 13 #include <linux/nospec.h> 14 #include <linux/cookie.h> 15 #include <linux/inet_diag.h> 16 #include <linux/sock_diag.h> 17 18 static const struct sock_diag_handler __rcu *sock_diag_handlers[AF_MAX]; 19 20 static const struct sock_diag_inet_compat __rcu *inet_rcv_compat; 21 22 static struct workqueue_struct *broadcast_wq; 23 24 DEFINE_COOKIE(sock_cookie); 25 26 u64 __sock_gen_cookie(struct sock *sk) 27 { 28 u64 res = atomic64_read(&sk->sk_cookie); 29 30 if (!res) { 31 u64 new = gen_cookie_next(&sock_cookie); 32 33 atomic64_cmpxchg(&sk->sk_cookie, res, new); 34 35 /* Another thread might have changed sk_cookie before us. */ 36 res = atomic64_read(&sk->sk_cookie); 37 } 38 return res; 39 } 40 41 int sock_diag_check_cookie(struct sock *sk, const __u32 *cookie) 42 { 43 u64 res; 44 45 if (cookie[0] == INET_DIAG_NOCOOKIE && cookie[1] == INET_DIAG_NOCOOKIE) 46 return 0; 47 48 res = sock_gen_cookie(sk); 49 if ((u32)res != cookie[0] || (u32)(res >> 32) != cookie[1]) 50 return -ESTALE; 51 52 return 0; 53 } 54 EXPORT_SYMBOL_GPL(sock_diag_check_cookie); 55 56 void sock_diag_save_cookie(struct sock *sk, __u32 *cookie) 57 { 58 u64 res = sock_gen_cookie(sk); 59 60 cookie[0] = (u32)res; 61 cookie[1] = (u32)(res >> 32); 62 } 63 EXPORT_SYMBOL_GPL(sock_diag_save_cookie); 64 65 int sock_diag_put_meminfo(struct sock *sk, struct sk_buff *skb, int attrtype) 66 { 67 u32 mem[SK_MEMINFO_VARS]; 68 69 sk_get_meminfo(sk, mem); 70 71 return nla_put(skb, attrtype, sizeof(mem), &mem); 72 } 73 EXPORT_SYMBOL_GPL(sock_diag_put_meminfo); 74 75 int sock_diag_put_filterinfo(bool may_report_filterinfo, struct sock *sk, 76 struct sk_buff *skb, int attrtype) 77 { 78 struct sock_fprog_kern *fprog; 79 struct sk_filter *filter; 80 struct nlattr *attr; 81 unsigned int flen; 82 int err = 0; 83 84 if (!may_report_filterinfo) { 85 nla_reserve(skb, attrtype, 0); 86 return 0; 87 } 88 89 rcu_read_lock(); 90 filter = rcu_dereference(sk->sk_filter); 91 if (!filter) 92 goto out; 93 94 fprog = filter->prog->orig_prog; 95 if (!fprog) 96 goto out; 97 98 flen = bpf_classic_proglen(fprog); 99 100 attr = nla_reserve(skb, attrtype, flen); 101 if (attr == NULL) { 102 err = -EMSGSIZE; 103 goto out; 104 } 105 106 memcpy(nla_data(attr), fprog->filter, flen); 107 out: 108 rcu_read_unlock(); 109 return err; 110 } 111 EXPORT_SYMBOL(sock_diag_put_filterinfo); 112 113 struct broadcast_sk { 114 struct sock *sk; 115 struct work_struct work; 116 }; 117 118 static size_t sock_diag_nlmsg_size(void) 119 { 120 return NLMSG_ALIGN(sizeof(struct inet_diag_msg) 121 + nla_total_size(sizeof(u8)) /* INET_DIAG_PROTOCOL */ 122 + nla_total_size_64bit(sizeof(struct tcp_info))); /* INET_DIAG_INFO */ 123 } 124 125 static const struct sock_diag_handler *sock_diag_lock_handler(int family) 126 { 127 const struct sock_diag_handler *handler; 128 129 rcu_read_lock(); 130 handler = rcu_dereference(sock_diag_handlers[family]); 131 if (handler && !try_module_get(handler->owner)) 132 handler = NULL; 133 rcu_read_unlock(); 134 135 return handler; 136 } 137 138 static void sock_diag_unlock_handler(const struct sock_diag_handler *handler) 139 { 140 module_put(handler->owner); 141 } 142 143 static void sock_diag_broadcast_destroy_work(struct work_struct *work) 144 { 145 struct broadcast_sk *bsk = 146 container_of(work, struct broadcast_sk, work); 147 struct sock *sk = bsk->sk; 148 const struct sock_diag_handler *hndl; 149 struct sk_buff *skb; 150 const enum sknetlink_groups group = sock_diag_destroy_group(sk); 151 int err = -1; 152 153 WARN_ON(group == SKNLGRP_NONE); 154 155 skb = nlmsg_new(sock_diag_nlmsg_size(), GFP_KERNEL); 156 if (!skb) 157 goto out; 158 159 hndl = sock_diag_lock_handler(sk->sk_family); 160 if (hndl) { 161 if (hndl->get_info) 162 err = hndl->get_info(skb, sk); 163 sock_diag_unlock_handler(hndl); 164 } 165 if (!err) 166 nlmsg_multicast(sock_net(sk)->diag_nlsk, skb, 0, group, 167 GFP_KERNEL); 168 else 169 kfree_skb(skb); 170 out: 171 sk_destruct(sk); 172 kfree(bsk); 173 } 174 175 void sock_diag_broadcast_destroy(struct sock *sk) 176 { 177 /* Note, this function is often called from an interrupt context. */ 178 struct broadcast_sk *bsk = 179 kmalloc_obj(struct broadcast_sk, GFP_ATOMIC); 180 if (!bsk) 181 return sk_destruct(sk); 182 bsk->sk = sk; 183 INIT_WORK(&bsk->work, sock_diag_broadcast_destroy_work); 184 queue_work(broadcast_wq, &bsk->work); 185 } 186 187 void sock_diag_register_inet_compat(const struct sock_diag_inet_compat *ptr) 188 { 189 xchg(&inet_rcv_compat, RCU_INITIALIZER(ptr)); 190 } 191 EXPORT_SYMBOL_GPL(sock_diag_register_inet_compat); 192 193 void sock_diag_unregister_inet_compat(const struct sock_diag_inet_compat *ptr) 194 { 195 const struct sock_diag_inet_compat *old; 196 197 old = unrcu_pointer(xchg(&inet_rcv_compat, NULL)); 198 WARN_ON_ONCE(old != ptr); 199 } 200 EXPORT_SYMBOL_GPL(sock_diag_unregister_inet_compat); 201 202 int sock_diag_register(const struct sock_diag_handler *hndl) 203 { 204 int family = hndl->family; 205 206 if (family >= AF_MAX) 207 return -EINVAL; 208 209 return !cmpxchg((const struct sock_diag_handler **) 210 &sock_diag_handlers[family], 211 NULL, hndl) ? 0 : -EBUSY; 212 } 213 EXPORT_SYMBOL_GPL(sock_diag_register); 214 215 void sock_diag_unregister(const struct sock_diag_handler *hndl) 216 { 217 int family = hndl->family; 218 219 if (family >= AF_MAX) 220 return; 221 222 xchg((const struct sock_diag_handler **)&sock_diag_handlers[family], 223 NULL); 224 } 225 EXPORT_SYMBOL_GPL(sock_diag_unregister); 226 227 static int __sock_diag_cmd(struct sk_buff *skb, struct nlmsghdr *nlh) 228 { 229 int err; 230 struct sock_diag_req *req = nlmsg_data(nlh); 231 const struct sock_diag_handler *hndl; 232 233 if (nlmsg_len(nlh) < sizeof(*req)) 234 return -EINVAL; 235 236 if (req->sdiag_family >= AF_MAX) 237 return -EINVAL; 238 req->sdiag_family = array_index_nospec(req->sdiag_family, AF_MAX); 239 240 if (!rcu_access_pointer(sock_diag_handlers[req->sdiag_family])) 241 sock_load_diag_module(req->sdiag_family, 0); 242 243 hndl = sock_diag_lock_handler(req->sdiag_family); 244 if (hndl == NULL) 245 return -ENOENT; 246 247 if (nlh->nlmsg_type == SOCK_DIAG_BY_FAMILY) 248 err = hndl->dump(skb, nlh); 249 else if (nlh->nlmsg_type == SOCK_DESTROY && hndl->destroy) 250 err = hndl->destroy(skb, nlh); 251 else 252 err = -EOPNOTSUPP; 253 sock_diag_unlock_handler(hndl); 254 255 return err; 256 } 257 258 static int sock_diag_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh, 259 struct netlink_ext_ack *extack) 260 { 261 const struct sock_diag_inet_compat *ptr; 262 int ret; 263 264 switch (nlh->nlmsg_type) { 265 case TCPDIAG_GETSOCK: 266 if (!rcu_access_pointer(inet_rcv_compat)) 267 sock_load_diag_module(AF_INET, 0); 268 269 rcu_read_lock(); 270 ptr = rcu_dereference(inet_rcv_compat); 271 if (ptr && !try_module_get(ptr->owner)) 272 ptr = NULL; 273 rcu_read_unlock(); 274 275 ret = -EOPNOTSUPP; 276 if (ptr) { 277 ret = ptr->fn(skb, nlh); 278 module_put(ptr->owner); 279 } 280 281 return ret; 282 case SOCK_DIAG_BY_FAMILY: 283 case SOCK_DESTROY: 284 return __sock_diag_cmd(skb, nlh); 285 default: 286 return -EINVAL; 287 } 288 } 289 290 static void sock_diag_rcv(struct sk_buff *skb) 291 { 292 netlink_rcv_skb(skb, &sock_diag_rcv_msg); 293 } 294 295 static int sock_diag_bind(struct net *net, int group) 296 { 297 switch (group) { 298 case SKNLGRP_INET_TCP_DESTROY: 299 case SKNLGRP_INET_UDP_DESTROY: 300 if (!rcu_access_pointer(sock_diag_handlers[AF_INET])) 301 sock_load_diag_module(AF_INET, 0); 302 break; 303 case SKNLGRP_INET6_TCP_DESTROY: 304 case SKNLGRP_INET6_UDP_DESTROY: 305 if (!rcu_access_pointer(sock_diag_handlers[AF_INET6])) 306 sock_load_diag_module(AF_INET6, 0); 307 break; 308 } 309 return 0; 310 } 311 312 int sock_diag_destroy(struct sock *sk, int err) 313 { 314 if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) 315 return -EPERM; 316 317 if (!sk->sk_prot->diag_destroy) 318 return -EOPNOTSUPP; 319 320 return sk->sk_prot->diag_destroy(sk, err); 321 } 322 EXPORT_SYMBOL_GPL(sock_diag_destroy); 323 324 static int __net_init diag_net_init(struct net *net) 325 { 326 struct netlink_kernel_cfg cfg = { 327 .groups = SKNLGRP_MAX, 328 .input = sock_diag_rcv, 329 .bind = sock_diag_bind, 330 .flags = NL_CFG_F_NONROOT_RECV, 331 }; 332 333 net->diag_nlsk = netlink_kernel_create(net, NETLINK_SOCK_DIAG, &cfg); 334 return net->diag_nlsk == NULL ? -ENOMEM : 0; 335 } 336 337 static void __net_exit diag_net_exit(struct net *net) 338 { 339 netlink_kernel_release(net->diag_nlsk); 340 net->diag_nlsk = NULL; 341 } 342 343 static struct pernet_operations diag_net_ops = { 344 .init = diag_net_init, 345 .exit = diag_net_exit, 346 }; 347 348 static int __init sock_diag_init(void) 349 { 350 broadcast_wq = alloc_workqueue("sock_diag_events", WQ_PERCPU, 0); 351 BUG_ON(!broadcast_wq); 352 return register_pernet_subsys(&diag_net_ops); 353 } 354 device_initcall(sock_diag_init); 355