xref: /linux/tools/testing/selftests/bpf/prog_tests/test_bpf_smc.c (revision c4dde411bc366f568dbe33366253bbfea049e8ea)
1 // SPDX-License-Identifier: GPL-2.0
2 #include <test_progs.h>
3 #include <linux/genetlink.h>
4 #include "network_helpers.h"
5 #include "bpf_smc.skel.h"
6 
7 #ifndef IPPROTO_SMC
8 #define IPPROTO_SMC 256
9 #endif
10 
11 #define CLIENT_IP			"127.0.0.1"
12 #define SERVER_IP			"127.0.1.0"
13 #define SERVER_IP_VIA_RISK_PATH	"127.0.2.0"
14 
15 #define SERVICE_1	80
16 #define SERVICE_2	443
17 #define SERVICE_3	8443
18 
19 #define TEST_NS	"bpf_smc_netns"
20 
21 static struct netns_obj *test_netns;
22 
23 struct smc_policy_ip_key {
24 	__u32  sip;
25 	__u32  dip;
26 };
27 
28 struct smc_policy_ip_value {
29 	__u8	mode;
30 };
31 
32 #if defined(__s390x__)
33 /* s390x has default seid  */
34 static bool setup_ueid(void) { return true; }
35 static void cleanup_ueid(void) {}
36 #else
37 enum {
38 	SMC_NETLINK_ADD_UEID = 10,
39 	SMC_NETLINK_REMOVE_UEID
40 };
41 
42 enum {
43 	SMC_NLA_EID_TABLE_UNSPEC,
44 	SMC_NLA_EID_TABLE_ENTRY,    /* string */
45 };
46 
47 struct msgtemplate {
48 	struct nlmsghdr n;
49 	struct genlmsghdr g;
50 	char buf[1024];
51 };
52 
53 #define GENLMSG_DATA(glh)	((void *)(NLMSG_DATA(glh) + GENL_HDRLEN))
54 #define GENLMSG_PAYLOAD(glh)	(NLMSG_PAYLOAD(glh, 0) - GENL_HDRLEN)
55 #define NLA_DATA(na)		((void *)((char *)(na) + NLA_HDRLEN))
56 #define NLA_PAYLOAD(len)	((len) - NLA_HDRLEN)
57 
58 #define SMC_GENL_FAMILY_NAME	"SMC_GEN_NETLINK"
59 #define SMC_BPFTEST_UEID	"SMC-BPFTEST-UEID"
60 
61 static uint16_t smc_nl_family_id = -1;
62 
63 static int send_cmd(int fd, __u16 nlmsg_type, __u32 nlmsg_pid,
64 		    __u16 nlmsg_flags, __u8 genl_cmd, __u16 nla_type,
65 		    void *nla_data, int nla_len)
66 {
67 	struct nlattr *na;
68 	struct sockaddr_nl nladdr;
69 	int r, buflen;
70 	char *buf;
71 
72 	struct msgtemplate msg = {0};
73 
74 	msg.n.nlmsg_len = NLMSG_LENGTH(GENL_HDRLEN);
75 	msg.n.nlmsg_type = nlmsg_type;
76 	msg.n.nlmsg_flags = nlmsg_flags;
77 	msg.n.nlmsg_seq = 0;
78 	msg.n.nlmsg_pid = nlmsg_pid;
79 	msg.g.cmd = genl_cmd;
80 	msg.g.version = 1;
81 	na = (struct nlattr *)GENLMSG_DATA(&msg);
82 	na->nla_type = nla_type;
83 	na->nla_len = nla_len + 1 + NLA_HDRLEN;
84 	memcpy(NLA_DATA(na), nla_data, nla_len);
85 	msg.n.nlmsg_len += NLMSG_ALIGN(na->nla_len);
86 
87 	buf = (char *)&msg;
88 	buflen = msg.n.nlmsg_len;
89 	memset(&nladdr, 0, sizeof(nladdr));
90 	nladdr.nl_family = AF_NETLINK;
91 
92 	while ((r = sendto(fd, buf, buflen, 0, (struct sockaddr *)&nladdr,
93 			   sizeof(nladdr))) < buflen) {
94 		if (r > 0) {
95 			buf += r;
96 			buflen -= r;
97 		} else if (errno != EAGAIN) {
98 			return -1;
99 		}
100 	}
101 	return 0;
102 }
103 
104 static bool get_smc_nl_family_id(void)
105 {
106 	struct sockaddr_nl nl_src;
107 	struct msgtemplate msg;
108 	struct nlattr *nl;
109 	int fd, ret;
110 	pid_t pid;
111 
112 	fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_GENERIC);
113 	if (!ASSERT_OK_FD(fd, "nl_family socket"))
114 		return false;
115 
116 	pid = getpid();
117 
118 	memset(&nl_src, 0, sizeof(nl_src));
119 	nl_src.nl_family = AF_NETLINK;
120 	nl_src.nl_pid = pid;
121 
122 	ret = bind(fd, (struct sockaddr *)&nl_src, sizeof(nl_src));
123 	if (!ASSERT_OK(ret, "nl_family bind"))
124 		goto fail;
125 
126 	ret = send_cmd(fd, GENL_ID_CTRL, pid,
127 		       NLM_F_REQUEST, CTRL_CMD_GETFAMILY,
128 		       CTRL_ATTR_FAMILY_NAME, (void *)SMC_GENL_FAMILY_NAME,
129 		       strlen(SMC_GENL_FAMILY_NAME));
130 	if (!ASSERT_OK(ret, "nl_family query"))
131 		goto fail;
132 
133 	ret = recv(fd, &msg, sizeof(msg), 0);
134 	if (msg.n.nlmsg_type == NLMSG_ERROR)
135 		goto fail;
136 	if (!ASSERT_FALSE(ret < 0 || !NLMSG_OK(&msg.n, ret),
137 			  "nl_family response"))
138 		goto fail;
139 
140 	nl = (struct nlattr *)GENLMSG_DATA(&msg);
141 	nl = (struct nlattr *)((char *)nl + NLA_ALIGN(nl->nla_len));
142 	if (!ASSERT_EQ(nl->nla_type, CTRL_ATTR_FAMILY_ID, "nl_family nla type"))
143 		goto fail;
144 
145 	smc_nl_family_id = *(uint16_t *)NLA_DATA(nl);
146 	close(fd);
147 	return true;
148 fail:
149 	close(fd);
150 	return false;
151 }
152 
153 static bool smc_ueid(int op)
154 {
155 	struct sockaddr_nl nl_src;
156 	struct msgtemplate msg;
157 	struct nlmsgerr *err;
158 	char test_ueid[32];
159 	int fd, ret;
160 	pid_t pid;
161 
162 	/* UEID required */
163 	memset(test_ueid, '\x20', sizeof(test_ueid));
164 	memcpy(test_ueid, SMC_BPFTEST_UEID, strlen(SMC_BPFTEST_UEID));
165 	fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_GENERIC);
166 	if (!ASSERT_OK_FD(fd, "ueid socket"))
167 		return false;
168 
169 	pid = getpid();
170 	memset(&nl_src, 0, sizeof(nl_src));
171 	nl_src.nl_family = AF_NETLINK;
172 	nl_src.nl_pid = pid;
173 
174 	ret = bind(fd, (struct sockaddr *)&nl_src, sizeof(nl_src));
175 	if (!ASSERT_OK(ret, "ueid bind"))
176 		goto fail;
177 
178 	ret = send_cmd(fd, smc_nl_family_id, pid,
179 		       NLM_F_REQUEST | NLM_F_ACK, op, SMC_NLA_EID_TABLE_ENTRY,
180 		       (void *)test_ueid, sizeof(test_ueid));
181 	if (!ASSERT_OK(ret, "ueid cmd"))
182 		goto fail;
183 
184 	ret = recv(fd, &msg, sizeof(msg), 0);
185 	if (!ASSERT_FALSE(ret < 0 ||
186 			  !NLMSG_OK(&msg.n, ret), "ueid response"))
187 		goto fail;
188 
189 	if (msg.n.nlmsg_type == NLMSG_ERROR) {
190 		err = NLMSG_DATA(&msg);
191 		switch (op) {
192 		case SMC_NETLINK_REMOVE_UEID:
193 			if (!ASSERT_FALSE((err->error && err->error != -ENOENT),
194 					  "ueid remove"))
195 				goto fail;
196 			break;
197 		case SMC_NETLINK_ADD_UEID:
198 			if (!ASSERT_OK(err->error, "ueid add"))
199 				goto fail;
200 			break;
201 		default:
202 			break;
203 		}
204 	}
205 	close(fd);
206 	return true;
207 fail:
208 	close(fd);
209 	return false;
210 }
211 
212 static bool setup_ueid(void)
213 {
214 	/* get smc nl id */
215 	if (!get_smc_nl_family_id())
216 		return false;
217 	/* clear old ueid for bpftest */
218 	smc_ueid(SMC_NETLINK_REMOVE_UEID);
219 	/* smc-loopback required ueid */
220 	return smc_ueid(SMC_NETLINK_ADD_UEID);
221 }
222 
223 static void cleanup_ueid(void)
224 {
225 	smc_ueid(SMC_NETLINK_REMOVE_UEID);
226 }
227 #endif /* __s390x__ */
228 
229 static bool setup_netns(void)
230 {
231 	test_netns = netns_new(TEST_NS, true);
232 	if (!ASSERT_OK_PTR(test_netns, "open net namespace"))
233 		goto fail_netns;
234 
235 	SYS(fail_ip, "ip addr add 127.0.1.0/8 dev lo");
236 	SYS(fail_ip, "ip addr add 127.0.2.0/8 dev lo");
237 
238 	return true;
239 fail_ip:
240 	netns_free(test_netns);
241 fail_netns:
242 	return false;
243 }
244 
245 static void cleanup_netns(void)
246 {
247 	netns_free(test_netns);
248 }
249 
250 static bool setup_smc(void)
251 {
252 	if (!setup_ueid())
253 		return false;
254 
255 	if (!setup_netns())
256 		goto fail_netns;
257 
258 	return true;
259 fail_netns:
260 	cleanup_ueid();
261 	return false;
262 }
263 
264 static int set_client_addr_cb(int fd, void *opts)
265 {
266 	const char *src = (const char *)opts;
267 	struct sockaddr_in localaddr;
268 
269 	localaddr.sin_family = AF_INET;
270 	localaddr.sin_port = htons(0);
271 	localaddr.sin_addr.s_addr = inet_addr(src);
272 	return !ASSERT_OK(bind(fd, &localaddr, sizeof(localaddr)), "client bind");
273 }
274 
275 static void run_link(const char *src, const char *dst, int port)
276 {
277 	struct network_helper_opts opts = {0};
278 	int server, client;
279 
280 	server = start_server_str(AF_INET, SOCK_STREAM, dst, port, NULL);
281 	if (!ASSERT_OK_FD(server, "start service_1"))
282 		return;
283 
284 	opts.proto = IPPROTO_TCP;
285 	opts.post_socket_cb = set_client_addr_cb;
286 	opts.cb_opts = (void *)src;
287 
288 	client = connect_to_fd_opts(server, &opts);
289 	if (!ASSERT_OK_FD(client, "start connect"))
290 		goto fail_client;
291 
292 	close(client);
293 fail_client:
294 	close(server);
295 }
296 
297 static void block_link(int map_fd, const char *src, const char *dst)
298 {
299 	struct smc_policy_ip_value val = { .mode = /* block */ 0 };
300 	struct smc_policy_ip_key key = {
301 		.sip = inet_addr(src),
302 		.dip = inet_addr(dst),
303 	};
304 
305 	bpf_map_update_elem(map_fd, &key, &val, BPF_ANY);
306 }
307 
308 /*
309  * This test describes a real-life service topology as follows:
310  *
311  *                             +-------------> service_1
312  *            link 1           |                     |
313  *   +--------------------> server                   |  link 2
314  *   |                         |                     V
315  *   |                         +-------------> service_2
316  *   |        link 3
317  *  client -------------------> server_via_unsafe_path -> service_3
318  *
319  * Among them,
320  * 1. link-1 is very suitable for using SMC.
321  * 2. link-2 is not suitable for using SMC, because the mode of this link is
322  *    kind of short-link services.
323  * 3. link-3 is also not suitable for using SMC, because the RDMA link is
324  *    unavailable and needs to go through a long timeout before it can fallback
325  *    to TCP.
326  * To achieve this goal, we use a customized SMC ip strategy via smc_hs_ctrl.
327  */
328 static void test_topo(void)
329 {
330 	struct bpf_smc *skel;
331 	int rc, map_fd;
332 
333 	skel = bpf_smc__open_and_load();
334 	if (!ASSERT_OK_PTR(skel, "bpf_smc__open_and_load"))
335 		return;
336 
337 	rc = bpf_smc__attach(skel);
338 	if (!ASSERT_OK(rc, "bpf_smc__attach"))
339 		goto fail;
340 
341 	map_fd = bpf_map__fd(skel->maps.smc_policy_ip);
342 	if (!ASSERT_OK_FD(map_fd, "bpf_map__fd"))
343 		goto fail;
344 
345 	/* Mock the process of transparent replacement, since we will modify
346 	 * protocol to ipproto_smc accropding to it via
347 	 * fmod_ret/update_socket_protocol.
348 	 */
349 	write_sysctl("/proc/sys/net/smc/hs_ctrl", "linkcheck");
350 
351 	/* Configure ip strat */
352 	block_link(map_fd, CLIENT_IP, SERVER_IP_VIA_RISK_PATH);
353 	block_link(map_fd, SERVER_IP, SERVER_IP);
354 
355 	/* should go with smc */
356 	run_link(CLIENT_IP, SERVER_IP, SERVICE_1);
357 	/* should go with smc fallback */
358 	run_link(SERVER_IP, SERVER_IP, SERVICE_2);
359 
360 	ASSERT_EQ(skel->bss->smc_cnt, 2, "smc count");
361 	ASSERT_EQ(skel->bss->fallback_cnt, 1, "fallback count");
362 
363 	/* should go with smc */
364 	run_link(CLIENT_IP, SERVER_IP, SERVICE_2);
365 
366 	ASSERT_EQ(skel->bss->smc_cnt, 3, "smc count");
367 	ASSERT_EQ(skel->bss->fallback_cnt, 1, "fallback count");
368 
369 	/* should go with smc fallback */
370 	run_link(CLIENT_IP, SERVER_IP_VIA_RISK_PATH, SERVICE_3);
371 
372 	ASSERT_EQ(skel->bss->smc_cnt, 4, "smc count");
373 	ASSERT_EQ(skel->bss->fallback_cnt, 2, "fallback count");
374 
375 fail:
376 	bpf_smc__destroy(skel);
377 }
378 
379 void test_bpf_smc(void)
380 {
381 	if (!setup_smc()) {
382 		printf("setup for smc test failed, test SKIP:\n");
383 		test__skip();
384 		return;
385 	}
386 
387 	if (test__start_subtest("topo"))
388 		test_topo();
389 
390 	cleanup_ueid();
391 	cleanup_netns();
392 }
393