xref: /linux/drivers/md/dm-vdo/indexer/radix-sort.c (revision 6af58aa3b028e364c0a8f8b6be48fca17e571de3)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright 2023 Red Hat
4  */
5 
6 #include "radix-sort.h"
7 
8 #include <linux/limits.h>
9 #include <linux/types.h>
10 
11 #include "memory-alloc.h"
12 #include "string-utils.h"
13 
14 /*
15  * This implementation allocates one large object to do the sorting, which can be reused as many
16  * times as desired. The amount of memory required is logarithmically proportional to the number of
17  * keys to be sorted.
18  */
19 
20 /* Piles smaller than this are handled with a simple insertion sort. */
21 #define INSERTION_SORT_THRESHOLD 12
22 
23 /* Sort keys are pointers to immutable fixed-length arrays of bytes. */
24 typedef const u8 *sort_key_t;
25 
26 /*
27  * The keys are separated into piles based on the byte in each keys at the current offset, so the
28  * number of keys with each byte must be counted.
29  */
30 struct histogram {
31 	/* The number of non-empty bins */
32 	u16 used;
33 	/* The index (key byte) of the first non-empty bin */
34 	u16 first;
35 	/* The index (key byte) of the last non-empty bin */
36 	u16 last;
37 	/* The number of occurrences of each specific byte */
38 	u32 size[256];
39 };
40 
41 /*
42  * Sub-tasks are manually managed on a stack, both for performance and to put a logarithmic bound
43  * on the stack space needed.
44  */
45 struct task {
46 	/* Pointer to the first key to sort. */
47 	sort_key_t *first_key;
48 	/* Pointer to the last key to sort. */
49 	sort_key_t *last_key;
50 	/* The offset into the key at which to continue sorting. */
51 	u16 offset;
52 	/* The number of bytes remaining in the sort keys. */
53 	u16 length;
54 };
55 
56 struct radix_sorter {
57 	unsigned int count;
58 	struct histogram bins;
59 	sort_key_t *pile[256];
60 	struct task *end_of_stack;
61 	struct task insertion_list[256];
62 	struct task stack[];
63 };
64 
65 /* Compare a segment of two fixed-length keys starting at an offset. */
66 static inline int compare(sort_key_t key1, sort_key_t key2, u16 offset, u16 length)
67 {
68 	return memcmp(&key1[offset], &key2[offset], length);
69 }
70 
71 /* Insert the next unsorted key into an array of sorted keys. */
72 static inline void insert_key(const struct task task, sort_key_t *next)
73 {
74 	/* Pull the unsorted key out, freeing up the array slot. */
75 	sort_key_t unsorted = *next;
76 
77 	/* Compare the key to the preceding sorted entries, shifting down ones that are larger. */
78 	while ((--next >= task.first_key) &&
79 	       (compare(unsorted, next[0], task.offset, task.length) < 0))
80 		next[1] = next[0];
81 
82 	/* Insert the key into the last slot that was cleared, sorting it. */
83 	next[1] = unsorted;
84 }
85 
86 /*
87  * Sort a range of key segments using an insertion sort. This simple sort is faster than the
88  * 256-way radix sort when the number of keys to sort is small.
89  */
90 static inline void insertion_sort(const struct task task)
91 {
92 	sort_key_t *next;
93 
94 	for (next = task.first_key + 1; next <= task.last_key; next++)
95 		insert_key(task, next);
96 }
97 
98 /* Push a sorting task onto a task stack. */
99 static inline void push_task(struct task **stack_pointer, sort_key_t *first_key,
100 			     u32 count, u16 offset, u16 length)
101 {
102 	struct task *task = (*stack_pointer)++;
103 
104 	task->first_key = first_key;
105 	task->last_key = &first_key[count - 1];
106 	task->offset = offset;
107 	task->length = length;
108 }
109 
110 static inline void swap_keys(sort_key_t *a, sort_key_t *b)
111 {
112 	sort_key_t c = *a;
113 	*a = *b;
114 	*b = c;
115 }
116 
117 /*
118  * Count the number of times each byte value appears in the arrays of keys to sort at the current
119  * offset, keeping track of the number of non-empty bins, and the index of the first and last
120  * non-empty bin.
121  */
122 static inline void measure_bins(const struct task task, struct histogram *bins)
123 {
124 	sort_key_t *key_ptr;
125 
126 	/*
127 	 * Subtle invariant: bins->used and bins->size[] are zero because the sorting code clears
128 	 * it all out as it goes. Even though this structure is re-used, we don't need to pay to
129 	 * zero it before starting a new tally.
130 	 */
131 	bins->first = U8_MAX;
132 	bins->last = 0;
133 
134 	for (key_ptr = task.first_key; key_ptr <= task.last_key; key_ptr++) {
135 		/* Increment the count for the byte in the key at the current offset. */
136 		u8 bin = (*key_ptr)[task.offset];
137 		u32 size = ++bins->size[bin];
138 
139 		/* Track non-empty bins. */
140 		if (size == 1) {
141 			bins->used += 1;
142 			if (bin < bins->first)
143 				bins->first = bin;
144 
145 			if (bin > bins->last)
146 				bins->last = bin;
147 		}
148 	}
149 }
150 
151 /*
152  * Convert the bin sizes to pointers to where each pile goes.
153  *
154  *   pile[0] = first_key + bin->size[0],
155  *   pile[1] = pile[0]  + bin->size[1], etc.
156  *
157  * After the keys are moved to the appropriate pile, we'll need to sort each of the piles by the
158  * next radix position. A new task is put on the stack for each pile containing lots of keys, or a
159  * new task is put on the list for each pile containing few keys.
160  *
161  * @stack: pointer the top of the stack
162  * @end_of_stack: the end of the stack
163  * @list: pointer the head of the list
164  * @pile: array for pointers to the end of each pile
165  * @bins: the histogram of the sizes of each pile
166  * @first_key: the first key of the stack
167  * @offset: the next radix position to sort by
168  * @length: the number of bytes remaining in the sort keys
169  *
170  * Return: UDS_SUCCESS or an error code
171  */
172 static inline int push_bins(struct task **stack, struct task *end_of_stack,
173 			    struct task **list, sort_key_t *pile[],
174 			    struct histogram *bins, sort_key_t *first_key,
175 			    u16 offset, u16 length)
176 {
177 	sort_key_t *pile_start = first_key;
178 	int bin;
179 
180 	for (bin = bins->first; ; bin++) {
181 		u32 size = bins->size[bin];
182 
183 		/* Skip empty piles. */
184 		if (size == 0)
185 			continue;
186 
187 		/* There's no need to sort empty keys. */
188 		if (length > 0) {
189 			if (size > INSERTION_SORT_THRESHOLD) {
190 				if (*stack >= end_of_stack)
191 					return UDS_BAD_STATE;
192 
193 				push_task(stack, pile_start, size, offset, length);
194 			} else if (size > 1) {
195 				push_task(list, pile_start, size, offset, length);
196 			}
197 		}
198 
199 		pile_start += size;
200 		pile[bin] = pile_start;
201 		if (--bins->used == 0)
202 			break;
203 	}
204 
205 	return UDS_SUCCESS;
206 }
207 
208 int uds_make_radix_sorter(unsigned int count, struct radix_sorter **sorter)
209 {
210 	int result;
211 	unsigned int stack_size = count / INSERTION_SORT_THRESHOLD;
212 	struct radix_sorter *radix_sorter;
213 
214 	result = vdo_allocate_extended(stack_size, stack, __func__, &radix_sorter);
215 	if (result != VDO_SUCCESS)
216 		return result;
217 
218 	radix_sorter->count = count;
219 	radix_sorter->end_of_stack = radix_sorter->stack + stack_size;
220 	*sorter = radix_sorter;
221 	return UDS_SUCCESS;
222 }
223 
224 void uds_free_radix_sorter(struct radix_sorter *sorter)
225 {
226 	vdo_free(sorter);
227 }
228 
229 /*
230  * Sort pointers to fixed-length keys (arrays of bytes) using a radix sort. The sort implementation
231  * is unstable, so the relative ordering of equal keys is not preserved.
232  */
233 int uds_radix_sort(struct radix_sorter *sorter, const unsigned char *keys[],
234 		   unsigned int count, unsigned short length)
235 {
236 	struct task start;
237 	struct histogram *bins = &sorter->bins;
238 	sort_key_t **pile = sorter->pile;
239 	struct task *task_stack = sorter->stack;
240 
241 	/* All zero-length keys are identical and therefore already sorted. */
242 	if ((count == 0) || (length == 0))
243 		return UDS_SUCCESS;
244 
245 	/* The initial task is to sort the entire length of all the keys. */
246 	start = (struct task) {
247 		.first_key = keys,
248 		.last_key = &keys[count - 1],
249 		.offset = 0,
250 		.length = length,
251 	};
252 
253 	if (count <= INSERTION_SORT_THRESHOLD) {
254 		insertion_sort(start);
255 		return UDS_SUCCESS;
256 	}
257 
258 	if (count > sorter->count)
259 		return UDS_INVALID_ARGUMENT;
260 
261 	/*
262 	 * Repeatedly consume a sorting task from the stack and process it, pushing new sub-tasks
263 	 * onto the stack for each radix-sorted pile. When all tasks and sub-tasks have been
264 	 * processed, the stack will be empty and all the keys in the starting task will be fully
265 	 * sorted.
266 	 */
267 	for (*task_stack = start; task_stack >= sorter->stack; task_stack--) {
268 		const struct task task = *task_stack;
269 		struct task *insertion_task_list;
270 		int result;
271 		sort_key_t *fence;
272 		sort_key_t *end;
273 
274 		measure_bins(task, bins);
275 
276 		/*
277 		 * Now that we know how large each bin is, generate pointers for each of the piles
278 		 * and push a new task to sort each pile by the next radix byte.
279 		 */
280 		insertion_task_list = sorter->insertion_list;
281 		result = push_bins(&task_stack, sorter->end_of_stack,
282 				   &insertion_task_list, pile, bins, task.first_key,
283 				   task.offset + 1, task.length - 1);
284 		if (result != UDS_SUCCESS) {
285 			memset(bins, 0, sizeof(*bins));
286 			return result;
287 		}
288 
289 		/* Now bins->used is zero again. */
290 
291 		/*
292 		 * Don't bother processing the last pile: when piles 0..N-1 are all in place, then
293 		 * pile N must also be in place.
294 		 */
295 		end = task.last_key - bins->size[bins->last];
296 		bins->size[bins->last] = 0;
297 
298 		for (fence = task.first_key; fence <= end; ) {
299 			u8 bin;
300 			sort_key_t key = *fence;
301 
302 			/*
303 			 * The radix byte of the key tells us which pile it belongs in. Swap it for
304 			 * an unprocessed item just below that pile, and repeat.
305 			 */
306 			while (--pile[bin = key[task.offset]] > fence)
307 				swap_keys(pile[bin], &key);
308 
309 			/*
310 			 * The pile reached the fence. Put the key at the bottom of that pile,
311 			 * completing it, and advance the fence to the next pile.
312 			 */
313 			*fence = key;
314 			fence += bins->size[bin];
315 			bins->size[bin] = 0;
316 		}
317 
318 		/* Now bins->size[] is all zero again. */
319 
320 		/*
321 		 * When the number of keys in a task gets small enough, it is faster to use an
322 		 * insertion sort than to keep subdividing into tiny piles.
323 		 */
324 		while (--insertion_task_list >= sorter->insertion_list)
325 			insertion_sort(*insertion_task_list);
326 	}
327 
328 	return UDS_SUCCESS;
329 }
330