xref: /linux/net/mctp/af_mctp.c (revision 5cfe477f6a3f9a4d9b2906d442964f2115b0403f)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Management Component Transport Protocol (MCTP)
4  *
5  * Copyright (c) 2021 Code Construct
6  * Copyright (c) 2021 Google
7  */
8 
9 #include <linux/compat.h>
10 #include <linux/if_arp.h>
11 #include <linux/net.h>
12 #include <linux/mctp.h>
13 #include <linux/module.h>
14 #include <linux/socket.h>
15 
16 #include <net/mctp.h>
17 #include <net/mctpdevice.h>
18 #include <net/sock.h>
19 
20 #define CREATE_TRACE_POINTS
21 #include <trace/events/mctp.h>
22 
23 /* socket implementation */
24 
25 static void mctp_sk_expire_keys(struct timer_list *timer);
26 
27 static int mctp_release(struct socket *sock)
28 {
29 	struct sock *sk = sock->sk;
30 
31 	if (sk) {
32 		sock->sk = NULL;
33 		sk->sk_prot->close(sk, 0);
34 	}
35 
36 	return 0;
37 }
38 
39 /* Generic sockaddr checks, padding checks only so far */
40 static bool mctp_sockaddr_is_ok(const struct sockaddr_mctp *addr)
41 {
42 	return !addr->__smctp_pad0 && !addr->__smctp_pad1;
43 }
44 
45 static bool mctp_sockaddr_ext_is_ok(const struct sockaddr_mctp_ext *addr)
46 {
47 	return !addr->__smctp_pad0[0] &&
48 	       !addr->__smctp_pad0[1] &&
49 	       !addr->__smctp_pad0[2];
50 }
51 
52 static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen)
53 {
54 	struct sock *sk = sock->sk;
55 	struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
56 	struct sockaddr_mctp *smctp;
57 	int rc;
58 
59 	if (addrlen < sizeof(*smctp))
60 		return -EINVAL;
61 
62 	if (addr->sa_family != AF_MCTP)
63 		return -EAFNOSUPPORT;
64 
65 	if (!capable(CAP_NET_BIND_SERVICE))
66 		return -EACCES;
67 
68 	/* it's a valid sockaddr for MCTP, cast and do protocol checks */
69 	smctp = (struct sockaddr_mctp *)addr;
70 
71 	if (!mctp_sockaddr_is_ok(smctp))
72 		return -EINVAL;
73 
74 	lock_sock(sk);
75 
76 	/* TODO: allow rebind */
77 	if (sk_hashed(sk)) {
78 		rc = -EADDRINUSE;
79 		goto out_release;
80 	}
81 	msk->bind_net = smctp->smctp_network;
82 	msk->bind_addr = smctp->smctp_addr.s_addr;
83 	msk->bind_type = smctp->smctp_type & 0x7f; /* ignore the IC bit */
84 
85 	rc = sk->sk_prot->hash(sk);
86 
87 out_release:
88 	release_sock(sk);
89 
90 	return rc;
91 }
92 
93 static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
94 {
95 	DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
96 	const int hlen = MCTP_HEADER_MAXLEN + sizeof(struct mctp_hdr);
97 	int rc, addrlen = msg->msg_namelen;
98 	struct sock *sk = sock->sk;
99 	struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
100 	struct mctp_skb_cb *cb;
101 	struct mctp_route *rt;
102 	struct sk_buff *skb;
103 
104 	if (addr) {
105 		const u8 tagbits = MCTP_TAG_MASK | MCTP_TAG_OWNER |
106 			MCTP_TAG_PREALLOC;
107 
108 		if (addrlen < sizeof(struct sockaddr_mctp))
109 			return -EINVAL;
110 		if (addr->smctp_family != AF_MCTP)
111 			return -EINVAL;
112 		if (!mctp_sockaddr_is_ok(addr))
113 			return -EINVAL;
114 		if (addr->smctp_tag & ~tagbits)
115 			return -EINVAL;
116 		/* can't preallocate a non-owned tag */
117 		if (addr->smctp_tag & MCTP_TAG_PREALLOC &&
118 		    !(addr->smctp_tag & MCTP_TAG_OWNER))
119 			return -EINVAL;
120 
121 	} else {
122 		/* TODO: connect()ed sockets */
123 		return -EDESTADDRREQ;
124 	}
125 
126 	if (!capable(CAP_NET_RAW))
127 		return -EACCES;
128 
129 	if (addr->smctp_network == MCTP_NET_ANY)
130 		addr->smctp_network = mctp_default_net(sock_net(sk));
131 
132 	skb = sock_alloc_send_skb(sk, hlen + 1 + len,
133 				  msg->msg_flags & MSG_DONTWAIT, &rc);
134 	if (!skb)
135 		return rc;
136 
137 	skb_reserve(skb, hlen);
138 
139 	/* set type as fist byte in payload */
140 	*(u8 *)skb_put(skb, 1) = addr->smctp_type;
141 
142 	rc = memcpy_from_msg((void *)skb_put(skb, len), msg, len);
143 	if (rc < 0)
144 		goto err_free;
145 
146 	/* set up cb */
147 	cb = __mctp_cb(skb);
148 	cb->net = addr->smctp_network;
149 
150 	/* direct addressing */
151 	if (msk->addr_ext && addrlen >= sizeof(struct sockaddr_mctp_ext)) {
152 		DECLARE_SOCKADDR(struct sockaddr_mctp_ext *,
153 				 extaddr, msg->msg_name);
154 
155 		if (!mctp_sockaddr_ext_is_ok(extaddr) ||
156 		    extaddr->smctp_halen > sizeof(cb->haddr)) {
157 			rc = -EINVAL;
158 			goto err_free;
159 		}
160 
161 		cb->ifindex = extaddr->smctp_ifindex;
162 		cb->halen = extaddr->smctp_halen;
163 		memcpy(cb->haddr, extaddr->smctp_haddr, cb->halen);
164 
165 		rt = NULL;
166 	} else {
167 		rt = mctp_route_lookup(sock_net(sk), addr->smctp_network,
168 				       addr->smctp_addr.s_addr);
169 		if (!rt) {
170 			rc = -EHOSTUNREACH;
171 			goto err_free;
172 		}
173 	}
174 
175 	rc = mctp_local_output(sk, rt, skb, addr->smctp_addr.s_addr,
176 			       addr->smctp_tag);
177 
178 	return rc ? : len;
179 
180 err_free:
181 	kfree_skb(skb);
182 	return rc;
183 }
184 
185 static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
186 			int flags)
187 {
188 	DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
189 	struct sock *sk = sock->sk;
190 	struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
191 	struct sk_buff *skb;
192 	size_t msglen;
193 	u8 type;
194 	int rc;
195 
196 	if (flags & ~(MSG_DONTWAIT | MSG_TRUNC | MSG_PEEK))
197 		return -EOPNOTSUPP;
198 
199 	skb = skb_recv_datagram(sk, flags, flags & MSG_DONTWAIT, &rc);
200 	if (!skb)
201 		return rc;
202 
203 	if (!skb->len) {
204 		rc = 0;
205 		goto out_free;
206 	}
207 
208 	/* extract message type, remove from data */
209 	type = *((u8 *)skb->data);
210 	msglen = skb->len - 1;
211 
212 	if (len < msglen)
213 		msg->msg_flags |= MSG_TRUNC;
214 	else
215 		len = msglen;
216 
217 	rc = skb_copy_datagram_msg(skb, 1, msg, len);
218 	if (rc < 0)
219 		goto out_free;
220 
221 	sock_recv_ts_and_drops(msg, sk, skb);
222 
223 	if (addr) {
224 		struct mctp_skb_cb *cb = mctp_cb(skb);
225 		/* TODO: expand mctp_skb_cb for header fields? */
226 		struct mctp_hdr *hdr = mctp_hdr(skb);
227 
228 		addr = msg->msg_name;
229 		addr->smctp_family = AF_MCTP;
230 		addr->__smctp_pad0 = 0;
231 		addr->smctp_network = cb->net;
232 		addr->smctp_addr.s_addr = hdr->src;
233 		addr->smctp_type = type;
234 		addr->smctp_tag = hdr->flags_seq_tag &
235 					(MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
236 		addr->__smctp_pad1 = 0;
237 		msg->msg_namelen = sizeof(*addr);
238 
239 		if (msk->addr_ext) {
240 			DECLARE_SOCKADDR(struct sockaddr_mctp_ext *, ae,
241 					 msg->msg_name);
242 			msg->msg_namelen = sizeof(*ae);
243 			ae->smctp_ifindex = cb->ifindex;
244 			ae->smctp_halen = cb->halen;
245 			memset(ae->__smctp_pad0, 0x0, sizeof(ae->__smctp_pad0));
246 			memset(ae->smctp_haddr, 0x0, sizeof(ae->smctp_haddr));
247 			memcpy(ae->smctp_haddr, cb->haddr, cb->halen);
248 		}
249 	}
250 
251 	rc = len;
252 
253 	if (flags & MSG_TRUNC)
254 		rc = msglen;
255 
256 out_free:
257 	skb_free_datagram(sk, skb);
258 	return rc;
259 }
260 
261 /* We're done with the key; invalidate, stop reassembly, and remove from lists.
262  */
263 static void __mctp_key_remove(struct mctp_sk_key *key, struct net *net,
264 			      unsigned long flags, unsigned long reason)
265 __releases(&key->lock)
266 __must_hold(&net->mctp.keys_lock)
267 {
268 	struct sk_buff *skb;
269 
270 	trace_mctp_key_release(key, reason);
271 	skb = key->reasm_head;
272 	key->reasm_head = NULL;
273 	key->reasm_dead = true;
274 	key->valid = false;
275 	mctp_dev_release_key(key->dev, key);
276 	spin_unlock_irqrestore(&key->lock, flags);
277 
278 	hlist_del(&key->hlist);
279 	hlist_del(&key->sklist);
280 
281 	/* unref for the lists */
282 	mctp_key_unref(key);
283 
284 	kfree_skb(skb);
285 }
286 
287 static int mctp_setsockopt(struct socket *sock, int level, int optname,
288 			   sockptr_t optval, unsigned int optlen)
289 {
290 	struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
291 	int val;
292 
293 	if (level != SOL_MCTP)
294 		return -EINVAL;
295 
296 	if (optname == MCTP_OPT_ADDR_EXT) {
297 		if (optlen != sizeof(int))
298 			return -EINVAL;
299 		if (copy_from_sockptr(&val, optval, sizeof(int)))
300 			return -EFAULT;
301 		msk->addr_ext = val;
302 		return 0;
303 	}
304 
305 	return -ENOPROTOOPT;
306 }
307 
308 static int mctp_getsockopt(struct socket *sock, int level, int optname,
309 			   char __user *optval, int __user *optlen)
310 {
311 	struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
312 	int len, val;
313 
314 	if (level != SOL_MCTP)
315 		return -EINVAL;
316 
317 	if (get_user(len, optlen))
318 		return -EFAULT;
319 
320 	if (optname == MCTP_OPT_ADDR_EXT) {
321 		if (len != sizeof(int))
322 			return -EINVAL;
323 		val = !!msk->addr_ext;
324 		if (copy_to_user(optval, &val, len))
325 			return -EFAULT;
326 		return 0;
327 	}
328 
329 	return -EINVAL;
330 }
331 
332 static int mctp_ioctl_alloctag(struct mctp_sock *msk, unsigned long arg)
333 {
334 	struct net *net = sock_net(&msk->sk);
335 	struct mctp_sk_key *key = NULL;
336 	struct mctp_ioc_tag_ctl ctl;
337 	unsigned long flags;
338 	u8 tag;
339 
340 	if (copy_from_user(&ctl, (void __user *)arg, sizeof(ctl)))
341 		return -EFAULT;
342 
343 	if (ctl.tag)
344 		return -EINVAL;
345 
346 	if (ctl.flags)
347 		return -EINVAL;
348 
349 	key = mctp_alloc_local_tag(msk, ctl.peer_addr, MCTP_ADDR_ANY,
350 				   true, &tag);
351 	if (IS_ERR(key))
352 		return PTR_ERR(key);
353 
354 	ctl.tag = tag | MCTP_TAG_OWNER | MCTP_TAG_PREALLOC;
355 	if (copy_to_user((void __user *)arg, &ctl, sizeof(ctl))) {
356 		spin_lock_irqsave(&key->lock, flags);
357 		__mctp_key_remove(key, net, flags, MCTP_TRACE_KEY_DROPPED);
358 		mctp_key_unref(key);
359 		return -EFAULT;
360 	}
361 
362 	mctp_key_unref(key);
363 	return 0;
364 }
365 
366 static int mctp_ioctl_droptag(struct mctp_sock *msk, unsigned long arg)
367 {
368 	struct net *net = sock_net(&msk->sk);
369 	struct mctp_ioc_tag_ctl ctl;
370 	unsigned long flags, fl2;
371 	struct mctp_sk_key *key;
372 	struct hlist_node *tmp;
373 	int rc;
374 	u8 tag;
375 
376 	if (copy_from_user(&ctl, (void __user *)arg, sizeof(ctl)))
377 		return -EFAULT;
378 
379 	if (ctl.flags)
380 		return -EINVAL;
381 
382 	/* Must be a local tag, TO set, preallocated */
383 	if ((ctl.tag & ~MCTP_TAG_MASK) != (MCTP_TAG_OWNER | MCTP_TAG_PREALLOC))
384 		return -EINVAL;
385 
386 	tag = ctl.tag & MCTP_TAG_MASK;
387 	rc = -EINVAL;
388 
389 	spin_lock_irqsave(&net->mctp.keys_lock, flags);
390 	hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
391 		/* we do an irqsave here, even though we know the irq state,
392 		 * so we have the flags to pass to __mctp_key_remove
393 		 */
394 		spin_lock_irqsave(&key->lock, fl2);
395 		if (key->manual_alloc &&
396 		    ctl.peer_addr == key->peer_addr &&
397 		    tag == key->tag) {
398 			__mctp_key_remove(key, net, fl2,
399 					  MCTP_TRACE_KEY_DROPPED);
400 			rc = 0;
401 		} else {
402 			spin_unlock_irqrestore(&key->lock, fl2);
403 		}
404 	}
405 	spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
406 
407 	return rc;
408 }
409 
410 static int mctp_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg)
411 {
412 	struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
413 
414 	switch (cmd) {
415 	case SIOCMCTPALLOCTAG:
416 		return mctp_ioctl_alloctag(msk, arg);
417 	case SIOCMCTPDROPTAG:
418 		return mctp_ioctl_droptag(msk, arg);
419 	}
420 
421 	return -EINVAL;
422 }
423 
424 #ifdef CONFIG_COMPAT
425 static int mctp_compat_ioctl(struct socket *sock, unsigned int cmd,
426 			     unsigned long arg)
427 {
428 	void __user *argp = compat_ptr(arg);
429 
430 	switch (cmd) {
431 	/* These have compatible ptr layouts */
432 	case SIOCMCTPALLOCTAG:
433 	case SIOCMCTPDROPTAG:
434 		return mctp_ioctl(sock, cmd, (unsigned long)argp);
435 	}
436 
437 	return -ENOIOCTLCMD;
438 }
439 #endif
440 
441 static const struct proto_ops mctp_dgram_ops = {
442 	.family		= PF_MCTP,
443 	.release	= mctp_release,
444 	.bind		= mctp_bind,
445 	.connect	= sock_no_connect,
446 	.socketpair	= sock_no_socketpair,
447 	.accept		= sock_no_accept,
448 	.getname	= sock_no_getname,
449 	.poll		= datagram_poll,
450 	.ioctl		= mctp_ioctl,
451 	.gettstamp	= sock_gettstamp,
452 	.listen		= sock_no_listen,
453 	.shutdown	= sock_no_shutdown,
454 	.setsockopt	= mctp_setsockopt,
455 	.getsockopt	= mctp_getsockopt,
456 	.sendmsg	= mctp_sendmsg,
457 	.recvmsg	= mctp_recvmsg,
458 	.mmap		= sock_no_mmap,
459 	.sendpage	= sock_no_sendpage,
460 #ifdef CONFIG_COMPAT
461 	.compat_ioctl	= mctp_compat_ioctl,
462 #endif
463 };
464 
465 static void mctp_sk_expire_keys(struct timer_list *timer)
466 {
467 	struct mctp_sock *msk = container_of(timer, struct mctp_sock,
468 					     key_expiry);
469 	struct net *net = sock_net(&msk->sk);
470 	unsigned long next_expiry, flags, fl2;
471 	struct mctp_sk_key *key;
472 	struct hlist_node *tmp;
473 	bool next_expiry_valid = false;
474 
475 	spin_lock_irqsave(&net->mctp.keys_lock, flags);
476 
477 	hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
478 		/* don't expire. manual_alloc is immutable, no locking
479 		 * required.
480 		 */
481 		if (key->manual_alloc)
482 			continue;
483 
484 		spin_lock_irqsave(&key->lock, fl2);
485 		if (!time_after_eq(key->expiry, jiffies)) {
486 			__mctp_key_remove(key, net, fl2,
487 					  MCTP_TRACE_KEY_TIMEOUT);
488 			continue;
489 		}
490 
491 		if (next_expiry_valid) {
492 			if (time_before(key->expiry, next_expiry))
493 				next_expiry = key->expiry;
494 		} else {
495 			next_expiry = key->expiry;
496 			next_expiry_valid = true;
497 		}
498 		spin_unlock_irqrestore(&key->lock, fl2);
499 	}
500 
501 	spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
502 
503 	if (next_expiry_valid)
504 		mod_timer(timer, next_expiry);
505 }
506 
507 static int mctp_sk_init(struct sock *sk)
508 {
509 	struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
510 
511 	INIT_HLIST_HEAD(&msk->keys);
512 	timer_setup(&msk->key_expiry, mctp_sk_expire_keys, 0);
513 	return 0;
514 }
515 
516 static void mctp_sk_close(struct sock *sk, long timeout)
517 {
518 	struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
519 
520 	del_timer_sync(&msk->key_expiry);
521 	sk_common_release(sk);
522 }
523 
524 static int mctp_sk_hash(struct sock *sk)
525 {
526 	struct net *net = sock_net(sk);
527 
528 	mutex_lock(&net->mctp.bind_lock);
529 	sk_add_node_rcu(sk, &net->mctp.binds);
530 	mutex_unlock(&net->mctp.bind_lock);
531 
532 	return 0;
533 }
534 
535 static void mctp_sk_unhash(struct sock *sk)
536 {
537 	struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
538 	struct net *net = sock_net(sk);
539 	unsigned long flags, fl2;
540 	struct mctp_sk_key *key;
541 	struct hlist_node *tmp;
542 
543 	/* remove from any type-based binds */
544 	mutex_lock(&net->mctp.bind_lock);
545 	sk_del_node_init_rcu(sk);
546 	mutex_unlock(&net->mctp.bind_lock);
547 
548 	/* remove tag allocations */
549 	spin_lock_irqsave(&net->mctp.keys_lock, flags);
550 	hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
551 		spin_lock_irqsave(&key->lock, fl2);
552 		__mctp_key_remove(key, net, fl2, MCTP_TRACE_KEY_CLOSED);
553 	}
554 	spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
555 }
556 
557 static struct proto mctp_proto = {
558 	.name		= "MCTP",
559 	.owner		= THIS_MODULE,
560 	.obj_size	= sizeof(struct mctp_sock),
561 	.init		= mctp_sk_init,
562 	.close		= mctp_sk_close,
563 	.hash		= mctp_sk_hash,
564 	.unhash		= mctp_sk_unhash,
565 };
566 
567 static int mctp_pf_create(struct net *net, struct socket *sock,
568 			  int protocol, int kern)
569 {
570 	const struct proto_ops *ops;
571 	struct proto *proto;
572 	struct sock *sk;
573 	int rc;
574 
575 	if (protocol)
576 		return -EPROTONOSUPPORT;
577 
578 	/* only datagram sockets are supported */
579 	if (sock->type != SOCK_DGRAM)
580 		return -ESOCKTNOSUPPORT;
581 
582 	proto = &mctp_proto;
583 	ops = &mctp_dgram_ops;
584 
585 	sock->state = SS_UNCONNECTED;
586 	sock->ops = ops;
587 
588 	sk = sk_alloc(net, PF_MCTP, GFP_KERNEL, proto, kern);
589 	if (!sk)
590 		return -ENOMEM;
591 
592 	sock_init_data(sock, sk);
593 
594 	rc = 0;
595 	if (sk->sk_prot->init)
596 		rc = sk->sk_prot->init(sk);
597 
598 	if (rc)
599 		goto err_sk_put;
600 
601 	return 0;
602 
603 err_sk_put:
604 	sock_orphan(sk);
605 	sock_put(sk);
606 	return rc;
607 }
608 
609 static struct net_proto_family mctp_pf = {
610 	.family = PF_MCTP,
611 	.create = mctp_pf_create,
612 	.owner = THIS_MODULE,
613 };
614 
615 static __init int mctp_init(void)
616 {
617 	int rc;
618 
619 	/* ensure our uapi tag definitions match the header format */
620 	BUILD_BUG_ON(MCTP_TAG_OWNER != MCTP_HDR_FLAG_TO);
621 	BUILD_BUG_ON(MCTP_TAG_MASK != MCTP_HDR_TAG_MASK);
622 
623 	pr_info("mctp: management component transport protocol core\n");
624 
625 	rc = sock_register(&mctp_pf);
626 	if (rc)
627 		return rc;
628 
629 	rc = proto_register(&mctp_proto, 0);
630 	if (rc)
631 		goto err_unreg_sock;
632 
633 	rc = mctp_routes_init();
634 	if (rc)
635 		goto err_unreg_proto;
636 
637 	rc = mctp_neigh_init();
638 	if (rc)
639 		goto err_unreg_proto;
640 
641 	mctp_device_init();
642 
643 	return 0;
644 
645 err_unreg_proto:
646 	proto_unregister(&mctp_proto);
647 err_unreg_sock:
648 	sock_unregister(PF_MCTP);
649 
650 	return rc;
651 }
652 
653 static __exit void mctp_exit(void)
654 {
655 	mctp_device_exit();
656 	mctp_neigh_exit();
657 	mctp_routes_exit();
658 	proto_unregister(&mctp_proto);
659 	sock_unregister(PF_MCTP);
660 }
661 
662 subsys_initcall(mctp_init);
663 module_exit(mctp_exit);
664 
665 MODULE_DESCRIPTION("MCTP core");
666 MODULE_LICENSE("GPL v2");
667 MODULE_AUTHOR("Jeremy Kerr <jk@codeconstruct.com.au>");
668 
669 MODULE_ALIAS_NETPROTO(PF_MCTP);
670