xref: /linux/kernel/bpf/tnum.c (revision eb71ab2bf72260054677e348498ba995a057c463)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* tnum: tracked (or tristate) numbers
3  *
4  * A tnum tracks knowledge about the bits of a value.  Each bit can be either
5  * known (0 or 1), or unknown (x).  Arithmetic operations on tnums will
6  * propagate the unknown bits such that the tnum result represents all the
7  * possible results for possible values of the operands.
8  */
9 #include <linux/kernel.h>
10 #include <linux/tnum.h>
11 #include <linux/swab.h>
12 
13 #define TNUM(_v, _m)	(struct tnum){.value = _v, .mask = _m}
14 /* A completely unknown value */
15 const struct tnum tnum_unknown = { .value = 0, .mask = -1 };
16 
tnum_const(u64 value)17 struct tnum tnum_const(u64 value)
18 {
19 	return TNUM(value, 0);
20 }
21 
tnum_range(u64 min,u64 max)22 struct tnum tnum_range(u64 min, u64 max)
23 {
24 	u64 chi = min ^ max, delta;
25 	u8 bits = fls64(chi);
26 
27 	/* special case, needed because 1ULL << 64 is undefined */
28 	if (bits > 63)
29 		return tnum_unknown;
30 	/* e.g. if chi = 4, bits = 3, delta = (1<<3) - 1 = 7.
31 	 * if chi = 0, bits = 0, delta = (1<<0) - 1 = 0, so we return
32 	 *  constant min (since min == max).
33 	 */
34 	delta = (1ULL << bits) - 1;
35 	return TNUM(min & ~delta, delta);
36 }
37 
tnum_lshift(struct tnum a,u8 shift)38 struct tnum tnum_lshift(struct tnum a, u8 shift)
39 {
40 	return TNUM(a.value << shift, a.mask << shift);
41 }
42 
tnum_rshift(struct tnum a,u8 shift)43 struct tnum tnum_rshift(struct tnum a, u8 shift)
44 {
45 	return TNUM(a.value >> shift, a.mask >> shift);
46 }
47 
tnum_arshift(struct tnum a,u8 min_shift,u8 insn_bitness)48 struct tnum tnum_arshift(struct tnum a, u8 min_shift, u8 insn_bitness)
49 {
50 	/* if a.value is negative, arithmetic shifting by minimum shift
51 	 * will have larger negative offset compared to more shifting.
52 	 * If a.value is nonnegative, arithmetic shifting by minimum shift
53 	 * will have larger positive offset compare to more shifting.
54 	 */
55 	if (insn_bitness == 32)
56 		return TNUM((u32)(((s32)a.value) >> min_shift),
57 			    (u32)(((s32)a.mask)  >> min_shift));
58 	else
59 		return TNUM((s64)a.value >> min_shift,
60 			    (s64)a.mask  >> min_shift);
61 }
62 
tnum_add(struct tnum a,struct tnum b)63 struct tnum tnum_add(struct tnum a, struct tnum b)
64 {
65 	u64 sm, sv, sigma, chi, mu;
66 
67 	sm = a.mask + b.mask;
68 	sv = a.value + b.value;
69 	sigma = sm + sv;
70 	chi = sigma ^ sv;
71 	mu = chi | a.mask | b.mask;
72 	return TNUM(sv & ~mu, mu);
73 }
74 
tnum_sub(struct tnum a,struct tnum b)75 struct tnum tnum_sub(struct tnum a, struct tnum b)
76 {
77 	u64 dv, alpha, beta, chi, mu;
78 
79 	dv = a.value - b.value;
80 	alpha = dv + a.mask;
81 	beta = dv - b.mask;
82 	chi = alpha ^ beta;
83 	mu = chi | a.mask | b.mask;
84 	return TNUM(dv & ~mu, mu);
85 }
86 
tnum_neg(struct tnum a)87 struct tnum tnum_neg(struct tnum a)
88 {
89 	return tnum_sub(TNUM(0, 0), a);
90 }
91 
tnum_and(struct tnum a,struct tnum b)92 struct tnum tnum_and(struct tnum a, struct tnum b)
93 {
94 	u64 alpha, beta, v;
95 
96 	alpha = a.value | a.mask;
97 	beta = b.value | b.mask;
98 	v = a.value & b.value;
99 	return TNUM(v, alpha & beta & ~v);
100 }
101 
tnum_or(struct tnum a,struct tnum b)102 struct tnum tnum_or(struct tnum a, struct tnum b)
103 {
104 	u64 v, mu;
105 
106 	v = a.value | b.value;
107 	mu = a.mask | b.mask;
108 	return TNUM(v, mu & ~v);
109 }
110 
tnum_xor(struct tnum a,struct tnum b)111 struct tnum tnum_xor(struct tnum a, struct tnum b)
112 {
113 	u64 v, mu;
114 
115 	v = a.value ^ b.value;
116 	mu = a.mask | b.mask;
117 	return TNUM(v & ~mu, mu);
118 }
119 
120 /* Perform long multiplication, iterating through the bits in a using rshift:
121  * - if LSB(a) is a known 0, keep current accumulator
122  * - if LSB(a) is a known 1, add b to current accumulator
123  * - if LSB(a) is unknown, take a union of the above cases.
124  *
125  * For example:
126  *
127  *               acc_0:        acc_1:
128  *
129  *     11 *  ->      11 *  ->      11 *  -> union(0011, 1001) == x0x1
130  *     x1            01            11
131  * ------        ------        ------
132  *     11            11            11
133  *    xx            00            11
134  * ------        ------        ------
135  *   ????          0011          1001
136  */
tnum_mul(struct tnum a,struct tnum b)137 struct tnum tnum_mul(struct tnum a, struct tnum b)
138 {
139 	struct tnum acc = TNUM(0, 0);
140 
141 	while (a.value || a.mask) {
142 		/* LSB of tnum a is a certain 1 */
143 		if (a.value & 1)
144 			acc = tnum_add(acc, b);
145 		/* LSB of tnum a is uncertain */
146 		else if (a.mask & 1) {
147 			/* acc = tnum_union(acc_0, acc_1), where acc_0 and
148 			 * acc_1 are partial accumulators for cases
149 			 * LSB(a) = certain 0 and LSB(a) = certain 1.
150 			 * acc_0 = acc + 0 * b = acc.
151 			 * acc_1 = acc + 1 * b = tnum_add(acc, b).
152 			 */
153 
154 			acc = tnum_union(acc, tnum_add(acc, b));
155 		}
156 		/* Note: no case for LSB is certain 0 */
157 		a = tnum_rshift(a, 1);
158 		b = tnum_lshift(b, 1);
159 	}
160 	return acc;
161 }
162 
tnum_overlap(struct tnum a,struct tnum b)163 bool tnum_overlap(struct tnum a, struct tnum b)
164 {
165 	u64 mu;
166 
167 	mu = ~a.mask & ~b.mask;
168 	return (a.value & mu) == (b.value & mu);
169 }
170 
171 /* Note that if a and b disagree - i.e. one has a 'known 1' where the other has
172  * a 'known 0' - this will return a 'known 1' for that bit.
173  */
tnum_intersect(struct tnum a,struct tnum b)174 struct tnum tnum_intersect(struct tnum a, struct tnum b)
175 {
176 	u64 v, mu;
177 
178 	v = a.value | b.value;
179 	mu = a.mask & b.mask;
180 	return TNUM(v & ~mu, mu);
181 }
182 
183 /* Returns a tnum with the uncertainty from both a and b, and in addition, new
184  * uncertainty at any position that a and b disagree. This represents a
185  * superset of the union of the concrete sets of both a and b. Despite the
186  * overapproximation, it is optimal.
187  */
tnum_union(struct tnum a,struct tnum b)188 struct tnum tnum_union(struct tnum a, struct tnum b)
189 {
190 	u64 v = a.value & b.value;
191 	u64 mu = (a.value ^ b.value) | a.mask | b.mask;
192 
193 	return TNUM(v & ~mu, mu);
194 }
195 
tnum_cast(struct tnum a,u8 size)196 struct tnum tnum_cast(struct tnum a, u8 size)
197 {
198 	a.value &= (1ULL << (size * 8)) - 1;
199 	a.mask &= (1ULL << (size * 8)) - 1;
200 	return a;
201 }
202 
tnum_is_aligned(struct tnum a,u64 size)203 bool tnum_is_aligned(struct tnum a, u64 size)
204 {
205 	if (!size)
206 		return true;
207 	return !((a.value | a.mask) & (size - 1));
208 }
209 
tnum_in(struct tnum a,struct tnum b)210 bool tnum_in(struct tnum a, struct tnum b)
211 {
212 	if (b.mask & ~a.mask)
213 		return false;
214 	b.value &= ~a.mask;
215 	return a.value == b.value;
216 }
217 
tnum_sbin(char * str,size_t size,struct tnum a)218 int tnum_sbin(char *str, size_t size, struct tnum a)
219 {
220 	size_t n;
221 
222 	for (n = 64; n; n--) {
223 		if (n < size) {
224 			if (a.mask & 1)
225 				str[n - 1] = 'x';
226 			else if (a.value & 1)
227 				str[n - 1] = '1';
228 			else
229 				str[n - 1] = '0';
230 		}
231 		a.mask >>= 1;
232 		a.value >>= 1;
233 	}
234 	str[min(size - 1, (size_t)64)] = 0;
235 	return 64;
236 }
237 
tnum_subreg(struct tnum a)238 struct tnum tnum_subreg(struct tnum a)
239 {
240 	return tnum_cast(a, 4);
241 }
242 
tnum_clear_subreg(struct tnum a)243 struct tnum tnum_clear_subreg(struct tnum a)
244 {
245 	return tnum_lshift(tnum_rshift(a, 32), 32);
246 }
247 
tnum_with_subreg(struct tnum reg,struct tnum subreg)248 struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg)
249 {
250 	return tnum_or(tnum_clear_subreg(reg), tnum_subreg(subreg));
251 }
252 
tnum_const_subreg(struct tnum a,u32 value)253 struct tnum tnum_const_subreg(struct tnum a, u32 value)
254 {
255 	return tnum_with_subreg(a, tnum_const(value));
256 }
257 
tnum_bswap16(struct tnum a)258 struct tnum tnum_bswap16(struct tnum a)
259 {
260 	return TNUM(swab16(a.value & 0xFFFF), swab16(a.mask & 0xFFFF));
261 }
262 
tnum_bswap32(struct tnum a)263 struct tnum tnum_bswap32(struct tnum a)
264 {
265 	return TNUM(swab32(a.value & 0xFFFFFFFF), swab32(a.mask & 0xFFFFFFFF));
266 }
267 
tnum_bswap64(struct tnum a)268 struct tnum tnum_bswap64(struct tnum a)
269 {
270 	return TNUM(swab64(a.value), swab64(a.mask));
271 }
272 
273 /* Given tnum t, and a number z such that tmin <= z < tmax, where tmin
274  * is the smallest member of the t (= t.value) and tmax is the largest
275  * member of t (= t.value | t.mask), returns the smallest member of t
276  * larger than z.
277  *
278  * For example,
279  * t      = x11100x0
280  * z      = 11110001 (241)
281  * result = 11110010 (242)
282  *
283  * Note: if this function is called with z >= tmax, it just returns
284  * early with tmax; if this function is called with z < tmin, the
285  * algorithm already returns tmin.
286  */
tnum_step(struct tnum t,u64 z)287 u64 tnum_step(struct tnum t, u64 z)
288 {
289 	u64 tmax, j, p, q, r, s, v, u, w, res;
290 	u8 k;
291 
292 	tmax = t.value | t.mask;
293 
294 	/* if z >= largest member of t, return largest member of t */
295 	if (z >= tmax)
296 		return tmax;
297 
298 	/* if z < smallest member of t, return smallest member of t */
299 	if (z < t.value)
300 		return t.value;
301 
302 	/* keep t's known bits, and match all unknown bits to z */
303 	j = t.value | (z & t.mask);
304 
305 	if (j > z) {
306 		p = ~z & t.value & ~t.mask;
307 		k = fls64(p); /* k is the most-significant 0-to-1 flip */
308 		q = U64_MAX << k;
309 		r = q & z; /* positions > k matched to z */
310 		s = ~q & t.value; /* positions <= k matched to t.value */
311 		v = r | s;
312 		res = v;
313 	} else {
314 		p = z & ~t.value & ~t.mask;
315 		k = fls64(p); /* k is the most-significant 1-to-0 flip */
316 		q = U64_MAX << k;
317 		r = q & t.mask & z; /* unknown positions > k, matched to z */
318 		s = q & ~t.mask; /* known positions > k, set to 1 */
319 		v = r | s;
320 		/* add 1 to unknown positions > k to make value greater than z */
321 		u = v + (1ULL << k);
322 		/* extract bits in unknown positions > k from u, rest from t.value */
323 		w = (u & t.mask) | t.value;
324 		res = w;
325 	}
326 	return res;
327 }
328