xref: /freebsd/sys/contrib/zlib/infback.c (revision 397e83df75e0fcd0d3fcb95ae4d794cb7600fc89)
1 /* infback.c -- inflate using a call-back interface
2  * Copyright (C) 1995-2022 Mark Adler
3  * For conditions of distribution and use, see copyright notice in zlib.h
4  */
5 
6 /*
7    This code is largely copied from inflate.c.  Normally either infback.o or
8    inflate.o would be linked into an application--not both.  The interface
9    with inffast.c is retained so that optimized assembler-coded versions of
10    inflate_fast() can be used with either inflate.c or infback.c.
11  */
12 
13 #include "zutil.h"
14 #include "inftrees.h"
15 #include "inflate.h"
16 #include "inffast.h"
17 
18 /*
19    strm provides memory allocation functions in zalloc and zfree, or
20    Z_NULL to use the library memory allocation functions.
21 
22    windowBits is in the range 8..15, and window is a user-supplied
23    window and output buffer that is 2**windowBits bytes.
24  */
25 int ZEXPORT inflateBackInit_(z_streamp strm, int windowBits,
26                              unsigned char FAR *window, const char *version,
27                              int stream_size) {
28     struct inflate_state FAR *state;
29 
30     if (version == Z_NULL || version[0] != ZLIB_VERSION[0] ||
31         stream_size != (int)(sizeof(z_stream)))
32         return Z_VERSION_ERROR;
33     if (strm == Z_NULL || window == Z_NULL ||
34         windowBits < 8 || windowBits > 15)
35         return Z_STREAM_ERROR;
36     strm->msg = Z_NULL;                 /* in case we return an error */
37     if (strm->zalloc == (alloc_func)0) {
38 #if defined(Z_SOLO) && !defined(_KERNEL)
39         return Z_STREAM_ERROR;
40 #else
41         strm->zalloc = zcalloc;
42         strm->opaque = (voidpf)0;
43 #endif
44     }
45     if (strm->zfree == (free_func)0)
46 #if defined(Z_SOLO) && !defined(_KERNEL)
47         return Z_STREAM_ERROR;
48 #else
49     strm->zfree = zcfree;
50 #endif
51     state = (struct inflate_state FAR *)ZALLOC(strm, 1,
52                                                sizeof(struct inflate_state));
53     if (state == Z_NULL) return Z_MEM_ERROR;
54     Tracev((stderr, "inflate: allocated\n"));
55     strm->state = (struct internal_state FAR *)state;
56     state->dmax = 32768U;
57     state->wbits = (uInt)windowBits;
58     state->wsize = 1U << windowBits;
59     state->window = window;
60     state->wnext = 0;
61     state->whave = 0;
62     state->sane = 1;
63     return Z_OK;
64 }
65 
66 /*
67    Return state with length and distance decoding tables and index sizes set to
68    fixed code decoding.  Normally this returns fixed tables from inffixed.h.
69    If BUILDFIXED is defined, then instead this routine builds the tables the
70    first time it's called, and returns those tables the first time and
71    thereafter.  This reduces the size of the code by about 2K bytes, in
72    exchange for a little execution time.  However, BUILDFIXED should not be
73    used for threaded applications, since the rewriting of the tables and virgin
74    may not be thread-safe.
75  */
76 local void fixedtables(struct inflate_state FAR *state) {
77 #ifdef BUILDFIXED
78     static int virgin = 1;
79     static code *lenfix, *distfix;
80     static code fixed[544];
81 
82     /* build fixed huffman tables if first call (may not be thread safe) */
83     if (virgin) {
84         unsigned sym, bits;
85         static code *next;
86 
87         /* literal/length table */
88         sym = 0;
89         while (sym < 144) state->lens[sym++] = 8;
90         while (sym < 256) state->lens[sym++] = 9;
91         while (sym < 280) state->lens[sym++] = 7;
92         while (sym < 288) state->lens[sym++] = 8;
93         next = fixed;
94         lenfix = next;
95         bits = 9;
96         inflate_table(LENS, state->lens, 288, &(next), &(bits), state->work);
97 
98         /* distance table */
99         sym = 0;
100         while (sym < 32) state->lens[sym++] = 5;
101         distfix = next;
102         bits = 5;
103         inflate_table(DISTS, state->lens, 32, &(next), &(bits), state->work);
104 
105         /* do this just once */
106         virgin = 0;
107     }
108 #else /* !BUILDFIXED */
109 #   include "inffixed.h"
110 #endif /* BUILDFIXED */
111     state->lencode = lenfix;
112     state->lenbits = 9;
113     state->distcode = distfix;
114     state->distbits = 5;
115 }
116 
117 /* Macros for inflateBack(): */
118 
119 /* Load returned state from inflate_fast() */
120 #define LOAD() \
121     do { \
122         put = strm->next_out; \
123         left = strm->avail_out; \
124         next = strm->next_in; \
125         have = strm->avail_in; \
126         hold = state->hold; \
127         bits = state->bits; \
128     } while (0)
129 
130 /* Set state from registers for inflate_fast() */
131 #define RESTORE() \
132     do { \
133         strm->next_out = put; \
134         strm->avail_out = left; \
135         strm->next_in = next; \
136         strm->avail_in = have; \
137         state->hold = hold; \
138         state->bits = bits; \
139     } while (0)
140 
141 /* Clear the input bit accumulator */
142 #define INITBITS() \
143     do { \
144         hold = 0; \
145         bits = 0; \
146     } while (0)
147 
148 /* Assure that some input is available.  If input is requested, but denied,
149    then return a Z_BUF_ERROR from inflateBack(). */
150 #define PULL() \
151     do { \
152         if (have == 0) { \
153             have = in(in_desc, &next); \
154             if (have == 0) { \
155                 next = Z_NULL; \
156                 ret = Z_BUF_ERROR; \
157                 goto inf_leave; \
158             } \
159         } \
160     } while (0)
161 
162 /* Get a byte of input into the bit accumulator, or return from inflateBack()
163    with an error if there is no input available. */
164 #define PULLBYTE() \
165     do { \
166         PULL(); \
167         have--; \
168         hold += (unsigned long)(*next++) << bits; \
169         bits += 8; \
170     } while (0)
171 
172 /* Assure that there are at least n bits in the bit accumulator.  If there is
173    not enough available input to do that, then return from inflateBack() with
174    an error. */
175 #define NEEDBITS(n) \
176     do { \
177         while (bits < (unsigned)(n)) \
178             PULLBYTE(); \
179     } while (0)
180 
181 /* Return the low n bits of the bit accumulator (n < 16) */
182 #define BITS(n) \
183     ((unsigned)hold & ((1U << (n)) - 1))
184 
185 /* Remove n bits from the bit accumulator */
186 #define DROPBITS(n) \
187     do { \
188         hold >>= (n); \
189         bits -= (unsigned)(n); \
190     } while (0)
191 
192 /* Remove zero to seven bits as needed to go to a byte boundary */
193 #define BYTEBITS() \
194     do { \
195         hold >>= bits & 7; \
196         bits -= bits & 7; \
197     } while (0)
198 
199 /* Assure that some output space is available, by writing out the window
200    if it's full.  If the write fails, return from inflateBack() with a
201    Z_BUF_ERROR. */
202 #define ROOM() \
203     do { \
204         if (left == 0) { \
205             put = state->window; \
206             left = state->wsize; \
207             state->whave = left; \
208             if (out(out_desc, put, left)) { \
209                 ret = Z_BUF_ERROR; \
210                 goto inf_leave; \
211             } \
212         } \
213     } while (0)
214 
215 /*
216    strm provides the memory allocation functions and window buffer on input,
217    and provides information on the unused input on return.  For Z_DATA_ERROR
218    returns, strm will also provide an error message.
219 
220    in() and out() are the call-back input and output functions.  When
221    inflateBack() needs more input, it calls in().  When inflateBack() has
222    filled the window with output, or when it completes with data in the
223    window, it calls out() to write out the data.  The application must not
224    change the provided input until in() is called again or inflateBack()
225    returns.  The application must not change the window/output buffer until
226    inflateBack() returns.
227 
228    in() and out() are called with a descriptor parameter provided in the
229    inflateBack() call.  This parameter can be a structure that provides the
230    information required to do the read or write, as well as accumulated
231    information on the input and output such as totals and check values.
232 
233    in() should return zero on failure.  out() should return non-zero on
234    failure.  If either in() or out() fails, than inflateBack() returns a
235    Z_BUF_ERROR.  strm->next_in can be checked for Z_NULL to see whether it
236    was in() or out() that caused in the error.  Otherwise,  inflateBack()
237    returns Z_STREAM_END on success, Z_DATA_ERROR for an deflate format
238    error, or Z_MEM_ERROR if it could not allocate memory for the state.
239    inflateBack() can also return Z_STREAM_ERROR if the input parameters
240    are not correct, i.e. strm is Z_NULL or the state was not initialized.
241  */
242 int ZEXPORT inflateBack(z_streamp strm, in_func in, void FAR *in_desc,
243                         out_func out, void FAR *out_desc) {
244     struct inflate_state FAR *state;
245     z_const unsigned char FAR *next;    /* next input */
246     unsigned char FAR *put;     /* next output */
247     unsigned have, left;        /* available input and output */
248     unsigned long hold;         /* bit buffer */
249     unsigned bits;              /* bits in bit buffer */
250     unsigned copy;              /* number of stored or match bytes to copy */
251     unsigned char FAR *from;    /* where to copy match bytes from */
252     code here;                  /* current decoding table entry */
253     code last;                  /* parent table entry */
254     unsigned len;               /* length to copy for repeats, bits to drop */
255     int ret;                    /* return code */
256     static const unsigned short order[19] = /* permutation of code lengths */
257         {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
258 
259     /* Check that the strm exists and that the state was initialized */
260     if (strm == Z_NULL || strm->state == Z_NULL)
261         return Z_STREAM_ERROR;
262     state = (struct inflate_state FAR *)strm->state;
263 
264     /* Reset the state */
265     strm->msg = Z_NULL;
266     state->mode = TYPE;
267     state->last = 0;
268     state->whave = 0;
269     next = strm->next_in;
270     have = next != Z_NULL ? strm->avail_in : 0;
271     hold = 0;
272     bits = 0;
273     put = state->window;
274     left = state->wsize;
275 
276     /* Inflate until end of block marked as last */
277     for (;;)
278         switch (state->mode) {
279         case TYPE:
280             /* determine and dispatch block type */
281             if (state->last) {
282                 BYTEBITS();
283                 state->mode = DONE;
284                 break;
285             }
286             NEEDBITS(3);
287             state->last = BITS(1);
288             DROPBITS(1);
289             switch (BITS(2)) {
290             case 0:                             /* stored block */
291                 Tracev((stderr, "inflate:     stored block%s\n",
292                         state->last ? " (last)" : ""));
293                 state->mode = STORED;
294                 break;
295             case 1:                             /* fixed block */
296                 fixedtables(state);
297                 Tracev((stderr, "inflate:     fixed codes block%s\n",
298                         state->last ? " (last)" : ""));
299                 state->mode = LEN;              /* decode codes */
300                 break;
301             case 2:                             /* dynamic block */
302                 Tracev((stderr, "inflate:     dynamic codes block%s\n",
303                         state->last ? " (last)" : ""));
304                 state->mode = TABLE;
305                 break;
306             case 3:
307                 strm->msg = (char *)"invalid block type";
308                 state->mode = BAD;
309             }
310             DROPBITS(2);
311             break;
312 
313         case STORED:
314             /* get and verify stored block length */
315             BYTEBITS();                         /* go to byte boundary */
316             NEEDBITS(32);
317             if ((hold & 0xffff) != ((hold >> 16) ^ 0xffff)) {
318                 strm->msg = (char *)"invalid stored block lengths";
319                 state->mode = BAD;
320                 break;
321             }
322             state->length = (unsigned)hold & 0xffff;
323             Tracev((stderr, "inflate:       stored length %u\n",
324                     state->length));
325             INITBITS();
326 
327             /* copy stored block from input to output */
328             while (state->length != 0) {
329                 copy = state->length;
330                 PULL();
331                 ROOM();
332                 if (copy > have) copy = have;
333                 if (copy > left) copy = left;
334                 zmemcpy(put, next, copy);
335                 have -= copy;
336                 next += copy;
337                 left -= copy;
338                 put += copy;
339                 state->length -= copy;
340             }
341             Tracev((stderr, "inflate:       stored end\n"));
342             state->mode = TYPE;
343             break;
344 
345         case TABLE:
346             /* get dynamic table entries descriptor */
347             NEEDBITS(14);
348             state->nlen = BITS(5) + 257;
349             DROPBITS(5);
350             state->ndist = BITS(5) + 1;
351             DROPBITS(5);
352             state->ncode = BITS(4) + 4;
353             DROPBITS(4);
354 #ifndef PKZIP_BUG_WORKAROUND
355             if (state->nlen > 286 || state->ndist > 30) {
356                 strm->msg = (char *)"too many length or distance symbols";
357                 state->mode = BAD;
358                 break;
359             }
360 #endif
361             Tracev((stderr, "inflate:       table sizes ok\n"));
362 
363             /* get code length code lengths (not a typo) */
364             state->have = 0;
365             while (state->have < state->ncode) {
366                 NEEDBITS(3);
367                 state->lens[order[state->have++]] = (unsigned short)BITS(3);
368                 DROPBITS(3);
369             }
370             while (state->have < 19)
371                 state->lens[order[state->have++]] = 0;
372             state->next = state->codes;
373             state->lencode = (code const FAR *)(state->next);
374             state->lenbits = 7;
375             ret = inflate_table(CODES, state->lens, 19, &(state->next),
376                                 &(state->lenbits), state->work);
377             if (ret) {
378                 strm->msg = (char *)"invalid code lengths set";
379                 state->mode = BAD;
380                 break;
381             }
382             Tracev((stderr, "inflate:       code lengths ok\n"));
383 
384             /* get length and distance code code lengths */
385             state->have = 0;
386             while (state->have < state->nlen + state->ndist) {
387                 for (;;) {
388                     here = state->lencode[BITS(state->lenbits)];
389                     if ((unsigned)(here.bits) <= bits) break;
390                     PULLBYTE();
391                 }
392                 if (here.val < 16) {
393                     DROPBITS(here.bits);
394                     state->lens[state->have++] = here.val;
395                 }
396                 else {
397                     if (here.val == 16) {
398                         NEEDBITS(here.bits + 2);
399                         DROPBITS(here.bits);
400                         if (state->have == 0) {
401                             strm->msg = (char *)"invalid bit length repeat";
402                             state->mode = BAD;
403                             break;
404                         }
405                         len = (unsigned)(state->lens[state->have - 1]);
406                         copy = 3 + BITS(2);
407                         DROPBITS(2);
408                     }
409                     else if (here.val == 17) {
410                         NEEDBITS(here.bits + 3);
411                         DROPBITS(here.bits);
412                         len = 0;
413                         copy = 3 + BITS(3);
414                         DROPBITS(3);
415                     }
416                     else {
417                         NEEDBITS(here.bits + 7);
418                         DROPBITS(here.bits);
419                         len = 0;
420                         copy = 11 + BITS(7);
421                         DROPBITS(7);
422                     }
423                     if (state->have + copy > state->nlen + state->ndist) {
424                         strm->msg = (char *)"invalid bit length repeat";
425                         state->mode = BAD;
426                         break;
427                     }
428                     while (copy--)
429                         state->lens[state->have++] = (unsigned short)len;
430                 }
431             }
432 
433             /* handle error breaks in while */
434             if (state->mode == BAD) break;
435 
436             /* check for end-of-block code (better have one) */
437             if (state->lens[256] == 0) {
438                 strm->msg = (char *)"invalid code -- missing end-of-block";
439                 state->mode = BAD;
440                 break;
441             }
442 
443             /* build code tables -- note: do not change the lenbits or distbits
444                values here (9 and 6) without reading the comments in inftrees.h
445                concerning the ENOUGH constants, which depend on those values */
446             state->next = state->codes;
447             state->lencode = (code const FAR *)(state->next);
448             state->lenbits = 9;
449             ret = inflate_table(LENS, state->lens, state->nlen, &(state->next),
450                                 &(state->lenbits), state->work);
451             if (ret) {
452                 strm->msg = (char *)"invalid literal/lengths set";
453                 state->mode = BAD;
454                 break;
455             }
456             state->distcode = (code const FAR *)(state->next);
457             state->distbits = 6;
458             ret = inflate_table(DISTS, state->lens + state->nlen, state->ndist,
459                             &(state->next), &(state->distbits), state->work);
460             if (ret) {
461                 strm->msg = (char *)"invalid distances set";
462                 state->mode = BAD;
463                 break;
464             }
465             Tracev((stderr, "inflate:       codes ok\n"));
466             state->mode = LEN;
467                 /* fallthrough */
468 
469         case LEN:
470             /* use inflate_fast() if we have enough input and output */
471             if (have >= 6 && left >= 258) {
472                 RESTORE();
473                 if (state->whave < state->wsize)
474                     state->whave = state->wsize - left;
475                 inflate_fast(strm, state->wsize);
476                 LOAD();
477                 break;
478             }
479 
480             /* get a literal, length, or end-of-block code */
481             for (;;) {
482                 here = state->lencode[BITS(state->lenbits)];
483                 if ((unsigned)(here.bits) <= bits) break;
484                 PULLBYTE();
485             }
486             if (here.op && (here.op & 0xf0) == 0) {
487                 last = here;
488                 for (;;) {
489                     here = state->lencode[last.val +
490                             (BITS(last.bits + last.op) >> last.bits)];
491                     if ((unsigned)(last.bits + here.bits) <= bits) break;
492                     PULLBYTE();
493                 }
494                 DROPBITS(last.bits);
495             }
496             DROPBITS(here.bits);
497             state->length = (unsigned)here.val;
498 
499             /* process literal */
500             if (here.op == 0) {
501                 Tracevv((stderr, here.val >= 0x20 && here.val < 0x7f ?
502                         "inflate:         literal '%c'\n" :
503                         "inflate:         literal 0x%02x\n", here.val));
504                 ROOM();
505                 *put++ = (unsigned char)(state->length);
506                 left--;
507                 state->mode = LEN;
508                 break;
509             }
510 
511             /* process end of block */
512             if (here.op & 32) {
513                 Tracevv((stderr, "inflate:         end of block\n"));
514                 state->mode = TYPE;
515                 break;
516             }
517 
518             /* invalid code */
519             if (here.op & 64) {
520                 strm->msg = (char *)"invalid literal/length code";
521                 state->mode = BAD;
522                 break;
523             }
524 
525             /* length code -- get extra bits, if any */
526             state->extra = (unsigned)(here.op) & 15;
527             if (state->extra != 0) {
528                 NEEDBITS(state->extra);
529                 state->length += BITS(state->extra);
530                 DROPBITS(state->extra);
531             }
532             Tracevv((stderr, "inflate:         length %u\n", state->length));
533 
534             /* get distance code */
535             for (;;) {
536                 here = state->distcode[BITS(state->distbits)];
537                 if ((unsigned)(here.bits) <= bits) break;
538                 PULLBYTE();
539             }
540             if ((here.op & 0xf0) == 0) {
541                 last = here;
542                 for (;;) {
543                     here = state->distcode[last.val +
544                             (BITS(last.bits + last.op) >> last.bits)];
545                     if ((unsigned)(last.bits + here.bits) <= bits) break;
546                     PULLBYTE();
547                 }
548                 DROPBITS(last.bits);
549             }
550             DROPBITS(here.bits);
551             if (here.op & 64) {
552                 strm->msg = (char *)"invalid distance code";
553                 state->mode = BAD;
554                 break;
555             }
556             state->offset = (unsigned)here.val;
557 
558             /* get distance extra bits, if any */
559             state->extra = (unsigned)(here.op) & 15;
560             if (state->extra != 0) {
561                 NEEDBITS(state->extra);
562                 state->offset += BITS(state->extra);
563                 DROPBITS(state->extra);
564             }
565             if (state->offset > state->wsize - (state->whave < state->wsize ?
566                                                 left : 0)) {
567                 strm->msg = (char *)"invalid distance too far back";
568                 state->mode = BAD;
569                 break;
570             }
571             Tracevv((stderr, "inflate:         distance %u\n", state->offset));
572 
573             /* copy match from window to output */
574             do {
575                 ROOM();
576                 copy = state->wsize - state->offset;
577                 if (copy < left) {
578                     from = put + copy;
579                     copy = left - copy;
580                 }
581                 else {
582                     from = put - state->offset;
583                     copy = left;
584                 }
585                 if (copy > state->length) copy = state->length;
586                 state->length -= copy;
587                 left -= copy;
588                 do {
589                     *put++ = *from++;
590                 } while (--copy);
591             } while (state->length != 0);
592             break;
593 
594         case DONE:
595             /* inflate stream terminated properly */
596             ret = Z_STREAM_END;
597             goto inf_leave;
598 
599         case BAD:
600             ret = Z_DATA_ERROR;
601             goto inf_leave;
602 
603         default:
604             /* can't happen, but makes compilers happy */
605             ret = Z_STREAM_ERROR;
606             goto inf_leave;
607         }
608 
609     /* Write leftover output and return unused input */
610   inf_leave:
611     if (left < state->wsize) {
612         if (out(out_desc, state->window, state->wsize - left) &&
613             ret == Z_STREAM_END)
614             ret = Z_BUF_ERROR;
615     }
616     strm->next_in = next;
617     strm->avail_in = have;
618     return ret;
619 }
620 
621 int ZEXPORT inflateBackEnd(z_streamp strm) {
622     if (strm == Z_NULL || strm->state == Z_NULL || strm->zfree == (free_func)0)
623         return Z_STREAM_ERROR;
624     ZFREE(strm, strm->state);
625     strm->state = Z_NULL;
626     Tracev((stderr, "inflate: end\n"));
627     return Z_OK;
628 }
629