xref: /freebsd/tools/test/arc4random/biastest.c (revision 4a0fc138e5eb343e45388e66698a4765b308a622)
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause
3  *
4  * Copyright (c) 2024 Robert Clausecker <fuz@FreeBSD.org>
5  *
6  * biastest.c -- bias test for arc4random_uniform().
7  *
8  * The default configuration of this test has an upper bound of
9  * (3/4) * UINT32_MAX, which should give a high amount of bias in
10  * an incorrect implementation.  If the range reduction is
11  * implemented correctly, the parameters of the statistic should
12  * closely match the expected values.  If not, they'll differ.
13  *
14  * For memory usage reasons, we use an uchar to track the number of
15  * observations per bucket.  If the number of tries is much larger
16  * than upper_bound, the buckets likely overflow.  This is detected
17  * by the test, but will lead to incorrect results.
18  */
19 
20 #include <assert.h>
21 #include <limits.h>
22 #include <math.h>
23 #include <signal.h>
24 #include <stdatomic.h>
25 #include <stdbool.h>
26 #include <stdint.h>
27 #include <stdio.h>
28 #include <stdlib.h>
29 #include <unistd.h>
30 
31 static void	collect_sample(unsigned char *, long long, uint32_t);
32 static void	analyze_sample(const unsigned char *, long long, uint32_t);
33 
34 static atomic_bool complete = false;
35 static long long tries = 5ULL << 32;
36 static atomic_llong tries_done = 0;
37 
38 static void
usage(const char * argv0)39 usage(const char *argv0)
40 {
41 	fprintf(stderr, "usage: %s [-n tries] [-t upper_bound]\n", argv0);
42 	exit(EXIT_FAILURE);
43 }
44 
45 int
main(int argc,char * argv[])46 main(int argc, char *argv[])
47 {
48 	uint32_t threshold = 3UL << 30;
49 	int ch;
50 	unsigned char *sample;
51 
52 	while (ch = getopt(argc, argv, "n:t:"), ch != EOF)
53 		switch (ch) {
54 		case 'n':
55 			tries = atoll(optarg);
56 			break;
57 
58 		case 't':
59 			threshold = (uint32_t)atoll(optarg);
60 			break;
61 
62 		default:
63 			usage(argv[0]);
64 		}
65 
66 	if (optind != argc)
67 		usage(argv[0]);
68 
69 	if (threshold == 0) {
70 		fprintf(stderr, "threshold must be between 1 and %lu\n", (unsigned long)UINT32_MAX);
71 		exit(EXIT_FAILURE);
72 	}
73 
74 	sample = calloc(threshold, 1);
75 	if (sample == NULL) {
76 		perror("calloc(threshold, 1)");
77 		return (EXIT_FAILURE);
78 	}
79 
80 	collect_sample(sample, tries, threshold);
81 	analyze_sample(sample, tries, threshold);
82 }
83 
84 static void
progress(int signo)85 progress(int signo)
86 {
87 	(void)signo;
88 
89 	if (!complete) {
90 		fprintf(stderr, "\r%10lld of %10lld samples taken (%5.2f%% done)",
91 		    tries_done, tries, (tries_done * 100.0) / tries);
92 
93 		signal(SIGALRM, progress);
94 		alarm(1);
95 	}
96 }
97 
98 static void
collect_sample(unsigned char * sample,long long tries,uint32_t threshold)99 collect_sample(unsigned char *sample, long long tries, uint32_t threshold)
100 {
101 	long long i;
102 	uint32_t x;
103 	bool overflowed = false;
104 
105 	progress(SIGALRM);
106 
107 	for (i = 0; i < tries; i++) {
108 		x = arc4random_uniform(threshold);
109 		tries_done++;
110 		assert(x < threshold);
111 
112 		if (sample[x] == UCHAR_MAX) {
113 			if (!overflowed) {
114 				printf("sample table overflow, results will be incorrect\n");
115 				overflowed = true;
116 			}
117 		} else
118 			sample[x]++;
119 	}
120 
121 	progress(SIGALRM);
122 	complete = true;
123 	fputc('\n', stderr);
124 }
125 
126 static void
analyze_sample(const unsigned char * sample,long long tries,uint32_t threshold)127 analyze_sample(const unsigned char *sample, long long tries,  uint32_t threshold)
128 {
129 	double discrepancy, average, variance, total;
130 	long long histogram[UCHAR_MAX + 1] = { 0 }, sum, n, median;
131 	uint32_t i, i_min, i_max;
132 	int min, max;
133 
134 	printf("distribution properties:\n");
135 
136 	/* find median, average, deviation, smallest, and largest bucket */
137 	total = 0.0;
138 	for (i = 0; i < threshold; i++) {
139 		histogram[sample[i]]++;
140 		total += (double)i * sample[i];
141 	}
142 
143 	average = total / tries;
144 
145 	variance = 0.0;
146 	median = threshold;
147 	n = 0;
148 	i_min = 0;
149 	i_max = 0;
150 	min = sample[i_min];
151 	max = sample[i_max];
152 
153 	for (i = 0; i < threshold; i++) {
154 		discrepancy = i - average;
155 		variance += sample[i] * discrepancy * discrepancy;
156 
157 		n += sample[i];
158 		if (median == threshold && n > tries / 2)
159 			median = i;
160 
161 		if (sample[i] < min) {
162 			i_min = i;
163 			min = sample[i_min];
164 		} else if (sample[i] > max) {
165 			i_max = i;
166 			max = sample[i_max];
167 		}
168 	}
169 
170 	variance /= tries;
171 	assert(median < threshold);
172 
173 	printf("\tthreshold:	%lu\n", (unsigned long)threshold);
174 	printf("\tobservations:	%lld\n", tries);
175 	printf("\tleast common:	%lu (%d observations)\n", (unsigned long)i_min, min);
176 	printf("\tmost common:	%lu (%d observations)\n", (unsigned long)i_max, max);
177 	printf("\tmedian:		%lld (expected %lu)\n", median, (unsigned long)threshold / 2);
178 	printf("\taverage:	%f (expected %f)\n", average, 0.5 * (threshold - 1));
179 	printf("\tdeviation:	%f (expected %f)\n\n", sqrt(variance),
180 	    sqrt(((double)threshold * threshold - 1.0) / 12));
181 
182 	/* build histogram and analyze it */
183 	printf("sample properties:\n");
184 
185 	/* find median, average, and deviation */
186 	average = (double)tries / threshold;
187 
188 	variance = 0.0;
189 	for (i = 0; i < UCHAR_MAX; i++) {
190 		discrepancy = i - average;
191 		variance += histogram[i] * discrepancy * discrepancy;
192 	}
193 
194 	variance /= threshold;
195 
196 	n = 0;
197 	median = UCHAR_MAX + 1;
198 	for (i = 0; i <= UCHAR_MAX; i++) {
199 		n += histogram[i];
200 		if (n >= threshold / 2) {
201 			median = i;
202 			break;
203 		}
204 	}
205 
206 	assert(median <= UCHAR_MAX); /* unreachable */
207 
208 	printf("\tmedian:		%lld\n", median);
209 	printf("\taverage:	%f\n", average);
210 	printf("\tdeviation:	%f (expected %f)\n\n", sqrt(variance), sqrt(average * (1.0 - 1.0 / threshold)));
211 
212 	printf("histogram:\n");
213 	for (i = 0; i < 256; i++)
214 		if (histogram[i] != 0)
215 			printf("\t%3d:\t%lld\n", (int)i, histogram[i]);
216 }
217