xref: /linux/tools/testing/selftests/net/ipsec.c (revision 9dbbc3b9d09d6deba9f3b9e1d5b355032ed46a75)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * ipsec.c - Check xfrm on veth inside a net-ns.
4  * Copyright (c) 2018 Dmitry Safonov
5  */
6 
7 #define _GNU_SOURCE
8 
9 #include <arpa/inet.h>
10 #include <asm/types.h>
11 #include <errno.h>
12 #include <fcntl.h>
13 #include <limits.h>
14 #include <linux/limits.h>
15 #include <linux/netlink.h>
16 #include <linux/random.h>
17 #include <linux/rtnetlink.h>
18 #include <linux/veth.h>
19 #include <linux/xfrm.h>
20 #include <netinet/in.h>
21 #include <net/if.h>
22 #include <sched.h>
23 #include <stdbool.h>
24 #include <stdint.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <sys/mman.h>
29 #include <sys/socket.h>
30 #include <sys/stat.h>
31 #include <sys/syscall.h>
32 #include <sys/types.h>
33 #include <sys/wait.h>
34 #include <time.h>
35 #include <unistd.h>
36 
37 #include "../kselftest.h"
38 
39 #define printk(fmt, ...)						\
40 	ksft_print_msg("%d[%u] " fmt "\n", getpid(), __LINE__, ##__VA_ARGS__)
41 
42 #define pr_err(fmt, ...)	printk(fmt ": %m", ##__VA_ARGS__)
43 
44 #define ARRAY_SIZE(arr) (sizeof(arr) / sizeof((arr)[0]))
45 #define BUILD_BUG_ON(condition) ((void)sizeof(char[1 - 2*!!(condition)]))
46 
47 #define IPV4_STR_SZ	16	/* xxx.xxx.xxx.xxx is longest + \0 */
48 #define MAX_PAYLOAD	2048
49 #define XFRM_ALGO_KEY_BUF_SIZE	512
50 #define MAX_PROCESSES	(1 << 14) /* /16 mask divided by /30 subnets */
51 #define INADDR_A	((in_addr_t) 0x0a000000) /* 10.0.0.0 */
52 #define INADDR_B	((in_addr_t) 0xc0a80000) /* 192.168.0.0 */
53 
54 /* /30 mask for one veth connection */
55 #define PREFIX_LEN	30
56 #define child_ip(nr)	(4*nr + 1)
57 #define grchild_ip(nr)	(4*nr + 2)
58 
59 #define VETH_FMT	"ktst-%d"
60 #define VETH_LEN	12
61 
62 static int nsfd_parent	= -1;
63 static int nsfd_childa	= -1;
64 static int nsfd_childb	= -1;
65 static long page_size;
66 
67 /*
68  * ksft_cnt is static in kselftest, so isn't shared with children.
69  * We have to send a test result back to parent and count there.
70  * results_fd is a pipe with test feedback from children.
71  */
72 static int results_fd[2];
73 
74 const unsigned int ping_delay_nsec	= 50 * 1000 * 1000;
75 const unsigned int ping_timeout		= 300;
76 const unsigned int ping_count		= 100;
77 const unsigned int ping_success		= 80;
78 
79 static void randomize_buffer(void *buf, size_t buflen)
80 {
81 	int *p = (int *)buf;
82 	size_t words = buflen / sizeof(int);
83 	size_t leftover = buflen % sizeof(int);
84 
85 	if (!buflen)
86 		return;
87 
88 	while (words--)
89 		*p++ = rand();
90 
91 	if (leftover) {
92 		int tmp = rand();
93 
94 		memcpy(buf + buflen - leftover, &tmp, leftover);
95 	}
96 
97 	return;
98 }
99 
100 static int unshare_open(void)
101 {
102 	const char *netns_path = "/proc/self/ns/net";
103 	int fd;
104 
105 	if (unshare(CLONE_NEWNET) != 0) {
106 		pr_err("unshare()");
107 		return -1;
108 	}
109 
110 	fd = open(netns_path, O_RDONLY);
111 	if (fd <= 0) {
112 		pr_err("open(%s)", netns_path);
113 		return -1;
114 	}
115 
116 	return fd;
117 }
118 
119 static int switch_ns(int fd)
120 {
121 	if (setns(fd, CLONE_NEWNET)) {
122 		pr_err("setns()");
123 		return -1;
124 	}
125 	return 0;
126 }
127 
128 /*
129  * Running the test inside a new parent net namespace to bother less
130  * about cleanup on error-path.
131  */
132 static int init_namespaces(void)
133 {
134 	nsfd_parent = unshare_open();
135 	if (nsfd_parent <= 0)
136 		return -1;
137 
138 	nsfd_childa = unshare_open();
139 	if (nsfd_childa <= 0)
140 		return -1;
141 
142 	if (switch_ns(nsfd_parent))
143 		return -1;
144 
145 	nsfd_childb = unshare_open();
146 	if (nsfd_childb <= 0)
147 		return -1;
148 
149 	if (switch_ns(nsfd_parent))
150 		return -1;
151 	return 0;
152 }
153 
154 static int netlink_sock(int *sock, uint32_t *seq_nr, int proto)
155 {
156 	if (*sock > 0) {
157 		seq_nr++;
158 		return 0;
159 	}
160 
161 	*sock = socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, proto);
162 	if (*sock <= 0) {
163 		pr_err("socket(AF_NETLINK)");
164 		return -1;
165 	}
166 
167 	randomize_buffer(seq_nr, sizeof(*seq_nr));
168 
169 	return 0;
170 }
171 
172 static inline struct rtattr *rtattr_hdr(struct nlmsghdr *nh)
173 {
174 	return (struct rtattr *)((char *)(nh) + RTA_ALIGN((nh)->nlmsg_len));
175 }
176 
177 static int rtattr_pack(struct nlmsghdr *nh, size_t req_sz,
178 		unsigned short rta_type, const void *payload, size_t size)
179 {
180 	/* NLMSG_ALIGNTO == RTA_ALIGNTO, nlmsg_len already aligned */
181 	struct rtattr *attr = rtattr_hdr(nh);
182 	size_t nl_size = RTA_ALIGN(nh->nlmsg_len) + RTA_LENGTH(size);
183 
184 	if (req_sz < nl_size) {
185 		printk("req buf is too small: %zu < %zu", req_sz, nl_size);
186 		return -1;
187 	}
188 	nh->nlmsg_len = nl_size;
189 
190 	attr->rta_len = RTA_LENGTH(size);
191 	attr->rta_type = rta_type;
192 	memcpy(RTA_DATA(attr), payload, size);
193 
194 	return 0;
195 }
196 
197 static struct rtattr *_rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
198 		unsigned short rta_type, const void *payload, size_t size)
199 {
200 	struct rtattr *ret = rtattr_hdr(nh);
201 
202 	if (rtattr_pack(nh, req_sz, rta_type, payload, size))
203 		return 0;
204 
205 	return ret;
206 }
207 
208 static inline struct rtattr *rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
209 		unsigned short rta_type)
210 {
211 	return _rtattr_begin(nh, req_sz, rta_type, 0, 0);
212 }
213 
214 static inline void rtattr_end(struct nlmsghdr *nh, struct rtattr *attr)
215 {
216 	char *nlmsg_end = (char *)nh + nh->nlmsg_len;
217 
218 	attr->rta_len = nlmsg_end - (char *)attr;
219 }
220 
221 static int veth_pack_peerb(struct nlmsghdr *nh, size_t req_sz,
222 		const char *peer, int ns)
223 {
224 	struct ifinfomsg pi;
225 	struct rtattr *peer_attr;
226 
227 	memset(&pi, 0, sizeof(pi));
228 	pi.ifi_family	= AF_UNSPEC;
229 	pi.ifi_change	= 0xFFFFFFFF;
230 
231 	peer_attr = _rtattr_begin(nh, req_sz, VETH_INFO_PEER, &pi, sizeof(pi));
232 	if (!peer_attr)
233 		return -1;
234 
235 	if (rtattr_pack(nh, req_sz, IFLA_IFNAME, peer, strlen(peer)))
236 		return -1;
237 
238 	if (rtattr_pack(nh, req_sz, IFLA_NET_NS_FD, &ns, sizeof(ns)))
239 		return -1;
240 
241 	rtattr_end(nh, peer_attr);
242 
243 	return 0;
244 }
245 
246 static int netlink_check_answer(int sock)
247 {
248 	struct nlmsgerror {
249 		struct nlmsghdr hdr;
250 		int error;
251 		struct nlmsghdr orig_msg;
252 	} answer;
253 
254 	if (recv(sock, &answer, sizeof(answer), 0) < 0) {
255 		pr_err("recv()");
256 		return -1;
257 	} else if (answer.hdr.nlmsg_type != NLMSG_ERROR) {
258 		printk("expected NLMSG_ERROR, got %d", (int)answer.hdr.nlmsg_type);
259 		return -1;
260 	} else if (answer.error) {
261 		printk("NLMSG_ERROR: %d: %s",
262 			answer.error, strerror(-answer.error));
263 		return answer.error;
264 	}
265 
266 	return 0;
267 }
268 
269 static int veth_add(int sock, uint32_t seq, const char *peera, int ns_a,
270 		const char *peerb, int ns_b)
271 {
272 	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
273 	struct {
274 		struct nlmsghdr		nh;
275 		struct ifinfomsg	info;
276 		char			attrbuf[MAX_PAYLOAD];
277 	} req;
278 	const char veth_type[] = "veth";
279 	struct rtattr *link_info, *info_data;
280 
281 	memset(&req, 0, sizeof(req));
282 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
283 	req.nh.nlmsg_type	= RTM_NEWLINK;
284 	req.nh.nlmsg_flags	= flags;
285 	req.nh.nlmsg_seq	= seq;
286 	req.info.ifi_family	= AF_UNSPEC;
287 	req.info.ifi_change	= 0xFFFFFFFF;
288 
289 	if (rtattr_pack(&req.nh, sizeof(req), IFLA_IFNAME, peera, strlen(peera)))
290 		return -1;
291 
292 	if (rtattr_pack(&req.nh, sizeof(req), IFLA_NET_NS_FD, &ns_a, sizeof(ns_a)))
293 		return -1;
294 
295 	link_info = rtattr_begin(&req.nh, sizeof(req), IFLA_LINKINFO);
296 	if (!link_info)
297 		return -1;
298 
299 	if (rtattr_pack(&req.nh, sizeof(req), IFLA_INFO_KIND, veth_type, sizeof(veth_type)))
300 		return -1;
301 
302 	info_data = rtattr_begin(&req.nh, sizeof(req), IFLA_INFO_DATA);
303 	if (!info_data)
304 		return -1;
305 
306 	if (veth_pack_peerb(&req.nh, sizeof(req), peerb, ns_b))
307 		return -1;
308 
309 	rtattr_end(&req.nh, info_data);
310 	rtattr_end(&req.nh, link_info);
311 
312 	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
313 		pr_err("send()");
314 		return -1;
315 	}
316 	return netlink_check_answer(sock);
317 }
318 
319 static int ip4_addr_set(int sock, uint32_t seq, const char *intf,
320 		struct in_addr addr, uint8_t prefix)
321 {
322 	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
323 	struct {
324 		struct nlmsghdr		nh;
325 		struct ifaddrmsg	info;
326 		char			attrbuf[MAX_PAYLOAD];
327 	} req;
328 
329 	memset(&req, 0, sizeof(req));
330 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
331 	req.nh.nlmsg_type	= RTM_NEWADDR;
332 	req.nh.nlmsg_flags	= flags;
333 	req.nh.nlmsg_seq	= seq;
334 	req.info.ifa_family	= AF_INET;
335 	req.info.ifa_prefixlen	= prefix;
336 	req.info.ifa_index	= if_nametoindex(intf);
337 
338 #ifdef DEBUG
339 	{
340 		char addr_str[IPV4_STR_SZ] = {};
341 
342 		strncpy(addr_str, inet_ntoa(addr), IPV4_STR_SZ - 1);
343 
344 		printk("ip addr set %s", addr_str);
345 	}
346 #endif
347 
348 	if (rtattr_pack(&req.nh, sizeof(req), IFA_LOCAL, &addr, sizeof(addr)))
349 		return -1;
350 
351 	if (rtattr_pack(&req.nh, sizeof(req), IFA_ADDRESS, &addr, sizeof(addr)))
352 		return -1;
353 
354 	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
355 		pr_err("send()");
356 		return -1;
357 	}
358 	return netlink_check_answer(sock);
359 }
360 
361 static int link_set_up(int sock, uint32_t seq, const char *intf)
362 {
363 	struct {
364 		struct nlmsghdr		nh;
365 		struct ifinfomsg	info;
366 		char			attrbuf[MAX_PAYLOAD];
367 	} req;
368 
369 	memset(&req, 0, sizeof(req));
370 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
371 	req.nh.nlmsg_type	= RTM_NEWLINK;
372 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
373 	req.nh.nlmsg_seq	= seq;
374 	req.info.ifi_family	= AF_UNSPEC;
375 	req.info.ifi_change	= 0xFFFFFFFF;
376 	req.info.ifi_index	= if_nametoindex(intf);
377 	req.info.ifi_flags	= IFF_UP;
378 	req.info.ifi_change	= IFF_UP;
379 
380 	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
381 		pr_err("send()");
382 		return -1;
383 	}
384 	return netlink_check_answer(sock);
385 }
386 
387 static int ip4_route_set(int sock, uint32_t seq, const char *intf,
388 		struct in_addr src, struct in_addr dst)
389 {
390 	struct {
391 		struct nlmsghdr	nh;
392 		struct rtmsg	rt;
393 		char		attrbuf[MAX_PAYLOAD];
394 	} req;
395 	unsigned int index = if_nametoindex(intf);
396 
397 	memset(&req, 0, sizeof(req));
398 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.rt));
399 	req.nh.nlmsg_type	= RTM_NEWROUTE;
400 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE;
401 	req.nh.nlmsg_seq	= seq;
402 	req.rt.rtm_family	= AF_INET;
403 	req.rt.rtm_dst_len	= 32;
404 	req.rt.rtm_table	= RT_TABLE_MAIN;
405 	req.rt.rtm_protocol	= RTPROT_BOOT;
406 	req.rt.rtm_scope	= RT_SCOPE_LINK;
407 	req.rt.rtm_type		= RTN_UNICAST;
408 
409 	if (rtattr_pack(&req.nh, sizeof(req), RTA_DST, &dst, sizeof(dst)))
410 		return -1;
411 
412 	if (rtattr_pack(&req.nh, sizeof(req), RTA_PREFSRC, &src, sizeof(src)))
413 		return -1;
414 
415 	if (rtattr_pack(&req.nh, sizeof(req), RTA_OIF, &index, sizeof(index)))
416 		return -1;
417 
418 	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
419 		pr_err("send()");
420 		return -1;
421 	}
422 
423 	return netlink_check_answer(sock);
424 }
425 
426 static int tunnel_set_route(int route_sock, uint32_t *route_seq, char *veth,
427 		struct in_addr tunsrc, struct in_addr tundst)
428 {
429 	if (ip4_addr_set(route_sock, (*route_seq)++, "lo",
430 			tunsrc, PREFIX_LEN)) {
431 		printk("Failed to set ipv4 addr");
432 		return -1;
433 	}
434 
435 	if (ip4_route_set(route_sock, (*route_seq)++, veth, tunsrc, tundst)) {
436 		printk("Failed to set ipv4 route");
437 		return -1;
438 	}
439 
440 	return 0;
441 }
442 
443 static int init_child(int nsfd, char *veth, unsigned int src, unsigned int dst)
444 {
445 	struct in_addr intsrc = inet_makeaddr(INADDR_B, src);
446 	struct in_addr tunsrc = inet_makeaddr(INADDR_A, src);
447 	struct in_addr tundst = inet_makeaddr(INADDR_A, dst);
448 	int route_sock = -1, ret = -1;
449 	uint32_t route_seq;
450 
451 	if (switch_ns(nsfd))
452 		return -1;
453 
454 	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) {
455 		printk("Failed to open netlink route socket in child");
456 		return -1;
457 	}
458 
459 	if (ip4_addr_set(route_sock, route_seq++, veth, intsrc, PREFIX_LEN)) {
460 		printk("Failed to set ipv4 addr");
461 		goto err;
462 	}
463 
464 	if (link_set_up(route_sock, route_seq++, veth)) {
465 		printk("Failed to bring up %s", veth);
466 		goto err;
467 	}
468 
469 	if (tunnel_set_route(route_sock, &route_seq, veth, tunsrc, tundst)) {
470 		printk("Failed to add tunnel route on %s", veth);
471 		goto err;
472 	}
473 	ret = 0;
474 
475 err:
476 	close(route_sock);
477 	return ret;
478 }
479 
480 #define ALGO_LEN	64
481 enum desc_type {
482 	CREATE_TUNNEL	= 0,
483 	ALLOCATE_SPI,
484 	MONITOR_ACQUIRE,
485 	EXPIRE_STATE,
486 	EXPIRE_POLICY,
487 };
488 const char *desc_name[] = {
489 	"create tunnel",
490 	"alloc spi",
491 	"monitor acquire",
492 	"expire state",
493 	"expire policy"
494 };
495 struct xfrm_desc {
496 	enum desc_type	type;
497 	uint8_t		proto;
498 	char		a_algo[ALGO_LEN];
499 	char		e_algo[ALGO_LEN];
500 	char		c_algo[ALGO_LEN];
501 	char		ae_algo[ALGO_LEN];
502 	unsigned int	icv_len;
503 	/* unsigned key_len; */
504 };
505 
506 enum msg_type {
507 	MSG_ACK		= 0,
508 	MSG_EXIT,
509 	MSG_PING,
510 	MSG_XFRM_PREPARE,
511 	MSG_XFRM_ADD,
512 	MSG_XFRM_DEL,
513 	MSG_XFRM_CLEANUP,
514 };
515 
516 struct test_desc {
517 	enum msg_type type;
518 	union {
519 		struct {
520 			in_addr_t reply_ip;
521 			unsigned int port;
522 		} ping;
523 		struct xfrm_desc xfrm_desc;
524 	} body;
525 };
526 
527 struct test_result {
528 	struct xfrm_desc desc;
529 	unsigned int res;
530 };
531 
532 static void write_test_result(unsigned int res, struct xfrm_desc *d)
533 {
534 	struct test_result tr = {};
535 	ssize_t ret;
536 
537 	tr.desc = *d;
538 	tr.res = res;
539 
540 	ret = write(results_fd[1], &tr, sizeof(tr));
541 	if (ret != sizeof(tr))
542 		pr_err("Failed to write the result in pipe %zd", ret);
543 }
544 
545 static void write_msg(int fd, struct test_desc *msg, bool exit_of_fail)
546 {
547 	ssize_t bytes = write(fd, msg, sizeof(*msg));
548 
549 	/* Make sure that write/read is atomic to a pipe */
550 	BUILD_BUG_ON(sizeof(struct test_desc) > PIPE_BUF);
551 
552 	if (bytes < 0) {
553 		pr_err("write()");
554 		if (exit_of_fail)
555 			exit(KSFT_FAIL);
556 	}
557 	if (bytes != sizeof(*msg)) {
558 		pr_err("sent part of the message %zd/%zu", bytes, sizeof(*msg));
559 		if (exit_of_fail)
560 			exit(KSFT_FAIL);
561 	}
562 }
563 
564 static void read_msg(int fd, struct test_desc *msg, bool exit_of_fail)
565 {
566 	ssize_t bytes = read(fd, msg, sizeof(*msg));
567 
568 	if (bytes < 0) {
569 		pr_err("read()");
570 		if (exit_of_fail)
571 			exit(KSFT_FAIL);
572 	}
573 	if (bytes != sizeof(*msg)) {
574 		pr_err("got incomplete message %zd/%zu", bytes, sizeof(*msg));
575 		if (exit_of_fail)
576 			exit(KSFT_FAIL);
577 	}
578 }
579 
580 static int udp_ping_init(struct in_addr listen_ip, unsigned int u_timeout,
581 		unsigned int *server_port, int sock[2])
582 {
583 	struct sockaddr_in server;
584 	struct timeval t = { .tv_sec = 0, .tv_usec = u_timeout };
585 	socklen_t s_len = sizeof(server);
586 
587 	sock[0] = socket(AF_INET, SOCK_DGRAM, 0);
588 	if (sock[0] < 0) {
589 		pr_err("socket()");
590 		return -1;
591 	}
592 
593 	server.sin_family	= AF_INET;
594 	server.sin_port		= 0;
595 	memcpy(&server.sin_addr.s_addr, &listen_ip, sizeof(struct in_addr));
596 
597 	if (bind(sock[0], (struct sockaddr *)&server, s_len)) {
598 		pr_err("bind()");
599 		goto err_close_server;
600 	}
601 
602 	if (getsockname(sock[0], (struct sockaddr *)&server, &s_len)) {
603 		pr_err("getsockname()");
604 		goto err_close_server;
605 	}
606 
607 	*server_port = ntohs(server.sin_port);
608 
609 	if (setsockopt(sock[0], SOL_SOCKET, SO_RCVTIMEO, (const char *)&t, sizeof t)) {
610 		pr_err("setsockopt()");
611 		goto err_close_server;
612 	}
613 
614 	sock[1] = socket(AF_INET, SOCK_DGRAM, 0);
615 	if (sock[1] < 0) {
616 		pr_err("socket()");
617 		goto err_close_server;
618 	}
619 
620 	return 0;
621 
622 err_close_server:
623 	close(sock[0]);
624 	return -1;
625 }
626 
627 static int udp_ping_send(int sock[2], in_addr_t dest_ip, unsigned int port,
628 		char *buf, size_t buf_len)
629 {
630 	struct sockaddr_in server;
631 	const struct sockaddr *dest_addr = (struct sockaddr *)&server;
632 	char *sock_buf[buf_len];
633 	ssize_t r_bytes, s_bytes;
634 
635 	server.sin_family	= AF_INET;
636 	server.sin_port		= htons(port);
637 	server.sin_addr.s_addr	= dest_ip;
638 
639 	s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
640 	if (s_bytes < 0) {
641 		pr_err("sendto()");
642 		return -1;
643 	} else if (s_bytes != buf_len) {
644 		printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
645 		return -1;
646 	}
647 
648 	r_bytes = recv(sock[0], sock_buf, buf_len, 0);
649 	if (r_bytes < 0) {
650 		if (errno != EAGAIN)
651 			pr_err("recv()");
652 		return -1;
653 	} else if (r_bytes == 0) { /* EOF */
654 		printk("EOF on reply to ping");
655 		return -1;
656 	} else if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
657 		printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
658 		return -1;
659 	}
660 
661 	return 0;
662 }
663 
664 static int udp_ping_reply(int sock[2], in_addr_t dest_ip, unsigned int port,
665 		char *buf, size_t buf_len)
666 {
667 	struct sockaddr_in server;
668 	const struct sockaddr *dest_addr = (struct sockaddr *)&server;
669 	char *sock_buf[buf_len];
670 	ssize_t r_bytes, s_bytes;
671 
672 	server.sin_family	= AF_INET;
673 	server.sin_port		= htons(port);
674 	server.sin_addr.s_addr	= dest_ip;
675 
676 	r_bytes = recv(sock[0], sock_buf, buf_len, 0);
677 	if (r_bytes < 0) {
678 		if (errno != EAGAIN)
679 			pr_err("recv()");
680 		return -1;
681 	}
682 	if (r_bytes == 0) { /* EOF */
683 		printk("EOF on reply to ping");
684 		return -1;
685 	}
686 	if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
687 		printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
688 		return -1;
689 	}
690 
691 	s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
692 	if (s_bytes < 0) {
693 		pr_err("sendto()");
694 		return -1;
695 	} else if (s_bytes != buf_len) {
696 		printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
697 		return -1;
698 	}
699 
700 	return 0;
701 }
702 
703 typedef int (*ping_f)(int sock[2], in_addr_t dest_ip, unsigned int port,
704 		char *buf, size_t buf_len);
705 static int do_ping(int cmd_fd, char *buf, size_t buf_len, struct in_addr from,
706 		bool init_side, int d_port, in_addr_t to, ping_f func)
707 {
708 	struct test_desc msg;
709 	unsigned int s_port, i, ping_succeeded = 0;
710 	int ping_sock[2];
711 	char to_str[IPV4_STR_SZ] = {}, from_str[IPV4_STR_SZ] = {};
712 
713 	if (udp_ping_init(from, ping_timeout, &s_port, ping_sock)) {
714 		printk("Failed to init ping");
715 		return -1;
716 	}
717 
718 	memset(&msg, 0, sizeof(msg));
719 	msg.type		= MSG_PING;
720 	msg.body.ping.port	= s_port;
721 	memcpy(&msg.body.ping.reply_ip, &from, sizeof(from));
722 
723 	write_msg(cmd_fd, &msg, 0);
724 	if (init_side) {
725 		/* The other end sends ip to ping */
726 		read_msg(cmd_fd, &msg, 0);
727 		if (msg.type != MSG_PING)
728 			return -1;
729 		to = msg.body.ping.reply_ip;
730 		d_port = msg.body.ping.port;
731 	}
732 
733 	for (i = 0; i < ping_count ; i++) {
734 		struct timespec sleep_time = {
735 			.tv_sec = 0,
736 			.tv_nsec = ping_delay_nsec,
737 		};
738 
739 		ping_succeeded += !func(ping_sock, to, d_port, buf, page_size);
740 		nanosleep(&sleep_time, 0);
741 	}
742 
743 	close(ping_sock[0]);
744 	close(ping_sock[1]);
745 
746 	strncpy(to_str, inet_ntoa(*(struct in_addr *)&to), IPV4_STR_SZ - 1);
747 	strncpy(from_str, inet_ntoa(from), IPV4_STR_SZ - 1);
748 
749 	if (ping_succeeded < ping_success) {
750 		printk("ping (%s) %s->%s failed %u/%u times",
751 			init_side ? "send" : "reply", from_str, to_str,
752 			ping_count - ping_succeeded, ping_count);
753 		return -1;
754 	}
755 
756 #ifdef DEBUG
757 	printk("ping (%s) %s->%s succeeded %u/%u times",
758 		init_side ? "send" : "reply", from_str, to_str,
759 		ping_succeeded, ping_count);
760 #endif
761 
762 	return 0;
763 }
764 
765 static int xfrm_fill_key(char *name, char *buf,
766 		size_t buf_len, unsigned int *key_len)
767 {
768 	/* TODO: use set/map instead */
769 	if (strncmp(name, "digest_null", ALGO_LEN) == 0)
770 		*key_len = 0;
771 	else if (strncmp(name, "ecb(cipher_null)", ALGO_LEN) == 0)
772 		*key_len = 0;
773 	else if (strncmp(name, "cbc(des)", ALGO_LEN) == 0)
774 		*key_len = 64;
775 	else if (strncmp(name, "hmac(md5)", ALGO_LEN) == 0)
776 		*key_len = 128;
777 	else if (strncmp(name, "cmac(aes)", ALGO_LEN) == 0)
778 		*key_len = 128;
779 	else if (strncmp(name, "xcbc(aes)", ALGO_LEN) == 0)
780 		*key_len = 128;
781 	else if (strncmp(name, "cbc(cast5)", ALGO_LEN) == 0)
782 		*key_len = 128;
783 	else if (strncmp(name, "cbc(serpent)", ALGO_LEN) == 0)
784 		*key_len = 128;
785 	else if (strncmp(name, "hmac(sha1)", ALGO_LEN) == 0)
786 		*key_len = 160;
787 	else if (strncmp(name, "hmac(rmd160)", ALGO_LEN) == 0)
788 		*key_len = 160;
789 	else if (strncmp(name, "cbc(des3_ede)", ALGO_LEN) == 0)
790 		*key_len = 192;
791 	else if (strncmp(name, "hmac(sha256)", ALGO_LEN) == 0)
792 		*key_len = 256;
793 	else if (strncmp(name, "cbc(aes)", ALGO_LEN) == 0)
794 		*key_len = 256;
795 	else if (strncmp(name, "cbc(camellia)", ALGO_LEN) == 0)
796 		*key_len = 256;
797 	else if (strncmp(name, "cbc(twofish)", ALGO_LEN) == 0)
798 		*key_len = 256;
799 	else if (strncmp(name, "rfc3686(ctr(aes))", ALGO_LEN) == 0)
800 		*key_len = 288;
801 	else if (strncmp(name, "hmac(sha384)", ALGO_LEN) == 0)
802 		*key_len = 384;
803 	else if (strncmp(name, "cbc(blowfish)", ALGO_LEN) == 0)
804 		*key_len = 448;
805 	else if (strncmp(name, "hmac(sha512)", ALGO_LEN) == 0)
806 		*key_len = 512;
807 	else if (strncmp(name, "rfc4106(gcm(aes))-128", ALGO_LEN) == 0)
808 		*key_len = 160;
809 	else if (strncmp(name, "rfc4543(gcm(aes))-128", ALGO_LEN) == 0)
810 		*key_len = 160;
811 	else if (strncmp(name, "rfc4309(ccm(aes))-128", ALGO_LEN) == 0)
812 		*key_len = 152;
813 	else if (strncmp(name, "rfc4106(gcm(aes))-192", ALGO_LEN) == 0)
814 		*key_len = 224;
815 	else if (strncmp(name, "rfc4543(gcm(aes))-192", ALGO_LEN) == 0)
816 		*key_len = 224;
817 	else if (strncmp(name, "rfc4309(ccm(aes))-192", ALGO_LEN) == 0)
818 		*key_len = 216;
819 	else if (strncmp(name, "rfc4106(gcm(aes))-256", ALGO_LEN) == 0)
820 		*key_len = 288;
821 	else if (strncmp(name, "rfc4543(gcm(aes))-256", ALGO_LEN) == 0)
822 		*key_len = 288;
823 	else if (strncmp(name, "rfc4309(ccm(aes))-256", ALGO_LEN) == 0)
824 		*key_len = 280;
825 	else if (strncmp(name, "rfc7539(chacha20,poly1305)-128", ALGO_LEN) == 0)
826 		*key_len = 0;
827 
828 	if (*key_len > buf_len) {
829 		printk("Can't pack a key - too big for buffer");
830 		return -1;
831 	}
832 
833 	randomize_buffer(buf, *key_len);
834 
835 	return 0;
836 }
837 
838 static int xfrm_state_pack_algo(struct nlmsghdr *nh, size_t req_sz,
839 		struct xfrm_desc *desc)
840 {
841 	struct {
842 		union {
843 			struct xfrm_algo	alg;
844 			struct xfrm_algo_aead	aead;
845 			struct xfrm_algo_auth	auth;
846 		} u;
847 		char buf[XFRM_ALGO_KEY_BUF_SIZE];
848 	} alg = {};
849 	size_t alen, elen, clen, aelen;
850 	unsigned short type;
851 
852 	alen = strlen(desc->a_algo);
853 	elen = strlen(desc->e_algo);
854 	clen = strlen(desc->c_algo);
855 	aelen = strlen(desc->ae_algo);
856 
857 	/* Verify desc */
858 	switch (desc->proto) {
859 	case IPPROTO_AH:
860 		if (!alen || elen || clen || aelen) {
861 			printk("BUG: buggy ah desc");
862 			return -1;
863 		}
864 		strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN - 1);
865 		if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
866 				sizeof(alg.buf), &alg.u.alg.alg_key_len))
867 			return -1;
868 		type = XFRMA_ALG_AUTH;
869 		break;
870 	case IPPROTO_COMP:
871 		if (!clen || elen || alen || aelen) {
872 			printk("BUG: buggy comp desc");
873 			return -1;
874 		}
875 		strncpy(alg.u.alg.alg_name, desc->c_algo, ALGO_LEN - 1);
876 		if (xfrm_fill_key(desc->c_algo, alg.u.alg.alg_key,
877 				sizeof(alg.buf), &alg.u.alg.alg_key_len))
878 			return -1;
879 		type = XFRMA_ALG_COMP;
880 		break;
881 	case IPPROTO_ESP:
882 		if (!((alen && elen) ^ aelen) || clen) {
883 			printk("BUG: buggy esp desc");
884 			return -1;
885 		}
886 		if (aelen) {
887 			alg.u.aead.alg_icv_len = desc->icv_len;
888 			strncpy(alg.u.aead.alg_name, desc->ae_algo, ALGO_LEN - 1);
889 			if (xfrm_fill_key(desc->ae_algo, alg.u.aead.alg_key,
890 						sizeof(alg.buf), &alg.u.aead.alg_key_len))
891 				return -1;
892 			type = XFRMA_ALG_AEAD;
893 		} else {
894 
895 			strncpy(alg.u.alg.alg_name, desc->e_algo, ALGO_LEN - 1);
896 			type = XFRMA_ALG_CRYPT;
897 			if (xfrm_fill_key(desc->e_algo, alg.u.alg.alg_key,
898 						sizeof(alg.buf), &alg.u.alg.alg_key_len))
899 				return -1;
900 			if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
901 				return -1;
902 
903 			strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN);
904 			type = XFRMA_ALG_AUTH;
905 			if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
906 						sizeof(alg.buf), &alg.u.alg.alg_key_len))
907 				return -1;
908 		}
909 		break;
910 	default:
911 		printk("BUG: unknown proto in desc");
912 		return -1;
913 	}
914 
915 	if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
916 		return -1;
917 
918 	return 0;
919 }
920 
921 static inline uint32_t gen_spi(struct in_addr src)
922 {
923 	return htonl(inet_lnaof(src));
924 }
925 
926 static int xfrm_state_add(int xfrm_sock, uint32_t seq, uint32_t spi,
927 		struct in_addr src, struct in_addr dst,
928 		struct xfrm_desc *desc)
929 {
930 	struct {
931 		struct nlmsghdr		nh;
932 		struct xfrm_usersa_info	info;
933 		char			attrbuf[MAX_PAYLOAD];
934 	} req;
935 
936 	memset(&req, 0, sizeof(req));
937 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
938 	req.nh.nlmsg_type	= XFRM_MSG_NEWSA;
939 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
940 	req.nh.nlmsg_seq	= seq;
941 
942 	/* Fill selector. */
943 	memcpy(&req.info.sel.daddr, &dst, sizeof(dst));
944 	memcpy(&req.info.sel.saddr, &src, sizeof(src));
945 	req.info.sel.family		= AF_INET;
946 	req.info.sel.prefixlen_d	= PREFIX_LEN;
947 	req.info.sel.prefixlen_s	= PREFIX_LEN;
948 
949 	/* Fill id */
950 	memcpy(&req.info.id.daddr, &dst, sizeof(dst));
951 	/* Note: zero-spi cannot be deleted */
952 	req.info.id.spi = spi;
953 	req.info.id.proto	= desc->proto;
954 
955 	memcpy(&req.info.saddr, &src, sizeof(src));
956 
957 	/* Fill lifteme_cfg */
958 	req.info.lft.soft_byte_limit	= XFRM_INF;
959 	req.info.lft.hard_byte_limit	= XFRM_INF;
960 	req.info.lft.soft_packet_limit	= XFRM_INF;
961 	req.info.lft.hard_packet_limit	= XFRM_INF;
962 
963 	req.info.family		= AF_INET;
964 	req.info.mode		= XFRM_MODE_TUNNEL;
965 
966 	if (xfrm_state_pack_algo(&req.nh, sizeof(req), desc))
967 		return -1;
968 
969 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
970 		pr_err("send()");
971 		return -1;
972 	}
973 
974 	return netlink_check_answer(xfrm_sock);
975 }
976 
977 static bool xfrm_usersa_found(struct xfrm_usersa_info *info, uint32_t spi,
978 		struct in_addr src, struct in_addr dst,
979 		struct xfrm_desc *desc)
980 {
981 	if (memcmp(&info->sel.daddr, &dst, sizeof(dst)))
982 		return false;
983 
984 	if (memcmp(&info->sel.saddr, &src, sizeof(src)))
985 		return false;
986 
987 	if (info->sel.family != AF_INET					||
988 			info->sel.prefixlen_d != PREFIX_LEN		||
989 			info->sel.prefixlen_s != PREFIX_LEN)
990 		return false;
991 
992 	if (info->id.spi != spi || info->id.proto != desc->proto)
993 		return false;
994 
995 	if (memcmp(&info->id.daddr, &dst, sizeof(dst)))
996 		return false;
997 
998 	if (memcmp(&info->saddr, &src, sizeof(src)))
999 		return false;
1000 
1001 	if (info->lft.soft_byte_limit != XFRM_INF			||
1002 			info->lft.hard_byte_limit != XFRM_INF		||
1003 			info->lft.soft_packet_limit != XFRM_INF		||
1004 			info->lft.hard_packet_limit != XFRM_INF)
1005 		return false;
1006 
1007 	if (info->family != AF_INET || info->mode != XFRM_MODE_TUNNEL)
1008 		return false;
1009 
1010 	/* XXX: check xfrm algo, see xfrm_state_pack_algo(). */
1011 
1012 	return true;
1013 }
1014 
1015 static int xfrm_state_check(int xfrm_sock, uint32_t seq, uint32_t spi,
1016 		struct in_addr src, struct in_addr dst,
1017 		struct xfrm_desc *desc)
1018 {
1019 	struct {
1020 		struct nlmsghdr		nh;
1021 		char			attrbuf[MAX_PAYLOAD];
1022 	} req;
1023 	struct {
1024 		struct nlmsghdr		nh;
1025 		union {
1026 			struct xfrm_usersa_info	info;
1027 			int error;
1028 		};
1029 		char			attrbuf[MAX_PAYLOAD];
1030 	} answer;
1031 	struct xfrm_address_filter filter = {};
1032 	bool found = false;
1033 
1034 
1035 	memset(&req, 0, sizeof(req));
1036 	req.nh.nlmsg_len	= NLMSG_LENGTH(0);
1037 	req.nh.nlmsg_type	= XFRM_MSG_GETSA;
1038 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_DUMP;
1039 	req.nh.nlmsg_seq	= seq;
1040 
1041 	/*
1042 	 * Add dump filter by source address as there may be other tunnels
1043 	 * in this netns (if tests run in parallel).
1044 	 */
1045 	filter.family = AF_INET;
1046 	filter.splen = 0x1f;	/* 0xffffffff mask see addr_match() */
1047 	memcpy(&filter.saddr, &src, sizeof(src));
1048 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_ADDRESS_FILTER,
1049 				&filter, sizeof(filter)))
1050 		return -1;
1051 
1052 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1053 		pr_err("send()");
1054 		return -1;
1055 	}
1056 
1057 	while (1) {
1058 		if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1059 			pr_err("recv()");
1060 			return -1;
1061 		}
1062 		if (answer.nh.nlmsg_type == NLMSG_ERROR) {
1063 			printk("NLMSG_ERROR: %d: %s",
1064 				answer.error, strerror(-answer.error));
1065 			return -1;
1066 		} else if (answer.nh.nlmsg_type == NLMSG_DONE) {
1067 			if (found)
1068 				return 0;
1069 			printk("didn't find allocated xfrm state in dump");
1070 			return -1;
1071 		} else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1072 			if (xfrm_usersa_found(&answer.info, spi, src, dst, desc))
1073 				found = true;
1074 		}
1075 	}
1076 }
1077 
1078 static int xfrm_set(int xfrm_sock, uint32_t *seq,
1079 		struct in_addr src, struct in_addr dst,
1080 		struct in_addr tunsrc, struct in_addr tundst,
1081 		struct xfrm_desc *desc)
1082 {
1083 	int err;
1084 
1085 	err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1086 	if (err) {
1087 		printk("Failed to add xfrm state");
1088 		return -1;
1089 	}
1090 
1091 	err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1092 	if (err) {
1093 		printk("Failed to add xfrm state");
1094 		return -1;
1095 	}
1096 
1097 	/* Check dumps for XFRM_MSG_GETSA */
1098 	err = xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1099 	err |= xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1100 	if (err) {
1101 		printk("Failed to check xfrm state");
1102 		return -1;
1103 	}
1104 
1105 	return 0;
1106 }
1107 
1108 static int xfrm_policy_add(int xfrm_sock, uint32_t seq, uint32_t spi,
1109 		struct in_addr src, struct in_addr dst, uint8_t dir,
1110 		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1111 {
1112 	struct {
1113 		struct nlmsghdr			nh;
1114 		struct xfrm_userpolicy_info	info;
1115 		char				attrbuf[MAX_PAYLOAD];
1116 	} req;
1117 	struct xfrm_user_tmpl tmpl;
1118 
1119 	memset(&req, 0, sizeof(req));
1120 	memset(&tmpl, 0, sizeof(tmpl));
1121 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
1122 	req.nh.nlmsg_type	= XFRM_MSG_NEWPOLICY;
1123 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1124 	req.nh.nlmsg_seq	= seq;
1125 
1126 	/* Fill selector. */
1127 	memcpy(&req.info.sel.daddr, &dst, sizeof(tundst));
1128 	memcpy(&req.info.sel.saddr, &src, sizeof(tunsrc));
1129 	req.info.sel.family		= AF_INET;
1130 	req.info.sel.prefixlen_d	= PREFIX_LEN;
1131 	req.info.sel.prefixlen_s	= PREFIX_LEN;
1132 
1133 	/* Fill lifteme_cfg */
1134 	req.info.lft.soft_byte_limit	= XFRM_INF;
1135 	req.info.lft.hard_byte_limit	= XFRM_INF;
1136 	req.info.lft.soft_packet_limit	= XFRM_INF;
1137 	req.info.lft.hard_packet_limit	= XFRM_INF;
1138 
1139 	req.info.dir = dir;
1140 
1141 	/* Fill tmpl */
1142 	memcpy(&tmpl.id.daddr, &dst, sizeof(dst));
1143 	/* Note: zero-spi cannot be deleted */
1144 	tmpl.id.spi = spi;
1145 	tmpl.id.proto	= proto;
1146 	tmpl.family	= AF_INET;
1147 	memcpy(&tmpl.saddr, &src, sizeof(src));
1148 	tmpl.mode	= XFRM_MODE_TUNNEL;
1149 	tmpl.aalgos = (~(uint32_t)0);
1150 	tmpl.ealgos = (~(uint32_t)0);
1151 	tmpl.calgos = (~(uint32_t)0);
1152 
1153 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &tmpl, sizeof(tmpl)))
1154 		return -1;
1155 
1156 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1157 		pr_err("send()");
1158 		return -1;
1159 	}
1160 
1161 	return netlink_check_answer(xfrm_sock);
1162 }
1163 
1164 static int xfrm_prepare(int xfrm_sock, uint32_t *seq,
1165 		struct in_addr src, struct in_addr dst,
1166 		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1167 {
1168 	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1169 				XFRM_POLICY_OUT, tunsrc, tundst, proto)) {
1170 		printk("Failed to add xfrm policy");
1171 		return -1;
1172 	}
1173 
1174 	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src,
1175 				XFRM_POLICY_IN, tunsrc, tundst, proto)) {
1176 		printk("Failed to add xfrm policy");
1177 		return -1;
1178 	}
1179 
1180 	return 0;
1181 }
1182 
1183 static int xfrm_policy_del(int xfrm_sock, uint32_t seq,
1184 		struct in_addr src, struct in_addr dst, uint8_t dir,
1185 		struct in_addr tunsrc, struct in_addr tundst)
1186 {
1187 	struct {
1188 		struct nlmsghdr			nh;
1189 		struct xfrm_userpolicy_id	id;
1190 		char				attrbuf[MAX_PAYLOAD];
1191 	} req;
1192 
1193 	memset(&req, 0, sizeof(req));
1194 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.id));
1195 	req.nh.nlmsg_type	= XFRM_MSG_DELPOLICY;
1196 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1197 	req.nh.nlmsg_seq	= seq;
1198 
1199 	/* Fill id */
1200 	memcpy(&req.id.sel.daddr, &dst, sizeof(tundst));
1201 	memcpy(&req.id.sel.saddr, &src, sizeof(tunsrc));
1202 	req.id.sel.family		= AF_INET;
1203 	req.id.sel.prefixlen_d		= PREFIX_LEN;
1204 	req.id.sel.prefixlen_s		= PREFIX_LEN;
1205 	req.id.dir = dir;
1206 
1207 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1208 		pr_err("send()");
1209 		return -1;
1210 	}
1211 
1212 	return netlink_check_answer(xfrm_sock);
1213 }
1214 
1215 static int xfrm_cleanup(int xfrm_sock, uint32_t *seq,
1216 		struct in_addr src, struct in_addr dst,
1217 		struct in_addr tunsrc, struct in_addr tundst)
1218 {
1219 	if (xfrm_policy_del(xfrm_sock, (*seq)++, src, dst,
1220 				XFRM_POLICY_OUT, tunsrc, tundst)) {
1221 		printk("Failed to add xfrm policy");
1222 		return -1;
1223 	}
1224 
1225 	if (xfrm_policy_del(xfrm_sock, (*seq)++, dst, src,
1226 				XFRM_POLICY_IN, tunsrc, tundst)) {
1227 		printk("Failed to add xfrm policy");
1228 		return -1;
1229 	}
1230 
1231 	return 0;
1232 }
1233 
1234 static int xfrm_state_del(int xfrm_sock, uint32_t seq, uint32_t spi,
1235 		struct in_addr src, struct in_addr dst, uint8_t proto)
1236 {
1237 	struct {
1238 		struct nlmsghdr		nh;
1239 		struct xfrm_usersa_id	id;
1240 		char			attrbuf[MAX_PAYLOAD];
1241 	} req;
1242 	xfrm_address_t saddr = {};
1243 
1244 	memset(&req, 0, sizeof(req));
1245 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.id));
1246 	req.nh.nlmsg_type	= XFRM_MSG_DELSA;
1247 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1248 	req.nh.nlmsg_seq	= seq;
1249 
1250 	memcpy(&req.id.daddr, &dst, sizeof(dst));
1251 	req.id.family		= AF_INET;
1252 	req.id.proto		= proto;
1253 	/* Note: zero-spi cannot be deleted */
1254 	req.id.spi = spi;
1255 
1256 	memcpy(&saddr, &src, sizeof(src));
1257 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SRCADDR, &saddr, sizeof(saddr)))
1258 		return -1;
1259 
1260 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1261 		pr_err("send()");
1262 		return -1;
1263 	}
1264 
1265 	return netlink_check_answer(xfrm_sock);
1266 }
1267 
1268 static int xfrm_delete(int xfrm_sock, uint32_t *seq,
1269 		struct in_addr src, struct in_addr dst,
1270 		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1271 {
1272 	if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), src, dst, proto)) {
1273 		printk("Failed to remove xfrm state");
1274 		return -1;
1275 	}
1276 
1277 	if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), dst, src, proto)) {
1278 		printk("Failed to remove xfrm state");
1279 		return -1;
1280 	}
1281 
1282 	return 0;
1283 }
1284 
1285 static int xfrm_state_allocspi(int xfrm_sock, uint32_t *seq,
1286 		uint32_t spi, uint8_t proto)
1287 {
1288 	struct {
1289 		struct nlmsghdr			nh;
1290 		struct xfrm_userspi_info	spi;
1291 	} req;
1292 	struct {
1293 		struct nlmsghdr			nh;
1294 		union {
1295 			struct xfrm_usersa_info	info;
1296 			int error;
1297 		};
1298 	} answer;
1299 
1300 	memset(&req, 0, sizeof(req));
1301 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.spi));
1302 	req.nh.nlmsg_type	= XFRM_MSG_ALLOCSPI;
1303 	req.nh.nlmsg_flags	= NLM_F_REQUEST;
1304 	req.nh.nlmsg_seq	= (*seq)++;
1305 
1306 	req.spi.info.family	= AF_INET;
1307 	req.spi.min		= spi;
1308 	req.spi.max		= spi;
1309 	req.spi.info.id.proto	= proto;
1310 
1311 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1312 		pr_err("send()");
1313 		return KSFT_FAIL;
1314 	}
1315 
1316 	if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1317 		pr_err("recv()");
1318 		return KSFT_FAIL;
1319 	} else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1320 		uint32_t new_spi = htonl(answer.info.id.spi);
1321 
1322 		if (new_spi != spi) {
1323 			printk("allocated spi is different from requested: %#x != %#x",
1324 					new_spi, spi);
1325 			return KSFT_FAIL;
1326 		}
1327 		return KSFT_PASS;
1328 	} else if (answer.nh.nlmsg_type != NLMSG_ERROR) {
1329 		printk("expected NLMSG_ERROR, got %d", (int)answer.nh.nlmsg_type);
1330 		return KSFT_FAIL;
1331 	}
1332 
1333 	printk("NLMSG_ERROR: %d: %s", answer.error, strerror(-answer.error));
1334 	return (answer.error) ? KSFT_FAIL : KSFT_PASS;
1335 }
1336 
1337 static int netlink_sock_bind(int *sock, uint32_t *seq, int proto, uint32_t groups)
1338 {
1339 	struct sockaddr_nl snl = {};
1340 	socklen_t addr_len;
1341 	int ret = -1;
1342 
1343 	snl.nl_family = AF_NETLINK;
1344 	snl.nl_groups = groups;
1345 
1346 	if (netlink_sock(sock, seq, proto)) {
1347 		printk("Failed to open xfrm netlink socket");
1348 		return -1;
1349 	}
1350 
1351 	if (bind(*sock, (struct sockaddr *)&snl, sizeof(snl)) < 0) {
1352 		pr_err("bind()");
1353 		goto out_close;
1354 	}
1355 
1356 	addr_len = sizeof(snl);
1357 	if (getsockname(*sock, (struct sockaddr *)&snl, &addr_len) < 0) {
1358 		pr_err("getsockname()");
1359 		goto out_close;
1360 	}
1361 	if (addr_len != sizeof(snl)) {
1362 		printk("Wrong address length %d", addr_len);
1363 		goto out_close;
1364 	}
1365 	if (snl.nl_family != AF_NETLINK) {
1366 		printk("Wrong address family %d", snl.nl_family);
1367 		goto out_close;
1368 	}
1369 	return 0;
1370 
1371 out_close:
1372 	close(*sock);
1373 	return ret;
1374 }
1375 
1376 static int xfrm_monitor_acquire(int xfrm_sock, uint32_t *seq, unsigned int nr)
1377 {
1378 	struct {
1379 		struct nlmsghdr nh;
1380 		union {
1381 			struct xfrm_user_acquire acq;
1382 			int error;
1383 		};
1384 		char attrbuf[MAX_PAYLOAD];
1385 	} req;
1386 	struct xfrm_user_tmpl xfrm_tmpl = {};
1387 	int xfrm_listen = -1, ret = KSFT_FAIL;
1388 	uint32_t seq_listen;
1389 
1390 	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_ACQUIRE))
1391 		return KSFT_FAIL;
1392 
1393 	memset(&req, 0, sizeof(req));
1394 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.acq));
1395 	req.nh.nlmsg_type	= XFRM_MSG_ACQUIRE;
1396 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1397 	req.nh.nlmsg_seq	= (*seq)++;
1398 
1399 	req.acq.policy.sel.family	= AF_INET;
1400 	req.acq.aalgos	= 0xfeed;
1401 	req.acq.ealgos	= 0xbaad;
1402 	req.acq.calgos	= 0xbabe;
1403 
1404 	xfrm_tmpl.family = AF_INET;
1405 	xfrm_tmpl.id.proto = IPPROTO_ESP;
1406 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &xfrm_tmpl, sizeof(xfrm_tmpl)))
1407 		goto out_close;
1408 
1409 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1410 		pr_err("send()");
1411 		goto out_close;
1412 	}
1413 
1414 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1415 		pr_err("recv()");
1416 		goto out_close;
1417 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1418 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1419 		goto out_close;
1420 	}
1421 
1422 	if (req.error) {
1423 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1424 		ret = req.error;
1425 		goto out_close;
1426 	}
1427 
1428 	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1429 		pr_err("recv()");
1430 		goto out_close;
1431 	}
1432 
1433 	if (req.acq.aalgos != 0xfeed || req.acq.ealgos != 0xbaad
1434 			|| req.acq.calgos != 0xbabe) {
1435 		printk("xfrm_user_acquire has changed  %x %x %x",
1436 				req.acq.aalgos, req.acq.ealgos, req.acq.calgos);
1437 		goto out_close;
1438 	}
1439 
1440 	ret = KSFT_PASS;
1441 out_close:
1442 	close(xfrm_listen);
1443 	return ret;
1444 }
1445 
1446 static int xfrm_expire_state(int xfrm_sock, uint32_t *seq,
1447 		unsigned int nr, struct xfrm_desc *desc)
1448 {
1449 	struct {
1450 		struct nlmsghdr nh;
1451 		union {
1452 			struct xfrm_user_expire expire;
1453 			int error;
1454 		};
1455 	} req;
1456 	struct in_addr src, dst;
1457 	int xfrm_listen = -1, ret = KSFT_FAIL;
1458 	uint32_t seq_listen;
1459 
1460 	src = inet_makeaddr(INADDR_B, child_ip(nr));
1461 	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1462 
1463 	if (xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc)) {
1464 		printk("Failed to add xfrm state");
1465 		return KSFT_FAIL;
1466 	}
1467 
1468 	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1469 		return KSFT_FAIL;
1470 
1471 	memset(&req, 0, sizeof(req));
1472 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.expire));
1473 	req.nh.nlmsg_type	= XFRM_MSG_EXPIRE;
1474 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1475 	req.nh.nlmsg_seq	= (*seq)++;
1476 
1477 	memcpy(&req.expire.state.id.daddr, &dst, sizeof(dst));
1478 	req.expire.state.id.spi		= gen_spi(src);
1479 	req.expire.state.id.proto	= desc->proto;
1480 	req.expire.state.family		= AF_INET;
1481 	req.expire.hard			= 0xff;
1482 
1483 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1484 		pr_err("send()");
1485 		goto out_close;
1486 	}
1487 
1488 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1489 		pr_err("recv()");
1490 		goto out_close;
1491 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1492 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1493 		goto out_close;
1494 	}
1495 
1496 	if (req.error) {
1497 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1498 		ret = req.error;
1499 		goto out_close;
1500 	}
1501 
1502 	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1503 		pr_err("recv()");
1504 		goto out_close;
1505 	}
1506 
1507 	if (req.expire.hard != 0x1) {
1508 		printk("expire.hard is not set: %x", req.expire.hard);
1509 		goto out_close;
1510 	}
1511 
1512 	ret = KSFT_PASS;
1513 out_close:
1514 	close(xfrm_listen);
1515 	return ret;
1516 }
1517 
1518 static int xfrm_expire_policy(int xfrm_sock, uint32_t *seq,
1519 		unsigned int nr, struct xfrm_desc *desc)
1520 {
1521 	struct {
1522 		struct nlmsghdr nh;
1523 		union {
1524 			struct xfrm_user_polexpire expire;
1525 			int error;
1526 		};
1527 	} req;
1528 	struct in_addr src, dst, tunsrc, tundst;
1529 	int xfrm_listen = -1, ret = KSFT_FAIL;
1530 	uint32_t seq_listen;
1531 
1532 	src = inet_makeaddr(INADDR_B, child_ip(nr));
1533 	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1534 	tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1535 	tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1536 
1537 	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1538 				XFRM_POLICY_OUT, tunsrc, tundst, desc->proto)) {
1539 		printk("Failed to add xfrm policy");
1540 		return KSFT_FAIL;
1541 	}
1542 
1543 	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1544 		return KSFT_FAIL;
1545 
1546 	memset(&req, 0, sizeof(req));
1547 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.expire));
1548 	req.nh.nlmsg_type	= XFRM_MSG_POLEXPIRE;
1549 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1550 	req.nh.nlmsg_seq	= (*seq)++;
1551 
1552 	/* Fill selector. */
1553 	memcpy(&req.expire.pol.sel.daddr, &dst, sizeof(tundst));
1554 	memcpy(&req.expire.pol.sel.saddr, &src, sizeof(tunsrc));
1555 	req.expire.pol.sel.family	= AF_INET;
1556 	req.expire.pol.sel.prefixlen_d	= PREFIX_LEN;
1557 	req.expire.pol.sel.prefixlen_s	= PREFIX_LEN;
1558 	req.expire.pol.dir		= XFRM_POLICY_OUT;
1559 	req.expire.hard			= 0xff;
1560 
1561 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1562 		pr_err("send()");
1563 		goto out_close;
1564 	}
1565 
1566 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1567 		pr_err("recv()");
1568 		goto out_close;
1569 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1570 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1571 		goto out_close;
1572 	}
1573 
1574 	if (req.error) {
1575 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1576 		ret = req.error;
1577 		goto out_close;
1578 	}
1579 
1580 	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1581 		pr_err("recv()");
1582 		goto out_close;
1583 	}
1584 
1585 	if (req.expire.hard != 0x1) {
1586 		printk("expire.hard is not set: %x", req.expire.hard);
1587 		goto out_close;
1588 	}
1589 
1590 	ret = KSFT_PASS;
1591 out_close:
1592 	close(xfrm_listen);
1593 	return ret;
1594 }
1595 
1596 static int child_serv(int xfrm_sock, uint32_t *seq,
1597 		unsigned int nr, int cmd_fd, void *buf, struct xfrm_desc *desc)
1598 {
1599 	struct in_addr src, dst, tunsrc, tundst;
1600 	struct test_desc msg;
1601 	int ret = KSFT_FAIL;
1602 
1603 	src = inet_makeaddr(INADDR_B, child_ip(nr));
1604 	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1605 	tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1606 	tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1607 
1608 	/* UDP pinging without xfrm */
1609 	if (do_ping(cmd_fd, buf, page_size, src, true, 0, 0, udp_ping_send)) {
1610 		printk("ping failed before setting xfrm");
1611 		return KSFT_FAIL;
1612 	}
1613 
1614 	memset(&msg, 0, sizeof(msg));
1615 	msg.type = MSG_XFRM_PREPARE;
1616 	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1617 	write_msg(cmd_fd, &msg, 1);
1618 
1619 	if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1620 		printk("failed to prepare xfrm");
1621 		goto cleanup;
1622 	}
1623 
1624 	memset(&msg, 0, sizeof(msg));
1625 	msg.type = MSG_XFRM_ADD;
1626 	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1627 	write_msg(cmd_fd, &msg, 1);
1628 	if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1629 		printk("failed to set xfrm");
1630 		goto delete;
1631 	}
1632 
1633 	/* UDP pinging with xfrm tunnel */
1634 	if (do_ping(cmd_fd, buf, page_size, tunsrc,
1635 				true, 0, 0, udp_ping_send)) {
1636 		printk("ping failed for xfrm");
1637 		goto delete;
1638 	}
1639 
1640 	ret = KSFT_PASS;
1641 delete:
1642 	/* xfrm delete */
1643 	memset(&msg, 0, sizeof(msg));
1644 	msg.type = MSG_XFRM_DEL;
1645 	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1646 	write_msg(cmd_fd, &msg, 1);
1647 
1648 	if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1649 		printk("failed ping to remove xfrm");
1650 		ret = KSFT_FAIL;
1651 	}
1652 
1653 cleanup:
1654 	memset(&msg, 0, sizeof(msg));
1655 	msg.type = MSG_XFRM_CLEANUP;
1656 	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1657 	write_msg(cmd_fd, &msg, 1);
1658 	if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1659 		printk("failed ping to cleanup xfrm");
1660 		ret = KSFT_FAIL;
1661 	}
1662 	return ret;
1663 }
1664 
1665 static int child_f(unsigned int nr, int test_desc_fd, int cmd_fd, void *buf)
1666 {
1667 	struct xfrm_desc desc;
1668 	struct test_desc msg;
1669 	int xfrm_sock = -1;
1670 	uint32_t seq;
1671 
1672 	if (switch_ns(nsfd_childa))
1673 		exit(KSFT_FAIL);
1674 
1675 	if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1676 		printk("Failed to open xfrm netlink socket");
1677 		exit(KSFT_FAIL);
1678 	}
1679 
1680 	/* Check that seq sock is ready, just for sure. */
1681 	memset(&msg, 0, sizeof(msg));
1682 	msg.type = MSG_ACK;
1683 	write_msg(cmd_fd, &msg, 1);
1684 	read_msg(cmd_fd, &msg, 1);
1685 	if (msg.type != MSG_ACK) {
1686 		printk("Ack failed");
1687 		exit(KSFT_FAIL);
1688 	}
1689 
1690 	for (;;) {
1691 		ssize_t received = read(test_desc_fd, &desc, sizeof(desc));
1692 		int ret;
1693 
1694 		if (received == 0) /* EOF */
1695 			break;
1696 
1697 		if (received != sizeof(desc)) {
1698 			pr_err("read() returned %zd", received);
1699 			exit(KSFT_FAIL);
1700 		}
1701 
1702 		switch (desc.type) {
1703 		case CREATE_TUNNEL:
1704 			ret = child_serv(xfrm_sock, &seq, nr,
1705 					 cmd_fd, buf, &desc);
1706 			break;
1707 		case ALLOCATE_SPI:
1708 			ret = xfrm_state_allocspi(xfrm_sock, &seq,
1709 						  -1, desc.proto);
1710 			break;
1711 		case MONITOR_ACQUIRE:
1712 			ret = xfrm_monitor_acquire(xfrm_sock, &seq, nr);
1713 			break;
1714 		case EXPIRE_STATE:
1715 			ret = xfrm_expire_state(xfrm_sock, &seq, nr, &desc);
1716 			break;
1717 		case EXPIRE_POLICY:
1718 			ret = xfrm_expire_policy(xfrm_sock, &seq, nr, &desc);
1719 			break;
1720 		default:
1721 			printk("Unknown desc type %d", desc.type);
1722 			exit(KSFT_FAIL);
1723 		}
1724 		write_test_result(ret, &desc);
1725 	}
1726 
1727 	close(xfrm_sock);
1728 
1729 	msg.type = MSG_EXIT;
1730 	write_msg(cmd_fd, &msg, 1);
1731 	exit(KSFT_PASS);
1732 }
1733 
1734 static void grand_child_serv(unsigned int nr, int cmd_fd, void *buf,
1735 		struct test_desc *msg, int xfrm_sock, uint32_t *seq)
1736 {
1737 	struct in_addr src, dst, tunsrc, tundst;
1738 	bool tun_reply;
1739 	struct xfrm_desc *desc = &msg->body.xfrm_desc;
1740 
1741 	src = inet_makeaddr(INADDR_B, grchild_ip(nr));
1742 	dst = inet_makeaddr(INADDR_B, child_ip(nr));
1743 	tunsrc = inet_makeaddr(INADDR_A, grchild_ip(nr));
1744 	tundst = inet_makeaddr(INADDR_A, child_ip(nr));
1745 
1746 	switch (msg->type) {
1747 	case MSG_EXIT:
1748 		exit(KSFT_PASS);
1749 	case MSG_ACK:
1750 		write_msg(cmd_fd, msg, 1);
1751 		break;
1752 	case MSG_PING:
1753 		tun_reply = memcmp(&dst, &msg->body.ping.reply_ip, sizeof(in_addr_t));
1754 		/* UDP pinging without xfrm */
1755 		if (do_ping(cmd_fd, buf, page_size, tun_reply ? tunsrc : src,
1756 				false, msg->body.ping.port,
1757 				msg->body.ping.reply_ip, udp_ping_reply)) {
1758 			printk("ping failed before setting xfrm");
1759 		}
1760 		break;
1761 	case MSG_XFRM_PREPARE:
1762 		if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst,
1763 					desc->proto)) {
1764 			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1765 			printk("failed to prepare xfrm");
1766 		}
1767 		break;
1768 	case MSG_XFRM_ADD:
1769 		if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1770 			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1771 			printk("failed to set xfrm");
1772 		}
1773 		break;
1774 	case MSG_XFRM_DEL:
1775 		if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst,
1776 					desc->proto)) {
1777 			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1778 			printk("failed to remove xfrm");
1779 		}
1780 		break;
1781 	case MSG_XFRM_CLEANUP:
1782 		if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1783 			printk("failed to cleanup xfrm");
1784 		}
1785 		break;
1786 	default:
1787 		printk("got unknown msg type %d", msg->type);
1788 	}
1789 }
1790 
1791 static int grand_child_f(unsigned int nr, int cmd_fd, void *buf)
1792 {
1793 	struct test_desc msg;
1794 	int xfrm_sock = -1;
1795 	uint32_t seq;
1796 
1797 	if (switch_ns(nsfd_childb))
1798 		exit(KSFT_FAIL);
1799 
1800 	if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1801 		printk("Failed to open xfrm netlink socket");
1802 		exit(KSFT_FAIL);
1803 	}
1804 
1805 	do {
1806 		read_msg(cmd_fd, &msg, 1);
1807 		grand_child_serv(nr, cmd_fd, buf, &msg, xfrm_sock, &seq);
1808 	} while (1);
1809 
1810 	close(xfrm_sock);
1811 	exit(KSFT_FAIL);
1812 }
1813 
1814 static int start_child(unsigned int nr, char *veth, int test_desc_fd[2])
1815 {
1816 	int cmd_sock[2];
1817 	void *data_map;
1818 	pid_t child;
1819 
1820 	if (init_child(nsfd_childa, veth, child_ip(nr), grchild_ip(nr)))
1821 		return -1;
1822 
1823 	if (init_child(nsfd_childb, veth, grchild_ip(nr), child_ip(nr)))
1824 		return -1;
1825 
1826 	child = fork();
1827 	if (child < 0) {
1828 		pr_err("fork()");
1829 		return -1;
1830 	} else if (child) {
1831 		/* in parent - selftest */
1832 		return switch_ns(nsfd_parent);
1833 	}
1834 
1835 	if (close(test_desc_fd[1])) {
1836 		pr_err("close()");
1837 		return -1;
1838 	}
1839 
1840 	/* child */
1841 	data_map = mmap(0, page_size, PROT_READ | PROT_WRITE,
1842 			MAP_SHARED | MAP_ANONYMOUS, -1, 0);
1843 	if (data_map == MAP_FAILED) {
1844 		pr_err("mmap()");
1845 		return -1;
1846 	}
1847 
1848 	randomize_buffer(data_map, page_size);
1849 
1850 	if (socketpair(PF_LOCAL, SOCK_SEQPACKET, 0, cmd_sock)) {
1851 		pr_err("socketpair()");
1852 		return -1;
1853 	}
1854 
1855 	child = fork();
1856 	if (child < 0) {
1857 		pr_err("fork()");
1858 		return -1;
1859 	} else if (child) {
1860 		if (close(cmd_sock[0])) {
1861 			pr_err("close()");
1862 			return -1;
1863 		}
1864 		return child_f(nr, test_desc_fd[0], cmd_sock[1], data_map);
1865 	}
1866 	if (close(cmd_sock[1])) {
1867 		pr_err("close()");
1868 		return -1;
1869 	}
1870 	return grand_child_f(nr, cmd_sock[0], data_map);
1871 }
1872 
1873 static void exit_usage(char **argv)
1874 {
1875 	printk("Usage: %s [nr_process]", argv[0]);
1876 	exit(KSFT_FAIL);
1877 }
1878 
1879 static int __write_desc(int test_desc_fd, struct xfrm_desc *desc)
1880 {
1881 	ssize_t ret;
1882 
1883 	ret = write(test_desc_fd, desc, sizeof(*desc));
1884 
1885 	if (ret == sizeof(*desc))
1886 		return 0;
1887 
1888 	pr_err("Writing test's desc failed %ld", ret);
1889 
1890 	return -1;
1891 }
1892 
1893 static int write_desc(int proto, int test_desc_fd,
1894 		char *a, char *e, char *c, char *ae)
1895 {
1896 	struct xfrm_desc desc = {};
1897 
1898 	desc.type = CREATE_TUNNEL;
1899 	desc.proto = proto;
1900 
1901 	if (a)
1902 		strncpy(desc.a_algo, a, ALGO_LEN - 1);
1903 	if (e)
1904 		strncpy(desc.e_algo, e, ALGO_LEN - 1);
1905 	if (c)
1906 		strncpy(desc.c_algo, c, ALGO_LEN - 1);
1907 	if (ae)
1908 		strncpy(desc.ae_algo, ae, ALGO_LEN - 1);
1909 
1910 	return __write_desc(test_desc_fd, &desc);
1911 }
1912 
1913 int proto_list[] = { IPPROTO_AH, IPPROTO_COMP, IPPROTO_ESP };
1914 char *ah_list[] = {
1915 	"digest_null", "hmac(md5)", "hmac(sha1)", "hmac(sha256)",
1916 	"hmac(sha384)", "hmac(sha512)", "hmac(rmd160)",
1917 	"xcbc(aes)", "cmac(aes)"
1918 };
1919 char *comp_list[] = {
1920 	"deflate",
1921 #if 0
1922 	/* No compression backend realization */
1923 	"lzs", "lzjh"
1924 #endif
1925 };
1926 char *e_list[] = {
1927 	"ecb(cipher_null)", "cbc(des)", "cbc(des3_ede)", "cbc(cast5)",
1928 	"cbc(blowfish)", "cbc(aes)", "cbc(serpent)", "cbc(camellia)",
1929 	"cbc(twofish)", "rfc3686(ctr(aes))"
1930 };
1931 char *ae_list[] = {
1932 #if 0
1933 	/* not implemented */
1934 	"rfc4106(gcm(aes))", "rfc4309(ccm(aes))", "rfc4543(gcm(aes))",
1935 	"rfc7539esp(chacha20,poly1305)"
1936 #endif
1937 };
1938 
1939 const unsigned int proto_plan = ARRAY_SIZE(ah_list) + ARRAY_SIZE(comp_list) \
1940 				+ (ARRAY_SIZE(ah_list) * ARRAY_SIZE(e_list)) \
1941 				+ ARRAY_SIZE(ae_list);
1942 
1943 static int write_proto_plan(int fd, int proto)
1944 {
1945 	unsigned int i;
1946 
1947 	switch (proto) {
1948 	case IPPROTO_AH:
1949 		for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
1950 			if (write_desc(proto, fd, ah_list[i], 0, 0, 0))
1951 				return -1;
1952 		}
1953 		break;
1954 	case IPPROTO_COMP:
1955 		for (i = 0; i < ARRAY_SIZE(comp_list); i++) {
1956 			if (write_desc(proto, fd, 0, 0, comp_list[i], 0))
1957 				return -1;
1958 		}
1959 		break;
1960 	case IPPROTO_ESP:
1961 		for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
1962 			int j;
1963 
1964 			for (j = 0; j < ARRAY_SIZE(e_list); j++) {
1965 				if (write_desc(proto, fd, ah_list[i],
1966 							e_list[j], 0, 0))
1967 					return -1;
1968 			}
1969 		}
1970 		for (i = 0; i < ARRAY_SIZE(ae_list); i++) {
1971 			if (write_desc(proto, fd, 0, 0, 0, ae_list[i]))
1972 				return -1;
1973 		}
1974 		break;
1975 	default:
1976 		printk("BUG: Specified unknown proto %d", proto);
1977 		return -1;
1978 	}
1979 
1980 	return 0;
1981 }
1982 
1983 /*
1984  * Some structures in xfrm uapi header differ in size between
1985  * 64-bit and 32-bit ABI:
1986  *
1987  *             32-bit UABI               |            64-bit UABI
1988  *  -------------------------------------|-------------------------------------
1989  *   sizeof(xfrm_usersa_info)     = 220  |  sizeof(xfrm_usersa_info)     = 224
1990  *   sizeof(xfrm_userpolicy_info) = 164  |  sizeof(xfrm_userpolicy_info) = 168
1991  *   sizeof(xfrm_userspi_info)    = 228  |  sizeof(xfrm_userspi_info)    = 232
1992  *   sizeof(xfrm_user_acquire)    = 276  |  sizeof(xfrm_user_acquire)    = 280
1993  *   sizeof(xfrm_user_expire)     = 224  |  sizeof(xfrm_user_expire)     = 232
1994  *   sizeof(xfrm_user_polexpire)  = 168  |  sizeof(xfrm_user_polexpire)  = 176
1995  *
1996  * Check the affected by the UABI difference structures.
1997  */
1998 const unsigned int compat_plan = 4;
1999 static int write_compat_struct_tests(int test_desc_fd)
2000 {
2001 	struct xfrm_desc desc = {};
2002 
2003 	desc.type = ALLOCATE_SPI;
2004 	desc.proto = IPPROTO_AH;
2005 	strncpy(desc.a_algo, ah_list[0], ALGO_LEN - 1);
2006 
2007 	if (__write_desc(test_desc_fd, &desc))
2008 		return -1;
2009 
2010 	desc.type = MONITOR_ACQUIRE;
2011 	if (__write_desc(test_desc_fd, &desc))
2012 		return -1;
2013 
2014 	desc.type = EXPIRE_STATE;
2015 	if (__write_desc(test_desc_fd, &desc))
2016 		return -1;
2017 
2018 	desc.type = EXPIRE_POLICY;
2019 	if (__write_desc(test_desc_fd, &desc))
2020 		return -1;
2021 
2022 	return 0;
2023 }
2024 
2025 static int write_test_plan(int test_desc_fd)
2026 {
2027 	unsigned int i;
2028 	pid_t child;
2029 
2030 	child = fork();
2031 	if (child < 0) {
2032 		pr_err("fork()");
2033 		return -1;
2034 	}
2035 	if (child) {
2036 		if (close(test_desc_fd))
2037 			printk("close(): %m");
2038 		return 0;
2039 	}
2040 
2041 	if (write_compat_struct_tests(test_desc_fd))
2042 		exit(KSFT_FAIL);
2043 
2044 	for (i = 0; i < ARRAY_SIZE(proto_list); i++) {
2045 		if (write_proto_plan(test_desc_fd, proto_list[i]))
2046 			exit(KSFT_FAIL);
2047 	}
2048 
2049 	exit(KSFT_PASS);
2050 }
2051 
2052 static int children_cleanup(void)
2053 {
2054 	unsigned ret = KSFT_PASS;
2055 
2056 	while (1) {
2057 		int status;
2058 		pid_t p = wait(&status);
2059 
2060 		if ((p < 0) && errno == ECHILD)
2061 			break;
2062 
2063 		if (p < 0) {
2064 			pr_err("wait()");
2065 			return KSFT_FAIL;
2066 		}
2067 
2068 		if (!WIFEXITED(status)) {
2069 			ret = KSFT_FAIL;
2070 			continue;
2071 		}
2072 
2073 		if (WEXITSTATUS(status) == KSFT_FAIL)
2074 			ret = KSFT_FAIL;
2075 	}
2076 
2077 	return ret;
2078 }
2079 
2080 typedef void (*print_res)(const char *, ...);
2081 
2082 static int check_results(void)
2083 {
2084 	struct test_result tr = {};
2085 	struct xfrm_desc *d = &tr.desc;
2086 	int ret = KSFT_PASS;
2087 
2088 	while (1) {
2089 		ssize_t received = read(results_fd[0], &tr, sizeof(tr));
2090 		print_res result;
2091 
2092 		if (received == 0) /* EOF */
2093 			break;
2094 
2095 		if (received != sizeof(tr)) {
2096 			pr_err("read() returned %zd", received);
2097 			return KSFT_FAIL;
2098 		}
2099 
2100 		switch (tr.res) {
2101 		case KSFT_PASS:
2102 			result = ksft_test_result_pass;
2103 			break;
2104 		case KSFT_FAIL:
2105 		default:
2106 			result = ksft_test_result_fail;
2107 			ret = KSFT_FAIL;
2108 		}
2109 
2110 		result(" %s: [%u, '%s', '%s', '%s', '%s', %u]\n",
2111 		       desc_name[d->type], (unsigned int)d->proto, d->a_algo,
2112 		       d->e_algo, d->c_algo, d->ae_algo, d->icv_len);
2113 	}
2114 
2115 	return ret;
2116 }
2117 
2118 int main(int argc, char **argv)
2119 {
2120 	unsigned int nr_process = 1;
2121 	int route_sock = -1, ret = KSFT_SKIP;
2122 	int test_desc_fd[2];
2123 	uint32_t route_seq;
2124 	unsigned int i;
2125 
2126 	if (argc > 2)
2127 		exit_usage(argv);
2128 
2129 	if (argc > 1) {
2130 		char *endptr;
2131 
2132 		errno = 0;
2133 		nr_process = strtol(argv[1], &endptr, 10);
2134 		if ((errno == ERANGE && (nr_process == LONG_MAX || nr_process == LONG_MIN))
2135 				|| (errno != 0 && nr_process == 0)
2136 				|| (endptr == argv[1]) || (*endptr != '\0')) {
2137 			printk("Failed to parse [nr_process]");
2138 			exit_usage(argv);
2139 		}
2140 
2141 		if (nr_process > MAX_PROCESSES || !nr_process) {
2142 			printk("nr_process should be between [1; %u]",
2143 					MAX_PROCESSES);
2144 			exit_usage(argv);
2145 		}
2146 	}
2147 
2148 	srand(time(NULL));
2149 	page_size = sysconf(_SC_PAGESIZE);
2150 	if (page_size < 1)
2151 		ksft_exit_skip("sysconf(): %m\n");
2152 
2153 	if (pipe2(test_desc_fd, O_DIRECT) < 0)
2154 		ksft_exit_skip("pipe(): %m\n");
2155 
2156 	if (pipe2(results_fd, O_DIRECT) < 0)
2157 		ksft_exit_skip("pipe(): %m\n");
2158 
2159 	if (init_namespaces())
2160 		ksft_exit_skip("Failed to create namespaces\n");
2161 
2162 	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
2163 		ksft_exit_skip("Failed to open netlink route socket\n");
2164 
2165 	for (i = 0; i < nr_process; i++) {
2166 		char veth[VETH_LEN];
2167 
2168 		snprintf(veth, VETH_LEN, VETH_FMT, i);
2169 
2170 		if (veth_add(route_sock, route_seq++, veth, nsfd_childa, veth, nsfd_childb)) {
2171 			close(route_sock);
2172 			ksft_exit_fail_msg("Failed to create veth device");
2173 		}
2174 
2175 		if (start_child(i, veth, test_desc_fd)) {
2176 			close(route_sock);
2177 			ksft_exit_fail_msg("Child %u failed to start", i);
2178 		}
2179 	}
2180 
2181 	if (close(route_sock) || close(test_desc_fd[0]) || close(results_fd[1]))
2182 		ksft_exit_fail_msg("close(): %m");
2183 
2184 	ksft_set_plan(proto_plan + compat_plan);
2185 
2186 	if (write_test_plan(test_desc_fd[1]))
2187 		ksft_exit_fail_msg("Failed to write test plan to pipe");
2188 
2189 	ret = check_results();
2190 
2191 	if (children_cleanup() == KSFT_FAIL)
2192 		exit(KSFT_FAIL);
2193 
2194 	exit(ret);
2195 }
2196