1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2020 Cloudflare Ltd https://cloudflare.com */ 3 4 #include <linux/skmsg.h> 5 #include <net/sock.h> 6 #include <net/udp.h> 7 #include <net/inet_common.h> 8 #include <asm/ioctls.h> 9 10 static struct proto *udpv6_prot_saved __read_mostly; 11 12 static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 13 int flags) 14 { 15 #if IS_ENABLED(CONFIG_IPV6) 16 if (sk->sk_family == AF_INET6) 17 return udpv6_prot_saved->recvmsg(sk, msg, len, flags); 18 #endif 19 return udp_prot.recvmsg(sk, msg, len, flags); 20 } 21 22 static bool udp_sk_has_data(struct sock *sk) 23 { 24 return !skb_queue_empty(&udp_sk(sk)->reader_queue) || 25 !skb_queue_empty(&sk->sk_receive_queue); 26 } 27 28 static bool psock_has_data(struct sk_psock *psock) 29 { 30 return !skb_queue_empty(&psock->ingress_skb) || 31 !sk_psock_queue_empty(psock); 32 } 33 34 #define udp_msg_has_data(__sk, __psock) \ 35 ({ udp_sk_has_data(__sk) || psock_has_data(__psock); }) 36 37 static int udp_msg_wait_data(struct sock *sk, struct sk_psock *psock, 38 long timeo) 39 { 40 DEFINE_WAIT_FUNC(wait, woken_wake_function); 41 int ret = 0; 42 43 if (sk->sk_shutdown & RCV_SHUTDOWN) 44 return 1; 45 46 if (!timeo) 47 return ret; 48 49 add_wait_queue(sk_sleep(sk), &wait); 50 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 51 ret = udp_msg_has_data(sk, psock); 52 if (!ret) { 53 release_sock(sk); 54 wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); 55 lock_sock(sk); 56 ret = udp_msg_has_data(sk, psock); 57 } 58 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 59 remove_wait_queue(sk_sleep(sk), &wait); 60 return ret; 61 } 62 63 static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 64 int flags) 65 { 66 struct sk_psock *psock; 67 int copied, ret; 68 69 if (unlikely(flags & MSG_ERRQUEUE)) 70 return inet_recv_error(sk, msg, len); 71 72 if (!len) 73 return 0; 74 75 psock = sk_psock_get(sk); 76 if (unlikely(!psock)) 77 return sk_udp_recvmsg(sk, msg, len, flags); 78 79 if (!psock_has_data(psock)) { 80 ret = sk_udp_recvmsg(sk, msg, len, flags); 81 goto out; 82 } 83 84 lock_sock(sk); 85 msg_bytes_ready: 86 copied = sk_msg_recvmsg(sk, psock, msg, len, flags); 87 if (!copied) { 88 long timeo; 89 int data; 90 91 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 92 data = udp_msg_wait_data(sk, psock, timeo); 93 if (data) { 94 if (psock_has_data(psock)) 95 goto msg_bytes_ready; 96 97 release_sock(sk); 98 99 ret = sk_udp_recvmsg(sk, msg, len, flags); 100 goto out; 101 } 102 copied = -EAGAIN; 103 } 104 105 release_sock(sk); 106 107 ret = copied; 108 out: 109 sk_psock_put(sk, psock); 110 return ret; 111 } 112 113 enum { 114 UDP_BPF_IPV4, 115 UDP_BPF_IPV6, 116 UDP_BPF_NUM_PROTS, 117 }; 118 119 static DEFINE_SPINLOCK(udpv6_prot_lock); 120 static struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS]; 121 122 static int udp_bpf_ioctl(struct sock *sk, int cmd, int *karg) 123 { 124 if (cmd != SIOCINQ) 125 return udp_ioctl(sk, cmd, karg); 126 127 /* Since we don't hold a lock, sk_receive_queue may contain data. 128 * BPF might only be processing this data at the moment. We only 129 * care about the data in the ingress_msg here. 130 */ 131 *karg = sk_msg_first_len(sk); 132 return 0; 133 } 134 135 static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base) 136 { 137 *prot = *base; 138 prot->close = sock_map_close; 139 prot->recvmsg = udp_bpf_recvmsg; 140 prot->sock_is_readable = sk_msg_is_readable; 141 prot->ioctl = udp_bpf_ioctl; 142 } 143 144 static void udp_bpf_check_v6_needs_rebuild(struct proto *ops) 145 { 146 if (unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) { 147 spin_lock_bh(&udpv6_prot_lock); 148 if (likely(ops != udpv6_prot_saved)) { 149 udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops); 150 smp_store_release(&udpv6_prot_saved, ops); 151 } 152 spin_unlock_bh(&udpv6_prot_lock); 153 } 154 } 155 156 static int __init udp_bpf_v4_build_proto(void) 157 { 158 udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV4], &udp_prot); 159 return 0; 160 } 161 late_initcall(udp_bpf_v4_build_proto); 162 163 int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 164 { 165 int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6; 166 167 if (restore) { 168 WRITE_ONCE(sk->sk_write_space, psock->saved_write_space); 169 sock_replace_proto(sk, psock->sk_proto); 170 return 0; 171 } 172 173 if (sk->sk_family == AF_INET6) 174 udp_bpf_check_v6_needs_rebuild(psock->sk_proto); 175 176 sock_replace_proto(sk, &udp_bpf_prots[family]); 177 return 0; 178 } 179 EXPORT_SYMBOL_GPL(udp_bpf_update_proto); 180