xref: /freebsd/sys/contrib/zstd/lib/common/entropy_common.c (revision 43a5ec4eb41567cc92586503212743d89686d78f)
1 /* ******************************************************************
2  * Common functions of New Generation Entropy library
3  * Copyright (c) 2016-2020, Yann Collet, Facebook, Inc.
4  *
5  *  You can contact the author at :
6  *  - FSE+HUF source repository : https://github.com/Cyan4973/FiniteStateEntropy
7  *  - Public forum : https://groups.google.com/forum/#!forum/lz4c
8  *
9  * This source code is licensed under both the BSD-style license (found in the
10  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
11  * in the COPYING file in the root directory of this source tree).
12  * You may select, at your option, one of the above-listed licenses.
13 ****************************************************************** */
14 
15 /* *************************************
16 *  Dependencies
17 ***************************************/
18 #include "mem.h"
19 #include "error_private.h"       /* ERR_*, ERROR */
20 #define FSE_STATIC_LINKING_ONLY  /* FSE_MIN_TABLELOG */
21 #include "fse.h"
22 #define HUF_STATIC_LINKING_ONLY  /* HUF_TABLELOG_ABSOLUTEMAX */
23 #include "huf.h"
24 
25 
26 /*===   Version   ===*/
27 unsigned FSE_versionNumber(void) { return FSE_VERSION_NUMBER; }
28 
29 
30 /*===   Error Management   ===*/
31 unsigned FSE_isError(size_t code) { return ERR_isError(code); }
32 const char* FSE_getErrorName(size_t code) { return ERR_getErrorName(code); }
33 
34 unsigned HUF_isError(size_t code) { return ERR_isError(code); }
35 const char* HUF_getErrorName(size_t code) { return ERR_getErrorName(code); }
36 
37 
38 /*-**************************************************************
39 *  FSE NCount encoding-decoding
40 ****************************************************************/
41 static U32 FSE_ctz(U32 val)
42 {
43     assert(val != 0);
44     {
45 #   if defined(_MSC_VER)   /* Visual */
46         unsigned long r=0;
47         return _BitScanForward(&r, val) ? (unsigned)r : 0;
48 #   elif defined(__GNUC__) && (__GNUC__ >= 3)   /* GCC Intrinsic */
49         return __builtin_ctz(val);
50 #   elif defined(__ICCARM__)    /* IAR Intrinsic */
51         return __CTZ(val);
52 #   else   /* Software version */
53         U32 count = 0;
54         while ((val & 1) == 0) {
55             val >>= 1;
56             ++count;
57         }
58         return count;
59 #   endif
60     }
61 }
62 
63 FORCE_INLINE_TEMPLATE
64 size_t FSE_readNCount_body(short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr,
65                            const void* headerBuffer, size_t hbSize)
66 {
67     const BYTE* const istart = (const BYTE*) headerBuffer;
68     const BYTE* const iend = istart + hbSize;
69     const BYTE* ip = istart;
70     int nbBits;
71     int remaining;
72     int threshold;
73     U32 bitStream;
74     int bitCount;
75     unsigned charnum = 0;
76     unsigned const maxSV1 = *maxSVPtr + 1;
77     int previous0 = 0;
78 
79     if (hbSize < 8) {
80         /* This function only works when hbSize >= 8 */
81         char buffer[8] = {0};
82         ZSTD_memcpy(buffer, headerBuffer, hbSize);
83         {   size_t const countSize = FSE_readNCount(normalizedCounter, maxSVPtr, tableLogPtr,
84                                                     buffer, sizeof(buffer));
85             if (FSE_isError(countSize)) return countSize;
86             if (countSize > hbSize) return ERROR(corruption_detected);
87             return countSize;
88     }   }
89     assert(hbSize >= 8);
90 
91     /* init */
92     ZSTD_memset(normalizedCounter, 0, (*maxSVPtr+1) * sizeof(normalizedCounter[0]));   /* all symbols not present in NCount have a frequency of 0 */
93     bitStream = MEM_readLE32(ip);
94     nbBits = (bitStream & 0xF) + FSE_MIN_TABLELOG;   /* extract tableLog */
95     if (nbBits > FSE_TABLELOG_ABSOLUTE_MAX) return ERROR(tableLog_tooLarge);
96     bitStream >>= 4;
97     bitCount = 4;
98     *tableLogPtr = nbBits;
99     remaining = (1<<nbBits)+1;
100     threshold = 1<<nbBits;
101     nbBits++;
102 
103     for (;;) {
104         if (previous0) {
105             /* Count the number of repeats. Each time the
106              * 2-bit repeat code is 0b11 there is another
107              * repeat.
108              * Avoid UB by setting the high bit to 1.
109              */
110             int repeats = FSE_ctz(~bitStream | 0x80000000) >> 1;
111             while (repeats >= 12) {
112                 charnum += 3 * 12;
113                 if (LIKELY(ip <= iend-7)) {
114                     ip += 3;
115                 } else {
116                     bitCount -= (int)(8 * (iend - 7 - ip));
117                     bitCount &= 31;
118                     ip = iend - 4;
119                 }
120                 bitStream = MEM_readLE32(ip) >> bitCount;
121                 repeats = FSE_ctz(~bitStream | 0x80000000) >> 1;
122             }
123             charnum += 3 * repeats;
124             bitStream >>= 2 * repeats;
125             bitCount += 2 * repeats;
126 
127             /* Add the final repeat which isn't 0b11. */
128             assert((bitStream & 3) < 3);
129             charnum += bitStream & 3;
130             bitCount += 2;
131 
132             /* This is an error, but break and return an error
133              * at the end, because returning out of a loop makes
134              * it harder for the compiler to optimize.
135              */
136             if (charnum >= maxSV1) break;
137 
138             /* We don't need to set the normalized count to 0
139              * because we already memset the whole buffer to 0.
140              */
141 
142             if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) {
143                 assert((bitCount >> 3) <= 3); /* For first condition to work */
144                 ip += bitCount>>3;
145                 bitCount &= 7;
146             } else {
147                 bitCount -= (int)(8 * (iend - 4 - ip));
148                 bitCount &= 31;
149                 ip = iend - 4;
150             }
151             bitStream = MEM_readLE32(ip) >> bitCount;
152         }
153         {
154             int const max = (2*threshold-1) - remaining;
155             int count;
156 
157             if ((bitStream & (threshold-1)) < (U32)max) {
158                 count = bitStream & (threshold-1);
159                 bitCount += nbBits-1;
160             } else {
161                 count = bitStream & (2*threshold-1);
162                 if (count >= threshold) count -= max;
163                 bitCount += nbBits;
164             }
165 
166             count--;   /* extra accuracy */
167             /* When it matters (small blocks), this is a
168              * predictable branch, because we don't use -1.
169              */
170             if (count >= 0) {
171                 remaining -= count;
172             } else {
173                 assert(count == -1);
174                 remaining += count;
175             }
176             normalizedCounter[charnum++] = (short)count;
177             previous0 = !count;
178 
179             assert(threshold > 1);
180             if (remaining < threshold) {
181                 /* This branch can be folded into the
182                  * threshold update condition because we
183                  * know that threshold > 1.
184                  */
185                 if (remaining <= 1) break;
186                 nbBits = BIT_highbit32(remaining) + 1;
187                 threshold = 1 << (nbBits - 1);
188             }
189             if (charnum >= maxSV1) break;
190 
191             if (LIKELY(ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) {
192                 ip += bitCount>>3;
193                 bitCount &= 7;
194             } else {
195                 bitCount -= (int)(8 * (iend - 4 - ip));
196                 bitCount &= 31;
197                 ip = iend - 4;
198             }
199             bitStream = MEM_readLE32(ip) >> bitCount;
200     }   }
201     if (remaining != 1) return ERROR(corruption_detected);
202     /* Only possible when there are too many zeros. */
203     if (charnum > maxSV1) return ERROR(maxSymbolValue_tooSmall);
204     if (bitCount > 32) return ERROR(corruption_detected);
205     *maxSVPtr = charnum-1;
206 
207     ip += (bitCount+7)>>3;
208     return ip-istart;
209 }
210 
211 /* Avoids the FORCE_INLINE of the _body() function. */
212 static size_t FSE_readNCount_body_default(
213         short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr,
214         const void* headerBuffer, size_t hbSize)
215 {
216     return FSE_readNCount_body(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize);
217 }
218 
219 #if DYNAMIC_BMI2
220 TARGET_ATTRIBUTE("bmi2") static size_t FSE_readNCount_body_bmi2(
221         short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr,
222         const void* headerBuffer, size_t hbSize)
223 {
224     return FSE_readNCount_body(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize);
225 }
226 #endif
227 
228 size_t FSE_readNCount_bmi2(
229         short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr,
230         const void* headerBuffer, size_t hbSize, int bmi2)
231 {
232 #if DYNAMIC_BMI2
233     if (bmi2) {
234         return FSE_readNCount_body_bmi2(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize);
235     }
236 #endif
237     (void)bmi2;
238     return FSE_readNCount_body_default(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize);
239 }
240 
241 size_t FSE_readNCount(
242         short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr,
243         const void* headerBuffer, size_t hbSize)
244 {
245     return FSE_readNCount_bmi2(normalizedCounter, maxSVPtr, tableLogPtr, headerBuffer, hbSize, /* bmi2 */ 0);
246 }
247 
248 
249 /*! HUF_readStats() :
250     Read compact Huffman tree, saved by HUF_writeCTable().
251     `huffWeight` is destination buffer.
252     `rankStats` is assumed to be a table of at least HUF_TABLELOG_MAX U32.
253     @return : size read from `src` , or an error Code .
254     Note : Needed by HUF_readCTable() and HUF_readDTableX?() .
255 */
256 size_t HUF_readStats(BYTE* huffWeight, size_t hwSize, U32* rankStats,
257                      U32* nbSymbolsPtr, U32* tableLogPtr,
258                      const void* src, size_t srcSize)
259 {
260     U32 wksp[HUF_READ_STATS_WORKSPACE_SIZE_U32];
261     return HUF_readStats_wksp(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, wksp, sizeof(wksp), /* bmi2 */ 0);
262 }
263 
264 FORCE_INLINE_TEMPLATE size_t
265 HUF_readStats_body(BYTE* huffWeight, size_t hwSize, U32* rankStats,
266                    U32* nbSymbolsPtr, U32* tableLogPtr,
267                    const void* src, size_t srcSize,
268                    void* workSpace, size_t wkspSize,
269                    int bmi2)
270 {
271     U32 weightTotal;
272     const BYTE* ip = (const BYTE*) src;
273     size_t iSize;
274     size_t oSize;
275 
276     if (!srcSize) return ERROR(srcSize_wrong);
277     iSize = ip[0];
278     /* ZSTD_memset(huffWeight, 0, hwSize);   *//* is not necessary, even though some analyzer complain ... */
279 
280     if (iSize >= 128) {  /* special header */
281         oSize = iSize - 127;
282         iSize = ((oSize+1)/2);
283         if (iSize+1 > srcSize) return ERROR(srcSize_wrong);
284         if (oSize >= hwSize) return ERROR(corruption_detected);
285         ip += 1;
286         {   U32 n;
287             for (n=0; n<oSize; n+=2) {
288                 huffWeight[n]   = ip[n/2] >> 4;
289                 huffWeight[n+1] = ip[n/2] & 15;
290     }   }   }
291     else  {   /* header compressed with FSE (normal case) */
292         if (iSize+1 > srcSize) return ERROR(srcSize_wrong);
293         /* max (hwSize-1) values decoded, as last one is implied */
294         oSize = FSE_decompress_wksp_bmi2(huffWeight, hwSize-1, ip+1, iSize, 6, workSpace, wkspSize, bmi2);
295         if (FSE_isError(oSize)) return oSize;
296     }
297 
298     /* collect weight stats */
299     ZSTD_memset(rankStats, 0, (HUF_TABLELOG_MAX + 1) * sizeof(U32));
300     weightTotal = 0;
301     {   U32 n; for (n=0; n<oSize; n++) {
302             if (huffWeight[n] >= HUF_TABLELOG_MAX) return ERROR(corruption_detected);
303             rankStats[huffWeight[n]]++;
304             weightTotal += (1 << huffWeight[n]) >> 1;
305     }   }
306     if (weightTotal == 0) return ERROR(corruption_detected);
307 
308     /* get last non-null symbol weight (implied, total must be 2^n) */
309     {   U32 const tableLog = BIT_highbit32(weightTotal) + 1;
310         if (tableLog > HUF_TABLELOG_MAX) return ERROR(corruption_detected);
311         *tableLogPtr = tableLog;
312         /* determine last weight */
313         {   U32 const total = 1 << tableLog;
314             U32 const rest = total - weightTotal;
315             U32 const verif = 1 << BIT_highbit32(rest);
316             U32 const lastWeight = BIT_highbit32(rest) + 1;
317             if (verif != rest) return ERROR(corruption_detected);    /* last value must be a clean power of 2 */
318             huffWeight[oSize] = (BYTE)lastWeight;
319             rankStats[lastWeight]++;
320     }   }
321 
322     /* check tree construction validity */
323     if ((rankStats[1] < 2) || (rankStats[1] & 1)) return ERROR(corruption_detected);   /* by construction : at least 2 elts of rank 1, must be even */
324 
325     /* results */
326     *nbSymbolsPtr = (U32)(oSize+1);
327     return iSize+1;
328 }
329 
330 /* Avoids the FORCE_INLINE of the _body() function. */
331 static size_t HUF_readStats_body_default(BYTE* huffWeight, size_t hwSize, U32* rankStats,
332                      U32* nbSymbolsPtr, U32* tableLogPtr,
333                      const void* src, size_t srcSize,
334                      void* workSpace, size_t wkspSize)
335 {
336     return HUF_readStats_body(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize, 0);
337 }
338 
339 #if DYNAMIC_BMI2
340 static TARGET_ATTRIBUTE("bmi2") size_t HUF_readStats_body_bmi2(BYTE* huffWeight, size_t hwSize, U32* rankStats,
341                      U32* nbSymbolsPtr, U32* tableLogPtr,
342                      const void* src, size_t srcSize,
343                      void* workSpace, size_t wkspSize)
344 {
345     return HUF_readStats_body(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize, 1);
346 }
347 #endif
348 
349 size_t HUF_readStats_wksp(BYTE* huffWeight, size_t hwSize, U32* rankStats,
350                      U32* nbSymbolsPtr, U32* tableLogPtr,
351                      const void* src, size_t srcSize,
352                      void* workSpace, size_t wkspSize,
353                      int bmi2)
354 {
355 #if DYNAMIC_BMI2
356     if (bmi2) {
357         return HUF_readStats_body_bmi2(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize);
358     }
359 #endif
360     (void)bmi2;
361     return HUF_readStats_body_default(huffWeight, hwSize, rankStats, nbSymbolsPtr, tableLogPtr, src, srcSize, workSpace, wkspSize);
362 }
363