xref: /freebsd/sys/contrib/zstd/lib/dictBuilder/cover.c (revision 0c16b53773565120a8f80a31a0af2ef56ccd368e)
1*0c16b537SWarner Losh /*
2*0c16b537SWarner Losh  * Copyright (c) 2016-present, Yann Collet, Facebook, Inc.
3*0c16b537SWarner Losh  * All rights reserved.
4*0c16b537SWarner Losh  *
5*0c16b537SWarner Losh  * This source code is licensed under both the BSD-style license (found in the
6*0c16b537SWarner Losh  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7*0c16b537SWarner Losh  * in the COPYING file in the root directory of this source tree).
8*0c16b537SWarner Losh  * You may select, at your option, one of the above-listed licenses.
9*0c16b537SWarner Losh  */
10*0c16b537SWarner Losh 
11*0c16b537SWarner Losh /* *****************************************************************************
12*0c16b537SWarner Losh  * Constructs a dictionary using a heuristic based on the following paper:
13*0c16b537SWarner Losh  *
14*0c16b537SWarner Losh  * Liao, Petri, Moffat, Wirth
15*0c16b537SWarner Losh  * Effective Construction of Relative Lempel-Ziv Dictionaries
16*0c16b537SWarner Losh  * Published in WWW 2016.
17*0c16b537SWarner Losh  *
18*0c16b537SWarner Losh  * Adapted from code originally written by @ot (Giuseppe Ottaviano).
19*0c16b537SWarner Losh  ******************************************************************************/
20*0c16b537SWarner Losh 
21*0c16b537SWarner Losh /*-*************************************
22*0c16b537SWarner Losh *  Dependencies
23*0c16b537SWarner Losh ***************************************/
24*0c16b537SWarner Losh #include <stdio.h>  /* fprintf */
25*0c16b537SWarner Losh #include <stdlib.h> /* malloc, free, qsort */
26*0c16b537SWarner Losh #include <string.h> /* memset */
27*0c16b537SWarner Losh #include <time.h>   /* clock */
28*0c16b537SWarner Losh 
29*0c16b537SWarner Losh #include "mem.h" /* read */
30*0c16b537SWarner Losh #include "pool.h"
31*0c16b537SWarner Losh #include "threading.h"
32*0c16b537SWarner Losh #include "zstd_internal.h" /* includes zstd.h */
33*0c16b537SWarner Losh #ifndef ZDICT_STATIC_LINKING_ONLY
34*0c16b537SWarner Losh #define ZDICT_STATIC_LINKING_ONLY
35*0c16b537SWarner Losh #endif
36*0c16b537SWarner Losh #include "zdict.h"
37*0c16b537SWarner Losh 
38*0c16b537SWarner Losh /*-*************************************
39*0c16b537SWarner Losh *  Constants
40*0c16b537SWarner Losh ***************************************/
41*0c16b537SWarner Losh #define COVER_MAX_SAMPLES_SIZE (sizeof(size_t) == 8 ? ((U32)-1) : ((U32)1 GB))
42*0c16b537SWarner Losh 
43*0c16b537SWarner Losh /*-*************************************
44*0c16b537SWarner Losh *  Console display
45*0c16b537SWarner Losh ***************************************/
46*0c16b537SWarner Losh static int g_displayLevel = 2;
47*0c16b537SWarner Losh #define DISPLAY(...)                                                           \
48*0c16b537SWarner Losh   {                                                                            \
49*0c16b537SWarner Losh     fprintf(stderr, __VA_ARGS__);                                              \
50*0c16b537SWarner Losh     fflush(stderr);                                                            \
51*0c16b537SWarner Losh   }
52*0c16b537SWarner Losh #define LOCALDISPLAYLEVEL(displayLevel, l, ...)                                \
53*0c16b537SWarner Losh   if (displayLevel >= l) {                                                     \
54*0c16b537SWarner Losh     DISPLAY(__VA_ARGS__);                                                      \
55*0c16b537SWarner Losh   } /* 0 : no display;   1: errors;   2: default;  3: details;  4: debug */
56*0c16b537SWarner Losh #define DISPLAYLEVEL(l, ...) LOCALDISPLAYLEVEL(g_displayLevel, l, __VA_ARGS__)
57*0c16b537SWarner Losh 
58*0c16b537SWarner Losh #define LOCALDISPLAYUPDATE(displayLevel, l, ...)                               \
59*0c16b537SWarner Losh   if (displayLevel >= l) {                                                     \
60*0c16b537SWarner Losh     if ((clock() - g_time > refreshRate) || (displayLevel >= 4)) {             \
61*0c16b537SWarner Losh       g_time = clock();                                                        \
62*0c16b537SWarner Losh       DISPLAY(__VA_ARGS__);                                                    \
63*0c16b537SWarner Losh     }                                                                          \
64*0c16b537SWarner Losh   }
65*0c16b537SWarner Losh #define DISPLAYUPDATE(l, ...) LOCALDISPLAYUPDATE(g_displayLevel, l, __VA_ARGS__)
66*0c16b537SWarner Losh static const clock_t refreshRate = CLOCKS_PER_SEC * 15 / 100;
67*0c16b537SWarner Losh static clock_t g_time = 0;
68*0c16b537SWarner Losh 
69*0c16b537SWarner Losh /*-*************************************
70*0c16b537SWarner Losh * Hash table
71*0c16b537SWarner Losh ***************************************
72*0c16b537SWarner Losh * A small specialized hash map for storing activeDmers.
73*0c16b537SWarner Losh * The map does not resize, so if it becomes full it will loop forever.
74*0c16b537SWarner Losh * Thus, the map must be large enough to store every value.
75*0c16b537SWarner Losh * The map implements linear probing and keeps its load less than 0.5.
76*0c16b537SWarner Losh */
77*0c16b537SWarner Losh 
78*0c16b537SWarner Losh #define MAP_EMPTY_VALUE ((U32)-1)
79*0c16b537SWarner Losh typedef struct COVER_map_pair_t_s {
80*0c16b537SWarner Losh   U32 key;
81*0c16b537SWarner Losh   U32 value;
82*0c16b537SWarner Losh } COVER_map_pair_t;
83*0c16b537SWarner Losh 
84*0c16b537SWarner Losh typedef struct COVER_map_s {
85*0c16b537SWarner Losh   COVER_map_pair_t *data;
86*0c16b537SWarner Losh   U32 sizeLog;
87*0c16b537SWarner Losh   U32 size;
88*0c16b537SWarner Losh   U32 sizeMask;
89*0c16b537SWarner Losh } COVER_map_t;
90*0c16b537SWarner Losh 
91*0c16b537SWarner Losh /**
92*0c16b537SWarner Losh  * Clear the map.
93*0c16b537SWarner Losh  */
94*0c16b537SWarner Losh static void COVER_map_clear(COVER_map_t *map) {
95*0c16b537SWarner Losh   memset(map->data, MAP_EMPTY_VALUE, map->size * sizeof(COVER_map_pair_t));
96*0c16b537SWarner Losh }
97*0c16b537SWarner Losh 
98*0c16b537SWarner Losh /**
99*0c16b537SWarner Losh  * Initializes a map of the given size.
100*0c16b537SWarner Losh  * Returns 1 on success and 0 on failure.
101*0c16b537SWarner Losh  * The map must be destroyed with COVER_map_destroy().
102*0c16b537SWarner Losh  * The map is only guaranteed to be large enough to hold size elements.
103*0c16b537SWarner Losh  */
104*0c16b537SWarner Losh static int COVER_map_init(COVER_map_t *map, U32 size) {
105*0c16b537SWarner Losh   map->sizeLog = ZSTD_highbit32(size) + 2;
106*0c16b537SWarner Losh   map->size = (U32)1 << map->sizeLog;
107*0c16b537SWarner Losh   map->sizeMask = map->size - 1;
108*0c16b537SWarner Losh   map->data = (COVER_map_pair_t *)malloc(map->size * sizeof(COVER_map_pair_t));
109*0c16b537SWarner Losh   if (!map->data) {
110*0c16b537SWarner Losh     map->sizeLog = 0;
111*0c16b537SWarner Losh     map->size = 0;
112*0c16b537SWarner Losh     return 0;
113*0c16b537SWarner Losh   }
114*0c16b537SWarner Losh   COVER_map_clear(map);
115*0c16b537SWarner Losh   return 1;
116*0c16b537SWarner Losh }
117*0c16b537SWarner Losh 
118*0c16b537SWarner Losh /**
119*0c16b537SWarner Losh  * Internal hash function
120*0c16b537SWarner Losh  */
121*0c16b537SWarner Losh static const U32 prime4bytes = 2654435761U;
122*0c16b537SWarner Losh static U32 COVER_map_hash(COVER_map_t *map, U32 key) {
123*0c16b537SWarner Losh   return (key * prime4bytes) >> (32 - map->sizeLog);
124*0c16b537SWarner Losh }
125*0c16b537SWarner Losh 
126*0c16b537SWarner Losh /**
127*0c16b537SWarner Losh  * Helper function that returns the index that a key should be placed into.
128*0c16b537SWarner Losh  */
129*0c16b537SWarner Losh static U32 COVER_map_index(COVER_map_t *map, U32 key) {
130*0c16b537SWarner Losh   const U32 hash = COVER_map_hash(map, key);
131*0c16b537SWarner Losh   U32 i;
132*0c16b537SWarner Losh   for (i = hash;; i = (i + 1) & map->sizeMask) {
133*0c16b537SWarner Losh     COVER_map_pair_t *pos = &map->data[i];
134*0c16b537SWarner Losh     if (pos->value == MAP_EMPTY_VALUE) {
135*0c16b537SWarner Losh       return i;
136*0c16b537SWarner Losh     }
137*0c16b537SWarner Losh     if (pos->key == key) {
138*0c16b537SWarner Losh       return i;
139*0c16b537SWarner Losh     }
140*0c16b537SWarner Losh   }
141*0c16b537SWarner Losh }
142*0c16b537SWarner Losh 
143*0c16b537SWarner Losh /**
144*0c16b537SWarner Losh  * Returns the pointer to the value for key.
145*0c16b537SWarner Losh  * If key is not in the map, it is inserted and the value is set to 0.
146*0c16b537SWarner Losh  * The map must not be full.
147*0c16b537SWarner Losh  */
148*0c16b537SWarner Losh static U32 *COVER_map_at(COVER_map_t *map, U32 key) {
149*0c16b537SWarner Losh   COVER_map_pair_t *pos = &map->data[COVER_map_index(map, key)];
150*0c16b537SWarner Losh   if (pos->value == MAP_EMPTY_VALUE) {
151*0c16b537SWarner Losh     pos->key = key;
152*0c16b537SWarner Losh     pos->value = 0;
153*0c16b537SWarner Losh   }
154*0c16b537SWarner Losh   return &pos->value;
155*0c16b537SWarner Losh }
156*0c16b537SWarner Losh 
157*0c16b537SWarner Losh /**
158*0c16b537SWarner Losh  * Deletes key from the map if present.
159*0c16b537SWarner Losh  */
160*0c16b537SWarner Losh static void COVER_map_remove(COVER_map_t *map, U32 key) {
161*0c16b537SWarner Losh   U32 i = COVER_map_index(map, key);
162*0c16b537SWarner Losh   COVER_map_pair_t *del = &map->data[i];
163*0c16b537SWarner Losh   U32 shift = 1;
164*0c16b537SWarner Losh   if (del->value == MAP_EMPTY_VALUE) {
165*0c16b537SWarner Losh     return;
166*0c16b537SWarner Losh   }
167*0c16b537SWarner Losh   for (i = (i + 1) & map->sizeMask;; i = (i + 1) & map->sizeMask) {
168*0c16b537SWarner Losh     COVER_map_pair_t *const pos = &map->data[i];
169*0c16b537SWarner Losh     /* If the position is empty we are done */
170*0c16b537SWarner Losh     if (pos->value == MAP_EMPTY_VALUE) {
171*0c16b537SWarner Losh       del->value = MAP_EMPTY_VALUE;
172*0c16b537SWarner Losh       return;
173*0c16b537SWarner Losh     }
174*0c16b537SWarner Losh     /* If pos can be moved to del do so */
175*0c16b537SWarner Losh     if (((i - COVER_map_hash(map, pos->key)) & map->sizeMask) >= shift) {
176*0c16b537SWarner Losh       del->key = pos->key;
177*0c16b537SWarner Losh       del->value = pos->value;
178*0c16b537SWarner Losh       del = pos;
179*0c16b537SWarner Losh       shift = 1;
180*0c16b537SWarner Losh     } else {
181*0c16b537SWarner Losh       ++shift;
182*0c16b537SWarner Losh     }
183*0c16b537SWarner Losh   }
184*0c16b537SWarner Losh }
185*0c16b537SWarner Losh 
186*0c16b537SWarner Losh /**
187*0c16b537SWarner Losh  * Destroyes a map that is inited with COVER_map_init().
188*0c16b537SWarner Losh  */
189*0c16b537SWarner Losh static void COVER_map_destroy(COVER_map_t *map) {
190*0c16b537SWarner Losh   if (map->data) {
191*0c16b537SWarner Losh     free(map->data);
192*0c16b537SWarner Losh   }
193*0c16b537SWarner Losh   map->data = NULL;
194*0c16b537SWarner Losh   map->size = 0;
195*0c16b537SWarner Losh }
196*0c16b537SWarner Losh 
197*0c16b537SWarner Losh /*-*************************************
198*0c16b537SWarner Losh * Context
199*0c16b537SWarner Losh ***************************************/
200*0c16b537SWarner Losh 
201*0c16b537SWarner Losh typedef struct {
202*0c16b537SWarner Losh   const BYTE *samples;
203*0c16b537SWarner Losh   size_t *offsets;
204*0c16b537SWarner Losh   const size_t *samplesSizes;
205*0c16b537SWarner Losh   size_t nbSamples;
206*0c16b537SWarner Losh   U32 *suffix;
207*0c16b537SWarner Losh   size_t suffixSize;
208*0c16b537SWarner Losh   U32 *freqs;
209*0c16b537SWarner Losh   U32 *dmerAt;
210*0c16b537SWarner Losh   unsigned d;
211*0c16b537SWarner Losh } COVER_ctx_t;
212*0c16b537SWarner Losh 
213*0c16b537SWarner Losh /* We need a global context for qsort... */
214*0c16b537SWarner Losh static COVER_ctx_t *g_ctx = NULL;
215*0c16b537SWarner Losh 
216*0c16b537SWarner Losh /*-*************************************
217*0c16b537SWarner Losh *  Helper functions
218*0c16b537SWarner Losh ***************************************/
219*0c16b537SWarner Losh 
220*0c16b537SWarner Losh /**
221*0c16b537SWarner Losh  * Returns the sum of the sample sizes.
222*0c16b537SWarner Losh  */
223*0c16b537SWarner Losh static size_t COVER_sum(const size_t *samplesSizes, unsigned nbSamples) {
224*0c16b537SWarner Losh   size_t sum = 0;
225*0c16b537SWarner Losh   size_t i;
226*0c16b537SWarner Losh   for (i = 0; i < nbSamples; ++i) {
227*0c16b537SWarner Losh     sum += samplesSizes[i];
228*0c16b537SWarner Losh   }
229*0c16b537SWarner Losh   return sum;
230*0c16b537SWarner Losh }
231*0c16b537SWarner Losh 
232*0c16b537SWarner Losh /**
233*0c16b537SWarner Losh  * Returns -1 if the dmer at lp is less than the dmer at rp.
234*0c16b537SWarner Losh  * Return 0 if the dmers at lp and rp are equal.
235*0c16b537SWarner Losh  * Returns 1 if the dmer at lp is greater than the dmer at rp.
236*0c16b537SWarner Losh  */
237*0c16b537SWarner Losh static int COVER_cmp(COVER_ctx_t *ctx, const void *lp, const void *rp) {
238*0c16b537SWarner Losh   U32 const lhs = *(U32 const *)lp;
239*0c16b537SWarner Losh   U32 const rhs = *(U32 const *)rp;
240*0c16b537SWarner Losh   return memcmp(ctx->samples + lhs, ctx->samples + rhs, ctx->d);
241*0c16b537SWarner Losh }
242*0c16b537SWarner Losh /**
243*0c16b537SWarner Losh  * Faster version for d <= 8.
244*0c16b537SWarner Losh  */
245*0c16b537SWarner Losh static int COVER_cmp8(COVER_ctx_t *ctx, const void *lp, const void *rp) {
246*0c16b537SWarner Losh   U64 const mask = (ctx->d == 8) ? (U64)-1 : (((U64)1 << (8 * ctx->d)) - 1);
247*0c16b537SWarner Losh   U64 const lhs = MEM_readLE64(ctx->samples + *(U32 const *)lp) & mask;
248*0c16b537SWarner Losh   U64 const rhs = MEM_readLE64(ctx->samples + *(U32 const *)rp) & mask;
249*0c16b537SWarner Losh   if (lhs < rhs) {
250*0c16b537SWarner Losh     return -1;
251*0c16b537SWarner Losh   }
252*0c16b537SWarner Losh   return (lhs > rhs);
253*0c16b537SWarner Losh }
254*0c16b537SWarner Losh 
255*0c16b537SWarner Losh /**
256*0c16b537SWarner Losh  * Same as COVER_cmp() except ties are broken by pointer value
257*0c16b537SWarner Losh  * NOTE: g_ctx must be set to call this function.  A global is required because
258*0c16b537SWarner Losh  * qsort doesn't take an opaque pointer.
259*0c16b537SWarner Losh  */
260*0c16b537SWarner Losh static int COVER_strict_cmp(const void *lp, const void *rp) {
261*0c16b537SWarner Losh   int result = COVER_cmp(g_ctx, lp, rp);
262*0c16b537SWarner Losh   if (result == 0) {
263*0c16b537SWarner Losh     result = lp < rp ? -1 : 1;
264*0c16b537SWarner Losh   }
265*0c16b537SWarner Losh   return result;
266*0c16b537SWarner Losh }
267*0c16b537SWarner Losh /**
268*0c16b537SWarner Losh  * Faster version for d <= 8.
269*0c16b537SWarner Losh  */
270*0c16b537SWarner Losh static int COVER_strict_cmp8(const void *lp, const void *rp) {
271*0c16b537SWarner Losh   int result = COVER_cmp8(g_ctx, lp, rp);
272*0c16b537SWarner Losh   if (result == 0) {
273*0c16b537SWarner Losh     result = lp < rp ? -1 : 1;
274*0c16b537SWarner Losh   }
275*0c16b537SWarner Losh   return result;
276*0c16b537SWarner Losh }
277*0c16b537SWarner Losh 
278*0c16b537SWarner Losh /**
279*0c16b537SWarner Losh  * Returns the first pointer in [first, last) whose element does not compare
280*0c16b537SWarner Losh  * less than value.  If no such element exists it returns last.
281*0c16b537SWarner Losh  */
282*0c16b537SWarner Losh static const size_t *COVER_lower_bound(const size_t *first, const size_t *last,
283*0c16b537SWarner Losh                                        size_t value) {
284*0c16b537SWarner Losh   size_t count = last - first;
285*0c16b537SWarner Losh   while (count != 0) {
286*0c16b537SWarner Losh     size_t step = count / 2;
287*0c16b537SWarner Losh     const size_t *ptr = first;
288*0c16b537SWarner Losh     ptr += step;
289*0c16b537SWarner Losh     if (*ptr < value) {
290*0c16b537SWarner Losh       first = ++ptr;
291*0c16b537SWarner Losh       count -= step + 1;
292*0c16b537SWarner Losh     } else {
293*0c16b537SWarner Losh       count = step;
294*0c16b537SWarner Losh     }
295*0c16b537SWarner Losh   }
296*0c16b537SWarner Losh   return first;
297*0c16b537SWarner Losh }
298*0c16b537SWarner Losh 
299*0c16b537SWarner Losh /**
300*0c16b537SWarner Losh  * Generic groupBy function.
301*0c16b537SWarner Losh  * Groups an array sorted by cmp into groups with equivalent values.
302*0c16b537SWarner Losh  * Calls grp for each group.
303*0c16b537SWarner Losh  */
304*0c16b537SWarner Losh static void
305*0c16b537SWarner Losh COVER_groupBy(const void *data, size_t count, size_t size, COVER_ctx_t *ctx,
306*0c16b537SWarner Losh               int (*cmp)(COVER_ctx_t *, const void *, const void *),
307*0c16b537SWarner Losh               void (*grp)(COVER_ctx_t *, const void *, const void *)) {
308*0c16b537SWarner Losh   const BYTE *ptr = (const BYTE *)data;
309*0c16b537SWarner Losh   size_t num = 0;
310*0c16b537SWarner Losh   while (num < count) {
311*0c16b537SWarner Losh     const BYTE *grpEnd = ptr + size;
312*0c16b537SWarner Losh     ++num;
313*0c16b537SWarner Losh     while (num < count && cmp(ctx, ptr, grpEnd) == 0) {
314*0c16b537SWarner Losh       grpEnd += size;
315*0c16b537SWarner Losh       ++num;
316*0c16b537SWarner Losh     }
317*0c16b537SWarner Losh     grp(ctx, ptr, grpEnd);
318*0c16b537SWarner Losh     ptr = grpEnd;
319*0c16b537SWarner Losh   }
320*0c16b537SWarner Losh }
321*0c16b537SWarner Losh 
322*0c16b537SWarner Losh /*-*************************************
323*0c16b537SWarner Losh *  Cover functions
324*0c16b537SWarner Losh ***************************************/
325*0c16b537SWarner Losh 
326*0c16b537SWarner Losh /**
327*0c16b537SWarner Losh  * Called on each group of positions with the same dmer.
328*0c16b537SWarner Losh  * Counts the frequency of each dmer and saves it in the suffix array.
329*0c16b537SWarner Losh  * Fills `ctx->dmerAt`.
330*0c16b537SWarner Losh  */
331*0c16b537SWarner Losh static void COVER_group(COVER_ctx_t *ctx, const void *group,
332*0c16b537SWarner Losh                         const void *groupEnd) {
333*0c16b537SWarner Losh   /* The group consists of all the positions with the same first d bytes. */
334*0c16b537SWarner Losh   const U32 *grpPtr = (const U32 *)group;
335*0c16b537SWarner Losh   const U32 *grpEnd = (const U32 *)groupEnd;
336*0c16b537SWarner Losh   /* The dmerId is how we will reference this dmer.
337*0c16b537SWarner Losh    * This allows us to map the whole dmer space to a much smaller space, the
338*0c16b537SWarner Losh    * size of the suffix array.
339*0c16b537SWarner Losh    */
340*0c16b537SWarner Losh   const U32 dmerId = (U32)(grpPtr - ctx->suffix);
341*0c16b537SWarner Losh   /* Count the number of samples this dmer shows up in */
342*0c16b537SWarner Losh   U32 freq = 0;
343*0c16b537SWarner Losh   /* Details */
344*0c16b537SWarner Losh   const size_t *curOffsetPtr = ctx->offsets;
345*0c16b537SWarner Losh   const size_t *offsetsEnd = ctx->offsets + ctx->nbSamples;
346*0c16b537SWarner Losh   /* Once *grpPtr >= curSampleEnd this occurrence of the dmer is in a
347*0c16b537SWarner Losh    * different sample than the last.
348*0c16b537SWarner Losh    */
349*0c16b537SWarner Losh   size_t curSampleEnd = ctx->offsets[0];
350*0c16b537SWarner Losh   for (; grpPtr != grpEnd; ++grpPtr) {
351*0c16b537SWarner Losh     /* Save the dmerId for this position so we can get back to it. */
352*0c16b537SWarner Losh     ctx->dmerAt[*grpPtr] = dmerId;
353*0c16b537SWarner Losh     /* Dictionaries only help for the first reference to the dmer.
354*0c16b537SWarner Losh      * After that zstd can reference the match from the previous reference.
355*0c16b537SWarner Losh      * So only count each dmer once for each sample it is in.
356*0c16b537SWarner Losh      */
357*0c16b537SWarner Losh     if (*grpPtr < curSampleEnd) {
358*0c16b537SWarner Losh       continue;
359*0c16b537SWarner Losh     }
360*0c16b537SWarner Losh     freq += 1;
361*0c16b537SWarner Losh     /* Binary search to find the end of the sample *grpPtr is in.
362*0c16b537SWarner Losh      * In the common case that grpPtr + 1 == grpEnd we can skip the binary
363*0c16b537SWarner Losh      * search because the loop is over.
364*0c16b537SWarner Losh      */
365*0c16b537SWarner Losh     if (grpPtr + 1 != grpEnd) {
366*0c16b537SWarner Losh       const size_t *sampleEndPtr =
367*0c16b537SWarner Losh           COVER_lower_bound(curOffsetPtr, offsetsEnd, *grpPtr);
368*0c16b537SWarner Losh       curSampleEnd = *sampleEndPtr;
369*0c16b537SWarner Losh       curOffsetPtr = sampleEndPtr + 1;
370*0c16b537SWarner Losh     }
371*0c16b537SWarner Losh   }
372*0c16b537SWarner Losh   /* At this point we are never going to look at this segment of the suffix
373*0c16b537SWarner Losh    * array again.  We take advantage of this fact to save memory.
374*0c16b537SWarner Losh    * We store the frequency of the dmer in the first position of the group,
375*0c16b537SWarner Losh    * which is dmerId.
376*0c16b537SWarner Losh    */
377*0c16b537SWarner Losh   ctx->suffix[dmerId] = freq;
378*0c16b537SWarner Losh }
379*0c16b537SWarner Losh 
380*0c16b537SWarner Losh /**
381*0c16b537SWarner Losh  * A segment is a range in the source as well as the score of the segment.
382*0c16b537SWarner Losh  */
383*0c16b537SWarner Losh typedef struct {
384*0c16b537SWarner Losh   U32 begin;
385*0c16b537SWarner Losh   U32 end;
386*0c16b537SWarner Losh   U32 score;
387*0c16b537SWarner Losh } COVER_segment_t;
388*0c16b537SWarner Losh 
389*0c16b537SWarner Losh /**
390*0c16b537SWarner Losh  * Selects the best segment in an epoch.
391*0c16b537SWarner Losh  * Segments of are scored according to the function:
392*0c16b537SWarner Losh  *
393*0c16b537SWarner Losh  * Let F(d) be the frequency of dmer d.
394*0c16b537SWarner Losh  * Let S_i be the dmer at position i of segment S which has length k.
395*0c16b537SWarner Losh  *
396*0c16b537SWarner Losh  *     Score(S) = F(S_1) + F(S_2) + ... + F(S_{k-d+1})
397*0c16b537SWarner Losh  *
398*0c16b537SWarner Losh  * Once the dmer d is in the dictionay we set F(d) = 0.
399*0c16b537SWarner Losh  */
400*0c16b537SWarner Losh static COVER_segment_t COVER_selectSegment(const COVER_ctx_t *ctx, U32 *freqs,
401*0c16b537SWarner Losh                                            COVER_map_t *activeDmers, U32 begin,
402*0c16b537SWarner Losh                                            U32 end,
403*0c16b537SWarner Losh                                            ZDICT_cover_params_t parameters) {
404*0c16b537SWarner Losh   /* Constants */
405*0c16b537SWarner Losh   const U32 k = parameters.k;
406*0c16b537SWarner Losh   const U32 d = parameters.d;
407*0c16b537SWarner Losh   const U32 dmersInK = k - d + 1;
408*0c16b537SWarner Losh   /* Try each segment (activeSegment) and save the best (bestSegment) */
409*0c16b537SWarner Losh   COVER_segment_t bestSegment = {0, 0, 0};
410*0c16b537SWarner Losh   COVER_segment_t activeSegment;
411*0c16b537SWarner Losh   /* Reset the activeDmers in the segment */
412*0c16b537SWarner Losh   COVER_map_clear(activeDmers);
413*0c16b537SWarner Losh   /* The activeSegment starts at the beginning of the epoch. */
414*0c16b537SWarner Losh   activeSegment.begin = begin;
415*0c16b537SWarner Losh   activeSegment.end = begin;
416*0c16b537SWarner Losh   activeSegment.score = 0;
417*0c16b537SWarner Losh   /* Slide the activeSegment through the whole epoch.
418*0c16b537SWarner Losh    * Save the best segment in bestSegment.
419*0c16b537SWarner Losh    */
420*0c16b537SWarner Losh   while (activeSegment.end < end) {
421*0c16b537SWarner Losh     /* The dmerId for the dmer at the next position */
422*0c16b537SWarner Losh     U32 newDmer = ctx->dmerAt[activeSegment.end];
423*0c16b537SWarner Losh     /* The entry in activeDmers for this dmerId */
424*0c16b537SWarner Losh     U32 *newDmerOcc = COVER_map_at(activeDmers, newDmer);
425*0c16b537SWarner Losh     /* If the dmer isn't already present in the segment add its score. */
426*0c16b537SWarner Losh     if (*newDmerOcc == 0) {
427*0c16b537SWarner Losh       /* The paper suggest using the L-0.5 norm, but experiments show that it
428*0c16b537SWarner Losh        * doesn't help.
429*0c16b537SWarner Losh        */
430*0c16b537SWarner Losh       activeSegment.score += freqs[newDmer];
431*0c16b537SWarner Losh     }
432*0c16b537SWarner Losh     /* Add the dmer to the segment */
433*0c16b537SWarner Losh     activeSegment.end += 1;
434*0c16b537SWarner Losh     *newDmerOcc += 1;
435*0c16b537SWarner Losh 
436*0c16b537SWarner Losh     /* If the window is now too large, drop the first position */
437*0c16b537SWarner Losh     if (activeSegment.end - activeSegment.begin == dmersInK + 1) {
438*0c16b537SWarner Losh       U32 delDmer = ctx->dmerAt[activeSegment.begin];
439*0c16b537SWarner Losh       U32 *delDmerOcc = COVER_map_at(activeDmers, delDmer);
440*0c16b537SWarner Losh       activeSegment.begin += 1;
441*0c16b537SWarner Losh       *delDmerOcc -= 1;
442*0c16b537SWarner Losh       /* If this is the last occurence of the dmer, subtract its score */
443*0c16b537SWarner Losh       if (*delDmerOcc == 0) {
444*0c16b537SWarner Losh         COVER_map_remove(activeDmers, delDmer);
445*0c16b537SWarner Losh         activeSegment.score -= freqs[delDmer];
446*0c16b537SWarner Losh       }
447*0c16b537SWarner Losh     }
448*0c16b537SWarner Losh 
449*0c16b537SWarner Losh     /* If this segment is the best so far save it */
450*0c16b537SWarner Losh     if (activeSegment.score > bestSegment.score) {
451*0c16b537SWarner Losh       bestSegment = activeSegment;
452*0c16b537SWarner Losh     }
453*0c16b537SWarner Losh   }
454*0c16b537SWarner Losh   {
455*0c16b537SWarner Losh     /* Trim off the zero frequency head and tail from the segment. */
456*0c16b537SWarner Losh     U32 newBegin = bestSegment.end;
457*0c16b537SWarner Losh     U32 newEnd = bestSegment.begin;
458*0c16b537SWarner Losh     U32 pos;
459*0c16b537SWarner Losh     for (pos = bestSegment.begin; pos != bestSegment.end; ++pos) {
460*0c16b537SWarner Losh       U32 freq = freqs[ctx->dmerAt[pos]];
461*0c16b537SWarner Losh       if (freq != 0) {
462*0c16b537SWarner Losh         newBegin = MIN(newBegin, pos);
463*0c16b537SWarner Losh         newEnd = pos + 1;
464*0c16b537SWarner Losh       }
465*0c16b537SWarner Losh     }
466*0c16b537SWarner Losh     bestSegment.begin = newBegin;
467*0c16b537SWarner Losh     bestSegment.end = newEnd;
468*0c16b537SWarner Losh   }
469*0c16b537SWarner Losh   {
470*0c16b537SWarner Losh     /* Zero out the frequency of each dmer covered by the chosen segment. */
471*0c16b537SWarner Losh     U32 pos;
472*0c16b537SWarner Losh     for (pos = bestSegment.begin; pos != bestSegment.end; ++pos) {
473*0c16b537SWarner Losh       freqs[ctx->dmerAt[pos]] = 0;
474*0c16b537SWarner Losh     }
475*0c16b537SWarner Losh   }
476*0c16b537SWarner Losh   return bestSegment;
477*0c16b537SWarner Losh }
478*0c16b537SWarner Losh 
479*0c16b537SWarner Losh /**
480*0c16b537SWarner Losh  * Check the validity of the parameters.
481*0c16b537SWarner Losh  * Returns non-zero if the parameters are valid and 0 otherwise.
482*0c16b537SWarner Losh  */
483*0c16b537SWarner Losh static int COVER_checkParameters(ZDICT_cover_params_t parameters,
484*0c16b537SWarner Losh                                  size_t maxDictSize) {
485*0c16b537SWarner Losh   /* k and d are required parameters */
486*0c16b537SWarner Losh   if (parameters.d == 0 || parameters.k == 0) {
487*0c16b537SWarner Losh     return 0;
488*0c16b537SWarner Losh   }
489*0c16b537SWarner Losh   /* k <= maxDictSize */
490*0c16b537SWarner Losh   if (parameters.k > maxDictSize) {
491*0c16b537SWarner Losh     return 0;
492*0c16b537SWarner Losh   }
493*0c16b537SWarner Losh   /* d <= k */
494*0c16b537SWarner Losh   if (parameters.d > parameters.k) {
495*0c16b537SWarner Losh     return 0;
496*0c16b537SWarner Losh   }
497*0c16b537SWarner Losh   return 1;
498*0c16b537SWarner Losh }
499*0c16b537SWarner Losh 
500*0c16b537SWarner Losh /**
501*0c16b537SWarner Losh  * Clean up a context initialized with `COVER_ctx_init()`.
502*0c16b537SWarner Losh  */
503*0c16b537SWarner Losh static void COVER_ctx_destroy(COVER_ctx_t *ctx) {
504*0c16b537SWarner Losh   if (!ctx) {
505*0c16b537SWarner Losh     return;
506*0c16b537SWarner Losh   }
507*0c16b537SWarner Losh   if (ctx->suffix) {
508*0c16b537SWarner Losh     free(ctx->suffix);
509*0c16b537SWarner Losh     ctx->suffix = NULL;
510*0c16b537SWarner Losh   }
511*0c16b537SWarner Losh   if (ctx->freqs) {
512*0c16b537SWarner Losh     free(ctx->freqs);
513*0c16b537SWarner Losh     ctx->freqs = NULL;
514*0c16b537SWarner Losh   }
515*0c16b537SWarner Losh   if (ctx->dmerAt) {
516*0c16b537SWarner Losh     free(ctx->dmerAt);
517*0c16b537SWarner Losh     ctx->dmerAt = NULL;
518*0c16b537SWarner Losh   }
519*0c16b537SWarner Losh   if (ctx->offsets) {
520*0c16b537SWarner Losh     free(ctx->offsets);
521*0c16b537SWarner Losh     ctx->offsets = NULL;
522*0c16b537SWarner Losh   }
523*0c16b537SWarner Losh }
524*0c16b537SWarner Losh 
525*0c16b537SWarner Losh /**
526*0c16b537SWarner Losh  * Prepare a context for dictionary building.
527*0c16b537SWarner Losh  * The context is only dependent on the parameter `d` and can used multiple
528*0c16b537SWarner Losh  * times.
529*0c16b537SWarner Losh  * Returns 1 on success or zero on error.
530*0c16b537SWarner Losh  * The context must be destroyed with `COVER_ctx_destroy()`.
531*0c16b537SWarner Losh  */
532*0c16b537SWarner Losh static int COVER_ctx_init(COVER_ctx_t *ctx, const void *samplesBuffer,
533*0c16b537SWarner Losh                           const size_t *samplesSizes, unsigned nbSamples,
534*0c16b537SWarner Losh                           unsigned d) {
535*0c16b537SWarner Losh   const BYTE *const samples = (const BYTE *)samplesBuffer;
536*0c16b537SWarner Losh   const size_t totalSamplesSize = COVER_sum(samplesSizes, nbSamples);
537*0c16b537SWarner Losh   /* Checks */
538*0c16b537SWarner Losh   if (totalSamplesSize < MAX(d, sizeof(U64)) ||
539*0c16b537SWarner Losh       totalSamplesSize >= (size_t)COVER_MAX_SAMPLES_SIZE) {
540*0c16b537SWarner Losh     DISPLAYLEVEL(1, "Total samples size is too large, maximum size is %u MB\n",
541*0c16b537SWarner Losh                  (COVER_MAX_SAMPLES_SIZE >> 20));
542*0c16b537SWarner Losh     return 0;
543*0c16b537SWarner Losh   }
544*0c16b537SWarner Losh   /* Zero the context */
545*0c16b537SWarner Losh   memset(ctx, 0, sizeof(*ctx));
546*0c16b537SWarner Losh   DISPLAYLEVEL(2, "Training on %u samples of total size %u\n", nbSamples,
547*0c16b537SWarner Losh                (U32)totalSamplesSize);
548*0c16b537SWarner Losh   ctx->samples = samples;
549*0c16b537SWarner Losh   ctx->samplesSizes = samplesSizes;
550*0c16b537SWarner Losh   ctx->nbSamples = nbSamples;
551*0c16b537SWarner Losh   /* Partial suffix array */
552*0c16b537SWarner Losh   ctx->suffixSize = totalSamplesSize - MAX(d, sizeof(U64)) + 1;
553*0c16b537SWarner Losh   ctx->suffix = (U32 *)malloc(ctx->suffixSize * sizeof(U32));
554*0c16b537SWarner Losh   /* Maps index to the dmerID */
555*0c16b537SWarner Losh   ctx->dmerAt = (U32 *)malloc(ctx->suffixSize * sizeof(U32));
556*0c16b537SWarner Losh   /* The offsets of each file */
557*0c16b537SWarner Losh   ctx->offsets = (size_t *)malloc((nbSamples + 1) * sizeof(size_t));
558*0c16b537SWarner Losh   if (!ctx->suffix || !ctx->dmerAt || !ctx->offsets) {
559*0c16b537SWarner Losh     DISPLAYLEVEL(1, "Failed to allocate scratch buffers\n");
560*0c16b537SWarner Losh     COVER_ctx_destroy(ctx);
561*0c16b537SWarner Losh     return 0;
562*0c16b537SWarner Losh   }
563*0c16b537SWarner Losh   ctx->freqs = NULL;
564*0c16b537SWarner Losh   ctx->d = d;
565*0c16b537SWarner Losh 
566*0c16b537SWarner Losh   /* Fill offsets from the samlesSizes */
567*0c16b537SWarner Losh   {
568*0c16b537SWarner Losh     U32 i;
569*0c16b537SWarner Losh     ctx->offsets[0] = 0;
570*0c16b537SWarner Losh     for (i = 1; i <= nbSamples; ++i) {
571*0c16b537SWarner Losh       ctx->offsets[i] = ctx->offsets[i - 1] + samplesSizes[i - 1];
572*0c16b537SWarner Losh     }
573*0c16b537SWarner Losh   }
574*0c16b537SWarner Losh   DISPLAYLEVEL(2, "Constructing partial suffix array\n");
575*0c16b537SWarner Losh   {
576*0c16b537SWarner Losh     /* suffix is a partial suffix array.
577*0c16b537SWarner Losh      * It only sorts suffixes by their first parameters.d bytes.
578*0c16b537SWarner Losh      * The sort is stable, so each dmer group is sorted by position in input.
579*0c16b537SWarner Losh      */
580*0c16b537SWarner Losh     U32 i;
581*0c16b537SWarner Losh     for (i = 0; i < ctx->suffixSize; ++i) {
582*0c16b537SWarner Losh       ctx->suffix[i] = i;
583*0c16b537SWarner Losh     }
584*0c16b537SWarner Losh     /* qsort doesn't take an opaque pointer, so pass as a global */
585*0c16b537SWarner Losh     g_ctx = ctx;
586*0c16b537SWarner Losh     qsort(ctx->suffix, ctx->suffixSize, sizeof(U32),
587*0c16b537SWarner Losh           (ctx->d <= 8 ? &COVER_strict_cmp8 : &COVER_strict_cmp));
588*0c16b537SWarner Losh   }
589*0c16b537SWarner Losh   DISPLAYLEVEL(2, "Computing frequencies\n");
590*0c16b537SWarner Losh   /* For each dmer group (group of positions with the same first d bytes):
591*0c16b537SWarner Losh    * 1. For each position we set dmerAt[position] = dmerID.  The dmerID is
592*0c16b537SWarner Losh    *    (groupBeginPtr - suffix).  This allows us to go from position to
593*0c16b537SWarner Losh    *    dmerID so we can look up values in freq.
594*0c16b537SWarner Losh    * 2. We calculate how many samples the dmer occurs in and save it in
595*0c16b537SWarner Losh    *    freqs[dmerId].
596*0c16b537SWarner Losh    */
597*0c16b537SWarner Losh   COVER_groupBy(ctx->suffix, ctx->suffixSize, sizeof(U32), ctx,
598*0c16b537SWarner Losh                 (ctx->d <= 8 ? &COVER_cmp8 : &COVER_cmp), &COVER_group);
599*0c16b537SWarner Losh   ctx->freqs = ctx->suffix;
600*0c16b537SWarner Losh   ctx->suffix = NULL;
601*0c16b537SWarner Losh   return 1;
602*0c16b537SWarner Losh }
603*0c16b537SWarner Losh 
604*0c16b537SWarner Losh /**
605*0c16b537SWarner Losh  * Given the prepared context build the dictionary.
606*0c16b537SWarner Losh  */
607*0c16b537SWarner Losh static size_t COVER_buildDictionary(const COVER_ctx_t *ctx, U32 *freqs,
608*0c16b537SWarner Losh                                     COVER_map_t *activeDmers, void *dictBuffer,
609*0c16b537SWarner Losh                                     size_t dictBufferCapacity,
610*0c16b537SWarner Losh                                     ZDICT_cover_params_t parameters) {
611*0c16b537SWarner Losh   BYTE *const dict = (BYTE *)dictBuffer;
612*0c16b537SWarner Losh   size_t tail = dictBufferCapacity;
613*0c16b537SWarner Losh   /* Divide the data up into epochs of equal size.
614*0c16b537SWarner Losh    * We will select at least one segment from each epoch.
615*0c16b537SWarner Losh    */
616*0c16b537SWarner Losh   const U32 epochs = (U32)(dictBufferCapacity / parameters.k);
617*0c16b537SWarner Losh   const U32 epochSize = (U32)(ctx->suffixSize / epochs);
618*0c16b537SWarner Losh   size_t epoch;
619*0c16b537SWarner Losh   DISPLAYLEVEL(2, "Breaking content into %u epochs of size %u\n", epochs,
620*0c16b537SWarner Losh                epochSize);
621*0c16b537SWarner Losh   /* Loop through the epochs until there are no more segments or the dictionary
622*0c16b537SWarner Losh    * is full.
623*0c16b537SWarner Losh    */
624*0c16b537SWarner Losh   for (epoch = 0; tail > 0; epoch = (epoch + 1) % epochs) {
625*0c16b537SWarner Losh     const U32 epochBegin = (U32)(epoch * epochSize);
626*0c16b537SWarner Losh     const U32 epochEnd = epochBegin + epochSize;
627*0c16b537SWarner Losh     size_t segmentSize;
628*0c16b537SWarner Losh     /* Select a segment */
629*0c16b537SWarner Losh     COVER_segment_t segment = COVER_selectSegment(
630*0c16b537SWarner Losh         ctx, freqs, activeDmers, epochBegin, epochEnd, parameters);
631*0c16b537SWarner Losh     /* If the segment covers no dmers, then we are out of content */
632*0c16b537SWarner Losh     if (segment.score == 0) {
633*0c16b537SWarner Losh       break;
634*0c16b537SWarner Losh     }
635*0c16b537SWarner Losh     /* Trim the segment if necessary and if it is too small then we are done */
636*0c16b537SWarner Losh     segmentSize = MIN(segment.end - segment.begin + parameters.d - 1, tail);
637*0c16b537SWarner Losh     if (segmentSize < parameters.d) {
638*0c16b537SWarner Losh       break;
639*0c16b537SWarner Losh     }
640*0c16b537SWarner Losh     /* We fill the dictionary from the back to allow the best segments to be
641*0c16b537SWarner Losh      * referenced with the smallest offsets.
642*0c16b537SWarner Losh      */
643*0c16b537SWarner Losh     tail -= segmentSize;
644*0c16b537SWarner Losh     memcpy(dict + tail, ctx->samples + segment.begin, segmentSize);
645*0c16b537SWarner Losh     DISPLAYUPDATE(
646*0c16b537SWarner Losh         2, "\r%u%%       ",
647*0c16b537SWarner Losh         (U32)(((dictBufferCapacity - tail) * 100) / dictBufferCapacity));
648*0c16b537SWarner Losh   }
649*0c16b537SWarner Losh   DISPLAYLEVEL(2, "\r%79s\r", "");
650*0c16b537SWarner Losh   return tail;
651*0c16b537SWarner Losh }
652*0c16b537SWarner Losh 
653*0c16b537SWarner Losh ZDICTLIB_API size_t ZDICT_trainFromBuffer_cover(
654*0c16b537SWarner Losh     void *dictBuffer, size_t dictBufferCapacity, const void *samplesBuffer,
655*0c16b537SWarner Losh     const size_t *samplesSizes, unsigned nbSamples,
656*0c16b537SWarner Losh     ZDICT_cover_params_t parameters) {
657*0c16b537SWarner Losh   BYTE *const dict = (BYTE *)dictBuffer;
658*0c16b537SWarner Losh   COVER_ctx_t ctx;
659*0c16b537SWarner Losh   COVER_map_t activeDmers;
660*0c16b537SWarner Losh   /* Checks */
661*0c16b537SWarner Losh   if (!COVER_checkParameters(parameters, dictBufferCapacity)) {
662*0c16b537SWarner Losh     DISPLAYLEVEL(1, "Cover parameters incorrect\n");
663*0c16b537SWarner Losh     return ERROR(GENERIC);
664*0c16b537SWarner Losh   }
665*0c16b537SWarner Losh   if (nbSamples == 0) {
666*0c16b537SWarner Losh     DISPLAYLEVEL(1, "Cover must have at least one input file\n");
667*0c16b537SWarner Losh     return ERROR(GENERIC);
668*0c16b537SWarner Losh   }
669*0c16b537SWarner Losh   if (dictBufferCapacity < ZDICT_DICTSIZE_MIN) {
670*0c16b537SWarner Losh     DISPLAYLEVEL(1, "dictBufferCapacity must be at least %u\n",
671*0c16b537SWarner Losh                  ZDICT_DICTSIZE_MIN);
672*0c16b537SWarner Losh     return ERROR(dstSize_tooSmall);
673*0c16b537SWarner Losh   }
674*0c16b537SWarner Losh   /* Initialize global data */
675*0c16b537SWarner Losh   g_displayLevel = parameters.zParams.notificationLevel;
676*0c16b537SWarner Losh   /* Initialize context and activeDmers */
677*0c16b537SWarner Losh   if (!COVER_ctx_init(&ctx, samplesBuffer, samplesSizes, nbSamples,
678*0c16b537SWarner Losh                       parameters.d)) {
679*0c16b537SWarner Losh     return ERROR(GENERIC);
680*0c16b537SWarner Losh   }
681*0c16b537SWarner Losh   if (!COVER_map_init(&activeDmers, parameters.k - parameters.d + 1)) {
682*0c16b537SWarner Losh     DISPLAYLEVEL(1, "Failed to allocate dmer map: out of memory\n");
683*0c16b537SWarner Losh     COVER_ctx_destroy(&ctx);
684*0c16b537SWarner Losh     return ERROR(GENERIC);
685*0c16b537SWarner Losh   }
686*0c16b537SWarner Losh 
687*0c16b537SWarner Losh   DISPLAYLEVEL(2, "Building dictionary\n");
688*0c16b537SWarner Losh   {
689*0c16b537SWarner Losh     const size_t tail =
690*0c16b537SWarner Losh         COVER_buildDictionary(&ctx, ctx.freqs, &activeDmers, dictBuffer,
691*0c16b537SWarner Losh                               dictBufferCapacity, parameters);
692*0c16b537SWarner Losh     const size_t dictionarySize = ZDICT_finalizeDictionary(
693*0c16b537SWarner Losh         dict, dictBufferCapacity, dict + tail, dictBufferCapacity - tail,
694*0c16b537SWarner Losh         samplesBuffer, samplesSizes, nbSamples, parameters.zParams);
695*0c16b537SWarner Losh     if (!ZSTD_isError(dictionarySize)) {
696*0c16b537SWarner Losh       DISPLAYLEVEL(2, "Constructed dictionary of size %u\n",
697*0c16b537SWarner Losh                    (U32)dictionarySize);
698*0c16b537SWarner Losh     }
699*0c16b537SWarner Losh     COVER_ctx_destroy(&ctx);
700*0c16b537SWarner Losh     COVER_map_destroy(&activeDmers);
701*0c16b537SWarner Losh     return dictionarySize;
702*0c16b537SWarner Losh   }
703*0c16b537SWarner Losh }
704*0c16b537SWarner Losh 
705*0c16b537SWarner Losh /**
706*0c16b537SWarner Losh  * COVER_best_t is used for two purposes:
707*0c16b537SWarner Losh  * 1. Synchronizing threads.
708*0c16b537SWarner Losh  * 2. Saving the best parameters and dictionary.
709*0c16b537SWarner Losh  *
710*0c16b537SWarner Losh  * All of the methods except COVER_best_init() are thread safe if zstd is
711*0c16b537SWarner Losh  * compiled with multithreaded support.
712*0c16b537SWarner Losh  */
713*0c16b537SWarner Losh typedef struct COVER_best_s {
714*0c16b537SWarner Losh   ZSTD_pthread_mutex_t mutex;
715*0c16b537SWarner Losh   ZSTD_pthread_cond_t cond;
716*0c16b537SWarner Losh   size_t liveJobs;
717*0c16b537SWarner Losh   void *dict;
718*0c16b537SWarner Losh   size_t dictSize;
719*0c16b537SWarner Losh   ZDICT_cover_params_t parameters;
720*0c16b537SWarner Losh   size_t compressedSize;
721*0c16b537SWarner Losh } COVER_best_t;
722*0c16b537SWarner Losh 
723*0c16b537SWarner Losh /**
724*0c16b537SWarner Losh  * Initialize the `COVER_best_t`.
725*0c16b537SWarner Losh  */
726*0c16b537SWarner Losh static void COVER_best_init(COVER_best_t *best) {
727*0c16b537SWarner Losh   if (best==NULL) return; /* compatible with init on NULL */
728*0c16b537SWarner Losh   (void)ZSTD_pthread_mutex_init(&best->mutex, NULL);
729*0c16b537SWarner Losh   (void)ZSTD_pthread_cond_init(&best->cond, NULL);
730*0c16b537SWarner Losh   best->liveJobs = 0;
731*0c16b537SWarner Losh   best->dict = NULL;
732*0c16b537SWarner Losh   best->dictSize = 0;
733*0c16b537SWarner Losh   best->compressedSize = (size_t)-1;
734*0c16b537SWarner Losh   memset(&best->parameters, 0, sizeof(best->parameters));
735*0c16b537SWarner Losh }
736*0c16b537SWarner Losh 
737*0c16b537SWarner Losh /**
738*0c16b537SWarner Losh  * Wait until liveJobs == 0.
739*0c16b537SWarner Losh  */
740*0c16b537SWarner Losh static void COVER_best_wait(COVER_best_t *best) {
741*0c16b537SWarner Losh   if (!best) {
742*0c16b537SWarner Losh     return;
743*0c16b537SWarner Losh   }
744*0c16b537SWarner Losh   ZSTD_pthread_mutex_lock(&best->mutex);
745*0c16b537SWarner Losh   while (best->liveJobs != 0) {
746*0c16b537SWarner Losh     ZSTD_pthread_cond_wait(&best->cond, &best->mutex);
747*0c16b537SWarner Losh   }
748*0c16b537SWarner Losh   ZSTD_pthread_mutex_unlock(&best->mutex);
749*0c16b537SWarner Losh }
750*0c16b537SWarner Losh 
751*0c16b537SWarner Losh /**
752*0c16b537SWarner Losh  * Call COVER_best_wait() and then destroy the COVER_best_t.
753*0c16b537SWarner Losh  */
754*0c16b537SWarner Losh static void COVER_best_destroy(COVER_best_t *best) {
755*0c16b537SWarner Losh   if (!best) {
756*0c16b537SWarner Losh     return;
757*0c16b537SWarner Losh   }
758*0c16b537SWarner Losh   COVER_best_wait(best);
759*0c16b537SWarner Losh   if (best->dict) {
760*0c16b537SWarner Losh     free(best->dict);
761*0c16b537SWarner Losh   }
762*0c16b537SWarner Losh   ZSTD_pthread_mutex_destroy(&best->mutex);
763*0c16b537SWarner Losh   ZSTD_pthread_cond_destroy(&best->cond);
764*0c16b537SWarner Losh }
765*0c16b537SWarner Losh 
766*0c16b537SWarner Losh /**
767*0c16b537SWarner Losh  * Called when a thread is about to be launched.
768*0c16b537SWarner Losh  * Increments liveJobs.
769*0c16b537SWarner Losh  */
770*0c16b537SWarner Losh static void COVER_best_start(COVER_best_t *best) {
771*0c16b537SWarner Losh   if (!best) {
772*0c16b537SWarner Losh     return;
773*0c16b537SWarner Losh   }
774*0c16b537SWarner Losh   ZSTD_pthread_mutex_lock(&best->mutex);
775*0c16b537SWarner Losh   ++best->liveJobs;
776*0c16b537SWarner Losh   ZSTD_pthread_mutex_unlock(&best->mutex);
777*0c16b537SWarner Losh }
778*0c16b537SWarner Losh 
779*0c16b537SWarner Losh /**
780*0c16b537SWarner Losh  * Called when a thread finishes executing, both on error or success.
781*0c16b537SWarner Losh  * Decrements liveJobs and signals any waiting threads if liveJobs == 0.
782*0c16b537SWarner Losh  * If this dictionary is the best so far save it and its parameters.
783*0c16b537SWarner Losh  */
784*0c16b537SWarner Losh static void COVER_best_finish(COVER_best_t *best, size_t compressedSize,
785*0c16b537SWarner Losh                               ZDICT_cover_params_t parameters, void *dict,
786*0c16b537SWarner Losh                               size_t dictSize) {
787*0c16b537SWarner Losh   if (!best) {
788*0c16b537SWarner Losh     return;
789*0c16b537SWarner Losh   }
790*0c16b537SWarner Losh   {
791*0c16b537SWarner Losh     size_t liveJobs;
792*0c16b537SWarner Losh     ZSTD_pthread_mutex_lock(&best->mutex);
793*0c16b537SWarner Losh     --best->liveJobs;
794*0c16b537SWarner Losh     liveJobs = best->liveJobs;
795*0c16b537SWarner Losh     /* If the new dictionary is better */
796*0c16b537SWarner Losh     if (compressedSize < best->compressedSize) {
797*0c16b537SWarner Losh       /* Allocate space if necessary */
798*0c16b537SWarner Losh       if (!best->dict || best->dictSize < dictSize) {
799*0c16b537SWarner Losh         if (best->dict) {
800*0c16b537SWarner Losh           free(best->dict);
801*0c16b537SWarner Losh         }
802*0c16b537SWarner Losh         best->dict = malloc(dictSize);
803*0c16b537SWarner Losh         if (!best->dict) {
804*0c16b537SWarner Losh           best->compressedSize = ERROR(GENERIC);
805*0c16b537SWarner Losh           best->dictSize = 0;
806*0c16b537SWarner Losh           return;
807*0c16b537SWarner Losh         }
808*0c16b537SWarner Losh       }
809*0c16b537SWarner Losh       /* Save the dictionary, parameters, and size */
810*0c16b537SWarner Losh       memcpy(best->dict, dict, dictSize);
811*0c16b537SWarner Losh       best->dictSize = dictSize;
812*0c16b537SWarner Losh       best->parameters = parameters;
813*0c16b537SWarner Losh       best->compressedSize = compressedSize;
814*0c16b537SWarner Losh     }
815*0c16b537SWarner Losh     ZSTD_pthread_mutex_unlock(&best->mutex);
816*0c16b537SWarner Losh     if (liveJobs == 0) {
817*0c16b537SWarner Losh       ZSTD_pthread_cond_broadcast(&best->cond);
818*0c16b537SWarner Losh     }
819*0c16b537SWarner Losh   }
820*0c16b537SWarner Losh }
821*0c16b537SWarner Losh 
822*0c16b537SWarner Losh /**
823*0c16b537SWarner Losh  * Parameters for COVER_tryParameters().
824*0c16b537SWarner Losh  */
825*0c16b537SWarner Losh typedef struct COVER_tryParameters_data_s {
826*0c16b537SWarner Losh   const COVER_ctx_t *ctx;
827*0c16b537SWarner Losh   COVER_best_t *best;
828*0c16b537SWarner Losh   size_t dictBufferCapacity;
829*0c16b537SWarner Losh   ZDICT_cover_params_t parameters;
830*0c16b537SWarner Losh } COVER_tryParameters_data_t;
831*0c16b537SWarner Losh 
832*0c16b537SWarner Losh /**
833*0c16b537SWarner Losh  * Tries a set of parameters and upates the COVER_best_t with the results.
834*0c16b537SWarner Losh  * This function is thread safe if zstd is compiled with multithreaded support.
835*0c16b537SWarner Losh  * It takes its parameters as an *OWNING* opaque pointer to support threading.
836*0c16b537SWarner Losh  */
837*0c16b537SWarner Losh static void COVER_tryParameters(void *opaque) {
838*0c16b537SWarner Losh   /* Save parameters as local variables */
839*0c16b537SWarner Losh   COVER_tryParameters_data_t *const data = (COVER_tryParameters_data_t *)opaque;
840*0c16b537SWarner Losh   const COVER_ctx_t *const ctx = data->ctx;
841*0c16b537SWarner Losh   const ZDICT_cover_params_t parameters = data->parameters;
842*0c16b537SWarner Losh   size_t dictBufferCapacity = data->dictBufferCapacity;
843*0c16b537SWarner Losh   size_t totalCompressedSize = ERROR(GENERIC);
844*0c16b537SWarner Losh   /* Allocate space for hash table, dict, and freqs */
845*0c16b537SWarner Losh   COVER_map_t activeDmers;
846*0c16b537SWarner Losh   BYTE *const dict = (BYTE * const)malloc(dictBufferCapacity);
847*0c16b537SWarner Losh   U32 *freqs = (U32 *)malloc(ctx->suffixSize * sizeof(U32));
848*0c16b537SWarner Losh   if (!COVER_map_init(&activeDmers, parameters.k - parameters.d + 1)) {
849*0c16b537SWarner Losh     DISPLAYLEVEL(1, "Failed to allocate dmer map: out of memory\n");
850*0c16b537SWarner Losh     goto _cleanup;
851*0c16b537SWarner Losh   }
852*0c16b537SWarner Losh   if (!dict || !freqs) {
853*0c16b537SWarner Losh     DISPLAYLEVEL(1, "Failed to allocate buffers: out of memory\n");
854*0c16b537SWarner Losh     goto _cleanup;
855*0c16b537SWarner Losh   }
856*0c16b537SWarner Losh   /* Copy the frequencies because we need to modify them */
857*0c16b537SWarner Losh   memcpy(freqs, ctx->freqs, ctx->suffixSize * sizeof(U32));
858*0c16b537SWarner Losh   /* Build the dictionary */
859*0c16b537SWarner Losh   {
860*0c16b537SWarner Losh     const size_t tail = COVER_buildDictionary(ctx, freqs, &activeDmers, dict,
861*0c16b537SWarner Losh                                               dictBufferCapacity, parameters);
862*0c16b537SWarner Losh     dictBufferCapacity = ZDICT_finalizeDictionary(
863*0c16b537SWarner Losh         dict, dictBufferCapacity, dict + tail, dictBufferCapacity - tail,
864*0c16b537SWarner Losh         ctx->samples, ctx->samplesSizes, (unsigned)ctx->nbSamples,
865*0c16b537SWarner Losh         parameters.zParams);
866*0c16b537SWarner Losh     if (ZDICT_isError(dictBufferCapacity)) {
867*0c16b537SWarner Losh       DISPLAYLEVEL(1, "Failed to finalize dictionary\n");
868*0c16b537SWarner Losh       goto _cleanup;
869*0c16b537SWarner Losh     }
870*0c16b537SWarner Losh   }
871*0c16b537SWarner Losh   /* Check total compressed size */
872*0c16b537SWarner Losh   {
873*0c16b537SWarner Losh     /* Pointers */
874*0c16b537SWarner Losh     ZSTD_CCtx *cctx;
875*0c16b537SWarner Losh     ZSTD_CDict *cdict;
876*0c16b537SWarner Losh     void *dst;
877*0c16b537SWarner Losh     /* Local variables */
878*0c16b537SWarner Losh     size_t dstCapacity;
879*0c16b537SWarner Losh     size_t i;
880*0c16b537SWarner Losh     /* Allocate dst with enough space to compress the maximum sized sample */
881*0c16b537SWarner Losh     {
882*0c16b537SWarner Losh       size_t maxSampleSize = 0;
883*0c16b537SWarner Losh       for (i = 0; i < ctx->nbSamples; ++i) {
884*0c16b537SWarner Losh         maxSampleSize = MAX(ctx->samplesSizes[i], maxSampleSize);
885*0c16b537SWarner Losh       }
886*0c16b537SWarner Losh       dstCapacity = ZSTD_compressBound(maxSampleSize);
887*0c16b537SWarner Losh       dst = malloc(dstCapacity);
888*0c16b537SWarner Losh     }
889*0c16b537SWarner Losh     /* Create the cctx and cdict */
890*0c16b537SWarner Losh     cctx = ZSTD_createCCtx();
891*0c16b537SWarner Losh     cdict = ZSTD_createCDict(dict, dictBufferCapacity,
892*0c16b537SWarner Losh                              parameters.zParams.compressionLevel);
893*0c16b537SWarner Losh     if (!dst || !cctx || !cdict) {
894*0c16b537SWarner Losh       goto _compressCleanup;
895*0c16b537SWarner Losh     }
896*0c16b537SWarner Losh     /* Compress each sample and sum their sizes (or error) */
897*0c16b537SWarner Losh     totalCompressedSize = dictBufferCapacity;
898*0c16b537SWarner Losh     for (i = 0; i < ctx->nbSamples; ++i) {
899*0c16b537SWarner Losh       const size_t size = ZSTD_compress_usingCDict(
900*0c16b537SWarner Losh           cctx, dst, dstCapacity, ctx->samples + ctx->offsets[i],
901*0c16b537SWarner Losh           ctx->samplesSizes[i], cdict);
902*0c16b537SWarner Losh       if (ZSTD_isError(size)) {
903*0c16b537SWarner Losh         totalCompressedSize = ERROR(GENERIC);
904*0c16b537SWarner Losh         goto _compressCleanup;
905*0c16b537SWarner Losh       }
906*0c16b537SWarner Losh       totalCompressedSize += size;
907*0c16b537SWarner Losh     }
908*0c16b537SWarner Losh   _compressCleanup:
909*0c16b537SWarner Losh     ZSTD_freeCCtx(cctx);
910*0c16b537SWarner Losh     ZSTD_freeCDict(cdict);
911*0c16b537SWarner Losh     if (dst) {
912*0c16b537SWarner Losh       free(dst);
913*0c16b537SWarner Losh     }
914*0c16b537SWarner Losh   }
915*0c16b537SWarner Losh 
916*0c16b537SWarner Losh _cleanup:
917*0c16b537SWarner Losh   COVER_best_finish(data->best, totalCompressedSize, parameters, dict,
918*0c16b537SWarner Losh                     dictBufferCapacity);
919*0c16b537SWarner Losh   free(data);
920*0c16b537SWarner Losh   COVER_map_destroy(&activeDmers);
921*0c16b537SWarner Losh   if (dict) {
922*0c16b537SWarner Losh     free(dict);
923*0c16b537SWarner Losh   }
924*0c16b537SWarner Losh   if (freqs) {
925*0c16b537SWarner Losh     free(freqs);
926*0c16b537SWarner Losh   }
927*0c16b537SWarner Losh }
928*0c16b537SWarner Losh 
929*0c16b537SWarner Losh ZDICTLIB_API size_t ZDICT_optimizeTrainFromBuffer_cover(
930*0c16b537SWarner Losh     void *dictBuffer, size_t dictBufferCapacity, const void *samplesBuffer,
931*0c16b537SWarner Losh     const size_t *samplesSizes, unsigned nbSamples,
932*0c16b537SWarner Losh     ZDICT_cover_params_t *parameters) {
933*0c16b537SWarner Losh   /* constants */
934*0c16b537SWarner Losh   const unsigned nbThreads = parameters->nbThreads;
935*0c16b537SWarner Losh   const unsigned kMinD = parameters->d == 0 ? 6 : parameters->d;
936*0c16b537SWarner Losh   const unsigned kMaxD = parameters->d == 0 ? 8 : parameters->d;
937*0c16b537SWarner Losh   const unsigned kMinK = parameters->k == 0 ? 50 : parameters->k;
938*0c16b537SWarner Losh   const unsigned kMaxK = parameters->k == 0 ? 2000 : parameters->k;
939*0c16b537SWarner Losh   const unsigned kSteps = parameters->steps == 0 ? 40 : parameters->steps;
940*0c16b537SWarner Losh   const unsigned kStepSize = MAX((kMaxK - kMinK) / kSteps, 1);
941*0c16b537SWarner Losh   const unsigned kIterations =
942*0c16b537SWarner Losh       (1 + (kMaxD - kMinD) / 2) * (1 + (kMaxK - kMinK) / kStepSize);
943*0c16b537SWarner Losh   /* Local variables */
944*0c16b537SWarner Losh   const int displayLevel = parameters->zParams.notificationLevel;
945*0c16b537SWarner Losh   unsigned iteration = 1;
946*0c16b537SWarner Losh   unsigned d;
947*0c16b537SWarner Losh   unsigned k;
948*0c16b537SWarner Losh   COVER_best_t best;
949*0c16b537SWarner Losh   POOL_ctx *pool = NULL;
950*0c16b537SWarner Losh   /* Checks */
951*0c16b537SWarner Losh   if (kMinK < kMaxD || kMaxK < kMinK) {
952*0c16b537SWarner Losh     LOCALDISPLAYLEVEL(displayLevel, 1, "Incorrect parameters\n");
953*0c16b537SWarner Losh     return ERROR(GENERIC);
954*0c16b537SWarner Losh   }
955*0c16b537SWarner Losh   if (nbSamples == 0) {
956*0c16b537SWarner Losh     DISPLAYLEVEL(1, "Cover must have at least one input file\n");
957*0c16b537SWarner Losh     return ERROR(GENERIC);
958*0c16b537SWarner Losh   }
959*0c16b537SWarner Losh   if (dictBufferCapacity < ZDICT_DICTSIZE_MIN) {
960*0c16b537SWarner Losh     DISPLAYLEVEL(1, "dictBufferCapacity must be at least %u\n",
961*0c16b537SWarner Losh                  ZDICT_DICTSIZE_MIN);
962*0c16b537SWarner Losh     return ERROR(dstSize_tooSmall);
963*0c16b537SWarner Losh   }
964*0c16b537SWarner Losh   if (nbThreads > 1) {
965*0c16b537SWarner Losh     pool = POOL_create(nbThreads, 1);
966*0c16b537SWarner Losh     if (!pool) {
967*0c16b537SWarner Losh       return ERROR(memory_allocation);
968*0c16b537SWarner Losh     }
969*0c16b537SWarner Losh   }
970*0c16b537SWarner Losh   /* Initialization */
971*0c16b537SWarner Losh   COVER_best_init(&best);
972*0c16b537SWarner Losh   /* Turn down global display level to clean up display at level 2 and below */
973*0c16b537SWarner Losh   g_displayLevel = displayLevel == 0 ? 0 : displayLevel - 1;
974*0c16b537SWarner Losh   /* Loop through d first because each new value needs a new context */
975*0c16b537SWarner Losh   LOCALDISPLAYLEVEL(displayLevel, 2, "Trying %u different sets of parameters\n",
976*0c16b537SWarner Losh                     kIterations);
977*0c16b537SWarner Losh   for (d = kMinD; d <= kMaxD; d += 2) {
978*0c16b537SWarner Losh     /* Initialize the context for this value of d */
979*0c16b537SWarner Losh     COVER_ctx_t ctx;
980*0c16b537SWarner Losh     LOCALDISPLAYLEVEL(displayLevel, 3, "d=%u\n", d);
981*0c16b537SWarner Losh     if (!COVER_ctx_init(&ctx, samplesBuffer, samplesSizes, nbSamples, d)) {
982*0c16b537SWarner Losh       LOCALDISPLAYLEVEL(displayLevel, 1, "Failed to initialize context\n");
983*0c16b537SWarner Losh       COVER_best_destroy(&best);
984*0c16b537SWarner Losh       POOL_free(pool);
985*0c16b537SWarner Losh       return ERROR(GENERIC);
986*0c16b537SWarner Losh     }
987*0c16b537SWarner Losh     /* Loop through k reusing the same context */
988*0c16b537SWarner Losh     for (k = kMinK; k <= kMaxK; k += kStepSize) {
989*0c16b537SWarner Losh       /* Prepare the arguments */
990*0c16b537SWarner Losh       COVER_tryParameters_data_t *data = (COVER_tryParameters_data_t *)malloc(
991*0c16b537SWarner Losh           sizeof(COVER_tryParameters_data_t));
992*0c16b537SWarner Losh       LOCALDISPLAYLEVEL(displayLevel, 3, "k=%u\n", k);
993*0c16b537SWarner Losh       if (!data) {
994*0c16b537SWarner Losh         LOCALDISPLAYLEVEL(displayLevel, 1, "Failed to allocate parameters\n");
995*0c16b537SWarner Losh         COVER_best_destroy(&best);
996*0c16b537SWarner Losh         COVER_ctx_destroy(&ctx);
997*0c16b537SWarner Losh         POOL_free(pool);
998*0c16b537SWarner Losh         return ERROR(GENERIC);
999*0c16b537SWarner Losh       }
1000*0c16b537SWarner Losh       data->ctx = &ctx;
1001*0c16b537SWarner Losh       data->best = &best;
1002*0c16b537SWarner Losh       data->dictBufferCapacity = dictBufferCapacity;
1003*0c16b537SWarner Losh       data->parameters = *parameters;
1004*0c16b537SWarner Losh       data->parameters.k = k;
1005*0c16b537SWarner Losh       data->parameters.d = d;
1006*0c16b537SWarner Losh       data->parameters.steps = kSteps;
1007*0c16b537SWarner Losh       data->parameters.zParams.notificationLevel = g_displayLevel;
1008*0c16b537SWarner Losh       /* Check the parameters */
1009*0c16b537SWarner Losh       if (!COVER_checkParameters(data->parameters, dictBufferCapacity)) {
1010*0c16b537SWarner Losh         DISPLAYLEVEL(1, "Cover parameters incorrect\n");
1011*0c16b537SWarner Losh         free(data);
1012*0c16b537SWarner Losh         continue;
1013*0c16b537SWarner Losh       }
1014*0c16b537SWarner Losh       /* Call the function and pass ownership of data to it */
1015*0c16b537SWarner Losh       COVER_best_start(&best);
1016*0c16b537SWarner Losh       if (pool) {
1017*0c16b537SWarner Losh         POOL_add(pool, &COVER_tryParameters, data);
1018*0c16b537SWarner Losh       } else {
1019*0c16b537SWarner Losh         COVER_tryParameters(data);
1020*0c16b537SWarner Losh       }
1021*0c16b537SWarner Losh       /* Print status */
1022*0c16b537SWarner Losh       LOCALDISPLAYUPDATE(displayLevel, 2, "\r%u%%       ",
1023*0c16b537SWarner Losh                          (U32)((iteration * 100) / kIterations));
1024*0c16b537SWarner Losh       ++iteration;
1025*0c16b537SWarner Losh     }
1026*0c16b537SWarner Losh     COVER_best_wait(&best);
1027*0c16b537SWarner Losh     COVER_ctx_destroy(&ctx);
1028*0c16b537SWarner Losh   }
1029*0c16b537SWarner Losh   LOCALDISPLAYLEVEL(displayLevel, 2, "\r%79s\r", "");
1030*0c16b537SWarner Losh   /* Fill the output buffer and parameters with output of the best parameters */
1031*0c16b537SWarner Losh   {
1032*0c16b537SWarner Losh     const size_t dictSize = best.dictSize;
1033*0c16b537SWarner Losh     if (ZSTD_isError(best.compressedSize)) {
1034*0c16b537SWarner Losh       const size_t compressedSize = best.compressedSize;
1035*0c16b537SWarner Losh       COVER_best_destroy(&best);
1036*0c16b537SWarner Losh       POOL_free(pool);
1037*0c16b537SWarner Losh       return compressedSize;
1038*0c16b537SWarner Losh     }
1039*0c16b537SWarner Losh     *parameters = best.parameters;
1040*0c16b537SWarner Losh     memcpy(dictBuffer, best.dict, dictSize);
1041*0c16b537SWarner Losh     COVER_best_destroy(&best);
1042*0c16b537SWarner Losh     POOL_free(pool);
1043*0c16b537SWarner Losh     return dictSize;
1044*0c16b537SWarner Losh   }
1045*0c16b537SWarner Losh }
1046