xref: /linux/net/xfrm/espintcp.c (revision 15a1fbdcfb519c2bd291ed01c6c94e0b89537a77)
1 // SPDX-License-Identifier: GPL-2.0
2 #include <net/tcp.h>
3 #include <net/strparser.h>
4 #include <net/xfrm.h>
5 #include <net/esp.h>
6 #include <net/espintcp.h>
7 #include <linux/skmsg.h>
8 #include <net/inet_common.h>
9 
10 static void handle_nonesp(struct espintcp_ctx *ctx, struct sk_buff *skb,
11 			  struct sock *sk)
12 {
13 	if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf ||
14 	    !sk_rmem_schedule(sk, skb, skb->truesize)) {
15 		kfree_skb(skb);
16 		return;
17 	}
18 
19 	skb_set_owner_r(skb, sk);
20 
21 	memset(skb->cb, 0, sizeof(skb->cb));
22 	skb_queue_tail(&ctx->ike_queue, skb);
23 	ctx->saved_data_ready(sk);
24 }
25 
26 static void handle_esp(struct sk_buff *skb, struct sock *sk)
27 {
28 	skb_reset_transport_header(skb);
29 	memset(skb->cb, 0, sizeof(skb->cb));
30 
31 	rcu_read_lock();
32 	skb->dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif);
33 	local_bh_disable();
34 	xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
35 	local_bh_enable();
36 	rcu_read_unlock();
37 }
38 
39 static void espintcp_rcv(struct strparser *strp, struct sk_buff *skb)
40 {
41 	struct espintcp_ctx *ctx = container_of(strp, struct espintcp_ctx,
42 						strp);
43 	struct strp_msg *rxm = strp_msg(skb);
44 	u32 nonesp_marker;
45 	int err;
46 
47 	err = skb_copy_bits(skb, rxm->offset + 2, &nonesp_marker,
48 			    sizeof(nonesp_marker));
49 	if (err < 0) {
50 		kfree_skb(skb);
51 		return;
52 	}
53 
54 	/* remove header, leave non-ESP marker/SPI */
55 	if (!__pskb_pull(skb, rxm->offset + 2)) {
56 		kfree_skb(skb);
57 		return;
58 	}
59 
60 	if (pskb_trim(skb, rxm->full_len - 2) != 0) {
61 		kfree_skb(skb);
62 		return;
63 	}
64 
65 	if (nonesp_marker == 0)
66 		handle_nonesp(ctx, skb, strp->sk);
67 	else
68 		handle_esp(skb, strp->sk);
69 }
70 
71 static int espintcp_parse(struct strparser *strp, struct sk_buff *skb)
72 {
73 	struct strp_msg *rxm = strp_msg(skb);
74 	__be16 blen;
75 	u16 len;
76 	int err;
77 
78 	if (skb->len < rxm->offset + 2)
79 		return 0;
80 
81 	err = skb_copy_bits(skb, rxm->offset, &blen, sizeof(blen));
82 	if (err < 0)
83 		return err;
84 
85 	len = be16_to_cpu(blen);
86 	if (len < 6)
87 		return -EINVAL;
88 
89 	return len;
90 }
91 
92 static int espintcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
93 			    int nonblock, int flags, int *addr_len)
94 {
95 	struct espintcp_ctx *ctx = espintcp_getctx(sk);
96 	struct sk_buff *skb;
97 	int err = 0;
98 	int copied;
99 	int off = 0;
100 
101 	flags |= nonblock ? MSG_DONTWAIT : 0;
102 
103 	skb = __skb_recv_datagram(sk, &ctx->ike_queue, flags, NULL, &off, &err);
104 	if (!skb)
105 		return err;
106 
107 	copied = len;
108 	if (copied > skb->len)
109 		copied = skb->len;
110 	else if (copied < skb->len)
111 		msg->msg_flags |= MSG_TRUNC;
112 
113 	err = skb_copy_datagram_msg(skb, 0, msg, copied);
114 	if (unlikely(err)) {
115 		kfree_skb(skb);
116 		return err;
117 	}
118 
119 	if (flags & MSG_TRUNC)
120 		copied = skb->len;
121 	kfree_skb(skb);
122 	return copied;
123 }
124 
125 int espintcp_queue_out(struct sock *sk, struct sk_buff *skb)
126 {
127 	struct espintcp_ctx *ctx = espintcp_getctx(sk);
128 
129 	if (skb_queue_len(&ctx->out_queue) >= netdev_max_backlog)
130 		return -ENOBUFS;
131 
132 	__skb_queue_tail(&ctx->out_queue, skb);
133 
134 	return 0;
135 }
136 EXPORT_SYMBOL_GPL(espintcp_queue_out);
137 
138 /* espintcp length field is 2B and length includes the length field's size */
139 #define MAX_ESPINTCP_MSG (((1 << 16) - 1) - 2)
140 
141 static int espintcp_sendskb_locked(struct sock *sk, struct espintcp_msg *emsg,
142 				   int flags)
143 {
144 	do {
145 		int ret;
146 
147 		ret = skb_send_sock_locked(sk, emsg->skb,
148 					   emsg->offset, emsg->len);
149 		if (ret < 0)
150 			return ret;
151 
152 		emsg->len -= ret;
153 		emsg->offset += ret;
154 	} while (emsg->len > 0);
155 
156 	kfree_skb(emsg->skb);
157 	memset(emsg, 0, sizeof(*emsg));
158 
159 	return 0;
160 }
161 
162 static int espintcp_sendskmsg_locked(struct sock *sk,
163 				     struct espintcp_msg *emsg, int flags)
164 {
165 	struct sk_msg *skmsg = &emsg->skmsg;
166 	struct scatterlist *sg;
167 	int done = 0;
168 	int ret;
169 
170 	flags |= MSG_SENDPAGE_NOTLAST;
171 	sg = &skmsg->sg.data[skmsg->sg.start];
172 	do {
173 		size_t size = sg->length - emsg->offset;
174 		int offset = sg->offset + emsg->offset;
175 		struct page *p;
176 
177 		emsg->offset = 0;
178 
179 		if (sg_is_last(sg))
180 			flags &= ~MSG_SENDPAGE_NOTLAST;
181 
182 		p = sg_page(sg);
183 retry:
184 		ret = do_tcp_sendpages(sk, p, offset, size, flags);
185 		if (ret < 0) {
186 			emsg->offset = offset - sg->offset;
187 			skmsg->sg.start += done;
188 			return ret;
189 		}
190 
191 		if (ret != size) {
192 			offset += ret;
193 			size -= ret;
194 			goto retry;
195 		}
196 
197 		done++;
198 		put_page(p);
199 		sk_mem_uncharge(sk, sg->length);
200 		sg = sg_next(sg);
201 	} while (sg);
202 
203 	memset(emsg, 0, sizeof(*emsg));
204 
205 	return 0;
206 }
207 
208 static int espintcp_push_msgs(struct sock *sk)
209 {
210 	struct espintcp_ctx *ctx = espintcp_getctx(sk);
211 	struct espintcp_msg *emsg = &ctx->partial;
212 	int err;
213 
214 	if (!emsg->len)
215 		return 0;
216 
217 	if (ctx->tx_running)
218 		return -EAGAIN;
219 	ctx->tx_running = 1;
220 
221 	if (emsg->skb)
222 		err = espintcp_sendskb_locked(sk, emsg, 0);
223 	else
224 		err = espintcp_sendskmsg_locked(sk, emsg, 0);
225 	if (err == -EAGAIN) {
226 		ctx->tx_running = 0;
227 		return 0;
228 	}
229 	if (!err)
230 		memset(emsg, 0, sizeof(*emsg));
231 
232 	ctx->tx_running = 0;
233 
234 	return err;
235 }
236 
237 int espintcp_push_skb(struct sock *sk, struct sk_buff *skb)
238 {
239 	struct espintcp_ctx *ctx = espintcp_getctx(sk);
240 	struct espintcp_msg *emsg = &ctx->partial;
241 	unsigned int len;
242 	int offset;
243 
244 	if (sk->sk_state != TCP_ESTABLISHED) {
245 		kfree_skb(skb);
246 		return -ECONNRESET;
247 	}
248 
249 	offset = skb_transport_offset(skb);
250 	len = skb->len - offset;
251 
252 	espintcp_push_msgs(sk);
253 
254 	if (emsg->len) {
255 		kfree_skb(skb);
256 		return -ENOBUFS;
257 	}
258 
259 	skb_set_owner_w(skb, sk);
260 
261 	emsg->offset = offset;
262 	emsg->len = len;
263 	emsg->skb = skb;
264 
265 	espintcp_push_msgs(sk);
266 
267 	return 0;
268 }
269 EXPORT_SYMBOL_GPL(espintcp_push_skb);
270 
271 static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
272 {
273 	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
274 	struct espintcp_ctx *ctx = espintcp_getctx(sk);
275 	struct espintcp_msg *emsg = &ctx->partial;
276 	struct iov_iter pfx_iter;
277 	struct kvec pfx_iov = {};
278 	size_t msglen = size + 2;
279 	char buf[2] = {0};
280 	int err, end;
281 
282 	if (msg->msg_flags)
283 		return -EOPNOTSUPP;
284 
285 	if (size > MAX_ESPINTCP_MSG)
286 		return -EMSGSIZE;
287 
288 	if (msg->msg_controllen)
289 		return -EOPNOTSUPP;
290 
291 	lock_sock(sk);
292 
293 	err = espintcp_push_msgs(sk);
294 	if (err < 0) {
295 		err = -ENOBUFS;
296 		goto unlock;
297 	}
298 
299 	sk_msg_init(&emsg->skmsg);
300 	while (1) {
301 		/* only -ENOMEM is possible since we don't coalesce */
302 		err = sk_msg_alloc(sk, &emsg->skmsg, msglen, 0);
303 		if (!err)
304 			break;
305 
306 		err = sk_stream_wait_memory(sk, &timeo);
307 		if (err)
308 			goto fail;
309 	}
310 
311 	*((__be16 *)buf) = cpu_to_be16(msglen);
312 	pfx_iov.iov_base = buf;
313 	pfx_iov.iov_len = sizeof(buf);
314 	iov_iter_kvec(&pfx_iter, WRITE, &pfx_iov, 1, pfx_iov.iov_len);
315 
316 	err = sk_msg_memcopy_from_iter(sk, &pfx_iter, &emsg->skmsg,
317 				       pfx_iov.iov_len);
318 	if (err < 0)
319 		goto fail;
320 
321 	err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, &emsg->skmsg, size);
322 	if (err < 0)
323 		goto fail;
324 
325 	end = emsg->skmsg.sg.end;
326 	emsg->len = size;
327 	sk_msg_iter_var_prev(end);
328 	sg_mark_end(sk_msg_elem(&emsg->skmsg, end));
329 
330 	tcp_rate_check_app_limited(sk);
331 
332 	err = espintcp_push_msgs(sk);
333 	/* this message could be partially sent, keep it */
334 	if (err < 0)
335 		goto unlock;
336 	release_sock(sk);
337 
338 	return size;
339 
340 fail:
341 	sk_msg_free(sk, &emsg->skmsg);
342 	memset(emsg, 0, sizeof(*emsg));
343 unlock:
344 	release_sock(sk);
345 	return err;
346 }
347 
348 static struct proto espintcp_prot __ro_after_init;
349 static struct proto_ops espintcp_ops __ro_after_init;
350 
351 static void espintcp_data_ready(struct sock *sk)
352 {
353 	struct espintcp_ctx *ctx = espintcp_getctx(sk);
354 
355 	strp_data_ready(&ctx->strp);
356 }
357 
358 static void espintcp_tx_work(struct work_struct *work)
359 {
360 	struct espintcp_ctx *ctx = container_of(work,
361 						struct espintcp_ctx, work);
362 	struct sock *sk = ctx->strp.sk;
363 
364 	lock_sock(sk);
365 	if (!ctx->tx_running)
366 		espintcp_push_msgs(sk);
367 	release_sock(sk);
368 }
369 
370 static void espintcp_write_space(struct sock *sk)
371 {
372 	struct espintcp_ctx *ctx = espintcp_getctx(sk);
373 
374 	schedule_work(&ctx->work);
375 	ctx->saved_write_space(sk);
376 }
377 
378 static void espintcp_destruct(struct sock *sk)
379 {
380 	struct espintcp_ctx *ctx = espintcp_getctx(sk);
381 
382 	kfree(ctx);
383 }
384 
385 bool tcp_is_ulp_esp(struct sock *sk)
386 {
387 	return sk->sk_prot == &espintcp_prot;
388 }
389 EXPORT_SYMBOL_GPL(tcp_is_ulp_esp);
390 
391 static int espintcp_init_sk(struct sock *sk)
392 {
393 	struct inet_connection_sock *icsk = inet_csk(sk);
394 	struct strp_callbacks cb = {
395 		.rcv_msg = espintcp_rcv,
396 		.parse_msg = espintcp_parse,
397 	};
398 	struct espintcp_ctx *ctx;
399 	int err;
400 
401 	/* sockmap is not compatible with espintcp */
402 	if (sk->sk_user_data)
403 		return -EBUSY;
404 
405 	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
406 	if (!ctx)
407 		return -ENOMEM;
408 
409 	err = strp_init(&ctx->strp, sk, &cb);
410 	if (err)
411 		goto free;
412 
413 	__sk_dst_reset(sk);
414 
415 	strp_check_rcv(&ctx->strp);
416 	skb_queue_head_init(&ctx->ike_queue);
417 	skb_queue_head_init(&ctx->out_queue);
418 	sk->sk_prot = &espintcp_prot;
419 	sk->sk_socket->ops = &espintcp_ops;
420 	ctx->saved_data_ready = sk->sk_data_ready;
421 	ctx->saved_write_space = sk->sk_write_space;
422 	sk->sk_data_ready = espintcp_data_ready;
423 	sk->sk_write_space = espintcp_write_space;
424 	sk->sk_destruct = espintcp_destruct;
425 	rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
426 	INIT_WORK(&ctx->work, espintcp_tx_work);
427 
428 	/* avoid using task_frag */
429 	sk->sk_allocation = GFP_ATOMIC;
430 
431 	return 0;
432 
433 free:
434 	kfree(ctx);
435 	return err;
436 }
437 
438 static void espintcp_release(struct sock *sk)
439 {
440 	struct espintcp_ctx *ctx = espintcp_getctx(sk);
441 	struct sk_buff_head queue;
442 	struct sk_buff *skb;
443 
444 	__skb_queue_head_init(&queue);
445 	skb_queue_splice_init(&ctx->out_queue, &queue);
446 
447 	while ((skb = __skb_dequeue(&queue)))
448 		espintcp_push_skb(sk, skb);
449 
450 	tcp_release_cb(sk);
451 }
452 
453 static void espintcp_close(struct sock *sk, long timeout)
454 {
455 	struct espintcp_ctx *ctx = espintcp_getctx(sk);
456 	struct espintcp_msg *emsg = &ctx->partial;
457 
458 	strp_stop(&ctx->strp);
459 
460 	sk->sk_prot = &tcp_prot;
461 	barrier();
462 
463 	cancel_work_sync(&ctx->work);
464 	strp_done(&ctx->strp);
465 
466 	skb_queue_purge(&ctx->out_queue);
467 	skb_queue_purge(&ctx->ike_queue);
468 
469 	if (emsg->len) {
470 		if (emsg->skb)
471 			kfree_skb(emsg->skb);
472 		else
473 			sk_msg_free(sk, &emsg->skmsg);
474 	}
475 
476 	tcp_close(sk, timeout);
477 }
478 
479 static __poll_t espintcp_poll(struct file *file, struct socket *sock,
480 			      poll_table *wait)
481 {
482 	__poll_t mask = datagram_poll(file, sock, wait);
483 	struct sock *sk = sock->sk;
484 	struct espintcp_ctx *ctx = espintcp_getctx(sk);
485 
486 	if (!skb_queue_empty(&ctx->ike_queue))
487 		mask |= EPOLLIN | EPOLLRDNORM;
488 
489 	return mask;
490 }
491 
492 static struct tcp_ulp_ops espintcp_ulp __read_mostly = {
493 	.name = "espintcp",
494 	.owner = THIS_MODULE,
495 	.init = espintcp_init_sk,
496 };
497 
498 void __init espintcp_init(void)
499 {
500 	memcpy(&espintcp_prot, &tcp_prot, sizeof(tcp_prot));
501 	memcpy(&espintcp_ops, &inet_stream_ops, sizeof(inet_stream_ops));
502 	espintcp_prot.sendmsg = espintcp_sendmsg;
503 	espintcp_prot.recvmsg = espintcp_recvmsg;
504 	espintcp_prot.close = espintcp_close;
505 	espintcp_prot.release_cb = espintcp_release;
506 	espintcp_ops.poll = espintcp_poll;
507 
508 	tcp_register_ulp(&espintcp_ulp);
509 }
510