xref: /linux/drivers/iommu/generic_pt/pt_log2.h (revision ce5cfb0fa20dc6454da039612e34325b7b4a8243)
1 /* SPDX-License-Identifier: GPL-2.0-only */
2 /*
3  * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES
4  *
5  * Helper macros for working with log2 values
6  *
7  */
8 #ifndef __GENERIC_PT_LOG2_H
9 #define __GENERIC_PT_LOG2_H
10 #include <linux/bitops.h>
11 #include <linux/limits.h>
12 
13 /* Compute a */
14 #define log2_to_int_t(type, a_lg2) ((type)(((type)1) << (a_lg2)))
15 static_assert(log2_to_int_t(unsigned int, 0) == 1);
16 
17 /* Compute a - 1 (aka all low bits set) */
18 #define log2_to_max_int_t(type, a_lg2) ((type)(log2_to_int_t(type, a_lg2) - 1))
19 
20 /* Compute a / b */
21 #define log2_div_t(type, a, b_lg2) ((type)(((type)a) >> (b_lg2)))
22 static_assert(log2_div_t(unsigned int, 4, 2) == 1);
23 
24 /*
25  * Compute:
26  *   a / c == b / c
27  * aka the high bits are equal
28  */
29 #define log2_div_eq_t(type, a, b, c_lg2) \
30 	(log2_div_t(type, (a) ^ (b), c_lg2) == 0)
31 static_assert(log2_div_eq_t(unsigned int, 1, 1, 2));
32 
33 /* Compute a % b */
34 #define log2_mod_t(type, a, b_lg2) \
35 	((type)(((type)a) & log2_to_max_int_t(type, b_lg2)))
36 static_assert(log2_mod_t(unsigned int, 1, 2) == 1);
37 
38 /*
39  * Compute:
40  *   a % b == b - 1
41  * aka the low bits are all 1s
42  */
43 #define log2_mod_eq_max_t(type, a, b_lg2) \
44 	(log2_mod_t(type, a, b_lg2) == log2_to_max_int_t(type, b_lg2))
45 static_assert(log2_mod_eq_max_t(unsigned int, 3, 2));
46 
47 /*
48  * Return a value such that:
49  *    a / b == ret / b
50  *    ret % b == val
51  * aka set the low bits to val. val must be < b
52  */
53 #define log2_set_mod_t(type, a, val, b_lg2) \
54 	((((type)(a)) & (~log2_to_max_int_t(type, b_lg2))) | ((type)(val)))
55 static_assert(log2_set_mod_t(unsigned int, 3, 1, 2) == 1);
56 
57 /* Return a value such that:
58  *    a / b == ret / b
59  *    ret % b == b - 1
60  * aka set the low bits to all 1s
61  */
62 #define log2_set_mod_max_t(type, a, b_lg2) \
63 	(((type)(a)) | log2_to_max_int_t(type, b_lg2))
64 static_assert(log2_set_mod_max_t(unsigned int, 2, 2) == 3);
65 
66 /* Compute a * b */
67 #define log2_mul_t(type, a, b_lg2) ((type)(((type)a) << (b_lg2)))
68 static_assert(log2_mul_t(unsigned int, 2, 2) == 8);
69 
70 #define _dispatch_sz(type, fn, a) \
71 	(sizeof(type) == 4 ? fn##32((u32)a) : fn##64(a))
72 
73 /*
74  * Return the highest value such that:
75  *    fls_t(u32, 0) == 0
76  *    fls_t(u3, 1) == 1
77  *    a >= log2_to_int(ret - 1)
78  * aka find last set bit
79  */
fls32(u32 a)80 static inline unsigned int fls32(u32 a)
81 {
82 	return fls(a);
83 }
84 #define fls_t(type, a) _dispatch_sz(type, fls, a)
85 
86 /*
87  * Return the highest value such that:
88  *    ffs_t(u32, 0) == UNDEFINED
89  *    ffs_t(u32, 1) == 0
90  *    log_mod(a, ret) == 0
91  * aka find first set bit
92  */
__ffs32(u32 a)93 static inline unsigned int __ffs32(u32 a)
94 {
95 	return __ffs(a);
96 }
97 #define ffs_t(type, a) _dispatch_sz(type, __ffs, a)
98 
99 /*
100  * Return the highest value such that:
101  *    ffz_t(u32, U32_MAX) == UNDEFINED
102  *    ffz_t(u32, 0) == 0
103  *    ffz_t(u32, 1) == 1
104  *    log_mod(a, ret) == log_to_max_int(ret)
105  * aka find first zero bit
106  */
ffz32(u32 a)107 static inline unsigned int ffz32(u32 a)
108 {
109 	return ffz(a);
110 }
ffz64(u64 a)111 static inline unsigned int ffz64(u64 a)
112 {
113 	if (sizeof(u64) == sizeof(unsigned long))
114 		return ffz(a);
115 
116 	if ((u32)a == U32_MAX)
117 		return ffz32(a >> 32) + 32;
118 	return ffz32(a);
119 }
120 #define ffz_t(type, a) _dispatch_sz(type, ffz, a)
121 
122 #endif
123