xref: /linux/lib/list_sort.c (revision e0bf6c5ca2d3281f231c5f0c9bf145e9513644de)
1 
2 #define pr_fmt(fmt) "list_sort_test: " fmt
3 
4 #include <linux/kernel.h>
5 #include <linux/bug.h>
6 #include <linux/compiler.h>
7 #include <linux/export.h>
8 #include <linux/string.h>
9 #include <linux/list_sort.h>
10 #include <linux/list.h>
11 
12 #define MAX_LIST_LENGTH_BITS 20
13 
14 /*
15  * Returns a list organized in an intermediate format suited
16  * to chaining of merge() calls: null-terminated, no reserved or
17  * sentinel head node, "prev" links not maintained.
18  */
19 static struct list_head *merge(void *priv,
20 				int (*cmp)(void *priv, struct list_head *a,
21 					struct list_head *b),
22 				struct list_head *a, struct list_head *b)
23 {
24 	struct list_head head, *tail = &head;
25 
26 	while (a && b) {
27 		/* if equal, take 'a' -- important for sort stability */
28 		if ((*cmp)(priv, a, b) <= 0) {
29 			tail->next = a;
30 			a = a->next;
31 		} else {
32 			tail->next = b;
33 			b = b->next;
34 		}
35 		tail = tail->next;
36 	}
37 	tail->next = a?:b;
38 	return head.next;
39 }
40 
41 /*
42  * Combine final list merge with restoration of standard doubly-linked
43  * list structure.  This approach duplicates code from merge(), but
44  * runs faster than the tidier alternatives of either a separate final
45  * prev-link restoration pass, or maintaining the prev links
46  * throughout.
47  */
48 static void merge_and_restore_back_links(void *priv,
49 				int (*cmp)(void *priv, struct list_head *a,
50 					struct list_head *b),
51 				struct list_head *head,
52 				struct list_head *a, struct list_head *b)
53 {
54 	struct list_head *tail = head;
55 	u8 count = 0;
56 
57 	while (a && b) {
58 		/* if equal, take 'a' -- important for sort stability */
59 		if ((*cmp)(priv, a, b) <= 0) {
60 			tail->next = a;
61 			a->prev = tail;
62 			a = a->next;
63 		} else {
64 			tail->next = b;
65 			b->prev = tail;
66 			b = b->next;
67 		}
68 		tail = tail->next;
69 	}
70 	tail->next = a ? : b;
71 
72 	do {
73 		/*
74 		 * In worst cases this loop may run many iterations.
75 		 * Continue callbacks to the client even though no
76 		 * element comparison is needed, so the client's cmp()
77 		 * routine can invoke cond_resched() periodically.
78 		 */
79 		if (unlikely(!(++count)))
80 			(*cmp)(priv, tail->next, tail->next);
81 
82 		tail->next->prev = tail;
83 		tail = tail->next;
84 	} while (tail->next);
85 
86 	tail->next = head;
87 	head->prev = tail;
88 }
89 
90 /**
91  * list_sort - sort a list
92  * @priv: private data, opaque to list_sort(), passed to @cmp
93  * @head: the list to sort
94  * @cmp: the elements comparison function
95  *
96  * This function implements "merge sort", which has O(nlog(n))
97  * complexity.
98  *
99  * The comparison function @cmp must return a negative value if @a
100  * should sort before @b, and a positive value if @a should sort after
101  * @b. If @a and @b are equivalent, and their original relative
102  * ordering is to be preserved, @cmp must return 0.
103  */
104 void list_sort(void *priv, struct list_head *head,
105 		int (*cmp)(void *priv, struct list_head *a,
106 			struct list_head *b))
107 {
108 	struct list_head *part[MAX_LIST_LENGTH_BITS+1]; /* sorted partial lists
109 						-- last slot is a sentinel */
110 	int lev;  /* index into part[] */
111 	int max_lev = 0;
112 	struct list_head *list;
113 
114 	if (list_empty(head))
115 		return;
116 
117 	memset(part, 0, sizeof(part));
118 
119 	head->prev->next = NULL;
120 	list = head->next;
121 
122 	while (list) {
123 		struct list_head *cur = list;
124 		list = list->next;
125 		cur->next = NULL;
126 
127 		for (lev = 0; part[lev]; lev++) {
128 			cur = merge(priv, cmp, part[lev], cur);
129 			part[lev] = NULL;
130 		}
131 		if (lev > max_lev) {
132 			if (unlikely(lev >= ARRAY_SIZE(part)-1)) {
133 				printk_once(KERN_DEBUG "list too long for efficiency\n");
134 				lev--;
135 			}
136 			max_lev = lev;
137 		}
138 		part[lev] = cur;
139 	}
140 
141 	for (lev = 0; lev < max_lev; lev++)
142 		if (part[lev])
143 			list = merge(priv, cmp, part[lev], list);
144 
145 	merge_and_restore_back_links(priv, cmp, head, part[max_lev], list);
146 }
147 EXPORT_SYMBOL(list_sort);
148 
149 #ifdef CONFIG_TEST_LIST_SORT
150 
151 #include <linux/slab.h>
152 #include <linux/random.h>
153 
154 /*
155  * The pattern of set bits in the list length determines which cases
156  * are hit in list_sort().
157  */
158 #define TEST_LIST_LEN (512+128+2) /* not including head */
159 
160 #define TEST_POISON1 0xDEADBEEF
161 #define TEST_POISON2 0xA324354C
162 
163 struct debug_el {
164 	unsigned int poison1;
165 	struct list_head list;
166 	unsigned int poison2;
167 	int value;
168 	unsigned serial;
169 };
170 
171 /* Array, containing pointers to all elements in the test list */
172 static struct debug_el **elts __initdata;
173 
174 static int __init check(struct debug_el *ela, struct debug_el *elb)
175 {
176 	if (ela->serial >= TEST_LIST_LEN) {
177 		pr_err("error: incorrect serial %d\n", ela->serial);
178 		return -EINVAL;
179 	}
180 	if (elb->serial >= TEST_LIST_LEN) {
181 		pr_err("error: incorrect serial %d\n", elb->serial);
182 		return -EINVAL;
183 	}
184 	if (elts[ela->serial] != ela || elts[elb->serial] != elb) {
185 		pr_err("error: phantom element\n");
186 		return -EINVAL;
187 	}
188 	if (ela->poison1 != TEST_POISON1 || ela->poison2 != TEST_POISON2) {
189 		pr_err("error: bad poison: %#x/%#x\n",
190 			ela->poison1, ela->poison2);
191 		return -EINVAL;
192 	}
193 	if (elb->poison1 != TEST_POISON1 || elb->poison2 != TEST_POISON2) {
194 		pr_err("error: bad poison: %#x/%#x\n",
195 			elb->poison1, elb->poison2);
196 		return -EINVAL;
197 	}
198 	return 0;
199 }
200 
201 static int __init cmp(void *priv, struct list_head *a, struct list_head *b)
202 {
203 	struct debug_el *ela, *elb;
204 
205 	ela = container_of(a, struct debug_el, list);
206 	elb = container_of(b, struct debug_el, list);
207 
208 	check(ela, elb);
209 	return ela->value - elb->value;
210 }
211 
212 static int __init list_sort_test(void)
213 {
214 	int i, count = 1, err = -ENOMEM;
215 	struct debug_el *el;
216 	struct list_head *cur;
217 	LIST_HEAD(head);
218 
219 	pr_debug("start testing list_sort()\n");
220 
221 	elts = kcalloc(TEST_LIST_LEN, sizeof(*elts), GFP_KERNEL);
222 	if (!elts) {
223 		pr_err("error: cannot allocate memory\n");
224 		return err;
225 	}
226 
227 	for (i = 0; i < TEST_LIST_LEN; i++) {
228 		el = kmalloc(sizeof(*el), GFP_KERNEL);
229 		if (!el) {
230 			pr_err("error: cannot allocate memory\n");
231 			goto exit;
232 		}
233 		 /* force some equivalencies */
234 		el->value = prandom_u32() % (TEST_LIST_LEN / 3);
235 		el->serial = i;
236 		el->poison1 = TEST_POISON1;
237 		el->poison2 = TEST_POISON2;
238 		elts[i] = el;
239 		list_add_tail(&el->list, &head);
240 	}
241 
242 	list_sort(NULL, &head, cmp);
243 
244 	err = -EINVAL;
245 	for (cur = head.next; cur->next != &head; cur = cur->next) {
246 		struct debug_el *el1;
247 		int cmp_result;
248 
249 		if (cur->next->prev != cur) {
250 			pr_err("error: list is corrupted\n");
251 			goto exit;
252 		}
253 
254 		cmp_result = cmp(NULL, cur, cur->next);
255 		if (cmp_result > 0) {
256 			pr_err("error: list is not sorted\n");
257 			goto exit;
258 		}
259 
260 		el = container_of(cur, struct debug_el, list);
261 		el1 = container_of(cur->next, struct debug_el, list);
262 		if (cmp_result == 0 && el->serial >= el1->serial) {
263 			pr_err("error: order of equivalent elements not "
264 				"preserved\n");
265 			goto exit;
266 		}
267 
268 		if (check(el, el1)) {
269 			pr_err("error: element check failed\n");
270 			goto exit;
271 		}
272 		count++;
273 	}
274 	if (head.prev != cur) {
275 		pr_err("error: list is corrupted\n");
276 		goto exit;
277 	}
278 
279 
280 	if (count != TEST_LIST_LEN) {
281 		pr_err("error: bad list length %d", count);
282 		goto exit;
283 	}
284 
285 	err = 0;
286 exit:
287 	for (i = 0; i < TEST_LIST_LEN; i++)
288 		kfree(elts[i]);
289 	kfree(elts);
290 	return err;
291 }
292 module_init(list_sort_test);
293 #endif /* CONFIG_TEST_LIST_SORT */
294