xref: /linux/net/ipv4/udp_bpf.c (revision c49661aa6f7097047b7e86ad37b1cf308a7a8d4f)
1edc6741cSLorenz Bauer // SPDX-License-Identifier: GPL-2.0
2edc6741cSLorenz Bauer /* Copyright (c) 2020 Cloudflare Ltd https://cloudflare.com */
3edc6741cSLorenz Bauer 
4edc6741cSLorenz Bauer #include <linux/skmsg.h>
5edc6741cSLorenz Bauer #include <net/sock.h>
6edc6741cSLorenz Bauer #include <net/udp.h>
71f5be6b3SCong Wang #include <net/inet_common.h>
81f5be6b3SCong Wang 
91f5be6b3SCong Wang #include "udp_impl.h"
101f5be6b3SCong Wang 
111f5be6b3SCong Wang static struct proto *udpv6_prot_saved __read_mostly;
121f5be6b3SCong Wang 
131f5be6b3SCong Wang static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
141f5be6b3SCong Wang 			  int noblock, int flags, int *addr_len)
151f5be6b3SCong Wang {
161f5be6b3SCong Wang #if IS_ENABLED(CONFIG_IPV6)
171f5be6b3SCong Wang 	if (sk->sk_family == AF_INET6)
181f5be6b3SCong Wang 		return udpv6_prot_saved->recvmsg(sk, msg, len, noblock, flags,
191f5be6b3SCong Wang 						 addr_len);
201f5be6b3SCong Wang #endif
211f5be6b3SCong Wang 	return udp_prot.recvmsg(sk, msg, len, noblock, flags, addr_len);
221f5be6b3SCong Wang }
231f5be6b3SCong Wang 
241f5be6b3SCong Wang static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
251f5be6b3SCong Wang 			   int nonblock, int flags, int *addr_len)
261f5be6b3SCong Wang {
271f5be6b3SCong Wang 	struct sk_psock *psock;
281f5be6b3SCong Wang 	int copied, ret;
291f5be6b3SCong Wang 
301f5be6b3SCong Wang 	if (unlikely(flags & MSG_ERRQUEUE))
311f5be6b3SCong Wang 		return inet_recv_error(sk, msg, len, addr_len);
321f5be6b3SCong Wang 
331f5be6b3SCong Wang 	psock = sk_psock_get(sk);
341f5be6b3SCong Wang 	if (unlikely(!psock))
351f5be6b3SCong Wang 		return sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
361f5be6b3SCong Wang 
371f5be6b3SCong Wang 	lock_sock(sk);
381f5be6b3SCong Wang 	if (sk_psock_queue_empty(psock)) {
391f5be6b3SCong Wang 		ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
401f5be6b3SCong Wang 		goto out;
411f5be6b3SCong Wang 	}
421f5be6b3SCong Wang 
431f5be6b3SCong Wang msg_bytes_ready:
441f5be6b3SCong Wang 	copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
451f5be6b3SCong Wang 	if (!copied) {
461f5be6b3SCong Wang 		long timeo;
47*c49661aaSCong Wang 		int data;
481f5be6b3SCong Wang 
491f5be6b3SCong Wang 		timeo = sock_rcvtimeo(sk, nonblock);
50*c49661aaSCong Wang 		data = sk_msg_wait_data(sk, psock, timeo);
511f5be6b3SCong Wang 		if (data) {
521f5be6b3SCong Wang 			if (!sk_psock_queue_empty(psock))
531f5be6b3SCong Wang 				goto msg_bytes_ready;
541f5be6b3SCong Wang 			ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
551f5be6b3SCong Wang 			goto out;
561f5be6b3SCong Wang 		}
571f5be6b3SCong Wang 		copied = -EAGAIN;
581f5be6b3SCong Wang 	}
591f5be6b3SCong Wang 	ret = copied;
601f5be6b3SCong Wang out:
611f5be6b3SCong Wang 	release_sock(sk);
621f5be6b3SCong Wang 	sk_psock_put(sk, psock);
631f5be6b3SCong Wang 	return ret;
641f5be6b3SCong Wang }
65edc6741cSLorenz Bauer 
66edc6741cSLorenz Bauer enum {
67edc6741cSLorenz Bauer 	UDP_BPF_IPV4,
68edc6741cSLorenz Bauer 	UDP_BPF_IPV6,
69edc6741cSLorenz Bauer 	UDP_BPF_NUM_PROTS,
70edc6741cSLorenz Bauer };
71edc6741cSLorenz Bauer 
72edc6741cSLorenz Bauer static DEFINE_SPINLOCK(udpv6_prot_lock);
73edc6741cSLorenz Bauer static struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS];
74edc6741cSLorenz Bauer 
75edc6741cSLorenz Bauer static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
76edc6741cSLorenz Bauer {
77edc6741cSLorenz Bauer 	*prot        = *base;
78edc6741cSLorenz Bauer 	prot->unhash = sock_map_unhash;
79edc6741cSLorenz Bauer 	prot->close  = sock_map_close;
801f5be6b3SCong Wang 	prot->recvmsg = udp_bpf_recvmsg;
81edc6741cSLorenz Bauer }
82edc6741cSLorenz Bauer 
837b219da4SLorenz Bauer static void udp_bpf_check_v6_needs_rebuild(struct proto *ops)
84edc6741cSLorenz Bauer {
857b219da4SLorenz Bauer 	if (unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) {
86edc6741cSLorenz Bauer 		spin_lock_bh(&udpv6_prot_lock);
87edc6741cSLorenz Bauer 		if (likely(ops != udpv6_prot_saved)) {
88edc6741cSLorenz Bauer 			udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops);
89edc6741cSLorenz Bauer 			smp_store_release(&udpv6_prot_saved, ops);
90edc6741cSLorenz Bauer 		}
91edc6741cSLorenz Bauer 		spin_unlock_bh(&udpv6_prot_lock);
92edc6741cSLorenz Bauer 	}
93edc6741cSLorenz Bauer }
94edc6741cSLorenz Bauer 
95edc6741cSLorenz Bauer static int __init udp_bpf_v4_build_proto(void)
96edc6741cSLorenz Bauer {
97edc6741cSLorenz Bauer 	udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV4], &udp_prot);
98edc6741cSLorenz Bauer 	return 0;
99edc6741cSLorenz Bauer }
100edc6741cSLorenz Bauer core_initcall(udp_bpf_v4_build_proto);
101edc6741cSLorenz Bauer 
10251e0158aSCong Wang int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
103edc6741cSLorenz Bauer {
104edc6741cSLorenz Bauer 	int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
1058a59f9d1SCong Wang 
1068a59f9d1SCong Wang 	if (restore) {
1078a59f9d1SCong Wang 		sk->sk_write_space = psock->saved_write_space;
1088a59f9d1SCong Wang 		WRITE_ONCE(sk->sk_prot, psock->sk_proto);
1098a59f9d1SCong Wang 		return 0;
1108a59f9d1SCong Wang 	}
111edc6741cSLorenz Bauer 
1127b219da4SLorenz Bauer 	if (sk->sk_family == AF_INET6)
1137b219da4SLorenz Bauer 		udp_bpf_check_v6_needs_rebuild(psock->sk_proto);
114edc6741cSLorenz Bauer 
1158a59f9d1SCong Wang 	WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]);
1168a59f9d1SCong Wang 	return 0;
117edc6741cSLorenz Bauer }
1188a59f9d1SCong Wang EXPORT_SYMBOL_GPL(udp_bpf_update_proto);
119