xref: /freebsd/contrib/ofed/libibverbs/memory.c (revision e6bfd18d21b225af6a0ed67ceeaf1293b7b9eba5)
1 /*
2  * Copyright (c) 2004, 2005 Topspin Communications.  All rights reserved.
3  * Copyright (c) 2006 Cisco Systems, Inc.  All rights reserved.
4  *
5  * This software is available to you under a choice of one of two
6  * licenses.  You may choose to be licensed under the terms of the GNU
7  * General Public License (GPL) Version 2, available from the file
8  * COPYING in the main directory of this source tree, or the
9  * OpenIB.org BSD license below:
10  *
11  *     Redistribution and use in source and binary forms, with or
12  *     without modification, are permitted provided that the following
13  *     conditions are met:
14  *
15  *      - Redistributions of source code must retain the above
16  *        copyright notice, this list of conditions and the following
17  *        disclaimer.
18  *
19  *      - Redistributions in binary form must reproduce the above
20  *        copyright notice, this list of conditions and the following
21  *        disclaimer in the documentation and/or other materials
22  *        provided with the distribution.
23  *
24  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
25  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
26  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
27  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
28  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
29  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
30  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31  * SOFTWARE.
32  */
33 
34 #include <config.h>
35 
36 #include <errno.h>
37 #include <sys/mman.h>
38 #include <unistd.h>
39 #include <stdlib.h>
40 #include <stdint.h>
41 #include <stdio.h>
42 #include <string.h>
43 #include <dirent.h>
44 #include <limits.h>
45 #include <inttypes.h>
46 
47 #include "ibverbs.h"
48 
49 struct ibv_mem_node {
50 	enum {
51 		IBV_RED,
52 		IBV_BLACK
53 	}			color;
54 	struct ibv_mem_node    *parent;
55 	struct ibv_mem_node    *left, *right;
56 	uintptr_t		start, end;
57 	int			refcnt;
58 };
59 
60 static struct ibv_mem_node *mm_root;
61 static pthread_mutex_t mm_mutex = PTHREAD_MUTEX_INITIALIZER;
62 static int page_size;
63 static int huge_page_enabled;
64 static int too_late;
65 
66 static unsigned long smaps_page_size(FILE *file)
67 {
68 	int n;
69 	unsigned long size = page_size;
70 	char buf[1024];
71 
72 	while (fgets(buf, sizeof(buf), file) != NULL) {
73 		if (!strstr(buf, "KernelPageSize:"))
74 			continue;
75 
76 		n = sscanf(buf, "%*s %lu", &size);
77 		if (n < 1)
78 			continue;
79 
80 		/* page size is printed in Kb */
81 		size = size * 1024;
82 
83 		break;
84 	}
85 
86 	return size;
87 }
88 
89 static unsigned long get_page_size(void *base)
90 {
91 	unsigned long ret = page_size;
92 	pid_t pid;
93 	FILE *file;
94 	char buf[1024];
95 
96 	pid = getpid();
97 	snprintf(buf, sizeof(buf), "/proc/%d/smaps", pid);
98 
99 	file = fopen(buf, "r" STREAM_CLOEXEC);
100 	if (!file)
101 		goto out;
102 
103 	while (fgets(buf, sizeof(buf), file) != NULL) {
104 		int n;
105 		uintptr_t range_start, range_end;
106 
107 		n = sscanf(buf, "%" SCNxPTR "-%" SCNxPTR, &range_start, &range_end);
108 
109 		if (n < 2)
110 			continue;
111 
112 		if ((uintptr_t) base >= range_start && (uintptr_t) base < range_end) {
113 			ret = smaps_page_size(file);
114 			break;
115 		}
116 	}
117 
118 	fclose(file);
119 
120 out:
121 	return ret;
122 }
123 
124 int ibv_fork_init(void)
125 {
126 	void *tmp, *tmp_aligned;
127 	int ret;
128 	unsigned long size;
129 
130 	if (getenv("RDMAV_HUGEPAGES_SAFE"))
131 		huge_page_enabled = 1;
132 
133 	if (mm_root)
134 		return 0;
135 
136 	if (too_late)
137 		return EINVAL;
138 
139 	page_size = sysconf(_SC_PAGESIZE);
140 	if (page_size < 0)
141 		return errno;
142 
143 	if (posix_memalign(&tmp, page_size, page_size))
144 		return ENOMEM;
145 
146 	if (huge_page_enabled) {
147 		size = get_page_size(tmp);
148 		tmp_aligned = (void *) ((uintptr_t) tmp & ~(size - 1));
149 	} else {
150 		size = page_size;
151 		tmp_aligned = tmp;
152 	}
153 
154 	ret = madvise(tmp_aligned, size, MADV_DONTFORK) ||
155 	      madvise(tmp_aligned, size, MADV_DOFORK);
156 
157 	free(tmp);
158 
159 	if (ret)
160 		return ENOSYS;
161 
162 	mm_root = malloc(sizeof *mm_root);
163 	if (!mm_root)
164 		return ENOMEM;
165 
166 	mm_root->parent = NULL;
167 	mm_root->left   = NULL;
168 	mm_root->right  = NULL;
169 	mm_root->color  = IBV_BLACK;
170 	mm_root->start  = 0;
171 	mm_root->end    = UINTPTR_MAX;
172 	mm_root->refcnt = 0;
173 
174 	return 0;
175 }
176 
177 static struct ibv_mem_node *__mm_prev(struct ibv_mem_node *node)
178 {
179 	if (node->left) {
180 		node = node->left;
181 		while (node->right)
182 			node = node->right;
183 	} else {
184 		while (node->parent && node == node->parent->left)
185 			node = node->parent;
186 
187 		node = node->parent;
188 	}
189 
190 	return node;
191 }
192 
193 static struct ibv_mem_node *__mm_next(struct ibv_mem_node *node)
194 {
195 	if (node->right) {
196 		node = node->right;
197 		while (node->left)
198 			node = node->left;
199 	} else {
200 		while (node->parent && node == node->parent->right)
201 			node = node->parent;
202 
203 		node = node->parent;
204 	}
205 
206 	return node;
207 }
208 
209 static void __mm_rotate_right(struct ibv_mem_node *node)
210 {
211 	struct ibv_mem_node *tmp;
212 
213 	tmp = node->left;
214 
215 	node->left = tmp->right;
216 	if (node->left)
217 		node->left->parent = node;
218 
219 	if (node->parent) {
220 		if (node->parent->right == node)
221 			node->parent->right = tmp;
222 		else
223 			node->parent->left = tmp;
224 	} else
225 		mm_root = tmp;
226 
227 	tmp->parent = node->parent;
228 
229 	tmp->right = node;
230 	node->parent = tmp;
231 }
232 
233 static void __mm_rotate_left(struct ibv_mem_node *node)
234 {
235 	struct ibv_mem_node *tmp;
236 
237 	tmp = node->right;
238 
239 	node->right = tmp->left;
240 	if (node->right)
241 		node->right->parent = node;
242 
243 	if (node->parent) {
244 		if (node->parent->right == node)
245 			node->parent->right = tmp;
246 		else
247 			node->parent->left = tmp;
248 	} else
249 		mm_root = tmp;
250 
251 	tmp->parent = node->parent;
252 
253 	tmp->left = node;
254 	node->parent = tmp;
255 }
256 
257 #if 0
258 static int verify(struct ibv_mem_node *node)
259 {
260 	int hl, hr;
261 
262 	if (!node)
263 		return 1;
264 
265 	hl = verify(node->left);
266 	hr = verify(node->left);
267 
268 	if (!hl || !hr)
269 		return 0;
270 	if (hl != hr)
271 		return 0;
272 
273 	if (node->color == IBV_RED) {
274 		if (node->left && node->left->color != IBV_BLACK)
275 			return 0;
276 		if (node->right && node->right->color != IBV_BLACK)
277 			return 0;
278 		return hl;
279 	}
280 
281 	return hl + 1;
282 }
283 #endif
284 
285 static void __mm_add_rebalance(struct ibv_mem_node *node)
286 {
287 	struct ibv_mem_node *parent, *gp, *uncle;
288 
289 	while (node->parent && node->parent->color == IBV_RED) {
290 		parent = node->parent;
291 		gp     = node->parent->parent;
292 
293 		if (parent == gp->left) {
294 			uncle = gp->right;
295 
296 			if (uncle && uncle->color == IBV_RED) {
297 				parent->color = IBV_BLACK;
298 				uncle->color  = IBV_BLACK;
299 				gp->color     = IBV_RED;
300 
301 				node = gp;
302 			} else {
303 				if (node == parent->right) {
304 					__mm_rotate_left(parent);
305 					node   = parent;
306 					parent = node->parent;
307 				}
308 
309 				parent->color = IBV_BLACK;
310 				gp->color     = IBV_RED;
311 
312 				__mm_rotate_right(gp);
313 			}
314 		} else {
315 			uncle = gp->left;
316 
317 			if (uncle && uncle->color == IBV_RED) {
318 				parent->color = IBV_BLACK;
319 				uncle->color  = IBV_BLACK;
320 				gp->color     = IBV_RED;
321 
322 				node = gp;
323 			} else {
324 				if (node == parent->left) {
325 					__mm_rotate_right(parent);
326 					node   = parent;
327 					parent = node->parent;
328 				}
329 
330 				parent->color = IBV_BLACK;
331 				gp->color     = IBV_RED;
332 
333 				__mm_rotate_left(gp);
334 			}
335 		}
336 	}
337 
338 	mm_root->color = IBV_BLACK;
339 }
340 
341 static void __mm_add(struct ibv_mem_node *new)
342 {
343 	struct ibv_mem_node *node, *parent = NULL;
344 
345 	node = mm_root;
346 	while (node) {
347 		parent = node;
348 		if (node->start < new->start)
349 			node = node->right;
350 		else
351 			node = node->left;
352 	}
353 
354 	if (parent->start < new->start)
355 		parent->right = new;
356 	else
357 		parent->left = new;
358 
359 	new->parent = parent;
360 	new->left   = NULL;
361 	new->right  = NULL;
362 
363 	new->color = IBV_RED;
364 	__mm_add_rebalance(new);
365 }
366 
367 static void __mm_remove(struct ibv_mem_node *node)
368 {
369 	struct ibv_mem_node *child, *parent, *sib, *tmp;
370 	int nodecol;
371 
372 	if (node->left && node->right) {
373 		tmp = node->left;
374 		while (tmp->right)
375 			tmp = tmp->right;
376 
377 		nodecol    = tmp->color;
378 		child      = tmp->left;
379 		tmp->color = node->color;
380 
381 		if (tmp->parent != node) {
382 			parent        = tmp->parent;
383 			parent->right = tmp->left;
384 			if (tmp->left)
385 				tmp->left->parent = parent;
386 
387 			tmp->left   	   = node->left;
388 			node->left->parent = tmp;
389 		} else
390 			parent = tmp;
391 
392 		tmp->right          = node->right;
393 		node->right->parent = tmp;
394 
395 		tmp->parent = node->parent;
396 		if (node->parent) {
397 			if (node->parent->left == node)
398 				node->parent->left = tmp;
399 			else
400 				node->parent->right = tmp;
401 		} else
402 			mm_root = tmp;
403 	} else {
404 		nodecol = node->color;
405 
406 		child  = node->left ? node->left : node->right;
407 		parent = node->parent;
408 
409 		if (child)
410 			child->parent = parent;
411 		if (parent) {
412 			if (parent->left == node)
413 				parent->left = child;
414 			else
415 				parent->right = child;
416 		} else
417 			mm_root = child;
418 	}
419 
420 	free(node);
421 
422 	if (nodecol == IBV_RED)
423 		return;
424 
425 	while ((!child || child->color == IBV_BLACK) && child != mm_root) {
426 		if (parent->left == child) {
427 			sib = parent->right;
428 
429 			if (sib->color == IBV_RED) {
430 				parent->color = IBV_RED;
431 				sib->color    = IBV_BLACK;
432 				__mm_rotate_left(parent);
433 				sib = parent->right;
434 			}
435 
436 			if ((!sib->left  || sib->left->color  == IBV_BLACK) &&
437 			    (!sib->right || sib->right->color == IBV_BLACK)) {
438 				sib->color = IBV_RED;
439 				child  = parent;
440 				parent = child->parent;
441 			} else {
442 				if (!sib->right || sib->right->color == IBV_BLACK) {
443 					if (sib->left)
444 						sib->left->color = IBV_BLACK;
445 					sib->color = IBV_RED;
446 					__mm_rotate_right(sib);
447 					sib = parent->right;
448 				}
449 
450 				sib->color    = parent->color;
451 				parent->color = IBV_BLACK;
452 				if (sib->right)
453 					sib->right->color = IBV_BLACK;
454 				__mm_rotate_left(parent);
455 				child = mm_root;
456 				break;
457 			}
458 		} else {
459 			sib = parent->left;
460 
461 			if (sib->color == IBV_RED) {
462 				parent->color = IBV_RED;
463 				sib->color    = IBV_BLACK;
464 				__mm_rotate_right(parent);
465 				sib = parent->left;
466 			}
467 
468 			if ((!sib->left  || sib->left->color  == IBV_BLACK) &&
469 			    (!sib->right || sib->right->color == IBV_BLACK)) {
470 				sib->color = IBV_RED;
471 				child  = parent;
472 				parent = child->parent;
473 			} else {
474 				if (!sib->left || sib->left->color == IBV_BLACK) {
475 					if (sib->right)
476 						sib->right->color = IBV_BLACK;
477 					sib->color = IBV_RED;
478 					__mm_rotate_left(sib);
479 					sib = parent->left;
480 				}
481 
482 				sib->color    = parent->color;
483 				parent->color = IBV_BLACK;
484 				if (sib->left)
485 					sib->left->color = IBV_BLACK;
486 				__mm_rotate_right(parent);
487 				child = mm_root;
488 				break;
489 			}
490 		}
491 	}
492 
493 	if (child)
494 		child->color = IBV_BLACK;
495 }
496 
497 static struct ibv_mem_node *__mm_find_start(uintptr_t start, uintptr_t end)
498 {
499 	struct ibv_mem_node *node = mm_root;
500 
501 	while (node) {
502 		if (node->start <= start && node->end >= start)
503 			break;
504 
505 		if (node->start < start)
506 			node = node->right;
507 		else
508 			node = node->left;
509 	}
510 
511 	return node;
512 }
513 
514 static struct ibv_mem_node *merge_ranges(struct ibv_mem_node *node,
515 					 struct ibv_mem_node *prev)
516 {
517 	prev->end = node->end;
518 	prev->refcnt = node->refcnt;
519 	__mm_remove(node);
520 
521 	return prev;
522 }
523 
524 static struct ibv_mem_node *split_range(struct ibv_mem_node *node,
525 					uintptr_t cut_line)
526 {
527 	struct ibv_mem_node *new_node = NULL;
528 
529 	new_node = malloc(sizeof *new_node);
530 	if (!new_node)
531 		return NULL;
532 	new_node->start  = cut_line;
533 	new_node->end    = node->end;
534 	new_node->refcnt = node->refcnt;
535 	node->end  = cut_line - 1;
536 	__mm_add(new_node);
537 
538 	return new_node;
539 }
540 
541 static struct ibv_mem_node *get_start_node(uintptr_t start, uintptr_t end,
542 					   int inc)
543 {
544 	struct ibv_mem_node *node, *tmp = NULL;
545 
546 	node = __mm_find_start(start, end);
547 	if (node->start < start)
548 		node = split_range(node, start);
549 	else {
550 		tmp = __mm_prev(node);
551 		if (tmp && tmp->refcnt == node->refcnt + inc)
552 			node = merge_ranges(node, tmp);
553 	}
554 	return node;
555 }
556 
557 /*
558  * This function is called if madvise() fails to undo merging/splitting
559  * operations performed on the node.
560  */
561 static struct ibv_mem_node *undo_node(struct ibv_mem_node *node,
562 				      uintptr_t start, int inc)
563 {
564 	struct ibv_mem_node *tmp = NULL;
565 
566 	/*
567 	 * This condition can be true only if we merged this
568 	 * node with the previous one, so we need to split them.
569 	*/
570 	if (start > node->start) {
571 		tmp = split_range(node, start);
572 		if (tmp) {
573 			node->refcnt += inc;
574 			node = tmp;
575 		} else
576 			return NULL;
577 	}
578 
579 	tmp  =  __mm_prev(node);
580 	if (tmp && tmp->refcnt == node->refcnt)
581 		node = merge_ranges(node, tmp);
582 
583 	tmp  =  __mm_next(node);
584 	if (tmp && tmp->refcnt == node->refcnt)
585 		node = merge_ranges(tmp, node);
586 
587 	return node;
588 }
589 
590 static int ibv_madvise_range(void *base, size_t size, int advice)
591 {
592 	uintptr_t start, end;
593 	struct ibv_mem_node *node, *tmp;
594 	int inc;
595 	int rolling_back = 0;
596 	int ret = 0;
597 	unsigned long range_page_size;
598 
599 	if (!size)
600 		return 0;
601 
602 	if (huge_page_enabled)
603 		range_page_size = get_page_size(base);
604 	else
605 		range_page_size = page_size;
606 
607 	start = (uintptr_t) base & ~(range_page_size - 1);
608 	end   = ((uintptr_t) (base + size + range_page_size - 1) &
609 		 ~(range_page_size - 1)) - 1;
610 
611 	pthread_mutex_lock(&mm_mutex);
612 again:
613 	inc = advice == MADV_DONTFORK ? 1 : -1;
614 
615 	node = get_start_node(start, end, inc);
616 	if (!node) {
617 		ret = -1;
618 		goto out;
619 	}
620 
621 	while (node && node->start <= end) {
622 		if (node->end > end) {
623 			if (!split_range(node, end + 1)) {
624 				ret = -1;
625 				goto out;
626 			}
627 		}
628 
629 		if ((inc == -1 && node->refcnt == 1) ||
630 		    (inc ==  1 && node->refcnt == 0)) {
631 			/*
632 			 * If this is the first time through the loop,
633 			 * and we merged this node with the previous
634 			 * one, then we only want to do the madvise()
635 			 * on start ... node->end (rather than
636 			 * starting at node->start).
637 			 *
638 			 * Otherwise we end up doing madvise() on
639 			 * bigger region than we're being asked to,
640 			 * and that may lead to a spurious failure.
641 			 */
642 			if (start > node->start)
643 				ret = madvise((void *) start, node->end - start + 1,
644 					      advice);
645 			else
646 				ret = madvise((void *) node->start,
647 					      node->end - node->start + 1,
648 					      advice);
649 			if (ret) {
650 				node = undo_node(node, start, inc);
651 
652 				if (rolling_back || !node)
653 					goto out;
654 
655 				/* madvise failed, roll back previous changes */
656 				rolling_back = 1;
657 				advice = advice == MADV_DONTFORK ?
658 					MADV_DOFORK : MADV_DONTFORK;
659 				tmp = __mm_prev(node);
660 				if (!tmp || start > tmp->end)
661 					goto out;
662 				end = tmp->end;
663 				goto again;
664 			}
665 		}
666 
667 		node->refcnt += inc;
668 		node = __mm_next(node);
669 	}
670 
671 	if (node) {
672 		tmp = __mm_prev(node);
673 		if (tmp && node->refcnt == tmp->refcnt)
674 			node = merge_ranges(node, tmp);
675 	}
676 
677 out:
678 	if (rolling_back)
679 		ret = -1;
680 
681 	pthread_mutex_unlock(&mm_mutex);
682 
683 	return ret;
684 }
685 
686 int ibv_dontfork_range(void *base, size_t size)
687 {
688 	if (mm_root)
689 		return ibv_madvise_range(base, size, MADV_DONTFORK);
690 	else {
691 		too_late = 1;
692 		return 0;
693 	}
694 }
695 
696 int ibv_dofork_range(void *base, size_t size)
697 {
698 	if (mm_root)
699 		return ibv_madvise_range(base, size, MADV_DOFORK);
700 	else {
701 		too_late = 1;
702 		return 0;
703 	}
704 }
705