1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2021 Cong Wang <cong.wang@bytedance.com> */ 3 4 #include <linux/skmsg.h> 5 #include <linux/bpf.h> 6 #include <net/sock.h> 7 #include <net/af_unix.h> 8 9 #define unix_sk_has_data(__sk, __psock) \ 10 ({ !skb_queue_empty(&__sk->sk_receive_queue) || \ 11 !skb_queue_empty(&__psock->ingress_skb) || \ 12 !list_empty(&__psock->ingress_msg); \ 13 }) 14 15 static int unix_msg_wait_data(struct sock *sk, struct sk_psock *psock, 16 long timeo) 17 { 18 DEFINE_WAIT_FUNC(wait, woken_wake_function); 19 struct unix_sock *u = unix_sk(sk); 20 int ret = 0; 21 22 if (sk->sk_shutdown & RCV_SHUTDOWN) 23 return 1; 24 25 if (!timeo) 26 return ret; 27 28 add_wait_queue(sk_sleep(sk), &wait); 29 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 30 if (!unix_sk_has_data(sk, psock)) { 31 mutex_unlock(&u->iolock); 32 wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); 33 mutex_lock(&u->iolock); 34 ret = unix_sk_has_data(sk, psock); 35 } 36 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 37 remove_wait_queue(sk_sleep(sk), &wait); 38 return ret; 39 } 40 41 static int __unix_recvmsg(struct sock *sk, struct msghdr *msg, 42 size_t len, int flags) 43 { 44 if (sk->sk_type == SOCK_DGRAM) 45 return __unix_dgram_recvmsg(sk, msg, len, flags); 46 else 47 return __unix_stream_recvmsg(sk, msg, len, flags); 48 } 49 50 static int unix_bpf_recvmsg(struct sock *sk, struct msghdr *msg, 51 size_t len, int nonblock, int flags, 52 int *addr_len) 53 { 54 struct unix_sock *u = unix_sk(sk); 55 struct sk_psock *psock; 56 int copied; 57 58 psock = sk_psock_get(sk); 59 if (unlikely(!psock)) 60 return __unix_recvmsg(sk, msg, len, flags); 61 62 mutex_lock(&u->iolock); 63 if (!skb_queue_empty(&sk->sk_receive_queue) && 64 sk_psock_queue_empty(psock)) { 65 mutex_unlock(&u->iolock); 66 sk_psock_put(sk, psock); 67 return __unix_recvmsg(sk, msg, len, flags); 68 } 69 70 msg_bytes_ready: 71 copied = sk_msg_recvmsg(sk, psock, msg, len, flags); 72 if (!copied) { 73 long timeo; 74 int data; 75 76 timeo = sock_rcvtimeo(sk, nonblock); 77 data = unix_msg_wait_data(sk, psock, timeo); 78 if (data) { 79 if (!sk_psock_queue_empty(psock)) 80 goto msg_bytes_ready; 81 mutex_unlock(&u->iolock); 82 sk_psock_put(sk, psock); 83 return __unix_recvmsg(sk, msg, len, flags); 84 } 85 copied = -EAGAIN; 86 } 87 mutex_unlock(&u->iolock); 88 sk_psock_put(sk, psock); 89 return copied; 90 } 91 92 static struct proto *unix_dgram_prot_saved __read_mostly; 93 static DEFINE_SPINLOCK(unix_dgram_prot_lock); 94 static struct proto unix_dgram_bpf_prot; 95 96 static struct proto *unix_stream_prot_saved __read_mostly; 97 static DEFINE_SPINLOCK(unix_stream_prot_lock); 98 static struct proto unix_stream_bpf_prot; 99 100 static void unix_dgram_bpf_rebuild_protos(struct proto *prot, const struct proto *base) 101 { 102 *prot = *base; 103 prot->close = sock_map_close; 104 prot->recvmsg = unix_bpf_recvmsg; 105 prot->sock_is_readable = sk_msg_is_readable; 106 } 107 108 static void unix_stream_bpf_rebuild_protos(struct proto *prot, 109 const struct proto *base) 110 { 111 *prot = *base; 112 prot->close = sock_map_close; 113 prot->recvmsg = unix_bpf_recvmsg; 114 prot->sock_is_readable = sk_msg_is_readable; 115 prot->unhash = sock_map_unhash; 116 } 117 118 static void unix_dgram_bpf_check_needs_rebuild(struct proto *ops) 119 { 120 if (unlikely(ops != smp_load_acquire(&unix_dgram_prot_saved))) { 121 spin_lock_bh(&unix_dgram_prot_lock); 122 if (likely(ops != unix_dgram_prot_saved)) { 123 unix_dgram_bpf_rebuild_protos(&unix_dgram_bpf_prot, ops); 124 smp_store_release(&unix_dgram_prot_saved, ops); 125 } 126 spin_unlock_bh(&unix_dgram_prot_lock); 127 } 128 } 129 130 static void unix_stream_bpf_check_needs_rebuild(struct proto *ops) 131 { 132 if (unlikely(ops != smp_load_acquire(&unix_stream_prot_saved))) { 133 spin_lock_bh(&unix_stream_prot_lock); 134 if (likely(ops != unix_stream_prot_saved)) { 135 unix_stream_bpf_rebuild_protos(&unix_stream_bpf_prot, ops); 136 smp_store_release(&unix_stream_prot_saved, ops); 137 } 138 spin_unlock_bh(&unix_stream_prot_lock); 139 } 140 } 141 142 int unix_dgram_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 143 { 144 if (sk->sk_type != SOCK_DGRAM) 145 return -EOPNOTSUPP; 146 147 if (restore) { 148 sk->sk_write_space = psock->saved_write_space; 149 WRITE_ONCE(sk->sk_prot, psock->sk_proto); 150 return 0; 151 } 152 153 unix_dgram_bpf_check_needs_rebuild(psock->sk_proto); 154 WRITE_ONCE(sk->sk_prot, &unix_dgram_bpf_prot); 155 return 0; 156 } 157 158 int unix_stream_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 159 { 160 if (restore) { 161 sk->sk_write_space = psock->saved_write_space; 162 WRITE_ONCE(sk->sk_prot, psock->sk_proto); 163 return 0; 164 } 165 166 unix_stream_bpf_check_needs_rebuild(psock->sk_proto); 167 WRITE_ONCE(sk->sk_prot, &unix_stream_bpf_prot); 168 return 0; 169 } 170 171 void __init unix_bpf_build_proto(void) 172 { 173 unix_dgram_bpf_rebuild_protos(&unix_dgram_bpf_prot, &unix_dgram_proto); 174 unix_stream_bpf_rebuild_protos(&unix_stream_bpf_prot, &unix_stream_proto); 175 176 } 177