xref: /linux/tools/testing/selftests/net/ipsec.c (revision eb01fe7abbe2d0b38824d2a93fdb4cc3eaf2ccc1)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * ipsec.c - Check xfrm on veth inside a net-ns.
4  * Copyright (c) 2018 Dmitry Safonov
5  */
6 
7 #define _GNU_SOURCE
8 
9 #include <arpa/inet.h>
10 #include <asm/types.h>
11 #include <errno.h>
12 #include <fcntl.h>
13 #include <limits.h>
14 #include <linux/limits.h>
15 #include <linux/netlink.h>
16 #include <linux/random.h>
17 #include <linux/rtnetlink.h>
18 #include <linux/veth.h>
19 #include <linux/xfrm.h>
20 #include <netinet/in.h>
21 #include <net/if.h>
22 #include <sched.h>
23 #include <stdbool.h>
24 #include <stdint.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <sys/mman.h>
29 #include <sys/socket.h>
30 #include <sys/stat.h>
31 #include <sys/syscall.h>
32 #include <sys/types.h>
33 #include <sys/wait.h>
34 #include <time.h>
35 #include <unistd.h>
36 
37 #include "../kselftest.h"
38 
39 #define printk(fmt, ...)						\
40 	ksft_print_msg("%d[%u] " fmt "\n", getpid(), __LINE__, ##__VA_ARGS__)
41 
42 #define pr_err(fmt, ...)	printk(fmt ": %m", ##__VA_ARGS__)
43 
44 #define BUILD_BUG_ON(condition) ((void)sizeof(char[1 - 2*!!(condition)]))
45 
46 #define IPV4_STR_SZ	16	/* xxx.xxx.xxx.xxx is longest + \0 */
47 #define MAX_PAYLOAD	2048
48 #define XFRM_ALGO_KEY_BUF_SIZE	512
49 #define MAX_PROCESSES	(1 << 14) /* /16 mask divided by /30 subnets */
50 #define INADDR_A	((in_addr_t) 0x0a000000) /* 10.0.0.0 */
51 #define INADDR_B	((in_addr_t) 0xc0a80000) /* 192.168.0.0 */
52 
53 /* /30 mask for one veth connection */
54 #define PREFIX_LEN	30
55 #define child_ip(nr)	(4*nr + 1)
56 #define grchild_ip(nr)	(4*nr + 2)
57 
58 #define VETH_FMT	"ktst-%d"
59 #define VETH_LEN	12
60 
61 #define XFRM_ALGO_NR_KEYS 29
62 
63 static int nsfd_parent	= -1;
64 static int nsfd_childa	= -1;
65 static int nsfd_childb	= -1;
66 static long page_size;
67 
68 /*
69  * ksft_cnt is static in kselftest, so isn't shared with children.
70  * We have to send a test result back to parent and count there.
71  * results_fd is a pipe with test feedback from children.
72  */
73 static int results_fd[2];
74 
75 const unsigned int ping_delay_nsec	= 50 * 1000 * 1000;
76 const unsigned int ping_timeout		= 300;
77 const unsigned int ping_count		= 100;
78 const unsigned int ping_success		= 80;
79 
80 struct xfrm_key_entry {
81 	char algo_name[35];
82 	int key_len;
83 };
84 
85 struct xfrm_key_entry xfrm_key_entries[] = {
86 	{"digest_null", 0},
87 	{"ecb(cipher_null)", 0},
88 	{"cbc(des)", 64},
89 	{"hmac(md5)", 128},
90 	{"cmac(aes)", 128},
91 	{"xcbc(aes)", 128},
92 	{"cbc(cast5)", 128},
93 	{"cbc(serpent)", 128},
94 	{"hmac(sha1)", 160},
95 	{"hmac(rmd160)", 160},
96 	{"cbc(des3_ede)", 192},
97 	{"hmac(sha256)", 256},
98 	{"cbc(aes)", 256},
99 	{"cbc(camellia)", 256},
100 	{"cbc(twofish)", 256},
101 	{"rfc3686(ctr(aes))", 288},
102 	{"hmac(sha384)", 384},
103 	{"cbc(blowfish)", 448},
104 	{"hmac(sha512)", 512},
105 	{"rfc4106(gcm(aes))-128", 160},
106 	{"rfc4543(gcm(aes))-128", 160},
107 	{"rfc4309(ccm(aes))-128", 152},
108 	{"rfc4106(gcm(aes))-192", 224},
109 	{"rfc4543(gcm(aes))-192", 224},
110 	{"rfc4309(ccm(aes))-192", 216},
111 	{"rfc4106(gcm(aes))-256", 288},
112 	{"rfc4543(gcm(aes))-256", 288},
113 	{"rfc4309(ccm(aes))-256", 280},
114 	{"rfc7539(chacha20,poly1305)-128", 0}
115 };
116 
117 static void randomize_buffer(void *buf, size_t buflen)
118 {
119 	int *p = (int *)buf;
120 	size_t words = buflen / sizeof(int);
121 	size_t leftover = buflen % sizeof(int);
122 
123 	if (!buflen)
124 		return;
125 
126 	while (words--)
127 		*p++ = rand();
128 
129 	if (leftover) {
130 		int tmp = rand();
131 
132 		memcpy(buf + buflen - leftover, &tmp, leftover);
133 	}
134 
135 	return;
136 }
137 
138 static int unshare_open(void)
139 {
140 	const char *netns_path = "/proc/self/ns/net";
141 	int fd;
142 
143 	if (unshare(CLONE_NEWNET) != 0) {
144 		pr_err("unshare()");
145 		return -1;
146 	}
147 
148 	fd = open(netns_path, O_RDONLY);
149 	if (fd <= 0) {
150 		pr_err("open(%s)", netns_path);
151 		return -1;
152 	}
153 
154 	return fd;
155 }
156 
157 static int switch_ns(int fd)
158 {
159 	if (setns(fd, CLONE_NEWNET)) {
160 		pr_err("setns()");
161 		return -1;
162 	}
163 	return 0;
164 }
165 
166 /*
167  * Running the test inside a new parent net namespace to bother less
168  * about cleanup on error-path.
169  */
170 static int init_namespaces(void)
171 {
172 	nsfd_parent = unshare_open();
173 	if (nsfd_parent <= 0)
174 		return -1;
175 
176 	nsfd_childa = unshare_open();
177 	if (nsfd_childa <= 0)
178 		return -1;
179 
180 	if (switch_ns(nsfd_parent))
181 		return -1;
182 
183 	nsfd_childb = unshare_open();
184 	if (nsfd_childb <= 0)
185 		return -1;
186 
187 	if (switch_ns(nsfd_parent))
188 		return -1;
189 	return 0;
190 }
191 
192 static int netlink_sock(int *sock, uint32_t *seq_nr, int proto)
193 {
194 	if (*sock > 0) {
195 		seq_nr++;
196 		return 0;
197 	}
198 
199 	*sock = socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, proto);
200 	if (*sock <= 0) {
201 		pr_err("socket(AF_NETLINK)");
202 		return -1;
203 	}
204 
205 	randomize_buffer(seq_nr, sizeof(*seq_nr));
206 
207 	return 0;
208 }
209 
210 static inline struct rtattr *rtattr_hdr(struct nlmsghdr *nh)
211 {
212 	return (struct rtattr *)((char *)(nh) + RTA_ALIGN((nh)->nlmsg_len));
213 }
214 
215 static int rtattr_pack(struct nlmsghdr *nh, size_t req_sz,
216 		unsigned short rta_type, const void *payload, size_t size)
217 {
218 	/* NLMSG_ALIGNTO == RTA_ALIGNTO, nlmsg_len already aligned */
219 	struct rtattr *attr = rtattr_hdr(nh);
220 	size_t nl_size = RTA_ALIGN(nh->nlmsg_len) + RTA_LENGTH(size);
221 
222 	if (req_sz < nl_size) {
223 		printk("req buf is too small: %zu < %zu", req_sz, nl_size);
224 		return -1;
225 	}
226 	nh->nlmsg_len = nl_size;
227 
228 	attr->rta_len = RTA_LENGTH(size);
229 	attr->rta_type = rta_type;
230 	memcpy(RTA_DATA(attr), payload, size);
231 
232 	return 0;
233 }
234 
235 static struct rtattr *_rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
236 		unsigned short rta_type, const void *payload, size_t size)
237 {
238 	struct rtattr *ret = rtattr_hdr(nh);
239 
240 	if (rtattr_pack(nh, req_sz, rta_type, payload, size))
241 		return 0;
242 
243 	return ret;
244 }
245 
246 static inline struct rtattr *rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
247 		unsigned short rta_type)
248 {
249 	return _rtattr_begin(nh, req_sz, rta_type, 0, 0);
250 }
251 
252 static inline void rtattr_end(struct nlmsghdr *nh, struct rtattr *attr)
253 {
254 	char *nlmsg_end = (char *)nh + nh->nlmsg_len;
255 
256 	attr->rta_len = nlmsg_end - (char *)attr;
257 }
258 
259 static int veth_pack_peerb(struct nlmsghdr *nh, size_t req_sz,
260 		const char *peer, int ns)
261 {
262 	struct ifinfomsg pi;
263 	struct rtattr *peer_attr;
264 
265 	memset(&pi, 0, sizeof(pi));
266 	pi.ifi_family	= AF_UNSPEC;
267 	pi.ifi_change	= 0xFFFFFFFF;
268 
269 	peer_attr = _rtattr_begin(nh, req_sz, VETH_INFO_PEER, &pi, sizeof(pi));
270 	if (!peer_attr)
271 		return -1;
272 
273 	if (rtattr_pack(nh, req_sz, IFLA_IFNAME, peer, strlen(peer)))
274 		return -1;
275 
276 	if (rtattr_pack(nh, req_sz, IFLA_NET_NS_FD, &ns, sizeof(ns)))
277 		return -1;
278 
279 	rtattr_end(nh, peer_attr);
280 
281 	return 0;
282 }
283 
284 static int netlink_check_answer(int sock)
285 {
286 	struct nlmsgerror {
287 		struct nlmsghdr hdr;
288 		int error;
289 		struct nlmsghdr orig_msg;
290 	} answer;
291 
292 	if (recv(sock, &answer, sizeof(answer), 0) < 0) {
293 		pr_err("recv()");
294 		return -1;
295 	} else if (answer.hdr.nlmsg_type != NLMSG_ERROR) {
296 		printk("expected NLMSG_ERROR, got %d", (int)answer.hdr.nlmsg_type);
297 		return -1;
298 	} else if (answer.error) {
299 		printk("NLMSG_ERROR: %d: %s",
300 			answer.error, strerror(-answer.error));
301 		return answer.error;
302 	}
303 
304 	return 0;
305 }
306 
307 static int veth_add(int sock, uint32_t seq, const char *peera, int ns_a,
308 		const char *peerb, int ns_b)
309 {
310 	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
311 	struct {
312 		struct nlmsghdr		nh;
313 		struct ifinfomsg	info;
314 		char			attrbuf[MAX_PAYLOAD];
315 	} req;
316 	const char veth_type[] = "veth";
317 	struct rtattr *link_info, *info_data;
318 
319 	memset(&req, 0, sizeof(req));
320 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
321 	req.nh.nlmsg_type	= RTM_NEWLINK;
322 	req.nh.nlmsg_flags	= flags;
323 	req.nh.nlmsg_seq	= seq;
324 	req.info.ifi_family	= AF_UNSPEC;
325 	req.info.ifi_change	= 0xFFFFFFFF;
326 
327 	if (rtattr_pack(&req.nh, sizeof(req), IFLA_IFNAME, peera, strlen(peera)))
328 		return -1;
329 
330 	if (rtattr_pack(&req.nh, sizeof(req), IFLA_NET_NS_FD, &ns_a, sizeof(ns_a)))
331 		return -1;
332 
333 	link_info = rtattr_begin(&req.nh, sizeof(req), IFLA_LINKINFO);
334 	if (!link_info)
335 		return -1;
336 
337 	if (rtattr_pack(&req.nh, sizeof(req), IFLA_INFO_KIND, veth_type, sizeof(veth_type)))
338 		return -1;
339 
340 	info_data = rtattr_begin(&req.nh, sizeof(req), IFLA_INFO_DATA);
341 	if (!info_data)
342 		return -1;
343 
344 	if (veth_pack_peerb(&req.nh, sizeof(req), peerb, ns_b))
345 		return -1;
346 
347 	rtattr_end(&req.nh, info_data);
348 	rtattr_end(&req.nh, link_info);
349 
350 	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
351 		pr_err("send()");
352 		return -1;
353 	}
354 	return netlink_check_answer(sock);
355 }
356 
357 static int ip4_addr_set(int sock, uint32_t seq, const char *intf,
358 		struct in_addr addr, uint8_t prefix)
359 {
360 	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
361 	struct {
362 		struct nlmsghdr		nh;
363 		struct ifaddrmsg	info;
364 		char			attrbuf[MAX_PAYLOAD];
365 	} req;
366 
367 	memset(&req, 0, sizeof(req));
368 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
369 	req.nh.nlmsg_type	= RTM_NEWADDR;
370 	req.nh.nlmsg_flags	= flags;
371 	req.nh.nlmsg_seq	= seq;
372 	req.info.ifa_family	= AF_INET;
373 	req.info.ifa_prefixlen	= prefix;
374 	req.info.ifa_index	= if_nametoindex(intf);
375 
376 #ifdef DEBUG
377 	{
378 		char addr_str[IPV4_STR_SZ] = {};
379 
380 		strncpy(addr_str, inet_ntoa(addr), IPV4_STR_SZ - 1);
381 
382 		printk("ip addr set %s", addr_str);
383 	}
384 #endif
385 
386 	if (rtattr_pack(&req.nh, sizeof(req), IFA_LOCAL, &addr, sizeof(addr)))
387 		return -1;
388 
389 	if (rtattr_pack(&req.nh, sizeof(req), IFA_ADDRESS, &addr, sizeof(addr)))
390 		return -1;
391 
392 	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
393 		pr_err("send()");
394 		return -1;
395 	}
396 	return netlink_check_answer(sock);
397 }
398 
399 static int link_set_up(int sock, uint32_t seq, const char *intf)
400 {
401 	struct {
402 		struct nlmsghdr		nh;
403 		struct ifinfomsg	info;
404 		char			attrbuf[MAX_PAYLOAD];
405 	} req;
406 
407 	memset(&req, 0, sizeof(req));
408 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
409 	req.nh.nlmsg_type	= RTM_NEWLINK;
410 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
411 	req.nh.nlmsg_seq	= seq;
412 	req.info.ifi_family	= AF_UNSPEC;
413 	req.info.ifi_change	= 0xFFFFFFFF;
414 	req.info.ifi_index	= if_nametoindex(intf);
415 	req.info.ifi_flags	= IFF_UP;
416 	req.info.ifi_change	= IFF_UP;
417 
418 	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
419 		pr_err("send()");
420 		return -1;
421 	}
422 	return netlink_check_answer(sock);
423 }
424 
425 static int ip4_route_set(int sock, uint32_t seq, const char *intf,
426 		struct in_addr src, struct in_addr dst)
427 {
428 	struct {
429 		struct nlmsghdr	nh;
430 		struct rtmsg	rt;
431 		char		attrbuf[MAX_PAYLOAD];
432 	} req;
433 	unsigned int index = if_nametoindex(intf);
434 
435 	memset(&req, 0, sizeof(req));
436 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.rt));
437 	req.nh.nlmsg_type	= RTM_NEWROUTE;
438 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE;
439 	req.nh.nlmsg_seq	= seq;
440 	req.rt.rtm_family	= AF_INET;
441 	req.rt.rtm_dst_len	= 32;
442 	req.rt.rtm_table	= RT_TABLE_MAIN;
443 	req.rt.rtm_protocol	= RTPROT_BOOT;
444 	req.rt.rtm_scope	= RT_SCOPE_LINK;
445 	req.rt.rtm_type		= RTN_UNICAST;
446 
447 	if (rtattr_pack(&req.nh, sizeof(req), RTA_DST, &dst, sizeof(dst)))
448 		return -1;
449 
450 	if (rtattr_pack(&req.nh, sizeof(req), RTA_PREFSRC, &src, sizeof(src)))
451 		return -1;
452 
453 	if (rtattr_pack(&req.nh, sizeof(req), RTA_OIF, &index, sizeof(index)))
454 		return -1;
455 
456 	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
457 		pr_err("send()");
458 		return -1;
459 	}
460 
461 	return netlink_check_answer(sock);
462 }
463 
464 static int tunnel_set_route(int route_sock, uint32_t *route_seq, char *veth,
465 		struct in_addr tunsrc, struct in_addr tundst)
466 {
467 	if (ip4_addr_set(route_sock, (*route_seq)++, "lo",
468 			tunsrc, PREFIX_LEN)) {
469 		printk("Failed to set ipv4 addr");
470 		return -1;
471 	}
472 
473 	if (ip4_route_set(route_sock, (*route_seq)++, veth, tunsrc, tundst)) {
474 		printk("Failed to set ipv4 route");
475 		return -1;
476 	}
477 
478 	return 0;
479 }
480 
481 static int init_child(int nsfd, char *veth, unsigned int src, unsigned int dst)
482 {
483 	struct in_addr intsrc = inet_makeaddr(INADDR_B, src);
484 	struct in_addr tunsrc = inet_makeaddr(INADDR_A, src);
485 	struct in_addr tundst = inet_makeaddr(INADDR_A, dst);
486 	int route_sock = -1, ret = -1;
487 	uint32_t route_seq;
488 
489 	if (switch_ns(nsfd))
490 		return -1;
491 
492 	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) {
493 		printk("Failed to open netlink route socket in child");
494 		return -1;
495 	}
496 
497 	if (ip4_addr_set(route_sock, route_seq++, veth, intsrc, PREFIX_LEN)) {
498 		printk("Failed to set ipv4 addr");
499 		goto err;
500 	}
501 
502 	if (link_set_up(route_sock, route_seq++, veth)) {
503 		printk("Failed to bring up %s", veth);
504 		goto err;
505 	}
506 
507 	if (tunnel_set_route(route_sock, &route_seq, veth, tunsrc, tundst)) {
508 		printk("Failed to add tunnel route on %s", veth);
509 		goto err;
510 	}
511 	ret = 0;
512 
513 err:
514 	close(route_sock);
515 	return ret;
516 }
517 
518 #define ALGO_LEN	64
519 enum desc_type {
520 	CREATE_TUNNEL	= 0,
521 	ALLOCATE_SPI,
522 	MONITOR_ACQUIRE,
523 	EXPIRE_STATE,
524 	EXPIRE_POLICY,
525 	SPDINFO_ATTRS,
526 };
527 const char *desc_name[] = {
528 	"create tunnel",
529 	"alloc spi",
530 	"monitor acquire",
531 	"expire state",
532 	"expire policy",
533 	"spdinfo attributes",
534 	""
535 };
536 struct xfrm_desc {
537 	enum desc_type	type;
538 	uint8_t		proto;
539 	char		a_algo[ALGO_LEN];
540 	char		e_algo[ALGO_LEN];
541 	char		c_algo[ALGO_LEN];
542 	char		ae_algo[ALGO_LEN];
543 	unsigned int	icv_len;
544 	/* unsigned key_len; */
545 };
546 
547 enum msg_type {
548 	MSG_ACK		= 0,
549 	MSG_EXIT,
550 	MSG_PING,
551 	MSG_XFRM_PREPARE,
552 	MSG_XFRM_ADD,
553 	MSG_XFRM_DEL,
554 	MSG_XFRM_CLEANUP,
555 };
556 
557 struct test_desc {
558 	enum msg_type type;
559 	union {
560 		struct {
561 			in_addr_t reply_ip;
562 			unsigned int port;
563 		} ping;
564 		struct xfrm_desc xfrm_desc;
565 	} body;
566 };
567 
568 struct test_result {
569 	struct xfrm_desc desc;
570 	unsigned int res;
571 };
572 
573 static void write_test_result(unsigned int res, struct xfrm_desc *d)
574 {
575 	struct test_result tr = {};
576 	ssize_t ret;
577 
578 	tr.desc = *d;
579 	tr.res = res;
580 
581 	ret = write(results_fd[1], &tr, sizeof(tr));
582 	if (ret != sizeof(tr))
583 		pr_err("Failed to write the result in pipe %zd", ret);
584 }
585 
586 static void write_msg(int fd, struct test_desc *msg, bool exit_of_fail)
587 {
588 	ssize_t bytes = write(fd, msg, sizeof(*msg));
589 
590 	/* Make sure that write/read is atomic to a pipe */
591 	BUILD_BUG_ON(sizeof(struct test_desc) > PIPE_BUF);
592 
593 	if (bytes < 0) {
594 		pr_err("write()");
595 		if (exit_of_fail)
596 			exit(KSFT_FAIL);
597 	}
598 	if (bytes != sizeof(*msg)) {
599 		pr_err("sent part of the message %zd/%zu", bytes, sizeof(*msg));
600 		if (exit_of_fail)
601 			exit(KSFT_FAIL);
602 	}
603 }
604 
605 static void read_msg(int fd, struct test_desc *msg, bool exit_of_fail)
606 {
607 	ssize_t bytes = read(fd, msg, sizeof(*msg));
608 
609 	if (bytes < 0) {
610 		pr_err("read()");
611 		if (exit_of_fail)
612 			exit(KSFT_FAIL);
613 	}
614 	if (bytes != sizeof(*msg)) {
615 		pr_err("got incomplete message %zd/%zu", bytes, sizeof(*msg));
616 		if (exit_of_fail)
617 			exit(KSFT_FAIL);
618 	}
619 }
620 
621 static int udp_ping_init(struct in_addr listen_ip, unsigned int u_timeout,
622 		unsigned int *server_port, int sock[2])
623 {
624 	struct sockaddr_in server;
625 	struct timeval t = { .tv_sec = 0, .tv_usec = u_timeout };
626 	socklen_t s_len = sizeof(server);
627 
628 	sock[0] = socket(AF_INET, SOCK_DGRAM, 0);
629 	if (sock[0] < 0) {
630 		pr_err("socket()");
631 		return -1;
632 	}
633 
634 	server.sin_family	= AF_INET;
635 	server.sin_port		= 0;
636 	memcpy(&server.sin_addr.s_addr, &listen_ip, sizeof(struct in_addr));
637 
638 	if (bind(sock[0], (struct sockaddr *)&server, s_len)) {
639 		pr_err("bind()");
640 		goto err_close_server;
641 	}
642 
643 	if (getsockname(sock[0], (struct sockaddr *)&server, &s_len)) {
644 		pr_err("getsockname()");
645 		goto err_close_server;
646 	}
647 
648 	*server_port = ntohs(server.sin_port);
649 
650 	if (setsockopt(sock[0], SOL_SOCKET, SO_RCVTIMEO, (const char *)&t, sizeof t)) {
651 		pr_err("setsockopt()");
652 		goto err_close_server;
653 	}
654 
655 	sock[1] = socket(AF_INET, SOCK_DGRAM, 0);
656 	if (sock[1] < 0) {
657 		pr_err("socket()");
658 		goto err_close_server;
659 	}
660 
661 	return 0;
662 
663 err_close_server:
664 	close(sock[0]);
665 	return -1;
666 }
667 
668 static int udp_ping_send(int sock[2], in_addr_t dest_ip, unsigned int port,
669 		char *buf, size_t buf_len)
670 {
671 	struct sockaddr_in server;
672 	const struct sockaddr *dest_addr = (struct sockaddr *)&server;
673 	char *sock_buf[buf_len];
674 	ssize_t r_bytes, s_bytes;
675 
676 	server.sin_family	= AF_INET;
677 	server.sin_port		= htons(port);
678 	server.sin_addr.s_addr	= dest_ip;
679 
680 	s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
681 	if (s_bytes < 0) {
682 		pr_err("sendto()");
683 		return -1;
684 	} else if (s_bytes != buf_len) {
685 		printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
686 		return -1;
687 	}
688 
689 	r_bytes = recv(sock[0], sock_buf, buf_len, 0);
690 	if (r_bytes < 0) {
691 		if (errno != EAGAIN)
692 			pr_err("recv()");
693 		return -1;
694 	} else if (r_bytes == 0) { /* EOF */
695 		printk("EOF on reply to ping");
696 		return -1;
697 	} else if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
698 		printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
699 		return -1;
700 	}
701 
702 	return 0;
703 }
704 
705 static int udp_ping_reply(int sock[2], in_addr_t dest_ip, unsigned int port,
706 		char *buf, size_t buf_len)
707 {
708 	struct sockaddr_in server;
709 	const struct sockaddr *dest_addr = (struct sockaddr *)&server;
710 	char *sock_buf[buf_len];
711 	ssize_t r_bytes, s_bytes;
712 
713 	server.sin_family	= AF_INET;
714 	server.sin_port		= htons(port);
715 	server.sin_addr.s_addr	= dest_ip;
716 
717 	r_bytes = recv(sock[0], sock_buf, buf_len, 0);
718 	if (r_bytes < 0) {
719 		if (errno != EAGAIN)
720 			pr_err("recv()");
721 		return -1;
722 	}
723 	if (r_bytes == 0) { /* EOF */
724 		printk("EOF on reply to ping");
725 		return -1;
726 	}
727 	if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
728 		printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
729 		return -1;
730 	}
731 
732 	s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
733 	if (s_bytes < 0) {
734 		pr_err("sendto()");
735 		return -1;
736 	} else if (s_bytes != buf_len) {
737 		printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
738 		return -1;
739 	}
740 
741 	return 0;
742 }
743 
744 typedef int (*ping_f)(int sock[2], in_addr_t dest_ip, unsigned int port,
745 		char *buf, size_t buf_len);
746 static int do_ping(int cmd_fd, char *buf, size_t buf_len, struct in_addr from,
747 		bool init_side, int d_port, in_addr_t to, ping_f func)
748 {
749 	struct test_desc msg;
750 	unsigned int s_port, i, ping_succeeded = 0;
751 	int ping_sock[2];
752 	char to_str[IPV4_STR_SZ] = {}, from_str[IPV4_STR_SZ] = {};
753 
754 	if (udp_ping_init(from, ping_timeout, &s_port, ping_sock)) {
755 		printk("Failed to init ping");
756 		return -1;
757 	}
758 
759 	memset(&msg, 0, sizeof(msg));
760 	msg.type		= MSG_PING;
761 	msg.body.ping.port	= s_port;
762 	memcpy(&msg.body.ping.reply_ip, &from, sizeof(from));
763 
764 	write_msg(cmd_fd, &msg, 0);
765 	if (init_side) {
766 		/* The other end sends ip to ping */
767 		read_msg(cmd_fd, &msg, 0);
768 		if (msg.type != MSG_PING)
769 			return -1;
770 		to = msg.body.ping.reply_ip;
771 		d_port = msg.body.ping.port;
772 	}
773 
774 	for (i = 0; i < ping_count ; i++) {
775 		struct timespec sleep_time = {
776 			.tv_sec = 0,
777 			.tv_nsec = ping_delay_nsec,
778 		};
779 
780 		ping_succeeded += !func(ping_sock, to, d_port, buf, page_size);
781 		nanosleep(&sleep_time, 0);
782 	}
783 
784 	close(ping_sock[0]);
785 	close(ping_sock[1]);
786 
787 	strncpy(to_str, inet_ntoa(*(struct in_addr *)&to), IPV4_STR_SZ - 1);
788 	strncpy(from_str, inet_ntoa(from), IPV4_STR_SZ - 1);
789 
790 	if (ping_succeeded < ping_success) {
791 		printk("ping (%s) %s->%s failed %u/%u times",
792 			init_side ? "send" : "reply", from_str, to_str,
793 			ping_count - ping_succeeded, ping_count);
794 		return -1;
795 	}
796 
797 #ifdef DEBUG
798 	printk("ping (%s) %s->%s succeeded %u/%u times",
799 		init_side ? "send" : "reply", from_str, to_str,
800 		ping_succeeded, ping_count);
801 #endif
802 
803 	return 0;
804 }
805 
806 static int xfrm_fill_key(char *name, char *buf,
807 		size_t buf_len, unsigned int *key_len)
808 {
809 	int i;
810 
811 	for (i = 0; i < XFRM_ALGO_NR_KEYS; i++) {
812 		if (strncmp(name, xfrm_key_entries[i].algo_name, ALGO_LEN) == 0)
813 			*key_len = xfrm_key_entries[i].key_len;
814 	}
815 
816 	if (*key_len > buf_len) {
817 		printk("Can't pack a key - too big for buffer");
818 		return -1;
819 	}
820 
821 	randomize_buffer(buf, *key_len);
822 
823 	return 0;
824 }
825 
826 static int xfrm_state_pack_algo(struct nlmsghdr *nh, size_t req_sz,
827 		struct xfrm_desc *desc)
828 {
829 	struct {
830 		union {
831 			struct xfrm_algo	alg;
832 			struct xfrm_algo_aead	aead;
833 			struct xfrm_algo_auth	auth;
834 		} u;
835 		char buf[XFRM_ALGO_KEY_BUF_SIZE];
836 	} alg = {};
837 	size_t alen, elen, clen, aelen;
838 	unsigned short type;
839 
840 	alen = strlen(desc->a_algo);
841 	elen = strlen(desc->e_algo);
842 	clen = strlen(desc->c_algo);
843 	aelen = strlen(desc->ae_algo);
844 
845 	/* Verify desc */
846 	switch (desc->proto) {
847 	case IPPROTO_AH:
848 		if (!alen || elen || clen || aelen) {
849 			printk("BUG: buggy ah desc");
850 			return -1;
851 		}
852 		strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN - 1);
853 		if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
854 				sizeof(alg.buf), &alg.u.alg.alg_key_len))
855 			return -1;
856 		type = XFRMA_ALG_AUTH;
857 		break;
858 	case IPPROTO_COMP:
859 		if (!clen || elen || alen || aelen) {
860 			printk("BUG: buggy comp desc");
861 			return -1;
862 		}
863 		strncpy(alg.u.alg.alg_name, desc->c_algo, ALGO_LEN - 1);
864 		if (xfrm_fill_key(desc->c_algo, alg.u.alg.alg_key,
865 				sizeof(alg.buf), &alg.u.alg.alg_key_len))
866 			return -1;
867 		type = XFRMA_ALG_COMP;
868 		break;
869 	case IPPROTO_ESP:
870 		if (!((alen && elen) ^ aelen) || clen) {
871 			printk("BUG: buggy esp desc");
872 			return -1;
873 		}
874 		if (aelen) {
875 			alg.u.aead.alg_icv_len = desc->icv_len;
876 			strncpy(alg.u.aead.alg_name, desc->ae_algo, ALGO_LEN - 1);
877 			if (xfrm_fill_key(desc->ae_algo, alg.u.aead.alg_key,
878 						sizeof(alg.buf), &alg.u.aead.alg_key_len))
879 				return -1;
880 			type = XFRMA_ALG_AEAD;
881 		} else {
882 
883 			strncpy(alg.u.alg.alg_name, desc->e_algo, ALGO_LEN - 1);
884 			type = XFRMA_ALG_CRYPT;
885 			if (xfrm_fill_key(desc->e_algo, alg.u.alg.alg_key,
886 						sizeof(alg.buf), &alg.u.alg.alg_key_len))
887 				return -1;
888 			if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
889 				return -1;
890 
891 			strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN);
892 			type = XFRMA_ALG_AUTH;
893 			if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
894 						sizeof(alg.buf), &alg.u.alg.alg_key_len))
895 				return -1;
896 		}
897 		break;
898 	default:
899 		printk("BUG: unknown proto in desc");
900 		return -1;
901 	}
902 
903 	if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
904 		return -1;
905 
906 	return 0;
907 }
908 
909 static inline uint32_t gen_spi(struct in_addr src)
910 {
911 	return htonl(inet_lnaof(src));
912 }
913 
914 static int xfrm_state_add(int xfrm_sock, uint32_t seq, uint32_t spi,
915 		struct in_addr src, struct in_addr dst,
916 		struct xfrm_desc *desc)
917 {
918 	struct {
919 		struct nlmsghdr		nh;
920 		struct xfrm_usersa_info	info;
921 		char			attrbuf[MAX_PAYLOAD];
922 	} req;
923 
924 	memset(&req, 0, sizeof(req));
925 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
926 	req.nh.nlmsg_type	= XFRM_MSG_NEWSA;
927 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
928 	req.nh.nlmsg_seq	= seq;
929 
930 	/* Fill selector. */
931 	memcpy(&req.info.sel.daddr, &dst, sizeof(dst));
932 	memcpy(&req.info.sel.saddr, &src, sizeof(src));
933 	req.info.sel.family		= AF_INET;
934 	req.info.sel.prefixlen_d	= PREFIX_LEN;
935 	req.info.sel.prefixlen_s	= PREFIX_LEN;
936 
937 	/* Fill id */
938 	memcpy(&req.info.id.daddr, &dst, sizeof(dst));
939 	/* Note: zero-spi cannot be deleted */
940 	req.info.id.spi = spi;
941 	req.info.id.proto	= desc->proto;
942 
943 	memcpy(&req.info.saddr, &src, sizeof(src));
944 
945 	/* Fill lifteme_cfg */
946 	req.info.lft.soft_byte_limit	= XFRM_INF;
947 	req.info.lft.hard_byte_limit	= XFRM_INF;
948 	req.info.lft.soft_packet_limit	= XFRM_INF;
949 	req.info.lft.hard_packet_limit	= XFRM_INF;
950 
951 	req.info.family		= AF_INET;
952 	req.info.mode		= XFRM_MODE_TUNNEL;
953 
954 	if (xfrm_state_pack_algo(&req.nh, sizeof(req), desc))
955 		return -1;
956 
957 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
958 		pr_err("send()");
959 		return -1;
960 	}
961 
962 	return netlink_check_answer(xfrm_sock);
963 }
964 
965 static bool xfrm_usersa_found(struct xfrm_usersa_info *info, uint32_t spi,
966 		struct in_addr src, struct in_addr dst,
967 		struct xfrm_desc *desc)
968 {
969 	if (memcmp(&info->sel.daddr, &dst, sizeof(dst)))
970 		return false;
971 
972 	if (memcmp(&info->sel.saddr, &src, sizeof(src)))
973 		return false;
974 
975 	if (info->sel.family != AF_INET					||
976 			info->sel.prefixlen_d != PREFIX_LEN		||
977 			info->sel.prefixlen_s != PREFIX_LEN)
978 		return false;
979 
980 	if (info->id.spi != spi || info->id.proto != desc->proto)
981 		return false;
982 
983 	if (memcmp(&info->id.daddr, &dst, sizeof(dst)))
984 		return false;
985 
986 	if (memcmp(&info->saddr, &src, sizeof(src)))
987 		return false;
988 
989 	if (info->lft.soft_byte_limit != XFRM_INF			||
990 			info->lft.hard_byte_limit != XFRM_INF		||
991 			info->lft.soft_packet_limit != XFRM_INF		||
992 			info->lft.hard_packet_limit != XFRM_INF)
993 		return false;
994 
995 	if (info->family != AF_INET || info->mode != XFRM_MODE_TUNNEL)
996 		return false;
997 
998 	/* XXX: check xfrm algo, see xfrm_state_pack_algo(). */
999 
1000 	return true;
1001 }
1002 
1003 static int xfrm_state_check(int xfrm_sock, uint32_t seq, uint32_t spi,
1004 		struct in_addr src, struct in_addr dst,
1005 		struct xfrm_desc *desc)
1006 {
1007 	struct {
1008 		struct nlmsghdr		nh;
1009 		char			attrbuf[MAX_PAYLOAD];
1010 	} req;
1011 	struct {
1012 		struct nlmsghdr		nh;
1013 		union {
1014 			struct xfrm_usersa_info	info;
1015 			int error;
1016 		};
1017 		char			attrbuf[MAX_PAYLOAD];
1018 	} answer;
1019 	struct xfrm_address_filter filter = {};
1020 	bool found = false;
1021 
1022 
1023 	memset(&req, 0, sizeof(req));
1024 	req.nh.nlmsg_len	= NLMSG_LENGTH(0);
1025 	req.nh.nlmsg_type	= XFRM_MSG_GETSA;
1026 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_DUMP;
1027 	req.nh.nlmsg_seq	= seq;
1028 
1029 	/*
1030 	 * Add dump filter by source address as there may be other tunnels
1031 	 * in this netns (if tests run in parallel).
1032 	 */
1033 	filter.family = AF_INET;
1034 	filter.splen = 0x1f;	/* 0xffffffff mask see addr_match() */
1035 	memcpy(&filter.saddr, &src, sizeof(src));
1036 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_ADDRESS_FILTER,
1037 				&filter, sizeof(filter)))
1038 		return -1;
1039 
1040 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1041 		pr_err("send()");
1042 		return -1;
1043 	}
1044 
1045 	while (1) {
1046 		if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1047 			pr_err("recv()");
1048 			return -1;
1049 		}
1050 		if (answer.nh.nlmsg_type == NLMSG_ERROR) {
1051 			printk("NLMSG_ERROR: %d: %s",
1052 				answer.error, strerror(-answer.error));
1053 			return -1;
1054 		} else if (answer.nh.nlmsg_type == NLMSG_DONE) {
1055 			if (found)
1056 				return 0;
1057 			printk("didn't find allocated xfrm state in dump");
1058 			return -1;
1059 		} else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1060 			if (xfrm_usersa_found(&answer.info, spi, src, dst, desc))
1061 				found = true;
1062 		}
1063 	}
1064 }
1065 
1066 static int xfrm_set(int xfrm_sock, uint32_t *seq,
1067 		struct in_addr src, struct in_addr dst,
1068 		struct in_addr tunsrc, struct in_addr tundst,
1069 		struct xfrm_desc *desc)
1070 {
1071 	int err;
1072 
1073 	err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1074 	if (err) {
1075 		printk("Failed to add xfrm state");
1076 		return -1;
1077 	}
1078 
1079 	err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1080 	if (err) {
1081 		printk("Failed to add xfrm state");
1082 		return -1;
1083 	}
1084 
1085 	/* Check dumps for XFRM_MSG_GETSA */
1086 	err = xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1087 	err |= xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1088 	if (err) {
1089 		printk("Failed to check xfrm state");
1090 		return -1;
1091 	}
1092 
1093 	return 0;
1094 }
1095 
1096 static int xfrm_policy_add(int xfrm_sock, uint32_t seq, uint32_t spi,
1097 		struct in_addr src, struct in_addr dst, uint8_t dir,
1098 		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1099 {
1100 	struct {
1101 		struct nlmsghdr			nh;
1102 		struct xfrm_userpolicy_info	info;
1103 		char				attrbuf[MAX_PAYLOAD];
1104 	} req;
1105 	struct xfrm_user_tmpl tmpl;
1106 
1107 	memset(&req, 0, sizeof(req));
1108 	memset(&tmpl, 0, sizeof(tmpl));
1109 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
1110 	req.nh.nlmsg_type	= XFRM_MSG_NEWPOLICY;
1111 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1112 	req.nh.nlmsg_seq	= seq;
1113 
1114 	/* Fill selector. */
1115 	memcpy(&req.info.sel.daddr, &dst, sizeof(tundst));
1116 	memcpy(&req.info.sel.saddr, &src, sizeof(tunsrc));
1117 	req.info.sel.family		= AF_INET;
1118 	req.info.sel.prefixlen_d	= PREFIX_LEN;
1119 	req.info.sel.prefixlen_s	= PREFIX_LEN;
1120 
1121 	/* Fill lifteme_cfg */
1122 	req.info.lft.soft_byte_limit	= XFRM_INF;
1123 	req.info.lft.hard_byte_limit	= XFRM_INF;
1124 	req.info.lft.soft_packet_limit	= XFRM_INF;
1125 	req.info.lft.hard_packet_limit	= XFRM_INF;
1126 
1127 	req.info.dir = dir;
1128 
1129 	/* Fill tmpl */
1130 	memcpy(&tmpl.id.daddr, &dst, sizeof(dst));
1131 	/* Note: zero-spi cannot be deleted */
1132 	tmpl.id.spi = spi;
1133 	tmpl.id.proto	= proto;
1134 	tmpl.family	= AF_INET;
1135 	memcpy(&tmpl.saddr, &src, sizeof(src));
1136 	tmpl.mode	= XFRM_MODE_TUNNEL;
1137 	tmpl.aalgos = (~(uint32_t)0);
1138 	tmpl.ealgos = (~(uint32_t)0);
1139 	tmpl.calgos = (~(uint32_t)0);
1140 
1141 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &tmpl, sizeof(tmpl)))
1142 		return -1;
1143 
1144 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1145 		pr_err("send()");
1146 		return -1;
1147 	}
1148 
1149 	return netlink_check_answer(xfrm_sock);
1150 }
1151 
1152 static int xfrm_prepare(int xfrm_sock, uint32_t *seq,
1153 		struct in_addr src, struct in_addr dst,
1154 		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1155 {
1156 	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1157 				XFRM_POLICY_OUT, tunsrc, tundst, proto)) {
1158 		printk("Failed to add xfrm policy");
1159 		return -1;
1160 	}
1161 
1162 	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src,
1163 				XFRM_POLICY_IN, tunsrc, tundst, proto)) {
1164 		printk("Failed to add xfrm policy");
1165 		return -1;
1166 	}
1167 
1168 	return 0;
1169 }
1170 
1171 static int xfrm_policy_del(int xfrm_sock, uint32_t seq,
1172 		struct in_addr src, struct in_addr dst, uint8_t dir,
1173 		struct in_addr tunsrc, struct in_addr tundst)
1174 {
1175 	struct {
1176 		struct nlmsghdr			nh;
1177 		struct xfrm_userpolicy_id	id;
1178 		char				attrbuf[MAX_PAYLOAD];
1179 	} req;
1180 
1181 	memset(&req, 0, sizeof(req));
1182 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.id));
1183 	req.nh.nlmsg_type	= XFRM_MSG_DELPOLICY;
1184 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1185 	req.nh.nlmsg_seq	= seq;
1186 
1187 	/* Fill id */
1188 	memcpy(&req.id.sel.daddr, &dst, sizeof(tundst));
1189 	memcpy(&req.id.sel.saddr, &src, sizeof(tunsrc));
1190 	req.id.sel.family		= AF_INET;
1191 	req.id.sel.prefixlen_d		= PREFIX_LEN;
1192 	req.id.sel.prefixlen_s		= PREFIX_LEN;
1193 	req.id.dir = dir;
1194 
1195 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1196 		pr_err("send()");
1197 		return -1;
1198 	}
1199 
1200 	return netlink_check_answer(xfrm_sock);
1201 }
1202 
1203 static int xfrm_cleanup(int xfrm_sock, uint32_t *seq,
1204 		struct in_addr src, struct in_addr dst,
1205 		struct in_addr tunsrc, struct in_addr tundst)
1206 {
1207 	if (xfrm_policy_del(xfrm_sock, (*seq)++, src, dst,
1208 				XFRM_POLICY_OUT, tunsrc, tundst)) {
1209 		printk("Failed to add xfrm policy");
1210 		return -1;
1211 	}
1212 
1213 	if (xfrm_policy_del(xfrm_sock, (*seq)++, dst, src,
1214 				XFRM_POLICY_IN, tunsrc, tundst)) {
1215 		printk("Failed to add xfrm policy");
1216 		return -1;
1217 	}
1218 
1219 	return 0;
1220 }
1221 
1222 static int xfrm_state_del(int xfrm_sock, uint32_t seq, uint32_t spi,
1223 		struct in_addr src, struct in_addr dst, uint8_t proto)
1224 {
1225 	struct {
1226 		struct nlmsghdr		nh;
1227 		struct xfrm_usersa_id	id;
1228 		char			attrbuf[MAX_PAYLOAD];
1229 	} req;
1230 	xfrm_address_t saddr = {};
1231 
1232 	memset(&req, 0, sizeof(req));
1233 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.id));
1234 	req.nh.nlmsg_type	= XFRM_MSG_DELSA;
1235 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1236 	req.nh.nlmsg_seq	= seq;
1237 
1238 	memcpy(&req.id.daddr, &dst, sizeof(dst));
1239 	req.id.family		= AF_INET;
1240 	req.id.proto		= proto;
1241 	/* Note: zero-spi cannot be deleted */
1242 	req.id.spi = spi;
1243 
1244 	memcpy(&saddr, &src, sizeof(src));
1245 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SRCADDR, &saddr, sizeof(saddr)))
1246 		return -1;
1247 
1248 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1249 		pr_err("send()");
1250 		return -1;
1251 	}
1252 
1253 	return netlink_check_answer(xfrm_sock);
1254 }
1255 
1256 static int xfrm_delete(int xfrm_sock, uint32_t *seq,
1257 		struct in_addr src, struct in_addr dst,
1258 		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1259 {
1260 	if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), src, dst, proto)) {
1261 		printk("Failed to remove xfrm state");
1262 		return -1;
1263 	}
1264 
1265 	if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), dst, src, proto)) {
1266 		printk("Failed to remove xfrm state");
1267 		return -1;
1268 	}
1269 
1270 	return 0;
1271 }
1272 
1273 static int xfrm_state_allocspi(int xfrm_sock, uint32_t *seq,
1274 		uint32_t spi, uint8_t proto)
1275 {
1276 	struct {
1277 		struct nlmsghdr			nh;
1278 		struct xfrm_userspi_info	spi;
1279 	} req;
1280 	struct {
1281 		struct nlmsghdr			nh;
1282 		union {
1283 			struct xfrm_usersa_info	info;
1284 			int error;
1285 		};
1286 	} answer;
1287 
1288 	memset(&req, 0, sizeof(req));
1289 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.spi));
1290 	req.nh.nlmsg_type	= XFRM_MSG_ALLOCSPI;
1291 	req.nh.nlmsg_flags	= NLM_F_REQUEST;
1292 	req.nh.nlmsg_seq	= (*seq)++;
1293 
1294 	req.spi.info.family	= AF_INET;
1295 	req.spi.min		= spi;
1296 	req.spi.max		= spi;
1297 	req.spi.info.id.proto	= proto;
1298 
1299 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1300 		pr_err("send()");
1301 		return KSFT_FAIL;
1302 	}
1303 
1304 	if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1305 		pr_err("recv()");
1306 		return KSFT_FAIL;
1307 	} else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1308 		uint32_t new_spi = htonl(answer.info.id.spi);
1309 
1310 		if (new_spi != spi) {
1311 			printk("allocated spi is different from requested: %#x != %#x",
1312 					new_spi, spi);
1313 			return KSFT_FAIL;
1314 		}
1315 		return KSFT_PASS;
1316 	} else if (answer.nh.nlmsg_type != NLMSG_ERROR) {
1317 		printk("expected NLMSG_ERROR, got %d", (int)answer.nh.nlmsg_type);
1318 		return KSFT_FAIL;
1319 	}
1320 
1321 	printk("NLMSG_ERROR: %d: %s", answer.error, strerror(-answer.error));
1322 	return (answer.error) ? KSFT_FAIL : KSFT_PASS;
1323 }
1324 
1325 static int netlink_sock_bind(int *sock, uint32_t *seq, int proto, uint32_t groups)
1326 {
1327 	struct sockaddr_nl snl = {};
1328 	socklen_t addr_len;
1329 	int ret = -1;
1330 
1331 	snl.nl_family = AF_NETLINK;
1332 	snl.nl_groups = groups;
1333 
1334 	if (netlink_sock(sock, seq, proto)) {
1335 		printk("Failed to open xfrm netlink socket");
1336 		return -1;
1337 	}
1338 
1339 	if (bind(*sock, (struct sockaddr *)&snl, sizeof(snl)) < 0) {
1340 		pr_err("bind()");
1341 		goto out_close;
1342 	}
1343 
1344 	addr_len = sizeof(snl);
1345 	if (getsockname(*sock, (struct sockaddr *)&snl, &addr_len) < 0) {
1346 		pr_err("getsockname()");
1347 		goto out_close;
1348 	}
1349 	if (addr_len != sizeof(snl)) {
1350 		printk("Wrong address length %d", addr_len);
1351 		goto out_close;
1352 	}
1353 	if (snl.nl_family != AF_NETLINK) {
1354 		printk("Wrong address family %d", snl.nl_family);
1355 		goto out_close;
1356 	}
1357 	return 0;
1358 
1359 out_close:
1360 	close(*sock);
1361 	return ret;
1362 }
1363 
1364 static int xfrm_monitor_acquire(int xfrm_sock, uint32_t *seq, unsigned int nr)
1365 {
1366 	struct {
1367 		struct nlmsghdr nh;
1368 		union {
1369 			struct xfrm_user_acquire acq;
1370 			int error;
1371 		};
1372 		char attrbuf[MAX_PAYLOAD];
1373 	} req;
1374 	struct xfrm_user_tmpl xfrm_tmpl = {};
1375 	int xfrm_listen = -1, ret = KSFT_FAIL;
1376 	uint32_t seq_listen;
1377 
1378 	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_ACQUIRE))
1379 		return KSFT_FAIL;
1380 
1381 	memset(&req, 0, sizeof(req));
1382 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.acq));
1383 	req.nh.nlmsg_type	= XFRM_MSG_ACQUIRE;
1384 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1385 	req.nh.nlmsg_seq	= (*seq)++;
1386 
1387 	req.acq.policy.sel.family	= AF_INET;
1388 	req.acq.aalgos	= 0xfeed;
1389 	req.acq.ealgos	= 0xbaad;
1390 	req.acq.calgos	= 0xbabe;
1391 
1392 	xfrm_tmpl.family = AF_INET;
1393 	xfrm_tmpl.id.proto = IPPROTO_ESP;
1394 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &xfrm_tmpl, sizeof(xfrm_tmpl)))
1395 		goto out_close;
1396 
1397 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1398 		pr_err("send()");
1399 		goto out_close;
1400 	}
1401 
1402 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1403 		pr_err("recv()");
1404 		goto out_close;
1405 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1406 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1407 		goto out_close;
1408 	}
1409 
1410 	if (req.error) {
1411 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1412 		ret = req.error;
1413 		goto out_close;
1414 	}
1415 
1416 	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1417 		pr_err("recv()");
1418 		goto out_close;
1419 	}
1420 
1421 	if (req.acq.aalgos != 0xfeed || req.acq.ealgos != 0xbaad
1422 			|| req.acq.calgos != 0xbabe) {
1423 		printk("xfrm_user_acquire has changed  %x %x %x",
1424 				req.acq.aalgos, req.acq.ealgos, req.acq.calgos);
1425 		goto out_close;
1426 	}
1427 
1428 	ret = KSFT_PASS;
1429 out_close:
1430 	close(xfrm_listen);
1431 	return ret;
1432 }
1433 
1434 static int xfrm_expire_state(int xfrm_sock, uint32_t *seq,
1435 		unsigned int nr, struct xfrm_desc *desc)
1436 {
1437 	struct {
1438 		struct nlmsghdr nh;
1439 		union {
1440 			struct xfrm_user_expire expire;
1441 			int error;
1442 		};
1443 	} req;
1444 	struct in_addr src, dst;
1445 	int xfrm_listen = -1, ret = KSFT_FAIL;
1446 	uint32_t seq_listen;
1447 
1448 	src = inet_makeaddr(INADDR_B, child_ip(nr));
1449 	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1450 
1451 	if (xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc)) {
1452 		printk("Failed to add xfrm state");
1453 		return KSFT_FAIL;
1454 	}
1455 
1456 	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1457 		return KSFT_FAIL;
1458 
1459 	memset(&req, 0, sizeof(req));
1460 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.expire));
1461 	req.nh.nlmsg_type	= XFRM_MSG_EXPIRE;
1462 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1463 	req.nh.nlmsg_seq	= (*seq)++;
1464 
1465 	memcpy(&req.expire.state.id.daddr, &dst, sizeof(dst));
1466 	req.expire.state.id.spi		= gen_spi(src);
1467 	req.expire.state.id.proto	= desc->proto;
1468 	req.expire.state.family		= AF_INET;
1469 	req.expire.hard			= 0xff;
1470 
1471 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1472 		pr_err("send()");
1473 		goto out_close;
1474 	}
1475 
1476 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1477 		pr_err("recv()");
1478 		goto out_close;
1479 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1480 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1481 		goto out_close;
1482 	}
1483 
1484 	if (req.error) {
1485 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1486 		ret = req.error;
1487 		goto out_close;
1488 	}
1489 
1490 	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1491 		pr_err("recv()");
1492 		goto out_close;
1493 	}
1494 
1495 	if (req.expire.hard != 0x1) {
1496 		printk("expire.hard is not set: %x", req.expire.hard);
1497 		goto out_close;
1498 	}
1499 
1500 	ret = KSFT_PASS;
1501 out_close:
1502 	close(xfrm_listen);
1503 	return ret;
1504 }
1505 
1506 static int xfrm_expire_policy(int xfrm_sock, uint32_t *seq,
1507 		unsigned int nr, struct xfrm_desc *desc)
1508 {
1509 	struct {
1510 		struct nlmsghdr nh;
1511 		union {
1512 			struct xfrm_user_polexpire expire;
1513 			int error;
1514 		};
1515 	} req;
1516 	struct in_addr src, dst, tunsrc, tundst;
1517 	int xfrm_listen = -1, ret = KSFT_FAIL;
1518 	uint32_t seq_listen;
1519 
1520 	src = inet_makeaddr(INADDR_B, child_ip(nr));
1521 	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1522 	tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1523 	tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1524 
1525 	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1526 				XFRM_POLICY_OUT, tunsrc, tundst, desc->proto)) {
1527 		printk("Failed to add xfrm policy");
1528 		return KSFT_FAIL;
1529 	}
1530 
1531 	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1532 		return KSFT_FAIL;
1533 
1534 	memset(&req, 0, sizeof(req));
1535 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.expire));
1536 	req.nh.nlmsg_type	= XFRM_MSG_POLEXPIRE;
1537 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1538 	req.nh.nlmsg_seq	= (*seq)++;
1539 
1540 	/* Fill selector. */
1541 	memcpy(&req.expire.pol.sel.daddr, &dst, sizeof(tundst));
1542 	memcpy(&req.expire.pol.sel.saddr, &src, sizeof(tunsrc));
1543 	req.expire.pol.sel.family	= AF_INET;
1544 	req.expire.pol.sel.prefixlen_d	= PREFIX_LEN;
1545 	req.expire.pol.sel.prefixlen_s	= PREFIX_LEN;
1546 	req.expire.pol.dir		= XFRM_POLICY_OUT;
1547 	req.expire.hard			= 0xff;
1548 
1549 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1550 		pr_err("send()");
1551 		goto out_close;
1552 	}
1553 
1554 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1555 		pr_err("recv()");
1556 		goto out_close;
1557 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1558 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1559 		goto out_close;
1560 	}
1561 
1562 	if (req.error) {
1563 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1564 		ret = req.error;
1565 		goto out_close;
1566 	}
1567 
1568 	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1569 		pr_err("recv()");
1570 		goto out_close;
1571 	}
1572 
1573 	if (req.expire.hard != 0x1) {
1574 		printk("expire.hard is not set: %x", req.expire.hard);
1575 		goto out_close;
1576 	}
1577 
1578 	ret = KSFT_PASS;
1579 out_close:
1580 	close(xfrm_listen);
1581 	return ret;
1582 }
1583 
1584 static int xfrm_spdinfo_set_thresh(int xfrm_sock, uint32_t *seq,
1585 		unsigned thresh4_l, unsigned thresh4_r,
1586 		unsigned thresh6_l, unsigned thresh6_r,
1587 		bool add_bad_attr)
1588 
1589 {
1590 	struct {
1591 		struct nlmsghdr		nh;
1592 		union {
1593 			uint32_t	unused;
1594 			int		error;
1595 		};
1596 		char			attrbuf[MAX_PAYLOAD];
1597 	} req;
1598 	struct xfrmu_spdhthresh thresh;
1599 
1600 	memset(&req, 0, sizeof(req));
1601 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.unused));
1602 	req.nh.nlmsg_type	= XFRM_MSG_NEWSPDINFO;
1603 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1604 	req.nh.nlmsg_seq	= (*seq)++;
1605 
1606 	thresh.lbits = thresh4_l;
1607 	thresh.rbits = thresh4_r;
1608 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SPD_IPV4_HTHRESH, &thresh, sizeof(thresh)))
1609 		return -1;
1610 
1611 	thresh.lbits = thresh6_l;
1612 	thresh.rbits = thresh6_r;
1613 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SPD_IPV6_HTHRESH, &thresh, sizeof(thresh)))
1614 		return -1;
1615 
1616 	if (add_bad_attr) {
1617 		BUILD_BUG_ON(XFRMA_IF_ID <= XFRMA_SPD_MAX + 1);
1618 		if (rtattr_pack(&req.nh, sizeof(req), XFRMA_IF_ID, NULL, 0)) {
1619 			pr_err("adding attribute failed: no space");
1620 			return -1;
1621 		}
1622 	}
1623 
1624 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1625 		pr_err("send()");
1626 		return -1;
1627 	}
1628 
1629 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1630 		pr_err("recv()");
1631 		return -1;
1632 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1633 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1634 		return -1;
1635 	}
1636 
1637 	if (req.error) {
1638 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1639 		return -1;
1640 	}
1641 
1642 	return 0;
1643 }
1644 
1645 static int xfrm_spdinfo_attrs(int xfrm_sock, uint32_t *seq)
1646 {
1647 	struct {
1648 		struct nlmsghdr			nh;
1649 		union {
1650 			uint32_t	unused;
1651 			int		error;
1652 		};
1653 		char			attrbuf[MAX_PAYLOAD];
1654 	} req;
1655 
1656 	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 31, 120, 16, false)) {
1657 		pr_err("Can't set SPD HTHRESH");
1658 		return KSFT_FAIL;
1659 	}
1660 
1661 	memset(&req, 0, sizeof(req));
1662 
1663 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.unused));
1664 	req.nh.nlmsg_type	= XFRM_MSG_GETSPDINFO;
1665 	req.nh.nlmsg_flags	= NLM_F_REQUEST;
1666 	req.nh.nlmsg_seq	= (*seq)++;
1667 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1668 		pr_err("send()");
1669 		return KSFT_FAIL;
1670 	}
1671 
1672 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1673 		pr_err("recv()");
1674 		return KSFT_FAIL;
1675 	} else if (req.nh.nlmsg_type == XFRM_MSG_NEWSPDINFO) {
1676 		size_t len = NLMSG_PAYLOAD(&req.nh, sizeof(req.unused));
1677 		struct rtattr *attr = (void *)req.attrbuf;
1678 		int got_thresh = 0;
1679 
1680 		for (; RTA_OK(attr, len); attr = RTA_NEXT(attr, len)) {
1681 			if (attr->rta_type == XFRMA_SPD_IPV4_HTHRESH) {
1682 				struct xfrmu_spdhthresh *t = RTA_DATA(attr);
1683 
1684 				got_thresh++;
1685 				if (t->lbits != 32 || t->rbits != 31) {
1686 					pr_err("thresh differ: %u, %u",
1687 							t->lbits, t->rbits);
1688 					return KSFT_FAIL;
1689 				}
1690 			}
1691 			if (attr->rta_type == XFRMA_SPD_IPV6_HTHRESH) {
1692 				struct xfrmu_spdhthresh *t = RTA_DATA(attr);
1693 
1694 				got_thresh++;
1695 				if (t->lbits != 120 || t->rbits != 16) {
1696 					pr_err("thresh differ: %u, %u",
1697 							t->lbits, t->rbits);
1698 					return KSFT_FAIL;
1699 				}
1700 			}
1701 		}
1702 		if (got_thresh != 2) {
1703 			pr_err("only %d thresh returned by XFRM_MSG_GETSPDINFO", got_thresh);
1704 			return KSFT_FAIL;
1705 		}
1706 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1707 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1708 		return KSFT_FAIL;
1709 	} else {
1710 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1711 		return -1;
1712 	}
1713 
1714 	/* Restore the default */
1715 	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 32, 128, 128, false)) {
1716 		pr_err("Can't restore SPD HTHRESH");
1717 		return KSFT_FAIL;
1718 	}
1719 
1720 	/*
1721 	 * At this moment xfrm uses nlmsg_parse_deprecated(), which
1722 	 * implies NL_VALIDATE_LIBERAL - ignoring attributes with
1723 	 * (type > maxtype). nla_parse_depricated_strict() would enforce
1724 	 * it. Or even stricter nla_parse().
1725 	 * Right now it's not expected to fail, but to be ignored.
1726 	 */
1727 	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 32, 128, 128, true))
1728 		return KSFT_PASS;
1729 
1730 	return KSFT_PASS;
1731 }
1732 
1733 static int child_serv(int xfrm_sock, uint32_t *seq,
1734 		unsigned int nr, int cmd_fd, void *buf, struct xfrm_desc *desc)
1735 {
1736 	struct in_addr src, dst, tunsrc, tundst;
1737 	struct test_desc msg;
1738 	int ret = KSFT_FAIL;
1739 
1740 	src = inet_makeaddr(INADDR_B, child_ip(nr));
1741 	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1742 	tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1743 	tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1744 
1745 	/* UDP pinging without xfrm */
1746 	if (do_ping(cmd_fd, buf, page_size, src, true, 0, 0, udp_ping_send)) {
1747 		printk("ping failed before setting xfrm");
1748 		return KSFT_FAIL;
1749 	}
1750 
1751 	memset(&msg, 0, sizeof(msg));
1752 	msg.type = MSG_XFRM_PREPARE;
1753 	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1754 	write_msg(cmd_fd, &msg, 1);
1755 
1756 	if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1757 		printk("failed to prepare xfrm");
1758 		goto cleanup;
1759 	}
1760 
1761 	memset(&msg, 0, sizeof(msg));
1762 	msg.type = MSG_XFRM_ADD;
1763 	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1764 	write_msg(cmd_fd, &msg, 1);
1765 	if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1766 		printk("failed to set xfrm");
1767 		goto delete;
1768 	}
1769 
1770 	/* UDP pinging with xfrm tunnel */
1771 	if (do_ping(cmd_fd, buf, page_size, tunsrc,
1772 				true, 0, 0, udp_ping_send)) {
1773 		printk("ping failed for xfrm");
1774 		goto delete;
1775 	}
1776 
1777 	ret = KSFT_PASS;
1778 delete:
1779 	/* xfrm delete */
1780 	memset(&msg, 0, sizeof(msg));
1781 	msg.type = MSG_XFRM_DEL;
1782 	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1783 	write_msg(cmd_fd, &msg, 1);
1784 
1785 	if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1786 		printk("failed ping to remove xfrm");
1787 		ret = KSFT_FAIL;
1788 	}
1789 
1790 cleanup:
1791 	memset(&msg, 0, sizeof(msg));
1792 	msg.type = MSG_XFRM_CLEANUP;
1793 	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1794 	write_msg(cmd_fd, &msg, 1);
1795 	if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1796 		printk("failed ping to cleanup xfrm");
1797 		ret = KSFT_FAIL;
1798 	}
1799 	return ret;
1800 }
1801 
1802 static int child_f(unsigned int nr, int test_desc_fd, int cmd_fd, void *buf)
1803 {
1804 	struct xfrm_desc desc;
1805 	struct test_desc msg;
1806 	int xfrm_sock = -1;
1807 	uint32_t seq;
1808 
1809 	if (switch_ns(nsfd_childa))
1810 		exit(KSFT_FAIL);
1811 
1812 	if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1813 		printk("Failed to open xfrm netlink socket");
1814 		exit(KSFT_FAIL);
1815 	}
1816 
1817 	/* Check that seq sock is ready, just for sure. */
1818 	memset(&msg, 0, sizeof(msg));
1819 	msg.type = MSG_ACK;
1820 	write_msg(cmd_fd, &msg, 1);
1821 	read_msg(cmd_fd, &msg, 1);
1822 	if (msg.type != MSG_ACK) {
1823 		printk("Ack failed");
1824 		exit(KSFT_FAIL);
1825 	}
1826 
1827 	for (;;) {
1828 		ssize_t received = read(test_desc_fd, &desc, sizeof(desc));
1829 		int ret;
1830 
1831 		if (received == 0) /* EOF */
1832 			break;
1833 
1834 		if (received != sizeof(desc)) {
1835 			pr_err("read() returned %zd", received);
1836 			exit(KSFT_FAIL);
1837 		}
1838 
1839 		switch (desc.type) {
1840 		case CREATE_TUNNEL:
1841 			ret = child_serv(xfrm_sock, &seq, nr,
1842 					 cmd_fd, buf, &desc);
1843 			break;
1844 		case ALLOCATE_SPI:
1845 			ret = xfrm_state_allocspi(xfrm_sock, &seq,
1846 						  -1, desc.proto);
1847 			break;
1848 		case MONITOR_ACQUIRE:
1849 			ret = xfrm_monitor_acquire(xfrm_sock, &seq, nr);
1850 			break;
1851 		case EXPIRE_STATE:
1852 			ret = xfrm_expire_state(xfrm_sock, &seq, nr, &desc);
1853 			break;
1854 		case EXPIRE_POLICY:
1855 			ret = xfrm_expire_policy(xfrm_sock, &seq, nr, &desc);
1856 			break;
1857 		case SPDINFO_ATTRS:
1858 			ret = xfrm_spdinfo_attrs(xfrm_sock, &seq);
1859 			break;
1860 		default:
1861 			printk("Unknown desc type %d", desc.type);
1862 			exit(KSFT_FAIL);
1863 		}
1864 		write_test_result(ret, &desc);
1865 	}
1866 
1867 	close(xfrm_sock);
1868 
1869 	msg.type = MSG_EXIT;
1870 	write_msg(cmd_fd, &msg, 1);
1871 	exit(KSFT_PASS);
1872 }
1873 
1874 static void grand_child_serv(unsigned int nr, int cmd_fd, void *buf,
1875 		struct test_desc *msg, int xfrm_sock, uint32_t *seq)
1876 {
1877 	struct in_addr src, dst, tunsrc, tundst;
1878 	bool tun_reply;
1879 	struct xfrm_desc *desc = &msg->body.xfrm_desc;
1880 
1881 	src = inet_makeaddr(INADDR_B, grchild_ip(nr));
1882 	dst = inet_makeaddr(INADDR_B, child_ip(nr));
1883 	tunsrc = inet_makeaddr(INADDR_A, grchild_ip(nr));
1884 	tundst = inet_makeaddr(INADDR_A, child_ip(nr));
1885 
1886 	switch (msg->type) {
1887 	case MSG_EXIT:
1888 		exit(KSFT_PASS);
1889 	case MSG_ACK:
1890 		write_msg(cmd_fd, msg, 1);
1891 		break;
1892 	case MSG_PING:
1893 		tun_reply = memcmp(&dst, &msg->body.ping.reply_ip, sizeof(in_addr_t));
1894 		/* UDP pinging without xfrm */
1895 		if (do_ping(cmd_fd, buf, page_size, tun_reply ? tunsrc : src,
1896 				false, msg->body.ping.port,
1897 				msg->body.ping.reply_ip, udp_ping_reply)) {
1898 			printk("ping failed before setting xfrm");
1899 		}
1900 		break;
1901 	case MSG_XFRM_PREPARE:
1902 		if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst,
1903 					desc->proto)) {
1904 			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1905 			printk("failed to prepare xfrm");
1906 		}
1907 		break;
1908 	case MSG_XFRM_ADD:
1909 		if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1910 			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1911 			printk("failed to set xfrm");
1912 		}
1913 		break;
1914 	case MSG_XFRM_DEL:
1915 		if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst,
1916 					desc->proto)) {
1917 			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1918 			printk("failed to remove xfrm");
1919 		}
1920 		break;
1921 	case MSG_XFRM_CLEANUP:
1922 		if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1923 			printk("failed to cleanup xfrm");
1924 		}
1925 		break;
1926 	default:
1927 		printk("got unknown msg type %d", msg->type);
1928 	}
1929 }
1930 
1931 static int grand_child_f(unsigned int nr, int cmd_fd, void *buf)
1932 {
1933 	struct test_desc msg;
1934 	int xfrm_sock = -1;
1935 	uint32_t seq;
1936 
1937 	if (switch_ns(nsfd_childb))
1938 		exit(KSFT_FAIL);
1939 
1940 	if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1941 		printk("Failed to open xfrm netlink socket");
1942 		exit(KSFT_FAIL);
1943 	}
1944 
1945 	do {
1946 		read_msg(cmd_fd, &msg, 1);
1947 		grand_child_serv(nr, cmd_fd, buf, &msg, xfrm_sock, &seq);
1948 	} while (1);
1949 
1950 	close(xfrm_sock);
1951 	exit(KSFT_FAIL);
1952 }
1953 
1954 static int start_child(unsigned int nr, char *veth, int test_desc_fd[2])
1955 {
1956 	int cmd_sock[2];
1957 	void *data_map;
1958 	pid_t child;
1959 
1960 	if (init_child(nsfd_childa, veth, child_ip(nr), grchild_ip(nr)))
1961 		return -1;
1962 
1963 	if (init_child(nsfd_childb, veth, grchild_ip(nr), child_ip(nr)))
1964 		return -1;
1965 
1966 	child = fork();
1967 	if (child < 0) {
1968 		pr_err("fork()");
1969 		return -1;
1970 	} else if (child) {
1971 		/* in parent - selftest */
1972 		return switch_ns(nsfd_parent);
1973 	}
1974 
1975 	if (close(test_desc_fd[1])) {
1976 		pr_err("close()");
1977 		return -1;
1978 	}
1979 
1980 	/* child */
1981 	data_map = mmap(0, page_size, PROT_READ | PROT_WRITE,
1982 			MAP_SHARED | MAP_ANONYMOUS, -1, 0);
1983 	if (data_map == MAP_FAILED) {
1984 		pr_err("mmap()");
1985 		return -1;
1986 	}
1987 
1988 	randomize_buffer(data_map, page_size);
1989 
1990 	if (socketpair(PF_LOCAL, SOCK_SEQPACKET, 0, cmd_sock)) {
1991 		pr_err("socketpair()");
1992 		return -1;
1993 	}
1994 
1995 	child = fork();
1996 	if (child < 0) {
1997 		pr_err("fork()");
1998 		return -1;
1999 	} else if (child) {
2000 		if (close(cmd_sock[0])) {
2001 			pr_err("close()");
2002 			return -1;
2003 		}
2004 		return child_f(nr, test_desc_fd[0], cmd_sock[1], data_map);
2005 	}
2006 	if (close(cmd_sock[1])) {
2007 		pr_err("close()");
2008 		return -1;
2009 	}
2010 	return grand_child_f(nr, cmd_sock[0], data_map);
2011 }
2012 
2013 static void exit_usage(char **argv)
2014 {
2015 	printk("Usage: %s [nr_process]", argv[0]);
2016 	exit(KSFT_FAIL);
2017 }
2018 
2019 static int __write_desc(int test_desc_fd, struct xfrm_desc *desc)
2020 {
2021 	ssize_t ret;
2022 
2023 	ret = write(test_desc_fd, desc, sizeof(*desc));
2024 
2025 	if (ret == sizeof(*desc))
2026 		return 0;
2027 
2028 	pr_err("Writing test's desc failed %ld", ret);
2029 
2030 	return -1;
2031 }
2032 
2033 static int write_desc(int proto, int test_desc_fd,
2034 		char *a, char *e, char *c, char *ae)
2035 {
2036 	struct xfrm_desc desc = {};
2037 
2038 	desc.type = CREATE_TUNNEL;
2039 	desc.proto = proto;
2040 
2041 	if (a)
2042 		strncpy(desc.a_algo, a, ALGO_LEN - 1);
2043 	if (e)
2044 		strncpy(desc.e_algo, e, ALGO_LEN - 1);
2045 	if (c)
2046 		strncpy(desc.c_algo, c, ALGO_LEN - 1);
2047 	if (ae)
2048 		strncpy(desc.ae_algo, ae, ALGO_LEN - 1);
2049 
2050 	return __write_desc(test_desc_fd, &desc);
2051 }
2052 
2053 int proto_list[] = { IPPROTO_AH, IPPROTO_COMP, IPPROTO_ESP };
2054 char *ah_list[] = {
2055 	"digest_null", "hmac(md5)", "hmac(sha1)", "hmac(sha256)",
2056 	"hmac(sha384)", "hmac(sha512)", "hmac(rmd160)",
2057 	"xcbc(aes)", "cmac(aes)"
2058 };
2059 char *comp_list[] = {
2060 	"deflate",
2061 #if 0
2062 	/* No compression backend realization */
2063 	"lzs", "lzjh"
2064 #endif
2065 };
2066 char *e_list[] = {
2067 	"ecb(cipher_null)", "cbc(des)", "cbc(des3_ede)", "cbc(cast5)",
2068 	"cbc(blowfish)", "cbc(aes)", "cbc(serpent)", "cbc(camellia)",
2069 	"cbc(twofish)", "rfc3686(ctr(aes))"
2070 };
2071 char *ae_list[] = {
2072 #if 0
2073 	/* not implemented */
2074 	"rfc4106(gcm(aes))", "rfc4309(ccm(aes))", "rfc4543(gcm(aes))",
2075 	"rfc7539esp(chacha20,poly1305)"
2076 #endif
2077 };
2078 
2079 const unsigned int proto_plan = ARRAY_SIZE(ah_list) + ARRAY_SIZE(comp_list) \
2080 				+ (ARRAY_SIZE(ah_list) * ARRAY_SIZE(e_list)) \
2081 				+ ARRAY_SIZE(ae_list);
2082 
2083 static int write_proto_plan(int fd, int proto)
2084 {
2085 	unsigned int i;
2086 
2087 	switch (proto) {
2088 	case IPPROTO_AH:
2089 		for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
2090 			if (write_desc(proto, fd, ah_list[i], 0, 0, 0))
2091 				return -1;
2092 		}
2093 		break;
2094 	case IPPROTO_COMP:
2095 		for (i = 0; i < ARRAY_SIZE(comp_list); i++) {
2096 			if (write_desc(proto, fd, 0, 0, comp_list[i], 0))
2097 				return -1;
2098 		}
2099 		break;
2100 	case IPPROTO_ESP:
2101 		for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
2102 			int j;
2103 
2104 			for (j = 0; j < ARRAY_SIZE(e_list); j++) {
2105 				if (write_desc(proto, fd, ah_list[i],
2106 							e_list[j], 0, 0))
2107 					return -1;
2108 			}
2109 		}
2110 		for (i = 0; i < ARRAY_SIZE(ae_list); i++) {
2111 			if (write_desc(proto, fd, 0, 0, 0, ae_list[i]))
2112 				return -1;
2113 		}
2114 		break;
2115 	default:
2116 		printk("BUG: Specified unknown proto %d", proto);
2117 		return -1;
2118 	}
2119 
2120 	return 0;
2121 }
2122 
2123 /*
2124  * Some structures in xfrm uapi header differ in size between
2125  * 64-bit and 32-bit ABI:
2126  *
2127  *             32-bit UABI               |            64-bit UABI
2128  *  -------------------------------------|-------------------------------------
2129  *   sizeof(xfrm_usersa_info)     = 220  |  sizeof(xfrm_usersa_info)     = 224
2130  *   sizeof(xfrm_userpolicy_info) = 164  |  sizeof(xfrm_userpolicy_info) = 168
2131  *   sizeof(xfrm_userspi_info)    = 228  |  sizeof(xfrm_userspi_info)    = 232
2132  *   sizeof(xfrm_user_acquire)    = 276  |  sizeof(xfrm_user_acquire)    = 280
2133  *   sizeof(xfrm_user_expire)     = 224  |  sizeof(xfrm_user_expire)     = 232
2134  *   sizeof(xfrm_user_polexpire)  = 168  |  sizeof(xfrm_user_polexpire)  = 176
2135  *
2136  * Check the affected by the UABI difference structures.
2137  * Also, check translation for xfrm_set_spdinfo: it has it's own attributes
2138  * which needs to be correctly copied, but not translated.
2139  */
2140 const unsigned int compat_plan = 5;
2141 static int write_compat_struct_tests(int test_desc_fd)
2142 {
2143 	struct xfrm_desc desc = {};
2144 
2145 	desc.type = ALLOCATE_SPI;
2146 	desc.proto = IPPROTO_AH;
2147 	strncpy(desc.a_algo, ah_list[0], ALGO_LEN - 1);
2148 
2149 	if (__write_desc(test_desc_fd, &desc))
2150 		return -1;
2151 
2152 	desc.type = MONITOR_ACQUIRE;
2153 	if (__write_desc(test_desc_fd, &desc))
2154 		return -1;
2155 
2156 	desc.type = EXPIRE_STATE;
2157 	if (__write_desc(test_desc_fd, &desc))
2158 		return -1;
2159 
2160 	desc.type = EXPIRE_POLICY;
2161 	if (__write_desc(test_desc_fd, &desc))
2162 		return -1;
2163 
2164 	desc.type = SPDINFO_ATTRS;
2165 	if (__write_desc(test_desc_fd, &desc))
2166 		return -1;
2167 
2168 	return 0;
2169 }
2170 
2171 static int write_test_plan(int test_desc_fd)
2172 {
2173 	unsigned int i;
2174 	pid_t child;
2175 
2176 	child = fork();
2177 	if (child < 0) {
2178 		pr_err("fork()");
2179 		return -1;
2180 	}
2181 	if (child) {
2182 		if (close(test_desc_fd))
2183 			printk("close(): %m");
2184 		return 0;
2185 	}
2186 
2187 	if (write_compat_struct_tests(test_desc_fd))
2188 		exit(KSFT_FAIL);
2189 
2190 	for (i = 0; i < ARRAY_SIZE(proto_list); i++) {
2191 		if (write_proto_plan(test_desc_fd, proto_list[i]))
2192 			exit(KSFT_FAIL);
2193 	}
2194 
2195 	exit(KSFT_PASS);
2196 }
2197 
2198 static int children_cleanup(void)
2199 {
2200 	unsigned ret = KSFT_PASS;
2201 
2202 	while (1) {
2203 		int status;
2204 		pid_t p = wait(&status);
2205 
2206 		if ((p < 0) && errno == ECHILD)
2207 			break;
2208 
2209 		if (p < 0) {
2210 			pr_err("wait()");
2211 			return KSFT_FAIL;
2212 		}
2213 
2214 		if (!WIFEXITED(status)) {
2215 			ret = KSFT_FAIL;
2216 			continue;
2217 		}
2218 
2219 		if (WEXITSTATUS(status) == KSFT_FAIL)
2220 			ret = KSFT_FAIL;
2221 	}
2222 
2223 	return ret;
2224 }
2225 
2226 typedef void (*print_res)(const char *, ...);
2227 
2228 static int check_results(void)
2229 {
2230 	struct test_result tr = {};
2231 	struct xfrm_desc *d = &tr.desc;
2232 	int ret = KSFT_PASS;
2233 
2234 	while (1) {
2235 		ssize_t received = read(results_fd[0], &tr, sizeof(tr));
2236 		print_res result;
2237 
2238 		if (received == 0) /* EOF */
2239 			break;
2240 
2241 		if (received != sizeof(tr)) {
2242 			pr_err("read() returned %zd", received);
2243 			return KSFT_FAIL;
2244 		}
2245 
2246 		switch (tr.res) {
2247 		case KSFT_PASS:
2248 			result = ksft_test_result_pass;
2249 			break;
2250 		case KSFT_FAIL:
2251 		default:
2252 			result = ksft_test_result_fail;
2253 			ret = KSFT_FAIL;
2254 		}
2255 
2256 		result(" %s: [%u, '%s', '%s', '%s', '%s', %u]\n",
2257 		       desc_name[d->type], (unsigned int)d->proto, d->a_algo,
2258 		       d->e_algo, d->c_algo, d->ae_algo, d->icv_len);
2259 	}
2260 
2261 	return ret;
2262 }
2263 
2264 int main(int argc, char **argv)
2265 {
2266 	long nr_process = 1;
2267 	int route_sock = -1, ret = KSFT_SKIP;
2268 	int test_desc_fd[2];
2269 	uint32_t route_seq;
2270 	unsigned int i;
2271 
2272 	if (argc > 2)
2273 		exit_usage(argv);
2274 
2275 	if (argc > 1) {
2276 		char *endptr;
2277 
2278 		errno = 0;
2279 		nr_process = strtol(argv[1], &endptr, 10);
2280 		if ((errno == ERANGE && (nr_process == LONG_MAX || nr_process == LONG_MIN))
2281 				|| (errno != 0 && nr_process == 0)
2282 				|| (endptr == argv[1]) || (*endptr != '\0')) {
2283 			printk("Failed to parse [nr_process]");
2284 			exit_usage(argv);
2285 		}
2286 
2287 		if (nr_process > MAX_PROCESSES || nr_process < 1) {
2288 			printk("nr_process should be between [1; %u]",
2289 					MAX_PROCESSES);
2290 			exit_usage(argv);
2291 		}
2292 	}
2293 
2294 	srand(time(NULL));
2295 	page_size = sysconf(_SC_PAGESIZE);
2296 	if (page_size < 1)
2297 		ksft_exit_skip("sysconf(): %m\n");
2298 
2299 	if (pipe2(test_desc_fd, O_DIRECT) < 0)
2300 		ksft_exit_skip("pipe(): %m\n");
2301 
2302 	if (pipe2(results_fd, O_DIRECT) < 0)
2303 		ksft_exit_skip("pipe(): %m\n");
2304 
2305 	if (init_namespaces())
2306 		ksft_exit_skip("Failed to create namespaces\n");
2307 
2308 	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
2309 		ksft_exit_skip("Failed to open netlink route socket\n");
2310 
2311 	for (i = 0; i < nr_process; i++) {
2312 		char veth[VETH_LEN];
2313 
2314 		snprintf(veth, VETH_LEN, VETH_FMT, i);
2315 
2316 		if (veth_add(route_sock, route_seq++, veth, nsfd_childa, veth, nsfd_childb)) {
2317 			close(route_sock);
2318 			ksft_exit_fail_msg("Failed to create veth device");
2319 		}
2320 
2321 		if (start_child(i, veth, test_desc_fd)) {
2322 			close(route_sock);
2323 			ksft_exit_fail_msg("Child %u failed to start", i);
2324 		}
2325 	}
2326 
2327 	if (close(route_sock) || close(test_desc_fd[0]) || close(results_fd[1]))
2328 		ksft_exit_fail_msg("close(): %m");
2329 
2330 	ksft_set_plan(proto_plan + compat_plan);
2331 
2332 	if (write_test_plan(test_desc_fd[1]))
2333 		ksft_exit_fail_msg("Failed to write test plan to pipe");
2334 
2335 	ret = check_results();
2336 
2337 	if (children_cleanup() == KSFT_FAIL)
2338 		exit(KSFT_FAIL);
2339 
2340 	exit(ret);
2341 }
2342