1 2 #define pr_fmt(fmt) "list_sort_test: " fmt 3 4 #include <linux/kernel.h> 5 #include <linux/module.h> 6 #include <linux/list_sort.h> 7 #include <linux/slab.h> 8 #include <linux/list.h> 9 10 #define MAX_LIST_LENGTH_BITS 20 11 12 /* 13 * Returns a list organized in an intermediate format suited 14 * to chaining of merge() calls: null-terminated, no reserved or 15 * sentinel head node, "prev" links not maintained. 16 */ 17 static struct list_head *merge(void *priv, 18 int (*cmp)(void *priv, struct list_head *a, 19 struct list_head *b), 20 struct list_head *a, struct list_head *b) 21 { 22 struct list_head head, *tail = &head; 23 24 while (a && b) { 25 /* if equal, take 'a' -- important for sort stability */ 26 if ((*cmp)(priv, a, b) <= 0) { 27 tail->next = a; 28 a = a->next; 29 } else { 30 tail->next = b; 31 b = b->next; 32 } 33 tail = tail->next; 34 } 35 tail->next = a?:b; 36 return head.next; 37 } 38 39 /* 40 * Combine final list merge with restoration of standard doubly-linked 41 * list structure. This approach duplicates code from merge(), but 42 * runs faster than the tidier alternatives of either a separate final 43 * prev-link restoration pass, or maintaining the prev links 44 * throughout. 45 */ 46 static void merge_and_restore_back_links(void *priv, 47 int (*cmp)(void *priv, struct list_head *a, 48 struct list_head *b), 49 struct list_head *head, 50 struct list_head *a, struct list_head *b) 51 { 52 struct list_head *tail = head; 53 u8 count = 0; 54 55 while (a && b) { 56 /* if equal, take 'a' -- important for sort stability */ 57 if ((*cmp)(priv, a, b) <= 0) { 58 tail->next = a; 59 a->prev = tail; 60 a = a->next; 61 } else { 62 tail->next = b; 63 b->prev = tail; 64 b = b->next; 65 } 66 tail = tail->next; 67 } 68 tail->next = a ? : b; 69 70 do { 71 /* 72 * In worst cases this loop may run many iterations. 73 * Continue callbacks to the client even though no 74 * element comparison is needed, so the client's cmp() 75 * routine can invoke cond_resched() periodically. 76 */ 77 if (unlikely(!(++count))) 78 (*cmp)(priv, tail->next, tail->next); 79 80 tail->next->prev = tail; 81 tail = tail->next; 82 } while (tail->next); 83 84 tail->next = head; 85 head->prev = tail; 86 } 87 88 /** 89 * list_sort - sort a list 90 * @priv: private data, opaque to list_sort(), passed to @cmp 91 * @head: the list to sort 92 * @cmp: the elements comparison function 93 * 94 * This function implements "merge sort", which has O(nlog(n)) 95 * complexity. 96 * 97 * The comparison function @cmp must return a negative value if @a 98 * should sort before @b, and a positive value if @a should sort after 99 * @b. If @a and @b are equivalent, and their original relative 100 * ordering is to be preserved, @cmp must return 0. 101 */ 102 void list_sort(void *priv, struct list_head *head, 103 int (*cmp)(void *priv, struct list_head *a, 104 struct list_head *b)) 105 { 106 struct list_head *part[MAX_LIST_LENGTH_BITS+1]; /* sorted partial lists 107 -- last slot is a sentinel */ 108 int lev; /* index into part[] */ 109 int max_lev = 0; 110 struct list_head *list; 111 112 if (list_empty(head)) 113 return; 114 115 memset(part, 0, sizeof(part)); 116 117 head->prev->next = NULL; 118 list = head->next; 119 120 while (list) { 121 struct list_head *cur = list; 122 list = list->next; 123 cur->next = NULL; 124 125 for (lev = 0; part[lev]; lev++) { 126 cur = merge(priv, cmp, part[lev], cur); 127 part[lev] = NULL; 128 } 129 if (lev > max_lev) { 130 if (unlikely(lev >= ARRAY_SIZE(part)-1)) { 131 printk_once(KERN_DEBUG "list too long for efficiency\n"); 132 lev--; 133 } 134 max_lev = lev; 135 } 136 part[lev] = cur; 137 } 138 139 for (lev = 0; lev < max_lev; lev++) 140 if (part[lev]) 141 list = merge(priv, cmp, part[lev], list); 142 143 merge_and_restore_back_links(priv, cmp, head, part[max_lev], list); 144 } 145 EXPORT_SYMBOL(list_sort); 146 147 #ifdef CONFIG_TEST_LIST_SORT 148 149 #include <linux/random.h> 150 151 /* 152 * The pattern of set bits in the list length determines which cases 153 * are hit in list_sort(). 154 */ 155 #define TEST_LIST_LEN (512+128+2) /* not including head */ 156 157 #define TEST_POISON1 0xDEADBEEF 158 #define TEST_POISON2 0xA324354C 159 160 struct debug_el { 161 unsigned int poison1; 162 struct list_head list; 163 unsigned int poison2; 164 int value; 165 unsigned serial; 166 }; 167 168 /* Array, containing pointers to all elements in the test list */ 169 static struct debug_el **elts __initdata; 170 171 static int __init check(struct debug_el *ela, struct debug_el *elb) 172 { 173 if (ela->serial >= TEST_LIST_LEN) { 174 pr_err("error: incorrect serial %d\n", ela->serial); 175 return -EINVAL; 176 } 177 if (elb->serial >= TEST_LIST_LEN) { 178 pr_err("error: incorrect serial %d\n", elb->serial); 179 return -EINVAL; 180 } 181 if (elts[ela->serial] != ela || elts[elb->serial] != elb) { 182 pr_err("error: phantom element\n"); 183 return -EINVAL; 184 } 185 if (ela->poison1 != TEST_POISON1 || ela->poison2 != TEST_POISON2) { 186 pr_err("error: bad poison: %#x/%#x\n", 187 ela->poison1, ela->poison2); 188 return -EINVAL; 189 } 190 if (elb->poison1 != TEST_POISON1 || elb->poison2 != TEST_POISON2) { 191 pr_err("error: bad poison: %#x/%#x\n", 192 elb->poison1, elb->poison2); 193 return -EINVAL; 194 } 195 return 0; 196 } 197 198 static int __init cmp(void *priv, struct list_head *a, struct list_head *b) 199 { 200 struct debug_el *ela, *elb; 201 202 ela = container_of(a, struct debug_el, list); 203 elb = container_of(b, struct debug_el, list); 204 205 check(ela, elb); 206 return ela->value - elb->value; 207 } 208 209 static int __init list_sort_test(void) 210 { 211 int i, count = 1, err = -ENOMEM; 212 struct debug_el *el; 213 struct list_head *cur; 214 LIST_HEAD(head); 215 216 pr_debug("start testing list_sort()\n"); 217 218 elts = kcalloc(TEST_LIST_LEN, sizeof(*elts), GFP_KERNEL); 219 if (!elts) { 220 pr_err("error: cannot allocate memory\n"); 221 return err; 222 } 223 224 for (i = 0; i < TEST_LIST_LEN; i++) { 225 el = kmalloc(sizeof(*el), GFP_KERNEL); 226 if (!el) { 227 pr_err("error: cannot allocate memory\n"); 228 goto exit; 229 } 230 /* force some equivalencies */ 231 el->value = prandom_u32() % (TEST_LIST_LEN / 3); 232 el->serial = i; 233 el->poison1 = TEST_POISON1; 234 el->poison2 = TEST_POISON2; 235 elts[i] = el; 236 list_add_tail(&el->list, &head); 237 } 238 239 list_sort(NULL, &head, cmp); 240 241 err = -EINVAL; 242 for (cur = head.next; cur->next != &head; cur = cur->next) { 243 struct debug_el *el1; 244 int cmp_result; 245 246 if (cur->next->prev != cur) { 247 pr_err("error: list is corrupted\n"); 248 goto exit; 249 } 250 251 cmp_result = cmp(NULL, cur, cur->next); 252 if (cmp_result > 0) { 253 pr_err("error: list is not sorted\n"); 254 goto exit; 255 } 256 257 el = container_of(cur, struct debug_el, list); 258 el1 = container_of(cur->next, struct debug_el, list); 259 if (cmp_result == 0 && el->serial >= el1->serial) { 260 pr_err("error: order of equivalent elements not " 261 "preserved\n"); 262 goto exit; 263 } 264 265 if (check(el, el1)) { 266 pr_err("error: element check failed\n"); 267 goto exit; 268 } 269 count++; 270 } 271 if (head.prev != cur) { 272 pr_err("error: list is corrupted\n"); 273 goto exit; 274 } 275 276 277 if (count != TEST_LIST_LEN) { 278 pr_err("error: bad list length %d", count); 279 goto exit; 280 } 281 282 err = 0; 283 exit: 284 for (i = 0; i < TEST_LIST_LEN; i++) 285 kfree(elts[i]); 286 kfree(elts); 287 return err; 288 } 289 module_init(list_sort_test); 290 #endif /* CONFIG_TEST_LIST_SORT */ 291