xref: /linux/tools/testing/selftests/net/ipsec.c (revision f3956ebb3bf06ab2266ad5ee2214aed46405810c)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * ipsec.c - Check xfrm on veth inside a net-ns.
4  * Copyright (c) 2018 Dmitry Safonov
5  */
6 
7 #define _GNU_SOURCE
8 
9 #include <arpa/inet.h>
10 #include <asm/types.h>
11 #include <errno.h>
12 #include <fcntl.h>
13 #include <limits.h>
14 #include <linux/limits.h>
15 #include <linux/netlink.h>
16 #include <linux/random.h>
17 #include <linux/rtnetlink.h>
18 #include <linux/veth.h>
19 #include <linux/xfrm.h>
20 #include <netinet/in.h>
21 #include <net/if.h>
22 #include <sched.h>
23 #include <stdbool.h>
24 #include <stdint.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <sys/mman.h>
29 #include <sys/socket.h>
30 #include <sys/stat.h>
31 #include <sys/syscall.h>
32 #include <sys/types.h>
33 #include <sys/wait.h>
34 #include <time.h>
35 #include <unistd.h>
36 
37 #include "../kselftest.h"
38 
39 #define printk(fmt, ...)						\
40 	ksft_print_msg("%d[%u] " fmt "\n", getpid(), __LINE__, ##__VA_ARGS__)
41 
42 #define pr_err(fmt, ...)	printk(fmt ": %m", ##__VA_ARGS__)
43 
44 #define ARRAY_SIZE(arr) (sizeof(arr) / sizeof((arr)[0]))
45 #define BUILD_BUG_ON(condition) ((void)sizeof(char[1 - 2*!!(condition)]))
46 
47 #define IPV4_STR_SZ	16	/* xxx.xxx.xxx.xxx is longest + \0 */
48 #define MAX_PAYLOAD	2048
49 #define XFRM_ALGO_KEY_BUF_SIZE	512
50 #define MAX_PROCESSES	(1 << 14) /* /16 mask divided by /30 subnets */
51 #define INADDR_A	((in_addr_t) 0x0a000000) /* 10.0.0.0 */
52 #define INADDR_B	((in_addr_t) 0xc0a80000) /* 192.168.0.0 */
53 
54 /* /30 mask for one veth connection */
55 #define PREFIX_LEN	30
56 #define child_ip(nr)	(4*nr + 1)
57 #define grchild_ip(nr)	(4*nr + 2)
58 
59 #define VETH_FMT	"ktst-%d"
60 #define VETH_LEN	12
61 
62 static int nsfd_parent	= -1;
63 static int nsfd_childa	= -1;
64 static int nsfd_childb	= -1;
65 static long page_size;
66 
67 /*
68  * ksft_cnt is static in kselftest, so isn't shared with children.
69  * We have to send a test result back to parent and count there.
70  * results_fd is a pipe with test feedback from children.
71  */
72 static int results_fd[2];
73 
74 const unsigned int ping_delay_nsec	= 50 * 1000 * 1000;
75 const unsigned int ping_timeout		= 300;
76 const unsigned int ping_count		= 100;
77 const unsigned int ping_success		= 80;
78 
79 static void randomize_buffer(void *buf, size_t buflen)
80 {
81 	int *p = (int *)buf;
82 	size_t words = buflen / sizeof(int);
83 	size_t leftover = buflen % sizeof(int);
84 
85 	if (!buflen)
86 		return;
87 
88 	while (words--)
89 		*p++ = rand();
90 
91 	if (leftover) {
92 		int tmp = rand();
93 
94 		memcpy(buf + buflen - leftover, &tmp, leftover);
95 	}
96 
97 	return;
98 }
99 
100 static int unshare_open(void)
101 {
102 	const char *netns_path = "/proc/self/ns/net";
103 	int fd;
104 
105 	if (unshare(CLONE_NEWNET) != 0) {
106 		pr_err("unshare()");
107 		return -1;
108 	}
109 
110 	fd = open(netns_path, O_RDONLY);
111 	if (fd <= 0) {
112 		pr_err("open(%s)", netns_path);
113 		return -1;
114 	}
115 
116 	return fd;
117 }
118 
119 static int switch_ns(int fd)
120 {
121 	if (setns(fd, CLONE_NEWNET)) {
122 		pr_err("setns()");
123 		return -1;
124 	}
125 	return 0;
126 }
127 
128 /*
129  * Running the test inside a new parent net namespace to bother less
130  * about cleanup on error-path.
131  */
132 static int init_namespaces(void)
133 {
134 	nsfd_parent = unshare_open();
135 	if (nsfd_parent <= 0)
136 		return -1;
137 
138 	nsfd_childa = unshare_open();
139 	if (nsfd_childa <= 0)
140 		return -1;
141 
142 	if (switch_ns(nsfd_parent))
143 		return -1;
144 
145 	nsfd_childb = unshare_open();
146 	if (nsfd_childb <= 0)
147 		return -1;
148 
149 	if (switch_ns(nsfd_parent))
150 		return -1;
151 	return 0;
152 }
153 
154 static int netlink_sock(int *sock, uint32_t *seq_nr, int proto)
155 {
156 	if (*sock > 0) {
157 		seq_nr++;
158 		return 0;
159 	}
160 
161 	*sock = socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, proto);
162 	if (*sock <= 0) {
163 		pr_err("socket(AF_NETLINK)");
164 		return -1;
165 	}
166 
167 	randomize_buffer(seq_nr, sizeof(*seq_nr));
168 
169 	return 0;
170 }
171 
172 static inline struct rtattr *rtattr_hdr(struct nlmsghdr *nh)
173 {
174 	return (struct rtattr *)((char *)(nh) + RTA_ALIGN((nh)->nlmsg_len));
175 }
176 
177 static int rtattr_pack(struct nlmsghdr *nh, size_t req_sz,
178 		unsigned short rta_type, const void *payload, size_t size)
179 {
180 	/* NLMSG_ALIGNTO == RTA_ALIGNTO, nlmsg_len already aligned */
181 	struct rtattr *attr = rtattr_hdr(nh);
182 	size_t nl_size = RTA_ALIGN(nh->nlmsg_len) + RTA_LENGTH(size);
183 
184 	if (req_sz < nl_size) {
185 		printk("req buf is too small: %zu < %zu", req_sz, nl_size);
186 		return -1;
187 	}
188 	nh->nlmsg_len = nl_size;
189 
190 	attr->rta_len = RTA_LENGTH(size);
191 	attr->rta_type = rta_type;
192 	memcpy(RTA_DATA(attr), payload, size);
193 
194 	return 0;
195 }
196 
197 static struct rtattr *_rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
198 		unsigned short rta_type, const void *payload, size_t size)
199 {
200 	struct rtattr *ret = rtattr_hdr(nh);
201 
202 	if (rtattr_pack(nh, req_sz, rta_type, payload, size))
203 		return 0;
204 
205 	return ret;
206 }
207 
208 static inline struct rtattr *rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
209 		unsigned short rta_type)
210 {
211 	return _rtattr_begin(nh, req_sz, rta_type, 0, 0);
212 }
213 
214 static inline void rtattr_end(struct nlmsghdr *nh, struct rtattr *attr)
215 {
216 	char *nlmsg_end = (char *)nh + nh->nlmsg_len;
217 
218 	attr->rta_len = nlmsg_end - (char *)attr;
219 }
220 
221 static int veth_pack_peerb(struct nlmsghdr *nh, size_t req_sz,
222 		const char *peer, int ns)
223 {
224 	struct ifinfomsg pi;
225 	struct rtattr *peer_attr;
226 
227 	memset(&pi, 0, sizeof(pi));
228 	pi.ifi_family	= AF_UNSPEC;
229 	pi.ifi_change	= 0xFFFFFFFF;
230 
231 	peer_attr = _rtattr_begin(nh, req_sz, VETH_INFO_PEER, &pi, sizeof(pi));
232 	if (!peer_attr)
233 		return -1;
234 
235 	if (rtattr_pack(nh, req_sz, IFLA_IFNAME, peer, strlen(peer)))
236 		return -1;
237 
238 	if (rtattr_pack(nh, req_sz, IFLA_NET_NS_FD, &ns, sizeof(ns)))
239 		return -1;
240 
241 	rtattr_end(nh, peer_attr);
242 
243 	return 0;
244 }
245 
246 static int netlink_check_answer(int sock)
247 {
248 	struct nlmsgerror {
249 		struct nlmsghdr hdr;
250 		int error;
251 		struct nlmsghdr orig_msg;
252 	} answer;
253 
254 	if (recv(sock, &answer, sizeof(answer), 0) < 0) {
255 		pr_err("recv()");
256 		return -1;
257 	} else if (answer.hdr.nlmsg_type != NLMSG_ERROR) {
258 		printk("expected NLMSG_ERROR, got %d", (int)answer.hdr.nlmsg_type);
259 		return -1;
260 	} else if (answer.error) {
261 		printk("NLMSG_ERROR: %d: %s",
262 			answer.error, strerror(-answer.error));
263 		return answer.error;
264 	}
265 
266 	return 0;
267 }
268 
269 static int veth_add(int sock, uint32_t seq, const char *peera, int ns_a,
270 		const char *peerb, int ns_b)
271 {
272 	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
273 	struct {
274 		struct nlmsghdr		nh;
275 		struct ifinfomsg	info;
276 		char			attrbuf[MAX_PAYLOAD];
277 	} req;
278 	const char veth_type[] = "veth";
279 	struct rtattr *link_info, *info_data;
280 
281 	memset(&req, 0, sizeof(req));
282 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
283 	req.nh.nlmsg_type	= RTM_NEWLINK;
284 	req.nh.nlmsg_flags	= flags;
285 	req.nh.nlmsg_seq	= seq;
286 	req.info.ifi_family	= AF_UNSPEC;
287 	req.info.ifi_change	= 0xFFFFFFFF;
288 
289 	if (rtattr_pack(&req.nh, sizeof(req), IFLA_IFNAME, peera, strlen(peera)))
290 		return -1;
291 
292 	if (rtattr_pack(&req.nh, sizeof(req), IFLA_NET_NS_FD, &ns_a, sizeof(ns_a)))
293 		return -1;
294 
295 	link_info = rtattr_begin(&req.nh, sizeof(req), IFLA_LINKINFO);
296 	if (!link_info)
297 		return -1;
298 
299 	if (rtattr_pack(&req.nh, sizeof(req), IFLA_INFO_KIND, veth_type, sizeof(veth_type)))
300 		return -1;
301 
302 	info_data = rtattr_begin(&req.nh, sizeof(req), IFLA_INFO_DATA);
303 	if (!info_data)
304 		return -1;
305 
306 	if (veth_pack_peerb(&req.nh, sizeof(req), peerb, ns_b))
307 		return -1;
308 
309 	rtattr_end(&req.nh, info_data);
310 	rtattr_end(&req.nh, link_info);
311 
312 	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
313 		pr_err("send()");
314 		return -1;
315 	}
316 	return netlink_check_answer(sock);
317 }
318 
319 static int ip4_addr_set(int sock, uint32_t seq, const char *intf,
320 		struct in_addr addr, uint8_t prefix)
321 {
322 	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
323 	struct {
324 		struct nlmsghdr		nh;
325 		struct ifaddrmsg	info;
326 		char			attrbuf[MAX_PAYLOAD];
327 	} req;
328 
329 	memset(&req, 0, sizeof(req));
330 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
331 	req.nh.nlmsg_type	= RTM_NEWADDR;
332 	req.nh.nlmsg_flags	= flags;
333 	req.nh.nlmsg_seq	= seq;
334 	req.info.ifa_family	= AF_INET;
335 	req.info.ifa_prefixlen	= prefix;
336 	req.info.ifa_index	= if_nametoindex(intf);
337 
338 #ifdef DEBUG
339 	{
340 		char addr_str[IPV4_STR_SZ] = {};
341 
342 		strncpy(addr_str, inet_ntoa(addr), IPV4_STR_SZ - 1);
343 
344 		printk("ip addr set %s", addr_str);
345 	}
346 #endif
347 
348 	if (rtattr_pack(&req.nh, sizeof(req), IFA_LOCAL, &addr, sizeof(addr)))
349 		return -1;
350 
351 	if (rtattr_pack(&req.nh, sizeof(req), IFA_ADDRESS, &addr, sizeof(addr)))
352 		return -1;
353 
354 	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
355 		pr_err("send()");
356 		return -1;
357 	}
358 	return netlink_check_answer(sock);
359 }
360 
361 static int link_set_up(int sock, uint32_t seq, const char *intf)
362 {
363 	struct {
364 		struct nlmsghdr		nh;
365 		struct ifinfomsg	info;
366 		char			attrbuf[MAX_PAYLOAD];
367 	} req;
368 
369 	memset(&req, 0, sizeof(req));
370 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
371 	req.nh.nlmsg_type	= RTM_NEWLINK;
372 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
373 	req.nh.nlmsg_seq	= seq;
374 	req.info.ifi_family	= AF_UNSPEC;
375 	req.info.ifi_change	= 0xFFFFFFFF;
376 	req.info.ifi_index	= if_nametoindex(intf);
377 	req.info.ifi_flags	= IFF_UP;
378 	req.info.ifi_change	= IFF_UP;
379 
380 	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
381 		pr_err("send()");
382 		return -1;
383 	}
384 	return netlink_check_answer(sock);
385 }
386 
387 static int ip4_route_set(int sock, uint32_t seq, const char *intf,
388 		struct in_addr src, struct in_addr dst)
389 {
390 	struct {
391 		struct nlmsghdr	nh;
392 		struct rtmsg	rt;
393 		char		attrbuf[MAX_PAYLOAD];
394 	} req;
395 	unsigned int index = if_nametoindex(intf);
396 
397 	memset(&req, 0, sizeof(req));
398 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.rt));
399 	req.nh.nlmsg_type	= RTM_NEWROUTE;
400 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE;
401 	req.nh.nlmsg_seq	= seq;
402 	req.rt.rtm_family	= AF_INET;
403 	req.rt.rtm_dst_len	= 32;
404 	req.rt.rtm_table	= RT_TABLE_MAIN;
405 	req.rt.rtm_protocol	= RTPROT_BOOT;
406 	req.rt.rtm_scope	= RT_SCOPE_LINK;
407 	req.rt.rtm_type		= RTN_UNICAST;
408 
409 	if (rtattr_pack(&req.nh, sizeof(req), RTA_DST, &dst, sizeof(dst)))
410 		return -1;
411 
412 	if (rtattr_pack(&req.nh, sizeof(req), RTA_PREFSRC, &src, sizeof(src)))
413 		return -1;
414 
415 	if (rtattr_pack(&req.nh, sizeof(req), RTA_OIF, &index, sizeof(index)))
416 		return -1;
417 
418 	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
419 		pr_err("send()");
420 		return -1;
421 	}
422 
423 	return netlink_check_answer(sock);
424 }
425 
426 static int tunnel_set_route(int route_sock, uint32_t *route_seq, char *veth,
427 		struct in_addr tunsrc, struct in_addr tundst)
428 {
429 	if (ip4_addr_set(route_sock, (*route_seq)++, "lo",
430 			tunsrc, PREFIX_LEN)) {
431 		printk("Failed to set ipv4 addr");
432 		return -1;
433 	}
434 
435 	if (ip4_route_set(route_sock, (*route_seq)++, veth, tunsrc, tundst)) {
436 		printk("Failed to set ipv4 route");
437 		return -1;
438 	}
439 
440 	return 0;
441 }
442 
443 static int init_child(int nsfd, char *veth, unsigned int src, unsigned int dst)
444 {
445 	struct in_addr intsrc = inet_makeaddr(INADDR_B, src);
446 	struct in_addr tunsrc = inet_makeaddr(INADDR_A, src);
447 	struct in_addr tundst = inet_makeaddr(INADDR_A, dst);
448 	int route_sock = -1, ret = -1;
449 	uint32_t route_seq;
450 
451 	if (switch_ns(nsfd))
452 		return -1;
453 
454 	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) {
455 		printk("Failed to open netlink route socket in child");
456 		return -1;
457 	}
458 
459 	if (ip4_addr_set(route_sock, route_seq++, veth, intsrc, PREFIX_LEN)) {
460 		printk("Failed to set ipv4 addr");
461 		goto err;
462 	}
463 
464 	if (link_set_up(route_sock, route_seq++, veth)) {
465 		printk("Failed to bring up %s", veth);
466 		goto err;
467 	}
468 
469 	if (tunnel_set_route(route_sock, &route_seq, veth, tunsrc, tundst)) {
470 		printk("Failed to add tunnel route on %s", veth);
471 		goto err;
472 	}
473 	ret = 0;
474 
475 err:
476 	close(route_sock);
477 	return ret;
478 }
479 
480 #define ALGO_LEN	64
481 enum desc_type {
482 	CREATE_TUNNEL	= 0,
483 	ALLOCATE_SPI,
484 	MONITOR_ACQUIRE,
485 	EXPIRE_STATE,
486 	EXPIRE_POLICY,
487 	SPDINFO_ATTRS,
488 };
489 const char *desc_name[] = {
490 	"create tunnel",
491 	"alloc spi",
492 	"monitor acquire",
493 	"expire state",
494 	"expire policy",
495 	"spdinfo attributes",
496 	""
497 };
498 struct xfrm_desc {
499 	enum desc_type	type;
500 	uint8_t		proto;
501 	char		a_algo[ALGO_LEN];
502 	char		e_algo[ALGO_LEN];
503 	char		c_algo[ALGO_LEN];
504 	char		ae_algo[ALGO_LEN];
505 	unsigned int	icv_len;
506 	/* unsigned key_len; */
507 };
508 
509 enum msg_type {
510 	MSG_ACK		= 0,
511 	MSG_EXIT,
512 	MSG_PING,
513 	MSG_XFRM_PREPARE,
514 	MSG_XFRM_ADD,
515 	MSG_XFRM_DEL,
516 	MSG_XFRM_CLEANUP,
517 };
518 
519 struct test_desc {
520 	enum msg_type type;
521 	union {
522 		struct {
523 			in_addr_t reply_ip;
524 			unsigned int port;
525 		} ping;
526 		struct xfrm_desc xfrm_desc;
527 	} body;
528 };
529 
530 struct test_result {
531 	struct xfrm_desc desc;
532 	unsigned int res;
533 };
534 
535 static void write_test_result(unsigned int res, struct xfrm_desc *d)
536 {
537 	struct test_result tr = {};
538 	ssize_t ret;
539 
540 	tr.desc = *d;
541 	tr.res = res;
542 
543 	ret = write(results_fd[1], &tr, sizeof(tr));
544 	if (ret != sizeof(tr))
545 		pr_err("Failed to write the result in pipe %zd", ret);
546 }
547 
548 static void write_msg(int fd, struct test_desc *msg, bool exit_of_fail)
549 {
550 	ssize_t bytes = write(fd, msg, sizeof(*msg));
551 
552 	/* Make sure that write/read is atomic to a pipe */
553 	BUILD_BUG_ON(sizeof(struct test_desc) > PIPE_BUF);
554 
555 	if (bytes < 0) {
556 		pr_err("write()");
557 		if (exit_of_fail)
558 			exit(KSFT_FAIL);
559 	}
560 	if (bytes != sizeof(*msg)) {
561 		pr_err("sent part of the message %zd/%zu", bytes, sizeof(*msg));
562 		if (exit_of_fail)
563 			exit(KSFT_FAIL);
564 	}
565 }
566 
567 static void read_msg(int fd, struct test_desc *msg, bool exit_of_fail)
568 {
569 	ssize_t bytes = read(fd, msg, sizeof(*msg));
570 
571 	if (bytes < 0) {
572 		pr_err("read()");
573 		if (exit_of_fail)
574 			exit(KSFT_FAIL);
575 	}
576 	if (bytes != sizeof(*msg)) {
577 		pr_err("got incomplete message %zd/%zu", bytes, sizeof(*msg));
578 		if (exit_of_fail)
579 			exit(KSFT_FAIL);
580 	}
581 }
582 
583 static int udp_ping_init(struct in_addr listen_ip, unsigned int u_timeout,
584 		unsigned int *server_port, int sock[2])
585 {
586 	struct sockaddr_in server;
587 	struct timeval t = { .tv_sec = 0, .tv_usec = u_timeout };
588 	socklen_t s_len = sizeof(server);
589 
590 	sock[0] = socket(AF_INET, SOCK_DGRAM, 0);
591 	if (sock[0] < 0) {
592 		pr_err("socket()");
593 		return -1;
594 	}
595 
596 	server.sin_family	= AF_INET;
597 	server.sin_port		= 0;
598 	memcpy(&server.sin_addr.s_addr, &listen_ip, sizeof(struct in_addr));
599 
600 	if (bind(sock[0], (struct sockaddr *)&server, s_len)) {
601 		pr_err("bind()");
602 		goto err_close_server;
603 	}
604 
605 	if (getsockname(sock[0], (struct sockaddr *)&server, &s_len)) {
606 		pr_err("getsockname()");
607 		goto err_close_server;
608 	}
609 
610 	*server_port = ntohs(server.sin_port);
611 
612 	if (setsockopt(sock[0], SOL_SOCKET, SO_RCVTIMEO, (const char *)&t, sizeof t)) {
613 		pr_err("setsockopt()");
614 		goto err_close_server;
615 	}
616 
617 	sock[1] = socket(AF_INET, SOCK_DGRAM, 0);
618 	if (sock[1] < 0) {
619 		pr_err("socket()");
620 		goto err_close_server;
621 	}
622 
623 	return 0;
624 
625 err_close_server:
626 	close(sock[0]);
627 	return -1;
628 }
629 
630 static int udp_ping_send(int sock[2], in_addr_t dest_ip, unsigned int port,
631 		char *buf, size_t buf_len)
632 {
633 	struct sockaddr_in server;
634 	const struct sockaddr *dest_addr = (struct sockaddr *)&server;
635 	char *sock_buf[buf_len];
636 	ssize_t r_bytes, s_bytes;
637 
638 	server.sin_family	= AF_INET;
639 	server.sin_port		= htons(port);
640 	server.sin_addr.s_addr	= dest_ip;
641 
642 	s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
643 	if (s_bytes < 0) {
644 		pr_err("sendto()");
645 		return -1;
646 	} else if (s_bytes != buf_len) {
647 		printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
648 		return -1;
649 	}
650 
651 	r_bytes = recv(sock[0], sock_buf, buf_len, 0);
652 	if (r_bytes < 0) {
653 		if (errno != EAGAIN)
654 			pr_err("recv()");
655 		return -1;
656 	} else if (r_bytes == 0) { /* EOF */
657 		printk("EOF on reply to ping");
658 		return -1;
659 	} else if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
660 		printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
661 		return -1;
662 	}
663 
664 	return 0;
665 }
666 
667 static int udp_ping_reply(int sock[2], in_addr_t dest_ip, unsigned int port,
668 		char *buf, size_t buf_len)
669 {
670 	struct sockaddr_in server;
671 	const struct sockaddr *dest_addr = (struct sockaddr *)&server;
672 	char *sock_buf[buf_len];
673 	ssize_t r_bytes, s_bytes;
674 
675 	server.sin_family	= AF_INET;
676 	server.sin_port		= htons(port);
677 	server.sin_addr.s_addr	= dest_ip;
678 
679 	r_bytes = recv(sock[0], sock_buf, buf_len, 0);
680 	if (r_bytes < 0) {
681 		if (errno != EAGAIN)
682 			pr_err("recv()");
683 		return -1;
684 	}
685 	if (r_bytes == 0) { /* EOF */
686 		printk("EOF on reply to ping");
687 		return -1;
688 	}
689 	if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
690 		printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
691 		return -1;
692 	}
693 
694 	s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
695 	if (s_bytes < 0) {
696 		pr_err("sendto()");
697 		return -1;
698 	} else if (s_bytes != buf_len) {
699 		printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
700 		return -1;
701 	}
702 
703 	return 0;
704 }
705 
706 typedef int (*ping_f)(int sock[2], in_addr_t dest_ip, unsigned int port,
707 		char *buf, size_t buf_len);
708 static int do_ping(int cmd_fd, char *buf, size_t buf_len, struct in_addr from,
709 		bool init_side, int d_port, in_addr_t to, ping_f func)
710 {
711 	struct test_desc msg;
712 	unsigned int s_port, i, ping_succeeded = 0;
713 	int ping_sock[2];
714 	char to_str[IPV4_STR_SZ] = {}, from_str[IPV4_STR_SZ] = {};
715 
716 	if (udp_ping_init(from, ping_timeout, &s_port, ping_sock)) {
717 		printk("Failed to init ping");
718 		return -1;
719 	}
720 
721 	memset(&msg, 0, sizeof(msg));
722 	msg.type		= MSG_PING;
723 	msg.body.ping.port	= s_port;
724 	memcpy(&msg.body.ping.reply_ip, &from, sizeof(from));
725 
726 	write_msg(cmd_fd, &msg, 0);
727 	if (init_side) {
728 		/* The other end sends ip to ping */
729 		read_msg(cmd_fd, &msg, 0);
730 		if (msg.type != MSG_PING)
731 			return -1;
732 		to = msg.body.ping.reply_ip;
733 		d_port = msg.body.ping.port;
734 	}
735 
736 	for (i = 0; i < ping_count ; i++) {
737 		struct timespec sleep_time = {
738 			.tv_sec = 0,
739 			.tv_nsec = ping_delay_nsec,
740 		};
741 
742 		ping_succeeded += !func(ping_sock, to, d_port, buf, page_size);
743 		nanosleep(&sleep_time, 0);
744 	}
745 
746 	close(ping_sock[0]);
747 	close(ping_sock[1]);
748 
749 	strncpy(to_str, inet_ntoa(*(struct in_addr *)&to), IPV4_STR_SZ - 1);
750 	strncpy(from_str, inet_ntoa(from), IPV4_STR_SZ - 1);
751 
752 	if (ping_succeeded < ping_success) {
753 		printk("ping (%s) %s->%s failed %u/%u times",
754 			init_side ? "send" : "reply", from_str, to_str,
755 			ping_count - ping_succeeded, ping_count);
756 		return -1;
757 	}
758 
759 #ifdef DEBUG
760 	printk("ping (%s) %s->%s succeeded %u/%u times",
761 		init_side ? "send" : "reply", from_str, to_str,
762 		ping_succeeded, ping_count);
763 #endif
764 
765 	return 0;
766 }
767 
768 static int xfrm_fill_key(char *name, char *buf,
769 		size_t buf_len, unsigned int *key_len)
770 {
771 	/* TODO: use set/map instead */
772 	if (strncmp(name, "digest_null", ALGO_LEN) == 0)
773 		*key_len = 0;
774 	else if (strncmp(name, "ecb(cipher_null)", ALGO_LEN) == 0)
775 		*key_len = 0;
776 	else if (strncmp(name, "cbc(des)", ALGO_LEN) == 0)
777 		*key_len = 64;
778 	else if (strncmp(name, "hmac(md5)", ALGO_LEN) == 0)
779 		*key_len = 128;
780 	else if (strncmp(name, "cmac(aes)", ALGO_LEN) == 0)
781 		*key_len = 128;
782 	else if (strncmp(name, "xcbc(aes)", ALGO_LEN) == 0)
783 		*key_len = 128;
784 	else if (strncmp(name, "cbc(cast5)", ALGO_LEN) == 0)
785 		*key_len = 128;
786 	else if (strncmp(name, "cbc(serpent)", ALGO_LEN) == 0)
787 		*key_len = 128;
788 	else if (strncmp(name, "hmac(sha1)", ALGO_LEN) == 0)
789 		*key_len = 160;
790 	else if (strncmp(name, "hmac(rmd160)", ALGO_LEN) == 0)
791 		*key_len = 160;
792 	else if (strncmp(name, "cbc(des3_ede)", ALGO_LEN) == 0)
793 		*key_len = 192;
794 	else if (strncmp(name, "hmac(sha256)", ALGO_LEN) == 0)
795 		*key_len = 256;
796 	else if (strncmp(name, "cbc(aes)", ALGO_LEN) == 0)
797 		*key_len = 256;
798 	else if (strncmp(name, "cbc(camellia)", ALGO_LEN) == 0)
799 		*key_len = 256;
800 	else if (strncmp(name, "cbc(twofish)", ALGO_LEN) == 0)
801 		*key_len = 256;
802 	else if (strncmp(name, "rfc3686(ctr(aes))", ALGO_LEN) == 0)
803 		*key_len = 288;
804 	else if (strncmp(name, "hmac(sha384)", ALGO_LEN) == 0)
805 		*key_len = 384;
806 	else if (strncmp(name, "cbc(blowfish)", ALGO_LEN) == 0)
807 		*key_len = 448;
808 	else if (strncmp(name, "hmac(sha512)", ALGO_LEN) == 0)
809 		*key_len = 512;
810 	else if (strncmp(name, "rfc4106(gcm(aes))-128", ALGO_LEN) == 0)
811 		*key_len = 160;
812 	else if (strncmp(name, "rfc4543(gcm(aes))-128", ALGO_LEN) == 0)
813 		*key_len = 160;
814 	else if (strncmp(name, "rfc4309(ccm(aes))-128", ALGO_LEN) == 0)
815 		*key_len = 152;
816 	else if (strncmp(name, "rfc4106(gcm(aes))-192", ALGO_LEN) == 0)
817 		*key_len = 224;
818 	else if (strncmp(name, "rfc4543(gcm(aes))-192", ALGO_LEN) == 0)
819 		*key_len = 224;
820 	else if (strncmp(name, "rfc4309(ccm(aes))-192", ALGO_LEN) == 0)
821 		*key_len = 216;
822 	else if (strncmp(name, "rfc4106(gcm(aes))-256", ALGO_LEN) == 0)
823 		*key_len = 288;
824 	else if (strncmp(name, "rfc4543(gcm(aes))-256", ALGO_LEN) == 0)
825 		*key_len = 288;
826 	else if (strncmp(name, "rfc4309(ccm(aes))-256", ALGO_LEN) == 0)
827 		*key_len = 280;
828 	else if (strncmp(name, "rfc7539(chacha20,poly1305)-128", ALGO_LEN) == 0)
829 		*key_len = 0;
830 
831 	if (*key_len > buf_len) {
832 		printk("Can't pack a key - too big for buffer");
833 		return -1;
834 	}
835 
836 	randomize_buffer(buf, *key_len);
837 
838 	return 0;
839 }
840 
841 static int xfrm_state_pack_algo(struct nlmsghdr *nh, size_t req_sz,
842 		struct xfrm_desc *desc)
843 {
844 	struct {
845 		union {
846 			struct xfrm_algo	alg;
847 			struct xfrm_algo_aead	aead;
848 			struct xfrm_algo_auth	auth;
849 		} u;
850 		char buf[XFRM_ALGO_KEY_BUF_SIZE];
851 	} alg = {};
852 	size_t alen, elen, clen, aelen;
853 	unsigned short type;
854 
855 	alen = strlen(desc->a_algo);
856 	elen = strlen(desc->e_algo);
857 	clen = strlen(desc->c_algo);
858 	aelen = strlen(desc->ae_algo);
859 
860 	/* Verify desc */
861 	switch (desc->proto) {
862 	case IPPROTO_AH:
863 		if (!alen || elen || clen || aelen) {
864 			printk("BUG: buggy ah desc");
865 			return -1;
866 		}
867 		strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN - 1);
868 		if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
869 				sizeof(alg.buf), &alg.u.alg.alg_key_len))
870 			return -1;
871 		type = XFRMA_ALG_AUTH;
872 		break;
873 	case IPPROTO_COMP:
874 		if (!clen || elen || alen || aelen) {
875 			printk("BUG: buggy comp desc");
876 			return -1;
877 		}
878 		strncpy(alg.u.alg.alg_name, desc->c_algo, ALGO_LEN - 1);
879 		if (xfrm_fill_key(desc->c_algo, alg.u.alg.alg_key,
880 				sizeof(alg.buf), &alg.u.alg.alg_key_len))
881 			return -1;
882 		type = XFRMA_ALG_COMP;
883 		break;
884 	case IPPROTO_ESP:
885 		if (!((alen && elen) ^ aelen) || clen) {
886 			printk("BUG: buggy esp desc");
887 			return -1;
888 		}
889 		if (aelen) {
890 			alg.u.aead.alg_icv_len = desc->icv_len;
891 			strncpy(alg.u.aead.alg_name, desc->ae_algo, ALGO_LEN - 1);
892 			if (xfrm_fill_key(desc->ae_algo, alg.u.aead.alg_key,
893 						sizeof(alg.buf), &alg.u.aead.alg_key_len))
894 				return -1;
895 			type = XFRMA_ALG_AEAD;
896 		} else {
897 
898 			strncpy(alg.u.alg.alg_name, desc->e_algo, ALGO_LEN - 1);
899 			type = XFRMA_ALG_CRYPT;
900 			if (xfrm_fill_key(desc->e_algo, alg.u.alg.alg_key,
901 						sizeof(alg.buf), &alg.u.alg.alg_key_len))
902 				return -1;
903 			if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
904 				return -1;
905 
906 			strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN);
907 			type = XFRMA_ALG_AUTH;
908 			if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
909 						sizeof(alg.buf), &alg.u.alg.alg_key_len))
910 				return -1;
911 		}
912 		break;
913 	default:
914 		printk("BUG: unknown proto in desc");
915 		return -1;
916 	}
917 
918 	if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
919 		return -1;
920 
921 	return 0;
922 }
923 
924 static inline uint32_t gen_spi(struct in_addr src)
925 {
926 	return htonl(inet_lnaof(src));
927 }
928 
929 static int xfrm_state_add(int xfrm_sock, uint32_t seq, uint32_t spi,
930 		struct in_addr src, struct in_addr dst,
931 		struct xfrm_desc *desc)
932 {
933 	struct {
934 		struct nlmsghdr		nh;
935 		struct xfrm_usersa_info	info;
936 		char			attrbuf[MAX_PAYLOAD];
937 	} req;
938 
939 	memset(&req, 0, sizeof(req));
940 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
941 	req.nh.nlmsg_type	= XFRM_MSG_NEWSA;
942 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
943 	req.nh.nlmsg_seq	= seq;
944 
945 	/* Fill selector. */
946 	memcpy(&req.info.sel.daddr, &dst, sizeof(dst));
947 	memcpy(&req.info.sel.saddr, &src, sizeof(src));
948 	req.info.sel.family		= AF_INET;
949 	req.info.sel.prefixlen_d	= PREFIX_LEN;
950 	req.info.sel.prefixlen_s	= PREFIX_LEN;
951 
952 	/* Fill id */
953 	memcpy(&req.info.id.daddr, &dst, sizeof(dst));
954 	/* Note: zero-spi cannot be deleted */
955 	req.info.id.spi = spi;
956 	req.info.id.proto	= desc->proto;
957 
958 	memcpy(&req.info.saddr, &src, sizeof(src));
959 
960 	/* Fill lifteme_cfg */
961 	req.info.lft.soft_byte_limit	= XFRM_INF;
962 	req.info.lft.hard_byte_limit	= XFRM_INF;
963 	req.info.lft.soft_packet_limit	= XFRM_INF;
964 	req.info.lft.hard_packet_limit	= XFRM_INF;
965 
966 	req.info.family		= AF_INET;
967 	req.info.mode		= XFRM_MODE_TUNNEL;
968 
969 	if (xfrm_state_pack_algo(&req.nh, sizeof(req), desc))
970 		return -1;
971 
972 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
973 		pr_err("send()");
974 		return -1;
975 	}
976 
977 	return netlink_check_answer(xfrm_sock);
978 }
979 
980 static bool xfrm_usersa_found(struct xfrm_usersa_info *info, uint32_t spi,
981 		struct in_addr src, struct in_addr dst,
982 		struct xfrm_desc *desc)
983 {
984 	if (memcmp(&info->sel.daddr, &dst, sizeof(dst)))
985 		return false;
986 
987 	if (memcmp(&info->sel.saddr, &src, sizeof(src)))
988 		return false;
989 
990 	if (info->sel.family != AF_INET					||
991 			info->sel.prefixlen_d != PREFIX_LEN		||
992 			info->sel.prefixlen_s != PREFIX_LEN)
993 		return false;
994 
995 	if (info->id.spi != spi || info->id.proto != desc->proto)
996 		return false;
997 
998 	if (memcmp(&info->id.daddr, &dst, sizeof(dst)))
999 		return false;
1000 
1001 	if (memcmp(&info->saddr, &src, sizeof(src)))
1002 		return false;
1003 
1004 	if (info->lft.soft_byte_limit != XFRM_INF			||
1005 			info->lft.hard_byte_limit != XFRM_INF		||
1006 			info->lft.soft_packet_limit != XFRM_INF		||
1007 			info->lft.hard_packet_limit != XFRM_INF)
1008 		return false;
1009 
1010 	if (info->family != AF_INET || info->mode != XFRM_MODE_TUNNEL)
1011 		return false;
1012 
1013 	/* XXX: check xfrm algo, see xfrm_state_pack_algo(). */
1014 
1015 	return true;
1016 }
1017 
1018 static int xfrm_state_check(int xfrm_sock, uint32_t seq, uint32_t spi,
1019 		struct in_addr src, struct in_addr dst,
1020 		struct xfrm_desc *desc)
1021 {
1022 	struct {
1023 		struct nlmsghdr		nh;
1024 		char			attrbuf[MAX_PAYLOAD];
1025 	} req;
1026 	struct {
1027 		struct nlmsghdr		nh;
1028 		union {
1029 			struct xfrm_usersa_info	info;
1030 			int error;
1031 		};
1032 		char			attrbuf[MAX_PAYLOAD];
1033 	} answer;
1034 	struct xfrm_address_filter filter = {};
1035 	bool found = false;
1036 
1037 
1038 	memset(&req, 0, sizeof(req));
1039 	req.nh.nlmsg_len	= NLMSG_LENGTH(0);
1040 	req.nh.nlmsg_type	= XFRM_MSG_GETSA;
1041 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_DUMP;
1042 	req.nh.nlmsg_seq	= seq;
1043 
1044 	/*
1045 	 * Add dump filter by source address as there may be other tunnels
1046 	 * in this netns (if tests run in parallel).
1047 	 */
1048 	filter.family = AF_INET;
1049 	filter.splen = 0x1f;	/* 0xffffffff mask see addr_match() */
1050 	memcpy(&filter.saddr, &src, sizeof(src));
1051 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_ADDRESS_FILTER,
1052 				&filter, sizeof(filter)))
1053 		return -1;
1054 
1055 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1056 		pr_err("send()");
1057 		return -1;
1058 	}
1059 
1060 	while (1) {
1061 		if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1062 			pr_err("recv()");
1063 			return -1;
1064 		}
1065 		if (answer.nh.nlmsg_type == NLMSG_ERROR) {
1066 			printk("NLMSG_ERROR: %d: %s",
1067 				answer.error, strerror(-answer.error));
1068 			return -1;
1069 		} else if (answer.nh.nlmsg_type == NLMSG_DONE) {
1070 			if (found)
1071 				return 0;
1072 			printk("didn't find allocated xfrm state in dump");
1073 			return -1;
1074 		} else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1075 			if (xfrm_usersa_found(&answer.info, spi, src, dst, desc))
1076 				found = true;
1077 		}
1078 	}
1079 }
1080 
1081 static int xfrm_set(int xfrm_sock, uint32_t *seq,
1082 		struct in_addr src, struct in_addr dst,
1083 		struct in_addr tunsrc, struct in_addr tundst,
1084 		struct xfrm_desc *desc)
1085 {
1086 	int err;
1087 
1088 	err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1089 	if (err) {
1090 		printk("Failed to add xfrm state");
1091 		return -1;
1092 	}
1093 
1094 	err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1095 	if (err) {
1096 		printk("Failed to add xfrm state");
1097 		return -1;
1098 	}
1099 
1100 	/* Check dumps for XFRM_MSG_GETSA */
1101 	err = xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1102 	err |= xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1103 	if (err) {
1104 		printk("Failed to check xfrm state");
1105 		return -1;
1106 	}
1107 
1108 	return 0;
1109 }
1110 
1111 static int xfrm_policy_add(int xfrm_sock, uint32_t seq, uint32_t spi,
1112 		struct in_addr src, struct in_addr dst, uint8_t dir,
1113 		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1114 {
1115 	struct {
1116 		struct nlmsghdr			nh;
1117 		struct xfrm_userpolicy_info	info;
1118 		char				attrbuf[MAX_PAYLOAD];
1119 	} req;
1120 	struct xfrm_user_tmpl tmpl;
1121 
1122 	memset(&req, 0, sizeof(req));
1123 	memset(&tmpl, 0, sizeof(tmpl));
1124 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
1125 	req.nh.nlmsg_type	= XFRM_MSG_NEWPOLICY;
1126 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1127 	req.nh.nlmsg_seq	= seq;
1128 
1129 	/* Fill selector. */
1130 	memcpy(&req.info.sel.daddr, &dst, sizeof(tundst));
1131 	memcpy(&req.info.sel.saddr, &src, sizeof(tunsrc));
1132 	req.info.sel.family		= AF_INET;
1133 	req.info.sel.prefixlen_d	= PREFIX_LEN;
1134 	req.info.sel.prefixlen_s	= PREFIX_LEN;
1135 
1136 	/* Fill lifteme_cfg */
1137 	req.info.lft.soft_byte_limit	= XFRM_INF;
1138 	req.info.lft.hard_byte_limit	= XFRM_INF;
1139 	req.info.lft.soft_packet_limit	= XFRM_INF;
1140 	req.info.lft.hard_packet_limit	= XFRM_INF;
1141 
1142 	req.info.dir = dir;
1143 
1144 	/* Fill tmpl */
1145 	memcpy(&tmpl.id.daddr, &dst, sizeof(dst));
1146 	/* Note: zero-spi cannot be deleted */
1147 	tmpl.id.spi = spi;
1148 	tmpl.id.proto	= proto;
1149 	tmpl.family	= AF_INET;
1150 	memcpy(&tmpl.saddr, &src, sizeof(src));
1151 	tmpl.mode	= XFRM_MODE_TUNNEL;
1152 	tmpl.aalgos = (~(uint32_t)0);
1153 	tmpl.ealgos = (~(uint32_t)0);
1154 	tmpl.calgos = (~(uint32_t)0);
1155 
1156 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &tmpl, sizeof(tmpl)))
1157 		return -1;
1158 
1159 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1160 		pr_err("send()");
1161 		return -1;
1162 	}
1163 
1164 	return netlink_check_answer(xfrm_sock);
1165 }
1166 
1167 static int xfrm_prepare(int xfrm_sock, uint32_t *seq,
1168 		struct in_addr src, struct in_addr dst,
1169 		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1170 {
1171 	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1172 				XFRM_POLICY_OUT, tunsrc, tundst, proto)) {
1173 		printk("Failed to add xfrm policy");
1174 		return -1;
1175 	}
1176 
1177 	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src,
1178 				XFRM_POLICY_IN, tunsrc, tundst, proto)) {
1179 		printk("Failed to add xfrm policy");
1180 		return -1;
1181 	}
1182 
1183 	return 0;
1184 }
1185 
1186 static int xfrm_policy_del(int xfrm_sock, uint32_t seq,
1187 		struct in_addr src, struct in_addr dst, uint8_t dir,
1188 		struct in_addr tunsrc, struct in_addr tundst)
1189 {
1190 	struct {
1191 		struct nlmsghdr			nh;
1192 		struct xfrm_userpolicy_id	id;
1193 		char				attrbuf[MAX_PAYLOAD];
1194 	} req;
1195 
1196 	memset(&req, 0, sizeof(req));
1197 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.id));
1198 	req.nh.nlmsg_type	= XFRM_MSG_DELPOLICY;
1199 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1200 	req.nh.nlmsg_seq	= seq;
1201 
1202 	/* Fill id */
1203 	memcpy(&req.id.sel.daddr, &dst, sizeof(tundst));
1204 	memcpy(&req.id.sel.saddr, &src, sizeof(tunsrc));
1205 	req.id.sel.family		= AF_INET;
1206 	req.id.sel.prefixlen_d		= PREFIX_LEN;
1207 	req.id.sel.prefixlen_s		= PREFIX_LEN;
1208 	req.id.dir = dir;
1209 
1210 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1211 		pr_err("send()");
1212 		return -1;
1213 	}
1214 
1215 	return netlink_check_answer(xfrm_sock);
1216 }
1217 
1218 static int xfrm_cleanup(int xfrm_sock, uint32_t *seq,
1219 		struct in_addr src, struct in_addr dst,
1220 		struct in_addr tunsrc, struct in_addr tundst)
1221 {
1222 	if (xfrm_policy_del(xfrm_sock, (*seq)++, src, dst,
1223 				XFRM_POLICY_OUT, tunsrc, tundst)) {
1224 		printk("Failed to add xfrm policy");
1225 		return -1;
1226 	}
1227 
1228 	if (xfrm_policy_del(xfrm_sock, (*seq)++, dst, src,
1229 				XFRM_POLICY_IN, tunsrc, tundst)) {
1230 		printk("Failed to add xfrm policy");
1231 		return -1;
1232 	}
1233 
1234 	return 0;
1235 }
1236 
1237 static int xfrm_state_del(int xfrm_sock, uint32_t seq, uint32_t spi,
1238 		struct in_addr src, struct in_addr dst, uint8_t proto)
1239 {
1240 	struct {
1241 		struct nlmsghdr		nh;
1242 		struct xfrm_usersa_id	id;
1243 		char			attrbuf[MAX_PAYLOAD];
1244 	} req;
1245 	xfrm_address_t saddr = {};
1246 
1247 	memset(&req, 0, sizeof(req));
1248 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.id));
1249 	req.nh.nlmsg_type	= XFRM_MSG_DELSA;
1250 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1251 	req.nh.nlmsg_seq	= seq;
1252 
1253 	memcpy(&req.id.daddr, &dst, sizeof(dst));
1254 	req.id.family		= AF_INET;
1255 	req.id.proto		= proto;
1256 	/* Note: zero-spi cannot be deleted */
1257 	req.id.spi = spi;
1258 
1259 	memcpy(&saddr, &src, sizeof(src));
1260 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SRCADDR, &saddr, sizeof(saddr)))
1261 		return -1;
1262 
1263 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1264 		pr_err("send()");
1265 		return -1;
1266 	}
1267 
1268 	return netlink_check_answer(xfrm_sock);
1269 }
1270 
1271 static int xfrm_delete(int xfrm_sock, uint32_t *seq,
1272 		struct in_addr src, struct in_addr dst,
1273 		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1274 {
1275 	if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), src, dst, proto)) {
1276 		printk("Failed to remove xfrm state");
1277 		return -1;
1278 	}
1279 
1280 	if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), dst, src, proto)) {
1281 		printk("Failed to remove xfrm state");
1282 		return -1;
1283 	}
1284 
1285 	return 0;
1286 }
1287 
1288 static int xfrm_state_allocspi(int xfrm_sock, uint32_t *seq,
1289 		uint32_t spi, uint8_t proto)
1290 {
1291 	struct {
1292 		struct nlmsghdr			nh;
1293 		struct xfrm_userspi_info	spi;
1294 	} req;
1295 	struct {
1296 		struct nlmsghdr			nh;
1297 		union {
1298 			struct xfrm_usersa_info	info;
1299 			int error;
1300 		};
1301 	} answer;
1302 
1303 	memset(&req, 0, sizeof(req));
1304 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.spi));
1305 	req.nh.nlmsg_type	= XFRM_MSG_ALLOCSPI;
1306 	req.nh.nlmsg_flags	= NLM_F_REQUEST;
1307 	req.nh.nlmsg_seq	= (*seq)++;
1308 
1309 	req.spi.info.family	= AF_INET;
1310 	req.spi.min		= spi;
1311 	req.spi.max		= spi;
1312 	req.spi.info.id.proto	= proto;
1313 
1314 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1315 		pr_err("send()");
1316 		return KSFT_FAIL;
1317 	}
1318 
1319 	if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1320 		pr_err("recv()");
1321 		return KSFT_FAIL;
1322 	} else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1323 		uint32_t new_spi = htonl(answer.info.id.spi);
1324 
1325 		if (new_spi != spi) {
1326 			printk("allocated spi is different from requested: %#x != %#x",
1327 					new_spi, spi);
1328 			return KSFT_FAIL;
1329 		}
1330 		return KSFT_PASS;
1331 	} else if (answer.nh.nlmsg_type != NLMSG_ERROR) {
1332 		printk("expected NLMSG_ERROR, got %d", (int)answer.nh.nlmsg_type);
1333 		return KSFT_FAIL;
1334 	}
1335 
1336 	printk("NLMSG_ERROR: %d: %s", answer.error, strerror(-answer.error));
1337 	return (answer.error) ? KSFT_FAIL : KSFT_PASS;
1338 }
1339 
1340 static int netlink_sock_bind(int *sock, uint32_t *seq, int proto, uint32_t groups)
1341 {
1342 	struct sockaddr_nl snl = {};
1343 	socklen_t addr_len;
1344 	int ret = -1;
1345 
1346 	snl.nl_family = AF_NETLINK;
1347 	snl.nl_groups = groups;
1348 
1349 	if (netlink_sock(sock, seq, proto)) {
1350 		printk("Failed to open xfrm netlink socket");
1351 		return -1;
1352 	}
1353 
1354 	if (bind(*sock, (struct sockaddr *)&snl, sizeof(snl)) < 0) {
1355 		pr_err("bind()");
1356 		goto out_close;
1357 	}
1358 
1359 	addr_len = sizeof(snl);
1360 	if (getsockname(*sock, (struct sockaddr *)&snl, &addr_len) < 0) {
1361 		pr_err("getsockname()");
1362 		goto out_close;
1363 	}
1364 	if (addr_len != sizeof(snl)) {
1365 		printk("Wrong address length %d", addr_len);
1366 		goto out_close;
1367 	}
1368 	if (snl.nl_family != AF_NETLINK) {
1369 		printk("Wrong address family %d", snl.nl_family);
1370 		goto out_close;
1371 	}
1372 	return 0;
1373 
1374 out_close:
1375 	close(*sock);
1376 	return ret;
1377 }
1378 
1379 static int xfrm_monitor_acquire(int xfrm_sock, uint32_t *seq, unsigned int nr)
1380 {
1381 	struct {
1382 		struct nlmsghdr nh;
1383 		union {
1384 			struct xfrm_user_acquire acq;
1385 			int error;
1386 		};
1387 		char attrbuf[MAX_PAYLOAD];
1388 	} req;
1389 	struct xfrm_user_tmpl xfrm_tmpl = {};
1390 	int xfrm_listen = -1, ret = KSFT_FAIL;
1391 	uint32_t seq_listen;
1392 
1393 	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_ACQUIRE))
1394 		return KSFT_FAIL;
1395 
1396 	memset(&req, 0, sizeof(req));
1397 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.acq));
1398 	req.nh.nlmsg_type	= XFRM_MSG_ACQUIRE;
1399 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1400 	req.nh.nlmsg_seq	= (*seq)++;
1401 
1402 	req.acq.policy.sel.family	= AF_INET;
1403 	req.acq.aalgos	= 0xfeed;
1404 	req.acq.ealgos	= 0xbaad;
1405 	req.acq.calgos	= 0xbabe;
1406 
1407 	xfrm_tmpl.family = AF_INET;
1408 	xfrm_tmpl.id.proto = IPPROTO_ESP;
1409 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &xfrm_tmpl, sizeof(xfrm_tmpl)))
1410 		goto out_close;
1411 
1412 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1413 		pr_err("send()");
1414 		goto out_close;
1415 	}
1416 
1417 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1418 		pr_err("recv()");
1419 		goto out_close;
1420 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1421 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1422 		goto out_close;
1423 	}
1424 
1425 	if (req.error) {
1426 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1427 		ret = req.error;
1428 		goto out_close;
1429 	}
1430 
1431 	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1432 		pr_err("recv()");
1433 		goto out_close;
1434 	}
1435 
1436 	if (req.acq.aalgos != 0xfeed || req.acq.ealgos != 0xbaad
1437 			|| req.acq.calgos != 0xbabe) {
1438 		printk("xfrm_user_acquire has changed  %x %x %x",
1439 				req.acq.aalgos, req.acq.ealgos, req.acq.calgos);
1440 		goto out_close;
1441 	}
1442 
1443 	ret = KSFT_PASS;
1444 out_close:
1445 	close(xfrm_listen);
1446 	return ret;
1447 }
1448 
1449 static int xfrm_expire_state(int xfrm_sock, uint32_t *seq,
1450 		unsigned int nr, struct xfrm_desc *desc)
1451 {
1452 	struct {
1453 		struct nlmsghdr nh;
1454 		union {
1455 			struct xfrm_user_expire expire;
1456 			int error;
1457 		};
1458 	} req;
1459 	struct in_addr src, dst;
1460 	int xfrm_listen = -1, ret = KSFT_FAIL;
1461 	uint32_t seq_listen;
1462 
1463 	src = inet_makeaddr(INADDR_B, child_ip(nr));
1464 	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1465 
1466 	if (xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc)) {
1467 		printk("Failed to add xfrm state");
1468 		return KSFT_FAIL;
1469 	}
1470 
1471 	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1472 		return KSFT_FAIL;
1473 
1474 	memset(&req, 0, sizeof(req));
1475 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.expire));
1476 	req.nh.nlmsg_type	= XFRM_MSG_EXPIRE;
1477 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1478 	req.nh.nlmsg_seq	= (*seq)++;
1479 
1480 	memcpy(&req.expire.state.id.daddr, &dst, sizeof(dst));
1481 	req.expire.state.id.spi		= gen_spi(src);
1482 	req.expire.state.id.proto	= desc->proto;
1483 	req.expire.state.family		= AF_INET;
1484 	req.expire.hard			= 0xff;
1485 
1486 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1487 		pr_err("send()");
1488 		goto out_close;
1489 	}
1490 
1491 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1492 		pr_err("recv()");
1493 		goto out_close;
1494 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1495 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1496 		goto out_close;
1497 	}
1498 
1499 	if (req.error) {
1500 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1501 		ret = req.error;
1502 		goto out_close;
1503 	}
1504 
1505 	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1506 		pr_err("recv()");
1507 		goto out_close;
1508 	}
1509 
1510 	if (req.expire.hard != 0x1) {
1511 		printk("expire.hard is not set: %x", req.expire.hard);
1512 		goto out_close;
1513 	}
1514 
1515 	ret = KSFT_PASS;
1516 out_close:
1517 	close(xfrm_listen);
1518 	return ret;
1519 }
1520 
1521 static int xfrm_expire_policy(int xfrm_sock, uint32_t *seq,
1522 		unsigned int nr, struct xfrm_desc *desc)
1523 {
1524 	struct {
1525 		struct nlmsghdr nh;
1526 		union {
1527 			struct xfrm_user_polexpire expire;
1528 			int error;
1529 		};
1530 	} req;
1531 	struct in_addr src, dst, tunsrc, tundst;
1532 	int xfrm_listen = -1, ret = KSFT_FAIL;
1533 	uint32_t seq_listen;
1534 
1535 	src = inet_makeaddr(INADDR_B, child_ip(nr));
1536 	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1537 	tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1538 	tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1539 
1540 	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1541 				XFRM_POLICY_OUT, tunsrc, tundst, desc->proto)) {
1542 		printk("Failed to add xfrm policy");
1543 		return KSFT_FAIL;
1544 	}
1545 
1546 	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1547 		return KSFT_FAIL;
1548 
1549 	memset(&req, 0, sizeof(req));
1550 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.expire));
1551 	req.nh.nlmsg_type	= XFRM_MSG_POLEXPIRE;
1552 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1553 	req.nh.nlmsg_seq	= (*seq)++;
1554 
1555 	/* Fill selector. */
1556 	memcpy(&req.expire.pol.sel.daddr, &dst, sizeof(tundst));
1557 	memcpy(&req.expire.pol.sel.saddr, &src, sizeof(tunsrc));
1558 	req.expire.pol.sel.family	= AF_INET;
1559 	req.expire.pol.sel.prefixlen_d	= PREFIX_LEN;
1560 	req.expire.pol.sel.prefixlen_s	= PREFIX_LEN;
1561 	req.expire.pol.dir		= XFRM_POLICY_OUT;
1562 	req.expire.hard			= 0xff;
1563 
1564 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1565 		pr_err("send()");
1566 		goto out_close;
1567 	}
1568 
1569 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1570 		pr_err("recv()");
1571 		goto out_close;
1572 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1573 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1574 		goto out_close;
1575 	}
1576 
1577 	if (req.error) {
1578 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1579 		ret = req.error;
1580 		goto out_close;
1581 	}
1582 
1583 	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1584 		pr_err("recv()");
1585 		goto out_close;
1586 	}
1587 
1588 	if (req.expire.hard != 0x1) {
1589 		printk("expire.hard is not set: %x", req.expire.hard);
1590 		goto out_close;
1591 	}
1592 
1593 	ret = KSFT_PASS;
1594 out_close:
1595 	close(xfrm_listen);
1596 	return ret;
1597 }
1598 
1599 static int xfrm_spdinfo_set_thresh(int xfrm_sock, uint32_t *seq,
1600 		unsigned thresh4_l, unsigned thresh4_r,
1601 		unsigned thresh6_l, unsigned thresh6_r,
1602 		bool add_bad_attr)
1603 
1604 {
1605 	struct {
1606 		struct nlmsghdr		nh;
1607 		union {
1608 			uint32_t	unused;
1609 			int		error;
1610 		};
1611 		char			attrbuf[MAX_PAYLOAD];
1612 	} req;
1613 	struct xfrmu_spdhthresh thresh;
1614 
1615 	memset(&req, 0, sizeof(req));
1616 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.unused));
1617 	req.nh.nlmsg_type	= XFRM_MSG_NEWSPDINFO;
1618 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1619 	req.nh.nlmsg_seq	= (*seq)++;
1620 
1621 	thresh.lbits = thresh4_l;
1622 	thresh.rbits = thresh4_r;
1623 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SPD_IPV4_HTHRESH, &thresh, sizeof(thresh)))
1624 		return -1;
1625 
1626 	thresh.lbits = thresh6_l;
1627 	thresh.rbits = thresh6_r;
1628 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SPD_IPV6_HTHRESH, &thresh, sizeof(thresh)))
1629 		return -1;
1630 
1631 	if (add_bad_attr) {
1632 		BUILD_BUG_ON(XFRMA_IF_ID <= XFRMA_SPD_MAX + 1);
1633 		if (rtattr_pack(&req.nh, sizeof(req), XFRMA_IF_ID, NULL, 0)) {
1634 			pr_err("adding attribute failed: no space");
1635 			return -1;
1636 		}
1637 	}
1638 
1639 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1640 		pr_err("send()");
1641 		return -1;
1642 	}
1643 
1644 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1645 		pr_err("recv()");
1646 		return -1;
1647 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1648 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1649 		return -1;
1650 	}
1651 
1652 	if (req.error) {
1653 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1654 		return -1;
1655 	}
1656 
1657 	return 0;
1658 }
1659 
1660 static int xfrm_spdinfo_attrs(int xfrm_sock, uint32_t *seq)
1661 {
1662 	struct {
1663 		struct nlmsghdr			nh;
1664 		union {
1665 			uint32_t	unused;
1666 			int		error;
1667 		};
1668 		char			attrbuf[MAX_PAYLOAD];
1669 	} req;
1670 
1671 	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 31, 120, 16, false)) {
1672 		pr_err("Can't set SPD HTHRESH");
1673 		return KSFT_FAIL;
1674 	}
1675 
1676 	memset(&req, 0, sizeof(req));
1677 
1678 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.unused));
1679 	req.nh.nlmsg_type	= XFRM_MSG_GETSPDINFO;
1680 	req.nh.nlmsg_flags	= NLM_F_REQUEST;
1681 	req.nh.nlmsg_seq	= (*seq)++;
1682 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1683 		pr_err("send()");
1684 		return KSFT_FAIL;
1685 	}
1686 
1687 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1688 		pr_err("recv()");
1689 		return KSFT_FAIL;
1690 	} else if (req.nh.nlmsg_type == XFRM_MSG_NEWSPDINFO) {
1691 		size_t len = NLMSG_PAYLOAD(&req.nh, sizeof(req.unused));
1692 		struct rtattr *attr = (void *)req.attrbuf;
1693 		int got_thresh = 0;
1694 
1695 		for (; RTA_OK(attr, len); attr = RTA_NEXT(attr, len)) {
1696 			if (attr->rta_type == XFRMA_SPD_IPV4_HTHRESH) {
1697 				struct xfrmu_spdhthresh *t = RTA_DATA(attr);
1698 
1699 				got_thresh++;
1700 				if (t->lbits != 32 || t->rbits != 31) {
1701 					pr_err("thresh differ: %u, %u",
1702 							t->lbits, t->rbits);
1703 					return KSFT_FAIL;
1704 				}
1705 			}
1706 			if (attr->rta_type == XFRMA_SPD_IPV6_HTHRESH) {
1707 				struct xfrmu_spdhthresh *t = RTA_DATA(attr);
1708 
1709 				got_thresh++;
1710 				if (t->lbits != 120 || t->rbits != 16) {
1711 					pr_err("thresh differ: %u, %u",
1712 							t->lbits, t->rbits);
1713 					return KSFT_FAIL;
1714 				}
1715 			}
1716 		}
1717 		if (got_thresh != 2) {
1718 			pr_err("only %d thresh returned by XFRM_MSG_GETSPDINFO", got_thresh);
1719 			return KSFT_FAIL;
1720 		}
1721 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1722 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1723 		return KSFT_FAIL;
1724 	} else {
1725 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1726 		return -1;
1727 	}
1728 
1729 	/* Restore the default */
1730 	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 32, 128, 128, false)) {
1731 		pr_err("Can't restore SPD HTHRESH");
1732 		return KSFT_FAIL;
1733 	}
1734 
1735 	/*
1736 	 * At this moment xfrm uses nlmsg_parse_deprecated(), which
1737 	 * implies NL_VALIDATE_LIBERAL - ignoring attributes with
1738 	 * (type > maxtype). nla_parse_depricated_strict() would enforce
1739 	 * it. Or even stricter nla_parse().
1740 	 * Right now it's not expected to fail, but to be ignored.
1741 	 */
1742 	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 32, 128, 128, true))
1743 		return KSFT_PASS;
1744 
1745 	return KSFT_PASS;
1746 }
1747 
1748 static int child_serv(int xfrm_sock, uint32_t *seq,
1749 		unsigned int nr, int cmd_fd, void *buf, struct xfrm_desc *desc)
1750 {
1751 	struct in_addr src, dst, tunsrc, tundst;
1752 	struct test_desc msg;
1753 	int ret = KSFT_FAIL;
1754 
1755 	src = inet_makeaddr(INADDR_B, child_ip(nr));
1756 	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1757 	tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1758 	tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1759 
1760 	/* UDP pinging without xfrm */
1761 	if (do_ping(cmd_fd, buf, page_size, src, true, 0, 0, udp_ping_send)) {
1762 		printk("ping failed before setting xfrm");
1763 		return KSFT_FAIL;
1764 	}
1765 
1766 	memset(&msg, 0, sizeof(msg));
1767 	msg.type = MSG_XFRM_PREPARE;
1768 	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1769 	write_msg(cmd_fd, &msg, 1);
1770 
1771 	if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1772 		printk("failed to prepare xfrm");
1773 		goto cleanup;
1774 	}
1775 
1776 	memset(&msg, 0, sizeof(msg));
1777 	msg.type = MSG_XFRM_ADD;
1778 	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1779 	write_msg(cmd_fd, &msg, 1);
1780 	if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1781 		printk("failed to set xfrm");
1782 		goto delete;
1783 	}
1784 
1785 	/* UDP pinging with xfrm tunnel */
1786 	if (do_ping(cmd_fd, buf, page_size, tunsrc,
1787 				true, 0, 0, udp_ping_send)) {
1788 		printk("ping failed for xfrm");
1789 		goto delete;
1790 	}
1791 
1792 	ret = KSFT_PASS;
1793 delete:
1794 	/* xfrm delete */
1795 	memset(&msg, 0, sizeof(msg));
1796 	msg.type = MSG_XFRM_DEL;
1797 	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1798 	write_msg(cmd_fd, &msg, 1);
1799 
1800 	if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1801 		printk("failed ping to remove xfrm");
1802 		ret = KSFT_FAIL;
1803 	}
1804 
1805 cleanup:
1806 	memset(&msg, 0, sizeof(msg));
1807 	msg.type = MSG_XFRM_CLEANUP;
1808 	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1809 	write_msg(cmd_fd, &msg, 1);
1810 	if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1811 		printk("failed ping to cleanup xfrm");
1812 		ret = KSFT_FAIL;
1813 	}
1814 	return ret;
1815 }
1816 
1817 static int child_f(unsigned int nr, int test_desc_fd, int cmd_fd, void *buf)
1818 {
1819 	struct xfrm_desc desc;
1820 	struct test_desc msg;
1821 	int xfrm_sock = -1;
1822 	uint32_t seq;
1823 
1824 	if (switch_ns(nsfd_childa))
1825 		exit(KSFT_FAIL);
1826 
1827 	if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1828 		printk("Failed to open xfrm netlink socket");
1829 		exit(KSFT_FAIL);
1830 	}
1831 
1832 	/* Check that seq sock is ready, just for sure. */
1833 	memset(&msg, 0, sizeof(msg));
1834 	msg.type = MSG_ACK;
1835 	write_msg(cmd_fd, &msg, 1);
1836 	read_msg(cmd_fd, &msg, 1);
1837 	if (msg.type != MSG_ACK) {
1838 		printk("Ack failed");
1839 		exit(KSFT_FAIL);
1840 	}
1841 
1842 	for (;;) {
1843 		ssize_t received = read(test_desc_fd, &desc, sizeof(desc));
1844 		int ret;
1845 
1846 		if (received == 0) /* EOF */
1847 			break;
1848 
1849 		if (received != sizeof(desc)) {
1850 			pr_err("read() returned %zd", received);
1851 			exit(KSFT_FAIL);
1852 		}
1853 
1854 		switch (desc.type) {
1855 		case CREATE_TUNNEL:
1856 			ret = child_serv(xfrm_sock, &seq, nr,
1857 					 cmd_fd, buf, &desc);
1858 			break;
1859 		case ALLOCATE_SPI:
1860 			ret = xfrm_state_allocspi(xfrm_sock, &seq,
1861 						  -1, desc.proto);
1862 			break;
1863 		case MONITOR_ACQUIRE:
1864 			ret = xfrm_monitor_acquire(xfrm_sock, &seq, nr);
1865 			break;
1866 		case EXPIRE_STATE:
1867 			ret = xfrm_expire_state(xfrm_sock, &seq, nr, &desc);
1868 			break;
1869 		case EXPIRE_POLICY:
1870 			ret = xfrm_expire_policy(xfrm_sock, &seq, nr, &desc);
1871 			break;
1872 		case SPDINFO_ATTRS:
1873 			ret = xfrm_spdinfo_attrs(xfrm_sock, &seq);
1874 			break;
1875 		default:
1876 			printk("Unknown desc type %d", desc.type);
1877 			exit(KSFT_FAIL);
1878 		}
1879 		write_test_result(ret, &desc);
1880 	}
1881 
1882 	close(xfrm_sock);
1883 
1884 	msg.type = MSG_EXIT;
1885 	write_msg(cmd_fd, &msg, 1);
1886 	exit(KSFT_PASS);
1887 }
1888 
1889 static void grand_child_serv(unsigned int nr, int cmd_fd, void *buf,
1890 		struct test_desc *msg, int xfrm_sock, uint32_t *seq)
1891 {
1892 	struct in_addr src, dst, tunsrc, tundst;
1893 	bool tun_reply;
1894 	struct xfrm_desc *desc = &msg->body.xfrm_desc;
1895 
1896 	src = inet_makeaddr(INADDR_B, grchild_ip(nr));
1897 	dst = inet_makeaddr(INADDR_B, child_ip(nr));
1898 	tunsrc = inet_makeaddr(INADDR_A, grchild_ip(nr));
1899 	tundst = inet_makeaddr(INADDR_A, child_ip(nr));
1900 
1901 	switch (msg->type) {
1902 	case MSG_EXIT:
1903 		exit(KSFT_PASS);
1904 	case MSG_ACK:
1905 		write_msg(cmd_fd, msg, 1);
1906 		break;
1907 	case MSG_PING:
1908 		tun_reply = memcmp(&dst, &msg->body.ping.reply_ip, sizeof(in_addr_t));
1909 		/* UDP pinging without xfrm */
1910 		if (do_ping(cmd_fd, buf, page_size, tun_reply ? tunsrc : src,
1911 				false, msg->body.ping.port,
1912 				msg->body.ping.reply_ip, udp_ping_reply)) {
1913 			printk("ping failed before setting xfrm");
1914 		}
1915 		break;
1916 	case MSG_XFRM_PREPARE:
1917 		if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst,
1918 					desc->proto)) {
1919 			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1920 			printk("failed to prepare xfrm");
1921 		}
1922 		break;
1923 	case MSG_XFRM_ADD:
1924 		if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1925 			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1926 			printk("failed to set xfrm");
1927 		}
1928 		break;
1929 	case MSG_XFRM_DEL:
1930 		if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst,
1931 					desc->proto)) {
1932 			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1933 			printk("failed to remove xfrm");
1934 		}
1935 		break;
1936 	case MSG_XFRM_CLEANUP:
1937 		if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1938 			printk("failed to cleanup xfrm");
1939 		}
1940 		break;
1941 	default:
1942 		printk("got unknown msg type %d", msg->type);
1943 	}
1944 }
1945 
1946 static int grand_child_f(unsigned int nr, int cmd_fd, void *buf)
1947 {
1948 	struct test_desc msg;
1949 	int xfrm_sock = -1;
1950 	uint32_t seq;
1951 
1952 	if (switch_ns(nsfd_childb))
1953 		exit(KSFT_FAIL);
1954 
1955 	if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1956 		printk("Failed to open xfrm netlink socket");
1957 		exit(KSFT_FAIL);
1958 	}
1959 
1960 	do {
1961 		read_msg(cmd_fd, &msg, 1);
1962 		grand_child_serv(nr, cmd_fd, buf, &msg, xfrm_sock, &seq);
1963 	} while (1);
1964 
1965 	close(xfrm_sock);
1966 	exit(KSFT_FAIL);
1967 }
1968 
1969 static int start_child(unsigned int nr, char *veth, int test_desc_fd[2])
1970 {
1971 	int cmd_sock[2];
1972 	void *data_map;
1973 	pid_t child;
1974 
1975 	if (init_child(nsfd_childa, veth, child_ip(nr), grchild_ip(nr)))
1976 		return -1;
1977 
1978 	if (init_child(nsfd_childb, veth, grchild_ip(nr), child_ip(nr)))
1979 		return -1;
1980 
1981 	child = fork();
1982 	if (child < 0) {
1983 		pr_err("fork()");
1984 		return -1;
1985 	} else if (child) {
1986 		/* in parent - selftest */
1987 		return switch_ns(nsfd_parent);
1988 	}
1989 
1990 	if (close(test_desc_fd[1])) {
1991 		pr_err("close()");
1992 		return -1;
1993 	}
1994 
1995 	/* child */
1996 	data_map = mmap(0, page_size, PROT_READ | PROT_WRITE,
1997 			MAP_SHARED | MAP_ANONYMOUS, -1, 0);
1998 	if (data_map == MAP_FAILED) {
1999 		pr_err("mmap()");
2000 		return -1;
2001 	}
2002 
2003 	randomize_buffer(data_map, page_size);
2004 
2005 	if (socketpair(PF_LOCAL, SOCK_SEQPACKET, 0, cmd_sock)) {
2006 		pr_err("socketpair()");
2007 		return -1;
2008 	}
2009 
2010 	child = fork();
2011 	if (child < 0) {
2012 		pr_err("fork()");
2013 		return -1;
2014 	} else if (child) {
2015 		if (close(cmd_sock[0])) {
2016 			pr_err("close()");
2017 			return -1;
2018 		}
2019 		return child_f(nr, test_desc_fd[0], cmd_sock[1], data_map);
2020 	}
2021 	if (close(cmd_sock[1])) {
2022 		pr_err("close()");
2023 		return -1;
2024 	}
2025 	return grand_child_f(nr, cmd_sock[0], data_map);
2026 }
2027 
2028 static void exit_usage(char **argv)
2029 {
2030 	printk("Usage: %s [nr_process]", argv[0]);
2031 	exit(KSFT_FAIL);
2032 }
2033 
2034 static int __write_desc(int test_desc_fd, struct xfrm_desc *desc)
2035 {
2036 	ssize_t ret;
2037 
2038 	ret = write(test_desc_fd, desc, sizeof(*desc));
2039 
2040 	if (ret == sizeof(*desc))
2041 		return 0;
2042 
2043 	pr_err("Writing test's desc failed %ld", ret);
2044 
2045 	return -1;
2046 }
2047 
2048 static int write_desc(int proto, int test_desc_fd,
2049 		char *a, char *e, char *c, char *ae)
2050 {
2051 	struct xfrm_desc desc = {};
2052 
2053 	desc.type = CREATE_TUNNEL;
2054 	desc.proto = proto;
2055 
2056 	if (a)
2057 		strncpy(desc.a_algo, a, ALGO_LEN - 1);
2058 	if (e)
2059 		strncpy(desc.e_algo, e, ALGO_LEN - 1);
2060 	if (c)
2061 		strncpy(desc.c_algo, c, ALGO_LEN - 1);
2062 	if (ae)
2063 		strncpy(desc.ae_algo, ae, ALGO_LEN - 1);
2064 
2065 	return __write_desc(test_desc_fd, &desc);
2066 }
2067 
2068 int proto_list[] = { IPPROTO_AH, IPPROTO_COMP, IPPROTO_ESP };
2069 char *ah_list[] = {
2070 	"digest_null", "hmac(md5)", "hmac(sha1)", "hmac(sha256)",
2071 	"hmac(sha384)", "hmac(sha512)", "hmac(rmd160)",
2072 	"xcbc(aes)", "cmac(aes)"
2073 };
2074 char *comp_list[] = {
2075 	"deflate",
2076 #if 0
2077 	/* No compression backend realization */
2078 	"lzs", "lzjh"
2079 #endif
2080 };
2081 char *e_list[] = {
2082 	"ecb(cipher_null)", "cbc(des)", "cbc(des3_ede)", "cbc(cast5)",
2083 	"cbc(blowfish)", "cbc(aes)", "cbc(serpent)", "cbc(camellia)",
2084 	"cbc(twofish)", "rfc3686(ctr(aes))"
2085 };
2086 char *ae_list[] = {
2087 #if 0
2088 	/* not implemented */
2089 	"rfc4106(gcm(aes))", "rfc4309(ccm(aes))", "rfc4543(gcm(aes))",
2090 	"rfc7539esp(chacha20,poly1305)"
2091 #endif
2092 };
2093 
2094 const unsigned int proto_plan = ARRAY_SIZE(ah_list) + ARRAY_SIZE(comp_list) \
2095 				+ (ARRAY_SIZE(ah_list) * ARRAY_SIZE(e_list)) \
2096 				+ ARRAY_SIZE(ae_list);
2097 
2098 static int write_proto_plan(int fd, int proto)
2099 {
2100 	unsigned int i;
2101 
2102 	switch (proto) {
2103 	case IPPROTO_AH:
2104 		for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
2105 			if (write_desc(proto, fd, ah_list[i], 0, 0, 0))
2106 				return -1;
2107 		}
2108 		break;
2109 	case IPPROTO_COMP:
2110 		for (i = 0; i < ARRAY_SIZE(comp_list); i++) {
2111 			if (write_desc(proto, fd, 0, 0, comp_list[i], 0))
2112 				return -1;
2113 		}
2114 		break;
2115 	case IPPROTO_ESP:
2116 		for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
2117 			int j;
2118 
2119 			for (j = 0; j < ARRAY_SIZE(e_list); j++) {
2120 				if (write_desc(proto, fd, ah_list[i],
2121 							e_list[j], 0, 0))
2122 					return -1;
2123 			}
2124 		}
2125 		for (i = 0; i < ARRAY_SIZE(ae_list); i++) {
2126 			if (write_desc(proto, fd, 0, 0, 0, ae_list[i]))
2127 				return -1;
2128 		}
2129 		break;
2130 	default:
2131 		printk("BUG: Specified unknown proto %d", proto);
2132 		return -1;
2133 	}
2134 
2135 	return 0;
2136 }
2137 
2138 /*
2139  * Some structures in xfrm uapi header differ in size between
2140  * 64-bit and 32-bit ABI:
2141  *
2142  *             32-bit UABI               |            64-bit UABI
2143  *  -------------------------------------|-------------------------------------
2144  *   sizeof(xfrm_usersa_info)     = 220  |  sizeof(xfrm_usersa_info)     = 224
2145  *   sizeof(xfrm_userpolicy_info) = 164  |  sizeof(xfrm_userpolicy_info) = 168
2146  *   sizeof(xfrm_userspi_info)    = 228  |  sizeof(xfrm_userspi_info)    = 232
2147  *   sizeof(xfrm_user_acquire)    = 276  |  sizeof(xfrm_user_acquire)    = 280
2148  *   sizeof(xfrm_user_expire)     = 224  |  sizeof(xfrm_user_expire)     = 232
2149  *   sizeof(xfrm_user_polexpire)  = 168  |  sizeof(xfrm_user_polexpire)  = 176
2150  *
2151  * Check the affected by the UABI difference structures.
2152  * Also, check translation for xfrm_set_spdinfo: it has it's own attributes
2153  * which needs to be correctly copied, but not translated.
2154  */
2155 const unsigned int compat_plan = 5;
2156 static int write_compat_struct_tests(int test_desc_fd)
2157 {
2158 	struct xfrm_desc desc = {};
2159 
2160 	desc.type = ALLOCATE_SPI;
2161 	desc.proto = IPPROTO_AH;
2162 	strncpy(desc.a_algo, ah_list[0], ALGO_LEN - 1);
2163 
2164 	if (__write_desc(test_desc_fd, &desc))
2165 		return -1;
2166 
2167 	desc.type = MONITOR_ACQUIRE;
2168 	if (__write_desc(test_desc_fd, &desc))
2169 		return -1;
2170 
2171 	desc.type = EXPIRE_STATE;
2172 	if (__write_desc(test_desc_fd, &desc))
2173 		return -1;
2174 
2175 	desc.type = EXPIRE_POLICY;
2176 	if (__write_desc(test_desc_fd, &desc))
2177 		return -1;
2178 
2179 	desc.type = SPDINFO_ATTRS;
2180 	if (__write_desc(test_desc_fd, &desc))
2181 		return -1;
2182 
2183 	return 0;
2184 }
2185 
2186 static int write_test_plan(int test_desc_fd)
2187 {
2188 	unsigned int i;
2189 	pid_t child;
2190 
2191 	child = fork();
2192 	if (child < 0) {
2193 		pr_err("fork()");
2194 		return -1;
2195 	}
2196 	if (child) {
2197 		if (close(test_desc_fd))
2198 			printk("close(): %m");
2199 		return 0;
2200 	}
2201 
2202 	if (write_compat_struct_tests(test_desc_fd))
2203 		exit(KSFT_FAIL);
2204 
2205 	for (i = 0; i < ARRAY_SIZE(proto_list); i++) {
2206 		if (write_proto_plan(test_desc_fd, proto_list[i]))
2207 			exit(KSFT_FAIL);
2208 	}
2209 
2210 	exit(KSFT_PASS);
2211 }
2212 
2213 static int children_cleanup(void)
2214 {
2215 	unsigned ret = KSFT_PASS;
2216 
2217 	while (1) {
2218 		int status;
2219 		pid_t p = wait(&status);
2220 
2221 		if ((p < 0) && errno == ECHILD)
2222 			break;
2223 
2224 		if (p < 0) {
2225 			pr_err("wait()");
2226 			return KSFT_FAIL;
2227 		}
2228 
2229 		if (!WIFEXITED(status)) {
2230 			ret = KSFT_FAIL;
2231 			continue;
2232 		}
2233 
2234 		if (WEXITSTATUS(status) == KSFT_FAIL)
2235 			ret = KSFT_FAIL;
2236 	}
2237 
2238 	return ret;
2239 }
2240 
2241 typedef void (*print_res)(const char *, ...);
2242 
2243 static int check_results(void)
2244 {
2245 	struct test_result tr = {};
2246 	struct xfrm_desc *d = &tr.desc;
2247 	int ret = KSFT_PASS;
2248 
2249 	while (1) {
2250 		ssize_t received = read(results_fd[0], &tr, sizeof(tr));
2251 		print_res result;
2252 
2253 		if (received == 0) /* EOF */
2254 			break;
2255 
2256 		if (received != sizeof(tr)) {
2257 			pr_err("read() returned %zd", received);
2258 			return KSFT_FAIL;
2259 		}
2260 
2261 		switch (tr.res) {
2262 		case KSFT_PASS:
2263 			result = ksft_test_result_pass;
2264 			break;
2265 		case KSFT_FAIL:
2266 		default:
2267 			result = ksft_test_result_fail;
2268 			ret = KSFT_FAIL;
2269 		}
2270 
2271 		result(" %s: [%u, '%s', '%s', '%s', '%s', %u]\n",
2272 		       desc_name[d->type], (unsigned int)d->proto, d->a_algo,
2273 		       d->e_algo, d->c_algo, d->ae_algo, d->icv_len);
2274 	}
2275 
2276 	return ret;
2277 }
2278 
2279 int main(int argc, char **argv)
2280 {
2281 	unsigned int nr_process = 1;
2282 	int route_sock = -1, ret = KSFT_SKIP;
2283 	int test_desc_fd[2];
2284 	uint32_t route_seq;
2285 	unsigned int i;
2286 
2287 	if (argc > 2)
2288 		exit_usage(argv);
2289 
2290 	if (argc > 1) {
2291 		char *endptr;
2292 
2293 		errno = 0;
2294 		nr_process = strtol(argv[1], &endptr, 10);
2295 		if ((errno == ERANGE && (nr_process == LONG_MAX || nr_process == LONG_MIN))
2296 				|| (errno != 0 && nr_process == 0)
2297 				|| (endptr == argv[1]) || (*endptr != '\0')) {
2298 			printk("Failed to parse [nr_process]");
2299 			exit_usage(argv);
2300 		}
2301 
2302 		if (nr_process > MAX_PROCESSES || !nr_process) {
2303 			printk("nr_process should be between [1; %u]",
2304 					MAX_PROCESSES);
2305 			exit_usage(argv);
2306 		}
2307 	}
2308 
2309 	srand(time(NULL));
2310 	page_size = sysconf(_SC_PAGESIZE);
2311 	if (page_size < 1)
2312 		ksft_exit_skip("sysconf(): %m\n");
2313 
2314 	if (pipe2(test_desc_fd, O_DIRECT) < 0)
2315 		ksft_exit_skip("pipe(): %m\n");
2316 
2317 	if (pipe2(results_fd, O_DIRECT) < 0)
2318 		ksft_exit_skip("pipe(): %m\n");
2319 
2320 	if (init_namespaces())
2321 		ksft_exit_skip("Failed to create namespaces\n");
2322 
2323 	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
2324 		ksft_exit_skip("Failed to open netlink route socket\n");
2325 
2326 	for (i = 0; i < nr_process; i++) {
2327 		char veth[VETH_LEN];
2328 
2329 		snprintf(veth, VETH_LEN, VETH_FMT, i);
2330 
2331 		if (veth_add(route_sock, route_seq++, veth, nsfd_childa, veth, nsfd_childb)) {
2332 			close(route_sock);
2333 			ksft_exit_fail_msg("Failed to create veth device");
2334 		}
2335 
2336 		if (start_child(i, veth, test_desc_fd)) {
2337 			close(route_sock);
2338 			ksft_exit_fail_msg("Child %u failed to start", i);
2339 		}
2340 	}
2341 
2342 	if (close(route_sock) || close(test_desc_fd[0]) || close(results_fd[1]))
2343 		ksft_exit_fail_msg("close(): %m");
2344 
2345 	ksft_set_plan(proto_plan + compat_plan);
2346 
2347 	if (write_test_plan(test_desc_fd[1]))
2348 		ksft_exit_fail_msg("Failed to write test plan to pipe");
2349 
2350 	ret = check_results();
2351 
2352 	if (children_cleanup() == KSFT_FAIL)
2353 		exit(KSFT_FAIL);
2354 
2355 	exit(ret);
2356 }
2357