xref: /linux/net/mctp/af_mctp.c (revision e0c1b49f5b674cca7b10549c53b3791d0bbc90a8)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Management Component Transport Protocol (MCTP)
4  *
5  * Copyright (c) 2021 Code Construct
6  * Copyright (c) 2021 Google
7  */
8 
9 #include <linux/if_arp.h>
10 #include <linux/net.h>
11 #include <linux/mctp.h>
12 #include <linux/module.h>
13 #include <linux/socket.h>
14 
15 #include <net/mctp.h>
16 #include <net/mctpdevice.h>
17 #include <net/sock.h>
18 
19 #define CREATE_TRACE_POINTS
20 #include <trace/events/mctp.h>
21 
22 /* socket implementation */
23 
24 static int mctp_release(struct socket *sock)
25 {
26 	struct sock *sk = sock->sk;
27 
28 	if (sk) {
29 		sock->sk = NULL;
30 		sk->sk_prot->close(sk, 0);
31 	}
32 
33 	return 0;
34 }
35 
36 static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen)
37 {
38 	struct sock *sk = sock->sk;
39 	struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
40 	struct sockaddr_mctp *smctp;
41 	int rc;
42 
43 	if (addrlen < sizeof(*smctp))
44 		return -EINVAL;
45 
46 	if (addr->sa_family != AF_MCTP)
47 		return -EAFNOSUPPORT;
48 
49 	if (!capable(CAP_NET_BIND_SERVICE))
50 		return -EACCES;
51 
52 	/* it's a valid sockaddr for MCTP, cast and do protocol checks */
53 	smctp = (struct sockaddr_mctp *)addr;
54 
55 	lock_sock(sk);
56 
57 	/* TODO: allow rebind */
58 	if (sk_hashed(sk)) {
59 		rc = -EADDRINUSE;
60 		goto out_release;
61 	}
62 	msk->bind_net = smctp->smctp_network;
63 	msk->bind_addr = smctp->smctp_addr.s_addr;
64 	msk->bind_type = smctp->smctp_type & 0x7f; /* ignore the IC bit */
65 
66 	rc = sk->sk_prot->hash(sk);
67 
68 out_release:
69 	release_sock(sk);
70 
71 	return rc;
72 }
73 
74 static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
75 {
76 	DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
77 	const int hlen = MCTP_HEADER_MAXLEN + sizeof(struct mctp_hdr);
78 	int rc, addrlen = msg->msg_namelen;
79 	struct sock *sk = sock->sk;
80 	struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
81 	struct mctp_skb_cb *cb;
82 	struct mctp_route *rt;
83 	struct sk_buff *skb;
84 
85 	if (addr) {
86 		if (addrlen < sizeof(struct sockaddr_mctp))
87 			return -EINVAL;
88 		if (addr->smctp_family != AF_MCTP)
89 			return -EINVAL;
90 		if (addr->smctp_tag & ~(MCTP_TAG_MASK | MCTP_TAG_OWNER))
91 			return -EINVAL;
92 
93 	} else {
94 		/* TODO: connect()ed sockets */
95 		return -EDESTADDRREQ;
96 	}
97 
98 	if (!capable(CAP_NET_RAW))
99 		return -EACCES;
100 
101 	if (addr->smctp_network == MCTP_NET_ANY)
102 		addr->smctp_network = mctp_default_net(sock_net(sk));
103 
104 	skb = sock_alloc_send_skb(sk, hlen + 1 + len,
105 				  msg->msg_flags & MSG_DONTWAIT, &rc);
106 	if (!skb)
107 		return rc;
108 
109 	skb_reserve(skb, hlen);
110 
111 	/* set type as fist byte in payload */
112 	*(u8 *)skb_put(skb, 1) = addr->smctp_type;
113 
114 	rc = memcpy_from_msg((void *)skb_put(skb, len), msg, len);
115 	if (rc < 0)
116 		goto err_free;
117 
118 	/* set up cb */
119 	cb = __mctp_cb(skb);
120 	cb->net = addr->smctp_network;
121 
122 	/* direct addressing */
123 	if (msk->addr_ext && addrlen >= sizeof(struct sockaddr_mctp_ext)) {
124 		DECLARE_SOCKADDR(struct sockaddr_mctp_ext *,
125 				 extaddr, msg->msg_name);
126 
127 		if (extaddr->smctp_halen > sizeof(cb->haddr)) {
128 			rc = -EINVAL;
129 			goto err_free;
130 		}
131 
132 		cb->ifindex = extaddr->smctp_ifindex;
133 		cb->halen = extaddr->smctp_halen;
134 		memcpy(cb->haddr, extaddr->smctp_haddr, cb->halen);
135 
136 		rt = NULL;
137 	} else {
138 		rt = mctp_route_lookup(sock_net(sk), addr->smctp_network,
139 				       addr->smctp_addr.s_addr);
140 		if (!rt) {
141 			rc = -EHOSTUNREACH;
142 			goto err_free;
143 		}
144 	}
145 
146 	rc = mctp_local_output(sk, rt, skb, addr->smctp_addr.s_addr,
147 			       addr->smctp_tag);
148 
149 	return rc ? : len;
150 
151 err_free:
152 	kfree_skb(skb);
153 	return rc;
154 }
155 
156 static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
157 			int flags)
158 {
159 	DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
160 	struct sock *sk = sock->sk;
161 	struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
162 	struct sk_buff *skb;
163 	size_t msglen;
164 	u8 type;
165 	int rc;
166 
167 	if (flags & ~(MSG_DONTWAIT | MSG_TRUNC | MSG_PEEK))
168 		return -EOPNOTSUPP;
169 
170 	skb = skb_recv_datagram(sk, flags, flags & MSG_DONTWAIT, &rc);
171 	if (!skb)
172 		return rc;
173 
174 	if (!skb->len) {
175 		rc = 0;
176 		goto out_free;
177 	}
178 
179 	/* extract message type, remove from data */
180 	type = *((u8 *)skb->data);
181 	msglen = skb->len - 1;
182 
183 	if (len < msglen)
184 		msg->msg_flags |= MSG_TRUNC;
185 	else
186 		len = msglen;
187 
188 	rc = skb_copy_datagram_msg(skb, 1, msg, len);
189 	if (rc < 0)
190 		goto out_free;
191 
192 	sock_recv_ts_and_drops(msg, sk, skb);
193 
194 	if (addr) {
195 		struct mctp_skb_cb *cb = mctp_cb(skb);
196 		/* TODO: expand mctp_skb_cb for header fields? */
197 		struct mctp_hdr *hdr = mctp_hdr(skb);
198 
199 		addr = msg->msg_name;
200 		addr->smctp_family = AF_MCTP;
201 		addr->smctp_network = cb->net;
202 		addr->smctp_addr.s_addr = hdr->src;
203 		addr->smctp_type = type;
204 		addr->smctp_tag = hdr->flags_seq_tag &
205 					(MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
206 		msg->msg_namelen = sizeof(*addr);
207 
208 		if (msk->addr_ext) {
209 			DECLARE_SOCKADDR(struct sockaddr_mctp_ext *, ae,
210 					 msg->msg_name);
211 			msg->msg_namelen = sizeof(*ae);
212 			ae->smctp_ifindex = cb->ifindex;
213 			ae->smctp_halen = cb->halen;
214 			memset(ae->smctp_haddr, 0x0, sizeof(ae->smctp_haddr));
215 			memcpy(ae->smctp_haddr, cb->haddr, cb->halen);
216 		}
217 	}
218 
219 	rc = len;
220 
221 	if (flags & MSG_TRUNC)
222 		rc = msglen;
223 
224 out_free:
225 	skb_free_datagram(sk, skb);
226 	return rc;
227 }
228 
229 static int mctp_setsockopt(struct socket *sock, int level, int optname,
230 			   sockptr_t optval, unsigned int optlen)
231 {
232 	struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
233 	int val;
234 
235 	if (level != SOL_MCTP)
236 		return -EINVAL;
237 
238 	if (optname == MCTP_OPT_ADDR_EXT) {
239 		if (optlen != sizeof(int))
240 			return -EINVAL;
241 		if (copy_from_sockptr(&val, optval, sizeof(int)))
242 			return -EFAULT;
243 		msk->addr_ext = val;
244 		return 0;
245 	}
246 
247 	return -ENOPROTOOPT;
248 }
249 
250 static int mctp_getsockopt(struct socket *sock, int level, int optname,
251 			   char __user *optval, int __user *optlen)
252 {
253 	struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
254 	int len, val;
255 
256 	if (level != SOL_MCTP)
257 		return -EINVAL;
258 
259 	if (get_user(len, optlen))
260 		return -EFAULT;
261 
262 	if (optname == MCTP_OPT_ADDR_EXT) {
263 		if (len != sizeof(int))
264 			return -EINVAL;
265 		val = !!msk->addr_ext;
266 		if (copy_to_user(optval, &val, len))
267 			return -EFAULT;
268 		return 0;
269 	}
270 
271 	return -EINVAL;
272 }
273 
274 static const struct proto_ops mctp_dgram_ops = {
275 	.family		= PF_MCTP,
276 	.release	= mctp_release,
277 	.bind		= mctp_bind,
278 	.connect	= sock_no_connect,
279 	.socketpair	= sock_no_socketpair,
280 	.accept		= sock_no_accept,
281 	.getname	= sock_no_getname,
282 	.poll		= datagram_poll,
283 	.ioctl		= sock_no_ioctl,
284 	.gettstamp	= sock_gettstamp,
285 	.listen		= sock_no_listen,
286 	.shutdown	= sock_no_shutdown,
287 	.setsockopt	= mctp_setsockopt,
288 	.getsockopt	= mctp_getsockopt,
289 	.sendmsg	= mctp_sendmsg,
290 	.recvmsg	= mctp_recvmsg,
291 	.mmap		= sock_no_mmap,
292 	.sendpage	= sock_no_sendpage,
293 };
294 
295 static void mctp_sk_expire_keys(struct timer_list *timer)
296 {
297 	struct mctp_sock *msk = container_of(timer, struct mctp_sock,
298 					     key_expiry);
299 	struct net *net = sock_net(&msk->sk);
300 	unsigned long next_expiry, flags;
301 	struct mctp_sk_key *key;
302 	struct hlist_node *tmp;
303 	bool next_expiry_valid = false;
304 
305 	spin_lock_irqsave(&net->mctp.keys_lock, flags);
306 
307 	hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
308 		spin_lock(&key->lock);
309 
310 		if (!time_after_eq(key->expiry, jiffies)) {
311 			trace_mctp_key_release(key, MCTP_TRACE_KEY_TIMEOUT);
312 			key->valid = false;
313 			hlist_del_rcu(&key->hlist);
314 			hlist_del_rcu(&key->sklist);
315 			spin_unlock(&key->lock);
316 			mctp_key_unref(key);
317 			continue;
318 		}
319 
320 		if (next_expiry_valid) {
321 			if (time_before(key->expiry, next_expiry))
322 				next_expiry = key->expiry;
323 		} else {
324 			next_expiry = key->expiry;
325 			next_expiry_valid = true;
326 		}
327 		spin_unlock(&key->lock);
328 	}
329 
330 	spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
331 
332 	if (next_expiry_valid)
333 		mod_timer(timer, next_expiry);
334 }
335 
336 static int mctp_sk_init(struct sock *sk)
337 {
338 	struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
339 
340 	INIT_HLIST_HEAD(&msk->keys);
341 	timer_setup(&msk->key_expiry, mctp_sk_expire_keys, 0);
342 	return 0;
343 }
344 
345 static void mctp_sk_close(struct sock *sk, long timeout)
346 {
347 	struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
348 
349 	del_timer_sync(&msk->key_expiry);
350 	sk_common_release(sk);
351 }
352 
353 static int mctp_sk_hash(struct sock *sk)
354 {
355 	struct net *net = sock_net(sk);
356 
357 	mutex_lock(&net->mctp.bind_lock);
358 	sk_add_node_rcu(sk, &net->mctp.binds);
359 	mutex_unlock(&net->mctp.bind_lock);
360 
361 	return 0;
362 }
363 
364 static void mctp_sk_unhash(struct sock *sk)
365 {
366 	struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
367 	struct net *net = sock_net(sk);
368 	struct mctp_sk_key *key;
369 	struct hlist_node *tmp;
370 	unsigned long flags;
371 
372 	/* remove from any type-based binds */
373 	mutex_lock(&net->mctp.bind_lock);
374 	sk_del_node_init_rcu(sk);
375 	mutex_unlock(&net->mctp.bind_lock);
376 
377 	/* remove tag allocations */
378 	spin_lock_irqsave(&net->mctp.keys_lock, flags);
379 	hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
380 		hlist_del(&key->sklist);
381 		hlist_del(&key->hlist);
382 
383 		trace_mctp_key_release(key, MCTP_TRACE_KEY_CLOSED);
384 
385 		spin_lock(&key->lock);
386 		if (key->reasm_head)
387 			kfree_skb(key->reasm_head);
388 		key->reasm_head = NULL;
389 		key->reasm_dead = true;
390 		key->valid = false;
391 		spin_unlock(&key->lock);
392 
393 		/* key is no longer on the lookup lists, unref */
394 		mctp_key_unref(key);
395 	}
396 	spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
397 }
398 
399 static struct proto mctp_proto = {
400 	.name		= "MCTP",
401 	.owner		= THIS_MODULE,
402 	.obj_size	= sizeof(struct mctp_sock),
403 	.init		= mctp_sk_init,
404 	.close		= mctp_sk_close,
405 	.hash		= mctp_sk_hash,
406 	.unhash		= mctp_sk_unhash,
407 };
408 
409 static int mctp_pf_create(struct net *net, struct socket *sock,
410 			  int protocol, int kern)
411 {
412 	const struct proto_ops *ops;
413 	struct proto *proto;
414 	struct sock *sk;
415 	int rc;
416 
417 	if (protocol)
418 		return -EPROTONOSUPPORT;
419 
420 	/* only datagram sockets are supported */
421 	if (sock->type != SOCK_DGRAM)
422 		return -ESOCKTNOSUPPORT;
423 
424 	proto = &mctp_proto;
425 	ops = &mctp_dgram_ops;
426 
427 	sock->state = SS_UNCONNECTED;
428 	sock->ops = ops;
429 
430 	sk = sk_alloc(net, PF_MCTP, GFP_KERNEL, proto, kern);
431 	if (!sk)
432 		return -ENOMEM;
433 
434 	sock_init_data(sock, sk);
435 
436 	rc = 0;
437 	if (sk->sk_prot->init)
438 		rc = sk->sk_prot->init(sk);
439 
440 	if (rc)
441 		goto err_sk_put;
442 
443 	return 0;
444 
445 err_sk_put:
446 	sock_orphan(sk);
447 	sock_put(sk);
448 	return rc;
449 }
450 
451 static struct net_proto_family mctp_pf = {
452 	.family = PF_MCTP,
453 	.create = mctp_pf_create,
454 	.owner = THIS_MODULE,
455 };
456 
457 static __init int mctp_init(void)
458 {
459 	int rc;
460 
461 	/* ensure our uapi tag definitions match the header format */
462 	BUILD_BUG_ON(MCTP_TAG_OWNER != MCTP_HDR_FLAG_TO);
463 	BUILD_BUG_ON(MCTP_TAG_MASK != MCTP_HDR_TAG_MASK);
464 
465 	pr_info("mctp: management component transport protocol core\n");
466 
467 	rc = sock_register(&mctp_pf);
468 	if (rc)
469 		return rc;
470 
471 	rc = proto_register(&mctp_proto, 0);
472 	if (rc)
473 		goto err_unreg_sock;
474 
475 	rc = mctp_routes_init();
476 	if (rc)
477 		goto err_unreg_proto;
478 
479 	rc = mctp_neigh_init();
480 	if (rc)
481 		goto err_unreg_proto;
482 
483 	mctp_device_init();
484 
485 	return 0;
486 
487 err_unreg_proto:
488 	proto_unregister(&mctp_proto);
489 err_unreg_sock:
490 	sock_unregister(PF_MCTP);
491 
492 	return rc;
493 }
494 
495 static __exit void mctp_exit(void)
496 {
497 	mctp_device_exit();
498 	mctp_neigh_exit();
499 	mctp_routes_exit();
500 	proto_unregister(&mctp_proto);
501 	sock_unregister(PF_MCTP);
502 }
503 
504 subsys_initcall(mctp_init);
505 module_exit(mctp_exit);
506 
507 MODULE_DESCRIPTION("MCTP core");
508 MODULE_LICENSE("GPL v2");
509 MODULE_AUTHOR("Jeremy Kerr <jk@codeconstruct.com.au>");
510 
511 MODULE_ALIAS_NETPROTO(PF_MCTP);
512