1 // SPDX-License-Identifier: GPL-2.0-only 2 3 /* 4 * Copyright (c) 2025, Google LLC. 5 * Pasha Tatashin <pasha.tatashin@soleen.com> 6 */ 7 8 #define _GNU_SOURCE 9 10 #include <stdio.h> 11 #include <stdlib.h> 12 #include <string.h> 13 #include <getopt.h> 14 #include <fcntl.h> 15 #include <unistd.h> 16 #include <sys/ioctl.h> 17 #include <sys/syscall.h> 18 #include <sys/mman.h> 19 #include <sys/types.h> 20 #include <sys/stat.h> 21 #include <errno.h> 22 #include <stdarg.h> 23 24 #include "luo_test_utils.h" 25 26 int luo_open_device(void) 27 { 28 return open(LUO_DEVICE, O_RDWR); 29 } 30 31 int luo_create_session(int luo_fd, const char *name) 32 { 33 struct liveupdate_ioctl_create_session arg = { .size = sizeof(arg) }; 34 35 snprintf((char *)arg.name, LIVEUPDATE_SESSION_NAME_LENGTH, "%.*s", 36 LIVEUPDATE_SESSION_NAME_LENGTH - 1, name); 37 38 if (ioctl(luo_fd, LIVEUPDATE_IOCTL_CREATE_SESSION, &arg) < 0) 39 return -errno; 40 41 return arg.fd; 42 } 43 44 int luo_retrieve_session(int luo_fd, const char *name) 45 { 46 struct liveupdate_ioctl_retrieve_session arg = { .size = sizeof(arg) }; 47 48 snprintf((char *)arg.name, LIVEUPDATE_SESSION_NAME_LENGTH, "%.*s", 49 LIVEUPDATE_SESSION_NAME_LENGTH - 1, name); 50 51 if (ioctl(luo_fd, LIVEUPDATE_IOCTL_RETRIEVE_SESSION, &arg) < 0) 52 return -errno; 53 54 return arg.fd; 55 } 56 57 int create_and_preserve_memfd(int session_fd, int token, const char *data) 58 { 59 struct liveupdate_session_preserve_fd arg = { .size = sizeof(arg) }; 60 long page_size = sysconf(_SC_PAGE_SIZE); 61 void *map = MAP_FAILED; 62 int mfd = -1, ret = -1; 63 64 mfd = memfd_create("test_mfd", 0); 65 if (mfd < 0) 66 return -errno; 67 68 if (ftruncate(mfd, page_size) != 0) 69 goto out; 70 71 map = mmap(NULL, page_size, PROT_WRITE, MAP_SHARED, mfd, 0); 72 if (map == MAP_FAILED) 73 goto out; 74 75 snprintf(map, page_size, "%s", data); 76 munmap(map, page_size); 77 78 arg.fd = mfd; 79 arg.token = token; 80 if (ioctl(session_fd, LIVEUPDATE_SESSION_PRESERVE_FD, &arg) < 0) 81 goto out; 82 83 ret = 0; 84 out: 85 if (ret != 0 && errno != 0) 86 ret = -errno; 87 if (mfd >= 0) 88 close(mfd); 89 return ret; 90 } 91 92 int restore_and_verify_memfd(int session_fd, int token, 93 const char *expected_data) 94 { 95 struct liveupdate_session_retrieve_fd arg = { .size = sizeof(arg) }; 96 long page_size = sysconf(_SC_PAGE_SIZE); 97 void *map = MAP_FAILED; 98 int mfd = -1, ret = -1; 99 100 arg.token = token; 101 if (ioctl(session_fd, LIVEUPDATE_SESSION_RETRIEVE_FD, &arg) < 0) 102 return -errno; 103 mfd = arg.fd; 104 105 map = mmap(NULL, page_size, PROT_READ, MAP_SHARED, mfd, 0); 106 if (map == MAP_FAILED) 107 goto out; 108 109 if (expected_data && strcmp(expected_data, map) != 0) { 110 ksft_print_msg("Data mismatch! Expected '%s', Got '%s'\n", 111 expected_data, (char *)map); 112 ret = -EINVAL; 113 goto out_munmap; 114 } 115 116 ret = mfd; 117 out_munmap: 118 munmap(map, page_size); 119 out: 120 if (ret < 0 && errno != 0) 121 ret = -errno; 122 if (ret < 0 && mfd >= 0) 123 close(mfd); 124 return ret; 125 } 126 127 int luo_session_finish(int session_fd) 128 { 129 struct liveupdate_session_finish arg = { .size = sizeof(arg) }; 130 131 if (ioctl(session_fd, LIVEUPDATE_SESSION_FINISH, &arg) < 0) 132 return -errno; 133 134 return 0; 135 } 136 137 void create_state_file(int luo_fd, const char *session_name, int token, 138 int next_stage) 139 { 140 char buf[32]; 141 int state_session_fd; 142 143 state_session_fd = luo_create_session(luo_fd, session_name); 144 if (state_session_fd < 0) 145 fail_exit("luo_create_session for state tracking"); 146 147 snprintf(buf, sizeof(buf), "%d", next_stage); 148 if (create_and_preserve_memfd(state_session_fd, token, buf) < 0) 149 fail_exit("create_and_preserve_memfd for state tracking"); 150 151 /* 152 * DO NOT close session FD, otherwise it is going to be unpreserved 153 */ 154 } 155 156 void restore_and_read_stage(int state_session_fd, int token, int *stage) 157 { 158 char buf[32] = {0}; 159 int mfd; 160 161 mfd = restore_and_verify_memfd(state_session_fd, token, NULL); 162 if (mfd < 0) 163 fail_exit("failed to restore state memfd"); 164 165 if (read(mfd, buf, sizeof(buf) - 1) < 0) 166 fail_exit("failed to read state mfd"); 167 168 *stage = atoi(buf); 169 170 close(mfd); 171 } 172 173 void daemonize_and_wait(void) 174 { 175 pid_t pid; 176 177 ksft_print_msg("[STAGE 1] Forking persistent child to hold sessions...\n"); 178 179 pid = fork(); 180 if (pid < 0) 181 fail_exit("fork failed"); 182 183 if (pid > 0) { 184 ksft_print_msg("[STAGE 1] Child PID: %d. Resources are pinned.\n", pid); 185 ksft_print_msg("[STAGE 1] You may now perform kexec reboot.\n"); 186 exit(EXIT_SUCCESS); 187 } 188 189 /* Detach from terminal so closing the window doesn't kill us */ 190 if (setsid() < 0) 191 fail_exit("setsid failed"); 192 193 close(STDIN_FILENO); 194 close(STDOUT_FILENO); 195 close(STDERR_FILENO); 196 197 /* Change dir to root to avoid locking filesystems */ 198 if (chdir("/") < 0) 199 exit(EXIT_FAILURE); 200 201 while (1) 202 sleep(60); 203 } 204 205 static int parse_stage_args(int argc, char *argv[]) 206 { 207 static struct option long_options[] = { 208 {"stage", required_argument, 0, 's'}, 209 {0, 0, 0, 0} 210 }; 211 int option_index = 0; 212 int stage = 1; 213 int opt; 214 215 optind = 1; 216 while ((opt = getopt_long(argc, argv, "s:", long_options, &option_index)) != -1) { 217 switch (opt) { 218 case 's': 219 stage = atoi(optarg); 220 if (stage != 1 && stage != 2) 221 fail_exit("Invalid stage argument"); 222 break; 223 default: 224 fail_exit("Unknown argument"); 225 } 226 } 227 return stage; 228 } 229 230 int luo_test(int argc, char *argv[], 231 const char *state_session_name, 232 luo_test_stage1_fn stage1, 233 luo_test_stage2_fn stage2) 234 { 235 int target_stage = parse_stage_args(argc, argv); 236 int luo_fd = luo_open_device(); 237 int state_session_fd; 238 int detected_stage; 239 240 if (luo_fd < 0) { 241 ksft_exit_skip("Failed to open %s. Is the luo module loaded?\n", 242 LUO_DEVICE); 243 } 244 245 state_session_fd = luo_retrieve_session(luo_fd, state_session_name); 246 if (state_session_fd == -ENOENT) 247 detected_stage = 1; 248 else if (state_session_fd >= 0) 249 detected_stage = 2; 250 else 251 fail_exit("Failed to check for state session"); 252 253 if (target_stage != detected_stage) { 254 ksft_exit_fail_msg("Stage mismatch Requested --stage %d, but system is in stage %d.\n" 255 "(State session %s: %s)\n", 256 target_stage, detected_stage, state_session_name, 257 (detected_stage == 2) ? "EXISTS" : "MISSING"); 258 } 259 260 if (target_stage == 1) 261 stage1(luo_fd); 262 else 263 stage2(luo_fd, state_session_fd); 264 265 return 0; 266 } 267