xref: /linux/fs/bcachefs/mean_and_variance.h (revision 7a92fc8b4d20680e4c20289a670d8fca2d1f2c1b)
1 /* SPDX-License-Identifier: GPL-2.0 */
2 #ifndef MEAN_AND_VARIANCE_H_
3 #define MEAN_AND_VARIANCE_H_
4 
5 #include <linux/types.h>
6 #include <linux/limits.h>
7 #include <linux/math.h>
8 #include <linux/math64.h>
9 
10 #define SQRT_U64_MAX 4294967295ULL
11 
12 /*
13  * u128_u: u128 user mode, because not all architectures support a real int128
14  * type
15  */
16 
17 #ifdef __SIZEOF_INT128__
18 
19 typedef struct {
20 	unsigned __int128 v;
21 } __aligned(16) u128_u;
22 
23 static inline u128_u u64_to_u128(u64 a)
24 {
25 	return (u128_u) { .v = a };
26 }
27 
28 static inline u64 u128_lo(u128_u a)
29 {
30 	return a.v;
31 }
32 
33 static inline u64 u128_hi(u128_u a)
34 {
35 	return a.v >> 64;
36 }
37 
38 static inline u128_u u128_add(u128_u a, u128_u b)
39 {
40 	a.v += b.v;
41 	return a;
42 }
43 
44 static inline u128_u u128_sub(u128_u a, u128_u b)
45 {
46 	a.v -= b.v;
47 	return a;
48 }
49 
50 static inline u128_u u128_shl(u128_u a, s8 shift)
51 {
52 	a.v <<= shift;
53 	return a;
54 }
55 
56 static inline u128_u u128_square(u64 a)
57 {
58 	u128_u b = u64_to_u128(a);
59 
60 	b.v *= b.v;
61 	return b;
62 }
63 
64 #else
65 
66 typedef struct {
67 	u64 hi, lo;
68 } __aligned(16) u128_u;
69 
70 /* conversions */
71 
72 static inline u128_u u64_to_u128(u64 a)
73 {
74 	return (u128_u) { .lo = a };
75 }
76 
77 static inline u64 u128_lo(u128_u a)
78 {
79 	return a.lo;
80 }
81 
82 static inline u64 u128_hi(u128_u a)
83 {
84 	return a.hi;
85 }
86 
87 /* arithmetic */
88 
89 static inline u128_u u128_add(u128_u a, u128_u b)
90 {
91 	u128_u c;
92 
93 	c.lo = a.lo + b.lo;
94 	c.hi = a.hi + b.hi + (c.lo < a.lo);
95 	return c;
96 }
97 
98 static inline u128_u u128_sub(u128_u a, u128_u b)
99 {
100 	u128_u c;
101 
102 	c.lo = a.lo - b.lo;
103 	c.hi = a.hi - b.hi - (c.lo > a.lo);
104 	return c;
105 }
106 
107 static inline u128_u u128_shl(u128_u i, s8 shift)
108 {
109 	u128_u r;
110 
111 	r.lo = i.lo << shift;
112 	if (shift < 64)
113 		r.hi = (i.hi << shift) | (i.lo >> (64 - shift));
114 	else {
115 		r.hi = i.lo << (shift - 64);
116 		r.lo = 0;
117 	}
118 	return r;
119 }
120 
121 static inline u128_u u128_square(u64 i)
122 {
123 	u128_u r;
124 	u64  h = i >> 32, l = i & U32_MAX;
125 
126 	r =             u128_shl(u64_to_u128(h*h), 64);
127 	r = u128_add(r, u128_shl(u64_to_u128(h*l), 32));
128 	r = u128_add(r, u128_shl(u64_to_u128(l*h), 32));
129 	r = u128_add(r,          u64_to_u128(l*l));
130 	return r;
131 }
132 
133 #endif
134 
135 static inline u128_u u64s_to_u128(u64 hi, u64 lo)
136 {
137 	u128_u c = u64_to_u128(hi);
138 
139 	c = u128_shl(c, 64);
140 	c = u128_add(c, u64_to_u128(lo));
141 	return c;
142 }
143 
144 u128_u u128_div(u128_u n, u64 d);
145 
146 struct mean_and_variance {
147 	s64	n;
148 	s64	sum;
149 	u128_u	sum_squares;
150 };
151 
152 /* expontentially weighted variant */
153 struct mean_and_variance_weighted {
154 	bool	init;
155 	u8	weight;	/* base 2 logarithim */
156 	s64	mean;
157 	u64	variance;
158 };
159 
160 /**
161  * fast_divpow2() - fast approximation for n / (1 << d)
162  * @n: numerator
163  * @d: the power of 2 denominator.
164  *
165  * note: this rounds towards 0.
166  */
167 static inline s64 fast_divpow2(s64 n, u8 d)
168 {
169 	return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d;
170 }
171 
172 /**
173  * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1
174  * and return it.
175  * @s1: the mean_and_variance to update.
176  * @v1: the new sample.
177  *
178  * see linked pdf equation 12.
179  */
180 static inline void
181 mean_and_variance_update(struct mean_and_variance *s, s64 v)
182 {
183 	s->n++;
184 	s->sum += v;
185 	s->sum_squares = u128_add(s->sum_squares, u128_square(abs(v)));
186 }
187 
188 s64 mean_and_variance_get_mean(struct mean_and_variance s);
189 u64 mean_and_variance_get_variance(struct mean_and_variance s1);
190 u32 mean_and_variance_get_stddev(struct mean_and_variance s);
191 
192 void mean_and_variance_weighted_update(struct mean_and_variance_weighted *s, s64 v);
193 
194 s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s);
195 u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s);
196 u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s);
197 
198 #endif // MEAN_AND_VAIRANCE_H_
199