1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2020 Cloudflare Ltd https://cloudflare.com */ 3 4 #include <linux/skmsg.h> 5 #include <net/sock.h> 6 #include <net/udp.h> 7 #include <net/inet_common.h> 8 #include <asm/ioctls.h> 9 10 static struct proto *udpv6_prot_saved __read_mostly; 11 12 static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 13 int flags) 14 { 15 #if IS_ENABLED(CONFIG_IPV6) 16 if (sk->sk_family == AF_INET6) 17 return udpv6_prot_saved->recvmsg(sk, msg, len, flags); 18 #endif 19 return udp_prot.recvmsg(sk, msg, len, flags); 20 } 21 22 static bool udp_sk_has_data(struct sock *sk) 23 { 24 return !skb_queue_empty(&udp_sk(sk)->reader_queue) || 25 !skb_queue_empty(&sk->sk_receive_queue); 26 } 27 28 static bool psock_has_data(struct sk_psock *psock) 29 { 30 return !skb_queue_empty(&psock->ingress_skb) || 31 !sk_psock_queue_empty(psock); 32 } 33 34 #define udp_msg_has_data(__sk, __psock) \ 35 ({ udp_sk_has_data(__sk) || psock_has_data(__psock); }) 36 37 static int udp_msg_wait_data(struct sock *sk, struct sk_psock *psock, 38 long timeo) 39 { 40 DEFINE_WAIT_FUNC(wait, woken_wake_function); 41 int ret = 0; 42 43 if (sk->sk_shutdown & RCV_SHUTDOWN) 44 return 1; 45 46 if (!timeo) 47 return ret; 48 49 add_wait_queue(sk_sleep(sk), &wait); 50 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 51 ret = udp_msg_has_data(sk, psock); 52 if (!ret) { 53 wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); 54 ret = udp_msg_has_data(sk, psock); 55 } 56 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 57 remove_wait_queue(sk_sleep(sk), &wait); 58 return ret; 59 } 60 61 static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 62 int flags) 63 { 64 struct sk_psock *psock; 65 int copied, ret; 66 67 if (unlikely(flags & MSG_ERRQUEUE)) 68 return inet_recv_error(sk, msg, len); 69 70 if (!len) 71 return 0; 72 73 psock = sk_psock_get(sk); 74 if (unlikely(!psock)) 75 return sk_udp_recvmsg(sk, msg, len, flags); 76 77 if (!psock_has_data(psock)) { 78 ret = sk_udp_recvmsg(sk, msg, len, flags); 79 goto out; 80 } 81 82 msg_bytes_ready: 83 copied = sk_msg_recvmsg(sk, psock, msg, len, flags); 84 if (!copied) { 85 long timeo; 86 int data; 87 88 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 89 data = udp_msg_wait_data(sk, psock, timeo); 90 if (data) { 91 if (psock_has_data(psock)) 92 goto msg_bytes_ready; 93 ret = sk_udp_recvmsg(sk, msg, len, flags); 94 goto out; 95 } 96 copied = -EAGAIN; 97 } 98 ret = copied; 99 out: 100 sk_psock_put(sk, psock); 101 return ret; 102 } 103 104 enum { 105 UDP_BPF_IPV4, 106 UDP_BPF_IPV6, 107 UDP_BPF_NUM_PROTS, 108 }; 109 110 static DEFINE_SPINLOCK(udpv6_prot_lock); 111 static struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS]; 112 113 static int udp_bpf_ioctl(struct sock *sk, int cmd, int *karg) 114 { 115 if (cmd != SIOCINQ) 116 return udp_ioctl(sk, cmd, karg); 117 118 /* Since we don't hold a lock, sk_receive_queue may contain data. 119 * BPF might only be processing this data at the moment. We only 120 * care about the data in the ingress_msg here. 121 */ 122 *karg = sk_msg_first_len(sk); 123 return 0; 124 } 125 126 static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base) 127 { 128 *prot = *base; 129 prot->close = sock_map_close; 130 prot->recvmsg = udp_bpf_recvmsg; 131 prot->sock_is_readable = sk_msg_is_readable; 132 prot->ioctl = udp_bpf_ioctl; 133 } 134 135 static void udp_bpf_check_v6_needs_rebuild(struct proto *ops) 136 { 137 if (unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) { 138 spin_lock_bh(&udpv6_prot_lock); 139 if (likely(ops != udpv6_prot_saved)) { 140 udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops); 141 smp_store_release(&udpv6_prot_saved, ops); 142 } 143 spin_unlock_bh(&udpv6_prot_lock); 144 } 145 } 146 147 static int __init udp_bpf_v4_build_proto(void) 148 { 149 udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV4], &udp_prot); 150 return 0; 151 } 152 late_initcall(udp_bpf_v4_build_proto); 153 154 int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 155 { 156 int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6; 157 158 if (restore) { 159 WRITE_ONCE(sk->sk_write_space, psock->saved_write_space); 160 sock_replace_proto(sk, psock->sk_proto); 161 return 0; 162 } 163 164 if (sk->sk_family == AF_INET6) 165 udp_bpf_check_v6_needs_rebuild(psock->sk_proto); 166 167 sock_replace_proto(sk, &udp_bpf_prots[family]); 168 return 0; 169 } 170 EXPORT_SYMBOL_GPL(udp_bpf_update_proto); 171