1 /* SPDX-License-Identifier: GPL-2.0 */ 2 #ifndef MEAN_AND_VARIANCE_H_ 3 #define MEAN_AND_VARIANCE_H_ 4 5 #include <linux/types.h> 6 #include <linux/limits.h> 7 #include <linux/math.h> 8 #include <linux/math64.h> 9 10 #define SQRT_U64_MAX 4294967295ULL 11 12 /* 13 * u128_u: u128 user mode, because not all architectures support a real int128 14 * type 15 * 16 * We don't use this version in userspace, because in userspace we link with 17 * Rust and rustc has issues with u128. 18 */ 19 20 #if defined(__SIZEOF_INT128__) && defined(__KERNEL__) 21 22 typedef struct { 23 unsigned __int128 v; 24 } __aligned(16) u128_u; 25 26 static inline u128_u u64_to_u128(u64 a) 27 { 28 return (u128_u) { .v = a }; 29 } 30 31 static inline u64 u128_lo(u128_u a) 32 { 33 return a.v; 34 } 35 36 static inline u64 u128_hi(u128_u a) 37 { 38 return a.v >> 64; 39 } 40 41 static inline u128_u u128_add(u128_u a, u128_u b) 42 { 43 a.v += b.v; 44 return a; 45 } 46 47 static inline u128_u u128_sub(u128_u a, u128_u b) 48 { 49 a.v -= b.v; 50 return a; 51 } 52 53 static inline u128_u u128_shl(u128_u a, s8 shift) 54 { 55 a.v <<= shift; 56 return a; 57 } 58 59 static inline u128_u u128_square(u64 a) 60 { 61 u128_u b = u64_to_u128(a); 62 63 b.v *= b.v; 64 return b; 65 } 66 67 #else 68 69 typedef struct { 70 u64 hi, lo; 71 } __aligned(16) u128_u; 72 73 /* conversions */ 74 75 static inline u128_u u64_to_u128(u64 a) 76 { 77 return (u128_u) { .lo = a }; 78 } 79 80 static inline u64 u128_lo(u128_u a) 81 { 82 return a.lo; 83 } 84 85 static inline u64 u128_hi(u128_u a) 86 { 87 return a.hi; 88 } 89 90 /* arithmetic */ 91 92 static inline u128_u u128_add(u128_u a, u128_u b) 93 { 94 u128_u c; 95 96 c.lo = a.lo + b.lo; 97 c.hi = a.hi + b.hi + (c.lo < a.lo); 98 return c; 99 } 100 101 static inline u128_u u128_sub(u128_u a, u128_u b) 102 { 103 u128_u c; 104 105 c.lo = a.lo - b.lo; 106 c.hi = a.hi - b.hi - (c.lo > a.lo); 107 return c; 108 } 109 110 static inline u128_u u128_shl(u128_u i, s8 shift) 111 { 112 u128_u r; 113 114 r.lo = i.lo << shift; 115 if (shift < 64) 116 r.hi = (i.hi << shift) | (i.lo >> (64 - shift)); 117 else { 118 r.hi = i.lo << (shift - 64); 119 r.lo = 0; 120 } 121 return r; 122 } 123 124 static inline u128_u u128_square(u64 i) 125 { 126 u128_u r; 127 u64 h = i >> 32, l = i & U32_MAX; 128 129 r = u128_shl(u64_to_u128(h*h), 64); 130 r = u128_add(r, u128_shl(u64_to_u128(h*l), 32)); 131 r = u128_add(r, u128_shl(u64_to_u128(l*h), 32)); 132 r = u128_add(r, u64_to_u128(l*l)); 133 return r; 134 } 135 136 #endif 137 138 static inline u128_u u64s_to_u128(u64 hi, u64 lo) 139 { 140 u128_u c = u64_to_u128(hi); 141 142 c = u128_shl(c, 64); 143 c = u128_add(c, u64_to_u128(lo)); 144 return c; 145 } 146 147 u128_u u128_div(u128_u n, u64 d); 148 149 struct mean_and_variance { 150 s64 n; 151 s64 sum; 152 u128_u sum_squares; 153 }; 154 155 /* expontentially weighted variant */ 156 struct mean_and_variance_weighted { 157 bool init; 158 u8 weight; /* base 2 logarithim */ 159 s64 mean; 160 u64 variance; 161 }; 162 163 /** 164 * fast_divpow2() - fast approximation for n / (1 << d) 165 * @n: numerator 166 * @d: the power of 2 denominator. 167 * 168 * note: this rounds towards 0. 169 */ 170 static inline s64 fast_divpow2(s64 n, u8 d) 171 { 172 return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d; 173 } 174 175 /** 176 * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1 177 * and return it. 178 * @s1: the mean_and_variance to update. 179 * @v1: the new sample. 180 * 181 * see linked pdf equation 12. 182 */ 183 static inline void 184 mean_and_variance_update(struct mean_and_variance *s, s64 v) 185 { 186 s->n++; 187 s->sum += v; 188 s->sum_squares = u128_add(s->sum_squares, u128_square(abs(v))); 189 } 190 191 s64 mean_and_variance_get_mean(struct mean_and_variance s); 192 u64 mean_and_variance_get_variance(struct mean_and_variance s1); 193 u32 mean_and_variance_get_stddev(struct mean_and_variance s); 194 195 void mean_and_variance_weighted_update(struct mean_and_variance_weighted *s, s64 v); 196 197 s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s); 198 u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s); 199 u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s); 200 201 #endif // MEAN_AND_VAIRANCE_H_ 202