xref: /linux/fs/bcachefs/mean_and_variance_test.c (revision 16e5ac127d8d18adf85fe5ba847d77b58d1ed418)
1 // SPDX-License-Identifier: GPL-2.0
2 #include <kunit/test.h>
3 
4 #include "mean_and_variance.h"
5 
6 #define MAX_SQR (SQRT_U64_MAX*SQRT_U64_MAX)
7 
8 static void mean_and_variance_basic_test(struct kunit *test)
9 {
10 	struct mean_and_variance s = {};
11 
12 	mean_and_variance_update(&s, 2);
13 	mean_and_variance_update(&s, 2);
14 
15 	KUNIT_EXPECT_EQ(test, mean_and_variance_get_mean(s), 2);
16 	KUNIT_EXPECT_EQ(test, mean_and_variance_get_variance(s), 0);
17 	KUNIT_EXPECT_EQ(test, s.n, 2);
18 
19 	mean_and_variance_update(&s, 4);
20 	mean_and_variance_update(&s, 4);
21 
22 	KUNIT_EXPECT_EQ(test, mean_and_variance_get_mean(s), 3);
23 	KUNIT_EXPECT_EQ(test, mean_and_variance_get_variance(s), 1);
24 	KUNIT_EXPECT_EQ(test, s.n, 4);
25 }
26 
27 /*
28  * Test values computed using a spreadsheet from the psuedocode at the bottom:
29  * https://fanf2.user.srcf.net/hermes/doc/antiforgery/stats.pdf
30  */
31 
32 static void mean_and_variance_weighted_test(struct kunit *test)
33 {
34 	struct mean_and_variance_weighted s = { .weight = 2 };
35 
36 	mean_and_variance_weighted_update(&s, 10);
37 	KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), 10);
38 	KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 0);
39 
40 	mean_and_variance_weighted_update(&s, 20);
41 	KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), 12);
42 	KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 18);
43 
44 	mean_and_variance_weighted_update(&s, 30);
45 	KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), 16);
46 	KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 72);
47 
48 	s = (struct mean_and_variance_weighted) { .weight = 2 };
49 
50 	mean_and_variance_weighted_update(&s, -10);
51 	KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), -10);
52 	KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 0);
53 
54 	mean_and_variance_weighted_update(&s, -20);
55 	KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), -12);
56 	KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 18);
57 
58 	mean_and_variance_weighted_update(&s, -30);
59 	KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), -16);
60 	KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 72);
61 }
62 
63 static void mean_and_variance_weighted_advanced_test(struct kunit *test)
64 {
65 	struct mean_and_variance_weighted s = { .weight = 8 };
66 	s64 i;
67 
68 	for (i = 10; i <= 100; i += 10)
69 		mean_and_variance_weighted_update(&s, i);
70 
71 	KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), 11);
72 	KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 107);
73 
74 	s = (struct mean_and_variance_weighted) { .weight = 8 };
75 
76 	for (i = -10; i >= -100; i -= 10)
77 		mean_and_variance_weighted_update(&s, i);
78 
79 	KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), -11);
80 	KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 107);
81 }
82 
83 static void do_mean_and_variance_test(struct kunit *test,
84 				      s64 initial_value,
85 				      s64 initial_n,
86 				      s64 n,
87 				      unsigned weight,
88 				      s64 *data,
89 				      s64 *mean,
90 				      s64 *stddev,
91 				      s64 *weighted_mean,
92 				      s64 *weighted_stddev)
93 {
94 	struct mean_and_variance mv = {};
95 	struct mean_and_variance_weighted vw = { .weight = weight };
96 
97 	for (unsigned i = 0; i < initial_n; i++) {
98 		mean_and_variance_update(&mv, initial_value);
99 		mean_and_variance_weighted_update(&vw, initial_value);
100 
101 		KUNIT_EXPECT_EQ(test, mean_and_variance_get_mean(mv),		initial_value);
102 		KUNIT_EXPECT_EQ(test, mean_and_variance_get_stddev(mv),		0);
103 		KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(vw),	initial_value);
104 		KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_stddev(vw),0);
105 	}
106 
107 	for (unsigned i = 0; i < n; i++) {
108 		mean_and_variance_update(&mv, data[i]);
109 		mean_and_variance_weighted_update(&vw, data[i]);
110 
111 		KUNIT_EXPECT_EQ(test, mean_and_variance_get_mean(mv),		mean[i]);
112 		KUNIT_EXPECT_EQ(test, mean_and_variance_get_stddev(mv),		stddev[i]);
113 		KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(vw),	weighted_mean[i]);
114 		KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_stddev(vw),weighted_stddev[i]);
115 	}
116 
117 	KUNIT_EXPECT_EQ(test, mv.n, initial_n + n);
118 }
119 
120 /* Test behaviour with a single outlier, then back to steady state: */
121 static void mean_and_variance_test_1(struct kunit *test)
122 {
123 	s64 d[]			= { 100, 10, 10, 10, 10, 10, 10 };
124 	s64 mean[]		= {  22, 21, 20, 19, 18, 17, 16 };
125 	s64 stddev[]		= {  32, 29, 28, 27, 26, 25, 24 };
126 	s64 weighted_mean[]	= {  32, 27, 22, 19, 17, 15, 14 };
127 	s64 weighted_stddev[]	= {  38, 35, 31, 27, 24, 21, 18 };
128 
129 	do_mean_and_variance_test(test, 10, 6, ARRAY_SIZE(d), 2,
130 			d, mean, stddev, weighted_mean, weighted_stddev);
131 }
132 
133 static void mean_and_variance_test_2(struct kunit *test)
134 {
135 	s64 d[]			= { 100, 10, 10, 10, 10, 10, 10 };
136 	s64 mean[]		= {  10, 10, 10, 10, 10, 10, 10 };
137 	s64 stddev[]		= {   9,  9,  9,  9,  9,  9,  9 };
138 	s64 weighted_mean[]	= {  32, 27, 22, 19, 17, 15, 14 };
139 	s64 weighted_stddev[]	= {  38, 35, 31, 27, 24, 21, 18 };
140 
141 	do_mean_and_variance_test(test, 10, 6, ARRAY_SIZE(d), 2,
142 			d, mean, stddev, weighted_mean, weighted_stddev);
143 }
144 
145 /* Test behaviour where we switch from one steady state to another: */
146 static void mean_and_variance_test_3(struct kunit *test)
147 {
148 	s64 d[]			= { 100, 100, 100, 100, 100 };
149 	s64 mean[]		= {  22,  32,  40,  46,  50 };
150 	s64 stddev[]		= {  32,  39,  42,  44,  45 };
151 	s64 weighted_mean[]	= {  32,  49,  61,  71,  78 };
152 	s64 weighted_stddev[]	= {  38,  44,  44,  41,  38 };
153 
154 	do_mean_and_variance_test(test, 10, 6, ARRAY_SIZE(d), 2,
155 			d, mean, stddev, weighted_mean, weighted_stddev);
156 }
157 
158 static void mean_and_variance_test_4(struct kunit *test)
159 {
160 	s64 d[]			= { 100, 100, 100, 100, 100 };
161 	s64 mean[]		= {  10,  11,  12,  13,  14 };
162 	s64 stddev[]		= {   9,  13,  15,  17,  19 };
163 	s64 weighted_mean[]	= {  32,  49,  61,  71,  78 };
164 	s64 weighted_stddev[]	= {  38,  44,  44,  41,  38 };
165 
166 	do_mean_and_variance_test(test, 10, 6, ARRAY_SIZE(d), 2,
167 			d, mean, stddev, weighted_mean, weighted_stddev);
168 }
169 
170 static void mean_and_variance_fast_divpow2(struct kunit *test)
171 {
172 	s64 i;
173 	u8 d;
174 
175 	for (i = 0; i < 100; i++) {
176 		d = 0;
177 		KUNIT_EXPECT_EQ(test, fast_divpow2(i, d), div_u64(i, 1LLU << d));
178 		KUNIT_EXPECT_EQ(test, abs(fast_divpow2(-i, d)), div_u64(i, 1LLU << d));
179 		for (d = 1; d < 32; d++) {
180 			KUNIT_EXPECT_EQ_MSG(test, abs(fast_divpow2(i, d)),
181 					    div_u64(i, 1 << d), "%lld %u", i, d);
182 			KUNIT_EXPECT_EQ_MSG(test, abs(fast_divpow2(-i, d)),
183 					    div_u64(i, 1 << d), "%lld %u", -i, d);
184 		}
185 	}
186 }
187 
188 static void mean_and_variance_u128_basic_test(struct kunit *test)
189 {
190 	u128_u a  = u64s_to_u128(0, U64_MAX);
191 	u128_u a1 = u64s_to_u128(0, 1);
192 	u128_u b  = u64s_to_u128(1, 0);
193 	u128_u c  = u64s_to_u128(0, 1LLU << 63);
194 	u128_u c2 = u64s_to_u128(U64_MAX, U64_MAX);
195 
196 	KUNIT_EXPECT_EQ(test, u128_hi(u128_add(a, a1)), 1);
197 	KUNIT_EXPECT_EQ(test, u128_lo(u128_add(a, a1)), 0);
198 	KUNIT_EXPECT_EQ(test, u128_hi(u128_add(a1, a)), 1);
199 	KUNIT_EXPECT_EQ(test, u128_lo(u128_add(a1, a)), 0);
200 
201 	KUNIT_EXPECT_EQ(test, u128_lo(u128_sub(b, a1)), U64_MAX);
202 	KUNIT_EXPECT_EQ(test, u128_hi(u128_sub(b, a1)), 0);
203 
204 	KUNIT_EXPECT_EQ(test, u128_hi(u128_shl(c, 1)), 1);
205 	KUNIT_EXPECT_EQ(test, u128_lo(u128_shl(c, 1)), 0);
206 
207 	KUNIT_EXPECT_EQ(test, u128_hi(u128_square(U64_MAX)), U64_MAX - 1);
208 	KUNIT_EXPECT_EQ(test, u128_lo(u128_square(U64_MAX)), 1);
209 
210 	KUNIT_EXPECT_EQ(test, u128_lo(u128_div(b, 2)), 1LLU << 63);
211 
212 	KUNIT_EXPECT_EQ(test, u128_hi(u128_div(c2, 2)), U64_MAX >> 1);
213 	KUNIT_EXPECT_EQ(test, u128_lo(u128_div(c2, 2)), U64_MAX);
214 
215 	KUNIT_EXPECT_EQ(test, u128_hi(u128_div(u128_shl(u64_to_u128(U64_MAX), 32), 2)), U32_MAX >> 1);
216 	KUNIT_EXPECT_EQ(test, u128_lo(u128_div(u128_shl(u64_to_u128(U64_MAX), 32), 2)), U64_MAX << 31);
217 }
218 
219 static struct kunit_case mean_and_variance_test_cases[] = {
220 	KUNIT_CASE(mean_and_variance_fast_divpow2),
221 	KUNIT_CASE(mean_and_variance_u128_basic_test),
222 	KUNIT_CASE(mean_and_variance_basic_test),
223 	KUNIT_CASE(mean_and_variance_weighted_test),
224 	KUNIT_CASE(mean_and_variance_weighted_advanced_test),
225 	KUNIT_CASE(mean_and_variance_test_1),
226 	KUNIT_CASE(mean_and_variance_test_2),
227 	KUNIT_CASE(mean_and_variance_test_3),
228 	KUNIT_CASE(mean_and_variance_test_4),
229 	{}
230 };
231 
232 static struct kunit_suite mean_and_variance_test_suite = {
233 	.name		= "mean and variance tests",
234 	.test_cases	= mean_and_variance_test_cases
235 };
236 
237 kunit_test_suite(mean_and_variance_test_suite);
238 
239 MODULE_AUTHOR("Daniel B. Hill");
240 MODULE_LICENSE("GPL");
241