1 // SPDX-License-Identifier: GPL-2.0 2 /* 3 * Management Component Transport Protocol (MCTP) 4 * 5 * Copyright (c) 2021 Code Construct 6 * Copyright (c) 2021 Google 7 */ 8 9 #include <linux/if_arp.h> 10 #include <linux/net.h> 11 #include <linux/mctp.h> 12 #include <linux/module.h> 13 #include <linux/socket.h> 14 15 #include <net/mctp.h> 16 #include <net/mctpdevice.h> 17 #include <net/sock.h> 18 19 #define CREATE_TRACE_POINTS 20 #include <trace/events/mctp.h> 21 22 /* socket implementation */ 23 24 static int mctp_release(struct socket *sock) 25 { 26 struct sock *sk = sock->sk; 27 28 if (sk) { 29 sock->sk = NULL; 30 sk->sk_prot->close(sk, 0); 31 } 32 33 return 0; 34 } 35 36 static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen) 37 { 38 struct sock *sk = sock->sk; 39 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); 40 struct sockaddr_mctp *smctp; 41 int rc; 42 43 if (addrlen < sizeof(*smctp)) 44 return -EINVAL; 45 46 if (addr->sa_family != AF_MCTP) 47 return -EAFNOSUPPORT; 48 49 if (!capable(CAP_NET_BIND_SERVICE)) 50 return -EACCES; 51 52 /* it's a valid sockaddr for MCTP, cast and do protocol checks */ 53 smctp = (struct sockaddr_mctp *)addr; 54 55 lock_sock(sk); 56 57 /* TODO: allow rebind */ 58 if (sk_hashed(sk)) { 59 rc = -EADDRINUSE; 60 goto out_release; 61 } 62 msk->bind_net = smctp->smctp_network; 63 msk->bind_addr = smctp->smctp_addr.s_addr; 64 msk->bind_type = smctp->smctp_type & 0x7f; /* ignore the IC bit */ 65 66 rc = sk->sk_prot->hash(sk); 67 68 out_release: 69 release_sock(sk); 70 71 return rc; 72 } 73 74 static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) 75 { 76 DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name); 77 const int hlen = MCTP_HEADER_MAXLEN + sizeof(struct mctp_hdr); 78 int rc, addrlen = msg->msg_namelen; 79 struct sock *sk = sock->sk; 80 struct mctp_skb_cb *cb; 81 struct mctp_route *rt; 82 struct sk_buff *skb; 83 84 if (addr) { 85 if (addrlen < sizeof(struct sockaddr_mctp)) 86 return -EINVAL; 87 if (addr->smctp_family != AF_MCTP) 88 return -EINVAL; 89 if (addr->smctp_tag & ~(MCTP_TAG_MASK | MCTP_TAG_OWNER)) 90 return -EINVAL; 91 92 } else { 93 /* TODO: connect()ed sockets */ 94 return -EDESTADDRREQ; 95 } 96 97 if (!capable(CAP_NET_RAW)) 98 return -EACCES; 99 100 if (addr->smctp_network == MCTP_NET_ANY) 101 addr->smctp_network = mctp_default_net(sock_net(sk)); 102 103 rt = mctp_route_lookup(sock_net(sk), addr->smctp_network, 104 addr->smctp_addr.s_addr); 105 if (!rt) 106 return -EHOSTUNREACH; 107 108 skb = sock_alloc_send_skb(sk, hlen + 1 + len, 109 msg->msg_flags & MSG_DONTWAIT, &rc); 110 if (!skb) 111 return rc; 112 113 skb_reserve(skb, hlen); 114 115 /* set type as fist byte in payload */ 116 *(u8 *)skb_put(skb, 1) = addr->smctp_type; 117 118 rc = memcpy_from_msg((void *)skb_put(skb, len), msg, len); 119 if (rc < 0) { 120 kfree_skb(skb); 121 return rc; 122 } 123 124 /* set up cb */ 125 cb = __mctp_cb(skb); 126 cb->net = addr->smctp_network; 127 128 rc = mctp_local_output(sk, rt, skb, addr->smctp_addr.s_addr, 129 addr->smctp_tag); 130 131 return rc ? : len; 132 } 133 134 static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, 135 int flags) 136 { 137 DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name); 138 struct sock *sk = sock->sk; 139 struct sk_buff *skb; 140 size_t msglen; 141 u8 type; 142 int rc; 143 144 if (flags & ~(MSG_DONTWAIT | MSG_TRUNC | MSG_PEEK)) 145 return -EOPNOTSUPP; 146 147 skb = skb_recv_datagram(sk, flags, flags & MSG_DONTWAIT, &rc); 148 if (!skb) 149 return rc; 150 151 if (!skb->len) { 152 rc = 0; 153 goto out_free; 154 } 155 156 /* extract message type, remove from data */ 157 type = *((u8 *)skb->data); 158 msglen = skb->len - 1; 159 160 if (len < msglen) 161 msg->msg_flags |= MSG_TRUNC; 162 else 163 len = msglen; 164 165 rc = skb_copy_datagram_msg(skb, 1, msg, len); 166 if (rc < 0) 167 goto out_free; 168 169 sock_recv_ts_and_drops(msg, sk, skb); 170 171 if (addr) { 172 struct mctp_skb_cb *cb = mctp_cb(skb); 173 /* TODO: expand mctp_skb_cb for header fields? */ 174 struct mctp_hdr *hdr = mctp_hdr(skb); 175 176 addr = msg->msg_name; 177 addr->smctp_family = AF_MCTP; 178 addr->smctp_network = cb->net; 179 addr->smctp_addr.s_addr = hdr->src; 180 addr->smctp_type = type; 181 addr->smctp_tag = hdr->flags_seq_tag & 182 (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO); 183 msg->msg_namelen = sizeof(*addr); 184 } 185 186 rc = len; 187 188 if (flags & MSG_TRUNC) 189 rc = msglen; 190 191 out_free: 192 skb_free_datagram(sk, skb); 193 return rc; 194 } 195 196 static int mctp_setsockopt(struct socket *sock, int level, int optname, 197 sockptr_t optval, unsigned int optlen) 198 { 199 return -EINVAL; 200 } 201 202 static int mctp_getsockopt(struct socket *sock, int level, int optname, 203 char __user *optval, int __user *optlen) 204 { 205 return -EINVAL; 206 } 207 208 static const struct proto_ops mctp_dgram_ops = { 209 .family = PF_MCTP, 210 .release = mctp_release, 211 .bind = mctp_bind, 212 .connect = sock_no_connect, 213 .socketpair = sock_no_socketpair, 214 .accept = sock_no_accept, 215 .getname = sock_no_getname, 216 .poll = datagram_poll, 217 .ioctl = sock_no_ioctl, 218 .gettstamp = sock_gettstamp, 219 .listen = sock_no_listen, 220 .shutdown = sock_no_shutdown, 221 .setsockopt = mctp_setsockopt, 222 .getsockopt = mctp_getsockopt, 223 .sendmsg = mctp_sendmsg, 224 .recvmsg = mctp_recvmsg, 225 .mmap = sock_no_mmap, 226 .sendpage = sock_no_sendpage, 227 }; 228 229 static void mctp_sk_expire_keys(struct timer_list *timer) 230 { 231 struct mctp_sock *msk = container_of(timer, struct mctp_sock, 232 key_expiry); 233 struct net *net = sock_net(&msk->sk); 234 unsigned long next_expiry, flags; 235 struct mctp_sk_key *key; 236 struct hlist_node *tmp; 237 bool next_expiry_valid = false; 238 239 spin_lock_irqsave(&net->mctp.keys_lock, flags); 240 241 hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) { 242 spin_lock(&key->lock); 243 244 if (!time_after_eq(key->expiry, jiffies)) { 245 trace_mctp_key_release(key, MCTP_TRACE_KEY_TIMEOUT); 246 key->valid = false; 247 hlist_del_rcu(&key->hlist); 248 hlist_del_rcu(&key->sklist); 249 spin_unlock(&key->lock); 250 mctp_key_unref(key); 251 continue; 252 } 253 254 if (next_expiry_valid) { 255 if (time_before(key->expiry, next_expiry)) 256 next_expiry = key->expiry; 257 } else { 258 next_expiry = key->expiry; 259 next_expiry_valid = true; 260 } 261 spin_unlock(&key->lock); 262 } 263 264 spin_unlock_irqrestore(&net->mctp.keys_lock, flags); 265 266 if (next_expiry_valid) 267 mod_timer(timer, next_expiry); 268 } 269 270 static int mctp_sk_init(struct sock *sk) 271 { 272 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); 273 274 INIT_HLIST_HEAD(&msk->keys); 275 timer_setup(&msk->key_expiry, mctp_sk_expire_keys, 0); 276 return 0; 277 } 278 279 static void mctp_sk_close(struct sock *sk, long timeout) 280 { 281 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); 282 283 del_timer_sync(&msk->key_expiry); 284 sk_common_release(sk); 285 } 286 287 static int mctp_sk_hash(struct sock *sk) 288 { 289 struct net *net = sock_net(sk); 290 291 mutex_lock(&net->mctp.bind_lock); 292 sk_add_node_rcu(sk, &net->mctp.binds); 293 mutex_unlock(&net->mctp.bind_lock); 294 295 return 0; 296 } 297 298 static void mctp_sk_unhash(struct sock *sk) 299 { 300 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); 301 struct net *net = sock_net(sk); 302 struct mctp_sk_key *key; 303 struct hlist_node *tmp; 304 unsigned long flags; 305 306 /* remove from any type-based binds */ 307 mutex_lock(&net->mctp.bind_lock); 308 sk_del_node_init_rcu(sk); 309 mutex_unlock(&net->mctp.bind_lock); 310 311 /* remove tag allocations */ 312 spin_lock_irqsave(&net->mctp.keys_lock, flags); 313 hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) { 314 hlist_del(&key->sklist); 315 hlist_del(&key->hlist); 316 317 trace_mctp_key_release(key, MCTP_TRACE_KEY_CLOSED); 318 319 spin_lock(&key->lock); 320 if (key->reasm_head) 321 kfree_skb(key->reasm_head); 322 key->reasm_head = NULL; 323 key->reasm_dead = true; 324 key->valid = false; 325 spin_unlock(&key->lock); 326 327 /* key is no longer on the lookup lists, unref */ 328 mctp_key_unref(key); 329 } 330 spin_unlock_irqrestore(&net->mctp.keys_lock, flags); 331 } 332 333 static struct proto mctp_proto = { 334 .name = "MCTP", 335 .owner = THIS_MODULE, 336 .obj_size = sizeof(struct mctp_sock), 337 .init = mctp_sk_init, 338 .close = mctp_sk_close, 339 .hash = mctp_sk_hash, 340 .unhash = mctp_sk_unhash, 341 }; 342 343 static int mctp_pf_create(struct net *net, struct socket *sock, 344 int protocol, int kern) 345 { 346 const struct proto_ops *ops; 347 struct proto *proto; 348 struct sock *sk; 349 int rc; 350 351 if (protocol) 352 return -EPROTONOSUPPORT; 353 354 /* only datagram sockets are supported */ 355 if (sock->type != SOCK_DGRAM) 356 return -ESOCKTNOSUPPORT; 357 358 proto = &mctp_proto; 359 ops = &mctp_dgram_ops; 360 361 sock->state = SS_UNCONNECTED; 362 sock->ops = ops; 363 364 sk = sk_alloc(net, PF_MCTP, GFP_KERNEL, proto, kern); 365 if (!sk) 366 return -ENOMEM; 367 368 sock_init_data(sock, sk); 369 370 rc = 0; 371 if (sk->sk_prot->init) 372 rc = sk->sk_prot->init(sk); 373 374 if (rc) 375 goto err_sk_put; 376 377 return 0; 378 379 err_sk_put: 380 sock_orphan(sk); 381 sock_put(sk); 382 return rc; 383 } 384 385 static struct net_proto_family mctp_pf = { 386 .family = PF_MCTP, 387 .create = mctp_pf_create, 388 .owner = THIS_MODULE, 389 }; 390 391 static __init int mctp_init(void) 392 { 393 int rc; 394 395 /* ensure our uapi tag definitions match the header format */ 396 BUILD_BUG_ON(MCTP_TAG_OWNER != MCTP_HDR_FLAG_TO); 397 BUILD_BUG_ON(MCTP_TAG_MASK != MCTP_HDR_TAG_MASK); 398 399 pr_info("mctp: management component transport protocol core\n"); 400 401 rc = sock_register(&mctp_pf); 402 if (rc) 403 return rc; 404 405 rc = proto_register(&mctp_proto, 0); 406 if (rc) 407 goto err_unreg_sock; 408 409 rc = mctp_routes_init(); 410 if (rc) 411 goto err_unreg_proto; 412 413 rc = mctp_neigh_init(); 414 if (rc) 415 goto err_unreg_proto; 416 417 mctp_device_init(); 418 419 return 0; 420 421 err_unreg_proto: 422 proto_unregister(&mctp_proto); 423 err_unreg_sock: 424 sock_unregister(PF_MCTP); 425 426 return rc; 427 } 428 429 static __exit void mctp_exit(void) 430 { 431 mctp_device_exit(); 432 mctp_neigh_exit(); 433 mctp_routes_exit(); 434 proto_unregister(&mctp_proto); 435 sock_unregister(PF_MCTP); 436 } 437 438 subsys_initcall(mctp_init); 439 module_exit(mctp_exit); 440 441 MODULE_DESCRIPTION("MCTP core"); 442 MODULE_LICENSE("GPL v2"); 443 MODULE_AUTHOR("Jeremy Kerr <jk@codeconstruct.com.au>"); 444 445 MODULE_ALIAS_NETPROTO(PF_MCTP); 446