1 /* SPDX-License-Identifier: GPL-2.0 */ 2 #ifndef MEAN_AND_VARIANCE_H_ 3 #define MEAN_AND_VARIANCE_H_ 4 5 #include <linux/types.h> 6 #include <linux/limits.h> 7 #include <linux/math.h> 8 #include <linux/math64.h> 9 10 #define SQRT_U64_MAX 4294967295ULL 11 12 /* 13 * u128_u: u128 user mode, because not all architectures support a real int128 14 * type 15 */ 16 17 #ifdef __SIZEOF_INT128__ 18 19 typedef struct { 20 unsigned __int128 v; 21 } __aligned(16) u128_u; 22 23 static inline u128_u u64_to_u128(u64 a) 24 { 25 return (u128_u) { .v = a }; 26 } 27 28 static inline u64 u128_lo(u128_u a) 29 { 30 return a.v; 31 } 32 33 static inline u64 u128_hi(u128_u a) 34 { 35 return a.v >> 64; 36 } 37 38 static inline u128_u u128_add(u128_u a, u128_u b) 39 { 40 a.v += b.v; 41 return a; 42 } 43 44 static inline u128_u u128_sub(u128_u a, u128_u b) 45 { 46 a.v -= b.v; 47 return a; 48 } 49 50 static inline u128_u u128_shl(u128_u a, s8 shift) 51 { 52 a.v <<= shift; 53 return a; 54 } 55 56 static inline u128_u u128_square(u64 a) 57 { 58 u128_u b = u64_to_u128(a); 59 60 b.v *= b.v; 61 return b; 62 } 63 64 #else 65 66 typedef struct { 67 u64 hi, lo; 68 } __aligned(16) u128_u; 69 70 /* conversions */ 71 72 static inline u128_u u64_to_u128(u64 a) 73 { 74 return (u128_u) { .lo = a }; 75 } 76 77 static inline u64 u128_lo(u128_u a) 78 { 79 return a.lo; 80 } 81 82 static inline u64 u128_hi(u128_u a) 83 { 84 return a.hi; 85 } 86 87 /* arithmetic */ 88 89 static inline u128_u u128_add(u128_u a, u128_u b) 90 { 91 u128_u c; 92 93 c.lo = a.lo + b.lo; 94 c.hi = a.hi + b.hi + (c.lo < a.lo); 95 return c; 96 } 97 98 static inline u128_u u128_sub(u128_u a, u128_u b) 99 { 100 u128_u c; 101 102 c.lo = a.lo - b.lo; 103 c.hi = a.hi - b.hi - (c.lo > a.lo); 104 return c; 105 } 106 107 static inline u128_u u128_shl(u128_u i, s8 shift) 108 { 109 u128_u r; 110 111 r.lo = i.lo << shift; 112 if (shift < 64) 113 r.hi = (i.hi << shift) | (i.lo >> (64 - shift)); 114 else { 115 r.hi = i.lo << (shift - 64); 116 r.lo = 0; 117 } 118 return r; 119 } 120 121 static inline u128_u u128_square(u64 i) 122 { 123 u128_u r; 124 u64 h = i >> 32, l = i & U32_MAX; 125 126 r = u128_shl(u64_to_u128(h*h), 64); 127 r = u128_add(r, u128_shl(u64_to_u128(h*l), 32)); 128 r = u128_add(r, u128_shl(u64_to_u128(l*h), 32)); 129 r = u128_add(r, u64_to_u128(l*l)); 130 return r; 131 } 132 133 #endif 134 135 static inline u128_u u64s_to_u128(u64 hi, u64 lo) 136 { 137 u128_u c = u64_to_u128(hi); 138 139 c = u128_shl(c, 64); 140 c = u128_add(c, u64_to_u128(lo)); 141 return c; 142 } 143 144 u128_u u128_div(u128_u n, u64 d); 145 146 struct mean_and_variance { 147 s64 n; 148 s64 sum; 149 u128_u sum_squares; 150 }; 151 152 /* expontentially weighted variant */ 153 struct mean_and_variance_weighted { 154 bool init; 155 u8 weight; /* base 2 logarithim */ 156 s64 mean; 157 u64 variance; 158 }; 159 160 /** 161 * fast_divpow2() - fast approximation for n / (1 << d) 162 * @n: numerator 163 * @d: the power of 2 denominator. 164 * 165 * note: this rounds towards 0. 166 */ 167 static inline s64 fast_divpow2(s64 n, u8 d) 168 { 169 return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d; 170 } 171 172 /** 173 * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1 174 * and return it. 175 * @s1: the mean_and_variance to update. 176 * @v1: the new sample. 177 * 178 * see linked pdf equation 12. 179 */ 180 static inline void 181 mean_and_variance_update(struct mean_and_variance *s, s64 v) 182 { 183 s->n++; 184 s->sum += v; 185 s->sum_squares = u128_add(s->sum_squares, u128_square(abs(v))); 186 } 187 188 s64 mean_and_variance_get_mean(struct mean_and_variance s); 189 u64 mean_and_variance_get_variance(struct mean_and_variance s1); 190 u32 mean_and_variance_get_stddev(struct mean_and_variance s); 191 192 void mean_and_variance_weighted_update(struct mean_and_variance_weighted *s, s64 v); 193 194 s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s); 195 u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s); 196 u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s); 197 198 #endif // MEAN_AND_VAIRANCE_H_ 199