xref: /linux/kernel/bpf/cnum_defs.h (revision 79b8ebcbe483fee401e1b91dd32470348d9aa5b8)
1 /* SPDX-License-Identifier: GPL-2.0-only */
2 /* Copyright (c) 2026 Meta Platforms, Inc. and affiliates. */
3 
4 #ifndef T
5 #error "Define T (bit width: 32, 64) before including cnum_defs.h"
6 #endif
7 
8 #include <linux/cnum.h>
9 #include <linux/kernel.h>
10 #include <linux/limits.h>
11 #include <linux/minmax.h>
12 #include <linux/compiler_types.h>
13 
14 #define cnum_t   __PASTE(cnum, T)
15 #define ut       __PASTE(u, T)
16 #define st       __PASTE(s, T)
17 #define UT_MAX   __PASTE(__PASTE(U, T), _MAX)
18 #define ST_MAX   __PASTE(__PASTE(S, T), _MAX)
19 #define ST_MIN   __PASTE(__PASTE(S, T), _MIN)
20 #define EMPTY    __PASTE(__PASTE(CNUM, T), _EMPTY)
21 #define FN(name) __PASTE(__PASTE(cnum, T), __PASTE(_, name))
22 
23 struct cnum_t FN(from_urange)(ut min, ut max)
24 {
25 	return (struct cnum_t){ .base = min, .size = (ut)max - min };
26 }
27 
28 struct cnum_t FN(from_srange)(st min, st max)
29 {
30 	ut size = (ut)max - (ut)min;
31 	ut base = size == UT_MAX ? 0 : (ut)min;
32 
33 	return (struct cnum_t){ .base = base, .size = size };
34 }
35 
36 /* True if this cnum represents two unsigned ranges. */
37 static inline bool FN(urange_overflow)(struct cnum_t cnum)
38 {
39 	/* Same as cnum.base + cnum.size > UT_MAX but avoids overflow */
40 	return cnum.size > UT_MAX - (ut)cnum.base;
41 }
42 
43 /*
44  * cnum{T}_umin / cnum{T}_umax query an unsigned range represented by this cnum.
45  * If cnum represents a range crossing the UT_MAX/0 boundary, the unbound range
46  * [0..UT_MAX] is returned.
47  */
48 ut FN(umin)(struct cnum_t cnum)
49 {
50 	return FN(urange_overflow)(cnum) ? 0 : cnum.base;
51 }
52 EXPORT_SYMBOL_GPL(FN(umin));
53 
54 ut FN(umax)(struct cnum_t cnum)
55 {
56 	return FN(urange_overflow)(cnum) ? UT_MAX : cnum.base + cnum.size;
57 }
58 EXPORT_SYMBOL_GPL(FN(umax));
59 
60 /* True if this cnum represents two signed ranges. */
61 static inline bool FN(srange_overflow)(struct cnum_t cnum)
62 {
63 	return FN(contains)(cnum, (ut)ST_MAX) && FN(contains)(cnum, (ut)ST_MIN);
64 }
65 
66 /*
67  * cnum{T}_smin / cnum{T}_smax query a signed range represented by this cnum.
68  * If cnum represents a range crossing the ST_MAX/ST_MIN boundary, the unbound range
69  * [ST_MIN..ST_MAX] is returned.
70  */
71 st FN(smin)(struct cnum_t cnum)
72 {
73 	return FN(srange_overflow)(cnum)
74 	       ? ST_MIN
75 	       : min((st)cnum.base, (st)(cnum.base + cnum.size));
76 }
77 
78 st FN(smax)(struct cnum_t cnum)
79 {
80 	return FN(srange_overflow)(cnum)
81 	       ? ST_MAX
82 	       : max((st)cnum.base, (st)(cnum.base + cnum.size));
83 }
84 
85 /*
86  * Returns a possibly empty intersection of cnums 'a' and 'b'.
87  * If 'a' and 'b' intersect in two sub-arcs, the function over-approximates
88  * and returns either 'a' or 'b', whichever is smaller.
89  */
90 struct cnum_t FN(intersect)(struct cnum_t a, struct cnum_t b)
91 {
92 	struct cnum_t b1;
93 	ut dbase;
94 
95 	if (FN(is_empty)(a) || FN(is_empty)(b))
96 		return EMPTY;
97 
98 	if (a.base > b.base)
99 		swap(a, b);
100 
101 	/*
102 	 * Rotate frame of reference such that a.base is 0.
103 	 * 'b1' is 'b' in this frame of reference.
104 	 */
105 	dbase = b.base - a.base;
106 	b1 = (struct cnum_t){ dbase, b.size };
107 	if (FN(urange_overflow)(b1)) {
108 		if (b1.base <= a.size) {
109 			/*
110 			 * Rotated frame (a.base at origin):
111 			 *
112 			 * 0                                       UT_MAX
113 			 * |--------------------------------------------|
114 			 * [=== a ==========================]           |
115 			 * [= b1 tail =]  [========= b1 main ==========>]
116 			 *                 ^-- b1.base <= a.size
117 			 *
118 			 * 'a' and 'b' intersect in two disjoint arcs,
119 			 * can't represent as single cnum, over-approximate
120 			 * the result.
121 			 */
122 			return a.size <= b.size ? a : b;
123 		} else {
124 			/*
125 			 * Rotated frame (a.base at origin):
126 			 *
127 			 * 0                                       UT_MAX
128 			 * |--------------------------------------------|
129 			 * [=== a =============]  |                     |
130 			 * [= b1 tail =]          [======= b1 main ====>]
131 			 *                         ^-- b1.base > a.size
132 			 *
133 			 * Only 'b' tail intersects 'a'.
134 			 */
135 			return (struct cnum_t) {
136 				.base = a.base,
137 				.size = min(a.size, (ut)(b1.base + b1.size)),
138 			};
139 		}
140 	} else if (a.size >= b1.base) {
141 		/*
142 		 * Rotated frame (a.base at origin):
143 		 *
144 		 * 0                                             UT_MAX
145 		 * |--------------------------------------------------|
146 		 * [=== a ==================================]         |
147 		 *                   [== b1 =====================]
148 		 *
149 		 * 0                                             UT_MAX
150 		 * |--------------------------------------------------|
151 		 * [=== a ==================================]         |
152 		 *                   [== b1 ====]
153 		 *                   ^-- b1.base <= a.size
154 		 *                   |<-- a.size - dbase -->|
155 		 *
156 		 * 'a' and 'b' intersect as one cnum.
157 		 */
158 		return (struct cnum_t) {
159 			.base = b.base,
160 			.size = min((ut)(a.size - dbase), b.size),
161 		};
162 	} else {
163 		return EMPTY;
164 	}
165 }
166 
167 void FN(intersect_with)(struct cnum_t *dst, struct cnum_t src)
168 {
169 	*dst = FN(intersect)(*dst, src);
170 }
171 
172 void FN(intersect_with_urange)(struct cnum_t *dst, ut min, ut max)
173 {
174 	FN(intersect_with)(dst, FN(from_urange)(min, max));
175 }
176 
177 void FN(intersect_with_srange)(struct cnum_t *dst, st min, st max)
178 {
179 	FN(intersect_with)(dst, FN(from_srange)(min, max));
180 }
181 
182 static inline struct cnum_t FN(normalize)(struct cnum_t cnum)
183 {
184 	if (cnum.size == UT_MAX && cnum.base != 0 && cnum.base != (ut)ST_MAX)
185 		cnum.base = 0;
186 	return cnum;
187 }
188 
189 struct cnum_t FN(add)(struct cnum_t a, struct cnum_t b)
190 {
191 	if (FN(is_empty)(a) || FN(is_empty)(b))
192 		return EMPTY;
193 	if (a.size > UT_MAX - b.size)
194 		return (struct cnum_t){ 0, (ut)UT_MAX };
195 	else
196 		return FN(normalize)((struct cnum_t){ a.base + b.base, a.size + b.size });
197 }
198 
199 struct cnum_t FN(negate)(struct cnum_t a)
200 {
201 	if (FN(is_empty)(a))
202 		return EMPTY;
203 	return FN(normalize)((struct cnum_t){ -((ut)a.base + a.size), a.size });
204 }
205 
206 bool FN(is_empty)(struct cnum_t cnum)
207 {
208 	return cnum.base == EMPTY.base && cnum.size == EMPTY.size;
209 }
210 
211 bool FN(contains)(struct cnum_t cnum, ut v)
212 {
213 	if (FN(is_empty)(cnum))
214 		return false;
215 	if (FN(urange_overflow)(cnum))
216 		return v >= cnum.base || v <= (ut)cnum.base + cnum.size;
217 	else
218 		return v >= cnum.base && v <= (ut)cnum.base + cnum.size;
219 }
220 
221 bool FN(is_const)(struct cnum_t cnum)
222 {
223 	return cnum.size == 0;
224 }
225 
226 bool FN(is_subset)(struct cnum_t bigger, struct cnum_t smaller)
227 {
228 	if (FN(is_empty(smaller)))
229 		return true;
230 	if (FN(is_empty(bigger)))
231 		return false;
232 	/* rotate both arcs such that 'bigger' starts at origin, hence does not overflow */
233 	smaller.base -= bigger.base;
234 	bigger.base = 0;
235 	if (FN(urange_overflow)(smaller) && bigger.size < UT_MAX)
236 		return false;
237 	return smaller.base + smaller.size <= bigger.size;
238 }
239 
240 #undef EMPTY
241 #undef cnum_t
242 #undef ut
243 #undef st
244 #undef UT_MAX
245 #undef ST_MAX
246 #undef ST_MIN
247 #undef FN
248