xref: /linux/tools/testing/selftests/bpf/prog_tests/sockmap_ktls.c (revision 4c0a42c50021ee509f159c1f8a22efb35987c941)
1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (c) 2020 Cloudflare
3 /*
4  * Tests for sockmap/sockhash holding kTLS sockets.
5  */
6 #include <error.h>
7 #include <netinet/tcp.h>
8 #include <linux/tls.h>
9 #include "test_progs.h"
10 #include "sockmap_helpers.h"
11 #include "test_skmsg_load_helpers.skel.h"
12 #include "test_sockmap_ktls.skel.h"
13 
14 #define MAX_TEST_NAME 80
15 #define TCP_ULP 31
16 
17 static int init_ktls_pairs(int c, int p)
18 {
19 	int err;
20 	struct tls12_crypto_info_aes_gcm_128 crypto_rx;
21 	struct tls12_crypto_info_aes_gcm_128 crypto_tx;
22 
23 	err = setsockopt(c, IPPROTO_TCP, TCP_ULP, "tls", strlen("tls"));
24 	if (!ASSERT_OK(err, "setsockopt(TCP_ULP)"))
25 		goto out;
26 
27 	err = setsockopt(p, IPPROTO_TCP, TCP_ULP, "tls", strlen("tls"));
28 	if (!ASSERT_OK(err, "setsockopt(TCP_ULP)"))
29 		goto out;
30 
31 	memset(&crypto_rx, 0, sizeof(crypto_rx));
32 	memset(&crypto_tx, 0, sizeof(crypto_tx));
33 	crypto_rx.info.version = TLS_1_2_VERSION;
34 	crypto_tx.info.version = TLS_1_2_VERSION;
35 	crypto_rx.info.cipher_type = TLS_CIPHER_AES_GCM_128;
36 	crypto_tx.info.cipher_type = TLS_CIPHER_AES_GCM_128;
37 
38 	err = setsockopt(c, SOL_TLS, TLS_TX, &crypto_tx, sizeof(crypto_tx));
39 	if (!ASSERT_OK(err, "setsockopt(TLS_TX)"))
40 		goto out;
41 
42 	err = setsockopt(p, SOL_TLS, TLS_RX, &crypto_rx, sizeof(crypto_rx));
43 	if (!ASSERT_OK(err, "setsockopt(TLS_RX)"))
44 		goto out;
45 	return 0;
46 out:
47 	return -1;
48 }
49 
50 static int create_ktls_pairs(int family, int sotype, int *c, int *p)
51 {
52 	int err;
53 
54 	err = create_pair(family, sotype, c, p);
55 	if (!ASSERT_OK(err, "create_pair()"))
56 		return -1;
57 
58 	err = init_ktls_pairs(*c, *p);
59 	if (!ASSERT_OK(err, "init_ktls_pairs(c, p)"))
60 		return -1;
61 	return 0;
62 }
63 
64 static int tcp_server(int family)
65 {
66 	int err, s;
67 
68 	s = socket(family, SOCK_STREAM, 0);
69 	if (!ASSERT_GE(s, 0, "socket"))
70 		return -1;
71 
72 	err = listen(s, SOMAXCONN);
73 	if (!ASSERT_OK(err, "listen"))
74 		return -1;
75 
76 	return s;
77 }
78 
79 static int disconnect(int fd)
80 {
81 	struct sockaddr unspec = { AF_UNSPEC };
82 
83 	return connect(fd, &unspec, sizeof(unspec));
84 }
85 
86 /* Disconnect (unhash) a kTLS socket after removing it from sockmap. */
87 static void test_sockmap_ktls_disconnect_after_delete(int family, int map)
88 {
89 	struct sockaddr_storage addr = {0};
90 	socklen_t len = sizeof(addr);
91 	int err, cli, srv, zero = 0;
92 
93 	srv = tcp_server(family);
94 	if (srv == -1)
95 		return;
96 
97 	err = getsockname(srv, (struct sockaddr *)&addr, &len);
98 	if (!ASSERT_OK(err, "getsockopt"))
99 		goto close_srv;
100 
101 	cli = socket(family, SOCK_STREAM, 0);
102 	if (!ASSERT_GE(cli, 0, "socket"))
103 		goto close_srv;
104 
105 	err = connect(cli, (struct sockaddr *)&addr, len);
106 	if (!ASSERT_OK(err, "connect"))
107 		goto close_cli;
108 
109 	err = bpf_map_update_elem(map, &zero, &cli, 0);
110 	if (!ASSERT_OK(err, "bpf_map_update_elem"))
111 		goto close_cli;
112 
113 	err = setsockopt(cli, IPPROTO_TCP, TCP_ULP, "tls", strlen("tls"));
114 	if (!ASSERT_OK(err, "setsockopt(TCP_ULP)"))
115 		goto close_cli;
116 
117 	err = bpf_map_delete_elem(map, &zero);
118 	if (!ASSERT_OK(err, "bpf_map_delete_elem"))
119 		goto close_cli;
120 
121 	err = disconnect(cli);
122 	ASSERT_OK(err, "disconnect");
123 
124 close_cli:
125 	close(cli);
126 close_srv:
127 	close(srv);
128 }
129 
130 static void test_sockmap_ktls_update_fails_when_sock_has_ulp(int family, int map)
131 {
132 	struct sockaddr_storage addr = {};
133 	socklen_t len = sizeof(addr);
134 	struct sockaddr_in6 *v6;
135 	struct sockaddr_in *v4;
136 	int err, s, zero = 0;
137 
138 	switch (family) {
139 	case AF_INET:
140 		v4 = (struct sockaddr_in *)&addr;
141 		v4->sin_family = AF_INET;
142 		break;
143 	case AF_INET6:
144 		v6 = (struct sockaddr_in6 *)&addr;
145 		v6->sin6_family = AF_INET6;
146 		break;
147 	default:
148 		PRINT_FAIL("unsupported socket family %d", family);
149 		return;
150 	}
151 
152 	s = socket(family, SOCK_STREAM, 0);
153 	if (!ASSERT_GE(s, 0, "socket"))
154 		return;
155 
156 	err = bind(s, (struct sockaddr *)&addr, len);
157 	if (!ASSERT_OK(err, "bind"))
158 		goto close;
159 
160 	err = getsockname(s, (struct sockaddr *)&addr, &len);
161 	if (!ASSERT_OK(err, "getsockname"))
162 		goto close;
163 
164 	err = connect(s, (struct sockaddr *)&addr, len);
165 	if (!ASSERT_OK(err, "connect"))
166 		goto close;
167 
168 	/* save sk->sk_prot and set it to tls_prots */
169 	err = setsockopt(s, IPPROTO_TCP, TCP_ULP, "tls", strlen("tls"));
170 	if (!ASSERT_OK(err, "setsockopt(TCP_ULP)"))
171 		goto close;
172 
173 	/* sockmap update should not affect saved sk_prot */
174 	err = bpf_map_update_elem(map, &zero, &s, BPF_ANY);
175 	if (!ASSERT_ERR(err, "sockmap update elem"))
176 		goto close;
177 
178 	/* call sk->sk_prot->setsockopt to dispatch to saved sk_prot */
179 	err = setsockopt(s, IPPROTO_TCP, TCP_NODELAY, &zero, sizeof(zero));
180 	ASSERT_OK(err, "setsockopt(TCP_NODELAY)");
181 
182 close:
183 	close(s);
184 }
185 
186 static const char *fmt_test_name(const char *subtest_name, int family,
187 				 enum bpf_map_type map_type)
188 {
189 	const char *map_type_str = BPF_MAP_TYPE_SOCKMAP ? "SOCKMAP" : "SOCKHASH";
190 	const char *family_str = AF_INET ? "IPv4" : "IPv6";
191 	static char test_name[MAX_TEST_NAME];
192 
193 	snprintf(test_name, MAX_TEST_NAME,
194 		 "sockmap_ktls %s %s %s",
195 		 subtest_name, family_str, map_type_str);
196 
197 	return test_name;
198 }
199 
200 static void test_sockmap_ktls_offload(int family, int sotype)
201 {
202 	int err;
203 	int c = 0, p = 0, sent, recvd;
204 	char msg[12] = "hello world\0";
205 	char rcv[13];
206 
207 	err = create_ktls_pairs(family, sotype, &c, &p);
208 	if (!ASSERT_OK(err, "create_ktls_pairs()"))
209 		goto out;
210 
211 	sent = send(c, msg, sizeof(msg), 0);
212 	if (!ASSERT_OK(err, "send(msg)"))
213 		goto out;
214 
215 	recvd = recv(p, rcv, sizeof(rcv), 0);
216 	if (!ASSERT_OK(err, "recv(msg)") ||
217 	    !ASSERT_EQ(recvd, sent, "length mismatch"))
218 		goto out;
219 
220 	ASSERT_OK(memcmp(msg, rcv, sizeof(msg)), "data mismatch");
221 
222 out:
223 	if (c)
224 		close(c);
225 	if (p)
226 		close(p);
227 }
228 
229 static void test_sockmap_ktls_tx_cork(int family, int sotype, bool push)
230 {
231 	int err, off;
232 	int i, j;
233 	int start_push = 0, push_len = 0;
234 	int c = 0, p = 0, one = 1, sent, recvd;
235 	int prog_fd, map_fd;
236 	char msg[12] = "hello world\0";
237 	char rcv[20] = {0};
238 	struct test_sockmap_ktls *skel;
239 
240 	skel = test_sockmap_ktls__open_and_load();
241 	if (!ASSERT_TRUE(skel, "open ktls skel"))
242 		return;
243 
244 	err = create_pair(family, sotype, &c, &p);
245 	if (!ASSERT_OK(err, "create_pair()"))
246 		goto out;
247 
248 	prog_fd = bpf_program__fd(skel->progs.prog_sk_policy);
249 	map_fd = bpf_map__fd(skel->maps.sock_map);
250 
251 	err = bpf_prog_attach(prog_fd, map_fd, BPF_SK_MSG_VERDICT, 0);
252 	if (!ASSERT_OK(err, "bpf_prog_attach sk msg"))
253 		goto out;
254 
255 	err = bpf_map_update_elem(map_fd, &one, &c, BPF_NOEXIST);
256 	if (!ASSERT_OK(err, "bpf_map_update_elem(c)"))
257 		goto out;
258 
259 	err = init_ktls_pairs(c, p);
260 	if (!ASSERT_OK(err, "init_ktls_pairs(c, p)"))
261 		goto out;
262 
263 	skel->bss->cork_byte = sizeof(msg);
264 	if (push) {
265 		start_push = 1;
266 		push_len = 2;
267 	}
268 	skel->bss->push_start = start_push;
269 	skel->bss->push_end = push_len;
270 
271 	off = sizeof(msg) / 2;
272 	sent = send(c, msg, off, 0);
273 	if (!ASSERT_EQ(sent, off, "send(msg)"))
274 		goto out;
275 
276 	recvd = recv_timeout(p, rcv, sizeof(rcv), MSG_DONTWAIT, 1);
277 	if (!ASSERT_EQ(-1, recvd, "expected no data"))
278 		goto out;
279 
280 	/* send remaining msg */
281 	sent = send(c, msg + off, sizeof(msg) - off, 0);
282 	if (!ASSERT_EQ(sent, sizeof(msg) - off, "send remaining data"))
283 		goto out;
284 
285 	recvd = recv_timeout(p, rcv, sizeof(rcv), MSG_DONTWAIT, 1);
286 	if (!ASSERT_OK(err, "recv(msg)") ||
287 	    !ASSERT_EQ(recvd, sizeof(msg) + push_len, "check length mismatch"))
288 		goto out;
289 
290 	for (i = 0, j = 0; i < recvd;) {
291 		/* skip checking the data that has been pushed in */
292 		if (i >= start_push && i <= start_push + push_len - 1) {
293 			i++;
294 			continue;
295 		}
296 		if (!ASSERT_EQ(rcv[i], msg[j], "data mismatch"))
297 			goto out;
298 		i++;
299 		j++;
300 	}
301 out:
302 	if (c)
303 		close(c);
304 	if (p)
305 		close(p);
306 	test_sockmap_ktls__destroy(skel);
307 }
308 
309 static void run_tests(int family, enum bpf_map_type map_type)
310 {
311 	int map;
312 
313 	map = bpf_map_create(map_type, NULL, sizeof(int), sizeof(int), 1, NULL);
314 	if (!ASSERT_GE(map, 0, "bpf_map_create"))
315 		return;
316 
317 	if (test__start_subtest(fmt_test_name("disconnect_after_delete", family, map_type)))
318 		test_sockmap_ktls_disconnect_after_delete(family, map);
319 	if (test__start_subtest(fmt_test_name("update_fails_when_sock_has_ulp", family, map_type)))
320 		test_sockmap_ktls_update_fails_when_sock_has_ulp(family, map);
321 
322 	close(map);
323 }
324 
325 static void run_ktls_test(int family, int sotype)
326 {
327 	if (test__start_subtest("tls simple offload"))
328 		test_sockmap_ktls_offload(family, sotype);
329 	if (test__start_subtest("tls tx cork"))
330 		test_sockmap_ktls_tx_cork(family, sotype, false);
331 	if (test__start_subtest("tls tx cork with push"))
332 		test_sockmap_ktls_tx_cork(family, sotype, true);
333 }
334 
335 void test_sockmap_ktls(void)
336 {
337 	run_tests(AF_INET, BPF_MAP_TYPE_SOCKMAP);
338 	run_tests(AF_INET, BPF_MAP_TYPE_SOCKHASH);
339 	run_tests(AF_INET6, BPF_MAP_TYPE_SOCKMAP);
340 	run_tests(AF_INET6, BPF_MAP_TYPE_SOCKHASH);
341 	run_ktls_test(AF_INET, SOCK_STREAM);
342 	run_ktls_test(AF_INET6, SOCK_STREAM);
343 }
344