1 // SPDX-License-Identifier: GPL-2.0-or-later 2 3 #define _GNU_SOURCE 4 5 #include <assert.h> 6 #include <fcntl.h> 7 #include <limits.h> 8 #include <sched.h> 9 #include <stdlib.h> 10 #include <sys/mount.h> 11 #include <sys/stat.h> 12 #include <sys/wait.h> 13 #include <linux/nsfs.h> 14 #include <linux/stat.h> 15 16 #include "statmount.h" 17 #include "../utils.h" 18 #include "kselftest.h" 19 20 #define NSID_PASS 0 21 #define NSID_FAIL 1 22 #define NSID_SKIP 2 23 #define NSID_ERROR 3 24 25 static void handle_result(int ret, const char *testname) 26 { 27 if (ret == NSID_PASS) 28 ksft_test_result_pass("%s\n", testname); 29 else if (ret == NSID_FAIL) 30 ksft_test_result_fail("%s\n", testname); 31 else if (ret == NSID_ERROR) 32 ksft_exit_fail_msg("%s\n", testname); 33 else 34 ksft_test_result_skip("%s\n", testname); 35 } 36 37 static inline int wait_for_pid(pid_t pid) 38 { 39 int status, ret; 40 41 again: 42 ret = waitpid(pid, &status, 0); 43 if (ret == -1) { 44 if (errno == EINTR) 45 goto again; 46 47 ksft_print_msg("waitpid returned -1, errno=%d\n", errno); 48 return -1; 49 } 50 51 if (!WIFEXITED(status)) { 52 ksft_print_msg( 53 "waitpid !WIFEXITED, WIFSIGNALED=%d, WTERMSIG=%d\n", 54 WIFSIGNALED(status), WTERMSIG(status)); 55 return -1; 56 } 57 58 ret = WEXITSTATUS(status); 59 return ret; 60 } 61 62 static int get_mnt_ns_id(const char *mnt_ns, uint64_t *mnt_ns_id) 63 { 64 int fd = open(mnt_ns, O_RDONLY); 65 66 if (fd < 0) { 67 ksft_print_msg("failed to open for ns %s: %s\n", 68 mnt_ns, strerror(errno)); 69 sleep(60); 70 return NSID_ERROR; 71 } 72 73 if (ioctl(fd, NS_GET_MNTNS_ID, mnt_ns_id) < 0) { 74 ksft_print_msg("failed to get the nsid for ns %s: %s\n", 75 mnt_ns, strerror(errno)); 76 return NSID_ERROR; 77 } 78 close(fd); 79 return NSID_PASS; 80 } 81 82 static int setup_namespace(void) 83 { 84 if (setup_userns() != 0) 85 return NSID_ERROR; 86 87 return NSID_PASS; 88 } 89 90 static int _test_statmount_mnt_ns_id(void) 91 { 92 struct statmount sm; 93 uint64_t mnt_ns_id; 94 uint64_t root_id; 95 int ret; 96 97 ret = get_mnt_ns_id("/proc/self/ns/mnt", &mnt_ns_id); 98 if (ret != NSID_PASS) 99 return ret; 100 101 root_id = get_unique_mnt_id("/"); 102 if (!root_id) 103 return NSID_ERROR; 104 105 ret = statmount(root_id, 0, 0, STATMOUNT_MNT_NS_ID, &sm, sizeof(sm), 0); 106 if (ret == -1) { 107 ksft_print_msg("statmount mnt ns id: %s\n", strerror(errno)); 108 return NSID_ERROR; 109 } 110 111 if (sm.size != sizeof(sm)) { 112 ksft_print_msg("unexpected size: %u != %u\n", sm.size, 113 (uint32_t)sizeof(sm)); 114 return NSID_FAIL; 115 } 116 if (sm.mask != STATMOUNT_MNT_NS_ID) { 117 ksft_print_msg("statmount mnt ns id unavailable\n"); 118 return NSID_SKIP; 119 } 120 121 if (sm.mnt_ns_id != mnt_ns_id) { 122 ksft_print_msg("unexpected mnt ns ID: 0x%llx != 0x%llx\n", 123 (unsigned long long)sm.mnt_ns_id, 124 (unsigned long long)mnt_ns_id); 125 return NSID_FAIL; 126 } 127 128 return NSID_PASS; 129 } 130 131 static int _test_statmount_mnt_ns_id_by_fd(void) 132 { 133 struct statmount sm; 134 uint64_t mnt_ns_id; 135 int ret, fd, mounted = 1, status = NSID_ERROR; 136 char mnt[] = "/statmount.fd.XXXXXX"; 137 138 ret = get_mnt_ns_id("/proc/self/ns/mnt", &mnt_ns_id); 139 if (ret != NSID_PASS) 140 return ret; 141 142 if (!mkdtemp(mnt)) { 143 ksft_print_msg("statmount by fd mnt ns id mkdtemp: %s\n", strerror(errno)); 144 return NSID_ERROR; 145 } 146 147 if (mount(mnt, mnt, NULL, MS_BIND, 0)) { 148 ksft_print_msg("statmount by fd mnt ns id mount: %s\n", strerror(errno)); 149 status = NSID_ERROR; 150 goto err; 151 } 152 153 fd = open(mnt, O_PATH); 154 if (fd < 0) { 155 ksft_print_msg("statmount by fd mnt ns id open: %s\n", strerror(errno)); 156 goto err; 157 } 158 159 ret = statmount(0, 0, fd, STATMOUNT_MNT_NS_ID, &sm, sizeof(sm), STATMOUNT_BY_FD); 160 if (ret == -1) { 161 ksft_print_msg("statmount mnt ns id statmount: %s\n", strerror(errno)); 162 status = NSID_ERROR; 163 goto out; 164 } 165 166 if (sm.size != sizeof(sm)) { 167 ksft_print_msg("unexpected size: %u != %u\n", sm.size, 168 (uint32_t)sizeof(sm)); 169 status = NSID_FAIL; 170 goto out; 171 } 172 if (sm.mask != STATMOUNT_MNT_NS_ID) { 173 ksft_print_msg("statmount mnt ns id unavailable\n"); 174 status = NSID_SKIP; 175 goto out; 176 } 177 178 if (sm.mnt_ns_id != mnt_ns_id) { 179 ksft_print_msg("unexpected mnt ns ID: 0x%llx != 0x%llx\n", 180 (unsigned long long)sm.mnt_ns_id, 181 (unsigned long long)mnt_ns_id); 182 status = NSID_FAIL; 183 goto out; 184 } 185 186 mounted = 0; 187 if (umount2(mnt, MNT_DETACH)) { 188 ksft_print_msg("statmount by fd mnt ns id umount2: %s\n", strerror(errno)); 189 goto out; 190 } 191 192 ret = statmount(0, 0, fd, STATMOUNT_MNT_NS_ID, &sm, sizeof(sm), STATMOUNT_BY_FD); 193 if (ret == -1) { 194 ksft_print_msg("statmount mnt ns id statmount: %s\n", strerror(errno)); 195 status = NSID_ERROR; 196 goto out; 197 } 198 199 if (sm.size != sizeof(sm)) { 200 ksft_print_msg("unexpected size: %u != %u\n", sm.size, 201 (uint32_t)sizeof(sm)); 202 status = NSID_FAIL; 203 goto out; 204 } 205 206 if (sm.mask == STATMOUNT_MNT_NS_ID) { 207 ksft_print_msg("unexpected STATMOUNT_MNT_NS_ID in mask\n"); 208 status = NSID_FAIL; 209 goto out; 210 } 211 212 status = NSID_PASS; 213 out: 214 close(fd); 215 if (mounted) 216 umount2(mnt, MNT_DETACH); 217 err: 218 rmdir(mnt); 219 return status; 220 } 221 222 223 static void test_statmount_mnt_ns_id(void) 224 { 225 pid_t pid; 226 int ret; 227 228 pid = fork(); 229 if (pid < 0) 230 ksft_exit_fail_msg("failed to fork: %s\n", strerror(errno)); 231 232 /* We're the original pid, wait for the result. */ 233 if (pid != 0) { 234 ret = wait_for_pid(pid); 235 handle_result(ret, "test statmount ns id"); 236 return; 237 } 238 239 ret = setup_namespace(); 240 if (ret != NSID_PASS) 241 exit(ret); 242 ret = _test_statmount_mnt_ns_id(); 243 if (ret != NSID_PASS) 244 exit(ret); 245 ret = _test_statmount_mnt_ns_id_by_fd(); 246 exit(ret); 247 } 248 249 static int validate_external_listmount(pid_t pid, uint64_t child_nr_mounts) 250 { 251 uint64_t list[256]; 252 uint64_t mnt_ns_id; 253 uint64_t nr_mounts; 254 char buf[256]; 255 int ret; 256 257 /* Get the mount ns id for our child. */ 258 snprintf(buf, sizeof(buf), "/proc/%lu/ns/mnt", (unsigned long)pid); 259 ret = get_mnt_ns_id(buf, &mnt_ns_id); 260 261 nr_mounts = listmount(LSMT_ROOT, mnt_ns_id, 0, list, 256, 0); 262 if (nr_mounts == (uint64_t)-1) { 263 ksft_print_msg("listmount: %s\n", strerror(errno)); 264 return NSID_ERROR; 265 } 266 267 if (nr_mounts != child_nr_mounts) { 268 ksft_print_msg("listmount results is %zi != %zi\n", nr_mounts, 269 child_nr_mounts); 270 return NSID_FAIL; 271 } 272 273 /* Validate that all of our entries match our mnt_ns_id. */ 274 for (int i = 0; i < nr_mounts; i++) { 275 struct statmount sm; 276 277 ret = statmount(list[i], mnt_ns_id, 0, STATMOUNT_MNT_NS_ID, &sm, 278 sizeof(sm), 0); 279 if (ret < 0) { 280 ksft_print_msg("statmount mnt ns id: %s\n", strerror(errno)); 281 return NSID_ERROR; 282 } 283 284 if (sm.mask != STATMOUNT_MNT_NS_ID) { 285 ksft_print_msg("statmount mnt ns id unavailable\n"); 286 return NSID_SKIP; 287 } 288 289 if (sm.mnt_ns_id != mnt_ns_id) { 290 ksft_print_msg("listmount gave us the wrong ns id: 0x%llx != 0x%llx\n", 291 (unsigned long long)sm.mnt_ns_id, 292 (unsigned long long)mnt_ns_id); 293 return NSID_FAIL; 294 } 295 } 296 297 return NSID_PASS; 298 } 299 300 static void test_listmount_ns(void) 301 { 302 uint64_t nr_mounts; 303 char pval; 304 int child_ready_pipe[2]; 305 int parent_ready_pipe[2]; 306 pid_t pid; 307 int ret, child_ret; 308 309 if (pipe(child_ready_pipe) < 0) 310 ksft_exit_fail_msg("failed to create the child pipe: %s\n", 311 strerror(errno)); 312 if (pipe(parent_ready_pipe) < 0) 313 ksft_exit_fail_msg("failed to create the parent pipe: %s\n", 314 strerror(errno)); 315 316 pid = fork(); 317 if (pid < 0) 318 ksft_exit_fail_msg("failed to fork: %s\n", strerror(errno)); 319 320 if (pid == 0) { 321 char cval; 322 uint64_t list[256]; 323 324 close(child_ready_pipe[0]); 325 close(parent_ready_pipe[1]); 326 327 ret = setup_namespace(); 328 if (ret != NSID_PASS) 329 exit(ret); 330 331 nr_mounts = listmount(LSMT_ROOT, 0, 0, list, 256, 0); 332 if (nr_mounts == (uint64_t)-1) { 333 ksft_print_msg("listmount: %s\n", strerror(errno)); 334 exit(NSID_FAIL); 335 } 336 337 /* 338 * Tell our parent how many mounts we have, and then wait for it 339 * to tell us we're done. 340 */ 341 if (write(child_ready_pipe[1], &nr_mounts, sizeof(nr_mounts)) != 342 sizeof(nr_mounts)) 343 ret = NSID_ERROR; 344 if (read(parent_ready_pipe[0], &cval, sizeof(cval)) != sizeof(cval)) 345 ret = NSID_ERROR; 346 exit(NSID_PASS); 347 } 348 349 close(child_ready_pipe[1]); 350 close(parent_ready_pipe[0]); 351 352 /* Wait until the child has created everything. */ 353 if (read(child_ready_pipe[0], &nr_mounts, sizeof(nr_mounts)) != 354 sizeof(nr_mounts)) 355 ret = NSID_ERROR; 356 357 ret = validate_external_listmount(pid, nr_mounts); 358 359 if (write(parent_ready_pipe[1], &pval, sizeof(pval)) != sizeof(pval)) 360 ret = NSID_ERROR; 361 362 child_ret = wait_for_pid(pid); 363 if (child_ret != NSID_PASS) 364 ret = child_ret; 365 handle_result(ret, "test listmount ns id"); 366 } 367 368 int main(void) 369 { 370 int ret; 371 372 ksft_print_header(); 373 ret = statmount(0, 0, 0, 0, NULL, 0, 0); 374 assert(ret == -1); 375 if (errno == ENOSYS) 376 ksft_exit_skip("statmount() syscall not supported\n"); 377 378 ksft_set_plan(2); 379 test_statmount_mnt_ns_id(); 380 test_listmount_ns(); 381 382 if (ksft_get_fail_cnt() + ksft_get_error_cnt() > 0) 383 ksft_exit_fail(); 384 else 385 ksft_exit_pass(); 386 } 387