1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright 2025 Google LLC */ 3 4 #include <test_progs.h> 5 #include "sk_bypass_prot_mem.skel.h" 6 #include "network_helpers.h" 7 8 #define NR_PAGES 32 9 #define NR_SOCKETS 2 10 #define BUF_TOTAL (NR_PAGES * 4096 / NR_SOCKETS) 11 #define BUF_SINGLE 1024 12 #define NR_SEND (BUF_TOTAL / BUF_SINGLE) 13 14 struct test_case { 15 char name[8]; 16 int family; 17 int type; 18 int (*create_sockets)(struct test_case *test_case, int sk[], int len); 19 long (*get_memory_allocated)(struct test_case *test_case, struct sk_bypass_prot_mem *skel); 20 }; 21 22 static int tcp_create_sockets(struct test_case *test_case, int sk[], int len) 23 { 24 int server, i, err = 0; 25 26 server = start_server(test_case->family, test_case->type, NULL, 0, 0); 27 if (!ASSERT_GE(server, 0, "start_server_str")) 28 return server; 29 30 /* Keep for-loop so we can change NR_SOCKETS easily. */ 31 for (i = 0; i < len; i += 2) { 32 sk[i] = connect_to_fd(server, 0); 33 if (sk[i] < 0) { 34 ASSERT_GE(sk[i], 0, "connect_to_fd"); 35 err = sk[i]; 36 break; 37 } 38 39 sk[i + 1] = accept(server, NULL, NULL); 40 if (sk[i + 1] < 0) { 41 ASSERT_GE(sk[i + 1], 0, "accept"); 42 err = sk[i + 1]; 43 break; 44 } 45 } 46 47 close(server); 48 49 return err; 50 } 51 52 static int udp_create_sockets(struct test_case *test_case, int sk[], int len) 53 { 54 int i, j, err, rcvbuf = BUF_TOTAL; 55 56 /* Keep for-loop so we can change NR_SOCKETS easily. */ 57 for (i = 0; i < len; i += 2) { 58 sk[i] = start_server(test_case->family, test_case->type, NULL, 0, 0); 59 if (sk[i] < 0) { 60 ASSERT_GE(sk[i], 0, "start_server"); 61 return sk[i]; 62 } 63 64 sk[i + 1] = connect_to_fd(sk[i], 0); 65 if (sk[i + 1] < 0) { 66 ASSERT_GE(sk[i + 1], 0, "connect_to_fd"); 67 return sk[i + 1]; 68 } 69 70 err = connect_fd_to_fd(sk[i], sk[i + 1], 0); 71 if (err) { 72 ASSERT_EQ(err, 0, "connect_fd_to_fd"); 73 return err; 74 } 75 76 for (j = 0; j < 2; j++) { 77 err = setsockopt(sk[i + j], SOL_SOCKET, SO_RCVBUF, &rcvbuf, sizeof(int)); 78 if (err) { 79 ASSERT_EQ(err, 0, "setsockopt(SO_RCVBUF)"); 80 return err; 81 } 82 } 83 } 84 85 return 0; 86 } 87 88 static long get_memory_allocated(struct test_case *test_case, 89 bool *activated, long *memory_allocated) 90 { 91 int sk; 92 93 *activated = true; 94 95 /* AF_INET and AF_INET6 share the same memory_allocated. 96 * tcp_init_sock() is called by AF_INET and AF_INET6, 97 * but udp_lib_init_sock() is inline. 98 */ 99 sk = socket(AF_INET, test_case->type, 0); 100 if (!ASSERT_GE(sk, 0, "get_memory_allocated")) 101 return -1; 102 103 close(sk); 104 105 return *memory_allocated; 106 } 107 108 static long tcp_get_memory_allocated(struct test_case *test_case, struct sk_bypass_prot_mem *skel) 109 { 110 return get_memory_allocated(test_case, 111 &skel->bss->tcp_activated, 112 &skel->bss->tcp_memory_allocated); 113 } 114 115 static long udp_get_memory_allocated(struct test_case *test_case, struct sk_bypass_prot_mem *skel) 116 { 117 return get_memory_allocated(test_case, 118 &skel->bss->udp_activated, 119 &skel->bss->udp_memory_allocated); 120 } 121 122 static int check_bypass(struct test_case *test_case, 123 struct sk_bypass_prot_mem *skel, bool bypass) 124 { 125 char buf[BUF_SINGLE] = {}; 126 long memory_allocated[2]; 127 int sk[NR_SOCKETS]; 128 int err, i, j; 129 130 for (i = 0; i < ARRAY_SIZE(sk); i++) 131 sk[i] = -1; 132 133 err = test_case->create_sockets(test_case, sk, ARRAY_SIZE(sk)); 134 if (err) 135 goto close; 136 137 memory_allocated[0] = test_case->get_memory_allocated(test_case, skel); 138 139 /* allocate pages >= NR_PAGES */ 140 for (i = 0; i < ARRAY_SIZE(sk); i++) { 141 for (j = 0; j < NR_SEND; j++) { 142 int bytes = send(sk[i], buf, sizeof(buf), 0); 143 144 /* Avoid too noisy logs when something failed. */ 145 if (bytes != sizeof(buf)) { 146 ASSERT_EQ(bytes, sizeof(buf), "send"); 147 if (bytes < 0) { 148 err = bytes; 149 goto drain; 150 } 151 } 152 } 153 } 154 155 memory_allocated[1] = test_case->get_memory_allocated(test_case, skel); 156 157 if (bypass) 158 ASSERT_LE(memory_allocated[1], memory_allocated[0] + 10, "bypass"); 159 else 160 ASSERT_GT(memory_allocated[1], memory_allocated[0] + NR_PAGES, "no bypass"); 161 162 drain: 163 if (test_case->type == SOCK_DGRAM) { 164 /* UDP starts purging sk->sk_receive_queue after one RCU 165 * grace period, then udp_memory_allocated goes down, 166 * so drain the queue before close(). 167 */ 168 for (i = 0; i < ARRAY_SIZE(sk); i++) { 169 for (j = 0; j < NR_SEND; j++) { 170 int bytes = recv(sk[i], buf, 1, MSG_DONTWAIT | MSG_TRUNC); 171 172 if (bytes == sizeof(buf)) 173 continue; 174 if (bytes != -1 || errno != EAGAIN) 175 PRINT_FAIL("bytes: %d, errno: %s\n", bytes, strerror(errno)); 176 break; 177 } 178 } 179 } 180 181 close: 182 for (i = 0; i < ARRAY_SIZE(sk); i++) { 183 if (sk[i] < 0) 184 break; 185 186 close(sk[i]); 187 } 188 189 return err; 190 } 191 192 static void run_test(struct test_case *test_case) 193 { 194 struct sk_bypass_prot_mem *skel; 195 struct nstoken *nstoken; 196 int cgroup, err; 197 198 skel = sk_bypass_prot_mem__open_and_load(); 199 if (!ASSERT_OK_PTR(skel, "open_and_load")) 200 return; 201 202 skel->bss->nr_cpus = libbpf_num_possible_cpus(); 203 204 err = sk_bypass_prot_mem__attach(skel); 205 if (!ASSERT_OK(err, "attach")) 206 goto destroy_skel; 207 208 cgroup = test__join_cgroup("/sk_bypass_prot_mem"); 209 if (!ASSERT_GE(cgroup, 0, "join_cgroup")) 210 goto destroy_skel; 211 212 err = make_netns("sk_bypass_prot_mem"); 213 if (!ASSERT_EQ(err, 0, "make_netns")) 214 goto close_cgroup; 215 216 nstoken = open_netns("sk_bypass_prot_mem"); 217 if (!ASSERT_OK_PTR(nstoken, "open_netns")) 218 goto remove_netns; 219 220 err = check_bypass(test_case, skel, false); 221 if (!ASSERT_EQ(err, 0, "test_bypass(false)")) 222 goto close_netns; 223 224 err = write_sysctl("/proc/sys/net/core/bypass_prot_mem", "1"); 225 if (!ASSERT_EQ(err, 0, "write_sysctl(1)")) 226 goto close_netns; 227 228 err = check_bypass(test_case, skel, true); 229 if (!ASSERT_EQ(err, 0, "test_bypass(true by sysctl)")) 230 goto close_netns; 231 232 err = write_sysctl("/proc/sys/net/core/bypass_prot_mem", "0"); 233 if (!ASSERT_EQ(err, 0, "write_sysctl(0)")) 234 goto close_netns; 235 236 skel->links.sock_create = bpf_program__attach_cgroup(skel->progs.sock_create, cgroup); 237 if (!ASSERT_OK_PTR(skel->links.sock_create, "attach_cgroup(sock_create)")) 238 goto close_netns; 239 240 err = check_bypass(test_case, skel, true); 241 ASSERT_EQ(err, 0, "test_bypass(true by bpf)"); 242 243 close_netns: 244 close_netns(nstoken); 245 remove_netns: 246 remove_netns("sk_bypass_prot_mem"); 247 close_cgroup: 248 close(cgroup); 249 destroy_skel: 250 sk_bypass_prot_mem__destroy(skel); 251 } 252 253 static struct test_case test_cases[] = { 254 { 255 .name = "TCP ", 256 .family = AF_INET, 257 .type = SOCK_STREAM, 258 .create_sockets = tcp_create_sockets, 259 .get_memory_allocated = tcp_get_memory_allocated, 260 }, 261 { 262 .name = "UDP ", 263 .family = AF_INET, 264 .type = SOCK_DGRAM, 265 .create_sockets = udp_create_sockets, 266 .get_memory_allocated = udp_get_memory_allocated, 267 }, 268 { 269 .name = "TCPv6", 270 .family = AF_INET6, 271 .type = SOCK_STREAM, 272 .create_sockets = tcp_create_sockets, 273 .get_memory_allocated = tcp_get_memory_allocated, 274 }, 275 { 276 .name = "UDPv6", 277 .family = AF_INET6, 278 .type = SOCK_DGRAM, 279 .create_sockets = udp_create_sockets, 280 .get_memory_allocated = udp_get_memory_allocated, 281 }, 282 }; 283 284 void serial_test_sk_bypass_prot_mem(void) 285 { 286 int i; 287 288 for (i = 0; i < ARRAY_SIZE(test_cases); i++) { 289 if (test__start_subtest(test_cases[i].name)) 290 run_test(&test_cases[i]); 291 } 292 } 293