xref: /linux/net/sunrpc/socklib.c (revision ce3f5bb7504ca802efa710280a4601a06545bd2e)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * linux/net/sunrpc/socklib.c
4  *
5  * Common socket helper routines for RPC client and server
6  *
7  * Copyright (C) 1995, 1996 Olaf Kirch <okir@monad.swb.de>
8  */
9 
10 #include <linux/compiler.h>
11 #include <linux/netdevice.h>
12 #include <linux/gfp.h>
13 #include <linux/skbuff.h>
14 #include <linux/types.h>
15 #include <linux/pagemap.h>
16 #include <linux/udp.h>
17 #include <linux/sunrpc/msg_prot.h>
18 #include <linux/sunrpc/sched.h>
19 #include <linux/sunrpc/xdr.h>
20 #include <linux/export.h>
21 
22 #include "socklib.h"
23 
24 /*
25  * Helper structure for copying from an sk_buff.
26  */
27 struct xdr_skb_reader {
28 	struct sk_buff	*skb;
29 	unsigned int	offset;
30 	bool		need_checksum;
31 	size_t		count;
32 	__wsum		csum;
33 };
34 
35 /**
36  * xdr_skb_read_bits - copy some data bits from skb to internal buffer
37  * @desc: sk_buff copy helper
38  * @to: copy destination
39  * @len: number of bytes to copy
40  *
41  * Possibly called several times to iterate over an sk_buff and copy data out of
42  * it.
43  */
44 static size_t
xdr_skb_read_bits(struct xdr_skb_reader * desc,void * to,size_t len)45 xdr_skb_read_bits(struct xdr_skb_reader *desc, void *to, size_t len)
46 {
47 	len = min(len, desc->count);
48 
49 	if (desc->need_checksum) {
50 		__wsum csum;
51 
52 		csum = skb_copy_and_csum_bits(desc->skb, desc->offset, to, len);
53 		desc->csum = csum_block_add(desc->csum, csum, desc->offset);
54 	} else {
55 		if (unlikely(skb_copy_bits(desc->skb, desc->offset, to, len)))
56 			return 0;
57 	}
58 
59 	desc->count -= len;
60 	desc->offset += len;
61 	return len;
62 }
63 
64 static ssize_t
xdr_partial_copy_from_skb(struct xdr_buf * xdr,struct xdr_skb_reader * desc)65 xdr_partial_copy_from_skb(struct xdr_buf *xdr, struct xdr_skb_reader *desc)
66 {
67 	struct page **ppage = xdr->pages + (xdr->page_base >> PAGE_SHIFT);
68 	unsigned int poff = xdr->page_base & ~PAGE_MASK;
69 	unsigned int pglen = xdr->page_len;
70 	ssize_t copied = 0;
71 	size_t ret;
72 
73 	if (xdr->head[0].iov_len == 0)
74 		return 0;
75 
76 	ret = xdr_skb_read_bits(desc, xdr->head[0].iov_base,
77 				xdr->head[0].iov_len);
78 	if (ret != xdr->head[0].iov_len || !desc->count)
79 		return ret;
80 	copied += ret;
81 
82 	while (pglen) {
83 		unsigned int len = min(PAGE_SIZE - poff, pglen);
84 		char *kaddr;
85 
86 		/* ACL likes to be lazy in allocating pages - ACLs
87 		 * are small by default but can get huge. */
88 		if ((xdr->flags & XDRBUF_SPARSE_PAGES) && *ppage == NULL) {
89 			*ppage = alloc_page(GFP_NOWAIT | __GFP_NOWARN);
90 			if (unlikely(*ppage == NULL)) {
91 				if (copied == 0)
92 					return -ENOMEM;
93 				return copied;
94 			}
95 		}
96 
97 		kaddr = kmap_atomic(*ppage);
98 		ret = xdr_skb_read_bits(desc, kaddr + poff, len);
99 		flush_dcache_page(*ppage);
100 		kunmap_atomic(kaddr);
101 
102 		copied += ret;
103 		if (ret != len || !desc->count)
104 			return copied;
105 		ppage++;
106 		pglen -= len;
107 		poff = 0;
108 	}
109 
110 	if (xdr->tail[0].iov_len) {
111 		copied += xdr_skb_read_bits(desc, xdr->tail[0].iov_base,
112 					xdr->tail[0].iov_len);
113 	}
114 
115 	return copied;
116 }
117 
118 /**
119  * csum_partial_copy_to_xdr - checksum and copy data
120  * @xdr: target XDR buffer
121  * @skb: source skb
122  *
123  * We have set things up such that we perform the checksum of the UDP
124  * packet in parallel with the copies into the RPC client iovec.  -DaveM
125  */
csum_partial_copy_to_xdr(struct xdr_buf * xdr,struct sk_buff * skb)126 int csum_partial_copy_to_xdr(struct xdr_buf *xdr, struct sk_buff *skb)
127 {
128 	struct xdr_skb_reader desc = {
129 		.skb		= skb,
130 		.count		= skb->len - desc.offset,
131 	};
132 
133 	if (skb_csum_unnecessary(skb)) {
134 		if (xdr_partial_copy_from_skb(xdr, &desc) < 0)
135 			return -1;
136 		if (desc.count)
137 			return -1;
138 		return 0;
139 	}
140 
141 	desc.need_checksum = true;
142 	desc.csum = csum_partial(skb->data, desc.offset, skb->csum);
143 	if (xdr_partial_copy_from_skb(xdr, &desc) < 0)
144 		return -1;
145 	if (desc.offset != skb->len) {
146 		__wsum csum2;
147 		csum2 = skb_checksum(skb, desc.offset, skb->len - desc.offset, 0);
148 		desc.csum = csum_block_add(desc.csum, csum2, desc.offset);
149 	}
150 	if (desc.count)
151 		return -1;
152 	if (csum_fold(desc.csum))
153 		return -1;
154 	if (unlikely(skb->ip_summed == CHECKSUM_COMPLETE) &&
155 	    !skb->csum_complete_sw)
156 		netdev_rx_csum_fault(skb->dev, skb);
157 	return 0;
158 }
159 
xprt_sendmsg(struct socket * sock,struct msghdr * msg,size_t seek)160 static inline int xprt_sendmsg(struct socket *sock, struct msghdr *msg,
161 			       size_t seek)
162 {
163 	if (seek)
164 		iov_iter_advance(&msg->msg_iter, seek);
165 	return sock_sendmsg(sock, msg);
166 }
167 
xprt_send_kvec(struct socket * sock,struct msghdr * msg,struct kvec * vec,size_t seek)168 static int xprt_send_kvec(struct socket *sock, struct msghdr *msg,
169 			  struct kvec *vec, size_t seek)
170 {
171 	iov_iter_kvec(&msg->msg_iter, ITER_SOURCE, vec, 1, vec->iov_len);
172 	return xprt_sendmsg(sock, msg, seek);
173 }
174 
xprt_send_pagedata(struct socket * sock,struct msghdr * msg,struct xdr_buf * xdr,size_t base)175 static int xprt_send_pagedata(struct socket *sock, struct msghdr *msg,
176 			      struct xdr_buf *xdr, size_t base)
177 {
178 	iov_iter_bvec(&msg->msg_iter, ITER_SOURCE, xdr->bvec, xdr_buf_pagecount(xdr),
179 		      xdr->page_len + xdr->page_base);
180 	return xprt_sendmsg(sock, msg, base + xdr->page_base);
181 }
182 
183 /* Common case:
184  *  - stream transport
185  *  - sending from byte 0 of the message
186  *  - the message is wholly contained in @xdr's head iovec
187  */
xprt_send_rm_and_kvec(struct socket * sock,struct msghdr * msg,rpc_fraghdr marker,struct kvec * vec,size_t base)188 static int xprt_send_rm_and_kvec(struct socket *sock, struct msghdr *msg,
189 				 rpc_fraghdr marker, struct kvec *vec,
190 				 size_t base)
191 {
192 	struct kvec iov[2] = {
193 		[0] = {
194 			.iov_base	= &marker,
195 			.iov_len	= sizeof(marker)
196 		},
197 		[1] = *vec,
198 	};
199 	size_t len = iov[0].iov_len + iov[1].iov_len;
200 
201 	iov_iter_kvec(&msg->msg_iter, ITER_SOURCE, iov, 2, len);
202 	return xprt_sendmsg(sock, msg, base);
203 }
204 
205 /**
206  * xprt_sock_sendmsg - write an xdr_buf directly to a socket
207  * @sock: open socket to send on
208  * @msg: socket message metadata
209  * @xdr: xdr_buf containing this request
210  * @base: starting position in the buffer
211  * @marker: stream record marker field
212  * @sent_p: return the total number of bytes successfully queued for sending
213  *
214  * Return values:
215  *   On success, returns zero and fills in @sent_p.
216  *   %-ENOTSOCK if  @sock is not a struct socket.
217  */
xprt_sock_sendmsg(struct socket * sock,struct msghdr * msg,struct xdr_buf * xdr,unsigned int base,rpc_fraghdr marker,unsigned int * sent_p)218 int xprt_sock_sendmsg(struct socket *sock, struct msghdr *msg,
219 		      struct xdr_buf *xdr, unsigned int base,
220 		      rpc_fraghdr marker, unsigned int *sent_p)
221 {
222 	unsigned int rmsize = marker ? sizeof(marker) : 0;
223 	unsigned int remainder = rmsize + xdr->len - base;
224 	unsigned int want;
225 	int err = 0;
226 
227 	*sent_p = 0;
228 
229 	if (unlikely(!sock))
230 		return -ENOTSOCK;
231 
232 	msg->msg_flags |= MSG_MORE;
233 	want = xdr->head[0].iov_len + rmsize;
234 	if (base < want) {
235 		unsigned int len = want - base;
236 
237 		remainder -= len;
238 		if (remainder == 0)
239 			msg->msg_flags &= ~MSG_MORE;
240 		if (rmsize)
241 			err = xprt_send_rm_and_kvec(sock, msg, marker,
242 						    &xdr->head[0], base);
243 		else
244 			err = xprt_send_kvec(sock, msg, &xdr->head[0], base);
245 		if (remainder == 0 || err != len)
246 			goto out;
247 		*sent_p += err;
248 		base = 0;
249 	} else {
250 		base -= want;
251 	}
252 
253 	if (base < xdr->page_len) {
254 		unsigned int len = xdr->page_len - base;
255 
256 		remainder -= len;
257 		if (remainder == 0)
258 			msg->msg_flags &= ~MSG_MORE;
259 		err = xprt_send_pagedata(sock, msg, xdr, base);
260 		if (remainder == 0 || err != len)
261 			goto out;
262 		*sent_p += err;
263 		base = 0;
264 	} else {
265 		base -= xdr->page_len;
266 	}
267 
268 	if (base >= xdr->tail[0].iov_len)
269 		return 0;
270 	msg->msg_flags &= ~MSG_MORE;
271 	err = xprt_send_kvec(sock, msg, &xdr->tail[0], base);
272 out:
273 	if (err > 0) {
274 		*sent_p += err;
275 		err = 0;
276 	}
277 	return err;
278 }
279