1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2020 Cloudflare Ltd https://cloudflare.com */ 3 4 #include <linux/skmsg.h> 5 #include <net/sock.h> 6 #include <net/udp.h> 7 #include <net/inet_common.h> 8 9 #include "udp_impl.h" 10 11 static struct proto *udpv6_prot_saved __read_mostly; 12 13 static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 14 int flags, int *addr_len) 15 { 16 #if IS_ENABLED(CONFIG_IPV6) 17 if (sk->sk_family == AF_INET6) 18 return udpv6_prot_saved->recvmsg(sk, msg, len, flags, addr_len); 19 #endif 20 return udp_prot.recvmsg(sk, msg, len, flags, addr_len); 21 } 22 23 static bool udp_sk_has_data(struct sock *sk) 24 { 25 return !skb_queue_empty(&udp_sk(sk)->reader_queue) || 26 !skb_queue_empty(&sk->sk_receive_queue); 27 } 28 29 static bool psock_has_data(struct sk_psock *psock) 30 { 31 return !skb_queue_empty(&psock->ingress_skb) || 32 !sk_psock_queue_empty(psock); 33 } 34 35 #define udp_msg_has_data(__sk, __psock) \ 36 ({ udp_sk_has_data(__sk) || psock_has_data(__psock); }) 37 38 static int udp_msg_wait_data(struct sock *sk, struct sk_psock *psock, 39 long timeo) 40 { 41 DEFINE_WAIT_FUNC(wait, woken_wake_function); 42 int ret = 0; 43 44 if (sk->sk_shutdown & RCV_SHUTDOWN) 45 return 1; 46 47 if (!timeo) 48 return ret; 49 50 add_wait_queue(sk_sleep(sk), &wait); 51 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 52 ret = udp_msg_has_data(sk, psock); 53 if (!ret) { 54 wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); 55 ret = udp_msg_has_data(sk, psock); 56 } 57 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 58 remove_wait_queue(sk_sleep(sk), &wait); 59 return ret; 60 } 61 62 static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 63 int flags, int *addr_len) 64 { 65 struct sk_psock *psock; 66 int copied, ret; 67 68 if (unlikely(flags & MSG_ERRQUEUE)) 69 return inet_recv_error(sk, msg, len, addr_len); 70 71 if (!len) 72 return 0; 73 74 psock = sk_psock_get(sk); 75 if (unlikely(!psock)) 76 return sk_udp_recvmsg(sk, msg, len, flags, addr_len); 77 78 if (!psock_has_data(psock)) { 79 ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len); 80 goto out; 81 } 82 83 msg_bytes_ready: 84 copied = sk_msg_recvmsg(sk, psock, msg, len, flags); 85 if (!copied) { 86 long timeo; 87 int data; 88 89 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 90 data = udp_msg_wait_data(sk, psock, timeo); 91 if (data) { 92 if (psock_has_data(psock)) 93 goto msg_bytes_ready; 94 ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len); 95 goto out; 96 } 97 copied = -EAGAIN; 98 } 99 ret = copied; 100 out: 101 sk_psock_put(sk, psock); 102 return ret; 103 } 104 105 enum { 106 UDP_BPF_IPV4, 107 UDP_BPF_IPV6, 108 UDP_BPF_NUM_PROTS, 109 }; 110 111 static DEFINE_SPINLOCK(udpv6_prot_lock); 112 static struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS]; 113 114 static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base) 115 { 116 *prot = *base; 117 prot->close = sock_map_close; 118 prot->recvmsg = udp_bpf_recvmsg; 119 prot->sock_is_readable = sk_msg_is_readable; 120 } 121 122 static void udp_bpf_check_v6_needs_rebuild(struct proto *ops) 123 { 124 if (unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) { 125 spin_lock_bh(&udpv6_prot_lock); 126 if (likely(ops != udpv6_prot_saved)) { 127 udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops); 128 smp_store_release(&udpv6_prot_saved, ops); 129 } 130 spin_unlock_bh(&udpv6_prot_lock); 131 } 132 } 133 134 static int __init udp_bpf_v4_build_proto(void) 135 { 136 udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV4], &udp_prot); 137 return 0; 138 } 139 late_initcall(udp_bpf_v4_build_proto); 140 141 int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 142 { 143 int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6; 144 145 if (restore) { 146 sk->sk_write_space = psock->saved_write_space; 147 sock_replace_proto(sk, psock->sk_proto); 148 return 0; 149 } 150 151 if (sk->sk_family == AF_INET6) 152 udp_bpf_check_v6_needs_rebuild(psock->sk_proto); 153 154 sock_replace_proto(sk, &udp_bpf_prots[family]); 155 return 0; 156 } 157 EXPORT_SYMBOL_GPL(udp_bpf_update_proto); 158