xref: /linux/tools/testing/selftests/arm64/mte/mte_common_util.c (revision 964a07426eb8bfc3a6aed8d07eeeca77f44f6d91)
1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (C) 2020 ARM Limited
3 
4 #include <fcntl.h>
5 #include <sched.h>
6 #include <signal.h>
7 #include <stdio.h>
8 #include <stdlib.h>
9 #include <time.h>
10 #include <unistd.h>
11 
12 #include <linux/auxvec.h>
13 #include <sys/auxv.h>
14 #include <sys/mman.h>
15 #include <sys/prctl.h>
16 
17 #include <asm/hwcap.h>
18 
19 #include "kselftest.h"
20 #include "mte_common_util.h"
21 #include "mte_def.h"
22 
23 #ifndef SA_EXPOSE_TAGBITS
24 #define SA_EXPOSE_TAGBITS 0x00000800
25 #endif
26 
27 #define INIT_BUFFER_SIZE       256
28 
29 struct mte_fault_cxt cur_mte_cxt;
30 bool mtefar_support;
31 static unsigned int mte_cur_mode;
32 static unsigned int mte_cur_pstate_tco;
33 
34 void mte_default_handler(int signum, siginfo_t *si, void *uc)
35 {
36 	struct sigaction sa;
37 	unsigned long addr = (unsigned long)si->si_addr;
38 	unsigned char si_tag, si_atag;
39 
40 	sigaction(signum, NULL, &sa);
41 
42 	if (sa.sa_flags & SA_EXPOSE_TAGBITS) {
43 		si_tag = MT_FETCH_TAG(addr);
44 		si_atag = MT_FETCH_ATAG(addr);
45 		addr = MT_CLEAR_TAGS(addr);
46 	} else {
47 		si_tag = 0;
48 		si_atag = 0;
49 	}
50 
51 	if (signum == SIGSEGV) {
52 #ifdef DEBUG
53 		ksft_print_msg("INFO: SIGSEGV signal at pc=%lx, fault addr=%lx, si_code=%lx, si_tag=%x, si_atag=%x\n",
54 				((ucontext_t *)uc)->uc_mcontext.pc, addr, si->si_code, si_tag, si_atag);
55 #endif
56 		if (si->si_code == SEGV_MTEAERR) {
57 			if (cur_mte_cxt.trig_si_code == si->si_code)
58 				cur_mte_cxt.fault_valid = true;
59 			else
60 				ksft_print_msg("Got unexpected SEGV_MTEAERR at pc=%llx, fault addr=%lx\n",
61 					       ((ucontext_t *)uc)->uc_mcontext.pc,
62 					       addr);
63 			return;
64 		}
65 		/* Compare the context for precise error */
66 		else if (si->si_code == SEGV_MTESERR) {
67 			if ((!mtefar_support && si_atag) || (si_atag != MT_FETCH_ATAG(cur_mte_cxt.trig_addr))) {
68 				ksft_print_msg("Invalid MTE synchronous exception caught for address tag! si_tag=%x, si_atag: %x\n", si_tag, si_atag);
69 				exit(KSFT_FAIL);
70 			}
71 
72 			if (cur_mte_cxt.trig_si_code == si->si_code &&
73 			    ((cur_mte_cxt.trig_range >= 0 &&
74 			      addr >= MT_CLEAR_TAGS(cur_mte_cxt.trig_addr) &&
75 			      addr <= (MT_CLEAR_TAGS(cur_mte_cxt.trig_addr) + cur_mte_cxt.trig_range)) ||
76 			     (cur_mte_cxt.trig_range < 0 &&
77 			      addr <= MT_CLEAR_TAGS(cur_mte_cxt.trig_addr) &&
78 			      addr >= (MT_CLEAR_TAGS(cur_mte_cxt.trig_addr) + cur_mte_cxt.trig_range)))) {
79 				cur_mte_cxt.fault_valid = true;
80 				/* Adjust the pc by 4 */
81 				((ucontext_t *)uc)->uc_mcontext.pc += 4;
82 			} else {
83 				ksft_print_msg("Invalid MTE synchronous exception caught!\n");
84 				exit(1);
85 			}
86 		} else {
87 			ksft_print_msg("Unknown SIGSEGV exception caught!\n");
88 			exit(1);
89 		}
90 	} else if (signum == SIGBUS) {
91 		ksft_print_msg("INFO: SIGBUS signal at pc=%llx, fault addr=%lx, si_code=%x\n",
92 				((ucontext_t *)uc)->uc_mcontext.pc, addr, si->si_code);
93 		if ((cur_mte_cxt.trig_range >= 0 &&
94 		     addr >= MT_CLEAR_TAGS(cur_mte_cxt.trig_addr) &&
95 		     addr <= (MT_CLEAR_TAGS(cur_mte_cxt.trig_addr) + cur_mte_cxt.trig_range)) ||
96 		    (cur_mte_cxt.trig_range < 0 &&
97 		     addr <= MT_CLEAR_TAGS(cur_mte_cxt.trig_addr) &&
98 		     addr >= (MT_CLEAR_TAGS(cur_mte_cxt.trig_addr) + cur_mte_cxt.trig_range))) {
99 			cur_mte_cxt.fault_valid = true;
100 			/* Adjust the pc by 4 */
101 			((ucontext_t *)uc)->uc_mcontext.pc += 4;
102 		}
103 	}
104 }
105 
106 void mte_register_signal(int signal, void (*handler)(int, siginfo_t *, void *),
107 			 bool export_tags)
108 {
109 	struct sigaction sa;
110 
111 	sa.sa_sigaction = handler;
112 	sa.sa_flags = SA_SIGINFO;
113 
114 	if (export_tags && signal == SIGSEGV)
115 		sa.sa_flags |= SA_EXPOSE_TAGBITS;
116 
117 	sigemptyset(&sa.sa_mask);
118 	sigaction(signal, &sa, NULL);
119 }
120 
121 void mte_wait_after_trig(void)
122 {
123 	sched_yield();
124 }
125 
126 void *mte_insert_tags(void *ptr, size_t size)
127 {
128 	void *tag_ptr;
129 	int align_size;
130 
131 	if (!ptr || (unsigned long)(ptr) & MT_ALIGN_GRANULE) {
132 		ksft_print_msg("FAIL: Addr=%p: invalid\n", ptr);
133 		return NULL;
134 	}
135 	align_size = MT_ALIGN_UP(size);
136 	tag_ptr = mte_insert_random_tag(ptr);
137 	mte_set_tag_address_range(tag_ptr, align_size);
138 	return tag_ptr;
139 }
140 
141 void mte_clear_tags(void *ptr, size_t size)
142 {
143 	if (!ptr || (unsigned long)(ptr) & MT_ALIGN_GRANULE) {
144 		ksft_print_msg("FAIL: Addr=%p: invalid\n", ptr);
145 		return;
146 	}
147 	size = MT_ALIGN_UP(size);
148 	ptr = (void *)MT_CLEAR_TAG((unsigned long)ptr);
149 	mte_clear_tag_address_range(ptr, size);
150 }
151 
152 void *mte_insert_atag(void *ptr)
153 {
154 	unsigned char atag;
155 
156 	atag =  mtefar_support ? (random() % MT_ATAG_MASK) + 1 : 0;
157 	return (void *)MT_SET_ATAG((unsigned long)ptr, atag);
158 }
159 
160 void *mte_clear_atag(void *ptr)
161 {
162 	return (void *)MT_CLEAR_ATAG((unsigned long)ptr);
163 }
164 
165 static void *__mte_allocate_memory_range(size_t size, int mem_type, int mapping,
166 					 size_t range_before, size_t range_after,
167 					 bool tags, int fd)
168 {
169 	void *ptr;
170 	int prot_flag, map_flag;
171 	size_t entire_size = size + range_before + range_after;
172 
173 	switch (mem_type) {
174 	case USE_MALLOC:
175 		return malloc(entire_size) + range_before;
176 	case USE_MMAP:
177 	case USE_MPROTECT:
178 		break;
179 	default:
180 		ksft_print_msg("FAIL: Invalid allocate request\n");
181 		return NULL;
182 	}
183 
184 	prot_flag = PROT_READ | PROT_WRITE;
185 	if (mem_type == USE_MMAP)
186 		prot_flag |= PROT_MTE;
187 
188 	map_flag = mapping;
189 	if (fd == -1)
190 		map_flag = MAP_ANONYMOUS | map_flag;
191 	if (!(mapping & MAP_SHARED))
192 		map_flag |= MAP_PRIVATE;
193 	ptr = mmap(NULL, entire_size, prot_flag, map_flag, fd, 0);
194 	if (ptr == MAP_FAILED) {
195 		ksft_perror("mmap()");
196 		return NULL;
197 	}
198 	if (mem_type == USE_MPROTECT) {
199 		if (mprotect(ptr, entire_size, prot_flag | PROT_MTE)) {
200 			ksft_perror("mprotect(PROT_MTE)");
201 			munmap(ptr, size);
202 			return NULL;
203 		}
204 	}
205 	if (tags)
206 		ptr = mte_insert_tags(ptr + range_before, size);
207 	return ptr;
208 }
209 
210 void *mte_allocate_memory_tag_range(size_t size, int mem_type, int mapping,
211 				    size_t range_before, size_t range_after)
212 {
213 	return __mte_allocate_memory_range(size, mem_type, mapping, range_before,
214 					   range_after, true, -1);
215 }
216 
217 void *mte_allocate_memory(size_t size, int mem_type, int mapping, bool tags)
218 {
219 	return __mte_allocate_memory_range(size, mem_type, mapping, 0, 0, tags, -1);
220 }
221 
222 void *mte_allocate_file_memory(size_t size, int mem_type, int mapping, bool tags, int fd)
223 {
224 	int index;
225 	char buffer[INIT_BUFFER_SIZE];
226 
227 	if (mem_type != USE_MPROTECT && mem_type != USE_MMAP) {
228 		ksft_print_msg("FAIL: Invalid mmap file request\n");
229 		return NULL;
230 	}
231 	/* Initialize the file for mappable size */
232 	lseek(fd, 0, SEEK_SET);
233 	for (index = INIT_BUFFER_SIZE; index < size; index += INIT_BUFFER_SIZE) {
234 		if (write(fd, buffer, INIT_BUFFER_SIZE) != INIT_BUFFER_SIZE) {
235 			ksft_perror("initialising buffer");
236 			return NULL;
237 		}
238 	}
239 	index -= INIT_BUFFER_SIZE;
240 	if (write(fd, buffer, size - index) != size - index) {
241 		ksft_perror("initialising buffer");
242 		return NULL;
243 	}
244 	return __mte_allocate_memory_range(size, mem_type, mapping, 0, 0, tags, fd);
245 }
246 
247 void *mte_allocate_file_memory_tag_range(size_t size, int mem_type, int mapping,
248 					 size_t range_before, size_t range_after, int fd)
249 {
250 	int index;
251 	char buffer[INIT_BUFFER_SIZE];
252 	int map_size = size + range_before + range_after;
253 
254 	if (mem_type != USE_MPROTECT && mem_type != USE_MMAP) {
255 		ksft_print_msg("FAIL: Invalid mmap file request\n");
256 		return NULL;
257 	}
258 	/* Initialize the file for mappable size */
259 	lseek(fd, 0, SEEK_SET);
260 	for (index = INIT_BUFFER_SIZE; index < map_size; index += INIT_BUFFER_SIZE)
261 		if (write(fd, buffer, INIT_BUFFER_SIZE) != INIT_BUFFER_SIZE) {
262 			ksft_perror("initialising buffer");
263 			return NULL;
264 		}
265 	index -= INIT_BUFFER_SIZE;
266 	if (write(fd, buffer, map_size - index) != map_size - index) {
267 		ksft_perror("initialising buffer");
268 		return NULL;
269 	}
270 	return __mte_allocate_memory_range(size, mem_type, mapping, range_before,
271 					   range_after, true, fd);
272 }
273 
274 static void __mte_free_memory_range(void *ptr, size_t size, int mem_type,
275 				    size_t range_before, size_t range_after, bool tags)
276 {
277 	switch (mem_type) {
278 	case USE_MALLOC:
279 		free(ptr - range_before);
280 		break;
281 	case USE_MMAP:
282 	case USE_MPROTECT:
283 		if (tags)
284 			mte_clear_tags(ptr, size);
285 		munmap(ptr - range_before, size + range_before + range_after);
286 		break;
287 	default:
288 		ksft_print_msg("FAIL: Invalid free request\n");
289 		break;
290 	}
291 }
292 
293 void mte_free_memory_tag_range(void *ptr, size_t size, int mem_type,
294 			       size_t range_before, size_t range_after)
295 {
296 	__mte_free_memory_range(ptr, size, mem_type, range_before, range_after, true);
297 }
298 
299 void mte_free_memory(void *ptr, size_t size, int mem_type, bool tags)
300 {
301 	__mte_free_memory_range(ptr, size, mem_type, 0, 0, tags);
302 }
303 
304 void mte_initialize_current_context(int mode, uintptr_t ptr, ssize_t range)
305 {
306 	cur_mte_cxt.fault_valid = false;
307 	cur_mte_cxt.trig_addr = ptr;
308 	cur_mte_cxt.trig_range = range;
309 	if (mode == MTE_SYNC_ERR)
310 		cur_mte_cxt.trig_si_code = SEGV_MTESERR;
311 	else if (mode == MTE_ASYNC_ERR)
312 		cur_mte_cxt.trig_si_code = SEGV_MTEAERR;
313 	else
314 		cur_mte_cxt.trig_si_code = 0;
315 }
316 
317 int mte_switch_mode(int mte_option, unsigned long incl_mask)
318 {
319 	unsigned long en = 0;
320 
321 	switch (mte_option) {
322 	case MTE_NONE_ERR:
323 	case MTE_SYNC_ERR:
324 	case MTE_ASYNC_ERR:
325 		break;
326 	default:
327 		ksft_print_msg("FAIL: Invalid MTE option %x\n", mte_option);
328 		return -EINVAL;
329 	}
330 
331 	if (incl_mask & ~MT_INCLUDE_TAG_MASK) {
332 		ksft_print_msg("FAIL: Invalid incl_mask %lx\n", incl_mask);
333 		return -EINVAL;
334 	}
335 
336 	en = PR_TAGGED_ADDR_ENABLE;
337 	switch (mte_option) {
338 	case MTE_SYNC_ERR:
339 		en |= PR_MTE_TCF_SYNC;
340 		break;
341 	case MTE_ASYNC_ERR:
342 		en |= PR_MTE_TCF_ASYNC;
343 		break;
344 	case MTE_NONE_ERR:
345 		en |= PR_MTE_TCF_NONE;
346 		break;
347 	}
348 
349 	en |= (incl_mask << PR_MTE_TAG_SHIFT);
350 	/* Enable address tagging ABI, mte error reporting mode and tag inclusion mask. */
351 	if (prctl(PR_SET_TAGGED_ADDR_CTRL, en, 0, 0, 0) != 0) {
352 		ksft_print_msg("FAIL:prctl PR_SET_TAGGED_ADDR_CTRL for mte mode\n");
353 		return -EINVAL;
354 	}
355 	return 0;
356 }
357 
358 int mte_default_setup(void)
359 {
360 	unsigned long hwcaps2 = getauxval(AT_HWCAP2);
361 	unsigned long hwcaps3 = getauxval(AT_HWCAP3);
362 	unsigned long en = 0;
363 	int ret;
364 
365 	/* To generate random address tag */
366 	srandom(time(NULL));
367 
368 	if (!(hwcaps2 & HWCAP2_MTE))
369 		ksft_exit_skip("MTE features unavailable\n");
370 
371 	mtefar_support = !!(hwcaps3 & HWCAP3_MTE_FAR);
372 
373 	/* Get current mte mode */
374 	ret = prctl(PR_GET_TAGGED_ADDR_CTRL, en, 0, 0, 0);
375 	if (ret < 0) {
376 		ksft_print_msg("FAIL:prctl PR_GET_TAGGED_ADDR_CTRL with error =%d\n", ret);
377 		return KSFT_FAIL;
378 	}
379 	if (ret & PR_MTE_TCF_SYNC)
380 		mte_cur_mode = MTE_SYNC_ERR;
381 	else if (ret & PR_MTE_TCF_ASYNC)
382 		mte_cur_mode = MTE_ASYNC_ERR;
383 	else if (ret & PR_MTE_TCF_NONE)
384 		mte_cur_mode = MTE_NONE_ERR;
385 
386 	mte_cur_pstate_tco = mte_get_pstate_tco();
387 	/* Disable PSTATE.TCO */
388 	mte_disable_pstate_tco();
389 	return 0;
390 }
391 
392 void mte_restore_setup(void)
393 {
394 	mte_switch_mode(mte_cur_mode, MTE_ALLOW_NON_ZERO_TAG);
395 	if (mte_cur_pstate_tco == MT_PSTATE_TCO_EN)
396 		mte_enable_pstate_tco();
397 	else if (mte_cur_pstate_tco == MT_PSTATE_TCO_DIS)
398 		mte_disable_pstate_tco();
399 }
400 
401 int create_temp_file(void)
402 {
403 	int fd;
404 	char filename[] = "/dev/shm/tmp_XXXXXX";
405 
406 	/* Create a file in the tmpfs filesystem */
407 	fd = mkstemp(&filename[0]);
408 	if (fd == -1) {
409 		ksft_perror(filename);
410 		ksft_print_msg("FAIL: Unable to open temporary file\n");
411 		return 0;
412 	}
413 	unlink(&filename[0]);
414 	return fd;
415 }
416