xref: /linux/lib/sg_split.c (revision 2dbc0838bcf24ca59cabc3130cf3b1d6809cdcd4)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (C) 2015 Robert Jarzmik <robert.jarzmik@free.fr>
4  *
5  * Scatterlist splitting helpers.
6  */
7 
8 #include <linux/scatterlist.h>
9 #include <linux/slab.h>
10 
11 struct sg_splitter {
12 	struct scatterlist *in_sg0;
13 	int nents;
14 	off_t skip_sg0;
15 	unsigned int length_last_sg;
16 
17 	struct scatterlist *out_sg;
18 };
19 
20 static int sg_calculate_split(struct scatterlist *in, int nents, int nb_splits,
21 			      off_t skip, const size_t *sizes,
22 			      struct sg_splitter *splitters, bool mapped)
23 {
24 	int i;
25 	unsigned int sglen;
26 	size_t size = sizes[0], len;
27 	struct sg_splitter *curr = splitters;
28 	struct scatterlist *sg;
29 
30 	for (i = 0; i < nb_splits; i++) {
31 		splitters[i].in_sg0 = NULL;
32 		splitters[i].nents = 0;
33 	}
34 
35 	for_each_sg(in, sg, nents, i) {
36 		sglen = mapped ? sg_dma_len(sg) : sg->length;
37 		if (skip > sglen) {
38 			skip -= sglen;
39 			continue;
40 		}
41 
42 		len = min_t(size_t, size, sglen - skip);
43 		if (!curr->in_sg0) {
44 			curr->in_sg0 = sg;
45 			curr->skip_sg0 = skip;
46 		}
47 		size -= len;
48 		curr->nents++;
49 		curr->length_last_sg = len;
50 
51 		while (!size && (skip + len < sglen) && (--nb_splits > 0)) {
52 			curr++;
53 			size = *(++sizes);
54 			skip += len;
55 			len = min_t(size_t, size, sglen - skip);
56 
57 			curr->in_sg0 = sg;
58 			curr->skip_sg0 = skip;
59 			curr->nents = 1;
60 			curr->length_last_sg = len;
61 			size -= len;
62 		}
63 		skip = 0;
64 
65 		if (!size && --nb_splits > 0) {
66 			curr++;
67 			size = *(++sizes);
68 		}
69 
70 		if (!nb_splits)
71 			break;
72 	}
73 
74 	return (size || !splitters[0].in_sg0) ? -EINVAL : 0;
75 }
76 
77 static void sg_split_phys(struct sg_splitter *splitters, const int nb_splits)
78 {
79 	int i, j;
80 	struct scatterlist *in_sg, *out_sg;
81 	struct sg_splitter *split;
82 
83 	for (i = 0, split = splitters; i < nb_splits; i++, split++) {
84 		in_sg = split->in_sg0;
85 		out_sg = split->out_sg;
86 		for (j = 0; j < split->nents; j++, out_sg++) {
87 			*out_sg = *in_sg;
88 			if (!j) {
89 				out_sg->offset += split->skip_sg0;
90 				out_sg->length -= split->skip_sg0;
91 			} else {
92 				out_sg->offset = 0;
93 			}
94 			sg_dma_address(out_sg) = 0;
95 			sg_dma_len(out_sg) = 0;
96 			in_sg = sg_next(in_sg);
97 		}
98 		out_sg[-1].length = split->length_last_sg;
99 		sg_mark_end(out_sg - 1);
100 	}
101 }
102 
103 static void sg_split_mapped(struct sg_splitter *splitters, const int nb_splits)
104 {
105 	int i, j;
106 	struct scatterlist *in_sg, *out_sg;
107 	struct sg_splitter *split;
108 
109 	for (i = 0, split = splitters; i < nb_splits; i++, split++) {
110 		in_sg = split->in_sg0;
111 		out_sg = split->out_sg;
112 		for (j = 0; j < split->nents; j++, out_sg++) {
113 			sg_dma_address(out_sg) = sg_dma_address(in_sg);
114 			sg_dma_len(out_sg) = sg_dma_len(in_sg);
115 			if (!j) {
116 				sg_dma_address(out_sg) += split->skip_sg0;
117 				sg_dma_len(out_sg) -= split->skip_sg0;
118 			}
119 			in_sg = sg_next(in_sg);
120 		}
121 		sg_dma_len(--out_sg) = split->length_last_sg;
122 	}
123 }
124 
125 /**
126  * sg_split - split a scatterlist into several scatterlists
127  * @in: the input sg list
128  * @in_mapped_nents: the result of a dma_map_sg(in, ...), or 0 if not mapped.
129  * @skip: the number of bytes to skip in the input sg list
130  * @nb_splits: the number of desired sg outputs
131  * @split_sizes: the respective size of each output sg list in bytes
132  * @out: an array where to store the allocated output sg lists
133  * @out_mapped_nents: the resulting sg lists mapped number of sg entries. Might
134  *                    be NULL if sglist not already mapped (in_mapped_nents = 0)
135  * @gfp_mask: the allocation flag
136  *
137  * This function splits the input sg list into nb_splits sg lists, which are
138  * allocated and stored into out.
139  * The @in is split into :
140  *  - @out[0], which covers bytes [@skip .. @skip + @split_sizes[0] - 1] of @in
141  *  - @out[1], which covers bytes [@skip + split_sizes[0] ..
142  *                                 @skip + @split_sizes[0] + @split_sizes[1] -1]
143  * etc ...
144  * It will be the caller's duty to kfree() out array members.
145  *
146  * Returns 0 upon success, or error code
147  */
148 int sg_split(struct scatterlist *in, const int in_mapped_nents,
149 	     const off_t skip, const int nb_splits,
150 	     const size_t *split_sizes,
151 	     struct scatterlist **out, int *out_mapped_nents,
152 	     gfp_t gfp_mask)
153 {
154 	int i, ret;
155 	struct sg_splitter *splitters;
156 
157 	splitters = kcalloc(nb_splits, sizeof(*splitters), gfp_mask);
158 	if (!splitters)
159 		return -ENOMEM;
160 
161 	ret = sg_calculate_split(in, sg_nents(in), nb_splits, skip, split_sizes,
162 			   splitters, false);
163 	if (ret < 0)
164 		goto err;
165 
166 	ret = -ENOMEM;
167 	for (i = 0; i < nb_splits; i++) {
168 		splitters[i].out_sg = kmalloc_array(splitters[i].nents,
169 						    sizeof(struct scatterlist),
170 						    gfp_mask);
171 		if (!splitters[i].out_sg)
172 			goto err;
173 	}
174 
175 	/*
176 	 * The order of these 3 calls is important and should be kept.
177 	 */
178 	sg_split_phys(splitters, nb_splits);
179 	ret = sg_calculate_split(in, in_mapped_nents, nb_splits, skip,
180 				 split_sizes, splitters, true);
181 	if (ret < 0)
182 		goto err;
183 	sg_split_mapped(splitters, nb_splits);
184 
185 	for (i = 0; i < nb_splits; i++) {
186 		out[i] = splitters[i].out_sg;
187 		if (out_mapped_nents)
188 			out_mapped_nents[i] = splitters[i].nents;
189 	}
190 
191 	kfree(splitters);
192 	return 0;
193 
194 err:
195 	for (i = 0; i < nb_splits; i++)
196 		kfree(splitters[i].out_sg);
197 	kfree(splitters);
198 	return ret;
199 }
200 EXPORT_SYMBOL(sg_split);
201