1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2020 Cloudflare Ltd https://cloudflare.com */ 3 4 #include <linux/skmsg.h> 5 #include <net/sock.h> 6 #include <net/udp.h> 7 #include <net/inet_common.h> 8 #include <asm/ioctls.h> 9 10 #include "udp_impl.h" 11 12 static struct proto *udpv6_prot_saved __read_mostly; 13 14 static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 15 int flags, int *addr_len) 16 { 17 #if IS_ENABLED(CONFIG_IPV6) 18 if (sk->sk_family == AF_INET6) 19 return udpv6_prot_saved->recvmsg(sk, msg, len, flags, addr_len); 20 #endif 21 return udp_prot.recvmsg(sk, msg, len, flags, addr_len); 22 } 23 24 static bool udp_sk_has_data(struct sock *sk) 25 { 26 return !skb_queue_empty(&udp_sk(sk)->reader_queue) || 27 !skb_queue_empty(&sk->sk_receive_queue); 28 } 29 30 static bool psock_has_data(struct sk_psock *psock) 31 { 32 return !skb_queue_empty(&psock->ingress_skb) || 33 !sk_psock_queue_empty(psock); 34 } 35 36 #define udp_msg_has_data(__sk, __psock) \ 37 ({ udp_sk_has_data(__sk) || psock_has_data(__psock); }) 38 39 static int udp_msg_wait_data(struct sock *sk, struct sk_psock *psock, 40 long timeo) 41 { 42 DEFINE_WAIT_FUNC(wait, woken_wake_function); 43 int ret = 0; 44 45 if (sk->sk_shutdown & RCV_SHUTDOWN) 46 return 1; 47 48 if (!timeo) 49 return ret; 50 51 add_wait_queue(sk_sleep(sk), &wait); 52 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 53 ret = udp_msg_has_data(sk, psock); 54 if (!ret) { 55 wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); 56 ret = udp_msg_has_data(sk, psock); 57 } 58 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 59 remove_wait_queue(sk_sleep(sk), &wait); 60 return ret; 61 } 62 63 static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 64 int flags, int *addr_len) 65 { 66 struct sk_psock *psock; 67 int copied, ret; 68 69 if (unlikely(flags & MSG_ERRQUEUE)) 70 return inet_recv_error(sk, msg, len, addr_len); 71 72 if (!len) 73 return 0; 74 75 psock = sk_psock_get(sk); 76 if (unlikely(!psock)) 77 return sk_udp_recvmsg(sk, msg, len, flags, addr_len); 78 79 if (!psock_has_data(psock)) { 80 ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len); 81 goto out; 82 } 83 84 msg_bytes_ready: 85 copied = sk_msg_recvmsg(sk, psock, msg, len, flags); 86 if (!copied) { 87 long timeo; 88 int data; 89 90 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 91 data = udp_msg_wait_data(sk, psock, timeo); 92 if (data) { 93 if (psock_has_data(psock)) 94 goto msg_bytes_ready; 95 ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len); 96 goto out; 97 } 98 copied = -EAGAIN; 99 } 100 ret = copied; 101 out: 102 sk_psock_put(sk, psock); 103 return ret; 104 } 105 106 enum { 107 UDP_BPF_IPV4, 108 UDP_BPF_IPV6, 109 UDP_BPF_NUM_PROTS, 110 }; 111 112 static DEFINE_SPINLOCK(udpv6_prot_lock); 113 static struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS]; 114 115 static int udp_bpf_ioctl(struct sock *sk, int cmd, int *karg) 116 { 117 if (cmd != SIOCINQ) 118 return udp_ioctl(sk, cmd, karg); 119 120 /* Since we don't hold a lock, sk_receive_queue may contain data. 121 * BPF might only be processing this data at the moment. We only 122 * care about the data in the ingress_msg here. 123 */ 124 *karg = sk_msg_first_len(sk); 125 return 0; 126 } 127 128 static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base) 129 { 130 *prot = *base; 131 prot->close = sock_map_close; 132 prot->recvmsg = udp_bpf_recvmsg; 133 prot->sock_is_readable = sk_msg_is_readable; 134 prot->ioctl = udp_bpf_ioctl; 135 } 136 137 static void udp_bpf_check_v6_needs_rebuild(struct proto *ops) 138 { 139 if (unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) { 140 spin_lock_bh(&udpv6_prot_lock); 141 if (likely(ops != udpv6_prot_saved)) { 142 udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops); 143 smp_store_release(&udpv6_prot_saved, ops); 144 } 145 spin_unlock_bh(&udpv6_prot_lock); 146 } 147 } 148 149 static int __init udp_bpf_v4_build_proto(void) 150 { 151 udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV4], &udp_prot); 152 return 0; 153 } 154 late_initcall(udp_bpf_v4_build_proto); 155 156 int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 157 { 158 int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6; 159 160 if (restore) { 161 sk->sk_write_space = psock->saved_write_space; 162 sock_replace_proto(sk, psock->sk_proto); 163 return 0; 164 } 165 166 if (sk->sk_family == AF_INET6) 167 udp_bpf_check_v6_needs_rebuild(psock->sk_proto); 168 169 sock_replace_proto(sk, &udp_bpf_prots[family]); 170 return 0; 171 } 172 EXPORT_SYMBOL_GPL(udp_bpf_update_proto); 173