xref: /linux/kernel/bpf/cnum.c (revision 256f0071f9b61ae5028f749449fd3fdad015889d)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* Copyright (c) 2026 Meta Platforms, Inc. and affiliates. */
3 
4 #include <linux/bits.h>
5 
6 #define T 32
7 #include "cnum_defs.h"
8 #undef T
9 
10 #define T 64
11 #include "cnum_defs.h"
12 #undef T
13 
14 struct cnum32 cnum32_from_cnum64(struct cnum64 cnum)
15 {
16 	if (cnum64_is_empty(cnum))
17 		return CNUM32_EMPTY;
18 
19 	if (cnum.size >= U32_MAX)
20 		return (struct cnum32){ .base = 0, .size = U32_MAX };
21 	else
22 		return (struct cnum32){ .base = (u32)cnum.base, .size = cnum.size };
23 }
24 
25 /*
26  * Suppose 'a' and 'b' are laid out as follows:
27  *
28  *                                                          64-bit number axis --->
29  *
30  * N*2^32                   (N+1)*2^32                (N+2)*2^32                (N+3)*2^32
31  * ||------|---|=====|-------||----------|=====|-------||----------|=====|----|--||
32  *         |   |< b >|                   |< b >|                   |< b >|    |
33  *         |   |                                                         |    |
34  *         |<--+--------------------------- a ---------------------------+--->|
35  *             |                                                         |
36  *             |<-------------------------- t -------------------------->|
37  *
38  * In such a case it is possible to infer a more tight representation t
39  * such that ∀ v ∈ a, (u32)v ∈ b: v ∈ t.
40  */
41 struct cnum64 cnum64_cnum32_intersect(struct cnum64 a, struct cnum32 b)
42 {
43 	/*
44 	 * To simplify reasoning, rotate the circles so that [virtual] a1 starts
45 	 * at u32 boundary, b1 represents b in this new frame of reference.
46 	 */
47 	struct cnum32 b1 = { b.base - (u32)a.base, b.size };
48 	struct cnum64 t = a;
49 	u64 d, b1_max;
50 
51 	if (cnum64_is_empty(a) || cnum32_is_empty(b))
52 		return CNUM64_EMPTY;
53 
54 	if (cnum32_urange_overflow(b1)) {
55 		b1_max = (u32)b1.base + (u32)b1.size; /* overflow here is fine and necessary */
56 		if ((u32)a.size > b1_max && (u32)a.size < b1.base) {
57 			/*
58 			 * N*2^32                   (N+1)*2^32
59 			 * ||=====|------------|=====||=====|---------|---|=====||
60 			 *  |b1 ->|            |<- b1||b1 ->|         |   |<- b1|
61 			 *  |<----------------- a1 ------------------>|
62 			 *  |<-------------- t ------------>|<-- d -->| (after adjustment)
63 			 *                                  ^
64 			 *                                b1_max
65 			 */
66 			d = (u32)a.size - b1_max;
67 			t.size -= d;
68 		} else {
69 			/*
70 			 * No adjustments possible in the following cases:
71 			 *
72 			 * ||=====|------------|=====||===|=|-------------|=|===||
73 			 *  |b1 ->|            |<- b1||b1 +>|             |<+ b1|
74 			 *  |<----------------- a1 ------>|                 |
75 			 *  |<----------------- (or) a1 ------------------->|
76 			 */
77 		}
78 	} else {
79 		if (t.size < b1.base)
80 			/*
81 			 * N*2^32                   (N+1)*2^32
82 			 * ||----------|--|=======|--||------>
83 			 *  |<-- a1 -->|  |<- b ->|
84 			 */
85 			return CNUM64_EMPTY;
86 		/*
87 		 * N*2^32                   (N+1)*2^32
88 		 * ||-------------|========|-||-----| -------|========|-||
89 		 *  |             |<- b1 ->|        |        |<- b1 ->|
90 		 *  |<------------+ a1 ------------>|
91 		 *                |<------ t ------>| (after adjustment)
92 		 */
93 		t.base += b1.base;
94 		t.size -= b1.base;
95 		b1_max = b1.base + b1.size;
96 		d = 0;
97 		if ((u32)a.size < b1.base)
98 			/*
99 			 * N*2^32                   (N+1)*2^32
100 			 * ||-------------|========|-||------|-------|========|-||
101 			 *  |             |<- b1 ->|         |       |<- b1 ->|
102 			 *  |<------------+-- a1 --+-------->|
103 			 *                |<- t  ->|<-- d -->| (after adjustment)
104 			 */
105 			d = (u32)a.size + (BIT_ULL(32) - b1_max);
106 		else if ((u32)a.size >= b1_max)
107 			/*
108 			 * N*2^32                   (N+1)*2^32
109 			 * ||--|========|------------||--|========|-------|-----||
110 			 *  |  |<- b1 ->|                |<- b1 ->|       |
111 			 *  |<-+------------------ a1 ------------+------>|
112 			 *     |<-------------- t --------------->|<- d ->| (after adjustment)
113 			 */
114 			d = (u32)a.size - b1_max;
115 		if (t.size < d)
116 			return CNUM64_EMPTY;
117 		t.size -= d;
118 	}
119 	return t;
120 }
121