1edc6741cSLorenz Bauer // SPDX-License-Identifier: GPL-2.0 2edc6741cSLorenz Bauer /* Copyright (c) 2020 Cloudflare Ltd https://cloudflare.com */ 3edc6741cSLorenz Bauer 4edc6741cSLorenz Bauer #include <linux/skmsg.h> 5edc6741cSLorenz Bauer #include <net/sock.h> 6edc6741cSLorenz Bauer #include <net/udp.h> 71f5be6b3SCong Wang #include <net/inet_common.h> 81f5be6b3SCong Wang 91f5be6b3SCong Wang #include "udp_impl.h" 101f5be6b3SCong Wang 111f5be6b3SCong Wang static struct proto *udpv6_prot_saved __read_mostly; 121f5be6b3SCong Wang 131f5be6b3SCong Wang static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 141f5be6b3SCong Wang int noblock, int flags, int *addr_len) 151f5be6b3SCong Wang { 161f5be6b3SCong Wang #if IS_ENABLED(CONFIG_IPV6) 171f5be6b3SCong Wang if (sk->sk_family == AF_INET6) 181f5be6b3SCong Wang return udpv6_prot_saved->recvmsg(sk, msg, len, noblock, flags, 191f5be6b3SCong Wang addr_len); 201f5be6b3SCong Wang #endif 211f5be6b3SCong Wang return udp_prot.recvmsg(sk, msg, len, noblock, flags, addr_len); 221f5be6b3SCong Wang } 231f5be6b3SCong Wang 241f5be6b3SCong Wang static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 251f5be6b3SCong Wang int nonblock, int flags, int *addr_len) 261f5be6b3SCong Wang { 271f5be6b3SCong Wang struct sk_psock *psock; 281f5be6b3SCong Wang int copied, ret; 291f5be6b3SCong Wang 301f5be6b3SCong Wang if (unlikely(flags & MSG_ERRQUEUE)) 311f5be6b3SCong Wang return inet_recv_error(sk, msg, len, addr_len); 321f5be6b3SCong Wang 331f5be6b3SCong Wang psock = sk_psock_get(sk); 341f5be6b3SCong Wang if (unlikely(!psock)) 351f5be6b3SCong Wang return sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 361f5be6b3SCong Wang 371f5be6b3SCong Wang lock_sock(sk); 381f5be6b3SCong Wang if (sk_psock_queue_empty(psock)) { 391f5be6b3SCong Wang ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 401f5be6b3SCong Wang goto out; 411f5be6b3SCong Wang } 421f5be6b3SCong Wang 431f5be6b3SCong Wang msg_bytes_ready: 441f5be6b3SCong Wang copied = sk_msg_recvmsg(sk, psock, msg, len, flags); 451f5be6b3SCong Wang if (!copied) { 461f5be6b3SCong Wang long timeo; 47*c49661aaSCong Wang int data; 481f5be6b3SCong Wang 491f5be6b3SCong Wang timeo = sock_rcvtimeo(sk, nonblock); 50*c49661aaSCong Wang data = sk_msg_wait_data(sk, psock, timeo); 511f5be6b3SCong Wang if (data) { 521f5be6b3SCong Wang if (!sk_psock_queue_empty(psock)) 531f5be6b3SCong Wang goto msg_bytes_ready; 541f5be6b3SCong Wang ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 551f5be6b3SCong Wang goto out; 561f5be6b3SCong Wang } 571f5be6b3SCong Wang copied = -EAGAIN; 581f5be6b3SCong Wang } 591f5be6b3SCong Wang ret = copied; 601f5be6b3SCong Wang out: 611f5be6b3SCong Wang release_sock(sk); 621f5be6b3SCong Wang sk_psock_put(sk, psock); 631f5be6b3SCong Wang return ret; 641f5be6b3SCong Wang } 65edc6741cSLorenz Bauer 66edc6741cSLorenz Bauer enum { 67edc6741cSLorenz Bauer UDP_BPF_IPV4, 68edc6741cSLorenz Bauer UDP_BPF_IPV6, 69edc6741cSLorenz Bauer UDP_BPF_NUM_PROTS, 70edc6741cSLorenz Bauer }; 71edc6741cSLorenz Bauer 72edc6741cSLorenz Bauer static DEFINE_SPINLOCK(udpv6_prot_lock); 73edc6741cSLorenz Bauer static struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS]; 74edc6741cSLorenz Bauer 75edc6741cSLorenz Bauer static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base) 76edc6741cSLorenz Bauer { 77edc6741cSLorenz Bauer *prot = *base; 78edc6741cSLorenz Bauer prot->unhash = sock_map_unhash; 79edc6741cSLorenz Bauer prot->close = sock_map_close; 801f5be6b3SCong Wang prot->recvmsg = udp_bpf_recvmsg; 81edc6741cSLorenz Bauer } 82edc6741cSLorenz Bauer 837b219da4SLorenz Bauer static void udp_bpf_check_v6_needs_rebuild(struct proto *ops) 84edc6741cSLorenz Bauer { 857b219da4SLorenz Bauer if (unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) { 86edc6741cSLorenz Bauer spin_lock_bh(&udpv6_prot_lock); 87edc6741cSLorenz Bauer if (likely(ops != udpv6_prot_saved)) { 88edc6741cSLorenz Bauer udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops); 89edc6741cSLorenz Bauer smp_store_release(&udpv6_prot_saved, ops); 90edc6741cSLorenz Bauer } 91edc6741cSLorenz Bauer spin_unlock_bh(&udpv6_prot_lock); 92edc6741cSLorenz Bauer } 93edc6741cSLorenz Bauer } 94edc6741cSLorenz Bauer 95edc6741cSLorenz Bauer static int __init udp_bpf_v4_build_proto(void) 96edc6741cSLorenz Bauer { 97edc6741cSLorenz Bauer udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV4], &udp_prot); 98edc6741cSLorenz Bauer return 0; 99edc6741cSLorenz Bauer } 100edc6741cSLorenz Bauer core_initcall(udp_bpf_v4_build_proto); 101edc6741cSLorenz Bauer 10251e0158aSCong Wang int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 103edc6741cSLorenz Bauer { 104edc6741cSLorenz Bauer int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6; 1058a59f9d1SCong Wang 1068a59f9d1SCong Wang if (restore) { 1078a59f9d1SCong Wang sk->sk_write_space = psock->saved_write_space; 1088a59f9d1SCong Wang WRITE_ONCE(sk->sk_prot, psock->sk_proto); 1098a59f9d1SCong Wang return 0; 1108a59f9d1SCong Wang } 111edc6741cSLorenz Bauer 1127b219da4SLorenz Bauer if (sk->sk_family == AF_INET6) 1137b219da4SLorenz Bauer udp_bpf_check_v6_needs_rebuild(psock->sk_proto); 114edc6741cSLorenz Bauer 1158a59f9d1SCong Wang WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]); 1168a59f9d1SCong Wang return 0; 117edc6741cSLorenz Bauer } 1188a59f9d1SCong Wang EXPORT_SYMBOL_GPL(udp_bpf_update_proto); 119