1 // SPDX-License-Identifier: GPL-2.0 2 #include <kunit/test.h> 3 4 #include "mean_and_variance.h" 5 6 #define MAX_SQR (SQRT_U64_MAX*SQRT_U64_MAX) 7 8 static void mean_and_variance_basic_test(struct kunit *test) 9 { 10 struct mean_and_variance s = {}; 11 12 mean_and_variance_update(&s, 2); 13 mean_and_variance_update(&s, 2); 14 15 KUNIT_EXPECT_EQ(test, mean_and_variance_get_mean(s), 2); 16 KUNIT_EXPECT_EQ(test, mean_and_variance_get_variance(s), 0); 17 KUNIT_EXPECT_EQ(test, s.n, 2); 18 19 mean_and_variance_update(&s, 4); 20 mean_and_variance_update(&s, 4); 21 22 KUNIT_EXPECT_EQ(test, mean_and_variance_get_mean(s), 3); 23 KUNIT_EXPECT_EQ(test, mean_and_variance_get_variance(s), 1); 24 KUNIT_EXPECT_EQ(test, s.n, 4); 25 } 26 27 /* 28 * Test values computed using a spreadsheet from the psuedocode at the bottom: 29 * https://fanf2.user.srcf.net/hermes/doc/antiforgery/stats.pdf 30 */ 31 32 static void mean_and_variance_weighted_test(struct kunit *test) 33 { 34 struct mean_and_variance_weighted s = { }; 35 36 mean_and_variance_weighted_update(&s, 10, false, 2); 37 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s, 2), 10); 38 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s, 2), 0); 39 40 mean_and_variance_weighted_update(&s, 20, true, 2); 41 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s, 2), 12); 42 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s, 2), 18); 43 44 mean_and_variance_weighted_update(&s, 30, true, 2); 45 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s, 2), 16); 46 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s, 2), 72); 47 48 s = (struct mean_and_variance_weighted) { }; 49 50 mean_and_variance_weighted_update(&s, -10, false, 2); 51 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s, 2), -10); 52 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s, 2), 0); 53 54 mean_and_variance_weighted_update(&s, -20, true, 2); 55 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s, 2), -12); 56 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s, 2), 18); 57 58 mean_and_variance_weighted_update(&s, -30, true, 2); 59 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s, 2), -16); 60 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s, 2), 72); 61 } 62 63 static void mean_and_variance_weighted_advanced_test(struct kunit *test) 64 { 65 struct mean_and_variance_weighted s = { }; 66 bool initted = false; 67 s64 i; 68 69 for (i = 10; i <= 100; i += 10) { 70 mean_and_variance_weighted_update(&s, i, initted, 8); 71 initted = true; 72 } 73 74 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s, 8), 11); 75 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s, 8), 107); 76 77 s = (struct mean_and_variance_weighted) { }; 78 initted = false; 79 80 for (i = -10; i >= -100; i -= 10) { 81 mean_and_variance_weighted_update(&s, i, initted, 8); 82 initted = true; 83 } 84 85 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s, 8), -11); 86 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s, 8), 107); 87 } 88 89 static void do_mean_and_variance_test(struct kunit *test, 90 s64 initial_value, 91 s64 initial_n, 92 s64 n, 93 unsigned weight, 94 s64 *data, 95 s64 *mean, 96 s64 *stddev, 97 s64 *weighted_mean, 98 s64 *weighted_stddev) 99 { 100 struct mean_and_variance mv = {}; 101 struct mean_and_variance_weighted vw = { }; 102 103 for (unsigned i = 0; i < initial_n; i++) { 104 mean_and_variance_update(&mv, initial_value); 105 mean_and_variance_weighted_update(&vw, initial_value, false, weight); 106 107 KUNIT_EXPECT_EQ(test, mean_and_variance_get_mean(mv), initial_value); 108 KUNIT_EXPECT_EQ(test, mean_and_variance_get_stddev(mv), 0); 109 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(vw, weight), initial_value); 110 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_stddev(vw, weight),0); 111 } 112 113 for (unsigned i = 0; i < n; i++) { 114 mean_and_variance_update(&mv, data[i]); 115 mean_and_variance_weighted_update(&vw, data[i], true, weight); 116 117 KUNIT_EXPECT_EQ(test, mean_and_variance_get_mean(mv), mean[i]); 118 KUNIT_EXPECT_EQ(test, mean_and_variance_get_stddev(mv), stddev[i]); 119 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(vw, weight), weighted_mean[i]); 120 KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_stddev(vw, weight),weighted_stddev[i]); 121 } 122 123 KUNIT_EXPECT_EQ(test, mv.n, initial_n + n); 124 } 125 126 /* Test behaviour with a single outlier, then back to steady state: */ 127 static void mean_and_variance_test_1(struct kunit *test) 128 { 129 s64 d[] = { 100, 10, 10, 10, 10, 10, 10 }; 130 s64 mean[] = { 22, 21, 20, 19, 18, 17, 16 }; 131 s64 stddev[] = { 32, 29, 28, 27, 26, 25, 24 }; 132 s64 weighted_mean[] = { 32, 27, 22, 19, 17, 15, 14 }; 133 s64 weighted_stddev[] = { 38, 35, 31, 27, 24, 21, 18 }; 134 135 do_mean_and_variance_test(test, 10, 6, ARRAY_SIZE(d), 2, 136 d, mean, stddev, weighted_mean, weighted_stddev); 137 } 138 139 static void mean_and_variance_test_2(struct kunit *test) 140 { 141 s64 d[] = { 100, 10, 10, 10, 10, 10, 10 }; 142 s64 mean[] = { 10, 10, 10, 10, 10, 10, 10 }; 143 s64 stddev[] = { 9, 9, 9, 9, 9, 9, 9 }; 144 s64 weighted_mean[] = { 32, 27, 22, 19, 17, 15, 14 }; 145 s64 weighted_stddev[] = { 38, 35, 31, 27, 24, 21, 18 }; 146 147 do_mean_and_variance_test(test, 10, 6, ARRAY_SIZE(d), 2, 148 d, mean, stddev, weighted_mean, weighted_stddev); 149 } 150 151 /* Test behaviour where we switch from one steady state to another: */ 152 static void mean_and_variance_test_3(struct kunit *test) 153 { 154 s64 d[] = { 100, 100, 100, 100, 100 }; 155 s64 mean[] = { 22, 32, 40, 46, 50 }; 156 s64 stddev[] = { 32, 39, 42, 44, 45 }; 157 s64 weighted_mean[] = { 32, 49, 61, 71, 78 }; 158 s64 weighted_stddev[] = { 38, 44, 44, 41, 38 }; 159 160 do_mean_and_variance_test(test, 10, 6, ARRAY_SIZE(d), 2, 161 d, mean, stddev, weighted_mean, weighted_stddev); 162 } 163 164 static void mean_and_variance_test_4(struct kunit *test) 165 { 166 s64 d[] = { 100, 100, 100, 100, 100 }; 167 s64 mean[] = { 10, 11, 12, 13, 14 }; 168 s64 stddev[] = { 9, 13, 15, 17, 19 }; 169 s64 weighted_mean[] = { 32, 49, 61, 71, 78 }; 170 s64 weighted_stddev[] = { 38, 44, 44, 41, 38 }; 171 172 do_mean_and_variance_test(test, 10, 6, ARRAY_SIZE(d), 2, 173 d, mean, stddev, weighted_mean, weighted_stddev); 174 } 175 176 static void mean_and_variance_fast_divpow2(struct kunit *test) 177 { 178 s64 i; 179 u8 d; 180 181 for (i = 0; i < 100; i++) { 182 d = 0; 183 KUNIT_EXPECT_EQ(test, fast_divpow2(i, d), div_u64(i, 1LLU << d)); 184 KUNIT_EXPECT_EQ(test, abs(fast_divpow2(-i, d)), div_u64(i, 1LLU << d)); 185 for (d = 1; d < 32; d++) { 186 KUNIT_EXPECT_EQ_MSG(test, abs(fast_divpow2(i, d)), 187 div_u64(i, 1 << d), "%lld %u", i, d); 188 KUNIT_EXPECT_EQ_MSG(test, abs(fast_divpow2(-i, d)), 189 div_u64(i, 1 << d), "%lld %u", -i, d); 190 } 191 } 192 } 193 194 static void mean_and_variance_u128_basic_test(struct kunit *test) 195 { 196 u128_u a = u64s_to_u128(0, U64_MAX); 197 u128_u a1 = u64s_to_u128(0, 1); 198 u128_u b = u64s_to_u128(1, 0); 199 u128_u c = u64s_to_u128(0, 1LLU << 63); 200 u128_u c2 = u64s_to_u128(U64_MAX, U64_MAX); 201 202 KUNIT_EXPECT_EQ(test, u128_hi(u128_add(a, a1)), 1); 203 KUNIT_EXPECT_EQ(test, u128_lo(u128_add(a, a1)), 0); 204 KUNIT_EXPECT_EQ(test, u128_hi(u128_add(a1, a)), 1); 205 KUNIT_EXPECT_EQ(test, u128_lo(u128_add(a1, a)), 0); 206 207 KUNIT_EXPECT_EQ(test, u128_lo(u128_sub(b, a1)), U64_MAX); 208 KUNIT_EXPECT_EQ(test, u128_hi(u128_sub(b, a1)), 0); 209 210 KUNIT_EXPECT_EQ(test, u128_hi(u128_shl(c, 1)), 1); 211 KUNIT_EXPECT_EQ(test, u128_lo(u128_shl(c, 1)), 0); 212 213 KUNIT_EXPECT_EQ(test, u128_hi(u128_square(U64_MAX)), U64_MAX - 1); 214 KUNIT_EXPECT_EQ(test, u128_lo(u128_square(U64_MAX)), 1); 215 216 KUNIT_EXPECT_EQ(test, u128_lo(u128_div(b, 2)), 1LLU << 63); 217 218 KUNIT_EXPECT_EQ(test, u128_hi(u128_div(c2, 2)), U64_MAX >> 1); 219 KUNIT_EXPECT_EQ(test, u128_lo(u128_div(c2, 2)), U64_MAX); 220 221 KUNIT_EXPECT_EQ(test, u128_hi(u128_div(u128_shl(u64_to_u128(U64_MAX), 32), 2)), U32_MAX >> 1); 222 KUNIT_EXPECT_EQ(test, u128_lo(u128_div(u128_shl(u64_to_u128(U64_MAX), 32), 2)), U64_MAX << 31); 223 } 224 225 static struct kunit_case mean_and_variance_test_cases[] = { 226 KUNIT_CASE(mean_and_variance_fast_divpow2), 227 KUNIT_CASE(mean_and_variance_u128_basic_test), 228 KUNIT_CASE(mean_and_variance_basic_test), 229 KUNIT_CASE(mean_and_variance_weighted_test), 230 KUNIT_CASE(mean_and_variance_weighted_advanced_test), 231 KUNIT_CASE(mean_and_variance_test_1), 232 KUNIT_CASE(mean_and_variance_test_2), 233 KUNIT_CASE(mean_and_variance_test_3), 234 KUNIT_CASE(mean_and_variance_test_4), 235 {} 236 }; 237 238 static struct kunit_suite mean_and_variance_test_suite = { 239 .name = "mean and variance tests", 240 .test_cases = mean_and_variance_test_cases 241 }; 242 243 kunit_test_suite(mean_and_variance_test_suite); 244 245 MODULE_AUTHOR("Daniel B. Hill"); 246 MODULE_LICENSE("GPL"); 247