1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2021 Cong Wang <cong.wang@bytedance.com> */ 3 4 #include <linux/skmsg.h> 5 #include <linux/bpf.h> 6 #include <net/sock.h> 7 #include <net/af_unix.h> 8 9 #define unix_sk_has_data(__sk, __psock) \ 10 ({ !skb_queue_empty(&__sk->sk_receive_queue) || \ 11 !skb_queue_empty(&__psock->ingress_skb) || \ 12 !list_empty(&__psock->ingress_msg); \ 13 }) 14 15 static int unix_msg_wait_data(struct sock *sk, struct sk_psock *psock, 16 long timeo) 17 { 18 DEFINE_WAIT_FUNC(wait, woken_wake_function); 19 struct unix_sock *u = unix_sk(sk); 20 int ret = 0; 21 22 if (sk->sk_shutdown & RCV_SHUTDOWN) 23 return 1; 24 25 if (!timeo) 26 return ret; 27 28 add_wait_queue(sk_sleep(sk), &wait); 29 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 30 if (!unix_sk_has_data(sk, psock)) { 31 mutex_unlock(&u->iolock); 32 wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); 33 mutex_lock(&u->iolock); 34 ret = unix_sk_has_data(sk, psock); 35 } 36 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 37 remove_wait_queue(sk_sleep(sk), &wait); 38 return ret; 39 } 40 41 static int __unix_recvmsg(struct sock *sk, struct msghdr *msg, 42 size_t len, int flags) 43 { 44 if (sk->sk_type == SOCK_DGRAM) 45 return __unix_dgram_recvmsg(sk, msg, len, flags); 46 else 47 return __unix_stream_recvmsg(sk, msg, len, flags); 48 } 49 50 static int unix_bpf_recvmsg(struct sock *sk, struct msghdr *msg, 51 size_t len, int flags, int *addr_len) 52 { 53 struct unix_sock *u = unix_sk(sk); 54 struct sk_psock *psock; 55 int copied; 56 57 if (flags & MSG_OOB) 58 return -EOPNOTSUPP; 59 60 if (!len) 61 return 0; 62 63 psock = sk_psock_get(sk); 64 if (unlikely(!psock)) 65 return __unix_recvmsg(sk, msg, len, flags); 66 67 mutex_lock(&u->iolock); 68 if (!skb_queue_empty(&sk->sk_receive_queue) && 69 sk_psock_queue_empty(psock)) { 70 mutex_unlock(&u->iolock); 71 sk_psock_put(sk, psock); 72 return __unix_recvmsg(sk, msg, len, flags); 73 } 74 75 msg_bytes_ready: 76 copied = sk_msg_recvmsg(sk, psock, msg, len, flags); 77 if (!copied) { 78 long timeo; 79 int data; 80 81 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 82 data = unix_msg_wait_data(sk, psock, timeo); 83 if (data) { 84 if (!sk_psock_queue_empty(psock)) 85 goto msg_bytes_ready; 86 mutex_unlock(&u->iolock); 87 sk_psock_put(sk, psock); 88 return __unix_recvmsg(sk, msg, len, flags); 89 } 90 copied = -EAGAIN; 91 } 92 mutex_unlock(&u->iolock); 93 sk_psock_put(sk, psock); 94 return copied; 95 } 96 97 static struct proto *unix_dgram_prot_saved __read_mostly; 98 static DEFINE_SPINLOCK(unix_dgram_prot_lock); 99 static struct proto unix_dgram_bpf_prot; 100 101 static struct proto *unix_stream_prot_saved __read_mostly; 102 static DEFINE_SPINLOCK(unix_stream_prot_lock); 103 static struct proto unix_stream_bpf_prot; 104 105 static void unix_dgram_bpf_rebuild_protos(struct proto *prot, const struct proto *base) 106 { 107 *prot = *base; 108 prot->close = sock_map_close; 109 prot->recvmsg = unix_bpf_recvmsg; 110 prot->sock_is_readable = sk_msg_is_readable; 111 } 112 113 static void unix_stream_bpf_rebuild_protos(struct proto *prot, 114 const struct proto *base) 115 { 116 *prot = *base; 117 prot->close = sock_map_close; 118 prot->recvmsg = unix_bpf_recvmsg; 119 prot->sock_is_readable = sk_msg_is_readable; 120 prot->unhash = sock_map_unhash; 121 } 122 123 static void unix_dgram_bpf_check_needs_rebuild(struct proto *ops) 124 { 125 if (unlikely(ops != smp_load_acquire(&unix_dgram_prot_saved))) { 126 spin_lock_bh(&unix_dgram_prot_lock); 127 if (likely(ops != unix_dgram_prot_saved)) { 128 unix_dgram_bpf_rebuild_protos(&unix_dgram_bpf_prot, ops); 129 smp_store_release(&unix_dgram_prot_saved, ops); 130 } 131 spin_unlock_bh(&unix_dgram_prot_lock); 132 } 133 } 134 135 static void unix_stream_bpf_check_needs_rebuild(struct proto *ops) 136 { 137 if (unlikely(ops != smp_load_acquire(&unix_stream_prot_saved))) { 138 spin_lock_bh(&unix_stream_prot_lock); 139 if (likely(ops != unix_stream_prot_saved)) { 140 unix_stream_bpf_rebuild_protos(&unix_stream_bpf_prot, ops); 141 smp_store_release(&unix_stream_prot_saved, ops); 142 } 143 spin_unlock_bh(&unix_stream_prot_lock); 144 } 145 } 146 147 int unix_dgram_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 148 { 149 if (sk->sk_type != SOCK_DGRAM) 150 return -EOPNOTSUPP; 151 152 if (restore) { 153 sk->sk_write_space = psock->saved_write_space; 154 sock_replace_proto(sk, psock->sk_proto); 155 return 0; 156 } 157 158 unix_dgram_bpf_check_needs_rebuild(psock->sk_proto); 159 sock_replace_proto(sk, &unix_dgram_bpf_prot); 160 return 0; 161 } 162 163 int unix_stream_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 164 { 165 struct sock *sk_pair; 166 167 /* Restore does not decrement the sk_pair reference yet because we must 168 * keep the a reference to the socket until after an RCU grace period 169 * and any pending sends have completed. 170 */ 171 if (restore) { 172 sk->sk_write_space = psock->saved_write_space; 173 sock_replace_proto(sk, psock->sk_proto); 174 return 0; 175 } 176 177 /* psock_update_sk_prot can be called multiple times if psock is 178 * added to multiple maps and/or slots in the same map. There is 179 * also an edge case where replacing a psock with itself can trigger 180 * an extra psock_update_sk_prot during the insert process. So it 181 * must be safe to do multiple calls. Here we need to ensure we don't 182 * increment the refcnt through sock_hold many times. There will only 183 * be a single matching destroy operation. 184 */ 185 if (!psock->sk_pair) { 186 sk_pair = unix_peer(sk); 187 sock_hold(sk_pair); 188 psock->sk_pair = sk_pair; 189 } 190 191 unix_stream_bpf_check_needs_rebuild(psock->sk_proto); 192 sock_replace_proto(sk, &unix_stream_bpf_prot); 193 return 0; 194 } 195 196 void __init unix_bpf_build_proto(void) 197 { 198 unix_dgram_bpf_rebuild_protos(&unix_dgram_bpf_prot, &unix_dgram_proto); 199 unix_stream_bpf_rebuild_protos(&unix_stream_bpf_prot, &unix_stream_proto); 200 201 } 202