1 // SPDX-License-Identifier: GPL-2.0-or-later 2 3 #define _GNU_SOURCE 4 5 #include <assert.h> 6 #include <fcntl.h> 7 #include <limits.h> 8 #include <sched.h> 9 #include <stdlib.h> 10 #include <sys/mount.h> 11 #include <sys/stat.h> 12 #include <sys/wait.h> 13 #include <linux/nsfs.h> 14 #include <linux/stat.h> 15 16 #include "statmount.h" 17 #include "../utils.h" 18 #include "kselftest.h" 19 20 #define NSID_PASS 0 21 #define NSID_FAIL 1 22 #define NSID_SKIP 2 23 #define NSID_ERROR 3 24 25 static void handle_result(int ret, const char *testname) 26 { 27 if (ret == NSID_PASS) 28 ksft_test_result_pass("%s\n", testname); 29 else if (ret == NSID_FAIL) 30 ksft_test_result_fail("%s\n", testname); 31 else if (ret == NSID_ERROR) 32 ksft_exit_fail_msg("%s\n", testname); 33 else 34 ksft_test_result_skip("%s\n", testname); 35 } 36 37 static int get_mnt_ns_id(const char *mnt_ns, uint64_t *mnt_ns_id) 38 { 39 int fd = open(mnt_ns, O_RDONLY); 40 41 if (fd < 0) { 42 ksft_print_msg("failed to open for ns %s: %s\n", 43 mnt_ns, strerror(errno)); 44 sleep(60); 45 return NSID_ERROR; 46 } 47 48 if (ioctl(fd, NS_GET_MNTNS_ID, mnt_ns_id) < 0) { 49 ksft_print_msg("failed to get the nsid for ns %s: %s\n", 50 mnt_ns, strerror(errno)); 51 return NSID_ERROR; 52 } 53 close(fd); 54 return NSID_PASS; 55 } 56 57 static int setup_namespace(void) 58 { 59 if (setup_userns() != 0) 60 return NSID_ERROR; 61 62 return NSID_PASS; 63 } 64 65 static int _test_statmount_mnt_ns_id(void) 66 { 67 struct statmount sm; 68 uint64_t mnt_ns_id; 69 uint64_t root_id; 70 int ret; 71 72 ret = get_mnt_ns_id("/proc/self/ns/mnt", &mnt_ns_id); 73 if (ret != NSID_PASS) 74 return ret; 75 76 root_id = get_unique_mnt_id("/"); 77 if (!root_id) 78 return NSID_ERROR; 79 80 ret = statmount(root_id, 0, 0, STATMOUNT_MNT_NS_ID, &sm, sizeof(sm), 0); 81 if (ret == -1) { 82 ksft_print_msg("statmount mnt ns id: %s\n", strerror(errno)); 83 return NSID_ERROR; 84 } 85 86 if (sm.size != sizeof(sm)) { 87 ksft_print_msg("unexpected size: %u != %u\n", sm.size, 88 (uint32_t)sizeof(sm)); 89 return NSID_FAIL; 90 } 91 if (sm.mask != STATMOUNT_MNT_NS_ID) { 92 ksft_print_msg("statmount mnt ns id unavailable\n"); 93 return NSID_SKIP; 94 } 95 96 if (sm.mnt_ns_id != mnt_ns_id) { 97 ksft_print_msg("unexpected mnt ns ID: 0x%llx != 0x%llx\n", 98 (unsigned long long)sm.mnt_ns_id, 99 (unsigned long long)mnt_ns_id); 100 return NSID_FAIL; 101 } 102 103 return NSID_PASS; 104 } 105 106 static int _test_statmount_mnt_ns_id_by_fd(void) 107 { 108 struct statmount sm; 109 uint64_t mnt_ns_id; 110 int ret, fd, mounted = 1, status = NSID_ERROR; 111 char mnt[] = "/statmount.fd.XXXXXX"; 112 113 ret = get_mnt_ns_id("/proc/self/ns/mnt", &mnt_ns_id); 114 if (ret != NSID_PASS) 115 return ret; 116 117 if (!mkdtemp(mnt)) { 118 ksft_print_msg("statmount by fd mnt ns id mkdtemp: %s\n", strerror(errno)); 119 return NSID_ERROR; 120 } 121 122 if (mount(mnt, mnt, NULL, MS_BIND, 0)) { 123 ksft_print_msg("statmount by fd mnt ns id mount: %s\n", strerror(errno)); 124 status = NSID_ERROR; 125 goto err; 126 } 127 128 fd = open(mnt, O_PATH); 129 if (fd < 0) { 130 ksft_print_msg("statmount by fd mnt ns id open: %s\n", strerror(errno)); 131 goto err; 132 } 133 134 ret = statmount(0, 0, fd, STATMOUNT_MNT_NS_ID, &sm, sizeof(sm), STATMOUNT_BY_FD); 135 if (ret == -1) { 136 ksft_print_msg("statmount mnt ns id statmount: %s\n", strerror(errno)); 137 status = NSID_ERROR; 138 goto out; 139 } 140 141 if (sm.size != sizeof(sm)) { 142 ksft_print_msg("unexpected size: %u != %u\n", sm.size, 143 (uint32_t)sizeof(sm)); 144 status = NSID_FAIL; 145 goto out; 146 } 147 if (sm.mask != STATMOUNT_MNT_NS_ID) { 148 ksft_print_msg("statmount mnt ns id unavailable\n"); 149 status = NSID_SKIP; 150 goto out; 151 } 152 153 if (sm.mnt_ns_id != mnt_ns_id) { 154 ksft_print_msg("unexpected mnt ns ID: 0x%llx != 0x%llx\n", 155 (unsigned long long)sm.mnt_ns_id, 156 (unsigned long long)mnt_ns_id); 157 status = NSID_FAIL; 158 goto out; 159 } 160 161 mounted = 0; 162 if (umount2(mnt, MNT_DETACH)) { 163 ksft_print_msg("statmount by fd mnt ns id umount2: %s\n", strerror(errno)); 164 goto out; 165 } 166 167 ret = statmount(0, 0, fd, STATMOUNT_MNT_NS_ID, &sm, sizeof(sm), STATMOUNT_BY_FD); 168 if (ret == -1) { 169 ksft_print_msg("statmount mnt ns id statmount: %s\n", strerror(errno)); 170 status = NSID_ERROR; 171 goto out; 172 } 173 174 if (sm.size != sizeof(sm)) { 175 ksft_print_msg("unexpected size: %u != %u\n", sm.size, 176 (uint32_t)sizeof(sm)); 177 status = NSID_FAIL; 178 goto out; 179 } 180 181 if (sm.mask == STATMOUNT_MNT_NS_ID) { 182 ksft_print_msg("unexpected STATMOUNT_MNT_NS_ID in mask\n"); 183 status = NSID_FAIL; 184 goto out; 185 } 186 187 status = NSID_PASS; 188 out: 189 close(fd); 190 if (mounted) 191 umount2(mnt, MNT_DETACH); 192 err: 193 rmdir(mnt); 194 return status; 195 } 196 197 198 static void test_statmount_mnt_ns_id(void) 199 { 200 pid_t pid; 201 int ret; 202 203 pid = fork(); 204 if (pid < 0) 205 ksft_exit_fail_msg("failed to fork: %s\n", strerror(errno)); 206 207 /* We're the original pid, wait for the result. */ 208 if (pid != 0) { 209 ret = wait_for_pid(pid); 210 handle_result(ret, "test statmount ns id"); 211 return; 212 } 213 214 ret = setup_namespace(); 215 if (ret != NSID_PASS) 216 exit(ret); 217 ret = _test_statmount_mnt_ns_id(); 218 if (ret != NSID_PASS) 219 exit(ret); 220 ret = _test_statmount_mnt_ns_id_by_fd(); 221 exit(ret); 222 } 223 224 static int validate_external_listmount(pid_t pid, uint64_t child_nr_mounts) 225 { 226 uint64_t list[256]; 227 uint64_t mnt_ns_id; 228 uint64_t nr_mounts; 229 char buf[256]; 230 int ret; 231 232 /* Get the mount ns id for our child. */ 233 snprintf(buf, sizeof(buf), "/proc/%lu/ns/mnt", (unsigned long)pid); 234 ret = get_mnt_ns_id(buf, &mnt_ns_id); 235 236 nr_mounts = listmount(LSMT_ROOT, mnt_ns_id, 0, list, 256, 0); 237 if (nr_mounts == (uint64_t)-1) { 238 ksft_print_msg("listmount: %s\n", strerror(errno)); 239 return NSID_ERROR; 240 } 241 242 if (nr_mounts != child_nr_mounts) { 243 ksft_print_msg("listmount results is %zi != %zi\n", nr_mounts, 244 child_nr_mounts); 245 return NSID_FAIL; 246 } 247 248 /* Validate that all of our entries match our mnt_ns_id. */ 249 for (int i = 0; i < nr_mounts; i++) { 250 struct statmount sm; 251 252 ret = statmount(list[i], mnt_ns_id, 0, STATMOUNT_MNT_NS_ID, &sm, 253 sizeof(sm), 0); 254 if (ret < 0) { 255 ksft_print_msg("statmount mnt ns id: %s\n", strerror(errno)); 256 return NSID_ERROR; 257 } 258 259 if (sm.mask != STATMOUNT_MNT_NS_ID) { 260 ksft_print_msg("statmount mnt ns id unavailable\n"); 261 return NSID_SKIP; 262 } 263 264 if (sm.mnt_ns_id != mnt_ns_id) { 265 ksft_print_msg("listmount gave us the wrong ns id: 0x%llx != 0x%llx\n", 266 (unsigned long long)sm.mnt_ns_id, 267 (unsigned long long)mnt_ns_id); 268 return NSID_FAIL; 269 } 270 } 271 272 return NSID_PASS; 273 } 274 275 static void test_listmount_ns(void) 276 { 277 uint64_t nr_mounts; 278 char pval; 279 int child_ready_pipe[2]; 280 int parent_ready_pipe[2]; 281 pid_t pid; 282 int ret, child_ret; 283 284 if (pipe(child_ready_pipe) < 0) 285 ksft_exit_fail_msg("failed to create the child pipe: %s\n", 286 strerror(errno)); 287 if (pipe(parent_ready_pipe) < 0) 288 ksft_exit_fail_msg("failed to create the parent pipe: %s\n", 289 strerror(errno)); 290 291 pid = fork(); 292 if (pid < 0) 293 ksft_exit_fail_msg("failed to fork: %s\n", strerror(errno)); 294 295 if (pid == 0) { 296 char cval; 297 uint64_t list[256]; 298 299 close(child_ready_pipe[0]); 300 close(parent_ready_pipe[1]); 301 302 ret = setup_namespace(); 303 if (ret != NSID_PASS) 304 exit(ret); 305 306 nr_mounts = listmount(LSMT_ROOT, 0, 0, list, 256, 0); 307 if (nr_mounts == (uint64_t)-1) { 308 ksft_print_msg("listmount: %s\n", strerror(errno)); 309 exit(NSID_FAIL); 310 } 311 312 /* 313 * Tell our parent how many mounts we have, and then wait for it 314 * to tell us we're done. 315 */ 316 if (write(child_ready_pipe[1], &nr_mounts, sizeof(nr_mounts)) != 317 sizeof(nr_mounts)) 318 ret = NSID_ERROR; 319 if (read(parent_ready_pipe[0], &cval, sizeof(cval)) != sizeof(cval)) 320 ret = NSID_ERROR; 321 exit(NSID_PASS); 322 } 323 324 close(child_ready_pipe[1]); 325 close(parent_ready_pipe[0]); 326 327 /* Wait until the child has created everything. */ 328 if (read(child_ready_pipe[0], &nr_mounts, sizeof(nr_mounts)) != 329 sizeof(nr_mounts)) 330 ret = NSID_ERROR; 331 332 ret = validate_external_listmount(pid, nr_mounts); 333 334 if (write(parent_ready_pipe[1], &pval, sizeof(pval)) != sizeof(pval)) 335 ret = NSID_ERROR; 336 337 child_ret = wait_for_pid(pid); 338 if (child_ret != NSID_PASS) 339 ret = child_ret; 340 handle_result(ret, "test listmount ns id"); 341 } 342 343 int main(void) 344 { 345 int ret; 346 347 ksft_print_header(); 348 ret = statmount(0, 0, 0, 0, NULL, 0, 0); 349 assert(ret == -1); 350 if (errno == ENOSYS) 351 ksft_exit_skip("statmount() syscall not supported\n"); 352 353 ksft_set_plan(2); 354 test_statmount_mnt_ns_id(); 355 test_listmount_ns(); 356 357 if (ksft_get_fail_cnt() + ksft_get_error_cnt() > 0) 358 ksft_exit_fail(); 359 else 360 ksft_exit_pass(); 361 } 362