xref: /linux/net/core/skmsg.c (revision ca2f5f21dbbd5e3a00cd3e97f728aa2ca0b2e011)
1604326b4SDaniel Borkmann // SPDX-License-Identifier: GPL-2.0
2604326b4SDaniel Borkmann /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
3604326b4SDaniel Borkmann 
4604326b4SDaniel Borkmann #include <linux/skmsg.h>
5604326b4SDaniel Borkmann #include <linux/skbuff.h>
6604326b4SDaniel Borkmann #include <linux/scatterlist.h>
7604326b4SDaniel Borkmann 
8604326b4SDaniel Borkmann #include <net/sock.h>
9604326b4SDaniel Borkmann #include <net/tcp.h>
10604326b4SDaniel Borkmann 
11604326b4SDaniel Borkmann static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce)
12604326b4SDaniel Borkmann {
13604326b4SDaniel Borkmann 	if (msg->sg.end > msg->sg.start &&
14604326b4SDaniel Borkmann 	    elem_first_coalesce < msg->sg.end)
15604326b4SDaniel Borkmann 		return true;
16604326b4SDaniel Borkmann 
17604326b4SDaniel Borkmann 	if (msg->sg.end < msg->sg.start &&
18604326b4SDaniel Borkmann 	    (elem_first_coalesce > msg->sg.start ||
19604326b4SDaniel Borkmann 	     elem_first_coalesce < msg->sg.end))
20604326b4SDaniel Borkmann 		return true;
21604326b4SDaniel Borkmann 
22604326b4SDaniel Borkmann 	return false;
23604326b4SDaniel Borkmann }
24604326b4SDaniel Borkmann 
25604326b4SDaniel Borkmann int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
26604326b4SDaniel Borkmann 		 int elem_first_coalesce)
27604326b4SDaniel Borkmann {
28604326b4SDaniel Borkmann 	struct page_frag *pfrag = sk_page_frag(sk);
29604326b4SDaniel Borkmann 	int ret = 0;
30604326b4SDaniel Borkmann 
31604326b4SDaniel Borkmann 	len -= msg->sg.size;
32604326b4SDaniel Borkmann 	while (len > 0) {
33604326b4SDaniel Borkmann 		struct scatterlist *sge;
34604326b4SDaniel Borkmann 		u32 orig_offset;
35604326b4SDaniel Borkmann 		int use, i;
36604326b4SDaniel Borkmann 
37604326b4SDaniel Borkmann 		if (!sk_page_frag_refill(sk, pfrag))
38604326b4SDaniel Borkmann 			return -ENOMEM;
39604326b4SDaniel Borkmann 
40604326b4SDaniel Borkmann 		orig_offset = pfrag->offset;
41604326b4SDaniel Borkmann 		use = min_t(int, len, pfrag->size - orig_offset);
42604326b4SDaniel Borkmann 		if (!sk_wmem_schedule(sk, use))
43604326b4SDaniel Borkmann 			return -ENOMEM;
44604326b4SDaniel Borkmann 
45604326b4SDaniel Borkmann 		i = msg->sg.end;
46604326b4SDaniel Borkmann 		sk_msg_iter_var_prev(i);
47604326b4SDaniel Borkmann 		sge = &msg->sg.data[i];
48604326b4SDaniel Borkmann 
49604326b4SDaniel Borkmann 		if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) &&
50604326b4SDaniel Borkmann 		    sg_page(sge) == pfrag->page &&
51604326b4SDaniel Borkmann 		    sge->offset + sge->length == orig_offset) {
52604326b4SDaniel Borkmann 			sge->length += use;
53604326b4SDaniel Borkmann 		} else {
54604326b4SDaniel Borkmann 			if (sk_msg_full(msg)) {
55604326b4SDaniel Borkmann 				ret = -ENOSPC;
56604326b4SDaniel Borkmann 				break;
57604326b4SDaniel Borkmann 			}
58604326b4SDaniel Borkmann 
59604326b4SDaniel Borkmann 			sge = &msg->sg.data[msg->sg.end];
60604326b4SDaniel Borkmann 			sg_unmark_end(sge);
61604326b4SDaniel Borkmann 			sg_set_page(sge, pfrag->page, use, orig_offset);
62604326b4SDaniel Borkmann 			get_page(pfrag->page);
63604326b4SDaniel Borkmann 			sk_msg_iter_next(msg, end);
64604326b4SDaniel Borkmann 		}
65604326b4SDaniel Borkmann 
66604326b4SDaniel Borkmann 		sk_mem_charge(sk, use);
67604326b4SDaniel Borkmann 		msg->sg.size += use;
68604326b4SDaniel Borkmann 		pfrag->offset += use;
69604326b4SDaniel Borkmann 		len -= use;
70604326b4SDaniel Borkmann 	}
71604326b4SDaniel Borkmann 
72604326b4SDaniel Borkmann 	return ret;
73604326b4SDaniel Borkmann }
74604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_msg_alloc);
75604326b4SDaniel Borkmann 
76d829e9c4SDaniel Borkmann int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
77d829e9c4SDaniel Borkmann 		 u32 off, u32 len)
78d829e9c4SDaniel Borkmann {
79d829e9c4SDaniel Borkmann 	int i = src->sg.start;
80d829e9c4SDaniel Borkmann 	struct scatterlist *sge = sk_msg_elem(src, i);
81fda497e5SVakul Garg 	struct scatterlist *sgd = NULL;
82d829e9c4SDaniel Borkmann 	u32 sge_len, sge_off;
83d829e9c4SDaniel Borkmann 
84d829e9c4SDaniel Borkmann 	while (off) {
85d829e9c4SDaniel Borkmann 		if (sge->length > off)
86d829e9c4SDaniel Borkmann 			break;
87d829e9c4SDaniel Borkmann 		off -= sge->length;
88d829e9c4SDaniel Borkmann 		sk_msg_iter_var_next(i);
89d829e9c4SDaniel Borkmann 		if (i == src->sg.end && off)
90d829e9c4SDaniel Borkmann 			return -ENOSPC;
91d829e9c4SDaniel Borkmann 		sge = sk_msg_elem(src, i);
92d829e9c4SDaniel Borkmann 	}
93d829e9c4SDaniel Borkmann 
94d829e9c4SDaniel Borkmann 	while (len) {
95d829e9c4SDaniel Borkmann 		sge_len = sge->length - off;
96d829e9c4SDaniel Borkmann 		if (sge_len > len)
97d829e9c4SDaniel Borkmann 			sge_len = len;
98fda497e5SVakul Garg 
99fda497e5SVakul Garg 		if (dst->sg.end)
100fda497e5SVakul Garg 			sgd = sk_msg_elem(dst, dst->sg.end - 1);
101fda497e5SVakul Garg 
102fda497e5SVakul Garg 		if (sgd &&
103fda497e5SVakul Garg 		    (sg_page(sge) == sg_page(sgd)) &&
104fda497e5SVakul Garg 		    (sg_virt(sge) + off == sg_virt(sgd) + sgd->length)) {
105fda497e5SVakul Garg 			sgd->length += sge_len;
106fda497e5SVakul Garg 			dst->sg.size += sge_len;
107fda497e5SVakul Garg 		} else if (!sk_msg_full(dst)) {
108fda497e5SVakul Garg 			sge_off = sge->offset + off;
109fda497e5SVakul Garg 			sk_msg_page_add(dst, sg_page(sge), sge_len, sge_off);
110fda497e5SVakul Garg 		} else {
111fda497e5SVakul Garg 			return -ENOSPC;
112fda497e5SVakul Garg 		}
113fda497e5SVakul Garg 
114d829e9c4SDaniel Borkmann 		off = 0;
115d829e9c4SDaniel Borkmann 		len -= sge_len;
116d829e9c4SDaniel Borkmann 		sk_mem_charge(sk, sge_len);
117d829e9c4SDaniel Borkmann 		sk_msg_iter_var_next(i);
118d829e9c4SDaniel Borkmann 		if (i == src->sg.end && len)
119d829e9c4SDaniel Borkmann 			return -ENOSPC;
120d829e9c4SDaniel Borkmann 		sge = sk_msg_elem(src, i);
121d829e9c4SDaniel Borkmann 	}
122d829e9c4SDaniel Borkmann 
123d829e9c4SDaniel Borkmann 	return 0;
124d829e9c4SDaniel Borkmann }
125d829e9c4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_msg_clone);
126d829e9c4SDaniel Borkmann 
127604326b4SDaniel Borkmann void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes)
128604326b4SDaniel Borkmann {
129604326b4SDaniel Borkmann 	int i = msg->sg.start;
130604326b4SDaniel Borkmann 
131604326b4SDaniel Borkmann 	do {
132604326b4SDaniel Borkmann 		struct scatterlist *sge = sk_msg_elem(msg, i);
133604326b4SDaniel Borkmann 
134604326b4SDaniel Borkmann 		if (bytes < sge->length) {
135604326b4SDaniel Borkmann 			sge->length -= bytes;
136604326b4SDaniel Borkmann 			sge->offset += bytes;
137604326b4SDaniel Borkmann 			sk_mem_uncharge(sk, bytes);
138604326b4SDaniel Borkmann 			break;
139604326b4SDaniel Borkmann 		}
140604326b4SDaniel Borkmann 
141604326b4SDaniel Borkmann 		sk_mem_uncharge(sk, sge->length);
142604326b4SDaniel Borkmann 		bytes -= sge->length;
143604326b4SDaniel Borkmann 		sge->length = 0;
144604326b4SDaniel Borkmann 		sge->offset = 0;
145604326b4SDaniel Borkmann 		sk_msg_iter_var_next(i);
146604326b4SDaniel Borkmann 	} while (bytes && i != msg->sg.end);
147604326b4SDaniel Borkmann 	msg->sg.start = i;
148604326b4SDaniel Borkmann }
149604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_msg_return_zero);
150604326b4SDaniel Borkmann 
151604326b4SDaniel Borkmann void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes)
152604326b4SDaniel Borkmann {
153604326b4SDaniel Borkmann 	int i = msg->sg.start;
154604326b4SDaniel Borkmann 
155604326b4SDaniel Borkmann 	do {
156604326b4SDaniel Borkmann 		struct scatterlist *sge = &msg->sg.data[i];
157604326b4SDaniel Borkmann 		int uncharge = (bytes < sge->length) ? bytes : sge->length;
158604326b4SDaniel Borkmann 
159604326b4SDaniel Borkmann 		sk_mem_uncharge(sk, uncharge);
160604326b4SDaniel Borkmann 		bytes -= uncharge;
161604326b4SDaniel Borkmann 		sk_msg_iter_var_next(i);
162604326b4SDaniel Borkmann 	} while (i != msg->sg.end);
163604326b4SDaniel Borkmann }
164604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_msg_return);
165604326b4SDaniel Borkmann 
166604326b4SDaniel Borkmann static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i,
167604326b4SDaniel Borkmann 			    bool charge)
168604326b4SDaniel Borkmann {
169604326b4SDaniel Borkmann 	struct scatterlist *sge = sk_msg_elem(msg, i);
170604326b4SDaniel Borkmann 	u32 len = sge->length;
171604326b4SDaniel Borkmann 
172604326b4SDaniel Borkmann 	if (charge)
173604326b4SDaniel Borkmann 		sk_mem_uncharge(sk, len);
174604326b4SDaniel Borkmann 	if (!msg->skb)
175604326b4SDaniel Borkmann 		put_page(sg_page(sge));
176604326b4SDaniel Borkmann 	memset(sge, 0, sizeof(*sge));
177604326b4SDaniel Borkmann 	return len;
178604326b4SDaniel Borkmann }
179604326b4SDaniel Borkmann 
180604326b4SDaniel Borkmann static int __sk_msg_free(struct sock *sk, struct sk_msg *msg, u32 i,
181604326b4SDaniel Borkmann 			 bool charge)
182604326b4SDaniel Borkmann {
183604326b4SDaniel Borkmann 	struct scatterlist *sge = sk_msg_elem(msg, i);
184604326b4SDaniel Borkmann 	int freed = 0;
185604326b4SDaniel Borkmann 
186604326b4SDaniel Borkmann 	while (msg->sg.size) {
187604326b4SDaniel Borkmann 		msg->sg.size -= sge->length;
188604326b4SDaniel Borkmann 		freed += sk_msg_free_elem(sk, msg, i, charge);
189604326b4SDaniel Borkmann 		sk_msg_iter_var_next(i);
190604326b4SDaniel Borkmann 		sk_msg_check_to_free(msg, i, msg->sg.size);
191604326b4SDaniel Borkmann 		sge = sk_msg_elem(msg, i);
192604326b4SDaniel Borkmann 	}
193604326b4SDaniel Borkmann 	consume_skb(msg->skb);
194604326b4SDaniel Borkmann 	sk_msg_init(msg);
195604326b4SDaniel Borkmann 	return freed;
196604326b4SDaniel Borkmann }
197604326b4SDaniel Borkmann 
198604326b4SDaniel Borkmann int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg)
199604326b4SDaniel Borkmann {
200604326b4SDaniel Borkmann 	return __sk_msg_free(sk, msg, msg->sg.start, false);
201604326b4SDaniel Borkmann }
202604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_msg_free_nocharge);
203604326b4SDaniel Borkmann 
204604326b4SDaniel Borkmann int sk_msg_free(struct sock *sk, struct sk_msg *msg)
205604326b4SDaniel Borkmann {
206604326b4SDaniel Borkmann 	return __sk_msg_free(sk, msg, msg->sg.start, true);
207604326b4SDaniel Borkmann }
208604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_msg_free);
209604326b4SDaniel Borkmann 
210604326b4SDaniel Borkmann static void __sk_msg_free_partial(struct sock *sk, struct sk_msg *msg,
211604326b4SDaniel Borkmann 				  u32 bytes, bool charge)
212604326b4SDaniel Borkmann {
213604326b4SDaniel Borkmann 	struct scatterlist *sge;
214604326b4SDaniel Borkmann 	u32 i = msg->sg.start;
215604326b4SDaniel Borkmann 
216604326b4SDaniel Borkmann 	while (bytes) {
217604326b4SDaniel Borkmann 		sge = sk_msg_elem(msg, i);
218604326b4SDaniel Borkmann 		if (!sge->length)
219604326b4SDaniel Borkmann 			break;
220604326b4SDaniel Borkmann 		if (bytes < sge->length) {
221604326b4SDaniel Borkmann 			if (charge)
222604326b4SDaniel Borkmann 				sk_mem_uncharge(sk, bytes);
223604326b4SDaniel Borkmann 			sge->length -= bytes;
224604326b4SDaniel Borkmann 			sge->offset += bytes;
225604326b4SDaniel Borkmann 			msg->sg.size -= bytes;
226604326b4SDaniel Borkmann 			break;
227604326b4SDaniel Borkmann 		}
228604326b4SDaniel Borkmann 
229604326b4SDaniel Borkmann 		msg->sg.size -= sge->length;
230604326b4SDaniel Borkmann 		bytes -= sge->length;
231604326b4SDaniel Borkmann 		sk_msg_free_elem(sk, msg, i, charge);
232604326b4SDaniel Borkmann 		sk_msg_iter_var_next(i);
233604326b4SDaniel Borkmann 		sk_msg_check_to_free(msg, i, bytes);
234604326b4SDaniel Borkmann 	}
235604326b4SDaniel Borkmann 	msg->sg.start = i;
236604326b4SDaniel Borkmann }
237604326b4SDaniel Borkmann 
238604326b4SDaniel Borkmann void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes)
239604326b4SDaniel Borkmann {
240604326b4SDaniel Borkmann 	__sk_msg_free_partial(sk, msg, bytes, true);
241604326b4SDaniel Borkmann }
242604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_msg_free_partial);
243604326b4SDaniel Borkmann 
244604326b4SDaniel Borkmann void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
245604326b4SDaniel Borkmann 				  u32 bytes)
246604326b4SDaniel Borkmann {
247604326b4SDaniel Borkmann 	__sk_msg_free_partial(sk, msg, bytes, false);
248604326b4SDaniel Borkmann }
249604326b4SDaniel Borkmann 
250604326b4SDaniel Borkmann void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len)
251604326b4SDaniel Borkmann {
252604326b4SDaniel Borkmann 	int trim = msg->sg.size - len;
253604326b4SDaniel Borkmann 	u32 i = msg->sg.end;
254604326b4SDaniel Borkmann 
255604326b4SDaniel Borkmann 	if (trim <= 0) {
256604326b4SDaniel Borkmann 		WARN_ON(trim < 0);
257604326b4SDaniel Borkmann 		return;
258604326b4SDaniel Borkmann 	}
259604326b4SDaniel Borkmann 
260604326b4SDaniel Borkmann 	sk_msg_iter_var_prev(i);
261604326b4SDaniel Borkmann 	msg->sg.size = len;
262604326b4SDaniel Borkmann 	while (msg->sg.data[i].length &&
263604326b4SDaniel Borkmann 	       trim >= msg->sg.data[i].length) {
264604326b4SDaniel Borkmann 		trim -= msg->sg.data[i].length;
265604326b4SDaniel Borkmann 		sk_msg_free_elem(sk, msg, i, true);
266604326b4SDaniel Borkmann 		sk_msg_iter_var_prev(i);
267604326b4SDaniel Borkmann 		if (!trim)
268604326b4SDaniel Borkmann 			goto out;
269604326b4SDaniel Borkmann 	}
270604326b4SDaniel Borkmann 
271604326b4SDaniel Borkmann 	msg->sg.data[i].length -= trim;
272604326b4SDaniel Borkmann 	sk_mem_uncharge(sk, trim);
273683916f6SJakub Kicinski 	/* Adjust copybreak if it falls into the trimmed part of last buf */
274683916f6SJakub Kicinski 	if (msg->sg.curr == i && msg->sg.copybreak > msg->sg.data[i].length)
275683916f6SJakub Kicinski 		msg->sg.copybreak = msg->sg.data[i].length;
276604326b4SDaniel Borkmann out:
277683916f6SJakub Kicinski 	sk_msg_iter_var_next(i);
278683916f6SJakub Kicinski 	msg->sg.end = i;
279683916f6SJakub Kicinski 
280683916f6SJakub Kicinski 	/* If we trim data a full sg elem before curr pointer update
281683916f6SJakub Kicinski 	 * copybreak and current so that any future copy operations
282683916f6SJakub Kicinski 	 * start at new copy location.
283604326b4SDaniel Borkmann 	 * However trimed data that has not yet been used in a copy op
284604326b4SDaniel Borkmann 	 * does not require an update.
285604326b4SDaniel Borkmann 	 */
286683916f6SJakub Kicinski 	if (!msg->sg.size) {
287683916f6SJakub Kicinski 		msg->sg.curr = msg->sg.start;
288683916f6SJakub Kicinski 		msg->sg.copybreak = 0;
289683916f6SJakub Kicinski 	} else if (sk_msg_iter_dist(msg->sg.start, msg->sg.curr) >=
290683916f6SJakub Kicinski 		   sk_msg_iter_dist(msg->sg.start, msg->sg.end)) {
291683916f6SJakub Kicinski 		sk_msg_iter_var_prev(i);
292604326b4SDaniel Borkmann 		msg->sg.curr = i;
293604326b4SDaniel Borkmann 		msg->sg.copybreak = msg->sg.data[i].length;
294604326b4SDaniel Borkmann 	}
295604326b4SDaniel Borkmann }
296604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_msg_trim);
297604326b4SDaniel Borkmann 
298604326b4SDaniel Borkmann int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
299604326b4SDaniel Borkmann 			      struct sk_msg *msg, u32 bytes)
300604326b4SDaniel Borkmann {
301604326b4SDaniel Borkmann 	int i, maxpages, ret = 0, num_elems = sk_msg_elem_used(msg);
302604326b4SDaniel Borkmann 	const int to_max_pages = MAX_MSG_FRAGS;
303604326b4SDaniel Borkmann 	struct page *pages[MAX_MSG_FRAGS];
304604326b4SDaniel Borkmann 	ssize_t orig, copied, use, offset;
305604326b4SDaniel Borkmann 
306604326b4SDaniel Borkmann 	orig = msg->sg.size;
307604326b4SDaniel Borkmann 	while (bytes > 0) {
308604326b4SDaniel Borkmann 		i = 0;
309604326b4SDaniel Borkmann 		maxpages = to_max_pages - num_elems;
310604326b4SDaniel Borkmann 		if (maxpages == 0) {
311604326b4SDaniel Borkmann 			ret = -EFAULT;
312604326b4SDaniel Borkmann 			goto out;
313604326b4SDaniel Borkmann 		}
314604326b4SDaniel Borkmann 
315604326b4SDaniel Borkmann 		copied = iov_iter_get_pages(from, pages, bytes, maxpages,
316604326b4SDaniel Borkmann 					    &offset);
317604326b4SDaniel Borkmann 		if (copied <= 0) {
318604326b4SDaniel Borkmann 			ret = -EFAULT;
319604326b4SDaniel Borkmann 			goto out;
320604326b4SDaniel Borkmann 		}
321604326b4SDaniel Borkmann 
322604326b4SDaniel Borkmann 		iov_iter_advance(from, copied);
323604326b4SDaniel Borkmann 		bytes -= copied;
324604326b4SDaniel Borkmann 		msg->sg.size += copied;
325604326b4SDaniel Borkmann 
326604326b4SDaniel Borkmann 		while (copied) {
327604326b4SDaniel Borkmann 			use = min_t(int, copied, PAGE_SIZE - offset);
328604326b4SDaniel Borkmann 			sg_set_page(&msg->sg.data[msg->sg.end],
329604326b4SDaniel Borkmann 				    pages[i], use, offset);
330604326b4SDaniel Borkmann 			sg_unmark_end(&msg->sg.data[msg->sg.end]);
331604326b4SDaniel Borkmann 			sk_mem_charge(sk, use);
332604326b4SDaniel Borkmann 
333604326b4SDaniel Borkmann 			offset = 0;
334604326b4SDaniel Borkmann 			copied -= use;
335604326b4SDaniel Borkmann 			sk_msg_iter_next(msg, end);
336604326b4SDaniel Borkmann 			num_elems++;
337604326b4SDaniel Borkmann 			i++;
338604326b4SDaniel Borkmann 		}
339604326b4SDaniel Borkmann 		/* When zerocopy is mixed with sk_msg_*copy* operations we
340604326b4SDaniel Borkmann 		 * may have a copybreak set in this case clear and prefer
341604326b4SDaniel Borkmann 		 * zerocopy remainder when possible.
342604326b4SDaniel Borkmann 		 */
343604326b4SDaniel Borkmann 		msg->sg.copybreak = 0;
344604326b4SDaniel Borkmann 		msg->sg.curr = msg->sg.end;
345604326b4SDaniel Borkmann 	}
346604326b4SDaniel Borkmann out:
347604326b4SDaniel Borkmann 	/* Revert iov_iter updates, msg will need to use 'trim' later if it
348604326b4SDaniel Borkmann 	 * also needs to be cleared.
349604326b4SDaniel Borkmann 	 */
350604326b4SDaniel Borkmann 	if (ret)
351604326b4SDaniel Borkmann 		iov_iter_revert(from, msg->sg.size - orig);
352604326b4SDaniel Borkmann 	return ret;
353604326b4SDaniel Borkmann }
354604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter);
355604326b4SDaniel Borkmann 
356604326b4SDaniel Borkmann int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
357604326b4SDaniel Borkmann 			     struct sk_msg *msg, u32 bytes)
358604326b4SDaniel Borkmann {
359604326b4SDaniel Borkmann 	int ret = -ENOSPC, i = msg->sg.curr;
360604326b4SDaniel Borkmann 	struct scatterlist *sge;
361604326b4SDaniel Borkmann 	u32 copy, buf_size;
362604326b4SDaniel Borkmann 	void *to;
363604326b4SDaniel Borkmann 
364604326b4SDaniel Borkmann 	do {
365604326b4SDaniel Borkmann 		sge = sk_msg_elem(msg, i);
366604326b4SDaniel Borkmann 		/* This is possible if a trim operation shrunk the buffer */
367604326b4SDaniel Borkmann 		if (msg->sg.copybreak >= sge->length) {
368604326b4SDaniel Borkmann 			msg->sg.copybreak = 0;
369604326b4SDaniel Borkmann 			sk_msg_iter_var_next(i);
370604326b4SDaniel Borkmann 			if (i == msg->sg.end)
371604326b4SDaniel Borkmann 				break;
372604326b4SDaniel Borkmann 			sge = sk_msg_elem(msg, i);
373604326b4SDaniel Borkmann 		}
374604326b4SDaniel Borkmann 
375604326b4SDaniel Borkmann 		buf_size = sge->length - msg->sg.copybreak;
376604326b4SDaniel Borkmann 		copy = (buf_size > bytes) ? bytes : buf_size;
377604326b4SDaniel Borkmann 		to = sg_virt(sge) + msg->sg.copybreak;
378604326b4SDaniel Borkmann 		msg->sg.copybreak += copy;
379604326b4SDaniel Borkmann 		if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
380604326b4SDaniel Borkmann 			ret = copy_from_iter_nocache(to, copy, from);
381604326b4SDaniel Borkmann 		else
382604326b4SDaniel Borkmann 			ret = copy_from_iter(to, copy, from);
383604326b4SDaniel Borkmann 		if (ret != copy) {
384604326b4SDaniel Borkmann 			ret = -EFAULT;
385604326b4SDaniel Borkmann 			goto out;
386604326b4SDaniel Borkmann 		}
387604326b4SDaniel Borkmann 		bytes -= copy;
388604326b4SDaniel Borkmann 		if (!bytes)
389604326b4SDaniel Borkmann 			break;
390604326b4SDaniel Borkmann 		msg->sg.copybreak = 0;
391604326b4SDaniel Borkmann 		sk_msg_iter_var_next(i);
392604326b4SDaniel Borkmann 	} while (i != msg->sg.end);
393604326b4SDaniel Borkmann out:
394604326b4SDaniel Borkmann 	msg->sg.curr = i;
395604326b4SDaniel Borkmann 	return ret;
396604326b4SDaniel Borkmann }
397604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter);
398604326b4SDaniel Borkmann 
399604326b4SDaniel Borkmann static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb)
400604326b4SDaniel Borkmann {
401604326b4SDaniel Borkmann 	struct sock *sk = psock->sk;
402604326b4SDaniel Borkmann 	int copied = 0, num_sge;
403604326b4SDaniel Borkmann 	struct sk_msg *msg;
404604326b4SDaniel Borkmann 
405604326b4SDaniel Borkmann 	msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC);
406604326b4SDaniel Borkmann 	if (unlikely(!msg))
407604326b4SDaniel Borkmann 		return -EAGAIN;
408604326b4SDaniel Borkmann 	if (!sk_rmem_schedule(sk, skb, skb->len)) {
409604326b4SDaniel Borkmann 		kfree(msg);
410604326b4SDaniel Borkmann 		return -EAGAIN;
411604326b4SDaniel Borkmann 	}
412604326b4SDaniel Borkmann 
413604326b4SDaniel Borkmann 	sk_msg_init(msg);
414604326b4SDaniel Borkmann 	num_sge = skb_to_sgvec(skb, msg->sg.data, 0, skb->len);
415604326b4SDaniel Borkmann 	if (unlikely(num_sge < 0)) {
416604326b4SDaniel Borkmann 		kfree(msg);
417604326b4SDaniel Borkmann 		return num_sge;
418604326b4SDaniel Borkmann 	}
419604326b4SDaniel Borkmann 
420604326b4SDaniel Borkmann 	sk_mem_charge(sk, skb->len);
421604326b4SDaniel Borkmann 	copied = skb->len;
422604326b4SDaniel Borkmann 	msg->sg.start = 0;
423cabede8bSJohn Fastabend 	msg->sg.size = copied;
424031097d9SJakub Kicinski 	msg->sg.end = num_sge;
425604326b4SDaniel Borkmann 	msg->skb = skb;
426604326b4SDaniel Borkmann 
427604326b4SDaniel Borkmann 	sk_psock_queue_msg(psock, msg);
428552de910SJohn Fastabend 	sk_psock_data_ready(sk, psock);
429604326b4SDaniel Borkmann 	return copied;
430604326b4SDaniel Borkmann }
431604326b4SDaniel Borkmann 
432604326b4SDaniel Borkmann static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb,
433604326b4SDaniel Borkmann 			       u32 off, u32 len, bool ingress)
434604326b4SDaniel Borkmann {
435604326b4SDaniel Borkmann 	if (ingress)
436604326b4SDaniel Borkmann 		return sk_psock_skb_ingress(psock, skb);
437604326b4SDaniel Borkmann 	else
438604326b4SDaniel Borkmann 		return skb_send_sock_locked(psock->sk, skb, off, len);
439604326b4SDaniel Borkmann }
440604326b4SDaniel Borkmann 
441604326b4SDaniel Borkmann static void sk_psock_backlog(struct work_struct *work)
442604326b4SDaniel Borkmann {
443604326b4SDaniel Borkmann 	struct sk_psock *psock = container_of(work, struct sk_psock, work);
444604326b4SDaniel Borkmann 	struct sk_psock_work_state *state = &psock->work_state;
445604326b4SDaniel Borkmann 	struct sk_buff *skb;
446604326b4SDaniel Borkmann 	bool ingress;
447604326b4SDaniel Borkmann 	u32 len, off;
448604326b4SDaniel Borkmann 	int ret;
449604326b4SDaniel Borkmann 
450604326b4SDaniel Borkmann 	/* Lock sock to avoid losing sk_socket during loop. */
451604326b4SDaniel Borkmann 	lock_sock(psock->sk);
452604326b4SDaniel Borkmann 	if (state->skb) {
453604326b4SDaniel Borkmann 		skb = state->skb;
454604326b4SDaniel Borkmann 		len = state->len;
455604326b4SDaniel Borkmann 		off = state->off;
456604326b4SDaniel Borkmann 		state->skb = NULL;
457604326b4SDaniel Borkmann 		goto start;
458604326b4SDaniel Borkmann 	}
459604326b4SDaniel Borkmann 
460604326b4SDaniel Borkmann 	while ((skb = skb_dequeue(&psock->ingress_skb))) {
461604326b4SDaniel Borkmann 		len = skb->len;
462604326b4SDaniel Borkmann 		off = 0;
463604326b4SDaniel Borkmann start:
464604326b4SDaniel Borkmann 		ingress = tcp_skb_bpf_ingress(skb);
465604326b4SDaniel Borkmann 		do {
466604326b4SDaniel Borkmann 			ret = -EIO;
467604326b4SDaniel Borkmann 			if (likely(psock->sk->sk_socket))
468604326b4SDaniel Borkmann 				ret = sk_psock_handle_skb(psock, skb, off,
469604326b4SDaniel Borkmann 							  len, ingress);
470604326b4SDaniel Borkmann 			if (ret <= 0) {
471604326b4SDaniel Borkmann 				if (ret == -EAGAIN) {
472604326b4SDaniel Borkmann 					state->skb = skb;
473604326b4SDaniel Borkmann 					state->len = len;
474604326b4SDaniel Borkmann 					state->off = off;
475604326b4SDaniel Borkmann 					goto end;
476604326b4SDaniel Borkmann 				}
477604326b4SDaniel Borkmann 				/* Hard errors break pipe and stop xmit. */
478604326b4SDaniel Borkmann 				sk_psock_report_error(psock, ret ? -ret : EPIPE);
479604326b4SDaniel Borkmann 				sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
480604326b4SDaniel Borkmann 				kfree_skb(skb);
481604326b4SDaniel Borkmann 				goto end;
482604326b4SDaniel Borkmann 			}
483604326b4SDaniel Borkmann 			off += ret;
484604326b4SDaniel Borkmann 			len -= ret;
485604326b4SDaniel Borkmann 		} while (len);
486604326b4SDaniel Borkmann 
487604326b4SDaniel Borkmann 		if (!ingress)
488604326b4SDaniel Borkmann 			kfree_skb(skb);
489604326b4SDaniel Borkmann 	}
490604326b4SDaniel Borkmann end:
491604326b4SDaniel Borkmann 	release_sock(psock->sk);
492604326b4SDaniel Borkmann }
493604326b4SDaniel Borkmann 
494604326b4SDaniel Borkmann struct sk_psock *sk_psock_init(struct sock *sk, int node)
495604326b4SDaniel Borkmann {
496604326b4SDaniel Borkmann 	struct sk_psock *psock = kzalloc_node(sizeof(*psock),
497604326b4SDaniel Borkmann 					      GFP_ATOMIC | __GFP_NOWARN,
498604326b4SDaniel Borkmann 					      node);
499604326b4SDaniel Borkmann 	if (!psock)
500604326b4SDaniel Borkmann 		return NULL;
501604326b4SDaniel Borkmann 
502604326b4SDaniel Borkmann 	psock->sk = sk;
503604326b4SDaniel Borkmann 	psock->eval =  __SK_NONE;
504604326b4SDaniel Borkmann 
505604326b4SDaniel Borkmann 	INIT_LIST_HEAD(&psock->link);
506604326b4SDaniel Borkmann 	spin_lock_init(&psock->link_lock);
507604326b4SDaniel Borkmann 
508604326b4SDaniel Borkmann 	INIT_WORK(&psock->work, sk_psock_backlog);
509604326b4SDaniel Borkmann 	INIT_LIST_HEAD(&psock->ingress_msg);
510604326b4SDaniel Borkmann 	skb_queue_head_init(&psock->ingress_skb);
511604326b4SDaniel Borkmann 
512604326b4SDaniel Borkmann 	sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED);
513604326b4SDaniel Borkmann 	refcount_set(&psock->refcnt, 1);
514604326b4SDaniel Borkmann 
515f1ff5ce2SJakub Sitnicki 	rcu_assign_sk_user_data_nocopy(sk, psock);
516604326b4SDaniel Borkmann 	sock_hold(sk);
517604326b4SDaniel Borkmann 
518604326b4SDaniel Borkmann 	return psock;
519604326b4SDaniel Borkmann }
520604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_psock_init);
521604326b4SDaniel Borkmann 
522604326b4SDaniel Borkmann struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock)
523604326b4SDaniel Borkmann {
524604326b4SDaniel Borkmann 	struct sk_psock_link *link;
525604326b4SDaniel Borkmann 
526604326b4SDaniel Borkmann 	spin_lock_bh(&psock->link_lock);
527604326b4SDaniel Borkmann 	link = list_first_entry_or_null(&psock->link, struct sk_psock_link,
528604326b4SDaniel Borkmann 					list);
529604326b4SDaniel Borkmann 	if (link)
530604326b4SDaniel Borkmann 		list_del(&link->list);
531604326b4SDaniel Borkmann 	spin_unlock_bh(&psock->link_lock);
532604326b4SDaniel Borkmann 	return link;
533604326b4SDaniel Borkmann }
534604326b4SDaniel Borkmann 
535604326b4SDaniel Borkmann void __sk_psock_purge_ingress_msg(struct sk_psock *psock)
536604326b4SDaniel Borkmann {
537604326b4SDaniel Borkmann 	struct sk_msg *msg, *tmp;
538604326b4SDaniel Borkmann 
539604326b4SDaniel Borkmann 	list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) {
540604326b4SDaniel Borkmann 		list_del(&msg->list);
541604326b4SDaniel Borkmann 		sk_msg_free(psock->sk, msg);
542604326b4SDaniel Borkmann 		kfree(msg);
543604326b4SDaniel Borkmann 	}
544604326b4SDaniel Borkmann }
545604326b4SDaniel Borkmann 
546604326b4SDaniel Borkmann static void sk_psock_zap_ingress(struct sk_psock *psock)
547604326b4SDaniel Borkmann {
548604326b4SDaniel Borkmann 	__skb_queue_purge(&psock->ingress_skb);
549604326b4SDaniel Borkmann 	__sk_psock_purge_ingress_msg(psock);
550604326b4SDaniel Borkmann }
551604326b4SDaniel Borkmann 
552604326b4SDaniel Borkmann static void sk_psock_link_destroy(struct sk_psock *psock)
553604326b4SDaniel Borkmann {
554604326b4SDaniel Borkmann 	struct sk_psock_link *link, *tmp;
555604326b4SDaniel Borkmann 
556604326b4SDaniel Borkmann 	list_for_each_entry_safe(link, tmp, &psock->link, list) {
557604326b4SDaniel Borkmann 		list_del(&link->list);
558604326b4SDaniel Borkmann 		sk_psock_free_link(link);
559604326b4SDaniel Borkmann 	}
560604326b4SDaniel Borkmann }
561604326b4SDaniel Borkmann 
562604326b4SDaniel Borkmann static void sk_psock_destroy_deferred(struct work_struct *gc)
563604326b4SDaniel Borkmann {
564604326b4SDaniel Borkmann 	struct sk_psock *psock = container_of(gc, struct sk_psock, gc);
565604326b4SDaniel Borkmann 
566604326b4SDaniel Borkmann 	/* No sk_callback_lock since already detached. */
56701489436SJohn Fastabend 
56801489436SJohn Fastabend 	/* Parser has been stopped */
56901489436SJohn Fastabend 	if (psock->progs.skb_parser)
570604326b4SDaniel Borkmann 		strp_done(&psock->parser.strp);
571604326b4SDaniel Borkmann 
572604326b4SDaniel Borkmann 	cancel_work_sync(&psock->work);
573604326b4SDaniel Borkmann 
574604326b4SDaniel Borkmann 	psock_progs_drop(&psock->progs);
575604326b4SDaniel Borkmann 
576604326b4SDaniel Borkmann 	sk_psock_link_destroy(psock);
577604326b4SDaniel Borkmann 	sk_psock_cork_free(psock);
578604326b4SDaniel Borkmann 	sk_psock_zap_ingress(psock);
579604326b4SDaniel Borkmann 
580604326b4SDaniel Borkmann 	if (psock->sk_redir)
581604326b4SDaniel Borkmann 		sock_put(psock->sk_redir);
582604326b4SDaniel Borkmann 	sock_put(psock->sk);
583604326b4SDaniel Borkmann 	kfree(psock);
584604326b4SDaniel Borkmann }
585604326b4SDaniel Borkmann 
586604326b4SDaniel Borkmann void sk_psock_destroy(struct rcu_head *rcu)
587604326b4SDaniel Borkmann {
588604326b4SDaniel Borkmann 	struct sk_psock *psock = container_of(rcu, struct sk_psock, rcu);
589604326b4SDaniel Borkmann 
590604326b4SDaniel Borkmann 	INIT_WORK(&psock->gc, sk_psock_destroy_deferred);
591604326b4SDaniel Borkmann 	schedule_work(&psock->gc);
592604326b4SDaniel Borkmann }
593604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_psock_destroy);
594604326b4SDaniel Borkmann 
595604326b4SDaniel Borkmann void sk_psock_drop(struct sock *sk, struct sk_psock *psock)
596604326b4SDaniel Borkmann {
597604326b4SDaniel Borkmann 	sk_psock_cork_free(psock);
598a136678cSJohn Fastabend 	sk_psock_zap_ingress(psock);
599604326b4SDaniel Borkmann 
600604326b4SDaniel Borkmann 	write_lock_bh(&sk->sk_callback_lock);
60195fa1454SJohn Fastabend 	sk_psock_restore_proto(sk, psock);
60295fa1454SJohn Fastabend 	rcu_assign_sk_user_data(sk, NULL);
603604326b4SDaniel Borkmann 	if (psock->progs.skb_parser)
604604326b4SDaniel Borkmann 		sk_psock_stop_strp(sk, psock);
605604326b4SDaniel Borkmann 	write_unlock_bh(&sk->sk_callback_lock);
606604326b4SDaniel Borkmann 	sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
607604326b4SDaniel Borkmann 
6080245b80eSPaul E. McKenney 	call_rcu(&psock->rcu, sk_psock_destroy);
609604326b4SDaniel Borkmann }
610604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_psock_drop);
611604326b4SDaniel Borkmann 
612604326b4SDaniel Borkmann static int sk_psock_map_verd(int verdict, bool redir)
613604326b4SDaniel Borkmann {
614604326b4SDaniel Borkmann 	switch (verdict) {
615604326b4SDaniel Borkmann 	case SK_PASS:
616604326b4SDaniel Borkmann 		return redir ? __SK_REDIRECT : __SK_PASS;
617604326b4SDaniel Borkmann 	case SK_DROP:
618604326b4SDaniel Borkmann 	default:
619604326b4SDaniel Borkmann 		break;
620604326b4SDaniel Borkmann 	}
621604326b4SDaniel Borkmann 
622604326b4SDaniel Borkmann 	return __SK_DROP;
623604326b4SDaniel Borkmann }
624604326b4SDaniel Borkmann 
625604326b4SDaniel Borkmann int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
626604326b4SDaniel Borkmann 			 struct sk_msg *msg)
627604326b4SDaniel Borkmann {
628604326b4SDaniel Borkmann 	struct bpf_prog *prog;
629604326b4SDaniel Borkmann 	int ret;
630604326b4SDaniel Borkmann 
631604326b4SDaniel Borkmann 	rcu_read_lock();
632604326b4SDaniel Borkmann 	prog = READ_ONCE(psock->progs.msg_parser);
633604326b4SDaniel Borkmann 	if (unlikely(!prog)) {
634604326b4SDaniel Borkmann 		ret = __SK_PASS;
635604326b4SDaniel Borkmann 		goto out;
636604326b4SDaniel Borkmann 	}
637604326b4SDaniel Borkmann 
638604326b4SDaniel Borkmann 	sk_msg_compute_data_pointers(msg);
639604326b4SDaniel Borkmann 	msg->sk = sk;
6403d9f773cSDavid Miller 	ret = bpf_prog_run_pin_on_cpu(prog, msg);
641604326b4SDaniel Borkmann 	ret = sk_psock_map_verd(ret, msg->sk_redir);
642604326b4SDaniel Borkmann 	psock->apply_bytes = msg->apply_bytes;
643604326b4SDaniel Borkmann 	if (ret == __SK_REDIRECT) {
644604326b4SDaniel Borkmann 		if (psock->sk_redir)
645604326b4SDaniel Borkmann 			sock_put(psock->sk_redir);
646604326b4SDaniel Borkmann 		psock->sk_redir = msg->sk_redir;
647604326b4SDaniel Borkmann 		if (!psock->sk_redir) {
648604326b4SDaniel Borkmann 			ret = __SK_DROP;
649604326b4SDaniel Borkmann 			goto out;
650604326b4SDaniel Borkmann 		}
651604326b4SDaniel Borkmann 		sock_hold(psock->sk_redir);
652604326b4SDaniel Borkmann 	}
653604326b4SDaniel Borkmann out:
654604326b4SDaniel Borkmann 	rcu_read_unlock();
655604326b4SDaniel Borkmann 	return ret;
656604326b4SDaniel Borkmann }
657604326b4SDaniel Borkmann EXPORT_SYMBOL_GPL(sk_psock_msg_verdict);
658604326b4SDaniel Borkmann 
659604326b4SDaniel Borkmann static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog,
660604326b4SDaniel Borkmann 			    struct sk_buff *skb)
661604326b4SDaniel Borkmann {
662604326b4SDaniel Borkmann 	int ret;
663604326b4SDaniel Borkmann 
664604326b4SDaniel Borkmann 	skb->sk = psock->sk;
665604326b4SDaniel Borkmann 	bpf_compute_data_end_sk_skb(skb);
6663d9f773cSDavid Miller 	ret = bpf_prog_run_pin_on_cpu(prog, skb);
667604326b4SDaniel Borkmann 	/* strparser clones the skb before handing it to a upper layer,
668604326b4SDaniel Borkmann 	 * meaning skb_orphan has been called. We NULL sk on the way out
669604326b4SDaniel Borkmann 	 * to ensure we don't trigger a BUG_ON() in skb/sk operations
670604326b4SDaniel Borkmann 	 * later and because we are not charging the memory of this skb
671604326b4SDaniel Borkmann 	 * to any socket yet.
672604326b4SDaniel Borkmann 	 */
673604326b4SDaniel Borkmann 	skb->sk = NULL;
674604326b4SDaniel Borkmann 	return ret;
675604326b4SDaniel Borkmann }
676604326b4SDaniel Borkmann 
677604326b4SDaniel Borkmann static struct sk_psock *sk_psock_from_strp(struct strparser *strp)
678604326b4SDaniel Borkmann {
679604326b4SDaniel Borkmann 	struct sk_psock_parser *parser;
680604326b4SDaniel Borkmann 
681604326b4SDaniel Borkmann 	parser = container_of(strp, struct sk_psock_parser, strp);
682604326b4SDaniel Borkmann 	return container_of(parser, struct sk_psock, parser);
683604326b4SDaniel Borkmann }
684604326b4SDaniel Borkmann 
685*ca2f5f21SJohn Fastabend static void sk_psock_skb_redirect(struct sk_psock *psock, struct sk_buff *skb)
686604326b4SDaniel Borkmann {
687604326b4SDaniel Borkmann 	struct sk_psock *psock_other;
688604326b4SDaniel Borkmann 	struct sock *sk_other;
689604326b4SDaniel Borkmann 	bool ingress;
690604326b4SDaniel Borkmann 
691*ca2f5f21SJohn Fastabend 	sk_other = tcp_skb_bpf_redirect_fetch(skb);
692*ca2f5f21SJohn Fastabend 	if (unlikely(!sk_other)) {
693*ca2f5f21SJohn Fastabend 		kfree_skb(skb);
694*ca2f5f21SJohn Fastabend 		return;
695*ca2f5f21SJohn Fastabend 	}
696*ca2f5f21SJohn Fastabend 	psock_other = sk_psock(sk_other);
697*ca2f5f21SJohn Fastabend 	if (!psock_other || sock_flag(sk_other, SOCK_DEAD) ||
698*ca2f5f21SJohn Fastabend 	    !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED)) {
699*ca2f5f21SJohn Fastabend 		kfree_skb(skb);
700*ca2f5f21SJohn Fastabend 		return;
701*ca2f5f21SJohn Fastabend 	}
702*ca2f5f21SJohn Fastabend 
703*ca2f5f21SJohn Fastabend 	ingress = tcp_skb_bpf_ingress(skb);
704*ca2f5f21SJohn Fastabend 	if ((!ingress && sock_writeable(sk_other)) ||
705*ca2f5f21SJohn Fastabend 	    (ingress &&
706*ca2f5f21SJohn Fastabend 	     atomic_read(&sk_other->sk_rmem_alloc) <=
707*ca2f5f21SJohn Fastabend 	     sk_other->sk_rcvbuf)) {
708*ca2f5f21SJohn Fastabend 		if (!ingress)
709*ca2f5f21SJohn Fastabend 			skb_set_owner_w(skb, sk_other);
710*ca2f5f21SJohn Fastabend 		skb_queue_tail(&psock_other->ingress_skb, skb);
711*ca2f5f21SJohn Fastabend 		schedule_work(&psock_other->work);
712*ca2f5f21SJohn Fastabend 	} else {
713*ca2f5f21SJohn Fastabend 		kfree_skb(skb);
714*ca2f5f21SJohn Fastabend 	}
715*ca2f5f21SJohn Fastabend }
716*ca2f5f21SJohn Fastabend 
717*ca2f5f21SJohn Fastabend static void sk_psock_verdict_apply(struct sk_psock *psock,
718*ca2f5f21SJohn Fastabend 				   struct sk_buff *skb, int verdict)
719*ca2f5f21SJohn Fastabend {
720*ca2f5f21SJohn Fastabend 	struct sock *sk_other;
721*ca2f5f21SJohn Fastabend 
722604326b4SDaniel Borkmann 	switch (verdict) {
72351199405SJohn Fastabend 	case __SK_PASS:
72451199405SJohn Fastabend 		sk_other = psock->sk;
72551199405SJohn Fastabend 		if (sock_flag(sk_other, SOCK_DEAD) ||
72651199405SJohn Fastabend 		    !sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) {
72751199405SJohn Fastabend 			goto out_free;
72851199405SJohn Fastabend 		}
72951199405SJohn Fastabend 		if (atomic_read(&sk_other->sk_rmem_alloc) <=
73051199405SJohn Fastabend 		    sk_other->sk_rcvbuf) {
73151199405SJohn Fastabend 			struct tcp_skb_cb *tcp = TCP_SKB_CB(skb);
73251199405SJohn Fastabend 
73351199405SJohn Fastabend 			tcp->bpf.flags |= BPF_F_INGRESS;
73451199405SJohn Fastabend 			skb_queue_tail(&psock->ingress_skb, skb);
73551199405SJohn Fastabend 			schedule_work(&psock->work);
73651199405SJohn Fastabend 			break;
73751199405SJohn Fastabend 		}
73851199405SJohn Fastabend 		goto out_free;
739604326b4SDaniel Borkmann 	case __SK_REDIRECT:
740*ca2f5f21SJohn Fastabend 		sk_psock_skb_redirect(psock, skb);
741604326b4SDaniel Borkmann 		break;
742604326b4SDaniel Borkmann 	case __SK_DROP:
743604326b4SDaniel Borkmann 		/* fall-through */
744604326b4SDaniel Borkmann 	default:
745604326b4SDaniel Borkmann out_free:
746604326b4SDaniel Borkmann 		kfree_skb(skb);
747604326b4SDaniel Borkmann 	}
748604326b4SDaniel Borkmann }
749604326b4SDaniel Borkmann 
750604326b4SDaniel Borkmann static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb)
751604326b4SDaniel Borkmann {
752604326b4SDaniel Borkmann 	struct sk_psock *psock = sk_psock_from_strp(strp);
753604326b4SDaniel Borkmann 	struct bpf_prog *prog;
754604326b4SDaniel Borkmann 	int ret = __SK_DROP;
755604326b4SDaniel Borkmann 
756604326b4SDaniel Borkmann 	rcu_read_lock();
757604326b4SDaniel Borkmann 	prog = READ_ONCE(psock->progs.skb_verdict);
758604326b4SDaniel Borkmann 	if (likely(prog)) {
759604326b4SDaniel Borkmann 		skb_orphan(skb);
760604326b4SDaniel Borkmann 		tcp_skb_bpf_redirect_clear(skb);
761604326b4SDaniel Borkmann 		ret = sk_psock_bpf_run(psock, prog, skb);
762604326b4SDaniel Borkmann 		ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb));
763604326b4SDaniel Borkmann 	}
764604326b4SDaniel Borkmann 	rcu_read_unlock();
765604326b4SDaniel Borkmann 	sk_psock_verdict_apply(psock, skb, ret);
766604326b4SDaniel Borkmann }
767604326b4SDaniel Borkmann 
768604326b4SDaniel Borkmann static int sk_psock_strp_read_done(struct strparser *strp, int err)
769604326b4SDaniel Borkmann {
770604326b4SDaniel Borkmann 	return err;
771604326b4SDaniel Borkmann }
772604326b4SDaniel Borkmann 
773604326b4SDaniel Borkmann static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb)
774604326b4SDaniel Borkmann {
775604326b4SDaniel Borkmann 	struct sk_psock *psock = sk_psock_from_strp(strp);
776604326b4SDaniel Borkmann 	struct bpf_prog *prog;
777604326b4SDaniel Borkmann 	int ret = skb->len;
778604326b4SDaniel Borkmann 
779604326b4SDaniel Borkmann 	rcu_read_lock();
780604326b4SDaniel Borkmann 	prog = READ_ONCE(psock->progs.skb_parser);
781604326b4SDaniel Borkmann 	if (likely(prog))
782604326b4SDaniel Borkmann 		ret = sk_psock_bpf_run(psock, prog, skb);
783604326b4SDaniel Borkmann 	rcu_read_unlock();
784604326b4SDaniel Borkmann 	return ret;
785604326b4SDaniel Borkmann }
786604326b4SDaniel Borkmann 
787604326b4SDaniel Borkmann /* Called with socket lock held. */
788552de910SJohn Fastabend static void sk_psock_strp_data_ready(struct sock *sk)
789604326b4SDaniel Borkmann {
790604326b4SDaniel Borkmann 	struct sk_psock *psock;
791604326b4SDaniel Borkmann 
792604326b4SDaniel Borkmann 	rcu_read_lock();
793604326b4SDaniel Borkmann 	psock = sk_psock(sk);
794604326b4SDaniel Borkmann 	if (likely(psock)) {
795604326b4SDaniel Borkmann 		write_lock_bh(&sk->sk_callback_lock);
796604326b4SDaniel Borkmann 		strp_data_ready(&psock->parser.strp);
797604326b4SDaniel Borkmann 		write_unlock_bh(&sk->sk_callback_lock);
798604326b4SDaniel Borkmann 	}
799604326b4SDaniel Borkmann 	rcu_read_unlock();
800604326b4SDaniel Borkmann }
801604326b4SDaniel Borkmann 
802604326b4SDaniel Borkmann static void sk_psock_write_space(struct sock *sk)
803604326b4SDaniel Borkmann {
804604326b4SDaniel Borkmann 	struct sk_psock *psock;
8058163999dSJohn Fastabend 	void (*write_space)(struct sock *sk) = NULL;
806604326b4SDaniel Borkmann 
807604326b4SDaniel Borkmann 	rcu_read_lock();
808604326b4SDaniel Borkmann 	psock = sk_psock(sk);
8098163999dSJohn Fastabend 	if (likely(psock)) {
8108163999dSJohn Fastabend 		if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED))
811604326b4SDaniel Borkmann 			schedule_work(&psock->work);
812604326b4SDaniel Borkmann 		write_space = psock->saved_write_space;
8138163999dSJohn Fastabend 	}
814604326b4SDaniel Borkmann 	rcu_read_unlock();
8158163999dSJohn Fastabend 	if (write_space)
816604326b4SDaniel Borkmann 		write_space(sk);
817604326b4SDaniel Borkmann }
818604326b4SDaniel Borkmann 
819604326b4SDaniel Borkmann int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock)
820604326b4SDaniel Borkmann {
821604326b4SDaniel Borkmann 	static const struct strp_callbacks cb = {
822604326b4SDaniel Borkmann 		.rcv_msg	= sk_psock_strp_read,
823604326b4SDaniel Borkmann 		.read_sock_done	= sk_psock_strp_read_done,
824604326b4SDaniel Borkmann 		.parse_msg	= sk_psock_strp_parse,
825604326b4SDaniel Borkmann 	};
826604326b4SDaniel Borkmann 
827604326b4SDaniel Borkmann 	psock->parser.enabled = false;
828604326b4SDaniel Borkmann 	return strp_init(&psock->parser.strp, sk, &cb);
829604326b4SDaniel Borkmann }
830604326b4SDaniel Borkmann 
831604326b4SDaniel Borkmann void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock)
832604326b4SDaniel Borkmann {
833604326b4SDaniel Borkmann 	struct sk_psock_parser *parser = &psock->parser;
834604326b4SDaniel Borkmann 
835604326b4SDaniel Borkmann 	if (parser->enabled)
836604326b4SDaniel Borkmann 		return;
837604326b4SDaniel Borkmann 
838604326b4SDaniel Borkmann 	parser->saved_data_ready = sk->sk_data_ready;
839552de910SJohn Fastabend 	sk->sk_data_ready = sk_psock_strp_data_ready;
840604326b4SDaniel Borkmann 	sk->sk_write_space = sk_psock_write_space;
841604326b4SDaniel Borkmann 	parser->enabled = true;
842604326b4SDaniel Borkmann }
843604326b4SDaniel Borkmann 
844604326b4SDaniel Borkmann void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock)
845604326b4SDaniel Borkmann {
846604326b4SDaniel Borkmann 	struct sk_psock_parser *parser = &psock->parser;
847604326b4SDaniel Borkmann 
848604326b4SDaniel Borkmann 	if (!parser->enabled)
849604326b4SDaniel Borkmann 		return;
850604326b4SDaniel Borkmann 
851604326b4SDaniel Borkmann 	sk->sk_data_ready = parser->saved_data_ready;
852604326b4SDaniel Borkmann 	parser->saved_data_ready = NULL;
853604326b4SDaniel Borkmann 	strp_stop(&parser->strp);
854604326b4SDaniel Borkmann 	parser->enabled = false;
855604326b4SDaniel Borkmann }
856