1 // SPDX-License-Identifier: GPL-2.0-or-later 2 3 #define _GNU_SOURCE 4 5 #include <assert.h> 6 #include <fcntl.h> 7 #include <limits.h> 8 #include <sched.h> 9 #include <stdlib.h> 10 #include <sys/mount.h> 11 #include <sys/stat.h> 12 #include <sys/wait.h> 13 #include <linux/nsfs.h> 14 #include <linux/stat.h> 15 16 #include "statmount.h" 17 #include "../utils.h" 18 #include "../../kselftest.h" 19 20 #define NSID_PASS 0 21 #define NSID_FAIL 1 22 #define NSID_SKIP 2 23 #define NSID_ERROR 3 24 25 static void handle_result(int ret, const char *testname) 26 { 27 if (ret == NSID_PASS) 28 ksft_test_result_pass("%s\n", testname); 29 else if (ret == NSID_FAIL) 30 ksft_test_result_fail("%s\n", testname); 31 else if (ret == NSID_ERROR) 32 ksft_exit_fail_msg("%s\n", testname); 33 else 34 ksft_test_result_skip("%s\n", testname); 35 } 36 37 static inline int wait_for_pid(pid_t pid) 38 { 39 int status, ret; 40 41 again: 42 ret = waitpid(pid, &status, 0); 43 if (ret == -1) { 44 if (errno == EINTR) 45 goto again; 46 47 ksft_print_msg("waitpid returned -1, errno=%d\n", errno); 48 return -1; 49 } 50 51 if (!WIFEXITED(status)) { 52 ksft_print_msg( 53 "waitpid !WIFEXITED, WIFSIGNALED=%d, WTERMSIG=%d\n", 54 WIFSIGNALED(status), WTERMSIG(status)); 55 return -1; 56 } 57 58 ret = WEXITSTATUS(status); 59 return ret; 60 } 61 62 static int get_mnt_ns_id(const char *mnt_ns, uint64_t *mnt_ns_id) 63 { 64 int fd = open(mnt_ns, O_RDONLY); 65 66 if (fd < 0) { 67 ksft_print_msg("failed to open for ns %s: %s\n", 68 mnt_ns, strerror(errno)); 69 sleep(60); 70 return NSID_ERROR; 71 } 72 73 if (ioctl(fd, NS_GET_MNTNS_ID, mnt_ns_id) < 0) { 74 ksft_print_msg("failed to get the nsid for ns %s: %s\n", 75 mnt_ns, strerror(errno)); 76 return NSID_ERROR; 77 } 78 close(fd); 79 return NSID_PASS; 80 } 81 82 static int setup_namespace(void) 83 { 84 if (setup_userns() != 0) 85 return NSID_ERROR; 86 87 return NSID_PASS; 88 } 89 90 static int _test_statmount_mnt_ns_id(void) 91 { 92 struct statmount sm; 93 uint64_t mnt_ns_id; 94 uint64_t root_id; 95 int ret; 96 97 ret = get_mnt_ns_id("/proc/self/ns/mnt", &mnt_ns_id); 98 if (ret != NSID_PASS) 99 return ret; 100 101 root_id = get_unique_mnt_id("/"); 102 if (!root_id) 103 return NSID_ERROR; 104 105 ret = statmount(root_id, 0, STATMOUNT_MNT_NS_ID, &sm, sizeof(sm), 0); 106 if (ret == -1) { 107 ksft_print_msg("statmount mnt ns id: %s\n", strerror(errno)); 108 return NSID_ERROR; 109 } 110 111 if (sm.size != sizeof(sm)) { 112 ksft_print_msg("unexpected size: %u != %u\n", sm.size, 113 (uint32_t)sizeof(sm)); 114 return NSID_FAIL; 115 } 116 if (sm.mask != STATMOUNT_MNT_NS_ID) { 117 ksft_print_msg("statmount mnt ns id unavailable\n"); 118 return NSID_SKIP; 119 } 120 121 if (sm.mnt_ns_id != mnt_ns_id) { 122 ksft_print_msg("unexpected mnt ns ID: 0x%llx != 0x%llx\n", 123 (unsigned long long)sm.mnt_ns_id, 124 (unsigned long long)mnt_ns_id); 125 return NSID_FAIL; 126 } 127 128 return NSID_PASS; 129 } 130 131 static void test_statmount_mnt_ns_id(void) 132 { 133 pid_t pid; 134 int ret; 135 136 pid = fork(); 137 if (pid < 0) 138 ksft_exit_fail_msg("failed to fork: %s\n", strerror(errno)); 139 140 /* We're the original pid, wait for the result. */ 141 if (pid != 0) { 142 ret = wait_for_pid(pid); 143 handle_result(ret, "test statmount ns id"); 144 return; 145 } 146 147 ret = setup_namespace(); 148 if (ret != NSID_PASS) 149 exit(ret); 150 ret = _test_statmount_mnt_ns_id(); 151 exit(ret); 152 } 153 154 static int validate_external_listmount(pid_t pid, uint64_t child_nr_mounts) 155 { 156 uint64_t list[256]; 157 uint64_t mnt_ns_id; 158 uint64_t nr_mounts; 159 char buf[256]; 160 int ret; 161 162 /* Get the mount ns id for our child. */ 163 snprintf(buf, sizeof(buf), "/proc/%lu/ns/mnt", (unsigned long)pid); 164 ret = get_mnt_ns_id(buf, &mnt_ns_id); 165 166 nr_mounts = listmount(LSMT_ROOT, mnt_ns_id, 0, list, 256, 0); 167 if (nr_mounts == (uint64_t)-1) { 168 ksft_print_msg("listmount: %s\n", strerror(errno)); 169 return NSID_ERROR; 170 } 171 172 if (nr_mounts != child_nr_mounts) { 173 ksft_print_msg("listmount results is %zi != %zi\n", nr_mounts, 174 child_nr_mounts); 175 return NSID_FAIL; 176 } 177 178 /* Validate that all of our entries match our mnt_ns_id. */ 179 for (int i = 0; i < nr_mounts; i++) { 180 struct statmount sm; 181 182 ret = statmount(list[i], mnt_ns_id, STATMOUNT_MNT_NS_ID, &sm, 183 sizeof(sm), 0); 184 if (ret < 0) { 185 ksft_print_msg("statmount mnt ns id: %s\n", strerror(errno)); 186 return NSID_ERROR; 187 } 188 189 if (sm.mask != STATMOUNT_MNT_NS_ID) { 190 ksft_print_msg("statmount mnt ns id unavailable\n"); 191 return NSID_SKIP; 192 } 193 194 if (sm.mnt_ns_id != mnt_ns_id) { 195 ksft_print_msg("listmount gave us the wrong ns id: 0x%llx != 0x%llx\n", 196 (unsigned long long)sm.mnt_ns_id, 197 (unsigned long long)mnt_ns_id); 198 return NSID_FAIL; 199 } 200 } 201 202 return NSID_PASS; 203 } 204 205 static void test_listmount_ns(void) 206 { 207 uint64_t nr_mounts; 208 char pval; 209 int child_ready_pipe[2]; 210 int parent_ready_pipe[2]; 211 pid_t pid; 212 int ret, child_ret; 213 214 if (pipe(child_ready_pipe) < 0) 215 ksft_exit_fail_msg("failed to create the child pipe: %s\n", 216 strerror(errno)); 217 if (pipe(parent_ready_pipe) < 0) 218 ksft_exit_fail_msg("failed to create the parent pipe: %s\n", 219 strerror(errno)); 220 221 pid = fork(); 222 if (pid < 0) 223 ksft_exit_fail_msg("failed to fork: %s\n", strerror(errno)); 224 225 if (pid == 0) { 226 char cval; 227 uint64_t list[256]; 228 229 close(child_ready_pipe[0]); 230 close(parent_ready_pipe[1]); 231 232 ret = setup_namespace(); 233 if (ret != NSID_PASS) 234 exit(ret); 235 236 nr_mounts = listmount(LSMT_ROOT, 0, 0, list, 256, 0); 237 if (nr_mounts == (uint64_t)-1) { 238 ksft_print_msg("listmount: %s\n", strerror(errno)); 239 exit(NSID_FAIL); 240 } 241 242 /* 243 * Tell our parent how many mounts we have, and then wait for it 244 * to tell us we're done. 245 */ 246 if (write(child_ready_pipe[1], &nr_mounts, sizeof(nr_mounts)) != 247 sizeof(nr_mounts)) 248 ret = NSID_ERROR; 249 if (read(parent_ready_pipe[0], &cval, sizeof(cval)) != sizeof(cval)) 250 ret = NSID_ERROR; 251 exit(NSID_PASS); 252 } 253 254 close(child_ready_pipe[1]); 255 close(parent_ready_pipe[0]); 256 257 /* Wait until the child has created everything. */ 258 if (read(child_ready_pipe[0], &nr_mounts, sizeof(nr_mounts)) != 259 sizeof(nr_mounts)) 260 ret = NSID_ERROR; 261 262 ret = validate_external_listmount(pid, nr_mounts); 263 264 if (write(parent_ready_pipe[1], &pval, sizeof(pval)) != sizeof(pval)) 265 ret = NSID_ERROR; 266 267 child_ret = wait_for_pid(pid); 268 if (child_ret != NSID_PASS) 269 ret = child_ret; 270 handle_result(ret, "test listmount ns id"); 271 } 272 273 int main(void) 274 { 275 int ret; 276 277 ksft_print_header(); 278 ret = statmount(0, 0, 0, NULL, 0, 0); 279 assert(ret == -1); 280 if (errno == ENOSYS) 281 ksft_exit_skip("statmount() syscall not supported\n"); 282 283 ksft_set_plan(2); 284 test_statmount_mnt_ns_id(); 285 test_listmount_ns(); 286 287 if (ksft_get_fail_cnt() + ksft_get_error_cnt() > 0) 288 ksft_exit_fail(); 289 else 290 ksft_exit_pass(); 291 } 292