xref: /linux/tools/testing/selftests/arm64/mte/mte_common_util.c (revision c8bfe3fad4f86a029da7157bae9699c816f0c309)
1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (C) 2020 ARM Limited
3 
4 #include <fcntl.h>
5 #include <sched.h>
6 #include <signal.h>
7 #include <stdio.h>
8 #include <stdlib.h>
9 #include <unistd.h>
10 
11 #include <linux/auxvec.h>
12 #include <sys/auxv.h>
13 #include <sys/mman.h>
14 #include <sys/prctl.h>
15 
16 #include <asm/hwcap.h>
17 
18 #include "kselftest.h"
19 #include "mte_common_util.h"
20 #include "mte_def.h"
21 
22 #define INIT_BUFFER_SIZE       256
23 
24 struct mte_fault_cxt cur_mte_cxt;
25 static unsigned int mte_cur_mode;
26 static unsigned int mte_cur_pstate_tco;
27 
28 void mte_default_handler(int signum, siginfo_t *si, void *uc)
29 {
30 	unsigned long addr = (unsigned long)si->si_addr;
31 
32 	if (signum == SIGSEGV) {
33 #ifdef DEBUG
34 		ksft_print_msg("INFO: SIGSEGV signal at pc=%lx, fault addr=%lx, si_code=%lx\n",
35 				((ucontext_t *)uc)->uc_mcontext.pc, addr, si->si_code);
36 #endif
37 		if (si->si_code == SEGV_MTEAERR) {
38 			if (cur_mte_cxt.trig_si_code == si->si_code)
39 				cur_mte_cxt.fault_valid = true;
40 			else
41 				ksft_print_msg("Got unexpected SEGV_MTEAERR at pc=$lx, fault addr=%lx\n",
42 					       ((ucontext_t *)uc)->uc_mcontext.pc,
43 					       addr);
44 			return;
45 		}
46 		/* Compare the context for precise error */
47 		else if (si->si_code == SEGV_MTESERR) {
48 			if (cur_mte_cxt.trig_si_code == si->si_code &&
49 			    ((cur_mte_cxt.trig_range >= 0 &&
50 			      addr >= MT_CLEAR_TAG(cur_mte_cxt.trig_addr) &&
51 			      addr <= (MT_CLEAR_TAG(cur_mte_cxt.trig_addr) + cur_mte_cxt.trig_range)) ||
52 			     (cur_mte_cxt.trig_range < 0 &&
53 			      addr <= MT_CLEAR_TAG(cur_mte_cxt.trig_addr) &&
54 			      addr >= (MT_CLEAR_TAG(cur_mte_cxt.trig_addr) + cur_mte_cxt.trig_range)))) {
55 				cur_mte_cxt.fault_valid = true;
56 				/* Adjust the pc by 4 */
57 				((ucontext_t *)uc)->uc_mcontext.pc += 4;
58 			} else {
59 				ksft_print_msg("Invalid MTE synchronous exception caught!\n");
60 				exit(1);
61 			}
62 		} else {
63 			ksft_print_msg("Unknown SIGSEGV exception caught!\n");
64 			exit(1);
65 		}
66 	} else if (signum == SIGBUS) {
67 		ksft_print_msg("INFO: SIGBUS signal at pc=%lx, fault addr=%lx, si_code=%lx\n",
68 				((ucontext_t *)uc)->uc_mcontext.pc, addr, si->si_code);
69 		if ((cur_mte_cxt.trig_range >= 0 &&
70 		     addr >= MT_CLEAR_TAG(cur_mte_cxt.trig_addr) &&
71 		     addr <= (MT_CLEAR_TAG(cur_mte_cxt.trig_addr) + cur_mte_cxt.trig_range)) ||
72 		    (cur_mte_cxt.trig_range < 0 &&
73 		     addr <= MT_CLEAR_TAG(cur_mte_cxt.trig_addr) &&
74 		     addr >= (MT_CLEAR_TAG(cur_mte_cxt.trig_addr) + cur_mte_cxt.trig_range))) {
75 			cur_mte_cxt.fault_valid = true;
76 			/* Adjust the pc by 4 */
77 			((ucontext_t *)uc)->uc_mcontext.pc += 4;
78 		}
79 	}
80 }
81 
82 void mte_register_signal(int signal, void (*handler)(int, siginfo_t *, void *))
83 {
84 	struct sigaction sa;
85 
86 	sa.sa_sigaction = handler;
87 	sa.sa_flags = SA_SIGINFO;
88 	sigemptyset(&sa.sa_mask);
89 	sigaction(signal, &sa, NULL);
90 }
91 
92 void mte_wait_after_trig(void)
93 {
94 	sched_yield();
95 }
96 
97 void *mte_insert_tags(void *ptr, size_t size)
98 {
99 	void *tag_ptr;
100 	int align_size;
101 
102 	if (!ptr || (unsigned long)(ptr) & MT_ALIGN_GRANULE) {
103 		ksft_print_msg("FAIL: Addr=%lx: invalid\n", ptr);
104 		return NULL;
105 	}
106 	align_size = MT_ALIGN_UP(size);
107 	tag_ptr = mte_insert_random_tag(ptr);
108 	mte_set_tag_address_range(tag_ptr, align_size);
109 	return tag_ptr;
110 }
111 
112 void mte_clear_tags(void *ptr, size_t size)
113 {
114 	if (!ptr || (unsigned long)(ptr) & MT_ALIGN_GRANULE) {
115 		ksft_print_msg("FAIL: Addr=%lx: invalid\n", ptr);
116 		return;
117 	}
118 	size = MT_ALIGN_UP(size);
119 	ptr = (void *)MT_CLEAR_TAG((unsigned long)ptr);
120 	mte_clear_tag_address_range(ptr, size);
121 }
122 
123 static void *__mte_allocate_memory_range(size_t size, int mem_type, int mapping,
124 					 size_t range_before, size_t range_after,
125 					 bool tags, int fd)
126 {
127 	void *ptr;
128 	int prot_flag, map_flag;
129 	size_t entire_size = size + range_before + range_after;
130 
131 	switch (mem_type) {
132 	case USE_MALLOC:
133 		return malloc(entire_size) + range_before;
134 	case USE_MMAP:
135 	case USE_MPROTECT:
136 		break;
137 	default:
138 		ksft_print_msg("FAIL: Invalid allocate request\n");
139 		return NULL;
140 	}
141 
142 	prot_flag = PROT_READ | PROT_WRITE;
143 	if (mem_type == USE_MMAP)
144 		prot_flag |= PROT_MTE;
145 
146 	map_flag = mapping;
147 	if (fd == -1)
148 		map_flag = MAP_ANONYMOUS | map_flag;
149 	if (!(mapping & MAP_SHARED))
150 		map_flag |= MAP_PRIVATE;
151 	ptr = mmap(NULL, entire_size, prot_flag, map_flag, fd, 0);
152 	if (ptr == MAP_FAILED) {
153 		ksft_print_msg("FAIL: mmap allocation\n");
154 		return NULL;
155 	}
156 	if (mem_type == USE_MPROTECT) {
157 		if (mprotect(ptr, entire_size, prot_flag | PROT_MTE)) {
158 			munmap(ptr, size);
159 			ksft_print_msg("FAIL: mprotect PROT_MTE property\n");
160 			return NULL;
161 		}
162 	}
163 	if (tags)
164 		ptr = mte_insert_tags(ptr + range_before, size);
165 	return ptr;
166 }
167 
168 void *mte_allocate_memory_tag_range(size_t size, int mem_type, int mapping,
169 				    size_t range_before, size_t range_after)
170 {
171 	return __mte_allocate_memory_range(size, mem_type, mapping, range_before,
172 					   range_after, true, -1);
173 }
174 
175 void *mte_allocate_memory(size_t size, int mem_type, int mapping, bool tags)
176 {
177 	return __mte_allocate_memory_range(size, mem_type, mapping, 0, 0, tags, -1);
178 }
179 
180 void *mte_allocate_file_memory(size_t size, int mem_type, int mapping, bool tags, int fd)
181 {
182 	int index;
183 	char buffer[INIT_BUFFER_SIZE];
184 
185 	if (mem_type != USE_MPROTECT && mem_type != USE_MMAP) {
186 		ksft_print_msg("FAIL: Invalid mmap file request\n");
187 		return NULL;
188 	}
189 	/* Initialize the file for mappable size */
190 	lseek(fd, 0, SEEK_SET);
191 	for (index = INIT_BUFFER_SIZE; index < size; index += INIT_BUFFER_SIZE) {
192 		if (write(fd, buffer, INIT_BUFFER_SIZE) != INIT_BUFFER_SIZE) {
193 			perror("initialising buffer");
194 			return NULL;
195 		}
196 	}
197 	index -= INIT_BUFFER_SIZE;
198 	if (write(fd, buffer, size - index) != size - index) {
199 		perror("initialising buffer");
200 		return NULL;
201 	}
202 	return __mte_allocate_memory_range(size, mem_type, mapping, 0, 0, tags, fd);
203 }
204 
205 void *mte_allocate_file_memory_tag_range(size_t size, int mem_type, int mapping,
206 					 size_t range_before, size_t range_after, int fd)
207 {
208 	int index;
209 	char buffer[INIT_BUFFER_SIZE];
210 	int map_size = size + range_before + range_after;
211 
212 	if (mem_type != USE_MPROTECT && mem_type != USE_MMAP) {
213 		ksft_print_msg("FAIL: Invalid mmap file request\n");
214 		return NULL;
215 	}
216 	/* Initialize the file for mappable size */
217 	lseek(fd, 0, SEEK_SET);
218 	for (index = INIT_BUFFER_SIZE; index < map_size; index += INIT_BUFFER_SIZE)
219 		if (write(fd, buffer, INIT_BUFFER_SIZE) != INIT_BUFFER_SIZE) {
220 			perror("initialising buffer");
221 			return NULL;
222 		}
223 	index -= INIT_BUFFER_SIZE;
224 	if (write(fd, buffer, map_size - index) != map_size - index) {
225 		perror("initialising buffer");
226 		return NULL;
227 	}
228 	return __mte_allocate_memory_range(size, mem_type, mapping, range_before,
229 					   range_after, true, fd);
230 }
231 
232 static void __mte_free_memory_range(void *ptr, size_t size, int mem_type,
233 				    size_t range_before, size_t range_after, bool tags)
234 {
235 	switch (mem_type) {
236 	case USE_MALLOC:
237 		free(ptr - range_before);
238 		break;
239 	case USE_MMAP:
240 	case USE_MPROTECT:
241 		if (tags)
242 			mte_clear_tags(ptr, size);
243 		munmap(ptr - range_before, size + range_before + range_after);
244 		break;
245 	default:
246 		ksft_print_msg("FAIL: Invalid free request\n");
247 		break;
248 	}
249 }
250 
251 void mte_free_memory_tag_range(void *ptr, size_t size, int mem_type,
252 			       size_t range_before, size_t range_after)
253 {
254 	__mte_free_memory_range(ptr, size, mem_type, range_before, range_after, true);
255 }
256 
257 void mte_free_memory(void *ptr, size_t size, int mem_type, bool tags)
258 {
259 	__mte_free_memory_range(ptr, size, mem_type, 0, 0, tags);
260 }
261 
262 void mte_initialize_current_context(int mode, uintptr_t ptr, ssize_t range)
263 {
264 	cur_mte_cxt.fault_valid = false;
265 	cur_mte_cxt.trig_addr = ptr;
266 	cur_mte_cxt.trig_range = range;
267 	if (mode == MTE_SYNC_ERR)
268 		cur_mte_cxt.trig_si_code = SEGV_MTESERR;
269 	else if (mode == MTE_ASYNC_ERR)
270 		cur_mte_cxt.trig_si_code = SEGV_MTEAERR;
271 	else
272 		cur_mte_cxt.trig_si_code = 0;
273 }
274 
275 int mte_switch_mode(int mte_option, unsigned long incl_mask)
276 {
277 	unsigned long en = 0;
278 
279 	switch (mte_option) {
280 	case MTE_NONE_ERR:
281 	case MTE_SYNC_ERR:
282 	case MTE_ASYNC_ERR:
283 		break;
284 	default:
285 		ksft_print_msg("FAIL: Invalid MTE option %x\n", mte_option);
286 		return -EINVAL;
287 	}
288 
289 	if (incl_mask & ~MT_INCLUDE_TAG_MASK) {
290 		ksft_print_msg("FAIL: Invalid incl_mask %lx\n", incl_mask);
291 		return -EINVAL;
292 	}
293 
294 	en = PR_TAGGED_ADDR_ENABLE;
295 	switch (mte_option) {
296 	case MTE_SYNC_ERR:
297 		en |= PR_MTE_TCF_SYNC;
298 		break;
299 	case MTE_ASYNC_ERR:
300 		en |= PR_MTE_TCF_ASYNC;
301 		break;
302 	case MTE_NONE_ERR:
303 		en |= PR_MTE_TCF_NONE;
304 		break;
305 	}
306 
307 	en |= (incl_mask << PR_MTE_TAG_SHIFT);
308 	/* Enable address tagging ABI, mte error reporting mode and tag inclusion mask. */
309 	if (prctl(PR_SET_TAGGED_ADDR_CTRL, en, 0, 0, 0) != 0) {
310 		ksft_print_msg("FAIL:prctl PR_SET_TAGGED_ADDR_CTRL for mte mode\n");
311 		return -EINVAL;
312 	}
313 	return 0;
314 }
315 
316 int mte_default_setup(void)
317 {
318 	unsigned long hwcaps2 = getauxval(AT_HWCAP2);
319 	unsigned long en = 0;
320 	int ret;
321 
322 	if (!(hwcaps2 & HWCAP2_MTE)) {
323 		ksft_print_msg("SKIP: MTE features unavailable\n");
324 		return KSFT_SKIP;
325 	}
326 	/* Get current mte mode */
327 	ret = prctl(PR_GET_TAGGED_ADDR_CTRL, en, 0, 0, 0);
328 	if (ret < 0) {
329 		ksft_print_msg("FAIL:prctl PR_GET_TAGGED_ADDR_CTRL with error =%d\n", ret);
330 		return KSFT_FAIL;
331 	}
332 	if (ret & PR_MTE_TCF_SYNC)
333 		mte_cur_mode = MTE_SYNC_ERR;
334 	else if (ret & PR_MTE_TCF_ASYNC)
335 		mte_cur_mode = MTE_ASYNC_ERR;
336 	else if (ret & PR_MTE_TCF_NONE)
337 		mte_cur_mode = MTE_NONE_ERR;
338 
339 	mte_cur_pstate_tco = mte_get_pstate_tco();
340 	/* Disable PSTATE.TCO */
341 	mte_disable_pstate_tco();
342 	return 0;
343 }
344 
345 void mte_restore_setup(void)
346 {
347 	mte_switch_mode(mte_cur_mode, MTE_ALLOW_NON_ZERO_TAG);
348 	if (mte_cur_pstate_tco == MT_PSTATE_TCO_EN)
349 		mte_enable_pstate_tco();
350 	else if (mte_cur_pstate_tco == MT_PSTATE_TCO_DIS)
351 		mte_disable_pstate_tco();
352 }
353 
354 int create_temp_file(void)
355 {
356 	int fd;
357 	char filename[] = "/dev/shm/tmp_XXXXXX";
358 
359 	/* Create a file in the tmpfs filesystem */
360 	fd = mkstemp(&filename[0]);
361 	if (fd == -1) {
362 		perror(filename);
363 		ksft_print_msg("FAIL: Unable to open temporary file\n");
364 		return 0;
365 	}
366 	unlink(&filename[0]);
367 	return fd;
368 }
369