1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2020 Cloudflare Ltd https://cloudflare.com */ 3 4 #include <linux/skmsg.h> 5 #include <net/sock.h> 6 #include <net/udp.h> 7 #include <net/inet_common.h> 8 9 #include "udp_impl.h" 10 11 static struct proto *udpv6_prot_saved __read_mostly; 12 13 static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 14 int flags, int *addr_len) 15 { 16 #if IS_ENABLED(CONFIG_IPV6) 17 if (sk->sk_family == AF_INET6) 18 return udpv6_prot_saved->recvmsg(sk, msg, len, flags, addr_len); 19 #endif 20 return udp_prot.recvmsg(sk, msg, len, flags, addr_len); 21 } 22 23 static bool udp_sk_has_data(struct sock *sk) 24 { 25 return !skb_queue_empty(&udp_sk(sk)->reader_queue) || 26 !skb_queue_empty(&sk->sk_receive_queue); 27 } 28 29 static bool psock_has_data(struct sk_psock *psock) 30 { 31 return !skb_queue_empty(&psock->ingress_skb) || 32 !sk_psock_queue_empty(psock); 33 } 34 35 #define udp_msg_has_data(__sk, __psock) \ 36 ({ udp_sk_has_data(__sk) || psock_has_data(__psock); }) 37 38 static int udp_msg_wait_data(struct sock *sk, struct sk_psock *psock, 39 long timeo) 40 { 41 DEFINE_WAIT_FUNC(wait, woken_wake_function); 42 int ret = 0; 43 44 if (sk->sk_shutdown & RCV_SHUTDOWN) 45 return 1; 46 47 if (!timeo) 48 return ret; 49 50 add_wait_queue(sk_sleep(sk), &wait); 51 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 52 ret = udp_msg_has_data(sk, psock); 53 if (!ret) { 54 wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); 55 ret = udp_msg_has_data(sk, psock); 56 } 57 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 58 remove_wait_queue(sk_sleep(sk), &wait); 59 return ret; 60 } 61 62 static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 63 int flags, int *addr_len) 64 { 65 struct sk_psock *psock; 66 int copied, ret; 67 68 if (unlikely(flags & MSG_ERRQUEUE)) 69 return inet_recv_error(sk, msg, len, addr_len); 70 71 psock = sk_psock_get(sk); 72 if (unlikely(!psock)) 73 return sk_udp_recvmsg(sk, msg, len, flags, addr_len); 74 75 if (!psock_has_data(psock)) { 76 ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len); 77 goto out; 78 } 79 80 msg_bytes_ready: 81 copied = sk_msg_recvmsg(sk, psock, msg, len, flags); 82 if (!copied) { 83 long timeo; 84 int data; 85 86 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 87 data = udp_msg_wait_data(sk, psock, timeo); 88 if (data) { 89 if (psock_has_data(psock)) 90 goto msg_bytes_ready; 91 ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len); 92 goto out; 93 } 94 copied = -EAGAIN; 95 } 96 ret = copied; 97 out: 98 sk_psock_put(sk, psock); 99 return ret; 100 } 101 102 enum { 103 UDP_BPF_IPV4, 104 UDP_BPF_IPV6, 105 UDP_BPF_NUM_PROTS, 106 }; 107 108 static DEFINE_SPINLOCK(udpv6_prot_lock); 109 static struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS]; 110 111 static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base) 112 { 113 *prot = *base; 114 prot->close = sock_map_close; 115 prot->recvmsg = udp_bpf_recvmsg; 116 prot->sock_is_readable = sk_msg_is_readable; 117 } 118 119 static void udp_bpf_check_v6_needs_rebuild(struct proto *ops) 120 { 121 if (unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) { 122 spin_lock_bh(&udpv6_prot_lock); 123 if (likely(ops != udpv6_prot_saved)) { 124 udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops); 125 smp_store_release(&udpv6_prot_saved, ops); 126 } 127 spin_unlock_bh(&udpv6_prot_lock); 128 } 129 } 130 131 static int __init udp_bpf_v4_build_proto(void) 132 { 133 udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV4], &udp_prot); 134 return 0; 135 } 136 late_initcall(udp_bpf_v4_build_proto); 137 138 int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 139 { 140 int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6; 141 142 if (restore) { 143 sk->sk_write_space = psock->saved_write_space; 144 WRITE_ONCE(sk->sk_prot, psock->sk_proto); 145 return 0; 146 } 147 148 if (sk->sk_family == AF_INET6) 149 udp_bpf_check_v6_needs_rebuild(psock->sk_proto); 150 151 WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]); 152 return 0; 153 } 154 EXPORT_SYMBOL_GPL(udp_bpf_update_proto); 155