1 /* SPDX-License-Identifier: GPL-2.0-only */ 2 /* Copyright (c) 2026 Meta Platforms, Inc. and affiliates. */ 3 4 #ifndef T 5 #error "Define T (bit width: 32, 64) before including cnum_defs.h" 6 #endif 7 8 #include <linux/cnum.h> 9 #include <linux/kernel.h> 10 #include <linux/limits.h> 11 #include <linux/minmax.h> 12 #include <linux/compiler_types.h> 13 14 #define cnum_t __PASTE(cnum, T) 15 #define ut __PASTE(u, T) 16 #define st __PASTE(s, T) 17 #define UT_MAX __PASTE(__PASTE(U, T), _MAX) 18 #define ST_MAX __PASTE(__PASTE(S, T), _MAX) 19 #define ST_MIN __PASTE(__PASTE(S, T), _MIN) 20 #define EMPTY __PASTE(__PASTE(CNUM, T), _EMPTY) 21 #define FN(name) __PASTE(__PASTE(cnum, T), __PASTE(_, name)) 22 23 struct cnum_t FN(from_urange)(ut min, ut max) 24 { 25 return (struct cnum_t){ .base = min, .size = (ut)max - min }; 26 } 27 28 struct cnum_t FN(from_srange)(st min, st max) 29 { 30 ut size = (ut)max - (ut)min; 31 ut base = size == UT_MAX ? 0 : (ut)min; 32 33 return (struct cnum_t){ .base = base, .size = size }; 34 } 35 36 /* True if this cnum represents two unsigned ranges. */ 37 static inline bool FN(urange_overflow)(struct cnum_t cnum) 38 { 39 /* Same as cnum.base + cnum.size > UT_MAX but avoids overflow */ 40 return cnum.size > UT_MAX - (ut)cnum.base; 41 } 42 43 /* 44 * cnum{T}_umin / cnum{T}_umax query an unsigned range represented by this cnum. 45 * If cnum represents a range crossing the UT_MAX/0 boundary, the unbound range 46 * [0..UT_MAX] is returned. 47 */ 48 ut FN(umin)(struct cnum_t cnum) 49 { 50 return FN(urange_overflow)(cnum) ? 0 : cnum.base; 51 } 52 EXPORT_SYMBOL_GPL(FN(umin)); 53 54 ut FN(umax)(struct cnum_t cnum) 55 { 56 return FN(urange_overflow)(cnum) ? UT_MAX : cnum.base + cnum.size; 57 } 58 EXPORT_SYMBOL_GPL(FN(umax)); 59 60 /* True if this cnum represents two signed ranges. */ 61 static inline bool FN(srange_overflow)(struct cnum_t cnum) 62 { 63 return FN(contains)(cnum, (ut)ST_MAX) && FN(contains)(cnum, (ut)ST_MIN); 64 } 65 66 /* 67 * cnum{T}_smin / cnum{T}_smax query a signed range represented by this cnum. 68 * If cnum represents a range crossing the ST_MAX/ST_MIN boundary, the unbound range 69 * [ST_MIN..ST_MAX] is returned. 70 */ 71 st FN(smin)(struct cnum_t cnum) 72 { 73 return FN(srange_overflow)(cnum) 74 ? ST_MIN 75 : min((st)cnum.base, (st)(cnum.base + cnum.size)); 76 } 77 78 st FN(smax)(struct cnum_t cnum) 79 { 80 return FN(srange_overflow)(cnum) 81 ? ST_MAX 82 : max((st)cnum.base, (st)(cnum.base + cnum.size)); 83 } 84 85 /* 86 * Returns a possibly empty intersection of cnums 'a' and 'b'. 87 * If 'a' and 'b' intersect in two sub-arcs, the function over-approximates 88 * and returns either 'a' or 'b', whichever is smaller. 89 */ 90 struct cnum_t FN(intersect)(struct cnum_t a, struct cnum_t b) 91 { 92 struct cnum_t b1; 93 ut dbase; 94 95 if (FN(is_empty)(a) || FN(is_empty)(b)) 96 return EMPTY; 97 98 if (a.base > b.base) 99 swap(a, b); 100 101 /* 102 * Rotate frame of reference such that a.base is 0. 103 * 'b1' is 'b' in this frame of reference. 104 */ 105 dbase = b.base - a.base; 106 b1 = (struct cnum_t){ dbase, b.size }; 107 if (FN(urange_overflow)(b1)) { 108 if (b1.base <= a.size) { 109 /* 110 * Rotated frame (a.base at origin): 111 * 112 * 0 UT_MAX 113 * |--------------------------------------------| 114 * [=== a ==========================] | 115 * [= b1 tail =] [========= b1 main ==========>] 116 * ^-- b1.base <= a.size 117 * 118 * 'a' and 'b' intersect in two disjoint arcs, 119 * can't represent as single cnum, over-approximate 120 * the result. 121 */ 122 return a.size <= b.size ? a : b; 123 } else { 124 /* 125 * Rotated frame (a.base at origin): 126 * 127 * 0 UT_MAX 128 * |--------------------------------------------| 129 * [=== a =============] | | 130 * [= b1 tail =] [======= b1 main ====>] 131 * ^-- b1.base > a.size 132 * 133 * Only 'b' tail intersects 'a'. 134 */ 135 return (struct cnum_t) { 136 .base = a.base, 137 .size = min(a.size, (ut)(b1.base + b1.size)), 138 }; 139 } 140 } else if (a.size >= b1.base) { 141 /* 142 * Rotated frame (a.base at origin): 143 * 144 * 0 UT_MAX 145 * |--------------------------------------------------| 146 * [=== a ==================================] | 147 * [== b1 =====================] 148 * 149 * 0 UT_MAX 150 * |--------------------------------------------------| 151 * [=== a ==================================] | 152 * [== b1 ====] 153 * ^-- b1.base <= a.size 154 * |<-- a.size - dbase -->| 155 * 156 * 'a' and 'b' intersect as one cnum. 157 */ 158 return (struct cnum_t) { 159 .base = b.base, 160 .size = min((ut)(a.size - dbase), b.size), 161 }; 162 } else { 163 return EMPTY; 164 } 165 } 166 167 void FN(intersect_with)(struct cnum_t *dst, struct cnum_t src) 168 { 169 *dst = FN(intersect)(*dst, src); 170 } 171 172 void FN(intersect_with_urange)(struct cnum_t *dst, ut min, ut max) 173 { 174 FN(intersect_with)(dst, FN(from_urange)(min, max)); 175 } 176 177 void FN(intersect_with_srange)(struct cnum_t *dst, st min, st max) 178 { 179 FN(intersect_with)(dst, FN(from_srange)(min, max)); 180 } 181 182 static inline struct cnum_t FN(normalize)(struct cnum_t cnum) 183 { 184 if (cnum.size == UT_MAX && cnum.base != 0 && cnum.base != (ut)ST_MAX) 185 cnum.base = 0; 186 return cnum; 187 } 188 189 struct cnum_t FN(add)(struct cnum_t a, struct cnum_t b) 190 { 191 if (FN(is_empty)(a) || FN(is_empty)(b)) 192 return EMPTY; 193 if (a.size > UT_MAX - b.size) 194 return (struct cnum_t){ 0, (ut)UT_MAX }; 195 else 196 return FN(normalize)((struct cnum_t){ a.base + b.base, a.size + b.size }); 197 } 198 199 struct cnum_t FN(negate)(struct cnum_t a) 200 { 201 if (FN(is_empty)(a)) 202 return EMPTY; 203 return FN(normalize)((struct cnum_t){ -((ut)a.base + a.size), a.size }); 204 } 205 206 bool FN(is_empty)(struct cnum_t cnum) 207 { 208 return cnum.base == EMPTY.base && cnum.size == EMPTY.size; 209 } 210 211 bool FN(contains)(struct cnum_t cnum, ut v) 212 { 213 if (FN(is_empty)(cnum)) 214 return false; 215 if (FN(urange_overflow)(cnum)) 216 return v >= cnum.base || v <= (ut)cnum.base + cnum.size; 217 else 218 return v >= cnum.base && v <= (ut)cnum.base + cnum.size; 219 } 220 221 bool FN(is_const)(struct cnum_t cnum) 222 { 223 return cnum.size == 0; 224 } 225 226 bool FN(is_subset)(struct cnum_t bigger, struct cnum_t smaller) 227 { 228 if (FN(is_empty(smaller))) 229 return true; 230 if (FN(is_empty(bigger))) 231 return false; 232 /* rotate both arcs such that 'bigger' starts at origin, hence does not overflow */ 233 smaller.base -= bigger.base; 234 bigger.base = 0; 235 if (FN(urange_overflow)(smaller) && bigger.size < UT_MAX) 236 return false; 237 return smaller.base + smaller.size <= bigger.size; 238 } 239 240 #undef EMPTY 241 #undef cnum_t 242 #undef ut 243 #undef st 244 #undef UT_MAX 245 #undef ST_MAX 246 #undef ST_MIN 247 #undef FN 248