xref: /linux/kernel/bpf/cnum_defs.h (revision cd5b460ed1eca9e48f3eb07db1ee0a522c0eaa23)
1 /* SPDX-License-Identifier: GPL-2.0-only */
2 /* Copyright (c) 2026 Meta Platforms, Inc. and affiliates. */
3 
4 #ifndef T
5 #error "Define T (bit width: 32, 64) before including cnum_defs.h"
6 #endif
7 
8 #include <linux/cnum.h>
9 #include <linux/limits.h>
10 #include <linux/minmax.h>
11 #include <linux/compiler_types.h>
12 
13 #define cnum_t   __PASTE(cnum, T)
14 #define ut       __PASTE(u, T)
15 #define st       __PASTE(s, T)
16 #define UT_MAX   __PASTE(__PASTE(U, T), _MAX)
17 #define ST_MAX   __PASTE(__PASTE(S, T), _MAX)
18 #define ST_MIN   __PASTE(__PASTE(S, T), _MIN)
19 #define EMPTY    __PASTE(__PASTE(CNUM, T), _EMPTY)
20 #define FN(name) __PASTE(__PASTE(cnum, T), __PASTE(_, name))
21 
22 struct cnum_t FN(from_urange)(ut min, ut max)
23 {
24 	return (struct cnum_t){ .base = min, .size = (ut)max - min };
25 }
26 
27 struct cnum_t FN(from_srange)(st min, st max)
28 {
29 	ut size = (ut)max - (ut)min;
30 	ut base = size == UT_MAX ? 0 : (ut)min;
31 
32 	return (struct cnum_t){ .base = base, .size = size };
33 }
34 
35 /* True if this cnum represents two unsigned ranges. */
36 static inline bool FN(urange_overflow)(struct cnum_t cnum)
37 {
38 	/* Same as cnum.base + cnum.size > UT_MAX but avoids overflow */
39 	return cnum.size > UT_MAX - (ut)cnum.base;
40 }
41 
42 /*
43  * cnum{T}_umin / cnum{T}_umax query an unsigned range represented by this cnum.
44  * If cnum represents a range crossing the UT_MAX/0 boundary, the unbound range
45  * [0..UT_MAX] is returned.
46  */
47 ut FN(umin)(struct cnum_t cnum)
48 {
49 	return FN(urange_overflow)(cnum) ? 0 : cnum.base;
50 }
51 
52 ut FN(umax)(struct cnum_t cnum)
53 {
54 	return FN(urange_overflow)(cnum) ? UT_MAX : cnum.base + cnum.size;
55 }
56 
57 /* True if this cnum represents two signed ranges. */
58 static inline bool FN(srange_overflow)(struct cnum_t cnum)
59 {
60 	return FN(contains)(cnum, (ut)ST_MAX) && FN(contains)(cnum, (ut)ST_MIN);
61 }
62 
63 /*
64  * cnum{T}_smin / cnum{T}_smax query a signed range represented by this cnum.
65  * If cnum represents a range crossing the ST_MAX/ST_MIN boundary, the unbound range
66  * [ST_MIN..ST_MAX] is returned.
67  */
68 st FN(smin)(struct cnum_t cnum)
69 {
70 	return FN(srange_overflow)(cnum)
71 	       ? ST_MIN
72 	       : min((st)cnum.base, (st)(cnum.base + cnum.size));
73 }
74 
75 st FN(smax)(struct cnum_t cnum)
76 {
77 	return FN(srange_overflow)(cnum)
78 	       ? ST_MAX
79 	       : max((st)cnum.base, (st)(cnum.base + cnum.size));
80 }
81 
82 /*
83  * Returns a possibly empty intersection of cnums 'a' and 'b'.
84  * If 'a' and 'b' intersect in two sub-arcs, the function over-approximates
85  * and returns either 'a' or 'b', whichever is smaller.
86  */
87 struct cnum_t FN(intersect)(struct cnum_t a, struct cnum_t b)
88 {
89 	struct cnum_t b1;
90 	ut dbase;
91 
92 	if (FN(is_empty)(a) || FN(is_empty)(b))
93 		return EMPTY;
94 
95 	if (a.base > b.base)
96 		swap(a, b);
97 
98 	/*
99 	 * Rotate frame of reference such that a.base is 0.
100 	 * 'b1' is 'b' in this frame of reference.
101 	 */
102 	dbase = b.base - a.base;
103 	b1 = (struct cnum_t){ dbase, b.size };
104 	if (FN(urange_overflow)(b1)) {
105 		if (b1.base <= a.size) {
106 			/*
107 			 * Rotated frame (a.base at origin):
108 			 *
109 			 * 0                                       UT_MAX
110 			 * |--------------------------------------------|
111 			 * [=== a ==========================]           |
112 			 * [= b1 tail =]  [========= b1 main ==========>]
113 			 *                 ^-- b1.base <= a.size
114 			 *
115 			 * 'a' and 'b' intersect in two disjoint arcs,
116 			 * can't represent as single cnum, over-approximate
117 			 * the result.
118 			 */
119 			return a.size <= b.size ? a : b;
120 		} else {
121 			/*
122 			 * Rotated frame (a.base at origin):
123 			 *
124 			 * 0                                       UT_MAX
125 			 * |--------------------------------------------|
126 			 * [=== a =============]  |                     |
127 			 * [= b1 tail =]          [======= b1 main ====>]
128 			 *                         ^-- b1.base > a.size
129 			 *
130 			 * Only 'b' tail intersects 'a'.
131 			 */
132 			return (struct cnum_t) {
133 				.base = a.base,
134 				.size = min(a.size, (ut)(b1.base + b1.size)),
135 			};
136 		}
137 	} else if (a.size >= b1.base) {
138 		/*
139 		 * Rotated frame (a.base at origin):
140 		 *
141 		 * 0                                             UT_MAX
142 		 * |--------------------------------------------------|
143 		 * [=== a ==================================]         |
144 		 *                   [== b1 =====================]
145 		 *
146 		 * 0                                             UT_MAX
147 		 * |--------------------------------------------------|
148 		 * [=== a ==================================]         |
149 		 *                   [== b1 ====]
150 		 *                   ^-- b1.base <= a.size
151 		 *                   |<-- a.size - dbase -->|
152 		 *
153 		 * 'a' and 'b' intersect as one cnum.
154 		 */
155 		return (struct cnum_t) {
156 			.base = b.base,
157 			.size = min((ut)(a.size - dbase), b.size),
158 		};
159 	} else {
160 		return EMPTY;
161 	}
162 }
163 
164 void FN(intersect_with)(struct cnum_t *dst, struct cnum_t src)
165 {
166 	*dst = FN(intersect)(*dst, src);
167 }
168 
169 void FN(intersect_with_urange)(struct cnum_t *dst, ut min, ut max)
170 {
171 	FN(intersect_with)(dst, FN(from_urange)(min, max));
172 }
173 
174 void FN(intersect_with_srange)(struct cnum_t *dst, st min, st max)
175 {
176 	FN(intersect_with)(dst, FN(from_srange)(min, max));
177 }
178 
179 static inline struct cnum_t FN(normalize)(struct cnum_t cnum)
180 {
181 	if (cnum.size == UT_MAX && cnum.base != 0 && cnum.base != (ut)ST_MAX)
182 		cnum.base = 0;
183 	return cnum;
184 }
185 
186 struct cnum_t FN(add)(struct cnum_t a, struct cnum_t b)
187 {
188 	if (FN(is_empty)(a) || FN(is_empty)(b))
189 		return EMPTY;
190 	if (a.size > UT_MAX - b.size)
191 		return (struct cnum_t){ 0, (ut)UT_MAX };
192 	else
193 		return FN(normalize)((struct cnum_t){ a.base + b.base, a.size + b.size });
194 }
195 
196 struct cnum_t FN(negate)(struct cnum_t a)
197 {
198 	if (FN(is_empty)(a))
199 		return EMPTY;
200 	return FN(normalize)((struct cnum_t){ -((ut)a.base + a.size), a.size });
201 }
202 
203 bool FN(is_empty)(struct cnum_t cnum)
204 {
205 	return cnum.base == EMPTY.base && cnum.size == EMPTY.size;
206 }
207 
208 bool FN(contains)(struct cnum_t cnum, ut v)
209 {
210 	if (FN(is_empty)(cnum))
211 		return false;
212 	if (FN(urange_overflow)(cnum))
213 		return v >= cnum.base || v <= (ut)cnum.base + cnum.size;
214 	else
215 		return v >= cnum.base && v <= (ut)cnum.base + cnum.size;
216 }
217 
218 bool FN(is_const)(struct cnum_t cnum)
219 {
220 	return cnum.size == 0;
221 }
222 
223 bool FN(is_subset)(struct cnum_t bigger, struct cnum_t smaller)
224 {
225 	if (FN(is_empty(smaller)))
226 		return true;
227 	if (FN(is_empty(bigger)))
228 		return false;
229 	/* rotate both arcs such that 'bigger' starts at origin, hence does not overflow */
230 	smaller.base -= bigger.base;
231 	bigger.base = 0;
232 	if (FN(urange_overflow)(smaller) && bigger.size < UT_MAX)
233 		return false;
234 	return smaller.base + smaller.size <= bigger.size;
235 }
236 
237 #undef EMPTY
238 #undef cnum_t
239 #undef ut
240 #undef st
241 #undef UT_MAX
242 #undef ST_MAX
243 #undef ST_MIN
244 #undef FN
245