xref: /linux/tools/testing/selftests/bpf/prog_tests/sockmap_ktls.c (revision fcab107abe1ab5be9dbe874baa722372da8f4f73)
1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (c) 2020 Cloudflare
3 /*
4  * Tests for sockmap/sockhash holding kTLS sockets.
5  */
6 #include <error.h>
7 #include <netinet/tcp.h>
8 #include <linux/tls.h>
9 #include "test_progs.h"
10 #include "sockmap_helpers.h"
11 #include "test_skmsg_load_helpers.skel.h"
12 #include "test_sockmap_ktls.skel.h"
13 
14 #define MAX_TEST_NAME 80
15 #define TCP_ULP 31
16 
17 static int init_ktls_pairs(int c, int p)
18 {
19 	int err;
20 	struct tls12_crypto_info_aes_gcm_128 crypto_rx;
21 	struct tls12_crypto_info_aes_gcm_128 crypto_tx;
22 
23 	err = setsockopt(c, IPPROTO_TCP, TCP_ULP, "tls", strlen("tls"));
24 	if (!ASSERT_OK(err, "setsockopt(TCP_ULP)"))
25 		goto out;
26 
27 	err = setsockopt(p, IPPROTO_TCP, TCP_ULP, "tls", strlen("tls"));
28 	if (!ASSERT_OK(err, "setsockopt(TCP_ULP)"))
29 		goto out;
30 
31 	memset(&crypto_rx, 0, sizeof(crypto_rx));
32 	memset(&crypto_tx, 0, sizeof(crypto_tx));
33 	crypto_rx.info.version = TLS_1_2_VERSION;
34 	crypto_tx.info.version = TLS_1_2_VERSION;
35 	crypto_rx.info.cipher_type = TLS_CIPHER_AES_GCM_128;
36 	crypto_tx.info.cipher_type = TLS_CIPHER_AES_GCM_128;
37 
38 	err = setsockopt(c, SOL_TLS, TLS_TX, &crypto_tx, sizeof(crypto_tx));
39 	if (!ASSERT_OK(err, "setsockopt(TLS_TX)"))
40 		goto out;
41 
42 	err = setsockopt(p, SOL_TLS, TLS_RX, &crypto_rx, sizeof(crypto_rx));
43 	if (!ASSERT_OK(err, "setsockopt(TLS_RX)"))
44 		goto out;
45 	return 0;
46 out:
47 	return -1;
48 }
49 
50 static int create_ktls_pairs(int family, int sotype, int *c, int *p)
51 {
52 	int err;
53 
54 	err = create_pair(family, sotype, c, p);
55 	if (!ASSERT_OK(err, "create_pair()"))
56 		return -1;
57 
58 	err = init_ktls_pairs(*c, *p);
59 	if (!ASSERT_OK(err, "init_ktls_pairs(c, p)"))
60 		return -1;
61 	return 0;
62 }
63 
64 static void test_sockmap_ktls_update_fails_when_sock_has_ulp(int family, int map)
65 {
66 	struct sockaddr_storage addr = {};
67 	socklen_t len = sizeof(addr);
68 	struct sockaddr_in6 *v6;
69 	struct sockaddr_in *v4;
70 	int err, s, zero = 0;
71 
72 	switch (family) {
73 	case AF_INET:
74 		v4 = (struct sockaddr_in *)&addr;
75 		v4->sin_family = AF_INET;
76 		break;
77 	case AF_INET6:
78 		v6 = (struct sockaddr_in6 *)&addr;
79 		v6->sin6_family = AF_INET6;
80 		break;
81 	default:
82 		PRINT_FAIL("unsupported socket family %d", family);
83 		return;
84 	}
85 
86 	s = socket(family, SOCK_STREAM, 0);
87 	if (!ASSERT_GE(s, 0, "socket"))
88 		return;
89 
90 	err = bind(s, (struct sockaddr *)&addr, len);
91 	if (!ASSERT_OK(err, "bind"))
92 		goto close;
93 
94 	err = getsockname(s, (struct sockaddr *)&addr, &len);
95 	if (!ASSERT_OK(err, "getsockname"))
96 		goto close;
97 
98 	err = connect(s, (struct sockaddr *)&addr, len);
99 	if (!ASSERT_OK(err, "connect"))
100 		goto close;
101 
102 	/* save sk->sk_prot and set it to tls_prots */
103 	err = setsockopt(s, IPPROTO_TCP, TCP_ULP, "tls", strlen("tls"));
104 	if (!ASSERT_OK(err, "setsockopt(TCP_ULP)"))
105 		goto close;
106 
107 	/* sockmap update should not affect saved sk_prot */
108 	err = bpf_map_update_elem(map, &zero, &s, BPF_ANY);
109 	if (!ASSERT_ERR(err, "sockmap update elem"))
110 		goto close;
111 
112 	/* call sk->sk_prot->setsockopt to dispatch to saved sk_prot */
113 	err = setsockopt(s, IPPROTO_TCP, TCP_NODELAY, &zero, sizeof(zero));
114 	ASSERT_OK(err, "setsockopt(TCP_NODELAY)");
115 
116 close:
117 	close(s);
118 }
119 
120 static const char *fmt_test_name(const char *subtest_name, int family,
121 				 enum bpf_map_type map_type)
122 {
123 	const char *map_type_str = BPF_MAP_TYPE_SOCKMAP ? "SOCKMAP" : "SOCKHASH";
124 	const char *family_str = AF_INET ? "IPv4" : "IPv6";
125 	static char test_name[MAX_TEST_NAME];
126 
127 	snprintf(test_name, MAX_TEST_NAME,
128 		 "sockmap_ktls %s %s %s",
129 		 subtest_name, family_str, map_type_str);
130 
131 	return test_name;
132 }
133 
134 static void test_sockmap_ktls_offload(int family, int sotype)
135 {
136 	int err;
137 	int c = 0, p = 0, sent, recvd;
138 	char msg[12] = "hello world\0";
139 	char rcv[13];
140 
141 	err = create_ktls_pairs(family, sotype, &c, &p);
142 	if (!ASSERT_OK(err, "create_ktls_pairs()"))
143 		goto out;
144 
145 	sent = send(c, msg, sizeof(msg), 0);
146 	if (!ASSERT_OK(err, "send(msg)"))
147 		goto out;
148 
149 	recvd = recv(p, rcv, sizeof(rcv), 0);
150 	if (!ASSERT_OK(err, "recv(msg)") ||
151 	    !ASSERT_EQ(recvd, sent, "length mismatch"))
152 		goto out;
153 
154 	ASSERT_OK(memcmp(msg, rcv, sizeof(msg)), "data mismatch");
155 
156 out:
157 	if (c)
158 		close(c);
159 	if (p)
160 		close(p);
161 }
162 
163 static void test_sockmap_ktls_tx_cork(int family, int sotype, bool push)
164 {
165 	int err, off;
166 	int i, j;
167 	int start_push = 0, push_len = 0;
168 	int c = 0, p = 0, one = 1, sent, recvd;
169 	int prog_fd, map_fd;
170 	char msg[12] = "hello world\0";
171 	char rcv[20] = {0};
172 	struct test_sockmap_ktls *skel;
173 
174 	skel = test_sockmap_ktls__open_and_load();
175 	if (!ASSERT_TRUE(skel, "open ktls skel"))
176 		return;
177 
178 	err = create_pair(family, sotype, &c, &p);
179 	if (!ASSERT_OK(err, "create_pair()"))
180 		goto out;
181 
182 	prog_fd = bpf_program__fd(skel->progs.prog_sk_policy);
183 	map_fd = bpf_map__fd(skel->maps.sock_map);
184 
185 	err = bpf_prog_attach(prog_fd, map_fd, BPF_SK_MSG_VERDICT, 0);
186 	if (!ASSERT_OK(err, "bpf_prog_attach sk msg"))
187 		goto out;
188 
189 	err = bpf_map_update_elem(map_fd, &one, &c, BPF_NOEXIST);
190 	if (!ASSERT_OK(err, "bpf_map_update_elem(c)"))
191 		goto out;
192 
193 	err = init_ktls_pairs(c, p);
194 	if (!ASSERT_OK(err, "init_ktls_pairs(c, p)"))
195 		goto out;
196 
197 	skel->bss->cork_byte = sizeof(msg);
198 	if (push) {
199 		start_push = 1;
200 		push_len = 2;
201 	}
202 	skel->bss->push_start = start_push;
203 	skel->bss->push_end = push_len;
204 
205 	off = sizeof(msg) / 2;
206 	sent = send(c, msg, off, 0);
207 	if (!ASSERT_EQ(sent, off, "send(msg)"))
208 		goto out;
209 
210 	recvd = recv_timeout(p, rcv, sizeof(rcv), MSG_DONTWAIT, 1);
211 	if (!ASSERT_EQ(-1, recvd, "expected no data"))
212 		goto out;
213 
214 	/* send remaining msg */
215 	sent = send(c, msg + off, sizeof(msg) - off, 0);
216 	if (!ASSERT_EQ(sent, sizeof(msg) - off, "send remaining data"))
217 		goto out;
218 
219 	recvd = recv_timeout(p, rcv, sizeof(rcv), MSG_DONTWAIT, 1);
220 	if (!ASSERT_OK(err, "recv(msg)") ||
221 	    !ASSERT_EQ(recvd, sizeof(msg) + push_len, "check length mismatch"))
222 		goto out;
223 
224 	for (i = 0, j = 0; i < recvd;) {
225 		/* skip checking the data that has been pushed in */
226 		if (i >= start_push && i <= start_push + push_len - 1) {
227 			i++;
228 			continue;
229 		}
230 		if (!ASSERT_EQ(rcv[i], msg[j], "data mismatch"))
231 			goto out;
232 		i++;
233 		j++;
234 	}
235 out:
236 	if (c)
237 		close(c);
238 	if (p)
239 		close(p);
240 	test_sockmap_ktls__destroy(skel);
241 }
242 
243 static void test_sockmap_ktls_tx_no_buf(int family, int sotype, bool push)
244 {
245 	int c = -1, p = -1, one = 1, two = 2;
246 	struct test_sockmap_ktls *skel;
247 	unsigned char *data = NULL;
248 	struct msghdr msg = {0};
249 	struct iovec iov[2];
250 	int prog_fd, map_fd;
251 	int txrx_buf = 1024;
252 	int iov_length = 8192;
253 	int err;
254 
255 	skel = test_sockmap_ktls__open_and_load();
256 	if (!ASSERT_TRUE(skel, "open ktls skel"))
257 		return;
258 
259 	err = create_pair(family, sotype, &c, &p);
260 	if (!ASSERT_OK(err, "create_pair()"))
261 		goto out;
262 
263 	err = setsockopt(c, SOL_SOCKET, SO_RCVBUFFORCE, &txrx_buf, sizeof(int));
264 	err |= setsockopt(p, SOL_SOCKET, SO_SNDBUFFORCE, &txrx_buf, sizeof(int));
265 	if (!ASSERT_OK(err, "set buf limit"))
266 		goto out;
267 
268 	prog_fd = bpf_program__fd(skel->progs.prog_sk_policy_redir);
269 	map_fd = bpf_map__fd(skel->maps.sock_map);
270 
271 	err = bpf_prog_attach(prog_fd, map_fd, BPF_SK_MSG_VERDICT, 0);
272 	if (!ASSERT_OK(err, "bpf_prog_attach sk msg"))
273 		goto out;
274 
275 	err = bpf_map_update_elem(map_fd, &one, &c, BPF_NOEXIST);
276 	if (!ASSERT_OK(err, "bpf_map_update_elem(c)"))
277 		goto out;
278 
279 	err = bpf_map_update_elem(map_fd, &two, &p, BPF_NOEXIST);
280 	if (!ASSERT_OK(err, "bpf_map_update_elem(p)"))
281 		goto out;
282 
283 	skel->bss->apply_bytes = 1024;
284 
285 	err = init_ktls_pairs(c, p);
286 	if (!ASSERT_OK(err, "init_ktls_pairs(c, p)"))
287 		goto out;
288 
289 	data = calloc(iov_length, sizeof(char));
290 	if (!data)
291 		goto out;
292 
293 	iov[0].iov_base = data;
294 	iov[0].iov_len = iov_length;
295 	iov[1].iov_base = data;
296 	iov[1].iov_len = iov_length;
297 	msg.msg_iov = iov;
298 	msg.msg_iovlen = 2;
299 
300 	for (;;) {
301 		err = sendmsg(c, &msg, MSG_DONTWAIT);
302 		if (err <= 0)
303 			break;
304 	}
305 
306 out:
307 	if (data)
308 		free(data);
309 	if (c != -1)
310 		close(c);
311 	if (p != -1)
312 		close(p);
313 
314 	test_sockmap_ktls__destroy(skel);
315 }
316 
317 static void run_tests(int family, enum bpf_map_type map_type)
318 {
319 	int map;
320 
321 	map = bpf_map_create(map_type, NULL, sizeof(int), sizeof(int), 1, NULL);
322 	if (!ASSERT_GE(map, 0, "bpf_map_create"))
323 		return;
324 
325 	if (test__start_subtest(fmt_test_name("update_fails_when_sock_has_ulp", family, map_type)))
326 		test_sockmap_ktls_update_fails_when_sock_has_ulp(family, map);
327 
328 	close(map);
329 }
330 
331 static void run_ktls_test(int family, int sotype)
332 {
333 	if (test__start_subtest("tls simple offload"))
334 		test_sockmap_ktls_offload(family, sotype);
335 	if (test__start_subtest("tls tx cork"))
336 		test_sockmap_ktls_tx_cork(family, sotype, false);
337 	if (test__start_subtest("tls tx cork with push"))
338 		test_sockmap_ktls_tx_cork(family, sotype, true);
339 	if (test__start_subtest("tls tx egress with no buf"))
340 		test_sockmap_ktls_tx_no_buf(family, sotype, true);
341 }
342 
343 void test_sockmap_ktls(void)
344 {
345 	run_tests(AF_INET, BPF_MAP_TYPE_SOCKMAP);
346 	run_tests(AF_INET, BPF_MAP_TYPE_SOCKHASH);
347 	run_tests(AF_INET6, BPF_MAP_TYPE_SOCKMAP);
348 	run_tests(AF_INET6, BPF_MAP_TYPE_SOCKHASH);
349 	run_ktls_test(AF_INET, SOCK_STREAM);
350 	run_ktls_test(AF_INET6, SOCK_STREAM);
351 }
352