1 /* SPDX-License-Identifier: GPL-2.0 */ 2 #ifndef MEAN_AND_VARIANCE_H_ 3 #define MEAN_AND_VARIANCE_H_ 4 5 #include <linux/types.h> 6 #include <linux/limits.h> 7 #include <linux/math.h> 8 #include <linux/math64.h> 9 10 #define SQRT_U64_MAX 4294967295ULL 11 12 /* 13 * u128_u: u128 user mode, because not all architectures support a real int128 14 * type 15 * 16 * We don't use this version in userspace, because in userspace we link with 17 * Rust and rustc has issues with u128. 18 */ 19 20 #if defined(__SIZEOF_INT128__) && defined(__KERNEL__) && !defined(CONFIG_PARISC) 21 22 typedef struct { 23 unsigned __int128 v; 24 } __aligned(16) u128_u; 25 26 static inline u128_u u64_to_u128(u64 a) 27 { 28 return (u128_u) { .v = a }; 29 } 30 31 static inline u64 u128_lo(u128_u a) 32 { 33 return a.v; 34 } 35 36 static inline u64 u128_hi(u128_u a) 37 { 38 return a.v >> 64; 39 } 40 41 static inline u128_u u128_add(u128_u a, u128_u b) 42 { 43 a.v += b.v; 44 return a; 45 } 46 47 static inline u128_u u128_sub(u128_u a, u128_u b) 48 { 49 a.v -= b.v; 50 return a; 51 } 52 53 static inline u128_u u128_shl(u128_u a, s8 shift) 54 { 55 a.v <<= shift; 56 return a; 57 } 58 59 static inline u128_u u128_square(u64 a) 60 { 61 u128_u b = u64_to_u128(a); 62 63 b.v *= b.v; 64 return b; 65 } 66 67 #else 68 69 typedef struct { 70 u64 hi, lo; 71 } __aligned(16) u128_u; 72 73 /* conversions */ 74 75 static inline u128_u u64_to_u128(u64 a) 76 { 77 return (u128_u) { .lo = a }; 78 } 79 80 static inline u64 u128_lo(u128_u a) 81 { 82 return a.lo; 83 } 84 85 static inline u64 u128_hi(u128_u a) 86 { 87 return a.hi; 88 } 89 90 /* arithmetic */ 91 92 static inline u128_u u128_add(u128_u a, u128_u b) 93 { 94 u128_u c; 95 96 c.lo = a.lo + b.lo; 97 c.hi = a.hi + b.hi + (c.lo < a.lo); 98 return c; 99 } 100 101 static inline u128_u u128_sub(u128_u a, u128_u b) 102 { 103 u128_u c; 104 105 c.lo = a.lo - b.lo; 106 c.hi = a.hi - b.hi - (c.lo > a.lo); 107 return c; 108 } 109 110 static inline u128_u u128_shl(u128_u i, s8 shift) 111 { 112 u128_u r; 113 114 r.lo = i.lo << shift; 115 if (shift < 64) 116 r.hi = (i.hi << shift) | (i.lo >> (64 - shift)); 117 else { 118 r.hi = i.lo << (shift - 64); 119 r.lo = 0; 120 } 121 return r; 122 } 123 124 static inline u128_u u128_square(u64 i) 125 { 126 u128_u r; 127 u64 h = i >> 32, l = i & U32_MAX; 128 129 r = u128_shl(u64_to_u128(h*h), 64); 130 r = u128_add(r, u128_shl(u64_to_u128(h*l), 32)); 131 r = u128_add(r, u128_shl(u64_to_u128(l*h), 32)); 132 r = u128_add(r, u64_to_u128(l*l)); 133 return r; 134 } 135 136 #endif 137 138 static inline u128_u u64s_to_u128(u64 hi, u64 lo) 139 { 140 u128_u c = u64_to_u128(hi); 141 142 c = u128_shl(c, 64); 143 c = u128_add(c, u64_to_u128(lo)); 144 return c; 145 } 146 147 u128_u u128_div(u128_u n, u64 d); 148 149 struct mean_and_variance { 150 s64 n; 151 s64 sum; 152 u128_u sum_squares; 153 }; 154 155 /* expontentially weighted variant */ 156 struct mean_and_variance_weighted { 157 s64 mean; 158 u64 variance; 159 }; 160 161 /** 162 * fast_divpow2() - fast approximation for n / (1 << d) 163 * @n: numerator 164 * @d: the power of 2 denominator. 165 * 166 * note: this rounds towards 0. 167 */ 168 static inline s64 fast_divpow2(s64 n, u8 d) 169 { 170 return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d; 171 } 172 173 /** 174 * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1 175 * and return it. 176 * @s1: the mean_and_variance to update. 177 * @v1: the new sample. 178 * 179 * see linked pdf equation 12. 180 */ 181 static inline void 182 mean_and_variance_update(struct mean_and_variance *s, s64 v) 183 { 184 s->n++; 185 s->sum += v; 186 s->sum_squares = u128_add(s->sum_squares, u128_square(abs(v))); 187 } 188 189 s64 mean_and_variance_get_mean(struct mean_and_variance s); 190 u64 mean_and_variance_get_variance(struct mean_and_variance s1); 191 u32 mean_and_variance_get_stddev(struct mean_and_variance s); 192 193 void mean_and_variance_weighted_update(struct mean_and_variance_weighted *s, 194 s64 v, bool initted, u8 weight); 195 196 s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s, 197 u8 weight); 198 u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s, 199 u8 weight); 200 u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s, 201 u8 weight); 202 203 #endif // MEAN_AND_VAIRANCE_H_ 204