1 // SPDX-License-Identifier: GPL-2.0
2 /*
3 * Management Component Transport Protocol (MCTP)
4 *
5 * Copyright (c) 2021 Code Construct
6 * Copyright (c) 2021 Google
7 */
8
9 #include <linux/compat.h>
10 #include <linux/if_arp.h>
11 #include <linux/net.h>
12 #include <linux/mctp.h>
13 #include <linux/module.h>
14 #include <linux/socket.h>
15
16 #include <net/mctp.h>
17 #include <net/mctpdevice.h>
18 #include <net/sock.h>
19
20 #define CREATE_TRACE_POINTS
21 #include <trace/events/mctp.h>
22
23 /* socket implementation */
24
25 static void mctp_sk_expire_keys(struct timer_list *timer);
26
mctp_release(struct socket * sock)27 static int mctp_release(struct socket *sock)
28 {
29 struct sock *sk = sock->sk;
30
31 if (sk) {
32 sock->sk = NULL;
33 sk->sk_prot->close(sk, 0);
34 }
35
36 return 0;
37 }
38
39 /* Generic sockaddr checks, padding checks only so far */
mctp_sockaddr_is_ok(const struct sockaddr_mctp * addr)40 static bool mctp_sockaddr_is_ok(const struct sockaddr_mctp *addr)
41 {
42 return !addr->__smctp_pad0 && !addr->__smctp_pad1;
43 }
44
mctp_sockaddr_ext_is_ok(const struct sockaddr_mctp_ext * addr)45 static bool mctp_sockaddr_ext_is_ok(const struct sockaddr_mctp_ext *addr)
46 {
47 return !addr->__smctp_pad0[0] &&
48 !addr->__smctp_pad0[1] &&
49 !addr->__smctp_pad0[2];
50 }
51
mctp_bind(struct socket * sock,struct sockaddr * addr,int addrlen)52 static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen)
53 {
54 struct sock *sk = sock->sk;
55 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
56 struct net *net = sock_net(&msk->sk);
57 struct sockaddr_mctp *smctp;
58 int rc;
59
60 if (addrlen < sizeof(*smctp))
61 return -EINVAL;
62
63 if (addr->sa_family != AF_MCTP)
64 return -EAFNOSUPPORT;
65
66 if (!capable(CAP_NET_BIND_SERVICE))
67 return -EACCES;
68
69 /* it's a valid sockaddr for MCTP, cast and do protocol checks */
70 smctp = (struct sockaddr_mctp *)addr;
71
72 if (!mctp_sockaddr_is_ok(smctp))
73 return -EINVAL;
74
75 lock_sock(sk);
76
77 if (sk_hashed(sk)) {
78 rc = -EADDRINUSE;
79 goto out_release;
80 }
81
82 msk->bind_local_addr = smctp->smctp_addr.s_addr;
83
84 /* MCTP_NET_ANY with a specific EID is resolved to the default net
85 * at bind() time.
86 * For bind_addr=MCTP_ADDR_ANY it is handled specially at route
87 * lookup time.
88 */
89 if (smctp->smctp_network == MCTP_NET_ANY &&
90 msk->bind_local_addr != MCTP_ADDR_ANY) {
91 msk->bind_net = mctp_default_net(net);
92 } else {
93 msk->bind_net = smctp->smctp_network;
94 }
95
96 /* ignore the IC bit */
97 smctp->smctp_type &= 0x7f;
98
99 if (msk->bind_peer_set) {
100 if (msk->bind_type != smctp->smctp_type) {
101 /* Prior connect() had a different type */
102 rc = -EINVAL;
103 goto out_release;
104 }
105
106 if (msk->bind_net == MCTP_NET_ANY) {
107 /* Restrict to the network passed to connect() */
108 msk->bind_net = msk->bind_peer_net;
109 }
110
111 if (msk->bind_net != msk->bind_peer_net) {
112 /* connect() had a different net to bind() */
113 rc = -EINVAL;
114 goto out_release;
115 }
116 } else {
117 msk->bind_type = smctp->smctp_type;
118 }
119
120 rc = sk->sk_prot->hash(sk);
121
122 out_release:
123 release_sock(sk);
124
125 return rc;
126 }
127
128 /* Used to set a specific peer prior to bind. Not used for outbound
129 * connections (Tag Owner set) since MCTP is a datagram protocol.
130 */
mctp_connect(struct socket * sock,struct sockaddr * addr,int addrlen,int flags)131 static int mctp_connect(struct socket *sock, struct sockaddr *addr,
132 int addrlen, int flags)
133 {
134 struct sock *sk = sock->sk;
135 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
136 struct net *net = sock_net(&msk->sk);
137 struct sockaddr_mctp *smctp;
138 int rc;
139
140 if (addrlen != sizeof(*smctp))
141 return -EINVAL;
142
143 if (addr->sa_family != AF_MCTP)
144 return -EAFNOSUPPORT;
145
146 /* It's a valid sockaddr for MCTP, cast and do protocol checks */
147 smctp = (struct sockaddr_mctp *)addr;
148
149 if (!mctp_sockaddr_is_ok(smctp))
150 return -EINVAL;
151
152 /* Can't bind by tag */
153 if (smctp->smctp_tag)
154 return -EINVAL;
155
156 /* IC bit must be unset */
157 if (smctp->smctp_type & 0x80)
158 return -EINVAL;
159
160 lock_sock(sk);
161
162 if (sk_hashed(sk)) {
163 /* bind() already */
164 rc = -EADDRINUSE;
165 goto out_release;
166 }
167
168 if (msk->bind_peer_set) {
169 /* connect() already */
170 rc = -EADDRINUSE;
171 goto out_release;
172 }
173
174 msk->bind_peer_set = true;
175 msk->bind_peer_addr = smctp->smctp_addr.s_addr;
176 msk->bind_type = smctp->smctp_type;
177 if (smctp->smctp_network == MCTP_NET_ANY)
178 msk->bind_peer_net = mctp_default_net(net);
179 else
180 msk->bind_peer_net = smctp->smctp_network;
181
182 rc = 0;
183
184 out_release:
185 release_sock(sk);
186 return rc;
187 }
188
mctp_sendmsg(struct socket * sock,struct msghdr * msg,size_t len)189 static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
190 {
191 DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
192 int rc, addrlen = msg->msg_namelen;
193 struct sock *sk = sock->sk;
194 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
195 struct mctp_skb_cb *cb;
196 struct sk_buff *skb = NULL;
197 struct mctp_dst dst;
198 int hlen;
199
200 if (addr) {
201 const u8 tagbits = MCTP_TAG_MASK | MCTP_TAG_OWNER |
202 MCTP_TAG_PREALLOC;
203
204 if (addrlen < sizeof(struct sockaddr_mctp))
205 return -EINVAL;
206 if (addr->smctp_family != AF_MCTP)
207 return -EINVAL;
208 if (!mctp_sockaddr_is_ok(addr))
209 return -EINVAL;
210 if (addr->smctp_tag & ~tagbits)
211 return -EINVAL;
212 /* can't preallocate a non-owned tag */
213 if (addr->smctp_tag & MCTP_TAG_PREALLOC &&
214 !(addr->smctp_tag & MCTP_TAG_OWNER))
215 return -EINVAL;
216
217 } else {
218 /* TODO: connect()ed sockets */
219 return -EDESTADDRREQ;
220 }
221
222 if (!capable(CAP_NET_RAW))
223 return -EACCES;
224
225 if (addr->smctp_network == MCTP_NET_ANY)
226 addr->smctp_network = mctp_default_net(sock_net(sk));
227
228 /* direct addressing */
229 if (msk->addr_ext && addrlen >= sizeof(struct sockaddr_mctp_ext)) {
230 DECLARE_SOCKADDR(struct sockaddr_mctp_ext *,
231 extaddr, msg->msg_name);
232
233 if (!mctp_sockaddr_ext_is_ok(extaddr))
234 return -EINVAL;
235
236 rc = mctp_dst_from_extaddr(&dst, sock_net(sk),
237 extaddr->smctp_ifindex,
238 extaddr->smctp_halen,
239 extaddr->smctp_haddr);
240 if (rc)
241 return rc;
242
243 } else {
244 rc = mctp_route_lookup(sock_net(sk), addr->smctp_network,
245 addr->smctp_addr.s_addr, &dst);
246 if (rc)
247 return rc;
248 }
249
250 hlen = LL_RESERVED_SPACE(dst.dev->dev) + sizeof(struct mctp_hdr);
251
252 skb = sock_alloc_send_skb(sk, hlen + 1 + len,
253 msg->msg_flags & MSG_DONTWAIT, &rc);
254 if (!skb)
255 goto err_release_dst;
256
257 skb_reserve(skb, hlen);
258
259 /* set type as fist byte in payload */
260 *(u8 *)skb_put(skb, 1) = addr->smctp_type;
261
262 rc = memcpy_from_msg((void *)skb_put(skb, len), msg, len);
263 if (rc < 0)
264 goto err_free;
265
266 /* set up cb */
267 cb = __mctp_cb(skb);
268 cb->net = addr->smctp_network;
269
270 rc = mctp_local_output(sk, &dst, skb, addr->smctp_addr.s_addr,
271 addr->smctp_tag);
272
273 mctp_dst_release(&dst);
274 return rc ? : len;
275
276 err_free:
277 kfree_skb(skb);
278 err_release_dst:
279 mctp_dst_release(&dst);
280 return rc;
281 }
282
mctp_recvmsg(struct socket * sock,struct msghdr * msg,size_t len,int flags)283 static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
284 int flags)
285 {
286 DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
287 struct sock *sk = sock->sk;
288 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
289 struct sk_buff *skb;
290 size_t msglen;
291 u8 type;
292 int rc;
293
294 if (flags & ~(MSG_DONTWAIT | MSG_TRUNC | MSG_PEEK))
295 return -EOPNOTSUPP;
296
297 skb = skb_recv_datagram(sk, flags, &rc);
298 if (!skb)
299 return rc;
300
301 if (!skb->len) {
302 rc = 0;
303 goto out_free;
304 }
305
306 /* extract message type, remove from data */
307 type = *((u8 *)skb->data);
308 msglen = skb->len - 1;
309
310 if (len < msglen)
311 msg->msg_flags |= MSG_TRUNC;
312 else
313 len = msglen;
314
315 rc = skb_copy_datagram_msg(skb, 1, msg, len);
316 if (rc < 0)
317 goto out_free;
318
319 sock_recv_cmsgs(msg, sk, skb);
320
321 if (addr) {
322 struct mctp_skb_cb *cb = mctp_cb(skb);
323 /* TODO: expand mctp_skb_cb for header fields? */
324 struct mctp_hdr *hdr = mctp_hdr(skb);
325
326 addr = msg->msg_name;
327 addr->smctp_family = AF_MCTP;
328 addr->__smctp_pad0 = 0;
329 addr->smctp_network = cb->net;
330 addr->smctp_addr.s_addr = hdr->src;
331 addr->smctp_type = type;
332 addr->smctp_tag = hdr->flags_seq_tag &
333 (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
334 addr->__smctp_pad1 = 0;
335 msg->msg_namelen = sizeof(*addr);
336
337 if (msk->addr_ext) {
338 DECLARE_SOCKADDR(struct sockaddr_mctp_ext *, ae,
339 msg->msg_name);
340 msg->msg_namelen = sizeof(*ae);
341 ae->smctp_ifindex = cb->ifindex;
342 ae->smctp_halen = cb->halen;
343 memset(ae->__smctp_pad0, 0x0, sizeof(ae->__smctp_pad0));
344 memset(ae->smctp_haddr, 0x0, sizeof(ae->smctp_haddr));
345 memcpy(ae->smctp_haddr, cb->haddr, cb->halen);
346 }
347 }
348
349 rc = len;
350
351 if (flags & MSG_TRUNC)
352 rc = msglen;
353
354 out_free:
355 skb_free_datagram(sk, skb);
356 return rc;
357 }
358
359 /* We're done with the key; invalidate, stop reassembly, and remove from lists.
360 */
__mctp_key_remove(struct mctp_sk_key * key,struct net * net,unsigned long flags,unsigned long reason)361 static void __mctp_key_remove(struct mctp_sk_key *key, struct net *net,
362 unsigned long flags, unsigned long reason)
363 __releases(&key->lock)
364 __must_hold(&net->mctp.keys_lock)
365 {
366 struct sk_buff *skb;
367
368 trace_mctp_key_release(key, reason);
369 skb = key->reasm_head;
370 key->reasm_head = NULL;
371 key->reasm_dead = true;
372 key->valid = false;
373 mctp_dev_release_key(key->dev, key);
374 spin_unlock_irqrestore(&key->lock, flags);
375
376 if (!hlist_unhashed(&key->hlist)) {
377 hlist_del_init(&key->hlist);
378 hlist_del_init(&key->sklist);
379 /* unref for the lists */
380 mctp_key_unref(key);
381 }
382
383 kfree_skb(skb);
384 }
385
mctp_setsockopt(struct socket * sock,int level,int optname,sockptr_t optval,unsigned int optlen)386 static int mctp_setsockopt(struct socket *sock, int level, int optname,
387 sockptr_t optval, unsigned int optlen)
388 {
389 struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
390 int val;
391
392 if (level != SOL_MCTP)
393 return -EINVAL;
394
395 if (optname == MCTP_OPT_ADDR_EXT) {
396 if (optlen != sizeof(int))
397 return -EINVAL;
398 if (copy_from_sockptr(&val, optval, sizeof(int)))
399 return -EFAULT;
400 msk->addr_ext = val;
401 return 0;
402 }
403
404 return -ENOPROTOOPT;
405 }
406
mctp_getsockopt(struct socket * sock,int level,int optname,char __user * optval,int __user * optlen)407 static int mctp_getsockopt(struct socket *sock, int level, int optname,
408 char __user *optval, int __user *optlen)
409 {
410 struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
411 int len, val;
412
413 if (level != SOL_MCTP)
414 return -EINVAL;
415
416 if (get_user(len, optlen))
417 return -EFAULT;
418
419 if (optname == MCTP_OPT_ADDR_EXT) {
420 if (len != sizeof(int))
421 return -EINVAL;
422 val = !!msk->addr_ext;
423 if (copy_to_user(optval, &val, len))
424 return -EFAULT;
425 return 0;
426 }
427
428 return -ENOPROTOOPT;
429 }
430
431 /* helpers for reading/writing the tag ioc, handling compatibility across the
432 * two versions, and some basic API error checking
433 */
mctp_ioctl_tag_copy_from_user(unsigned long arg,struct mctp_ioc_tag_ctl2 * ctl,bool tagv2)434 static int mctp_ioctl_tag_copy_from_user(unsigned long arg,
435 struct mctp_ioc_tag_ctl2 *ctl,
436 bool tagv2)
437 {
438 struct mctp_ioc_tag_ctl ctl_compat;
439 unsigned long size;
440 void *ptr;
441 int rc;
442
443 if (tagv2) {
444 size = sizeof(*ctl);
445 ptr = ctl;
446 } else {
447 size = sizeof(ctl_compat);
448 ptr = &ctl_compat;
449 }
450
451 rc = copy_from_user(ptr, (void __user *)arg, size);
452 if (rc)
453 return -EFAULT;
454
455 if (!tagv2) {
456 /* compat, using defaults for new fields */
457 ctl->net = MCTP_INITIAL_DEFAULT_NET;
458 ctl->peer_addr = ctl_compat.peer_addr;
459 ctl->local_addr = MCTP_ADDR_ANY;
460 ctl->flags = ctl_compat.flags;
461 ctl->tag = ctl_compat.tag;
462 }
463
464 if (ctl->flags)
465 return -EINVAL;
466
467 if (ctl->local_addr != MCTP_ADDR_ANY &&
468 ctl->local_addr != MCTP_ADDR_NULL)
469 return -EINVAL;
470
471 return 0;
472 }
473
mctp_ioctl_tag_copy_to_user(unsigned long arg,struct mctp_ioc_tag_ctl2 * ctl,bool tagv2)474 static int mctp_ioctl_tag_copy_to_user(unsigned long arg,
475 struct mctp_ioc_tag_ctl2 *ctl,
476 bool tagv2)
477 {
478 struct mctp_ioc_tag_ctl ctl_compat;
479 unsigned long size;
480 void *ptr;
481 int rc;
482
483 if (tagv2) {
484 ptr = ctl;
485 size = sizeof(*ctl);
486 } else {
487 ctl_compat.peer_addr = ctl->peer_addr;
488 ctl_compat.tag = ctl->tag;
489 ctl_compat.flags = ctl->flags;
490
491 ptr = &ctl_compat;
492 size = sizeof(ctl_compat);
493 }
494
495 rc = copy_to_user((void __user *)arg, ptr, size);
496 if (rc)
497 return -EFAULT;
498
499 return 0;
500 }
501
mctp_ioctl_alloctag(struct mctp_sock * msk,bool tagv2,unsigned long arg)502 static int mctp_ioctl_alloctag(struct mctp_sock *msk, bool tagv2,
503 unsigned long arg)
504 {
505 struct net *net = sock_net(&msk->sk);
506 struct mctp_sk_key *key = NULL;
507 struct mctp_ioc_tag_ctl2 ctl;
508 unsigned long flags;
509 u8 tag;
510 int rc;
511
512 rc = mctp_ioctl_tag_copy_from_user(arg, &ctl, tagv2);
513 if (rc)
514 return rc;
515
516 if (ctl.tag)
517 return -EINVAL;
518
519 key = mctp_alloc_local_tag(msk, ctl.net, MCTP_ADDR_ANY,
520 ctl.peer_addr, true, &tag);
521 if (IS_ERR(key))
522 return PTR_ERR(key);
523
524 ctl.tag = tag | MCTP_TAG_OWNER | MCTP_TAG_PREALLOC;
525 rc = mctp_ioctl_tag_copy_to_user(arg, &ctl, tagv2);
526 if (rc) {
527 unsigned long fl2;
528 /* Unwind our key allocation: the keys list lock needs to be
529 * taken before the individual key locks, and we need a valid
530 * flags value (fl2) to pass to __mctp_key_remove, hence the
531 * second spin_lock_irqsave() rather than a plain spin_lock().
532 */
533 spin_lock_irqsave(&net->mctp.keys_lock, flags);
534 spin_lock_irqsave(&key->lock, fl2);
535 __mctp_key_remove(key, net, fl2, MCTP_TRACE_KEY_DROPPED);
536 mctp_key_unref(key);
537 spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
538 return rc;
539 }
540
541 mctp_key_unref(key);
542 return 0;
543 }
544
mctp_ioctl_droptag(struct mctp_sock * msk,bool tagv2,unsigned long arg)545 static int mctp_ioctl_droptag(struct mctp_sock *msk, bool tagv2,
546 unsigned long arg)
547 {
548 struct net *net = sock_net(&msk->sk);
549 struct mctp_ioc_tag_ctl2 ctl;
550 unsigned long flags, fl2;
551 struct mctp_sk_key *key;
552 struct hlist_node *tmp;
553 int rc;
554 u8 tag;
555
556 rc = mctp_ioctl_tag_copy_from_user(arg, &ctl, tagv2);
557 if (rc)
558 return rc;
559
560 /* Must be a local tag, TO set, preallocated */
561 if ((ctl.tag & ~MCTP_TAG_MASK) != (MCTP_TAG_OWNER | MCTP_TAG_PREALLOC))
562 return -EINVAL;
563
564 tag = ctl.tag & MCTP_TAG_MASK;
565 rc = -EINVAL;
566
567 if (ctl.peer_addr == MCTP_ADDR_NULL)
568 ctl.peer_addr = MCTP_ADDR_ANY;
569
570 spin_lock_irqsave(&net->mctp.keys_lock, flags);
571 hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
572 /* we do an irqsave here, even though we know the irq state,
573 * so we have the flags to pass to __mctp_key_remove
574 */
575 spin_lock_irqsave(&key->lock, fl2);
576 if (key->manual_alloc &&
577 ctl.net == key->net &&
578 ctl.peer_addr == key->peer_addr &&
579 tag == key->tag) {
580 __mctp_key_remove(key, net, fl2,
581 MCTP_TRACE_KEY_DROPPED);
582 rc = 0;
583 } else {
584 spin_unlock_irqrestore(&key->lock, fl2);
585 }
586 }
587 spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
588
589 return rc;
590 }
591
mctp_ioctl(struct socket * sock,unsigned int cmd,unsigned long arg)592 static int mctp_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg)
593 {
594 struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
595 bool tagv2 = false;
596
597 switch (cmd) {
598 case SIOCMCTPALLOCTAG2:
599 case SIOCMCTPALLOCTAG:
600 tagv2 = cmd == SIOCMCTPALLOCTAG2;
601 return mctp_ioctl_alloctag(msk, tagv2, arg);
602 case SIOCMCTPDROPTAG:
603 case SIOCMCTPDROPTAG2:
604 tagv2 = cmd == SIOCMCTPDROPTAG2;
605 return mctp_ioctl_droptag(msk, tagv2, arg);
606 }
607
608 return -EINVAL;
609 }
610
611 #ifdef CONFIG_COMPAT
mctp_compat_ioctl(struct socket * sock,unsigned int cmd,unsigned long arg)612 static int mctp_compat_ioctl(struct socket *sock, unsigned int cmd,
613 unsigned long arg)
614 {
615 void __user *argp = compat_ptr(arg);
616
617 switch (cmd) {
618 /* These have compatible ptr layouts */
619 case SIOCMCTPALLOCTAG:
620 case SIOCMCTPDROPTAG:
621 return mctp_ioctl(sock, cmd, (unsigned long)argp);
622 }
623
624 return -ENOIOCTLCMD;
625 }
626 #endif
627
628 static const struct proto_ops mctp_dgram_ops = {
629 .family = PF_MCTP,
630 .release = mctp_release,
631 .bind = mctp_bind,
632 .connect = mctp_connect,
633 .socketpair = sock_no_socketpair,
634 .accept = sock_no_accept,
635 .getname = sock_no_getname,
636 .poll = datagram_poll,
637 .ioctl = mctp_ioctl,
638 .gettstamp = sock_gettstamp,
639 .listen = sock_no_listen,
640 .shutdown = sock_no_shutdown,
641 .setsockopt = mctp_setsockopt,
642 .getsockopt = mctp_getsockopt,
643 .sendmsg = mctp_sendmsg,
644 .recvmsg = mctp_recvmsg,
645 .mmap = sock_no_mmap,
646 #ifdef CONFIG_COMPAT
647 .compat_ioctl = mctp_compat_ioctl,
648 #endif
649 };
650
mctp_sk_expire_keys(struct timer_list * timer)651 static void mctp_sk_expire_keys(struct timer_list *timer)
652 {
653 struct mctp_sock *msk = container_of(timer, struct mctp_sock,
654 key_expiry);
655 struct net *net = sock_net(&msk->sk);
656 unsigned long next_expiry, flags, fl2;
657 struct mctp_sk_key *key;
658 struct hlist_node *tmp;
659 bool next_expiry_valid = false;
660
661 spin_lock_irqsave(&net->mctp.keys_lock, flags);
662
663 hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
664 /* don't expire. manual_alloc is immutable, no locking
665 * required.
666 */
667 if (key->manual_alloc)
668 continue;
669
670 spin_lock_irqsave(&key->lock, fl2);
671 if (!time_after_eq(key->expiry, jiffies)) {
672 __mctp_key_remove(key, net, fl2,
673 MCTP_TRACE_KEY_TIMEOUT);
674 continue;
675 }
676
677 if (next_expiry_valid) {
678 if (time_before(key->expiry, next_expiry))
679 next_expiry = key->expiry;
680 } else {
681 next_expiry = key->expiry;
682 next_expiry_valid = true;
683 }
684 spin_unlock_irqrestore(&key->lock, fl2);
685 }
686
687 spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
688
689 if (next_expiry_valid)
690 mod_timer(timer, next_expiry);
691 }
692
mctp_sk_init(struct sock * sk)693 static int mctp_sk_init(struct sock *sk)
694 {
695 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
696
697 INIT_HLIST_HEAD(&msk->keys);
698 timer_setup(&msk->key_expiry, mctp_sk_expire_keys, 0);
699 msk->bind_peer_set = false;
700 return 0;
701 }
702
mctp_sk_close(struct sock * sk,long timeout)703 static void mctp_sk_close(struct sock *sk, long timeout)
704 {
705 sk_common_release(sk);
706 }
707
mctp_sk_hash(struct sock * sk)708 static int mctp_sk_hash(struct sock *sk)
709 {
710 struct net *net = sock_net(sk);
711 struct sock *existing;
712 struct mctp_sock *msk;
713 mctp_eid_t remote;
714 u32 hash;
715 int rc;
716
717 msk = container_of(sk, struct mctp_sock, sk);
718
719 if (msk->bind_peer_set)
720 remote = msk->bind_peer_addr;
721 else
722 remote = MCTP_ADDR_ANY;
723 hash = mctp_bind_hash(msk->bind_type, msk->bind_local_addr, remote);
724
725 mutex_lock(&net->mctp.bind_lock);
726
727 /* Prevent duplicate binds. */
728 sk_for_each(existing, &net->mctp.binds[hash]) {
729 struct mctp_sock *mex =
730 container_of(existing, struct mctp_sock, sk);
731
732 bool same_peer = (mex->bind_peer_set && msk->bind_peer_set &&
733 mex->bind_peer_addr == msk->bind_peer_addr) ||
734 (!mex->bind_peer_set && !msk->bind_peer_set);
735
736 if (mex->bind_type == msk->bind_type &&
737 mex->bind_local_addr == msk->bind_local_addr && same_peer &&
738 mex->bind_net == msk->bind_net) {
739 rc = -EADDRINUSE;
740 goto out;
741 }
742 }
743
744 /* Bind lookup runs under RCU, remain live during that. */
745 sock_set_flag(sk, SOCK_RCU_FREE);
746
747 sk_add_node_rcu(sk, &net->mctp.binds[hash]);
748 rc = 0;
749
750 out:
751 mutex_unlock(&net->mctp.bind_lock);
752 return rc;
753 }
754
mctp_sk_unhash(struct sock * sk)755 static void mctp_sk_unhash(struct sock *sk)
756 {
757 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
758 struct net *net = sock_net(sk);
759 unsigned long flags, fl2;
760 struct mctp_sk_key *key;
761 struct hlist_node *tmp;
762
763 /* remove from any type-based binds */
764 mutex_lock(&net->mctp.bind_lock);
765 sk_del_node_init_rcu(sk);
766 mutex_unlock(&net->mctp.bind_lock);
767
768 /* remove tag allocations */
769 spin_lock_irqsave(&net->mctp.keys_lock, flags);
770 hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
771 spin_lock_irqsave(&key->lock, fl2);
772 __mctp_key_remove(key, net, fl2, MCTP_TRACE_KEY_CLOSED);
773 }
774 sock_set_flag(sk, SOCK_DEAD);
775 spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
776
777 /* Since there are no more tag allocations (we have removed all of the
778 * keys), stop any pending expiry events. the timer cannot be re-queued
779 * as the sk is no longer observable
780 */
781 timer_delete_sync(&msk->key_expiry);
782 }
783
mctp_sk_destruct(struct sock * sk)784 static void mctp_sk_destruct(struct sock *sk)
785 {
786 skb_queue_purge(&sk->sk_receive_queue);
787 }
788
789 static struct proto mctp_proto = {
790 .name = "MCTP",
791 .owner = THIS_MODULE,
792 .obj_size = sizeof(struct mctp_sock),
793 .init = mctp_sk_init,
794 .close = mctp_sk_close,
795 .hash = mctp_sk_hash,
796 .unhash = mctp_sk_unhash,
797 };
798
mctp_pf_create(struct net * net,struct socket * sock,int protocol,int kern)799 static int mctp_pf_create(struct net *net, struct socket *sock,
800 int protocol, int kern)
801 {
802 const struct proto_ops *ops;
803 struct proto *proto;
804 struct sock *sk;
805 int rc;
806
807 if (protocol)
808 return -EPROTONOSUPPORT;
809
810 /* only datagram sockets are supported */
811 if (sock->type != SOCK_DGRAM)
812 return -ESOCKTNOSUPPORT;
813
814 proto = &mctp_proto;
815 ops = &mctp_dgram_ops;
816
817 sock->state = SS_UNCONNECTED;
818 sock->ops = ops;
819
820 sk = sk_alloc(net, PF_MCTP, GFP_KERNEL, proto, kern);
821 if (!sk)
822 return -ENOMEM;
823
824 sock_init_data(sock, sk);
825 sk->sk_destruct = mctp_sk_destruct;
826
827 rc = 0;
828 if (sk->sk_prot->init)
829 rc = sk->sk_prot->init(sk);
830
831 if (rc)
832 goto err_sk_put;
833
834 return 0;
835
836 err_sk_put:
837 sock_orphan(sk);
838 sock_put(sk);
839 return rc;
840 }
841
842 static struct net_proto_family mctp_pf = {
843 .family = PF_MCTP,
844 .create = mctp_pf_create,
845 .owner = THIS_MODULE,
846 };
847
mctp_init(void)848 static __init int mctp_init(void)
849 {
850 int rc;
851
852 /* ensure our uapi tag definitions match the header format */
853 BUILD_BUG_ON(MCTP_TAG_OWNER != MCTP_HDR_FLAG_TO);
854 BUILD_BUG_ON(MCTP_TAG_MASK != MCTP_HDR_TAG_MASK);
855
856 pr_info("mctp: management component transport protocol core\n");
857
858 rc = sock_register(&mctp_pf);
859 if (rc)
860 return rc;
861
862 rc = proto_register(&mctp_proto, 0);
863 if (rc)
864 goto err_unreg_sock;
865
866 rc = mctp_routes_init();
867 if (rc)
868 goto err_unreg_proto;
869
870 rc = mctp_neigh_init();
871 if (rc)
872 goto err_unreg_routes;
873
874 rc = mctp_device_init();
875 if (rc)
876 goto err_unreg_neigh;
877
878 return 0;
879
880 err_unreg_neigh:
881 mctp_neigh_exit();
882 err_unreg_routes:
883 mctp_routes_exit();
884 err_unreg_proto:
885 proto_unregister(&mctp_proto);
886 err_unreg_sock:
887 sock_unregister(PF_MCTP);
888
889 return rc;
890 }
891
mctp_exit(void)892 static __exit void mctp_exit(void)
893 {
894 mctp_device_exit();
895 mctp_neigh_exit();
896 mctp_routes_exit();
897 proto_unregister(&mctp_proto);
898 sock_unregister(PF_MCTP);
899 }
900
901 subsys_initcall(mctp_init);
902 module_exit(mctp_exit);
903
904 MODULE_DESCRIPTION("MCTP core");
905 MODULE_AUTHOR("Jeremy Kerr <jk@codeconstruct.com.au>");
906
907 MODULE_ALIAS_NETPROTO(PF_MCTP);
908
909 #if IS_ENABLED(CONFIG_MCTP_TEST)
910 #include "test/sock-test.c"
911 #endif
912