xref: /linux/net/core/skmsg.c (revision a136678c0bdbb650daff5df5eec1dab960e074a7)
1604326b4SDaniel Borkmann // SPDX-License-Identifier: GPL-2.0
2604326b4SDaniel Borkmann /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
3604326b4SDaniel Borkmann 
4604326b4SDaniel Borkmann #include <linux/skmsg.h>
5604326b4SDaniel Borkmann #include <linux/skbuff.h>
6604326b4SDaniel Borkmann #include <linux/scatterlist.h>
7604326b4SDaniel Borkmann 
8604326b4SDaniel Borkmann #include <net/sock.h>
9604326b4SDaniel Borkmann #include <net/tcp.h>
10604326b4SDaniel Borkmann 
11604326b4SDaniel Borkmann static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce)
12604326b4SDaniel Borkmann {
13604326b4SDaniel Borkmann 	if (msg->sg.end > msg->sg.start &&
14604326b4SDaniel Borkmann 	    elem_first_coalesce < msg->sg.end)
15604326b4SDaniel Borkmann 		return true;
16604326b4SDaniel Borkmann 
17604326b4SDaniel Borkmann 	if (msg->sg.end < msg->sg.start &&
18604326b4SDaniel Borkmann 	    (elem_first_coalesce > msg->sg.start ||
19604326b4SDaniel Borkmann 	     elem_first_coalesce < msg->sg.end))
20604326b4SDaniel Borkmann 		return true;
21604326b4SDaniel Borkmann 
22604326b4SDaniel Borkmann 	return false;
23604326b4SDaniel Borkmann }
24604326b4SDaniel Borkmann 
25604326b4SDaniel Borkmann int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
26604326b4SDaniel Borkmann 		 int elem_first_coalesce)
27604326b4SDaniel Borkmann {
28604326b4SDaniel Borkmann 	struct page_frag *pfrag = sk_page_frag(sk);
29604326b4SDaniel Borkmann 	int ret = 0;
30604326b4SDaniel Borkmann 
31604326b4SDaniel Borkmann 	len -= msg->sg.size;
32604326b4SDaniel Borkmann 	while (len > 0) {
33604326b4SDaniel Borkmann 		struct scatterlist *sge;
34604326b4SDaniel Borkmann 		u32 orig_offset;
35604326b4SDaniel Borkmann 		int use, i;
36604326b4SDaniel Borkmann 
37604326b4SDaniel Borkmann 		if (!sk_page_frag_refill(sk, pfrag))
38604326b4SDaniel Borkmann 			return -ENOMEM;
39604326b4SDaniel Borkmann 
40604326b4SDaniel Borkmann 		orig_offset = pfrag->offset;
41604326b4SDaniel Borkmann 		use = min_t(int, len, pfrag->size - orig_offset);
42604326b4SDaniel Borkmann 		if (!sk_wmem_schedule(sk, use))
43604326b4SDaniel Borkmann 			return -ENOMEM;
44604326b4SDaniel Borkmann 
45604326b4SDaniel Borkmann 		i = msg->sg.end;
46604326b4SDaniel Borkmann 		sk_msg_iter_var_prev(i);
47604326b4SDaniel Borkmann 		sge = &msg->sg.data[i];
48604326b4SDaniel Borkmann 
49604326b4SDaniel Borkmann 		if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) &&
50604326b4SDaniel Borkmann 		    sg_page(sge) == pfrag->page &&
51604326b4SDaniel Borkmann 		    sge->offset + sge->length == orig_offset) {
52604326b4SDaniel Borkmann 			sge->length += use;
53604326b4SDaniel Borkmann 		} else {
54604326b4SDaniel Borkmann 			if (sk_msg_full(msg)) {
55604326b4SDaniel Borkmann 				ret = -ENOSPC;
56604326b4SDaniel Borkmann 				break;
57604326b4SDaniel Borkmann 			}
58604326b4SDaniel Borkmann 
59604326b4SDaniel Borkmann 			sge = &msg->sg.data[msg->sg.end];
60604326b4SDaniel Borkmann 			sg_unmark_end(sge);
61604326b4SDaniel Borkmann 			sg_set_page(sge, pfrag->page, use, orig_offset);
62604326b4SDaniel Borkmann 			get_page(pfrag->page);
63604326b4SDaniel Borkmann 			sk_msg_iter_next(msg, end);
64604326b4SDaniel Borkmann 		}
65604326b4SDaniel Borkmann 
66604326b4SDaniel Borkmann 		sk_mem_charge(sk, use);
67604326b4SDaniel Borkmann 		msg->sg.size += use;
68604326b4SDaniel Borkmann 		pfrag->offset += use;
69604326b4SDaniel Borkmann 		len -= use;
70604326b4SDaniel Borkmann 	}
71604326b4SDaniel Borkmann 
72604326b4SDaniel Borkmann 	return ret;
73604326b4SDaniel Borkmann }
74604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_msg_alloc);
75604326b4SDaniel Borkmann 
76d829e9c4SDaniel Borkmann int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
77d829e9c4SDaniel Borkmann 		 u32 off, u32 len)
78d829e9c4SDaniel Borkmann {
79d829e9c4SDaniel Borkmann 	int i = src->sg.start;
80d829e9c4SDaniel Borkmann 	struct scatterlist *sge = sk_msg_elem(src, i);
81d829e9c4SDaniel Borkmann 	u32 sge_len, sge_off;
82d829e9c4SDaniel Borkmann 
83d829e9c4SDaniel Borkmann 	if (sk_msg_full(dst))
84d829e9c4SDaniel Borkmann 		return -ENOSPC;
85d829e9c4SDaniel Borkmann 
86d829e9c4SDaniel Borkmann 	while (off) {
87d829e9c4SDaniel Borkmann 		if (sge->length > off)
88d829e9c4SDaniel Borkmann 			break;
89d829e9c4SDaniel Borkmann 		off -= sge->length;
90d829e9c4SDaniel Borkmann 		sk_msg_iter_var_next(i);
91d829e9c4SDaniel Borkmann 		if (i == src->sg.end && off)
92d829e9c4SDaniel Borkmann 			return -ENOSPC;
93d829e9c4SDaniel Borkmann 		sge = sk_msg_elem(src, i);
94d829e9c4SDaniel Borkmann 	}
95d829e9c4SDaniel Borkmann 
96d829e9c4SDaniel Borkmann 	while (len) {
97d829e9c4SDaniel Borkmann 		sge_len = sge->length - off;
98d829e9c4SDaniel Borkmann 		sge_off = sge->offset + off;
99d829e9c4SDaniel Borkmann 		if (sge_len > len)
100d829e9c4SDaniel Borkmann 			sge_len = len;
101d829e9c4SDaniel Borkmann 		off = 0;
102d829e9c4SDaniel Borkmann 		len -= sge_len;
103d829e9c4SDaniel Borkmann 		sk_msg_page_add(dst, sg_page(sge), sge_len, sge_off);
104d829e9c4SDaniel Borkmann 		sk_mem_charge(sk, sge_len);
105d829e9c4SDaniel Borkmann 		sk_msg_iter_var_next(i);
106d829e9c4SDaniel Borkmann 		if (i == src->sg.end && len)
107d829e9c4SDaniel Borkmann 			return -ENOSPC;
108d829e9c4SDaniel Borkmann 		sge = sk_msg_elem(src, i);
109d829e9c4SDaniel Borkmann 	}
110d829e9c4SDaniel Borkmann 
111d829e9c4SDaniel Borkmann 	return 0;
112d829e9c4SDaniel Borkmann }
113d829e9c4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_msg_clone);
114d829e9c4SDaniel Borkmann 
115604326b4SDaniel Borkmann void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes)
116604326b4SDaniel Borkmann {
117604326b4SDaniel Borkmann 	int i = msg->sg.start;
118604326b4SDaniel Borkmann 
119604326b4SDaniel Borkmann 	do {
120604326b4SDaniel Borkmann 		struct scatterlist *sge = sk_msg_elem(msg, i);
121604326b4SDaniel Borkmann 
122604326b4SDaniel Borkmann 		if (bytes < sge->length) {
123604326b4SDaniel Borkmann 			sge->length -= bytes;
124604326b4SDaniel Borkmann 			sge->offset += bytes;
125604326b4SDaniel Borkmann 			sk_mem_uncharge(sk, bytes);
126604326b4SDaniel Borkmann 			break;
127604326b4SDaniel Borkmann 		}
128604326b4SDaniel Borkmann 
129604326b4SDaniel Borkmann 		sk_mem_uncharge(sk, sge->length);
130604326b4SDaniel Borkmann 		bytes -= sge->length;
131604326b4SDaniel Borkmann 		sge->length = 0;
132604326b4SDaniel Borkmann 		sge->offset = 0;
133604326b4SDaniel Borkmann 		sk_msg_iter_var_next(i);
134604326b4SDaniel Borkmann 	} while (bytes && i != msg->sg.end);
135604326b4SDaniel Borkmann 	msg->sg.start = i;
136604326b4SDaniel Borkmann }
137604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_msg_return_zero);
138604326b4SDaniel Borkmann 
139604326b4SDaniel Borkmann void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes)
140604326b4SDaniel Borkmann {
141604326b4SDaniel Borkmann 	int i = msg->sg.start;
142604326b4SDaniel Borkmann 
143604326b4SDaniel Borkmann 	do {
144604326b4SDaniel Borkmann 		struct scatterlist *sge = &msg->sg.data[i];
145604326b4SDaniel Borkmann 		int uncharge = (bytes < sge->length) ? bytes : sge->length;
146604326b4SDaniel Borkmann 
147604326b4SDaniel Borkmann 		sk_mem_uncharge(sk, uncharge);
148604326b4SDaniel Borkmann 		bytes -= uncharge;
149604326b4SDaniel Borkmann 		sk_msg_iter_var_next(i);
150604326b4SDaniel Borkmann 	} while (i != msg->sg.end);
151604326b4SDaniel Borkmann }
152604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_msg_return);
153604326b4SDaniel Borkmann 
154604326b4SDaniel Borkmann static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i,
155604326b4SDaniel Borkmann 			    bool charge)
156604326b4SDaniel Borkmann {
157604326b4SDaniel Borkmann 	struct scatterlist *sge = sk_msg_elem(msg, i);
158604326b4SDaniel Borkmann 	u32 len = sge->length;
159604326b4SDaniel Borkmann 
160604326b4SDaniel Borkmann 	if (charge)
161604326b4SDaniel Borkmann 		sk_mem_uncharge(sk, len);
162604326b4SDaniel Borkmann 	if (!msg->skb)
163604326b4SDaniel Borkmann 		put_page(sg_page(sge));
164604326b4SDaniel Borkmann 	memset(sge, 0, sizeof(*sge));
165604326b4SDaniel Borkmann 	return len;
166604326b4SDaniel Borkmann }
167604326b4SDaniel Borkmann 
168604326b4SDaniel Borkmann static int __sk_msg_free(struct sock *sk, struct sk_msg *msg, u32 i,
169604326b4SDaniel Borkmann 			 bool charge)
170604326b4SDaniel Borkmann {
171604326b4SDaniel Borkmann 	struct scatterlist *sge = sk_msg_elem(msg, i);
172604326b4SDaniel Borkmann 	int freed = 0;
173604326b4SDaniel Borkmann 
174604326b4SDaniel Borkmann 	while (msg->sg.size) {
175604326b4SDaniel Borkmann 		msg->sg.size -= sge->length;
176604326b4SDaniel Borkmann 		freed += sk_msg_free_elem(sk, msg, i, charge);
177604326b4SDaniel Borkmann 		sk_msg_iter_var_next(i);
178604326b4SDaniel Borkmann 		sk_msg_check_to_free(msg, i, msg->sg.size);
179604326b4SDaniel Borkmann 		sge = sk_msg_elem(msg, i);
180604326b4SDaniel Borkmann 	}
181604326b4SDaniel Borkmann 	if (msg->skb)
182604326b4SDaniel Borkmann 		consume_skb(msg->skb);
183604326b4SDaniel Borkmann 	sk_msg_init(msg);
184604326b4SDaniel Borkmann 	return freed;
185604326b4SDaniel Borkmann }
186604326b4SDaniel Borkmann 
187604326b4SDaniel Borkmann int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg)
188604326b4SDaniel Borkmann {
189604326b4SDaniel Borkmann 	return __sk_msg_free(sk, msg, msg->sg.start, false);
190604326b4SDaniel Borkmann }
191604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_msg_free_nocharge);
192604326b4SDaniel Borkmann 
193604326b4SDaniel Borkmann int sk_msg_free(struct sock *sk, struct sk_msg *msg)
194604326b4SDaniel Borkmann {
195604326b4SDaniel Borkmann 	return __sk_msg_free(sk, msg, msg->sg.start, true);
196604326b4SDaniel Borkmann }
197604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_msg_free);
198604326b4SDaniel Borkmann 
199604326b4SDaniel Borkmann static void __sk_msg_free_partial(struct sock *sk, struct sk_msg *msg,
200604326b4SDaniel Borkmann 				  u32 bytes, bool charge)
201604326b4SDaniel Borkmann {
202604326b4SDaniel Borkmann 	struct scatterlist *sge;
203604326b4SDaniel Borkmann 	u32 i = msg->sg.start;
204604326b4SDaniel Borkmann 
205604326b4SDaniel Borkmann 	while (bytes) {
206604326b4SDaniel Borkmann 		sge = sk_msg_elem(msg, i);
207604326b4SDaniel Borkmann 		if (!sge->length)
208604326b4SDaniel Borkmann 			break;
209604326b4SDaniel Borkmann 		if (bytes < sge->length) {
210604326b4SDaniel Borkmann 			if (charge)
211604326b4SDaniel Borkmann 				sk_mem_uncharge(sk, bytes);
212604326b4SDaniel Borkmann 			sge->length -= bytes;
213604326b4SDaniel Borkmann 			sge->offset += bytes;
214604326b4SDaniel Borkmann 			msg->sg.size -= bytes;
215604326b4SDaniel Borkmann 			break;
216604326b4SDaniel Borkmann 		}
217604326b4SDaniel Borkmann 
218604326b4SDaniel Borkmann 		msg->sg.size -= sge->length;
219604326b4SDaniel Borkmann 		bytes -= sge->length;
220604326b4SDaniel Borkmann 		sk_msg_free_elem(sk, msg, i, charge);
221604326b4SDaniel Borkmann 		sk_msg_iter_var_next(i);
222604326b4SDaniel Borkmann 		sk_msg_check_to_free(msg, i, bytes);
223604326b4SDaniel Borkmann 	}
224604326b4SDaniel Borkmann 	msg->sg.start = i;
225604326b4SDaniel Borkmann }
226604326b4SDaniel Borkmann 
227604326b4SDaniel Borkmann void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes)
228604326b4SDaniel Borkmann {
229604326b4SDaniel Borkmann 	__sk_msg_free_partial(sk, msg, bytes, true);
230604326b4SDaniel Borkmann }
231604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_msg_free_partial);
232604326b4SDaniel Borkmann 
233604326b4SDaniel Borkmann void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
234604326b4SDaniel Borkmann 				  u32 bytes)
235604326b4SDaniel Borkmann {
236604326b4SDaniel Borkmann 	__sk_msg_free_partial(sk, msg, bytes, false);
237604326b4SDaniel Borkmann }
238604326b4SDaniel Borkmann 
239604326b4SDaniel Borkmann void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len)
240604326b4SDaniel Borkmann {
241604326b4SDaniel Borkmann 	int trim = msg->sg.size - len;
242604326b4SDaniel Borkmann 	u32 i = msg->sg.end;
243604326b4SDaniel Borkmann 
244604326b4SDaniel Borkmann 	if (trim <= 0) {
245604326b4SDaniel Borkmann 		WARN_ON(trim < 0);
246604326b4SDaniel Borkmann 		return;
247604326b4SDaniel Borkmann 	}
248604326b4SDaniel Borkmann 
249604326b4SDaniel Borkmann 	sk_msg_iter_var_prev(i);
250604326b4SDaniel Borkmann 	msg->sg.size = len;
251604326b4SDaniel Borkmann 	while (msg->sg.data[i].length &&
252604326b4SDaniel Borkmann 	       trim >= msg->sg.data[i].length) {
253604326b4SDaniel Borkmann 		trim -= msg->sg.data[i].length;
254604326b4SDaniel Borkmann 		sk_msg_free_elem(sk, msg, i, true);
255604326b4SDaniel Borkmann 		sk_msg_iter_var_prev(i);
256604326b4SDaniel Borkmann 		if (!trim)
257604326b4SDaniel Borkmann 			goto out;
258604326b4SDaniel Borkmann 	}
259604326b4SDaniel Borkmann 
260604326b4SDaniel Borkmann 	msg->sg.data[i].length -= trim;
261604326b4SDaniel Borkmann 	sk_mem_uncharge(sk, trim);
262604326b4SDaniel Borkmann out:
263604326b4SDaniel Borkmann 	/* If we trim data before curr pointer update copybreak and current
264604326b4SDaniel Borkmann 	 * so that any future copy operations start at new copy location.
265604326b4SDaniel Borkmann 	 * However trimed data that has not yet been used in a copy op
266604326b4SDaniel Borkmann 	 * does not require an update.
267604326b4SDaniel Borkmann 	 */
268604326b4SDaniel Borkmann 	if (msg->sg.curr >= i) {
269604326b4SDaniel Borkmann 		msg->sg.curr = i;
270604326b4SDaniel Borkmann 		msg->sg.copybreak = msg->sg.data[i].length;
271604326b4SDaniel Borkmann 	}
272604326b4SDaniel Borkmann 	sk_msg_iter_var_next(i);
273604326b4SDaniel Borkmann 	msg->sg.end = i;
274604326b4SDaniel Borkmann }
275604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_msg_trim);
276604326b4SDaniel Borkmann 
277604326b4SDaniel Borkmann int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
278604326b4SDaniel Borkmann 			      struct sk_msg *msg, u32 bytes)
279604326b4SDaniel Borkmann {
280604326b4SDaniel Borkmann 	int i, maxpages, ret = 0, num_elems = sk_msg_elem_used(msg);
281604326b4SDaniel Borkmann 	const int to_max_pages = MAX_MSG_FRAGS;
282604326b4SDaniel Borkmann 	struct page *pages[MAX_MSG_FRAGS];
283604326b4SDaniel Borkmann 	ssize_t orig, copied, use, offset;
284604326b4SDaniel Borkmann 
285604326b4SDaniel Borkmann 	orig = msg->sg.size;
286604326b4SDaniel Borkmann 	while (bytes > 0) {
287604326b4SDaniel Borkmann 		i = 0;
288604326b4SDaniel Borkmann 		maxpages = to_max_pages - num_elems;
289604326b4SDaniel Borkmann 		if (maxpages == 0) {
290604326b4SDaniel Borkmann 			ret = -EFAULT;
291604326b4SDaniel Borkmann 			goto out;
292604326b4SDaniel Borkmann 		}
293604326b4SDaniel Borkmann 
294604326b4SDaniel Borkmann 		copied = iov_iter_get_pages(from, pages, bytes, maxpages,
295604326b4SDaniel Borkmann 					    &offset);
296604326b4SDaniel Borkmann 		if (copied <= 0) {
297604326b4SDaniel Borkmann 			ret = -EFAULT;
298604326b4SDaniel Borkmann 			goto out;
299604326b4SDaniel Borkmann 		}
300604326b4SDaniel Borkmann 
301604326b4SDaniel Borkmann 		iov_iter_advance(from, copied);
302604326b4SDaniel Borkmann 		bytes -= copied;
303604326b4SDaniel Borkmann 		msg->sg.size += copied;
304604326b4SDaniel Borkmann 
305604326b4SDaniel Borkmann 		while (copied) {
306604326b4SDaniel Borkmann 			use = min_t(int, copied, PAGE_SIZE - offset);
307604326b4SDaniel Borkmann 			sg_set_page(&msg->sg.data[msg->sg.end],
308604326b4SDaniel Borkmann 				    pages[i], use, offset);
309604326b4SDaniel Borkmann 			sg_unmark_end(&msg->sg.data[msg->sg.end]);
310604326b4SDaniel Borkmann 			sk_mem_charge(sk, use);
311604326b4SDaniel Borkmann 
312604326b4SDaniel Borkmann 			offset = 0;
313604326b4SDaniel Borkmann 			copied -= use;
314604326b4SDaniel Borkmann 			sk_msg_iter_next(msg, end);
315604326b4SDaniel Borkmann 			num_elems++;
316604326b4SDaniel Borkmann 			i++;
317604326b4SDaniel Borkmann 		}
318604326b4SDaniel Borkmann 		/* When zerocopy is mixed with sk_msg_*copy* operations we
319604326b4SDaniel Borkmann 		 * may have a copybreak set in this case clear and prefer
320604326b4SDaniel Borkmann 		 * zerocopy remainder when possible.
321604326b4SDaniel Borkmann 		 */
322604326b4SDaniel Borkmann 		msg->sg.copybreak = 0;
323604326b4SDaniel Borkmann 		msg->sg.curr = msg->sg.end;
324604326b4SDaniel Borkmann 	}
325604326b4SDaniel Borkmann out:
326604326b4SDaniel Borkmann 	/* Revert iov_iter updates, msg will need to use 'trim' later if it
327604326b4SDaniel Borkmann 	 * also needs to be cleared.
328604326b4SDaniel Borkmann 	 */
329604326b4SDaniel Borkmann 	if (ret)
330604326b4SDaniel Borkmann 		iov_iter_revert(from, msg->sg.size - orig);
331604326b4SDaniel Borkmann 	return ret;
332604326b4SDaniel Borkmann }
333604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter);
334604326b4SDaniel Borkmann 
335604326b4SDaniel Borkmann int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
336604326b4SDaniel Borkmann 			     struct sk_msg *msg, u32 bytes)
337604326b4SDaniel Borkmann {
338604326b4SDaniel Borkmann 	int ret = -ENOSPC, i = msg->sg.curr;
339604326b4SDaniel Borkmann 	struct scatterlist *sge;
340604326b4SDaniel Borkmann 	u32 copy, buf_size;
341604326b4SDaniel Borkmann 	void *to;
342604326b4SDaniel Borkmann 
343604326b4SDaniel Borkmann 	do {
344604326b4SDaniel Borkmann 		sge = sk_msg_elem(msg, i);
345604326b4SDaniel Borkmann 		/* This is possible if a trim operation shrunk the buffer */
346604326b4SDaniel Borkmann 		if (msg->sg.copybreak >= sge->length) {
347604326b4SDaniel Borkmann 			msg->sg.copybreak = 0;
348604326b4SDaniel Borkmann 			sk_msg_iter_var_next(i);
349604326b4SDaniel Borkmann 			if (i == msg->sg.end)
350604326b4SDaniel Borkmann 				break;
351604326b4SDaniel Borkmann 			sge = sk_msg_elem(msg, i);
352604326b4SDaniel Borkmann 		}
353604326b4SDaniel Borkmann 
354604326b4SDaniel Borkmann 		buf_size = sge->length - msg->sg.copybreak;
355604326b4SDaniel Borkmann 		copy = (buf_size > bytes) ? bytes : buf_size;
356604326b4SDaniel Borkmann 		to = sg_virt(sge) + msg->sg.copybreak;
357604326b4SDaniel Borkmann 		msg->sg.copybreak += copy;
358604326b4SDaniel Borkmann 		if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
359604326b4SDaniel Borkmann 			ret = copy_from_iter_nocache(to, copy, from);
360604326b4SDaniel Borkmann 		else
361604326b4SDaniel Borkmann 			ret = copy_from_iter(to, copy, from);
362604326b4SDaniel Borkmann 		if (ret != copy) {
363604326b4SDaniel Borkmann 			ret = -EFAULT;
364604326b4SDaniel Borkmann 			goto out;
365604326b4SDaniel Borkmann 		}
366604326b4SDaniel Borkmann 		bytes -= copy;
367604326b4SDaniel Borkmann 		if (!bytes)
368604326b4SDaniel Borkmann 			break;
369604326b4SDaniel Borkmann 		msg->sg.copybreak = 0;
370604326b4SDaniel Borkmann 		sk_msg_iter_var_next(i);
371604326b4SDaniel Borkmann 	} while (i != msg->sg.end);
372604326b4SDaniel Borkmann out:
373604326b4SDaniel Borkmann 	msg->sg.curr = i;
374604326b4SDaniel Borkmann 	return ret;
375604326b4SDaniel Borkmann }
376604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter);
377604326b4SDaniel Borkmann 
378604326b4SDaniel Borkmann static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb)
379604326b4SDaniel Borkmann {
380604326b4SDaniel Borkmann 	struct sock *sk = psock->sk;
381604326b4SDaniel Borkmann 	int copied = 0, num_sge;
382604326b4SDaniel Borkmann 	struct sk_msg *msg;
383604326b4SDaniel Borkmann 
384604326b4SDaniel Borkmann 	msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC);
385604326b4SDaniel Borkmann 	if (unlikely(!msg))
386604326b4SDaniel Borkmann 		return -EAGAIN;
387604326b4SDaniel Borkmann 	if (!sk_rmem_schedule(sk, skb, skb->len)) {
388604326b4SDaniel Borkmann 		kfree(msg);
389604326b4SDaniel Borkmann 		return -EAGAIN;
390604326b4SDaniel Borkmann 	}
391604326b4SDaniel Borkmann 
392604326b4SDaniel Borkmann 	sk_msg_init(msg);
393604326b4SDaniel Borkmann 	num_sge = skb_to_sgvec(skb, msg->sg.data, 0, skb->len);
394604326b4SDaniel Borkmann 	if (unlikely(num_sge < 0)) {
395604326b4SDaniel Borkmann 		kfree(msg);
396604326b4SDaniel Borkmann 		return num_sge;
397604326b4SDaniel Borkmann 	}
398604326b4SDaniel Borkmann 
399604326b4SDaniel Borkmann 	sk_mem_charge(sk, skb->len);
400604326b4SDaniel Borkmann 	copied = skb->len;
401604326b4SDaniel Borkmann 	msg->sg.start = 0;
402604326b4SDaniel Borkmann 	msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge;
403604326b4SDaniel Borkmann 	msg->skb = skb;
404604326b4SDaniel Borkmann 
405604326b4SDaniel Borkmann 	sk_psock_queue_msg(psock, msg);
406552de910SJohn Fastabend 	sk_psock_data_ready(sk, psock);
407604326b4SDaniel Borkmann 	return copied;
408604326b4SDaniel Borkmann }
409604326b4SDaniel Borkmann 
410604326b4SDaniel Borkmann static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb,
411604326b4SDaniel Borkmann 			       u32 off, u32 len, bool ingress)
412604326b4SDaniel Borkmann {
413604326b4SDaniel Borkmann 	if (ingress)
414604326b4SDaniel Borkmann 		return sk_psock_skb_ingress(psock, skb);
415604326b4SDaniel Borkmann 	else
416604326b4SDaniel Borkmann 		return skb_send_sock_locked(psock->sk, skb, off, len);
417604326b4SDaniel Borkmann }
418604326b4SDaniel Borkmann 
419604326b4SDaniel Borkmann static void sk_psock_backlog(struct work_struct *work)
420604326b4SDaniel Borkmann {
421604326b4SDaniel Borkmann 	struct sk_psock *psock = container_of(work, struct sk_psock, work);
422604326b4SDaniel Borkmann 	struct sk_psock_work_state *state = &psock->work_state;
423604326b4SDaniel Borkmann 	struct sk_buff *skb;
424604326b4SDaniel Borkmann 	bool ingress;
425604326b4SDaniel Borkmann 	u32 len, off;
426604326b4SDaniel Borkmann 	int ret;
427604326b4SDaniel Borkmann 
428604326b4SDaniel Borkmann 	/* Lock sock to avoid losing sk_socket during loop. */
429604326b4SDaniel Borkmann 	lock_sock(psock->sk);
430604326b4SDaniel Borkmann 	if (state->skb) {
431604326b4SDaniel Borkmann 		skb = state->skb;
432604326b4SDaniel Borkmann 		len = state->len;
433604326b4SDaniel Borkmann 		off = state->off;
434604326b4SDaniel Borkmann 		state->skb = NULL;
435604326b4SDaniel Borkmann 		goto start;
436604326b4SDaniel Borkmann 	}
437604326b4SDaniel Borkmann 
438604326b4SDaniel Borkmann 	while ((skb = skb_dequeue(&psock->ingress_skb))) {
439604326b4SDaniel Borkmann 		len = skb->len;
440604326b4SDaniel Borkmann 		off = 0;
441604326b4SDaniel Borkmann start:
442604326b4SDaniel Borkmann 		ingress = tcp_skb_bpf_ingress(skb);
443604326b4SDaniel Borkmann 		do {
444604326b4SDaniel Borkmann 			ret = -EIO;
445604326b4SDaniel Borkmann 			if (likely(psock->sk->sk_socket))
446604326b4SDaniel Borkmann 				ret = sk_psock_handle_skb(psock, skb, off,
447604326b4SDaniel Borkmann 							  len, ingress);
448604326b4SDaniel Borkmann 			if (ret <= 0) {
449604326b4SDaniel Borkmann 				if (ret == -EAGAIN) {
450604326b4SDaniel Borkmann 					state->skb = skb;
451604326b4SDaniel Borkmann 					state->len = len;
452604326b4SDaniel Borkmann 					state->off = off;
453604326b4SDaniel Borkmann 					goto end;
454604326b4SDaniel Borkmann 				}
455604326b4SDaniel Borkmann 				/* Hard errors break pipe and stop xmit. */
456604326b4SDaniel Borkmann 				sk_psock_report_error(psock, ret ? -ret : EPIPE);
457604326b4SDaniel Borkmann 				sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
458604326b4SDaniel Borkmann 				kfree_skb(skb);
459604326b4SDaniel Borkmann 				goto end;
460604326b4SDaniel Borkmann 			}
461604326b4SDaniel Borkmann 			off += ret;
462604326b4SDaniel Borkmann 			len -= ret;
463604326b4SDaniel Borkmann 		} while (len);
464604326b4SDaniel Borkmann 
465604326b4SDaniel Borkmann 		if (!ingress)
466604326b4SDaniel Borkmann 			kfree_skb(skb);
467604326b4SDaniel Borkmann 	}
468604326b4SDaniel Borkmann end:
469604326b4SDaniel Borkmann 	release_sock(psock->sk);
470604326b4SDaniel Borkmann }
471604326b4SDaniel Borkmann 
472604326b4SDaniel Borkmann struct sk_psock *sk_psock_init(struct sock *sk, int node)
473604326b4SDaniel Borkmann {
474604326b4SDaniel Borkmann 	struct sk_psock *psock = kzalloc_node(sizeof(*psock),
475604326b4SDaniel Borkmann 					      GFP_ATOMIC | __GFP_NOWARN,
476604326b4SDaniel Borkmann 					      node);
477604326b4SDaniel Borkmann 	if (!psock)
478604326b4SDaniel Borkmann 		return NULL;
479604326b4SDaniel Borkmann 
480604326b4SDaniel Borkmann 	psock->sk = sk;
481604326b4SDaniel Borkmann 	psock->eval =  __SK_NONE;
482604326b4SDaniel Borkmann 
483604326b4SDaniel Borkmann 	INIT_LIST_HEAD(&psock->link);
484604326b4SDaniel Borkmann 	spin_lock_init(&psock->link_lock);
485604326b4SDaniel Borkmann 
486604326b4SDaniel Borkmann 	INIT_WORK(&psock->work, sk_psock_backlog);
487604326b4SDaniel Borkmann 	INIT_LIST_HEAD(&psock->ingress_msg);
488604326b4SDaniel Borkmann 	skb_queue_head_init(&psock->ingress_skb);
489604326b4SDaniel Borkmann 
490604326b4SDaniel Borkmann 	sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED);
491604326b4SDaniel Borkmann 	refcount_set(&psock->refcnt, 1);
492604326b4SDaniel Borkmann 
493604326b4SDaniel Borkmann 	rcu_assign_sk_user_data(sk, psock);
494604326b4SDaniel Borkmann 	sock_hold(sk);
495604326b4SDaniel Borkmann 
496604326b4SDaniel Borkmann 	return psock;
497604326b4SDaniel Borkmann }
498604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_psock_init);
499604326b4SDaniel Borkmann 
500604326b4SDaniel Borkmann struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock)
501604326b4SDaniel Borkmann {
502604326b4SDaniel Borkmann 	struct sk_psock_link *link;
503604326b4SDaniel Borkmann 
504604326b4SDaniel Borkmann 	spin_lock_bh(&psock->link_lock);
505604326b4SDaniel Borkmann 	link = list_first_entry_or_null(&psock->link, struct sk_psock_link,
506604326b4SDaniel Borkmann 					list);
507604326b4SDaniel Borkmann 	if (link)
508604326b4SDaniel Borkmann 		list_del(&link->list);
509604326b4SDaniel Borkmann 	spin_unlock_bh(&psock->link_lock);
510604326b4SDaniel Borkmann 	return link;
511604326b4SDaniel Borkmann }
512604326b4SDaniel Borkmann 
513604326b4SDaniel Borkmann void __sk_psock_purge_ingress_msg(struct sk_psock *psock)
514604326b4SDaniel Borkmann {
515604326b4SDaniel Borkmann 	struct sk_msg *msg, *tmp;
516604326b4SDaniel Borkmann 
517604326b4SDaniel Borkmann 	list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) {
518604326b4SDaniel Borkmann 		list_del(&msg->list);
519604326b4SDaniel Borkmann 		sk_msg_free(psock->sk, msg);
520604326b4SDaniel Borkmann 		kfree(msg);
521604326b4SDaniel Borkmann 	}
522604326b4SDaniel Borkmann }
523604326b4SDaniel Borkmann 
524604326b4SDaniel Borkmann static void sk_psock_zap_ingress(struct sk_psock *psock)
525604326b4SDaniel Borkmann {
526604326b4SDaniel Borkmann 	__skb_queue_purge(&psock->ingress_skb);
527604326b4SDaniel Borkmann 	__sk_psock_purge_ingress_msg(psock);
528604326b4SDaniel Borkmann }
529604326b4SDaniel Borkmann 
530604326b4SDaniel Borkmann static void sk_psock_link_destroy(struct sk_psock *psock)
531604326b4SDaniel Borkmann {
532604326b4SDaniel Borkmann 	struct sk_psock_link *link, *tmp;
533604326b4SDaniel Borkmann 
534604326b4SDaniel Borkmann 	list_for_each_entry_safe(link, tmp, &psock->link, list) {
535604326b4SDaniel Borkmann 		list_del(&link->list);
536604326b4SDaniel Borkmann 		sk_psock_free_link(link);
537604326b4SDaniel Borkmann 	}
538604326b4SDaniel Borkmann }
539604326b4SDaniel Borkmann 
540604326b4SDaniel Borkmann static void sk_psock_destroy_deferred(struct work_struct *gc)
541604326b4SDaniel Borkmann {
542604326b4SDaniel Borkmann 	struct sk_psock *psock = container_of(gc, struct sk_psock, gc);
543604326b4SDaniel Borkmann 
544604326b4SDaniel Borkmann 	/* No sk_callback_lock since already detached. */
545604326b4SDaniel Borkmann 	if (psock->parser.enabled)
546604326b4SDaniel Borkmann 		strp_done(&psock->parser.strp);
547604326b4SDaniel Borkmann 
548604326b4SDaniel Borkmann 	cancel_work_sync(&psock->work);
549604326b4SDaniel Borkmann 
550604326b4SDaniel Borkmann 	psock_progs_drop(&psock->progs);
551604326b4SDaniel Borkmann 
552604326b4SDaniel Borkmann 	sk_psock_link_destroy(psock);
553604326b4SDaniel Borkmann 	sk_psock_cork_free(psock);
554604326b4SDaniel Borkmann 	sk_psock_zap_ingress(psock);
555604326b4SDaniel Borkmann 
556604326b4SDaniel Borkmann 	if (psock->sk_redir)
557604326b4SDaniel Borkmann 		sock_put(psock->sk_redir);
558604326b4SDaniel Borkmann 	sock_put(psock->sk);
559604326b4SDaniel Borkmann 	kfree(psock);
560604326b4SDaniel Borkmann }
561604326b4SDaniel Borkmann 
562604326b4SDaniel Borkmann void sk_psock_destroy(struct rcu_head *rcu)
563604326b4SDaniel Borkmann {
564604326b4SDaniel Borkmann 	struct sk_psock *psock = container_of(rcu, struct sk_psock, rcu);
565604326b4SDaniel Borkmann 
566604326b4SDaniel Borkmann 	INIT_WORK(&psock->gc, sk_psock_destroy_deferred);
567604326b4SDaniel Borkmann 	schedule_work(&psock->gc);
568604326b4SDaniel Borkmann }
569604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_psock_destroy);
570604326b4SDaniel Borkmann 
571604326b4SDaniel Borkmann void sk_psock_drop(struct sock *sk, struct sk_psock *psock)
572604326b4SDaniel Borkmann {
573604326b4SDaniel Borkmann 	rcu_assign_sk_user_data(sk, NULL);
574604326b4SDaniel Borkmann 	sk_psock_cork_free(psock);
575*a136678cSJohn Fastabend 	sk_psock_zap_ingress(psock);
576604326b4SDaniel Borkmann 	sk_psock_restore_proto(sk, psock);
577604326b4SDaniel Borkmann 
578604326b4SDaniel Borkmann 	write_lock_bh(&sk->sk_callback_lock);
579604326b4SDaniel Borkmann 	if (psock->progs.skb_parser)
580604326b4SDaniel Borkmann 		sk_psock_stop_strp(sk, psock);
581604326b4SDaniel Borkmann 	write_unlock_bh(&sk->sk_callback_lock);
582604326b4SDaniel Borkmann 	sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
583604326b4SDaniel Borkmann 
584604326b4SDaniel Borkmann 	call_rcu_sched(&psock->rcu, sk_psock_destroy);
585604326b4SDaniel Borkmann }
586604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_psock_drop);
587604326b4SDaniel Borkmann 
588604326b4SDaniel Borkmann static int sk_psock_map_verd(int verdict, bool redir)
589604326b4SDaniel Borkmann {
590604326b4SDaniel Borkmann 	switch (verdict) {
591604326b4SDaniel Borkmann 	case SK_PASS:
592604326b4SDaniel Borkmann 		return redir ? __SK_REDIRECT : __SK_PASS;
593604326b4SDaniel Borkmann 	case SK_DROP:
594604326b4SDaniel Borkmann 	default:
595604326b4SDaniel Borkmann 		break;
596604326b4SDaniel Borkmann 	}
597604326b4SDaniel Borkmann 
598604326b4SDaniel Borkmann 	return __SK_DROP;
599604326b4SDaniel Borkmann }
600604326b4SDaniel Borkmann 
601604326b4SDaniel Borkmann int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
602604326b4SDaniel Borkmann 			 struct sk_msg *msg)
603604326b4SDaniel Borkmann {
604604326b4SDaniel Borkmann 	struct bpf_prog *prog;
605604326b4SDaniel Borkmann 	int ret;
606604326b4SDaniel Borkmann 
607604326b4SDaniel Borkmann 	preempt_disable();
608604326b4SDaniel Borkmann 	rcu_read_lock();
609604326b4SDaniel Borkmann 	prog = READ_ONCE(psock->progs.msg_parser);
610604326b4SDaniel Borkmann 	if (unlikely(!prog)) {
611604326b4SDaniel Borkmann 		ret = __SK_PASS;
612604326b4SDaniel Borkmann 		goto out;
613604326b4SDaniel Borkmann 	}
614604326b4SDaniel Borkmann 
615604326b4SDaniel Borkmann 	sk_msg_compute_data_pointers(msg);
616604326b4SDaniel Borkmann 	msg->sk = sk;
617604326b4SDaniel Borkmann 	ret = BPF_PROG_RUN(prog, msg);
618604326b4SDaniel Borkmann 	ret = sk_psock_map_verd(ret, msg->sk_redir);
619604326b4SDaniel Borkmann 	psock->apply_bytes = msg->apply_bytes;
620604326b4SDaniel Borkmann 	if (ret == __SK_REDIRECT) {
621604326b4SDaniel Borkmann 		if (psock->sk_redir)
622604326b4SDaniel Borkmann 			sock_put(psock->sk_redir);
623604326b4SDaniel Borkmann 		psock->sk_redir = msg->sk_redir;
624604326b4SDaniel Borkmann 		if (!psock->sk_redir) {
625604326b4SDaniel Borkmann 			ret = __SK_DROP;
626604326b4SDaniel Borkmann 			goto out;
627604326b4SDaniel Borkmann 		}
628604326b4SDaniel Borkmann 		sock_hold(psock->sk_redir);
629604326b4SDaniel Borkmann 	}
630604326b4SDaniel Borkmann out:
631604326b4SDaniel Borkmann 	rcu_read_unlock();
632604326b4SDaniel Borkmann 	preempt_enable();
633604326b4SDaniel Borkmann 	return ret;
634604326b4SDaniel Borkmann }
635604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_psock_msg_verdict);
636604326b4SDaniel Borkmann 
637604326b4SDaniel Borkmann static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog,
638604326b4SDaniel Borkmann 			    struct sk_buff *skb)
639604326b4SDaniel Borkmann {
640604326b4SDaniel Borkmann 	int ret;
641604326b4SDaniel Borkmann 
642604326b4SDaniel Borkmann 	skb->sk = psock->sk;
643604326b4SDaniel Borkmann 	bpf_compute_data_end_sk_skb(skb);
644604326b4SDaniel Borkmann 	preempt_disable();
645604326b4SDaniel Borkmann 	ret = BPF_PROG_RUN(prog, skb);
646604326b4SDaniel Borkmann 	preempt_enable();
647604326b4SDaniel Borkmann 	/* strparser clones the skb before handing it to a upper layer,
648604326b4SDaniel Borkmann 	 * meaning skb_orphan has been called. We NULL sk on the way out
649604326b4SDaniel Borkmann 	 * to ensure we don't trigger a BUG_ON() in skb/sk operations
650604326b4SDaniel Borkmann 	 * later and because we are not charging the memory of this skb
651604326b4SDaniel Borkmann 	 * to any socket yet.
652604326b4SDaniel Borkmann 	 */
653604326b4SDaniel Borkmann 	skb->sk = NULL;
654604326b4SDaniel Borkmann 	return ret;
655604326b4SDaniel Borkmann }
656604326b4SDaniel Borkmann 
657604326b4SDaniel Borkmann static struct sk_psock *sk_psock_from_strp(struct strparser *strp)
658604326b4SDaniel Borkmann {
659604326b4SDaniel Borkmann 	struct sk_psock_parser *parser;
660604326b4SDaniel Borkmann 
661604326b4SDaniel Borkmann 	parser = container_of(strp, struct sk_psock_parser, strp);
662604326b4SDaniel Borkmann 	return container_of(parser, struct sk_psock, parser);
663604326b4SDaniel Borkmann }
664604326b4SDaniel Borkmann 
665604326b4SDaniel Borkmann static void sk_psock_verdict_apply(struct sk_psock *psock,
666604326b4SDaniel Borkmann 				   struct sk_buff *skb, int verdict)
667604326b4SDaniel Borkmann {
668604326b4SDaniel Borkmann 	struct sk_psock *psock_other;
669604326b4SDaniel Borkmann 	struct sock *sk_other;
670604326b4SDaniel Borkmann 	bool ingress;
671604326b4SDaniel Borkmann 
672604326b4SDaniel Borkmann 	switch (verdict) {
67351199405SJohn Fastabend 	case __SK_PASS:
67451199405SJohn Fastabend 		sk_other = psock->sk;
67551199405SJohn Fastabend 		if (sock_flag(sk_other, SOCK_DEAD) ||
67651199405SJohn Fastabend 		    !sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) {
67751199405SJohn Fastabend 			goto out_free;
67851199405SJohn Fastabend 		}
67951199405SJohn Fastabend 		if (atomic_read(&sk_other->sk_rmem_alloc) <=
68051199405SJohn Fastabend 		    sk_other->sk_rcvbuf) {
68151199405SJohn Fastabend 			struct tcp_skb_cb *tcp = TCP_SKB_CB(skb);
68251199405SJohn Fastabend 
68351199405SJohn Fastabend 			tcp->bpf.flags |= BPF_F_INGRESS;
68451199405SJohn Fastabend 			skb_queue_tail(&psock->ingress_skb, skb);
68551199405SJohn Fastabend 			schedule_work(&psock->work);
68651199405SJohn Fastabend 			break;
68751199405SJohn Fastabend 		}
68851199405SJohn Fastabend 		goto out_free;
689604326b4SDaniel Borkmann 	case __SK_REDIRECT:
690604326b4SDaniel Borkmann 		sk_other = tcp_skb_bpf_redirect_fetch(skb);
691604326b4SDaniel Borkmann 		if (unlikely(!sk_other))
692604326b4SDaniel Borkmann 			goto out_free;
693604326b4SDaniel Borkmann 		psock_other = sk_psock(sk_other);
694604326b4SDaniel Borkmann 		if (!psock_other || sock_flag(sk_other, SOCK_DEAD) ||
695604326b4SDaniel Borkmann 		    !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED))
696604326b4SDaniel Borkmann 			goto out_free;
697604326b4SDaniel Borkmann 		ingress = tcp_skb_bpf_ingress(skb);
698604326b4SDaniel Borkmann 		if ((!ingress && sock_writeable(sk_other)) ||
699604326b4SDaniel Borkmann 		    (ingress &&
700604326b4SDaniel Borkmann 		     atomic_read(&sk_other->sk_rmem_alloc) <=
701604326b4SDaniel Borkmann 		     sk_other->sk_rcvbuf)) {
702604326b4SDaniel Borkmann 			if (!ingress)
703604326b4SDaniel Borkmann 				skb_set_owner_w(skb, sk_other);
704604326b4SDaniel Borkmann 			skb_queue_tail(&psock_other->ingress_skb, skb);
705604326b4SDaniel Borkmann 			schedule_work(&psock_other->work);
706604326b4SDaniel Borkmann 			break;
707604326b4SDaniel Borkmann 		}
708604326b4SDaniel Borkmann 		/* fall-through */
709604326b4SDaniel Borkmann 	case __SK_DROP:
710604326b4SDaniel Borkmann 		/* fall-through */
711604326b4SDaniel Borkmann 	default:
712604326b4SDaniel Borkmann out_free:
713604326b4SDaniel Borkmann 		kfree_skb(skb);
714604326b4SDaniel Borkmann 	}
715604326b4SDaniel Borkmann }
716604326b4SDaniel Borkmann 
717604326b4SDaniel Borkmann static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb)
718604326b4SDaniel Borkmann {
719604326b4SDaniel Borkmann 	struct sk_psock *psock = sk_psock_from_strp(strp);
720604326b4SDaniel Borkmann 	struct bpf_prog *prog;
721604326b4SDaniel Borkmann 	int ret = __SK_DROP;
722604326b4SDaniel Borkmann 
723604326b4SDaniel Borkmann 	rcu_read_lock();
724604326b4SDaniel Borkmann 	prog = READ_ONCE(psock->progs.skb_verdict);
725604326b4SDaniel Borkmann 	if (likely(prog)) {
726604326b4SDaniel Borkmann 		skb_orphan(skb);
727604326b4SDaniel Borkmann 		tcp_skb_bpf_redirect_clear(skb);
728604326b4SDaniel Borkmann 		ret = sk_psock_bpf_run(psock, prog, skb);
729604326b4SDaniel Borkmann 		ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb));
730604326b4SDaniel Borkmann 	}
731604326b4SDaniel Borkmann 	rcu_read_unlock();
732604326b4SDaniel Borkmann 	sk_psock_verdict_apply(psock, skb, ret);
733604326b4SDaniel Borkmann }
734604326b4SDaniel Borkmann 
735604326b4SDaniel Borkmann static int sk_psock_strp_read_done(struct strparser *strp, int err)
736604326b4SDaniel Borkmann {
737604326b4SDaniel Borkmann 	return err;
738604326b4SDaniel Borkmann }
739604326b4SDaniel Borkmann 
740604326b4SDaniel Borkmann static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb)
741604326b4SDaniel Borkmann {
742604326b4SDaniel Borkmann 	struct sk_psock *psock = sk_psock_from_strp(strp);
743604326b4SDaniel Borkmann 	struct bpf_prog *prog;
744604326b4SDaniel Borkmann 	int ret = skb->len;
745604326b4SDaniel Borkmann 
746604326b4SDaniel Borkmann 	rcu_read_lock();
747604326b4SDaniel Borkmann 	prog = READ_ONCE(psock->progs.skb_parser);
748604326b4SDaniel Borkmann 	if (likely(prog))
749604326b4SDaniel Borkmann 		ret = sk_psock_bpf_run(psock, prog, skb);
750604326b4SDaniel Borkmann 	rcu_read_unlock();
751604326b4SDaniel Borkmann 	return ret;
752604326b4SDaniel Borkmann }
753604326b4SDaniel Borkmann 
754604326b4SDaniel Borkmann /* Called with socket lock held. */
755552de910SJohn Fastabend static void sk_psock_strp_data_ready(struct sock *sk)
756604326b4SDaniel Borkmann {
757604326b4SDaniel Borkmann 	struct sk_psock *psock;
758604326b4SDaniel Borkmann 
759604326b4SDaniel Borkmann 	rcu_read_lock();
760604326b4SDaniel Borkmann 	psock = sk_psock(sk);
761604326b4SDaniel Borkmann 	if (likely(psock)) {
762604326b4SDaniel Borkmann 		write_lock_bh(&sk->sk_callback_lock);
763604326b4SDaniel Borkmann 		strp_data_ready(&psock->parser.strp);
764604326b4SDaniel Borkmann 		write_unlock_bh(&sk->sk_callback_lock);
765604326b4SDaniel Borkmann 	}
766604326b4SDaniel Borkmann 	rcu_read_unlock();
767604326b4SDaniel Borkmann }
768604326b4SDaniel Borkmann 
769604326b4SDaniel Borkmann static void sk_psock_write_space(struct sock *sk)
770604326b4SDaniel Borkmann {
771604326b4SDaniel Borkmann 	struct sk_psock *psock;
772604326b4SDaniel Borkmann 	void (*write_space)(struct sock *sk);
773604326b4SDaniel Borkmann 
774604326b4SDaniel Borkmann 	rcu_read_lock();
775604326b4SDaniel Borkmann 	psock = sk_psock(sk);
776604326b4SDaniel Borkmann 	if (likely(psock && sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)))
777604326b4SDaniel Borkmann 		schedule_work(&psock->work);
778604326b4SDaniel Borkmann 	write_space = psock->saved_write_space;
779604326b4SDaniel Borkmann 	rcu_read_unlock();
780604326b4SDaniel Borkmann 	write_space(sk);
781604326b4SDaniel Borkmann }
782604326b4SDaniel Borkmann 
783604326b4SDaniel Borkmann int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock)
784604326b4SDaniel Borkmann {
785604326b4SDaniel Borkmann 	static const struct strp_callbacks cb = {
786604326b4SDaniel Borkmann 		.rcv_msg	= sk_psock_strp_read,
787604326b4SDaniel Borkmann 		.read_sock_done	= sk_psock_strp_read_done,
788604326b4SDaniel Borkmann 		.parse_msg	= sk_psock_strp_parse,
789604326b4SDaniel Borkmann 	};
790604326b4SDaniel Borkmann 
791604326b4SDaniel Borkmann 	psock->parser.enabled = false;
792604326b4SDaniel Borkmann 	return strp_init(&psock->parser.strp, sk, &cb);
793604326b4SDaniel Borkmann }
794604326b4SDaniel Borkmann 
795604326b4SDaniel Borkmann void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock)
796604326b4SDaniel Borkmann {
797604326b4SDaniel Borkmann 	struct sk_psock_parser *parser = &psock->parser;
798604326b4SDaniel Borkmann 
799604326b4SDaniel Borkmann 	if (parser->enabled)
800604326b4SDaniel Borkmann 		return;
801604326b4SDaniel Borkmann 
802604326b4SDaniel Borkmann 	parser->saved_data_ready = sk->sk_data_ready;
803552de910SJohn Fastabend 	sk->sk_data_ready = sk_psock_strp_data_ready;
804604326b4SDaniel Borkmann 	sk->sk_write_space = sk_psock_write_space;
805604326b4SDaniel Borkmann 	parser->enabled = true;
806604326b4SDaniel Borkmann }
807604326b4SDaniel Borkmann 
808604326b4SDaniel Borkmann void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock)
809604326b4SDaniel Borkmann {
810604326b4SDaniel Borkmann 	struct sk_psock_parser *parser = &psock->parser;
811604326b4SDaniel Borkmann 
812604326b4SDaniel Borkmann 	if (!parser->enabled)
813604326b4SDaniel Borkmann 		return;
814604326b4SDaniel Borkmann 
815604326b4SDaniel Borkmann 	sk->sk_data_ready = parser->saved_data_ready;
816604326b4SDaniel Borkmann 	parser->saved_data_ready = NULL;
817604326b4SDaniel Borkmann 	strp_stop(&parser->strp);
818604326b4SDaniel Borkmann 	parser->enabled = false;
819604326b4SDaniel Borkmann }
820