1 // SPDX-License-Identifier: GPL-2.0 2 #include <sys/un.h> 3 4 #include "test_progs.h" 5 6 #include "connect_unix_prog.skel.h" 7 #include "sendmsg_unix_prog.skel.h" 8 #include "recvmsg_unix_prog.skel.h" 9 #include "getsockname_unix_prog.skel.h" 10 #include "getpeername_unix_prog.skel.h" 11 #include "network_helpers.h" 12 13 #define SERVUN_ADDRESS "bpf_cgroup_unix_test" 14 #define SERVUN_REWRITE_ADDRESS "bpf_cgroup_unix_test_rewrite" 15 #define SRCUN_ADDRESS "bpf_cgroup_unix_test_src" 16 17 enum sock_addr_test_type { 18 SOCK_ADDR_TEST_BIND, 19 SOCK_ADDR_TEST_CONNECT, 20 SOCK_ADDR_TEST_SENDMSG, 21 SOCK_ADDR_TEST_RECVMSG, 22 SOCK_ADDR_TEST_GETSOCKNAME, 23 SOCK_ADDR_TEST_GETPEERNAME, 24 }; 25 26 typedef void *(*load_fn)(int cgroup_fd); 27 typedef void (*destroy_fn)(void *skel); 28 29 struct sock_addr_test { 30 enum sock_addr_test_type type; 31 const char *name; 32 /* BPF prog properties */ 33 load_fn loadfn; 34 destroy_fn destroyfn; 35 /* Socket properties */ 36 int socket_family; 37 int socket_type; 38 /* IP:port pairs for BPF prog to override */ 39 const char *requested_addr; 40 unsigned short requested_port; 41 const char *expected_addr; 42 unsigned short expected_port; 43 const char *expected_src_addr; 44 }; 45 46 static void *connect_unix_prog_load(int cgroup_fd) 47 { 48 struct connect_unix_prog *skel; 49 50 skel = connect_unix_prog__open_and_load(); 51 if (!ASSERT_OK_PTR(skel, "skel_open")) 52 goto cleanup; 53 54 skel->links.connect_unix_prog = bpf_program__attach_cgroup( 55 skel->progs.connect_unix_prog, cgroup_fd); 56 if (!ASSERT_OK_PTR(skel->links.connect_unix_prog, "prog_attach")) 57 goto cleanup; 58 59 return skel; 60 cleanup: 61 connect_unix_prog__destroy(skel); 62 return NULL; 63 } 64 65 static void connect_unix_prog_destroy(void *skel) 66 { 67 connect_unix_prog__destroy(skel); 68 } 69 70 static void *sendmsg_unix_prog_load(int cgroup_fd) 71 { 72 struct sendmsg_unix_prog *skel; 73 74 skel = sendmsg_unix_prog__open_and_load(); 75 if (!ASSERT_OK_PTR(skel, "skel_open")) 76 goto cleanup; 77 78 skel->links.sendmsg_unix_prog = bpf_program__attach_cgroup( 79 skel->progs.sendmsg_unix_prog, cgroup_fd); 80 if (!ASSERT_OK_PTR(skel->links.sendmsg_unix_prog, "prog_attach")) 81 goto cleanup; 82 83 return skel; 84 cleanup: 85 sendmsg_unix_prog__destroy(skel); 86 return NULL; 87 } 88 89 static void sendmsg_unix_prog_destroy(void *skel) 90 { 91 sendmsg_unix_prog__destroy(skel); 92 } 93 94 static void *recvmsg_unix_prog_load(int cgroup_fd) 95 { 96 struct recvmsg_unix_prog *skel; 97 98 skel = recvmsg_unix_prog__open_and_load(); 99 if (!ASSERT_OK_PTR(skel, "skel_open")) 100 goto cleanup; 101 102 skel->links.recvmsg_unix_prog = bpf_program__attach_cgroup( 103 skel->progs.recvmsg_unix_prog, cgroup_fd); 104 if (!ASSERT_OK_PTR(skel->links.recvmsg_unix_prog, "prog_attach")) 105 goto cleanup; 106 107 return skel; 108 cleanup: 109 recvmsg_unix_prog__destroy(skel); 110 return NULL; 111 } 112 113 static void recvmsg_unix_prog_destroy(void *skel) 114 { 115 recvmsg_unix_prog__destroy(skel); 116 } 117 118 static void *getsockname_unix_prog_load(int cgroup_fd) 119 { 120 struct getsockname_unix_prog *skel; 121 122 skel = getsockname_unix_prog__open_and_load(); 123 if (!ASSERT_OK_PTR(skel, "skel_open")) 124 goto cleanup; 125 126 skel->links.getsockname_unix_prog = bpf_program__attach_cgroup( 127 skel->progs.getsockname_unix_prog, cgroup_fd); 128 if (!ASSERT_OK_PTR(skel->links.getsockname_unix_prog, "prog_attach")) 129 goto cleanup; 130 131 return skel; 132 cleanup: 133 getsockname_unix_prog__destroy(skel); 134 return NULL; 135 } 136 137 static void getsockname_unix_prog_destroy(void *skel) 138 { 139 getsockname_unix_prog__destroy(skel); 140 } 141 142 static void *getpeername_unix_prog_load(int cgroup_fd) 143 { 144 struct getpeername_unix_prog *skel; 145 146 skel = getpeername_unix_prog__open_and_load(); 147 if (!ASSERT_OK_PTR(skel, "skel_open")) 148 goto cleanup; 149 150 skel->links.getpeername_unix_prog = bpf_program__attach_cgroup( 151 skel->progs.getpeername_unix_prog, cgroup_fd); 152 if (!ASSERT_OK_PTR(skel->links.getpeername_unix_prog, "prog_attach")) 153 goto cleanup; 154 155 return skel; 156 cleanup: 157 getpeername_unix_prog__destroy(skel); 158 return NULL; 159 } 160 161 static void getpeername_unix_prog_destroy(void *skel) 162 { 163 getpeername_unix_prog__destroy(skel); 164 } 165 166 static struct sock_addr_test tests[] = { 167 { 168 SOCK_ADDR_TEST_CONNECT, 169 "connect_unix", 170 connect_unix_prog_load, 171 connect_unix_prog_destroy, 172 AF_UNIX, 173 SOCK_STREAM, 174 SERVUN_ADDRESS, 175 0, 176 SERVUN_REWRITE_ADDRESS, 177 0, 178 NULL, 179 }, 180 { 181 SOCK_ADDR_TEST_SENDMSG, 182 "sendmsg_unix", 183 sendmsg_unix_prog_load, 184 sendmsg_unix_prog_destroy, 185 AF_UNIX, 186 SOCK_DGRAM, 187 SERVUN_ADDRESS, 188 0, 189 SERVUN_REWRITE_ADDRESS, 190 0, 191 NULL, 192 }, 193 { 194 SOCK_ADDR_TEST_RECVMSG, 195 "recvmsg_unix-dgram", 196 recvmsg_unix_prog_load, 197 recvmsg_unix_prog_destroy, 198 AF_UNIX, 199 SOCK_DGRAM, 200 SERVUN_REWRITE_ADDRESS, 201 0, 202 SERVUN_REWRITE_ADDRESS, 203 0, 204 SERVUN_ADDRESS, 205 }, 206 { 207 SOCK_ADDR_TEST_RECVMSG, 208 "recvmsg_unix-stream", 209 recvmsg_unix_prog_load, 210 recvmsg_unix_prog_destroy, 211 AF_UNIX, 212 SOCK_STREAM, 213 SERVUN_REWRITE_ADDRESS, 214 0, 215 SERVUN_REWRITE_ADDRESS, 216 0, 217 SERVUN_ADDRESS, 218 }, 219 { 220 SOCK_ADDR_TEST_GETSOCKNAME, 221 "getsockname_unix", 222 getsockname_unix_prog_load, 223 getsockname_unix_prog_destroy, 224 AF_UNIX, 225 SOCK_STREAM, 226 SERVUN_ADDRESS, 227 0, 228 SERVUN_REWRITE_ADDRESS, 229 0, 230 NULL, 231 }, 232 { 233 SOCK_ADDR_TEST_GETPEERNAME, 234 "getpeername_unix", 235 getpeername_unix_prog_load, 236 getpeername_unix_prog_destroy, 237 AF_UNIX, 238 SOCK_STREAM, 239 SERVUN_ADDRESS, 240 0, 241 SERVUN_REWRITE_ADDRESS, 242 0, 243 NULL, 244 }, 245 }; 246 247 typedef int (*info_fn)(int, struct sockaddr *, socklen_t *); 248 249 static int cmp_addr(const struct sockaddr_storage *addr1, socklen_t addr1_len, 250 const struct sockaddr_storage *addr2, socklen_t addr2_len, 251 bool cmp_port) 252 { 253 const struct sockaddr_in *four1, *four2; 254 const struct sockaddr_in6 *six1, *six2; 255 const struct sockaddr_un *un1, *un2; 256 257 if (addr1->ss_family != addr2->ss_family) 258 return -1; 259 260 if (addr1_len != addr2_len) 261 return -1; 262 263 if (addr1->ss_family == AF_INET) { 264 four1 = (const struct sockaddr_in *)addr1; 265 four2 = (const struct sockaddr_in *)addr2; 266 return !((four1->sin_port == four2->sin_port || !cmp_port) && 267 four1->sin_addr.s_addr == four2->sin_addr.s_addr); 268 } else if (addr1->ss_family == AF_INET6) { 269 six1 = (const struct sockaddr_in6 *)addr1; 270 six2 = (const struct sockaddr_in6 *)addr2; 271 return !((six1->sin6_port == six2->sin6_port || !cmp_port) && 272 !memcmp(&six1->sin6_addr, &six2->sin6_addr, 273 sizeof(struct in6_addr))); 274 } else if (addr1->ss_family == AF_UNIX) { 275 un1 = (const struct sockaddr_un *)addr1; 276 un2 = (const struct sockaddr_un *)addr2; 277 return memcmp(un1, un2, addr1_len); 278 } 279 280 return -1; 281 } 282 283 static int cmp_sock_addr(info_fn fn, int sock1, 284 const struct sockaddr_storage *addr2, 285 socklen_t addr2_len, bool cmp_port) 286 { 287 struct sockaddr_storage addr1; 288 socklen_t len1 = sizeof(addr1); 289 290 memset(&addr1, 0, len1); 291 if (fn(sock1, (struct sockaddr *)&addr1, (socklen_t *)&len1) != 0) 292 return -1; 293 294 return cmp_addr(&addr1, len1, addr2, addr2_len, cmp_port); 295 } 296 297 static int cmp_local_addr(int sock1, const struct sockaddr_storage *addr2, 298 socklen_t addr2_len, bool cmp_port) 299 { 300 return cmp_sock_addr(getsockname, sock1, addr2, addr2_len, cmp_port); 301 } 302 303 static int cmp_peer_addr(int sock1, const struct sockaddr_storage *addr2, 304 socklen_t addr2_len, bool cmp_port) 305 { 306 return cmp_sock_addr(getpeername, sock1, addr2, addr2_len, cmp_port); 307 } 308 309 static void test_bind(struct sock_addr_test *test) 310 { 311 struct sockaddr_storage expected_addr; 312 socklen_t expected_addr_len = sizeof(struct sockaddr_storage); 313 int serv = -1, client = -1, err; 314 315 serv = start_server(test->socket_family, test->socket_type, 316 test->requested_addr, test->requested_port, 0); 317 if (!ASSERT_GE(serv, 0, "start_server")) 318 goto cleanup; 319 320 err = make_sockaddr(test->socket_family, 321 test->expected_addr, test->expected_port, 322 &expected_addr, &expected_addr_len); 323 if (!ASSERT_EQ(err, 0, "make_sockaddr")) 324 goto cleanup; 325 326 err = cmp_local_addr(serv, &expected_addr, expected_addr_len, true); 327 if (!ASSERT_EQ(err, 0, "cmp_local_addr")) 328 goto cleanup; 329 330 /* Try to connect to server just in case */ 331 client = connect_to_addr(&expected_addr, expected_addr_len, test->socket_type); 332 if (!ASSERT_GE(client, 0, "connect_to_addr")) 333 goto cleanup; 334 335 cleanup: 336 if (client != -1) 337 close(client); 338 if (serv != -1) 339 close(serv); 340 } 341 342 static void test_connect(struct sock_addr_test *test) 343 { 344 struct sockaddr_storage addr, expected_addr, expected_src_addr; 345 socklen_t addr_len = sizeof(struct sockaddr_storage), 346 expected_addr_len = sizeof(struct sockaddr_storage), 347 expected_src_addr_len = sizeof(struct sockaddr_storage); 348 int serv = -1, client = -1, err; 349 350 serv = start_server(test->socket_family, test->socket_type, 351 test->expected_addr, test->expected_port, 0); 352 if (!ASSERT_GE(serv, 0, "start_server")) 353 goto cleanup; 354 355 err = make_sockaddr(test->socket_family, test->requested_addr, test->requested_port, 356 &addr, &addr_len); 357 if (!ASSERT_EQ(err, 0, "make_sockaddr")) 358 goto cleanup; 359 360 client = connect_to_addr(&addr, addr_len, test->socket_type); 361 if (!ASSERT_GE(client, 0, "connect_to_addr")) 362 goto cleanup; 363 364 err = make_sockaddr(test->socket_family, test->expected_addr, test->expected_port, 365 &expected_addr, &expected_addr_len); 366 if (!ASSERT_EQ(err, 0, "make_sockaddr")) 367 goto cleanup; 368 369 if (test->expected_src_addr) { 370 err = make_sockaddr(test->socket_family, test->expected_src_addr, 0, 371 &expected_src_addr, &expected_src_addr_len); 372 if (!ASSERT_EQ(err, 0, "make_sockaddr")) 373 goto cleanup; 374 } 375 376 err = cmp_peer_addr(client, &expected_addr, expected_addr_len, true); 377 if (!ASSERT_EQ(err, 0, "cmp_peer_addr")) 378 goto cleanup; 379 380 if (test->expected_src_addr) { 381 err = cmp_local_addr(client, &expected_src_addr, expected_src_addr_len, false); 382 if (!ASSERT_EQ(err, 0, "cmp_local_addr")) 383 goto cleanup; 384 } 385 cleanup: 386 if (client != -1) 387 close(client); 388 if (serv != -1) 389 close(serv); 390 } 391 392 static void test_xmsg(struct sock_addr_test *test) 393 { 394 struct sockaddr_storage addr, src_addr; 395 socklen_t addr_len = sizeof(struct sockaddr_storage), 396 src_addr_len = sizeof(struct sockaddr_storage); 397 struct msghdr hdr; 398 struct iovec iov; 399 char data = 'a'; 400 int serv = -1, client = -1, err; 401 402 /* Unlike the other tests, here we test that we can rewrite the src addr 403 * with a recvmsg() hook. 404 */ 405 406 serv = start_server(test->socket_family, test->socket_type, 407 test->expected_addr, test->expected_port, 0); 408 if (!ASSERT_GE(serv, 0, "start_server")) 409 goto cleanup; 410 411 client = socket(test->socket_family, test->socket_type, 0); 412 if (!ASSERT_GE(client, 0, "socket")) 413 goto cleanup; 414 415 /* AF_UNIX sockets have to be bound to something to trigger the recvmsg bpf program. */ 416 if (test->socket_family == AF_UNIX) { 417 err = make_sockaddr(AF_UNIX, SRCUN_ADDRESS, 0, &src_addr, &src_addr_len); 418 if (!ASSERT_EQ(err, 0, "make_sockaddr")) 419 goto cleanup; 420 421 err = bind(client, (const struct sockaddr *) &src_addr, src_addr_len); 422 if (!ASSERT_OK(err, "bind")) 423 goto cleanup; 424 } 425 426 err = make_sockaddr(test->socket_family, test->requested_addr, test->requested_port, 427 &addr, &addr_len); 428 if (!ASSERT_EQ(err, 0, "make_sockaddr")) 429 goto cleanup; 430 431 if (test->socket_type == SOCK_DGRAM) { 432 memset(&iov, 0, sizeof(iov)); 433 iov.iov_base = &data; 434 iov.iov_len = sizeof(data); 435 436 memset(&hdr, 0, sizeof(hdr)); 437 hdr.msg_name = (void *)&addr; 438 hdr.msg_namelen = addr_len; 439 hdr.msg_iov = &iov; 440 hdr.msg_iovlen = 1; 441 442 err = sendmsg(client, &hdr, 0); 443 if (!ASSERT_EQ(err, sizeof(data), "sendmsg")) 444 goto cleanup; 445 } else { 446 /* Testing with connection-oriented sockets is only valid for 447 * recvmsg() tests. 448 */ 449 if (!ASSERT_EQ(test->type, SOCK_ADDR_TEST_RECVMSG, "recvmsg")) 450 goto cleanup; 451 452 err = connect(client, (const struct sockaddr *)&addr, addr_len); 453 if (!ASSERT_OK(err, "connect")) 454 goto cleanup; 455 456 err = send(client, &data, sizeof(data), 0); 457 if (!ASSERT_EQ(err, sizeof(data), "send")) 458 goto cleanup; 459 460 err = listen(serv, 0); 461 if (!ASSERT_OK(err, "listen")) 462 goto cleanup; 463 464 err = accept(serv, NULL, NULL); 465 if (!ASSERT_GE(err, 0, "accept")) 466 goto cleanup; 467 468 close(serv); 469 serv = err; 470 } 471 472 addr_len = src_addr_len = sizeof(struct sockaddr_storage); 473 474 err = recvfrom(serv, &data, sizeof(data), 0, (struct sockaddr *) &src_addr, &src_addr_len); 475 if (!ASSERT_EQ(err, sizeof(data), "recvfrom")) 476 goto cleanup; 477 478 ASSERT_EQ(data, 'a', "data mismatch"); 479 480 if (test->expected_src_addr) { 481 err = make_sockaddr(test->socket_family, test->expected_src_addr, 0, 482 &addr, &addr_len); 483 if (!ASSERT_EQ(err, 0, "make_sockaddr")) 484 goto cleanup; 485 486 err = cmp_addr(&src_addr, src_addr_len, &addr, addr_len, false); 487 if (!ASSERT_EQ(err, 0, "cmp_addr")) 488 goto cleanup; 489 } 490 491 cleanup: 492 if (client != -1) 493 close(client); 494 if (serv != -1) 495 close(serv); 496 } 497 498 static void test_getsockname(struct sock_addr_test *test) 499 { 500 struct sockaddr_storage expected_addr; 501 socklen_t expected_addr_len = sizeof(struct sockaddr_storage); 502 int serv = -1, err; 503 504 serv = start_server(test->socket_family, test->socket_type, 505 test->requested_addr, test->requested_port, 0); 506 if (!ASSERT_GE(serv, 0, "start_server")) 507 goto cleanup; 508 509 err = make_sockaddr(test->socket_family, 510 test->expected_addr, test->expected_port, 511 &expected_addr, &expected_addr_len); 512 if (!ASSERT_EQ(err, 0, "make_sockaddr")) 513 goto cleanup; 514 515 err = cmp_local_addr(serv, &expected_addr, expected_addr_len, true); 516 if (!ASSERT_EQ(err, 0, "cmp_local_addr")) 517 goto cleanup; 518 519 cleanup: 520 if (serv != -1) 521 close(serv); 522 } 523 524 static void test_getpeername(struct sock_addr_test *test) 525 { 526 struct sockaddr_storage addr, expected_addr; 527 socklen_t addr_len = sizeof(struct sockaddr_storage), 528 expected_addr_len = sizeof(struct sockaddr_storage); 529 int serv = -1, client = -1, err; 530 531 serv = start_server(test->socket_family, test->socket_type, 532 test->requested_addr, test->requested_port, 0); 533 if (!ASSERT_GE(serv, 0, "start_server")) 534 goto cleanup; 535 536 err = make_sockaddr(test->socket_family, test->requested_addr, test->requested_port, 537 &addr, &addr_len); 538 if (!ASSERT_EQ(err, 0, "make_sockaddr")) 539 goto cleanup; 540 541 client = connect_to_addr(&addr, addr_len, test->socket_type); 542 if (!ASSERT_GE(client, 0, "connect_to_addr")) 543 goto cleanup; 544 545 err = make_sockaddr(test->socket_family, test->expected_addr, test->expected_port, 546 &expected_addr, &expected_addr_len); 547 if (!ASSERT_EQ(err, 0, "make_sockaddr")) 548 goto cleanup; 549 550 err = cmp_peer_addr(client, &expected_addr, expected_addr_len, true); 551 if (!ASSERT_EQ(err, 0, "cmp_peer_addr")) 552 goto cleanup; 553 554 cleanup: 555 if (client != -1) 556 close(client); 557 if (serv != -1) 558 close(serv); 559 } 560 561 void test_sock_addr(void) 562 { 563 int cgroup_fd = -1; 564 void *skel; 565 566 cgroup_fd = test__join_cgroup("/sock_addr"); 567 if (!ASSERT_GE(cgroup_fd, 0, "join_cgroup")) 568 goto cleanup; 569 570 for (size_t i = 0; i < ARRAY_SIZE(tests); ++i) { 571 struct sock_addr_test *test = &tests[i]; 572 573 if (!test__start_subtest(test->name)) 574 continue; 575 576 skel = test->loadfn(cgroup_fd); 577 if (!skel) 578 continue; 579 580 switch (test->type) { 581 /* Not exercised yet but we leave this code here for when the 582 * INET and INET6 sockaddr tests are migrated to this file in 583 * the future. 584 */ 585 case SOCK_ADDR_TEST_BIND: 586 test_bind(test); 587 break; 588 case SOCK_ADDR_TEST_CONNECT: 589 test_connect(test); 590 break; 591 case SOCK_ADDR_TEST_SENDMSG: 592 case SOCK_ADDR_TEST_RECVMSG: 593 test_xmsg(test); 594 break; 595 case SOCK_ADDR_TEST_GETSOCKNAME: 596 test_getsockname(test); 597 break; 598 case SOCK_ADDR_TEST_GETPEERNAME: 599 test_getpeername(test); 600 break; 601 default: 602 ASSERT_TRUE(false, "Unknown sock addr test type"); 603 break; 604 } 605 606 test->destroyfn(skel); 607 } 608 609 cleanup: 610 if (cgroup_fd >= 0) 611 close(cgroup_fd); 612 } 613