xref: /freebsd/sys/contrib/openzfs/module/zstd/lib/common/entropy_common.c (revision 61145dc2b94f12f6a47344fb9aac702321880e43)
1 // SPDX-License-Identifier: BSD-3-Clause OR GPL-2.0-only
2 /* ******************************************************************
3  * Common functions of New Generation Entropy library
4  * Copyright (c) 2016-2020, Yann Collet, Facebook, Inc.
5  *
6  *  You can contact the author at :
7  *  - FSE+HUF source repository : https://github.com/Cyan4973/FiniteStateEntropy
8  *  - Public forum : https://groups.google.com/forum/#!forum/lz4c
9  *
10  * This source code is licensed under both the BSD-style license (found in the
11  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
12  * in the COPYING file in the root directory of this source tree).
13  * You may select, at your option, one of the above-listed licenses.
14 ****************************************************************** */
15 
16 /* *************************************
17 *  Dependencies
18 ***************************************/
19 #include "mem.h"
20 #include "error_private.h"       /* ERR_*, ERROR */
21 #define FSE_STATIC_LINKING_ONLY  /* FSE_MIN_TABLELOG */
22 #include "fse.h"
23 #define HUF_STATIC_LINKING_ONLY  /* HUF_TABLELOG_ABSOLUTEMAX */
24 #include "huf.h"
25 
26 
27 /*===   Version   ===*/
FSE_versionNumber(void)28 unsigned FSE_versionNumber(void) { return FSE_VERSION_NUMBER; }
29 
30 
31 /*===   Error Management   ===*/
FSE_isError(size_t code)32 unsigned FSE_isError(size_t code) { return ERR_isError(code); }
FSE_getErrorName(size_t code)33 const char* FSE_getErrorName(size_t code) { return ERR_getErrorName(code); }
34 
HUF_isError(size_t code)35 unsigned HUF_isError(size_t code) { return ERR_isError(code); }
HUF_getErrorName(size_t code)36 const char* HUF_getErrorName(size_t code) { return ERR_getErrorName(code); }
37 
38 
39 /*-**************************************************************
40 *  FSE NCount encoding-decoding
41 ****************************************************************/
FSE_readNCount(short * normalizedCounter,unsigned * maxSVPtr,unsigned * tableLogPtr,const void * headerBuffer,size_t hbSize)42 size_t FSE_readNCount (short* normalizedCounter, unsigned* maxSVPtr, unsigned* tableLogPtr,
43                  const void* headerBuffer, size_t hbSize)
44 {
45     const BYTE* const istart = (const BYTE*) headerBuffer;
46     const BYTE* const iend = istart + hbSize;
47     const BYTE* ip = istart;
48     int nbBits;
49     int remaining;
50     int threshold;
51     U32 bitStream;
52     int bitCount;
53     unsigned charnum = 0;
54     int previous0 = 0;
55 
56     if (hbSize < 4) {
57         /* This function only works when hbSize >= 4 */
58         char buffer[4];
59         memset(buffer, 0, sizeof(buffer));
60         memcpy(buffer, headerBuffer, hbSize);
61         {   size_t const countSize = FSE_readNCount(normalizedCounter, maxSVPtr, tableLogPtr,
62                                                     buffer, sizeof(buffer));
63             if (FSE_isError(countSize)) return countSize;
64             if (countSize > hbSize) return ERROR(corruption_detected);
65             return countSize;
66     }   }
67     assert(hbSize >= 4);
68 
69     /* init */
70     memset(normalizedCounter, 0, (*maxSVPtr+1) * sizeof(normalizedCounter[0]));   /* all symbols not present in NCount have a frequency of 0 */
71     bitStream = MEM_readLE32(ip);
72     nbBits = (bitStream & 0xF) + FSE_MIN_TABLELOG;   /* extract tableLog */
73     if (nbBits > FSE_TABLELOG_ABSOLUTE_MAX) return ERROR(tableLog_tooLarge);
74     bitStream >>= 4;
75     bitCount = 4;
76     *tableLogPtr = nbBits;
77     remaining = (1<<nbBits)+1;
78     threshold = 1<<nbBits;
79     nbBits++;
80 
81     while ((remaining>1) & (charnum<=*maxSVPtr)) {
82         if (previous0) {
83             unsigned n0 = charnum;
84             while ((bitStream & 0xFFFF) == 0xFFFF) {
85                 n0 += 24;
86                 if (ip < iend-5) {
87                     ip += 2;
88                     bitStream = MEM_readLE32(ip) >> bitCount;
89                 } else {
90                     bitStream >>= 16;
91                     bitCount   += 16;
92             }   }
93             while ((bitStream & 3) == 3) {
94                 n0 += 3;
95                 bitStream >>= 2;
96                 bitCount += 2;
97             }
98             n0 += bitStream & 3;
99             bitCount += 2;
100             if (n0 > *maxSVPtr) return ERROR(maxSymbolValue_tooSmall);
101             while (charnum < n0) normalizedCounter[charnum++] = 0;
102             if ((ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) {
103                 assert((bitCount >> 3) <= 3); /* For first condition to work */
104                 ip += bitCount>>3;
105                 bitCount &= 7;
106                 bitStream = MEM_readLE32(ip) >> bitCount;
107             } else {
108                 bitStream >>= 2;
109         }   }
110         {   int const max = (2*threshold-1) - remaining;
111             int count;
112 
113             if ((bitStream & (threshold-1)) < (U32)max) {
114                 count = bitStream & (threshold-1);
115                 bitCount += nbBits-1;
116             } else {
117                 count = bitStream & (2*threshold-1);
118                 if (count >= threshold) count -= max;
119                 bitCount += nbBits;
120             }
121 
122             count--;   /* extra accuracy */
123             remaining -= count < 0 ? -count : count;   /* -1 means +1 */
124             normalizedCounter[charnum++] = (short)count;
125             previous0 = !count;
126             while (remaining < threshold) {
127                 nbBits--;
128                 threshold >>= 1;
129             }
130 
131             if ((ip <= iend-7) || (ip + (bitCount>>3) <= iend-4)) {
132                 ip += bitCount>>3;
133                 bitCount &= 7;
134             } else {
135                 bitCount -= (int)(8 * (iend - 4 - ip));
136                 ip = iend - 4;
137             }
138             bitStream = MEM_readLE32(ip) >> (bitCount & 31);
139     }   }   /* while ((remaining>1) & (charnum<=*maxSVPtr)) */
140     if (remaining != 1) return ERROR(corruption_detected);
141     if (bitCount > 32) return ERROR(corruption_detected);
142     *maxSVPtr = charnum-1;
143 
144     ip += (bitCount+7)>>3;
145     return ip-istart;
146 }
147 
148 
149 /*! HUF_readStats() :
150     Read compact Huffman tree, saved by HUF_writeCTable().
151     `huffWeight` is destination buffer.
152     `rankStats` is assumed to be a table of at least HUF_TABLELOG_MAX U32.
153     @return : size read from `src` , or an error Code .
154     Note : Needed by HUF_readCTable() and HUF_readDTableX?() .
155 */
HUF_readStats(BYTE * huffWeight,size_t hwSize,U32 * rankStats,U32 * nbSymbolsPtr,U32 * tableLogPtr,const void * src,size_t srcSize)156 size_t HUF_readStats(BYTE* huffWeight, size_t hwSize, U32* rankStats,
157                      U32* nbSymbolsPtr, U32* tableLogPtr,
158                      const void* src, size_t srcSize)
159 {
160     U32 weightTotal;
161     const BYTE* ip = (const BYTE*) src;
162     size_t iSize;
163     size_t oSize;
164 
165     if (!srcSize) return ERROR(srcSize_wrong);
166     iSize = ip[0];
167     /* memset(huffWeight, 0, hwSize);   *//* is not necessary, even though some analyzer complain ... */
168 
169     if (iSize >= 128) {  /* special header */
170         oSize = iSize - 127;
171         iSize = ((oSize+1)/2);
172         if (iSize+1 > srcSize) return ERROR(srcSize_wrong);
173         if (oSize >= hwSize) return ERROR(corruption_detected);
174         ip += 1;
175         {   U32 n;
176             for (n=0; n<oSize; n+=2) {
177                 huffWeight[n]   = ip[n/2] >> 4;
178                 huffWeight[n+1] = ip[n/2] & 15;
179     }   }   }
180     else  {   /* header compressed with FSE (normal case) */
181         FSE_DTable fseWorkspace[FSE_DTABLE_SIZE_U32(6)];  /* 6 is max possible tableLog for HUF header (maybe even 5, to be tested) */
182         if (iSize+1 > srcSize) return ERROR(srcSize_wrong);
183         oSize = FSE_decompress_wksp(huffWeight, hwSize-1, ip+1, iSize, fseWorkspace, 6);   /* max (hwSize-1) values decoded, as last one is implied */
184         if (FSE_isError(oSize)) return oSize;
185     }
186 
187     /* collect weight stats */
188     memset(rankStats, 0, (HUF_TABLELOG_MAX + 1) * sizeof(U32));
189     weightTotal = 0;
190     {   U32 n; for (n=0; n<oSize; n++) {
191             if (huffWeight[n] >= HUF_TABLELOG_MAX) return ERROR(corruption_detected);
192             rankStats[huffWeight[n]]++;
193             weightTotal += (1 << huffWeight[n]) >> 1;
194     }   }
195     if (weightTotal == 0) return ERROR(corruption_detected);
196 
197     /* get last non-null symbol weight (implied, total must be 2^n) */
198     {   U32 const tableLog = BIT_highbit32(weightTotal) + 1;
199         if (tableLog > HUF_TABLELOG_MAX) return ERROR(corruption_detected);
200         *tableLogPtr = tableLog;
201         /* determine last weight */
202         {   U32 const total = 1 << tableLog;
203             U32 const rest = total - weightTotal;
204             U32 const verif = 1 << BIT_highbit32(rest);
205             U32 const lastWeight = BIT_highbit32(rest) + 1;
206             if (verif != rest) return ERROR(corruption_detected);    /* last value must be a clean power of 2 */
207             huffWeight[oSize] = (BYTE)lastWeight;
208             rankStats[lastWeight]++;
209     }   }
210 
211     /* check tree construction validity */
212     if ((rankStats[1] < 2) || (rankStats[1] & 1)) return ERROR(corruption_detected);   /* by construction : at least 2 elts of rank 1, must be even */
213 
214     /* results */
215     *nbSymbolsPtr = (U32)(oSize+1);
216     return iSize+1;
217 }
218