1 // SPDX-License-Identifier: GPL-2.0 2 // Copyright (c) 2020 Cloudflare 3 /* 4 * Tests for sockmap/sockhash holding kTLS sockets. 5 */ 6 #include <error.h> 7 #include <netinet/tcp.h> 8 #include <linux/tls.h> 9 #include "test_progs.h" 10 #include "sockmap_helpers.h" 11 #include "test_skmsg_load_helpers.skel.h" 12 #include "test_sockmap_ktls.skel.h" 13 14 #define MAX_TEST_NAME 80 15 #define TCP_ULP 31 16 17 static int init_ktls_pairs(int c, int p) 18 { 19 int err; 20 struct tls12_crypto_info_aes_gcm_128 crypto_rx; 21 struct tls12_crypto_info_aes_gcm_128 crypto_tx; 22 23 err = setsockopt(c, IPPROTO_TCP, TCP_ULP, "tls", strlen("tls")); 24 if (!ASSERT_OK(err, "setsockopt(TCP_ULP)")) 25 goto out; 26 27 err = setsockopt(p, IPPROTO_TCP, TCP_ULP, "tls", strlen("tls")); 28 if (!ASSERT_OK(err, "setsockopt(TCP_ULP)")) 29 goto out; 30 31 memset(&crypto_rx, 0, sizeof(crypto_rx)); 32 memset(&crypto_tx, 0, sizeof(crypto_tx)); 33 crypto_rx.info.version = TLS_1_2_VERSION; 34 crypto_tx.info.version = TLS_1_2_VERSION; 35 crypto_rx.info.cipher_type = TLS_CIPHER_AES_GCM_128; 36 crypto_tx.info.cipher_type = TLS_CIPHER_AES_GCM_128; 37 38 err = setsockopt(c, SOL_TLS, TLS_TX, &crypto_tx, sizeof(crypto_tx)); 39 if (!ASSERT_OK(err, "setsockopt(TLS_TX)")) 40 goto out; 41 42 err = setsockopt(p, SOL_TLS, TLS_RX, &crypto_rx, sizeof(crypto_rx)); 43 if (!ASSERT_OK(err, "setsockopt(TLS_RX)")) 44 goto out; 45 return 0; 46 out: 47 return -1; 48 } 49 50 static int create_ktls_pairs(int family, int sotype, int *c, int *p) 51 { 52 int err; 53 54 err = create_pair(family, sotype, c, p); 55 if (!ASSERT_OK(err, "create_pair()")) 56 return -1; 57 58 err = init_ktls_pairs(*c, *p); 59 if (!ASSERT_OK(err, "init_ktls_pairs(c, p)")) 60 return -1; 61 return 0; 62 } 63 64 static int tcp_server(int family) 65 { 66 int err, s; 67 68 s = socket(family, SOCK_STREAM, 0); 69 if (!ASSERT_GE(s, 0, "socket")) 70 return -1; 71 72 err = listen(s, SOMAXCONN); 73 if (!ASSERT_OK(err, "listen")) 74 return -1; 75 76 return s; 77 } 78 79 static int disconnect(int fd) 80 { 81 struct sockaddr unspec = { AF_UNSPEC }; 82 83 return connect(fd, &unspec, sizeof(unspec)); 84 } 85 86 /* Disconnect (unhash) a kTLS socket after removing it from sockmap. */ 87 static void test_sockmap_ktls_disconnect_after_delete(int family, int map) 88 { 89 struct sockaddr_storage addr = {0}; 90 socklen_t len = sizeof(addr); 91 int err, cli, srv, zero = 0; 92 93 srv = tcp_server(family); 94 if (srv == -1) 95 return; 96 97 err = getsockname(srv, (struct sockaddr *)&addr, &len); 98 if (!ASSERT_OK(err, "getsockopt")) 99 goto close_srv; 100 101 cli = socket(family, SOCK_STREAM, 0); 102 if (!ASSERT_GE(cli, 0, "socket")) 103 goto close_srv; 104 105 err = connect(cli, (struct sockaddr *)&addr, len); 106 if (!ASSERT_OK(err, "connect")) 107 goto close_cli; 108 109 err = bpf_map_update_elem(map, &zero, &cli, 0); 110 if (!ASSERT_OK(err, "bpf_map_update_elem")) 111 goto close_cli; 112 113 err = setsockopt(cli, IPPROTO_TCP, TCP_ULP, "tls", strlen("tls")); 114 if (!ASSERT_OK(err, "setsockopt(TCP_ULP)")) 115 goto close_cli; 116 117 err = bpf_map_delete_elem(map, &zero); 118 if (!ASSERT_OK(err, "bpf_map_delete_elem")) 119 goto close_cli; 120 121 err = disconnect(cli); 122 ASSERT_OK(err, "disconnect"); 123 124 close_cli: 125 close(cli); 126 close_srv: 127 close(srv); 128 } 129 130 static void test_sockmap_ktls_update_fails_when_sock_has_ulp(int family, int map) 131 { 132 struct sockaddr_storage addr = {}; 133 socklen_t len = sizeof(addr); 134 struct sockaddr_in6 *v6; 135 struct sockaddr_in *v4; 136 int err, s, zero = 0; 137 138 switch (family) { 139 case AF_INET: 140 v4 = (struct sockaddr_in *)&addr; 141 v4->sin_family = AF_INET; 142 break; 143 case AF_INET6: 144 v6 = (struct sockaddr_in6 *)&addr; 145 v6->sin6_family = AF_INET6; 146 break; 147 default: 148 PRINT_FAIL("unsupported socket family %d", family); 149 return; 150 } 151 152 s = socket(family, SOCK_STREAM, 0); 153 if (!ASSERT_GE(s, 0, "socket")) 154 return; 155 156 err = bind(s, (struct sockaddr *)&addr, len); 157 if (!ASSERT_OK(err, "bind")) 158 goto close; 159 160 err = getsockname(s, (struct sockaddr *)&addr, &len); 161 if (!ASSERT_OK(err, "getsockname")) 162 goto close; 163 164 err = connect(s, (struct sockaddr *)&addr, len); 165 if (!ASSERT_OK(err, "connect")) 166 goto close; 167 168 /* save sk->sk_prot and set it to tls_prots */ 169 err = setsockopt(s, IPPROTO_TCP, TCP_ULP, "tls", strlen("tls")); 170 if (!ASSERT_OK(err, "setsockopt(TCP_ULP)")) 171 goto close; 172 173 /* sockmap update should not affect saved sk_prot */ 174 err = bpf_map_update_elem(map, &zero, &s, BPF_ANY); 175 if (!ASSERT_ERR(err, "sockmap update elem")) 176 goto close; 177 178 /* call sk->sk_prot->setsockopt to dispatch to saved sk_prot */ 179 err = setsockopt(s, IPPROTO_TCP, TCP_NODELAY, &zero, sizeof(zero)); 180 ASSERT_OK(err, "setsockopt(TCP_NODELAY)"); 181 182 close: 183 close(s); 184 } 185 186 static const char *fmt_test_name(const char *subtest_name, int family, 187 enum bpf_map_type map_type) 188 { 189 const char *map_type_str = BPF_MAP_TYPE_SOCKMAP ? "SOCKMAP" : "SOCKHASH"; 190 const char *family_str = AF_INET ? "IPv4" : "IPv6"; 191 static char test_name[MAX_TEST_NAME]; 192 193 snprintf(test_name, MAX_TEST_NAME, 194 "sockmap_ktls %s %s %s", 195 subtest_name, family_str, map_type_str); 196 197 return test_name; 198 } 199 200 static void test_sockmap_ktls_offload(int family, int sotype) 201 { 202 int err; 203 int c = 0, p = 0, sent, recvd; 204 char msg[12] = "hello world\0"; 205 char rcv[13]; 206 207 err = create_ktls_pairs(family, sotype, &c, &p); 208 if (!ASSERT_OK(err, "create_ktls_pairs()")) 209 goto out; 210 211 sent = send(c, msg, sizeof(msg), 0); 212 if (!ASSERT_OK(err, "send(msg)")) 213 goto out; 214 215 recvd = recv(p, rcv, sizeof(rcv), 0); 216 if (!ASSERT_OK(err, "recv(msg)") || 217 !ASSERT_EQ(recvd, sent, "length mismatch")) 218 goto out; 219 220 ASSERT_OK(memcmp(msg, rcv, sizeof(msg)), "data mismatch"); 221 222 out: 223 if (c) 224 close(c); 225 if (p) 226 close(p); 227 } 228 229 static void test_sockmap_ktls_tx_cork(int family, int sotype, bool push) 230 { 231 int err, off; 232 int i, j; 233 int start_push = 0, push_len = 0; 234 int c = 0, p = 0, one = 1, sent, recvd; 235 int prog_fd, map_fd; 236 char msg[12] = "hello world\0"; 237 char rcv[20] = {0}; 238 struct test_sockmap_ktls *skel; 239 240 skel = test_sockmap_ktls__open_and_load(); 241 if (!ASSERT_TRUE(skel, "open ktls skel")) 242 return; 243 244 err = create_pair(family, sotype, &c, &p); 245 if (!ASSERT_OK(err, "create_pair()")) 246 goto out; 247 248 prog_fd = bpf_program__fd(skel->progs.prog_sk_policy); 249 map_fd = bpf_map__fd(skel->maps.sock_map); 250 251 err = bpf_prog_attach(prog_fd, map_fd, BPF_SK_MSG_VERDICT, 0); 252 if (!ASSERT_OK(err, "bpf_prog_attach sk msg")) 253 goto out; 254 255 err = bpf_map_update_elem(map_fd, &one, &c, BPF_NOEXIST); 256 if (!ASSERT_OK(err, "bpf_map_update_elem(c)")) 257 goto out; 258 259 err = init_ktls_pairs(c, p); 260 if (!ASSERT_OK(err, "init_ktls_pairs(c, p)")) 261 goto out; 262 263 skel->bss->cork_byte = sizeof(msg); 264 if (push) { 265 start_push = 1; 266 push_len = 2; 267 } 268 skel->bss->push_start = start_push; 269 skel->bss->push_end = push_len; 270 271 off = sizeof(msg) / 2; 272 sent = send(c, msg, off, 0); 273 if (!ASSERT_EQ(sent, off, "send(msg)")) 274 goto out; 275 276 recvd = recv_timeout(p, rcv, sizeof(rcv), MSG_DONTWAIT, 1); 277 if (!ASSERT_EQ(-1, recvd, "expected no data")) 278 goto out; 279 280 /* send remaining msg */ 281 sent = send(c, msg + off, sizeof(msg) - off, 0); 282 if (!ASSERT_EQ(sent, sizeof(msg) - off, "send remaining data")) 283 goto out; 284 285 recvd = recv_timeout(p, rcv, sizeof(rcv), MSG_DONTWAIT, 1); 286 if (!ASSERT_OK(err, "recv(msg)") || 287 !ASSERT_EQ(recvd, sizeof(msg) + push_len, "check length mismatch")) 288 goto out; 289 290 for (i = 0, j = 0; i < recvd;) { 291 /* skip checking the data that has been pushed in */ 292 if (i >= start_push && i <= start_push + push_len - 1) { 293 i++; 294 continue; 295 } 296 if (!ASSERT_EQ(rcv[i], msg[j], "data mismatch")) 297 goto out; 298 i++; 299 j++; 300 } 301 out: 302 if (c) 303 close(c); 304 if (p) 305 close(p); 306 test_sockmap_ktls__destroy(skel); 307 } 308 309 static void run_tests(int family, enum bpf_map_type map_type) 310 { 311 int map; 312 313 map = bpf_map_create(map_type, NULL, sizeof(int), sizeof(int), 1, NULL); 314 if (!ASSERT_GE(map, 0, "bpf_map_create")) 315 return; 316 317 if (test__start_subtest(fmt_test_name("disconnect_after_delete", family, map_type))) 318 test_sockmap_ktls_disconnect_after_delete(family, map); 319 if (test__start_subtest(fmt_test_name("update_fails_when_sock_has_ulp", family, map_type))) 320 test_sockmap_ktls_update_fails_when_sock_has_ulp(family, map); 321 322 close(map); 323 } 324 325 static void run_ktls_test(int family, int sotype) 326 { 327 if (test__start_subtest("tls simple offload")) 328 test_sockmap_ktls_offload(family, sotype); 329 if (test__start_subtest("tls tx cork")) 330 test_sockmap_ktls_tx_cork(family, sotype, false); 331 if (test__start_subtest("tls tx cork with push")) 332 test_sockmap_ktls_tx_cork(family, sotype, true); 333 } 334 335 void test_sockmap_ktls(void) 336 { 337 run_tests(AF_INET, BPF_MAP_TYPE_SOCKMAP); 338 run_tests(AF_INET, BPF_MAP_TYPE_SOCKHASH); 339 run_tests(AF_INET6, BPF_MAP_TYPE_SOCKMAP); 340 run_tests(AF_INET6, BPF_MAP_TYPE_SOCKHASH); 341 run_ktls_test(AF_INET, SOCK_STREAM); 342 run_ktls_test(AF_INET6, SOCK_STREAM); 343 } 344