xref: /linux/tools/testing/selftests/net/ipsec.c (revision 35c2c39832e569449b9192fa1afbbc4c66227af7)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * ipsec.c - Check xfrm on veth inside a net-ns.
4  * Copyright (c) 2018 Dmitry Safonov
5  */
6 
7 #define _GNU_SOURCE
8 
9 #include <arpa/inet.h>
10 #include <asm/types.h>
11 #include <errno.h>
12 #include <fcntl.h>
13 #include <limits.h>
14 #include <linux/limits.h>
15 #include <linux/netlink.h>
16 #include <linux/random.h>
17 #include <linux/rtnetlink.h>
18 #include <linux/veth.h>
19 #include <linux/xfrm.h>
20 #include <netinet/in.h>
21 #include <net/if.h>
22 #include <sched.h>
23 #include <stdbool.h>
24 #include <stdint.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <sys/mman.h>
29 #include <sys/socket.h>
30 #include <sys/stat.h>
31 #include <sys/syscall.h>
32 #include <sys/types.h>
33 #include <sys/wait.h>
34 #include <time.h>
35 #include <unistd.h>
36 
37 #include "kselftest.h"
38 
39 #define printk(fmt, ...)						\
40 	ksft_print_msg("%d[%u] " fmt "\n", getpid(), __LINE__, ##__VA_ARGS__)
41 
42 #define pr_err(fmt, ...)	printk(fmt ": %m", ##__VA_ARGS__)
43 
44 #define BUILD_BUG_ON(condition) ((void)sizeof(char[1 - 2*!!(condition)]))
45 
46 #ifndef offsetof
47 #define offsetof(TYPE, MEMBER)	__builtin_offsetof(TYPE, MEMBER)
48 #endif
49 
50 #define IPV4_STR_SZ	16	/* xxx.xxx.xxx.xxx is longest + \0 */
51 #define MAX_PAYLOAD	2048
52 #define XFRM_ALGO_KEY_BUF_SIZE	512
53 #define MAX_PROCESSES	(1 << 14) /* /16 mask divided by /30 subnets */
54 #define INADDR_A	((in_addr_t) 0x0a000000) /* 10.0.0.0 */
55 #define INADDR_B	((in_addr_t) 0xc0a80000) /* 192.168.0.0 */
56 
57 /* /30 mask for one veth connection */
58 #define PREFIX_LEN	30
59 #define child_ip(nr)	(4*nr + 1)
60 #define grchild_ip(nr)	(4*nr + 2)
61 
62 #define VETH_FMT	"ktst-%d"
63 #define VETH_LEN	12
64 
65 static int nsfd_parent	= -1;
66 static int nsfd_childa	= -1;
67 static int nsfd_childb	= -1;
68 static long page_size;
69 
70 /*
71  * ksft_cnt is static in kselftest, so isn't shared with children.
72  * We have to send a test result back to parent and count there.
73  * results_fd is a pipe with test feedback from children.
74  */
75 static int results_fd[2];
76 
77 const unsigned int ping_delay_nsec	= 50 * 1000 * 1000;
78 const unsigned int ping_timeout		= 300;
79 const unsigned int ping_count		= 100;
80 const unsigned int ping_success		= 80;
81 
82 struct xfrm_key_entry {
83 	char algo_name[35];
84 	int key_len;
85 };
86 
87 struct xfrm_key_entry xfrm_key_entries[] = {
88 	{"digest_null", 0},
89 	{"ecb(cipher_null)", 0},
90 	{"cbc(des)", 64},
91 	{"hmac(md5)", 128},
92 	{"cmac(aes)", 128},
93 	{"xcbc(aes)", 128},
94 	{"cbc(cast5)", 128},
95 	{"cbc(serpent)", 128},
96 	{"hmac(sha1)", 160},
97 	{"cbc(des3_ede)", 192},
98 	{"hmac(sha256)", 256},
99 	{"cbc(aes)", 256},
100 	{"cbc(camellia)", 256},
101 	{"cbc(twofish)", 256},
102 	{"rfc3686(ctr(aes))", 288},
103 	{"hmac(sha384)", 384},
104 	{"cbc(blowfish)", 448},
105 	{"hmac(sha512)", 512},
106 	{"rfc4106(gcm(aes))-128", 160},
107 	{"rfc4543(gcm(aes))-128", 160},
108 	{"rfc4309(ccm(aes))-128", 152},
109 	{"rfc4106(gcm(aes))-192", 224},
110 	{"rfc4543(gcm(aes))-192", 224},
111 	{"rfc4309(ccm(aes))-192", 216},
112 	{"rfc4106(gcm(aes))-256", 288},
113 	{"rfc4543(gcm(aes))-256", 288},
114 	{"rfc4309(ccm(aes))-256", 280},
115 	{"rfc7539(chacha20,poly1305)-128", 0}
116 };
117 
118 static void randomize_buffer(void *buf, size_t buflen)
119 {
120 	int *p = (int *)buf;
121 	size_t words = buflen / sizeof(int);
122 	size_t leftover = buflen % sizeof(int);
123 
124 	if (!buflen)
125 		return;
126 
127 	while (words--)
128 		*p++ = rand();
129 
130 	if (leftover) {
131 		int tmp = rand();
132 
133 		memcpy(buf + buflen - leftover, &tmp, leftover);
134 	}
135 
136 	return;
137 }
138 
139 static int unshare_open(void)
140 {
141 	const char *netns_path = "/proc/self/ns/net";
142 	int fd;
143 
144 	if (unshare(CLONE_NEWNET) != 0) {
145 		pr_err("unshare()");
146 		return -1;
147 	}
148 
149 	fd = open(netns_path, O_RDONLY);
150 	if (fd <= 0) {
151 		pr_err("open(%s)", netns_path);
152 		return -1;
153 	}
154 
155 	return fd;
156 }
157 
158 static int switch_ns(int fd)
159 {
160 	if (setns(fd, CLONE_NEWNET)) {
161 		pr_err("setns()");
162 		return -1;
163 	}
164 	return 0;
165 }
166 
167 /*
168  * Running the test inside a new parent net namespace to bother less
169  * about cleanup on error-path.
170  */
171 static int init_namespaces(void)
172 {
173 	nsfd_parent = unshare_open();
174 	if (nsfd_parent <= 0)
175 		return -1;
176 
177 	nsfd_childa = unshare_open();
178 	if (nsfd_childa <= 0)
179 		return -1;
180 
181 	if (switch_ns(nsfd_parent))
182 		return -1;
183 
184 	nsfd_childb = unshare_open();
185 	if (nsfd_childb <= 0)
186 		return -1;
187 
188 	if (switch_ns(nsfd_parent))
189 		return -1;
190 	return 0;
191 }
192 
193 static int netlink_sock(int *sock, uint32_t *seq_nr, int proto)
194 {
195 	if (*sock > 0) {
196 		seq_nr++;
197 		return 0;
198 	}
199 
200 	*sock = socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, proto);
201 	if (*sock <= 0) {
202 		pr_err("socket(AF_NETLINK)");
203 		return -1;
204 	}
205 
206 	randomize_buffer(seq_nr, sizeof(*seq_nr));
207 
208 	return 0;
209 }
210 
211 static inline struct rtattr *rtattr_hdr(struct nlmsghdr *nh)
212 {
213 	return (struct rtattr *)((char *)(nh) + RTA_ALIGN((nh)->nlmsg_len));
214 }
215 
216 static int rtattr_pack(struct nlmsghdr *nh, size_t req_sz,
217 		unsigned short rta_type, const void *payload, size_t size)
218 {
219 	/* NLMSG_ALIGNTO == RTA_ALIGNTO, nlmsg_len already aligned */
220 	struct rtattr *attr = rtattr_hdr(nh);
221 	size_t nl_size = RTA_ALIGN(nh->nlmsg_len) + RTA_LENGTH(size);
222 
223 	if (req_sz < nl_size) {
224 		printk("req buf is too small: %zu < %zu", req_sz, nl_size);
225 		return -1;
226 	}
227 	nh->nlmsg_len = nl_size;
228 
229 	attr->rta_len = RTA_LENGTH(size);
230 	attr->rta_type = rta_type;
231 	if (payload)
232 		memcpy(RTA_DATA(attr), payload, size);
233 
234 	return 0;
235 }
236 
237 static struct rtattr *_rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
238 		unsigned short rta_type, const void *payload, size_t size)
239 {
240 	struct rtattr *ret = rtattr_hdr(nh);
241 
242 	if (rtattr_pack(nh, req_sz, rta_type, payload, size))
243 		return 0;
244 
245 	return ret;
246 }
247 
248 static inline struct rtattr *rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
249 		unsigned short rta_type)
250 {
251 	return _rtattr_begin(nh, req_sz, rta_type, 0, 0);
252 }
253 
254 static inline void rtattr_end(struct nlmsghdr *nh, struct rtattr *attr)
255 {
256 	char *nlmsg_end = (char *)nh + nh->nlmsg_len;
257 
258 	attr->rta_len = nlmsg_end - (char *)attr;
259 }
260 
261 static int veth_pack_peerb(struct nlmsghdr *nh, size_t req_sz,
262 		const char *peer, int ns)
263 {
264 	struct ifinfomsg pi;
265 	struct rtattr *peer_attr;
266 
267 	memset(&pi, 0, sizeof(pi));
268 	pi.ifi_family	= AF_UNSPEC;
269 	pi.ifi_change	= 0xFFFFFFFF;
270 
271 	peer_attr = _rtattr_begin(nh, req_sz, VETH_INFO_PEER, &pi, sizeof(pi));
272 	if (!peer_attr)
273 		return -1;
274 
275 	if (rtattr_pack(nh, req_sz, IFLA_IFNAME, peer, strlen(peer)))
276 		return -1;
277 
278 	if (rtattr_pack(nh, req_sz, IFLA_NET_NS_FD, &ns, sizeof(ns)))
279 		return -1;
280 
281 	rtattr_end(nh, peer_attr);
282 
283 	return 0;
284 }
285 
286 static int netlink_check_answer(int sock)
287 {
288 	struct nlmsgerror {
289 		struct nlmsghdr hdr;
290 		int error;
291 		struct nlmsghdr orig_msg;
292 	} answer;
293 
294 	if (recv(sock, &answer, sizeof(answer), 0) < 0) {
295 		pr_err("recv()");
296 		return -1;
297 	} else if (answer.hdr.nlmsg_type != NLMSG_ERROR) {
298 		printk("expected NLMSG_ERROR, got %d", (int)answer.hdr.nlmsg_type);
299 		return -1;
300 	} else if (answer.error) {
301 		printk("NLMSG_ERROR: %d: %s",
302 			answer.error, strerror(-answer.error));
303 		return answer.error;
304 	}
305 
306 	return 0;
307 }
308 
309 static int veth_add(int sock, uint32_t seq, const char *peera, int ns_a,
310 		const char *peerb, int ns_b)
311 {
312 	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
313 	struct {
314 		struct nlmsghdr		nh;
315 		struct ifinfomsg	info;
316 		char			attrbuf[MAX_PAYLOAD];
317 	} req;
318 	const char veth_type[] = "veth";
319 	struct rtattr *link_info, *info_data;
320 
321 	memset(&req, 0, sizeof(req));
322 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
323 	req.nh.nlmsg_type	= RTM_NEWLINK;
324 	req.nh.nlmsg_flags	= flags;
325 	req.nh.nlmsg_seq	= seq;
326 	req.info.ifi_family	= AF_UNSPEC;
327 	req.info.ifi_change	= 0xFFFFFFFF;
328 
329 	if (rtattr_pack(&req.nh, sizeof(req), IFLA_IFNAME, peera, strlen(peera)))
330 		return -1;
331 
332 	if (rtattr_pack(&req.nh, sizeof(req), IFLA_NET_NS_FD, &ns_a, sizeof(ns_a)))
333 		return -1;
334 
335 	link_info = rtattr_begin(&req.nh, sizeof(req), IFLA_LINKINFO);
336 	if (!link_info)
337 		return -1;
338 
339 	if (rtattr_pack(&req.nh, sizeof(req), IFLA_INFO_KIND, veth_type, sizeof(veth_type)))
340 		return -1;
341 
342 	info_data = rtattr_begin(&req.nh, sizeof(req), IFLA_INFO_DATA);
343 	if (!info_data)
344 		return -1;
345 
346 	if (veth_pack_peerb(&req.nh, sizeof(req), peerb, ns_b))
347 		return -1;
348 
349 	rtattr_end(&req.nh, info_data);
350 	rtattr_end(&req.nh, link_info);
351 
352 	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
353 		pr_err("send()");
354 		return -1;
355 	}
356 	return netlink_check_answer(sock);
357 }
358 
359 static int ip4_addr_set(int sock, uint32_t seq, const char *intf,
360 		struct in_addr addr, uint8_t prefix)
361 {
362 	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
363 	struct {
364 		struct nlmsghdr		nh;
365 		struct ifaddrmsg	info;
366 		char			attrbuf[MAX_PAYLOAD];
367 	} req;
368 
369 	memset(&req, 0, sizeof(req));
370 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
371 	req.nh.nlmsg_type	= RTM_NEWADDR;
372 	req.nh.nlmsg_flags	= flags;
373 	req.nh.nlmsg_seq	= seq;
374 	req.info.ifa_family	= AF_INET;
375 	req.info.ifa_prefixlen	= prefix;
376 	req.info.ifa_index	= if_nametoindex(intf);
377 
378 #ifdef DEBUG
379 	{
380 		char addr_str[IPV4_STR_SZ] = {};
381 
382 		strncpy(addr_str, inet_ntoa(addr), IPV4_STR_SZ - 1);
383 
384 		printk("ip addr set %s", addr_str);
385 	}
386 #endif
387 
388 	if (rtattr_pack(&req.nh, sizeof(req), IFA_LOCAL, &addr, sizeof(addr)))
389 		return -1;
390 
391 	if (rtattr_pack(&req.nh, sizeof(req), IFA_ADDRESS, &addr, sizeof(addr)))
392 		return -1;
393 
394 	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
395 		pr_err("send()");
396 		return -1;
397 	}
398 	return netlink_check_answer(sock);
399 }
400 
401 static int link_set_up(int sock, uint32_t seq, const char *intf)
402 {
403 	struct {
404 		struct nlmsghdr		nh;
405 		struct ifinfomsg	info;
406 		char			attrbuf[MAX_PAYLOAD];
407 	} req;
408 
409 	memset(&req, 0, sizeof(req));
410 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
411 	req.nh.nlmsg_type	= RTM_NEWLINK;
412 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
413 	req.nh.nlmsg_seq	= seq;
414 	req.info.ifi_family	= AF_UNSPEC;
415 	req.info.ifi_change	= 0xFFFFFFFF;
416 	req.info.ifi_index	= if_nametoindex(intf);
417 	req.info.ifi_flags	= IFF_UP;
418 	req.info.ifi_change	= IFF_UP;
419 
420 	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
421 		pr_err("send()");
422 		return -1;
423 	}
424 	return netlink_check_answer(sock);
425 }
426 
427 static int ip4_route_set(int sock, uint32_t seq, const char *intf,
428 		struct in_addr src, struct in_addr dst)
429 {
430 	struct {
431 		struct nlmsghdr	nh;
432 		struct rtmsg	rt;
433 		char		attrbuf[MAX_PAYLOAD];
434 	} req;
435 	unsigned int index = if_nametoindex(intf);
436 
437 	memset(&req, 0, sizeof(req));
438 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.rt));
439 	req.nh.nlmsg_type	= RTM_NEWROUTE;
440 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE;
441 	req.nh.nlmsg_seq	= seq;
442 	req.rt.rtm_family	= AF_INET;
443 	req.rt.rtm_dst_len	= 32;
444 	req.rt.rtm_table	= RT_TABLE_MAIN;
445 	req.rt.rtm_protocol	= RTPROT_BOOT;
446 	req.rt.rtm_scope	= RT_SCOPE_LINK;
447 	req.rt.rtm_type		= RTN_UNICAST;
448 
449 	if (rtattr_pack(&req.nh, sizeof(req), RTA_DST, &dst, sizeof(dst)))
450 		return -1;
451 
452 	if (rtattr_pack(&req.nh, sizeof(req), RTA_PREFSRC, &src, sizeof(src)))
453 		return -1;
454 
455 	if (rtattr_pack(&req.nh, sizeof(req), RTA_OIF, &index, sizeof(index)))
456 		return -1;
457 
458 	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
459 		pr_err("send()");
460 		return -1;
461 	}
462 
463 	return netlink_check_answer(sock);
464 }
465 
466 static int tunnel_set_route(int route_sock, uint32_t *route_seq, char *veth,
467 		struct in_addr tunsrc, struct in_addr tundst)
468 {
469 	if (ip4_addr_set(route_sock, (*route_seq)++, "lo",
470 			tunsrc, PREFIX_LEN)) {
471 		printk("Failed to set ipv4 addr");
472 		return -1;
473 	}
474 
475 	if (ip4_route_set(route_sock, (*route_seq)++, veth, tunsrc, tundst)) {
476 		printk("Failed to set ipv4 route");
477 		return -1;
478 	}
479 
480 	return 0;
481 }
482 
483 static int init_child(int nsfd, char *veth, unsigned int src, unsigned int dst)
484 {
485 	struct in_addr intsrc = inet_makeaddr(INADDR_B, src);
486 	struct in_addr tunsrc = inet_makeaddr(INADDR_A, src);
487 	struct in_addr tundst = inet_makeaddr(INADDR_A, dst);
488 	int route_sock = -1, ret = -1;
489 	uint32_t route_seq;
490 
491 	if (switch_ns(nsfd))
492 		return -1;
493 
494 	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) {
495 		printk("Failed to open netlink route socket in child");
496 		return -1;
497 	}
498 
499 	if (ip4_addr_set(route_sock, route_seq++, veth, intsrc, PREFIX_LEN)) {
500 		printk("Failed to set ipv4 addr");
501 		goto err;
502 	}
503 
504 	if (link_set_up(route_sock, route_seq++, veth)) {
505 		printk("Failed to bring up %s", veth);
506 		goto err;
507 	}
508 
509 	if (tunnel_set_route(route_sock, &route_seq, veth, tunsrc, tundst)) {
510 		printk("Failed to add tunnel route on %s", veth);
511 		goto err;
512 	}
513 	ret = 0;
514 
515 err:
516 	close(route_sock);
517 	return ret;
518 }
519 
520 #define ALGO_LEN	64
521 enum desc_type {
522 	CREATE_TUNNEL	= 0,
523 	ALLOCATE_SPI,
524 	MONITOR_ACQUIRE,
525 	EXPIRE_STATE,
526 	EXPIRE_POLICY,
527 	SPDINFO_ATTRS,
528 };
529 const char *desc_name[] = {
530 	"create tunnel",
531 	"alloc spi",
532 	"monitor acquire",
533 	"expire state",
534 	"expire policy",
535 	"spdinfo attributes",
536 	""
537 };
538 struct xfrm_desc {
539 	enum desc_type	type;
540 	uint8_t		proto;
541 	char		a_algo[ALGO_LEN];
542 	char		e_algo[ALGO_LEN];
543 	char		c_algo[ALGO_LEN];
544 	char		ae_algo[ALGO_LEN];
545 	unsigned int	icv_len;
546 	/* unsigned key_len; */
547 };
548 
549 enum msg_type {
550 	MSG_ACK		= 0,
551 	MSG_EXIT,
552 	MSG_PING,
553 	MSG_XFRM_PREPARE,
554 	MSG_XFRM_ADD,
555 	MSG_XFRM_DEL,
556 	MSG_XFRM_CLEANUP,
557 };
558 
559 struct test_desc {
560 	enum msg_type type;
561 	union {
562 		struct {
563 			in_addr_t reply_ip;
564 			unsigned int port;
565 		} ping;
566 		struct xfrm_desc xfrm_desc;
567 	} body;
568 };
569 
570 struct test_result {
571 	struct xfrm_desc desc;
572 	unsigned int res;
573 };
574 
575 static void write_test_result(unsigned int res, struct xfrm_desc *d)
576 {
577 	struct test_result tr = {};
578 	ssize_t ret;
579 
580 	tr.desc = *d;
581 	tr.res = res;
582 
583 	ret = write(results_fd[1], &tr, sizeof(tr));
584 	if (ret != sizeof(tr))
585 		pr_err("Failed to write the result in pipe %zd", ret);
586 }
587 
588 static void write_msg(int fd, struct test_desc *msg, bool exit_of_fail)
589 {
590 	ssize_t bytes = write(fd, msg, sizeof(*msg));
591 
592 	/* Make sure that write/read is atomic to a pipe */
593 	BUILD_BUG_ON(sizeof(struct test_desc) > PIPE_BUF);
594 
595 	if (bytes < 0) {
596 		pr_err("write()");
597 		if (exit_of_fail)
598 			exit(KSFT_FAIL);
599 	}
600 	if (bytes != sizeof(*msg)) {
601 		pr_err("sent part of the message %zd/%zu", bytes, sizeof(*msg));
602 		if (exit_of_fail)
603 			exit(KSFT_FAIL);
604 	}
605 }
606 
607 static void read_msg(int fd, struct test_desc *msg, bool exit_of_fail)
608 {
609 	ssize_t bytes = read(fd, msg, sizeof(*msg));
610 
611 	if (bytes < 0) {
612 		pr_err("read()");
613 		if (exit_of_fail)
614 			exit(KSFT_FAIL);
615 	}
616 	if (bytes != sizeof(*msg)) {
617 		pr_err("got incomplete message %zd/%zu", bytes, sizeof(*msg));
618 		if (exit_of_fail)
619 			exit(KSFT_FAIL);
620 	}
621 }
622 
623 static int udp_ping_init(struct in_addr listen_ip, unsigned int u_timeout,
624 		unsigned int *server_port, int sock[2])
625 {
626 	struct sockaddr_in server;
627 	struct timeval t = { .tv_sec = 0, .tv_usec = u_timeout };
628 	socklen_t s_len = sizeof(server);
629 
630 	sock[0] = socket(AF_INET, SOCK_DGRAM, 0);
631 	if (sock[0] < 0) {
632 		pr_err("socket()");
633 		return -1;
634 	}
635 
636 	server.sin_family	= AF_INET;
637 	server.sin_port		= 0;
638 	memcpy(&server.sin_addr.s_addr, &listen_ip, sizeof(struct in_addr));
639 
640 	if (bind(sock[0], (struct sockaddr *)&server, s_len)) {
641 		pr_err("bind()");
642 		goto err_close_server;
643 	}
644 
645 	if (getsockname(sock[0], (struct sockaddr *)&server, &s_len)) {
646 		pr_err("getsockname()");
647 		goto err_close_server;
648 	}
649 
650 	*server_port = ntohs(server.sin_port);
651 
652 	if (setsockopt(sock[0], SOL_SOCKET, SO_RCVTIMEO, (const char *)&t, sizeof t)) {
653 		pr_err("setsockopt()");
654 		goto err_close_server;
655 	}
656 
657 	sock[1] = socket(AF_INET, SOCK_DGRAM, 0);
658 	if (sock[1] < 0) {
659 		pr_err("socket()");
660 		goto err_close_server;
661 	}
662 
663 	return 0;
664 
665 err_close_server:
666 	close(sock[0]);
667 	return -1;
668 }
669 
670 static int udp_ping_send(int sock[2], in_addr_t dest_ip, unsigned int port,
671 		char *buf, size_t buf_len)
672 {
673 	struct sockaddr_in server;
674 	const struct sockaddr *dest_addr = (struct sockaddr *)&server;
675 	char *sock_buf[buf_len];
676 	ssize_t r_bytes, s_bytes;
677 
678 	server.sin_family	= AF_INET;
679 	server.sin_port		= htons(port);
680 	server.sin_addr.s_addr	= dest_ip;
681 
682 	s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
683 	if (s_bytes < 0) {
684 		pr_err("sendto()");
685 		return -1;
686 	} else if (s_bytes != buf_len) {
687 		printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
688 		return -1;
689 	}
690 
691 	r_bytes = recv(sock[0], sock_buf, buf_len, 0);
692 	if (r_bytes < 0) {
693 		if (errno != EAGAIN)
694 			pr_err("recv()");
695 		return -1;
696 	} else if (r_bytes == 0) { /* EOF */
697 		printk("EOF on reply to ping");
698 		return -1;
699 	} else if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
700 		printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
701 		return -1;
702 	}
703 
704 	return 0;
705 }
706 
707 static int udp_ping_reply(int sock[2], in_addr_t dest_ip, unsigned int port,
708 		char *buf, size_t buf_len)
709 {
710 	struct sockaddr_in server;
711 	const struct sockaddr *dest_addr = (struct sockaddr *)&server;
712 	char *sock_buf[buf_len];
713 	ssize_t r_bytes, s_bytes;
714 
715 	server.sin_family	= AF_INET;
716 	server.sin_port		= htons(port);
717 	server.sin_addr.s_addr	= dest_ip;
718 
719 	r_bytes = recv(sock[0], sock_buf, buf_len, 0);
720 	if (r_bytes < 0) {
721 		if (errno != EAGAIN)
722 			pr_err("recv()");
723 		return -1;
724 	}
725 	if (r_bytes == 0) { /* EOF */
726 		printk("EOF on reply to ping");
727 		return -1;
728 	}
729 	if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
730 		printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
731 		return -1;
732 	}
733 
734 	s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
735 	if (s_bytes < 0) {
736 		pr_err("sendto()");
737 		return -1;
738 	} else if (s_bytes != buf_len) {
739 		printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
740 		return -1;
741 	}
742 
743 	return 0;
744 }
745 
746 typedef int (*ping_f)(int sock[2], in_addr_t dest_ip, unsigned int port,
747 		char *buf, size_t buf_len);
748 static int do_ping(int cmd_fd, char *buf, size_t buf_len, struct in_addr from,
749 		bool init_side, int d_port, in_addr_t to, ping_f func)
750 {
751 	struct test_desc msg;
752 	unsigned int s_port, i, ping_succeeded = 0;
753 	int ping_sock[2];
754 	char to_str[IPV4_STR_SZ] = {}, from_str[IPV4_STR_SZ] = {};
755 
756 	if (udp_ping_init(from, ping_timeout, &s_port, ping_sock)) {
757 		printk("Failed to init ping");
758 		return -1;
759 	}
760 
761 	memset(&msg, 0, sizeof(msg));
762 	msg.type		= MSG_PING;
763 	msg.body.ping.port	= s_port;
764 	memcpy(&msg.body.ping.reply_ip, &from, sizeof(from));
765 
766 	write_msg(cmd_fd, &msg, 0);
767 	if (init_side) {
768 		/* The other end sends ip to ping */
769 		read_msg(cmd_fd, &msg, 0);
770 		if (msg.type != MSG_PING)
771 			return -1;
772 		to = msg.body.ping.reply_ip;
773 		d_port = msg.body.ping.port;
774 	}
775 
776 	for (i = 0; i < ping_count ; i++) {
777 		struct timespec sleep_time = {
778 			.tv_sec = 0,
779 			.tv_nsec = ping_delay_nsec,
780 		};
781 
782 		ping_succeeded += !func(ping_sock, to, d_port, buf, page_size);
783 		nanosleep(&sleep_time, 0);
784 	}
785 
786 	close(ping_sock[0]);
787 	close(ping_sock[1]);
788 
789 	strncpy(to_str, inet_ntoa(*(struct in_addr *)&to), IPV4_STR_SZ - 1);
790 	strncpy(from_str, inet_ntoa(from), IPV4_STR_SZ - 1);
791 
792 	if (ping_succeeded < ping_success) {
793 		printk("ping (%s) %s->%s failed %u/%u times",
794 			init_side ? "send" : "reply", from_str, to_str,
795 			ping_count - ping_succeeded, ping_count);
796 		return -1;
797 	}
798 
799 #ifdef DEBUG
800 	printk("ping (%s) %s->%s succeeded %u/%u times",
801 		init_side ? "send" : "reply", from_str, to_str,
802 		ping_succeeded, ping_count);
803 #endif
804 
805 	return 0;
806 }
807 
808 static int xfrm_fill_key(char *name, char *buf,
809 		size_t buf_len, unsigned int *key_len)
810 {
811 	int i;
812 
813 	for (i = 0; i < ARRAY_SIZE(xfrm_key_entries); i++) {
814 		if (strncmp(name, xfrm_key_entries[i].algo_name, ALGO_LEN) == 0)
815 			*key_len = xfrm_key_entries[i].key_len;
816 	}
817 
818 	if (*key_len > buf_len) {
819 		printk("Can't pack a key - too big for buffer");
820 		return -1;
821 	}
822 
823 	randomize_buffer(buf, *key_len);
824 
825 	return 0;
826 }
827 
828 static int xfrm_state_pack_algo(struct nlmsghdr *nh, size_t req_sz,
829 		struct xfrm_desc *desc)
830 {
831 	union {
832 		union {
833 			struct xfrm_algo	alg;
834 			struct xfrm_algo_aead	aead;
835 			struct xfrm_algo_auth	auth;
836 		} u;
837 		struct {
838 			unsigned char __offset_to_FAM[offsetof(struct xfrm_algo_auth, alg_key)];
839 			char buf[XFRM_ALGO_KEY_BUF_SIZE];
840 		};
841 	} alg = {};
842 	size_t alen, elen, clen, aelen;
843 	unsigned short type;
844 
845 	alen = strlen(desc->a_algo);
846 	elen = strlen(desc->e_algo);
847 	clen = strlen(desc->c_algo);
848 	aelen = strlen(desc->ae_algo);
849 
850 	/* Verify desc */
851 	switch (desc->proto) {
852 	case IPPROTO_AH:
853 		if (!alen || elen || clen || aelen) {
854 			printk("BUG: buggy ah desc");
855 			return -1;
856 		}
857 		strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN - 1);
858 		if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
859 				sizeof(alg.buf), &alg.u.alg.alg_key_len))
860 			return -1;
861 		type = XFRMA_ALG_AUTH;
862 		break;
863 	case IPPROTO_COMP:
864 		if (!clen || elen || alen || aelen) {
865 			printk("BUG: buggy comp desc");
866 			return -1;
867 		}
868 		strncpy(alg.u.alg.alg_name, desc->c_algo, ALGO_LEN - 1);
869 		if (xfrm_fill_key(desc->c_algo, alg.u.alg.alg_key,
870 				sizeof(alg.buf), &alg.u.alg.alg_key_len))
871 			return -1;
872 		type = XFRMA_ALG_COMP;
873 		break;
874 	case IPPROTO_ESP:
875 		if (!((alen && elen) ^ aelen) || clen) {
876 			printk("BUG: buggy esp desc");
877 			return -1;
878 		}
879 		if (aelen) {
880 			alg.u.aead.alg_icv_len = desc->icv_len;
881 			strncpy(alg.u.aead.alg_name, desc->ae_algo, ALGO_LEN - 1);
882 			if (xfrm_fill_key(desc->ae_algo, alg.u.aead.alg_key,
883 						sizeof(alg.buf), &alg.u.aead.alg_key_len))
884 				return -1;
885 			type = XFRMA_ALG_AEAD;
886 		} else {
887 
888 			strncpy(alg.u.alg.alg_name, desc->e_algo, ALGO_LEN - 1);
889 			type = XFRMA_ALG_CRYPT;
890 			if (xfrm_fill_key(desc->e_algo, alg.u.alg.alg_key,
891 						sizeof(alg.buf), &alg.u.alg.alg_key_len))
892 				return -1;
893 			if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
894 				return -1;
895 
896 			strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN);
897 			type = XFRMA_ALG_AUTH;
898 			if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
899 						sizeof(alg.buf), &alg.u.alg.alg_key_len))
900 				return -1;
901 		}
902 		break;
903 	default:
904 		printk("BUG: unknown proto in desc");
905 		return -1;
906 	}
907 
908 	if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
909 		return -1;
910 
911 	return 0;
912 }
913 
914 static inline uint32_t gen_spi(struct in_addr src)
915 {
916 	return htonl(inet_lnaof(src));
917 }
918 
919 static int xfrm_state_add(int xfrm_sock, uint32_t seq, uint32_t spi,
920 		struct in_addr src, struct in_addr dst,
921 		struct xfrm_desc *desc)
922 {
923 	struct {
924 		struct nlmsghdr		nh;
925 		struct xfrm_usersa_info	info;
926 		char			attrbuf[MAX_PAYLOAD];
927 	} req;
928 
929 	memset(&req, 0, sizeof(req));
930 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
931 	req.nh.nlmsg_type	= XFRM_MSG_NEWSA;
932 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
933 	req.nh.nlmsg_seq	= seq;
934 
935 	/* Fill selector. */
936 	memcpy(&req.info.sel.daddr, &dst, sizeof(dst));
937 	memcpy(&req.info.sel.saddr, &src, sizeof(src));
938 	req.info.sel.family		= AF_INET;
939 	req.info.sel.prefixlen_d	= PREFIX_LEN;
940 	req.info.sel.prefixlen_s	= PREFIX_LEN;
941 
942 	/* Fill id */
943 	memcpy(&req.info.id.daddr, &dst, sizeof(dst));
944 	/* Note: zero-spi cannot be deleted */
945 	req.info.id.spi = spi;
946 	req.info.id.proto	= desc->proto;
947 
948 	memcpy(&req.info.saddr, &src, sizeof(src));
949 
950 	/* Fill lifteme_cfg */
951 	req.info.lft.soft_byte_limit	= XFRM_INF;
952 	req.info.lft.hard_byte_limit	= XFRM_INF;
953 	req.info.lft.soft_packet_limit	= XFRM_INF;
954 	req.info.lft.hard_packet_limit	= XFRM_INF;
955 
956 	req.info.family		= AF_INET;
957 	req.info.mode		= XFRM_MODE_TUNNEL;
958 
959 	if (xfrm_state_pack_algo(&req.nh, sizeof(req), desc))
960 		return -1;
961 
962 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
963 		pr_err("send()");
964 		return -1;
965 	}
966 
967 	return netlink_check_answer(xfrm_sock);
968 }
969 
970 static bool xfrm_usersa_found(struct xfrm_usersa_info *info, uint32_t spi,
971 		struct in_addr src, struct in_addr dst,
972 		struct xfrm_desc *desc)
973 {
974 	if (memcmp(&info->sel.daddr, &dst, sizeof(dst)))
975 		return false;
976 
977 	if (memcmp(&info->sel.saddr, &src, sizeof(src)))
978 		return false;
979 
980 	if (info->sel.family != AF_INET					||
981 			info->sel.prefixlen_d != PREFIX_LEN		||
982 			info->sel.prefixlen_s != PREFIX_LEN)
983 		return false;
984 
985 	if (info->id.spi != spi || info->id.proto != desc->proto)
986 		return false;
987 
988 	if (memcmp(&info->id.daddr, &dst, sizeof(dst)))
989 		return false;
990 
991 	if (memcmp(&info->saddr, &src, sizeof(src)))
992 		return false;
993 
994 	if (info->lft.soft_byte_limit != XFRM_INF			||
995 			info->lft.hard_byte_limit != XFRM_INF		||
996 			info->lft.soft_packet_limit != XFRM_INF		||
997 			info->lft.hard_packet_limit != XFRM_INF)
998 		return false;
999 
1000 	if (info->family != AF_INET || info->mode != XFRM_MODE_TUNNEL)
1001 		return false;
1002 
1003 	/* XXX: check xfrm algo, see xfrm_state_pack_algo(). */
1004 
1005 	return true;
1006 }
1007 
1008 static int xfrm_state_check(int xfrm_sock, uint32_t seq, uint32_t spi,
1009 		struct in_addr src, struct in_addr dst,
1010 		struct xfrm_desc *desc)
1011 {
1012 	struct {
1013 		struct nlmsghdr		nh;
1014 		char			attrbuf[MAX_PAYLOAD];
1015 	} req;
1016 	struct {
1017 		struct nlmsghdr		nh;
1018 		union {
1019 			struct xfrm_usersa_info	info;
1020 			int error;
1021 		};
1022 		char			attrbuf[MAX_PAYLOAD];
1023 	} answer;
1024 	struct xfrm_address_filter filter = {};
1025 	bool found = false;
1026 
1027 
1028 	memset(&req, 0, sizeof(req));
1029 	req.nh.nlmsg_len	= NLMSG_LENGTH(0);
1030 	req.nh.nlmsg_type	= XFRM_MSG_GETSA;
1031 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_DUMP;
1032 	req.nh.nlmsg_seq	= seq;
1033 
1034 	/*
1035 	 * Add dump filter by source address as there may be other tunnels
1036 	 * in this netns (if tests run in parallel).
1037 	 */
1038 	filter.family = AF_INET;
1039 	filter.splen = 0x1f;	/* 0xffffffff mask see addr_match() */
1040 	memcpy(&filter.saddr, &src, sizeof(src));
1041 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_ADDRESS_FILTER,
1042 				&filter, sizeof(filter)))
1043 		return -1;
1044 
1045 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1046 		pr_err("send()");
1047 		return -1;
1048 	}
1049 
1050 	while (1) {
1051 		if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1052 			pr_err("recv()");
1053 			return -1;
1054 		}
1055 		if (answer.nh.nlmsg_type == NLMSG_ERROR) {
1056 			printk("NLMSG_ERROR: %d: %s",
1057 				answer.error, strerror(-answer.error));
1058 			return -1;
1059 		} else if (answer.nh.nlmsg_type == NLMSG_DONE) {
1060 			if (found)
1061 				return 0;
1062 			printk("didn't find allocated xfrm state in dump");
1063 			return -1;
1064 		} else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1065 			if (xfrm_usersa_found(&answer.info, spi, src, dst, desc))
1066 				found = true;
1067 		}
1068 	}
1069 }
1070 
1071 static int xfrm_set(int xfrm_sock, uint32_t *seq,
1072 		struct in_addr src, struct in_addr dst,
1073 		struct in_addr tunsrc, struct in_addr tundst,
1074 		struct xfrm_desc *desc)
1075 {
1076 	int err;
1077 
1078 	err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1079 	if (err) {
1080 		printk("Failed to add xfrm state");
1081 		return -1;
1082 	}
1083 
1084 	err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1085 	if (err) {
1086 		printk("Failed to add xfrm state");
1087 		return -1;
1088 	}
1089 
1090 	/* Check dumps for XFRM_MSG_GETSA */
1091 	err = xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1092 	err |= xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1093 	if (err) {
1094 		printk("Failed to check xfrm state");
1095 		return -1;
1096 	}
1097 
1098 	return 0;
1099 }
1100 
1101 static int xfrm_policy_add(int xfrm_sock, uint32_t seq, uint32_t spi,
1102 		struct in_addr src, struct in_addr dst, uint8_t dir,
1103 		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1104 {
1105 	struct {
1106 		struct nlmsghdr			nh;
1107 		struct xfrm_userpolicy_info	info;
1108 		char				attrbuf[MAX_PAYLOAD];
1109 	} req;
1110 	struct xfrm_user_tmpl tmpl;
1111 
1112 	memset(&req, 0, sizeof(req));
1113 	memset(&tmpl, 0, sizeof(tmpl));
1114 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
1115 	req.nh.nlmsg_type	= XFRM_MSG_NEWPOLICY;
1116 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1117 	req.nh.nlmsg_seq	= seq;
1118 
1119 	/* Fill selector. */
1120 	memcpy(&req.info.sel.daddr, &dst, sizeof(tundst));
1121 	memcpy(&req.info.sel.saddr, &src, sizeof(tunsrc));
1122 	req.info.sel.family		= AF_INET;
1123 	req.info.sel.prefixlen_d	= PREFIX_LEN;
1124 	req.info.sel.prefixlen_s	= PREFIX_LEN;
1125 
1126 	/* Fill lifteme_cfg */
1127 	req.info.lft.soft_byte_limit	= XFRM_INF;
1128 	req.info.lft.hard_byte_limit	= XFRM_INF;
1129 	req.info.lft.soft_packet_limit	= XFRM_INF;
1130 	req.info.lft.hard_packet_limit	= XFRM_INF;
1131 
1132 	req.info.dir = dir;
1133 
1134 	/* Fill tmpl */
1135 	memcpy(&tmpl.id.daddr, &dst, sizeof(dst));
1136 	/* Note: zero-spi cannot be deleted */
1137 	tmpl.id.spi = spi;
1138 	tmpl.id.proto	= proto;
1139 	tmpl.family	= AF_INET;
1140 	memcpy(&tmpl.saddr, &src, sizeof(src));
1141 	tmpl.mode	= XFRM_MODE_TUNNEL;
1142 	tmpl.aalgos = (~(uint32_t)0);
1143 	tmpl.ealgos = (~(uint32_t)0);
1144 	tmpl.calgos = (~(uint32_t)0);
1145 
1146 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &tmpl, sizeof(tmpl)))
1147 		return -1;
1148 
1149 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1150 		pr_err("send()");
1151 		return -1;
1152 	}
1153 
1154 	return netlink_check_answer(xfrm_sock);
1155 }
1156 
1157 static int xfrm_prepare(int xfrm_sock, uint32_t *seq,
1158 		struct in_addr src, struct in_addr dst,
1159 		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1160 {
1161 	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1162 				XFRM_POLICY_OUT, tunsrc, tundst, proto)) {
1163 		printk("Failed to add xfrm policy");
1164 		return -1;
1165 	}
1166 
1167 	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src,
1168 				XFRM_POLICY_IN, tunsrc, tundst, proto)) {
1169 		printk("Failed to add xfrm policy");
1170 		return -1;
1171 	}
1172 
1173 	return 0;
1174 }
1175 
1176 static int xfrm_policy_del(int xfrm_sock, uint32_t seq,
1177 		struct in_addr src, struct in_addr dst, uint8_t dir,
1178 		struct in_addr tunsrc, struct in_addr tundst)
1179 {
1180 	struct {
1181 		struct nlmsghdr			nh;
1182 		struct xfrm_userpolicy_id	id;
1183 		char				attrbuf[MAX_PAYLOAD];
1184 	} req;
1185 
1186 	memset(&req, 0, sizeof(req));
1187 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.id));
1188 	req.nh.nlmsg_type	= XFRM_MSG_DELPOLICY;
1189 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1190 	req.nh.nlmsg_seq	= seq;
1191 
1192 	/* Fill id */
1193 	memcpy(&req.id.sel.daddr, &dst, sizeof(tundst));
1194 	memcpy(&req.id.sel.saddr, &src, sizeof(tunsrc));
1195 	req.id.sel.family		= AF_INET;
1196 	req.id.sel.prefixlen_d		= PREFIX_LEN;
1197 	req.id.sel.prefixlen_s		= PREFIX_LEN;
1198 	req.id.dir = dir;
1199 
1200 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1201 		pr_err("send()");
1202 		return -1;
1203 	}
1204 
1205 	return netlink_check_answer(xfrm_sock);
1206 }
1207 
1208 static int xfrm_cleanup(int xfrm_sock, uint32_t *seq,
1209 		struct in_addr src, struct in_addr dst,
1210 		struct in_addr tunsrc, struct in_addr tundst)
1211 {
1212 	if (xfrm_policy_del(xfrm_sock, (*seq)++, src, dst,
1213 				XFRM_POLICY_OUT, tunsrc, tundst)) {
1214 		printk("Failed to add xfrm policy");
1215 		return -1;
1216 	}
1217 
1218 	if (xfrm_policy_del(xfrm_sock, (*seq)++, dst, src,
1219 				XFRM_POLICY_IN, tunsrc, tundst)) {
1220 		printk("Failed to add xfrm policy");
1221 		return -1;
1222 	}
1223 
1224 	return 0;
1225 }
1226 
1227 static int xfrm_state_del(int xfrm_sock, uint32_t seq, uint32_t spi,
1228 		struct in_addr src, struct in_addr dst, uint8_t proto)
1229 {
1230 	struct {
1231 		struct nlmsghdr		nh;
1232 		struct xfrm_usersa_id	id;
1233 		char			attrbuf[MAX_PAYLOAD];
1234 	} req;
1235 	xfrm_address_t saddr = {};
1236 
1237 	memset(&req, 0, sizeof(req));
1238 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.id));
1239 	req.nh.nlmsg_type	= XFRM_MSG_DELSA;
1240 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1241 	req.nh.nlmsg_seq	= seq;
1242 
1243 	memcpy(&req.id.daddr, &dst, sizeof(dst));
1244 	req.id.family		= AF_INET;
1245 	req.id.proto		= proto;
1246 	/* Note: zero-spi cannot be deleted */
1247 	req.id.spi = spi;
1248 
1249 	memcpy(&saddr, &src, sizeof(src));
1250 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SRCADDR, &saddr, sizeof(saddr)))
1251 		return -1;
1252 
1253 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1254 		pr_err("send()");
1255 		return -1;
1256 	}
1257 
1258 	return netlink_check_answer(xfrm_sock);
1259 }
1260 
1261 static int xfrm_delete(int xfrm_sock, uint32_t *seq,
1262 		struct in_addr src, struct in_addr dst,
1263 		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1264 {
1265 	if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), src, dst, proto)) {
1266 		printk("Failed to remove xfrm state");
1267 		return -1;
1268 	}
1269 
1270 	if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), dst, src, proto)) {
1271 		printk("Failed to remove xfrm state");
1272 		return -1;
1273 	}
1274 
1275 	return 0;
1276 }
1277 
1278 static int xfrm_state_allocspi(int xfrm_sock, uint32_t *seq,
1279 		uint32_t spi, uint8_t proto)
1280 {
1281 	struct {
1282 		struct nlmsghdr			nh;
1283 		struct xfrm_userspi_info	spi;
1284 	} req;
1285 	struct {
1286 		struct nlmsghdr			nh;
1287 		union {
1288 			struct xfrm_usersa_info	info;
1289 			int error;
1290 		};
1291 	} answer;
1292 
1293 	memset(&req, 0, sizeof(req));
1294 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.spi));
1295 	req.nh.nlmsg_type	= XFRM_MSG_ALLOCSPI;
1296 	req.nh.nlmsg_flags	= NLM_F_REQUEST;
1297 	req.nh.nlmsg_seq	= (*seq)++;
1298 
1299 	req.spi.info.family	= AF_INET;
1300 	req.spi.min		= spi;
1301 	req.spi.max		= spi;
1302 	req.spi.info.id.proto	= proto;
1303 
1304 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1305 		pr_err("send()");
1306 		return KSFT_FAIL;
1307 	}
1308 
1309 	if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1310 		pr_err("recv()");
1311 		return KSFT_FAIL;
1312 	} else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1313 		uint32_t new_spi = htonl(answer.info.id.spi);
1314 
1315 		if (new_spi != spi) {
1316 			printk("allocated spi is different from requested: %#x != %#x",
1317 					new_spi, spi);
1318 			return KSFT_FAIL;
1319 		}
1320 		return KSFT_PASS;
1321 	} else if (answer.nh.nlmsg_type != NLMSG_ERROR) {
1322 		printk("expected NLMSG_ERROR, got %d", (int)answer.nh.nlmsg_type);
1323 		return KSFT_FAIL;
1324 	}
1325 
1326 	printk("NLMSG_ERROR: %d: %s", answer.error, strerror(-answer.error));
1327 	return (answer.error) ? KSFT_FAIL : KSFT_PASS;
1328 }
1329 
1330 static int netlink_sock_bind(int *sock, uint32_t *seq, int proto, uint32_t groups)
1331 {
1332 	struct sockaddr_nl snl = {};
1333 	socklen_t addr_len;
1334 	int ret = -1;
1335 
1336 	snl.nl_family = AF_NETLINK;
1337 	snl.nl_groups = groups;
1338 
1339 	if (netlink_sock(sock, seq, proto)) {
1340 		printk("Failed to open xfrm netlink socket");
1341 		return -1;
1342 	}
1343 
1344 	if (bind(*sock, (struct sockaddr *)&snl, sizeof(snl)) < 0) {
1345 		pr_err("bind()");
1346 		goto out_close;
1347 	}
1348 
1349 	addr_len = sizeof(snl);
1350 	if (getsockname(*sock, (struct sockaddr *)&snl, &addr_len) < 0) {
1351 		pr_err("getsockname()");
1352 		goto out_close;
1353 	}
1354 	if (addr_len != sizeof(snl)) {
1355 		printk("Wrong address length %d", addr_len);
1356 		goto out_close;
1357 	}
1358 	if (snl.nl_family != AF_NETLINK) {
1359 		printk("Wrong address family %d", snl.nl_family);
1360 		goto out_close;
1361 	}
1362 	return 0;
1363 
1364 out_close:
1365 	close(*sock);
1366 	return ret;
1367 }
1368 
1369 static int xfrm_monitor_acquire(int xfrm_sock, uint32_t *seq, unsigned int nr)
1370 {
1371 	struct {
1372 		struct nlmsghdr nh;
1373 		union {
1374 			struct xfrm_user_acquire acq;
1375 			int error;
1376 		};
1377 		char attrbuf[MAX_PAYLOAD];
1378 	} req;
1379 	struct xfrm_user_tmpl xfrm_tmpl = {};
1380 	int xfrm_listen = -1, ret = KSFT_FAIL;
1381 	uint32_t seq_listen;
1382 
1383 	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_ACQUIRE))
1384 		return KSFT_FAIL;
1385 
1386 	memset(&req, 0, sizeof(req));
1387 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.acq));
1388 	req.nh.nlmsg_type	= XFRM_MSG_ACQUIRE;
1389 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1390 	req.nh.nlmsg_seq	= (*seq)++;
1391 
1392 	req.acq.policy.sel.family	= AF_INET;
1393 	req.acq.aalgos	= 0xfeed;
1394 	req.acq.ealgos	= 0xbaad;
1395 	req.acq.calgos	= 0xbabe;
1396 
1397 	xfrm_tmpl.family = AF_INET;
1398 	xfrm_tmpl.id.proto = IPPROTO_ESP;
1399 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &xfrm_tmpl, sizeof(xfrm_tmpl)))
1400 		goto out_close;
1401 
1402 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1403 		pr_err("send()");
1404 		goto out_close;
1405 	}
1406 
1407 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1408 		pr_err("recv()");
1409 		goto out_close;
1410 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1411 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1412 		goto out_close;
1413 	}
1414 
1415 	if (req.error) {
1416 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1417 		ret = req.error;
1418 		goto out_close;
1419 	}
1420 
1421 	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1422 		pr_err("recv()");
1423 		goto out_close;
1424 	}
1425 
1426 	if (req.acq.aalgos != 0xfeed || req.acq.ealgos != 0xbaad
1427 			|| req.acq.calgos != 0xbabe) {
1428 		printk("xfrm_user_acquire has changed  %x %x %x",
1429 				req.acq.aalgos, req.acq.ealgos, req.acq.calgos);
1430 		goto out_close;
1431 	}
1432 
1433 	ret = KSFT_PASS;
1434 out_close:
1435 	close(xfrm_listen);
1436 	return ret;
1437 }
1438 
1439 static int xfrm_expire_state(int xfrm_sock, uint32_t *seq,
1440 		unsigned int nr, struct xfrm_desc *desc)
1441 {
1442 	struct {
1443 		struct nlmsghdr nh;
1444 		union {
1445 			struct xfrm_user_expire expire;
1446 			int error;
1447 		};
1448 	} req;
1449 	struct in_addr src, dst;
1450 	int xfrm_listen = -1, ret = KSFT_FAIL;
1451 	uint32_t seq_listen;
1452 
1453 	src = inet_makeaddr(INADDR_B, child_ip(nr));
1454 	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1455 
1456 	if (xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc)) {
1457 		printk("Failed to add xfrm state");
1458 		return KSFT_FAIL;
1459 	}
1460 
1461 	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1462 		return KSFT_FAIL;
1463 
1464 	memset(&req, 0, sizeof(req));
1465 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.expire));
1466 	req.nh.nlmsg_type	= XFRM_MSG_EXPIRE;
1467 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1468 	req.nh.nlmsg_seq	= (*seq)++;
1469 
1470 	memcpy(&req.expire.state.id.daddr, &dst, sizeof(dst));
1471 	req.expire.state.id.spi		= gen_spi(src);
1472 	req.expire.state.id.proto	= desc->proto;
1473 	req.expire.state.family		= AF_INET;
1474 	req.expire.hard			= 0xff;
1475 
1476 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1477 		pr_err("send()");
1478 		goto out_close;
1479 	}
1480 
1481 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1482 		pr_err("recv()");
1483 		goto out_close;
1484 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1485 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1486 		goto out_close;
1487 	}
1488 
1489 	if (req.error) {
1490 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1491 		ret = req.error;
1492 		goto out_close;
1493 	}
1494 
1495 	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1496 		pr_err("recv()");
1497 		goto out_close;
1498 	}
1499 
1500 	if (req.expire.hard != 0x1) {
1501 		printk("expire.hard is not set: %x", req.expire.hard);
1502 		goto out_close;
1503 	}
1504 
1505 	ret = KSFT_PASS;
1506 out_close:
1507 	close(xfrm_listen);
1508 	return ret;
1509 }
1510 
1511 static int xfrm_expire_policy(int xfrm_sock, uint32_t *seq,
1512 		unsigned int nr, struct xfrm_desc *desc)
1513 {
1514 	struct {
1515 		struct nlmsghdr nh;
1516 		union {
1517 			struct xfrm_user_polexpire expire;
1518 			int error;
1519 		};
1520 	} req;
1521 	struct in_addr src, dst, tunsrc, tundst;
1522 	int xfrm_listen = -1, ret = KSFT_FAIL;
1523 	uint32_t seq_listen;
1524 
1525 	src = inet_makeaddr(INADDR_B, child_ip(nr));
1526 	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1527 	tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1528 	tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1529 
1530 	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1531 				XFRM_POLICY_OUT, tunsrc, tundst, desc->proto)) {
1532 		printk("Failed to add xfrm policy");
1533 		return KSFT_FAIL;
1534 	}
1535 
1536 	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1537 		return KSFT_FAIL;
1538 
1539 	memset(&req, 0, sizeof(req));
1540 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.expire));
1541 	req.nh.nlmsg_type	= XFRM_MSG_POLEXPIRE;
1542 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1543 	req.nh.nlmsg_seq	= (*seq)++;
1544 
1545 	/* Fill selector. */
1546 	memcpy(&req.expire.pol.sel.daddr, &dst, sizeof(tundst));
1547 	memcpy(&req.expire.pol.sel.saddr, &src, sizeof(tunsrc));
1548 	req.expire.pol.sel.family	= AF_INET;
1549 	req.expire.pol.sel.prefixlen_d	= PREFIX_LEN;
1550 	req.expire.pol.sel.prefixlen_s	= PREFIX_LEN;
1551 	req.expire.pol.dir		= XFRM_POLICY_OUT;
1552 	req.expire.hard			= 0xff;
1553 
1554 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1555 		pr_err("send()");
1556 		goto out_close;
1557 	}
1558 
1559 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1560 		pr_err("recv()");
1561 		goto out_close;
1562 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1563 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1564 		goto out_close;
1565 	}
1566 
1567 	if (req.error) {
1568 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1569 		ret = req.error;
1570 		goto out_close;
1571 	}
1572 
1573 	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1574 		pr_err("recv()");
1575 		goto out_close;
1576 	}
1577 
1578 	if (req.expire.hard != 0x1) {
1579 		printk("expire.hard is not set: %x", req.expire.hard);
1580 		goto out_close;
1581 	}
1582 
1583 	ret = KSFT_PASS;
1584 out_close:
1585 	close(xfrm_listen);
1586 	return ret;
1587 }
1588 
1589 static int xfrm_spdinfo_set_thresh(int xfrm_sock, uint32_t *seq,
1590 		unsigned thresh4_l, unsigned thresh4_r,
1591 		unsigned thresh6_l, unsigned thresh6_r,
1592 		bool add_bad_attr)
1593 
1594 {
1595 	struct {
1596 		struct nlmsghdr		nh;
1597 		union {
1598 			uint32_t	unused;
1599 			int		error;
1600 		};
1601 		char			attrbuf[MAX_PAYLOAD];
1602 	} req;
1603 	struct xfrmu_spdhthresh thresh;
1604 
1605 	memset(&req, 0, sizeof(req));
1606 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.unused));
1607 	req.nh.nlmsg_type	= XFRM_MSG_NEWSPDINFO;
1608 	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
1609 	req.nh.nlmsg_seq	= (*seq)++;
1610 
1611 	thresh.lbits = thresh4_l;
1612 	thresh.rbits = thresh4_r;
1613 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SPD_IPV4_HTHRESH, &thresh, sizeof(thresh)))
1614 		return -1;
1615 
1616 	thresh.lbits = thresh6_l;
1617 	thresh.rbits = thresh6_r;
1618 	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SPD_IPV6_HTHRESH, &thresh, sizeof(thresh)))
1619 		return -1;
1620 
1621 	if (add_bad_attr) {
1622 		BUILD_BUG_ON(XFRMA_IF_ID <= XFRMA_SPD_MAX + 1);
1623 		if (rtattr_pack(&req.nh, sizeof(req), XFRMA_IF_ID, NULL, 0)) {
1624 			pr_err("adding attribute failed: no space");
1625 			return -1;
1626 		}
1627 	}
1628 
1629 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1630 		pr_err("send()");
1631 		return -1;
1632 	}
1633 
1634 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1635 		pr_err("recv()");
1636 		return -1;
1637 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1638 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1639 		return -1;
1640 	}
1641 
1642 	if (req.error) {
1643 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1644 		return -1;
1645 	}
1646 
1647 	return 0;
1648 }
1649 
1650 static int xfrm_spdinfo_attrs(int xfrm_sock, uint32_t *seq)
1651 {
1652 	struct {
1653 		struct nlmsghdr			nh;
1654 		union {
1655 			uint32_t	unused;
1656 			int		error;
1657 		};
1658 		char			attrbuf[MAX_PAYLOAD];
1659 	} req;
1660 
1661 	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 31, 120, 16, false)) {
1662 		pr_err("Can't set SPD HTHRESH");
1663 		return KSFT_FAIL;
1664 	}
1665 
1666 	memset(&req, 0, sizeof(req));
1667 
1668 	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.unused));
1669 	req.nh.nlmsg_type	= XFRM_MSG_GETSPDINFO;
1670 	req.nh.nlmsg_flags	= NLM_F_REQUEST;
1671 	req.nh.nlmsg_seq	= (*seq)++;
1672 	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1673 		pr_err("send()");
1674 		return KSFT_FAIL;
1675 	}
1676 
1677 	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1678 		pr_err("recv()");
1679 		return KSFT_FAIL;
1680 	} else if (req.nh.nlmsg_type == XFRM_MSG_NEWSPDINFO) {
1681 		size_t len = NLMSG_PAYLOAD(&req.nh, sizeof(req.unused));
1682 		struct rtattr *attr = (void *)req.attrbuf;
1683 		int got_thresh = 0;
1684 
1685 		for (; RTA_OK(attr, len); attr = RTA_NEXT(attr, len)) {
1686 			if (attr->rta_type == XFRMA_SPD_IPV4_HTHRESH) {
1687 				struct xfrmu_spdhthresh *t = RTA_DATA(attr);
1688 
1689 				got_thresh++;
1690 				if (t->lbits != 32 || t->rbits != 31) {
1691 					pr_err("thresh differ: %u, %u",
1692 							t->lbits, t->rbits);
1693 					return KSFT_FAIL;
1694 				}
1695 			}
1696 			if (attr->rta_type == XFRMA_SPD_IPV6_HTHRESH) {
1697 				struct xfrmu_spdhthresh *t = RTA_DATA(attr);
1698 
1699 				got_thresh++;
1700 				if (t->lbits != 120 || t->rbits != 16) {
1701 					pr_err("thresh differ: %u, %u",
1702 							t->lbits, t->rbits);
1703 					return KSFT_FAIL;
1704 				}
1705 			}
1706 		}
1707 		if (got_thresh != 2) {
1708 			pr_err("only %d thresh returned by XFRM_MSG_GETSPDINFO", got_thresh);
1709 			return KSFT_FAIL;
1710 		}
1711 	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1712 		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1713 		return KSFT_FAIL;
1714 	} else {
1715 		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1716 		return -1;
1717 	}
1718 
1719 	/* Restore the default */
1720 	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 32, 128, 128, false)) {
1721 		pr_err("Can't restore SPD HTHRESH");
1722 		return KSFT_FAIL;
1723 	}
1724 
1725 	/*
1726 	 * At this moment xfrm uses nlmsg_parse_deprecated(), which
1727 	 * implies NL_VALIDATE_LIBERAL - ignoring attributes with
1728 	 * (type > maxtype). nla_parse_depricated_strict() would enforce
1729 	 * it. Or even stricter nla_parse().
1730 	 * Right now it's not expected to fail, but to be ignored.
1731 	 */
1732 	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 32, 128, 128, true))
1733 		return KSFT_PASS;
1734 
1735 	return KSFT_PASS;
1736 }
1737 
1738 static int child_serv(int xfrm_sock, uint32_t *seq,
1739 		unsigned int nr, int cmd_fd, void *buf, struct xfrm_desc *desc)
1740 {
1741 	struct in_addr src, dst, tunsrc, tundst;
1742 	struct test_desc msg;
1743 	int ret = KSFT_FAIL;
1744 
1745 	src = inet_makeaddr(INADDR_B, child_ip(nr));
1746 	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1747 	tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1748 	tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1749 
1750 	/* UDP pinging without xfrm */
1751 	if (do_ping(cmd_fd, buf, page_size, src, true, 0, 0, udp_ping_send)) {
1752 		printk("ping failed before setting xfrm");
1753 		return KSFT_FAIL;
1754 	}
1755 
1756 	memset(&msg, 0, sizeof(msg));
1757 	msg.type = MSG_XFRM_PREPARE;
1758 	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1759 	write_msg(cmd_fd, &msg, 1);
1760 
1761 	if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1762 		printk("failed to prepare xfrm");
1763 		goto cleanup;
1764 	}
1765 
1766 	memset(&msg, 0, sizeof(msg));
1767 	msg.type = MSG_XFRM_ADD;
1768 	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1769 	write_msg(cmd_fd, &msg, 1);
1770 	if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1771 		printk("failed to set xfrm");
1772 		goto delete;
1773 	}
1774 
1775 	/* UDP pinging with xfrm tunnel */
1776 	if (do_ping(cmd_fd, buf, page_size, tunsrc,
1777 				true, 0, 0, udp_ping_send)) {
1778 		printk("ping failed for xfrm");
1779 		goto delete;
1780 	}
1781 
1782 	ret = KSFT_PASS;
1783 delete:
1784 	/* xfrm delete */
1785 	memset(&msg, 0, sizeof(msg));
1786 	msg.type = MSG_XFRM_DEL;
1787 	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1788 	write_msg(cmd_fd, &msg, 1);
1789 
1790 	if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1791 		printk("failed ping to remove xfrm");
1792 		ret = KSFT_FAIL;
1793 	}
1794 
1795 cleanup:
1796 	memset(&msg, 0, sizeof(msg));
1797 	msg.type = MSG_XFRM_CLEANUP;
1798 	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1799 	write_msg(cmd_fd, &msg, 1);
1800 	if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1801 		printk("failed ping to cleanup xfrm");
1802 		ret = KSFT_FAIL;
1803 	}
1804 	return ret;
1805 }
1806 
1807 static int child_f(unsigned int nr, int test_desc_fd, int cmd_fd, void *buf)
1808 {
1809 	struct xfrm_desc desc;
1810 	struct test_desc msg;
1811 	int xfrm_sock = -1;
1812 	uint32_t seq;
1813 
1814 	if (switch_ns(nsfd_childa))
1815 		exit(KSFT_FAIL);
1816 
1817 	if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1818 		printk("Failed to open xfrm netlink socket");
1819 		exit(KSFT_FAIL);
1820 	}
1821 
1822 	/* Check that seq sock is ready, just for sure. */
1823 	memset(&msg, 0, sizeof(msg));
1824 	msg.type = MSG_ACK;
1825 	write_msg(cmd_fd, &msg, 1);
1826 	read_msg(cmd_fd, &msg, 1);
1827 	if (msg.type != MSG_ACK) {
1828 		printk("Ack failed");
1829 		exit(KSFT_FAIL);
1830 	}
1831 
1832 	for (;;) {
1833 		ssize_t received = read(test_desc_fd, &desc, sizeof(desc));
1834 		int ret;
1835 
1836 		if (received == 0) /* EOF */
1837 			break;
1838 
1839 		if (received != sizeof(desc)) {
1840 			pr_err("read() returned %zd", received);
1841 			exit(KSFT_FAIL);
1842 		}
1843 
1844 		switch (desc.type) {
1845 		case CREATE_TUNNEL:
1846 			ret = child_serv(xfrm_sock, &seq, nr,
1847 					 cmd_fd, buf, &desc);
1848 			break;
1849 		case ALLOCATE_SPI:
1850 			ret = xfrm_state_allocspi(xfrm_sock, &seq,
1851 						  -1, desc.proto);
1852 			break;
1853 		case MONITOR_ACQUIRE:
1854 			ret = xfrm_monitor_acquire(xfrm_sock, &seq, nr);
1855 			break;
1856 		case EXPIRE_STATE:
1857 			ret = xfrm_expire_state(xfrm_sock, &seq, nr, &desc);
1858 			break;
1859 		case EXPIRE_POLICY:
1860 			ret = xfrm_expire_policy(xfrm_sock, &seq, nr, &desc);
1861 			break;
1862 		case SPDINFO_ATTRS:
1863 			ret = xfrm_spdinfo_attrs(xfrm_sock, &seq);
1864 			break;
1865 		default:
1866 			printk("Unknown desc type %d", desc.type);
1867 			exit(KSFT_FAIL);
1868 		}
1869 		write_test_result(ret, &desc);
1870 	}
1871 
1872 	close(xfrm_sock);
1873 
1874 	msg.type = MSG_EXIT;
1875 	write_msg(cmd_fd, &msg, 1);
1876 	exit(KSFT_PASS);
1877 }
1878 
1879 static void grand_child_serv(unsigned int nr, int cmd_fd, void *buf,
1880 		struct test_desc *msg, int xfrm_sock, uint32_t *seq)
1881 {
1882 	struct in_addr src, dst, tunsrc, tundst;
1883 	bool tun_reply;
1884 	struct xfrm_desc *desc = &msg->body.xfrm_desc;
1885 
1886 	src = inet_makeaddr(INADDR_B, grchild_ip(nr));
1887 	dst = inet_makeaddr(INADDR_B, child_ip(nr));
1888 	tunsrc = inet_makeaddr(INADDR_A, grchild_ip(nr));
1889 	tundst = inet_makeaddr(INADDR_A, child_ip(nr));
1890 
1891 	switch (msg->type) {
1892 	case MSG_EXIT:
1893 		exit(KSFT_PASS);
1894 	case MSG_ACK:
1895 		write_msg(cmd_fd, msg, 1);
1896 		break;
1897 	case MSG_PING:
1898 		tun_reply = memcmp(&dst, &msg->body.ping.reply_ip, sizeof(in_addr_t));
1899 		/* UDP pinging without xfrm */
1900 		if (do_ping(cmd_fd, buf, page_size, tun_reply ? tunsrc : src,
1901 				false, msg->body.ping.port,
1902 				msg->body.ping.reply_ip, udp_ping_reply)) {
1903 			printk("ping failed before setting xfrm");
1904 		}
1905 		break;
1906 	case MSG_XFRM_PREPARE:
1907 		if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst,
1908 					desc->proto)) {
1909 			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1910 			printk("failed to prepare xfrm");
1911 		}
1912 		break;
1913 	case MSG_XFRM_ADD:
1914 		if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1915 			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1916 			printk("failed to set xfrm");
1917 		}
1918 		break;
1919 	case MSG_XFRM_DEL:
1920 		if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst,
1921 					desc->proto)) {
1922 			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1923 			printk("failed to remove xfrm");
1924 		}
1925 		break;
1926 	case MSG_XFRM_CLEANUP:
1927 		if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1928 			printk("failed to cleanup xfrm");
1929 		}
1930 		break;
1931 	default:
1932 		printk("got unknown msg type %d", msg->type);
1933 	}
1934 }
1935 
1936 static int grand_child_f(unsigned int nr, int cmd_fd, void *buf)
1937 {
1938 	struct test_desc msg;
1939 	int xfrm_sock = -1;
1940 	uint32_t seq;
1941 
1942 	if (switch_ns(nsfd_childb))
1943 		exit(KSFT_FAIL);
1944 
1945 	if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1946 		printk("Failed to open xfrm netlink socket");
1947 		exit(KSFT_FAIL);
1948 	}
1949 
1950 	do {
1951 		read_msg(cmd_fd, &msg, 1);
1952 		grand_child_serv(nr, cmd_fd, buf, &msg, xfrm_sock, &seq);
1953 	} while (1);
1954 
1955 	close(xfrm_sock);
1956 	exit(KSFT_FAIL);
1957 }
1958 
1959 static int start_child(unsigned int nr, char *veth, int test_desc_fd[2])
1960 {
1961 	int cmd_sock[2];
1962 	void *data_map;
1963 	pid_t child;
1964 
1965 	if (init_child(nsfd_childa, veth, child_ip(nr), grchild_ip(nr)))
1966 		return -1;
1967 
1968 	if (init_child(nsfd_childb, veth, grchild_ip(nr), child_ip(nr)))
1969 		return -1;
1970 
1971 	child = fork();
1972 	if (child < 0) {
1973 		pr_err("fork()");
1974 		return -1;
1975 	} else if (child) {
1976 		/* in parent - selftest */
1977 		return switch_ns(nsfd_parent);
1978 	}
1979 
1980 	if (close(test_desc_fd[1])) {
1981 		pr_err("close()");
1982 		return -1;
1983 	}
1984 
1985 	/* child */
1986 	data_map = mmap(0, page_size, PROT_READ | PROT_WRITE,
1987 			MAP_SHARED | MAP_ANONYMOUS, -1, 0);
1988 	if (data_map == MAP_FAILED) {
1989 		pr_err("mmap()");
1990 		return -1;
1991 	}
1992 
1993 	randomize_buffer(data_map, page_size);
1994 
1995 	if (socketpair(PF_LOCAL, SOCK_SEQPACKET, 0, cmd_sock)) {
1996 		pr_err("socketpair()");
1997 		return -1;
1998 	}
1999 
2000 	child = fork();
2001 	if (child < 0) {
2002 		pr_err("fork()");
2003 		return -1;
2004 	} else if (child) {
2005 		if (close(cmd_sock[0])) {
2006 			pr_err("close()");
2007 			return -1;
2008 		}
2009 		return child_f(nr, test_desc_fd[0], cmd_sock[1], data_map);
2010 	}
2011 	if (close(cmd_sock[1])) {
2012 		pr_err("close()");
2013 		return -1;
2014 	}
2015 	return grand_child_f(nr, cmd_sock[0], data_map);
2016 }
2017 
2018 static void exit_usage(char **argv)
2019 {
2020 	printk("Usage: %s [nr_process]", argv[0]);
2021 	exit(KSFT_FAIL);
2022 }
2023 
2024 static int __write_desc(int test_desc_fd, struct xfrm_desc *desc)
2025 {
2026 	ssize_t ret;
2027 
2028 	ret = write(test_desc_fd, desc, sizeof(*desc));
2029 
2030 	if (ret == sizeof(*desc))
2031 		return 0;
2032 
2033 	pr_err("Writing test's desc failed %ld", ret);
2034 
2035 	return -1;
2036 }
2037 
2038 static int write_desc(int proto, int test_desc_fd,
2039 		char *a, char *e, char *c, char *ae)
2040 {
2041 	struct xfrm_desc desc = {};
2042 
2043 	desc.type = CREATE_TUNNEL;
2044 	desc.proto = proto;
2045 
2046 	if (a)
2047 		strncpy(desc.a_algo, a, ALGO_LEN - 1);
2048 	if (e)
2049 		strncpy(desc.e_algo, e, ALGO_LEN - 1);
2050 	if (c)
2051 		strncpy(desc.c_algo, c, ALGO_LEN - 1);
2052 	if (ae)
2053 		strncpy(desc.ae_algo, ae, ALGO_LEN - 1);
2054 
2055 	return __write_desc(test_desc_fd, &desc);
2056 }
2057 
2058 int proto_list[] = { IPPROTO_AH, IPPROTO_COMP, IPPROTO_ESP };
2059 char *ah_list[] = {
2060 	"digest_null", "hmac(md5)", "hmac(sha1)", "hmac(sha256)",
2061 	"hmac(sha384)", "hmac(sha512)", "xcbc(aes)", "cmac(aes)"
2062 };
2063 char *comp_list[] = {
2064 	"deflate",
2065 #if 0
2066 	/* No compression backend realization */
2067 	"lzs", "lzjh"
2068 #endif
2069 };
2070 char *e_list[] = {
2071 	"ecb(cipher_null)", "cbc(des)", "cbc(des3_ede)", "cbc(cast5)",
2072 	"cbc(blowfish)", "cbc(aes)", "cbc(serpent)", "cbc(camellia)",
2073 	"cbc(twofish)", "rfc3686(ctr(aes))"
2074 };
2075 char *ae_list[] = {
2076 #if 0
2077 	/* not implemented */
2078 	"rfc4106(gcm(aes))", "rfc4309(ccm(aes))", "rfc4543(gcm(aes))",
2079 	"rfc7539esp(chacha20,poly1305)"
2080 #endif
2081 };
2082 
2083 const unsigned int proto_plan = ARRAY_SIZE(ah_list) + ARRAY_SIZE(comp_list) \
2084 				+ (ARRAY_SIZE(ah_list) * ARRAY_SIZE(e_list)) \
2085 				+ ARRAY_SIZE(ae_list);
2086 
2087 static int write_proto_plan(int fd, int proto)
2088 {
2089 	unsigned int i;
2090 
2091 	switch (proto) {
2092 	case IPPROTO_AH:
2093 		for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
2094 			if (write_desc(proto, fd, ah_list[i], 0, 0, 0))
2095 				return -1;
2096 		}
2097 		break;
2098 	case IPPROTO_COMP:
2099 		for (i = 0; i < ARRAY_SIZE(comp_list); i++) {
2100 			if (write_desc(proto, fd, 0, 0, comp_list[i], 0))
2101 				return -1;
2102 		}
2103 		break;
2104 	case IPPROTO_ESP:
2105 		for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
2106 			int j;
2107 
2108 			for (j = 0; j < ARRAY_SIZE(e_list); j++) {
2109 				if (write_desc(proto, fd, ah_list[i],
2110 							e_list[j], 0, 0))
2111 					return -1;
2112 			}
2113 		}
2114 		for (i = 0; i < ARRAY_SIZE(ae_list); i++) {
2115 			if (write_desc(proto, fd, 0, 0, 0, ae_list[i]))
2116 				return -1;
2117 		}
2118 		break;
2119 	default:
2120 		printk("BUG: Specified unknown proto %d", proto);
2121 		return -1;
2122 	}
2123 
2124 	return 0;
2125 }
2126 
2127 /*
2128  * Some structures in xfrm uapi header differ in size between
2129  * 64-bit and 32-bit ABI:
2130  *
2131  *             32-bit UABI               |            64-bit UABI
2132  *  -------------------------------------|-------------------------------------
2133  *   sizeof(xfrm_usersa_info)     = 220  |  sizeof(xfrm_usersa_info)     = 224
2134  *   sizeof(xfrm_userpolicy_info) = 164  |  sizeof(xfrm_userpolicy_info) = 168
2135  *   sizeof(xfrm_userspi_info)    = 228  |  sizeof(xfrm_userspi_info)    = 232
2136  *   sizeof(xfrm_user_acquire)    = 276  |  sizeof(xfrm_user_acquire)    = 280
2137  *   sizeof(xfrm_user_expire)     = 224  |  sizeof(xfrm_user_expire)     = 232
2138  *   sizeof(xfrm_user_polexpire)  = 168  |  sizeof(xfrm_user_polexpire)  = 176
2139  *
2140  * Check the affected by the UABI difference structures.
2141  * Also, check translation for xfrm_set_spdinfo: it has it's own attributes
2142  * which needs to be correctly copied, but not translated.
2143  */
2144 const unsigned int compat_plan = 5;
2145 static int write_compat_struct_tests(int test_desc_fd)
2146 {
2147 	struct xfrm_desc desc = {};
2148 
2149 	desc.type = ALLOCATE_SPI;
2150 	desc.proto = IPPROTO_AH;
2151 	strncpy(desc.a_algo, ah_list[0], ALGO_LEN - 1);
2152 
2153 	if (__write_desc(test_desc_fd, &desc))
2154 		return -1;
2155 
2156 	desc.type = MONITOR_ACQUIRE;
2157 	if (__write_desc(test_desc_fd, &desc))
2158 		return -1;
2159 
2160 	desc.type = EXPIRE_STATE;
2161 	if (__write_desc(test_desc_fd, &desc))
2162 		return -1;
2163 
2164 	desc.type = EXPIRE_POLICY;
2165 	if (__write_desc(test_desc_fd, &desc))
2166 		return -1;
2167 
2168 	desc.type = SPDINFO_ATTRS;
2169 	if (__write_desc(test_desc_fd, &desc))
2170 		return -1;
2171 
2172 	return 0;
2173 }
2174 
2175 static int write_test_plan(int test_desc_fd)
2176 {
2177 	unsigned int i;
2178 	pid_t child;
2179 
2180 	child = fork();
2181 	if (child < 0) {
2182 		pr_err("fork()");
2183 		return -1;
2184 	}
2185 	if (child) {
2186 		if (close(test_desc_fd))
2187 			printk("close(): %m");
2188 		return 0;
2189 	}
2190 
2191 	if (write_compat_struct_tests(test_desc_fd))
2192 		exit(KSFT_FAIL);
2193 
2194 	for (i = 0; i < ARRAY_SIZE(proto_list); i++) {
2195 		if (write_proto_plan(test_desc_fd, proto_list[i]))
2196 			exit(KSFT_FAIL);
2197 	}
2198 
2199 	exit(KSFT_PASS);
2200 }
2201 
2202 static int children_cleanup(void)
2203 {
2204 	unsigned ret = KSFT_PASS;
2205 
2206 	while (1) {
2207 		int status;
2208 		pid_t p = wait(&status);
2209 
2210 		if ((p < 0) && errno == ECHILD)
2211 			break;
2212 
2213 		if (p < 0) {
2214 			pr_err("wait()");
2215 			return KSFT_FAIL;
2216 		}
2217 
2218 		if (!WIFEXITED(status)) {
2219 			ret = KSFT_FAIL;
2220 			continue;
2221 		}
2222 
2223 		if (WEXITSTATUS(status) == KSFT_FAIL)
2224 			ret = KSFT_FAIL;
2225 	}
2226 
2227 	return ret;
2228 }
2229 
2230 typedef void (*print_res)(const char *, ...);
2231 
2232 static int check_results(void)
2233 {
2234 	struct test_result tr = {};
2235 	struct xfrm_desc *d = &tr.desc;
2236 	int ret = KSFT_PASS;
2237 
2238 	while (1) {
2239 		ssize_t received = read(results_fd[0], &tr, sizeof(tr));
2240 		print_res result;
2241 
2242 		if (received == 0) /* EOF */
2243 			break;
2244 
2245 		if (received != sizeof(tr)) {
2246 			pr_err("read() returned %zd", received);
2247 			return KSFT_FAIL;
2248 		}
2249 
2250 		switch (tr.res) {
2251 		case KSFT_PASS:
2252 			result = ksft_test_result_pass;
2253 			break;
2254 		case KSFT_FAIL:
2255 		default:
2256 			result = ksft_test_result_fail;
2257 			ret = KSFT_FAIL;
2258 		}
2259 
2260 		result(" %s: [%u, '%s', '%s', '%s', '%s', %u]\n",
2261 		       desc_name[d->type], (unsigned int)d->proto, d->a_algo,
2262 		       d->e_algo, d->c_algo, d->ae_algo, d->icv_len);
2263 	}
2264 
2265 	return ret;
2266 }
2267 
2268 int main(int argc, char **argv)
2269 {
2270 	long nr_process = 1;
2271 	int route_sock = -1, ret = KSFT_SKIP;
2272 	int test_desc_fd[2];
2273 	uint32_t route_seq;
2274 	unsigned int i;
2275 
2276 	if (argc > 2)
2277 		exit_usage(argv);
2278 
2279 	if (argc > 1) {
2280 		char *endptr;
2281 
2282 		errno = 0;
2283 		nr_process = strtol(argv[1], &endptr, 10);
2284 		if ((errno == ERANGE && (nr_process == LONG_MAX || nr_process == LONG_MIN))
2285 				|| (errno != 0 && nr_process == 0)
2286 				|| (endptr == argv[1]) || (*endptr != '\0')) {
2287 			printk("Failed to parse [nr_process]");
2288 			exit_usage(argv);
2289 		}
2290 
2291 		if (nr_process > MAX_PROCESSES || nr_process < 1) {
2292 			printk("nr_process should be between [1; %u]",
2293 					MAX_PROCESSES);
2294 			exit_usage(argv);
2295 		}
2296 	}
2297 
2298 	srand(time(NULL));
2299 	page_size = sysconf(_SC_PAGESIZE);
2300 	if (page_size < 1)
2301 		ksft_exit_skip("sysconf(): %m\n");
2302 
2303 	if (pipe2(test_desc_fd, O_DIRECT) < 0)
2304 		ksft_exit_skip("pipe(): %m\n");
2305 
2306 	if (pipe2(results_fd, O_DIRECT) < 0)
2307 		ksft_exit_skip("pipe(): %m\n");
2308 
2309 	if (init_namespaces())
2310 		ksft_exit_skip("Failed to create namespaces\n");
2311 
2312 	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
2313 		ksft_exit_skip("Failed to open netlink route socket\n");
2314 
2315 	for (i = 0; i < nr_process; i++) {
2316 		char veth[VETH_LEN];
2317 
2318 		snprintf(veth, VETH_LEN, VETH_FMT, i);
2319 
2320 		if (veth_add(route_sock, route_seq++, veth, nsfd_childa, veth, nsfd_childb)) {
2321 			close(route_sock);
2322 			ksft_exit_fail_msg("Failed to create veth device");
2323 		}
2324 
2325 		if (start_child(i, veth, test_desc_fd)) {
2326 			close(route_sock);
2327 			ksft_exit_fail_msg("Child %u failed to start", i);
2328 		}
2329 	}
2330 
2331 	if (close(route_sock) || close(test_desc_fd[0]) || close(results_fd[1]))
2332 		ksft_exit_fail_msg("close(): %m");
2333 
2334 	ksft_set_plan(proto_plan + compat_plan);
2335 
2336 	if (write_test_plan(test_desc_fd[1]))
2337 		ksft_exit_fail_msg("Failed to write test plan to pipe");
2338 
2339 	ret = check_results();
2340 
2341 	if (children_cleanup() == KSFT_FAIL)
2342 		exit(KSFT_FAIL);
2343 
2344 	exit(ret);
2345 }
2346