xref: /illumos-gate/usr/src/contrib/zlib/infback.c (revision ce17336ed725d3b7fdff67bf0a0ee2b55018fec6)
1 /* infback.c -- inflate using a call-back interface
2  * Copyright (C) 1995-2022 Mark Adler
3  * For conditions of distribution and use, see copyright notice in zlib.h
4  */
5 
6 /*
7    This code is largely copied from inflate.c.  Normally either infback.o or
8    inflate.o would be linked into an application--not both.  The interface
9    with inffast.c is retained so that optimized assembler-coded versions of
10    inflate_fast() can be used with either inflate.c or infback.c.
11  */
12 
13 #include "zutil.h"
14 #include "inftrees.h"
15 #include "inflate.h"
16 #include "inffast.h"
17 
18 /* function prototypes */
19 local void fixedtables OF((struct inflate_state FAR *state));
20 
21 /*
22    strm provides memory allocation functions in zalloc and zfree, or
23    Z_NULL to use the library memory allocation functions.
24 
25    windowBits is in the range 8..15, and window is a user-supplied
26    window and output buffer that is 2**windowBits bytes.
27  */
28 int ZEXPORT inflateBackInit_(z_streamp strm, int windowBits,
29     unsigned char FAR *window, const char *version, int stream_size)
30 {
31     struct inflate_state FAR *state;
32 
33     if (version == Z_NULL || version[0] != ZLIB_VERSION[0] ||
34         stream_size != (int)(sizeof(z_stream)))
35         return Z_VERSION_ERROR;
36     if (strm == Z_NULL || window == Z_NULL ||
37         windowBits < 8 || windowBits > 15)
38         return Z_STREAM_ERROR;
39     strm->msg = Z_NULL;                 /* in case we return an error */
40     if (strm->zalloc == (alloc_func)0) {
41 #ifdef Z_SOLO
42         return Z_STREAM_ERROR;
43 #else
44         strm->zalloc = zcalloc;
45         strm->opaque = (voidpf)0;
46 #endif
47     }
48     if (strm->zfree == (free_func)0)
49 #ifdef Z_SOLO
50         return Z_STREAM_ERROR;
51 #else
52         strm->zfree = zcfree;
53 #endif
54     state = (struct inflate_state FAR *)ZALLOC(strm, 1,
55                                                sizeof(struct inflate_state));
56     if (state == Z_NULL) return Z_MEM_ERROR;
57     Tracev((stderr, "inflate: allocated\n"));
58     strm->state = (struct internal_state FAR *)state;
59     state->dmax = 32768U;
60     state->wbits = (uInt)windowBits;
61     state->wsize = 1U << windowBits;
62     state->window = window;
63     state->wnext = 0;
64     state->whave = 0;
65     return Z_OK;
66 }
67 
68 /*
69    Return state with length and distance decoding tables and index sizes set to
70    fixed code decoding.  Normally this returns fixed tables from inffixed.h.
71    If BUILDFIXED is defined, then instead this routine builds the tables the
72    first time it's called, and returns those tables the first time and
73    thereafter.  This reduces the size of the code by about 2K bytes, in
74    exchange for a little execution time.  However, BUILDFIXED should not be
75    used for threaded applications, since the rewriting of the tables and virgin
76    may not be thread-safe.
77  */
78 local void fixedtables(struct inflate_state FAR *state)
79 {
80 #ifdef BUILDFIXED
81     static int virgin = 1;
82     static code *lenfix, *distfix;
83     static code fixed[544];
84 
85     /* build fixed huffman tables if first call (may not be thread safe) */
86     if (virgin) {
87         unsigned sym, bits;
88         static code *next;
89 
90         /* literal/length table */
91         sym = 0;
92         while (sym < 144) state->lens[sym++] = 8;
93         while (sym < 256) state->lens[sym++] = 9;
94         while (sym < 280) state->lens[sym++] = 7;
95         while (sym < 288) state->lens[sym++] = 8;
96         next = fixed;
97         lenfix = next;
98         bits = 9;
99         inflate_table(LENS, state->lens, 288, &(next), &(bits), state->work);
100 
101         /* distance table */
102         sym = 0;
103         while (sym < 32) state->lens[sym++] = 5;
104         distfix = next;
105         bits = 5;
106         inflate_table(DISTS, state->lens, 32, &(next), &(bits), state->work);
107 
108         /* do this just once */
109         virgin = 0;
110     }
111 #else /* !BUILDFIXED */
112 #   include "inffixed.h"
113 #endif /* BUILDFIXED */
114     state->lencode = lenfix;
115     state->lenbits = 9;
116     state->distcode = distfix;
117     state->distbits = 5;
118 }
119 
120 /* Macros for inflateBack(): */
121 
122 /* Load returned state from inflate_fast() */
123 #define LOAD() \
124     do { \
125         put = strm->next_out; \
126         left = strm->avail_out; \
127         next = strm->next_in; \
128         have = strm->avail_in; \
129         hold = state->hold; \
130         bits = state->bits; \
131     } while (0)
132 
133 /* Set state from registers for inflate_fast() */
134 #define RESTORE() \
135     do { \
136         strm->next_out = put; \
137         strm->avail_out = left; \
138         strm->next_in = next; \
139         strm->avail_in = have; \
140         state->hold = hold; \
141         state->bits = bits; \
142     } while (0)
143 
144 /* Clear the input bit accumulator */
145 #define INITBITS() \
146     do { \
147         hold = 0; \
148         bits = 0; \
149     } while (0)
150 
151 /* Assure that some input is available.  If input is requested, but denied,
152    then return a Z_BUF_ERROR from inflateBack(). */
153 #define PULL() \
154     do { \
155         if (have == 0) { \
156             have = in(in_desc, &next); \
157             if (have == 0) { \
158                 next = Z_NULL; \
159                 ret = Z_BUF_ERROR; \
160                 goto inf_leave; \
161             } \
162         } \
163     } while (0)
164 
165 /* Get a byte of input into the bit accumulator, or return from inflateBack()
166    with an error if there is no input available. */
167 #define PULLBYTE() \
168     do { \
169         PULL(); \
170         have--; \
171         hold += (unsigned long)(*next++) << bits; \
172         bits += 8; \
173     } while (0)
174 
175 /* Assure that there are at least n bits in the bit accumulator.  If there is
176    not enough available input to do that, then return from inflateBack() with
177    an error. */
178 #define NEEDBITS(n) \
179     do { \
180         while (bits < (unsigned)(n)) \
181             PULLBYTE(); \
182     } while (0)
183 
184 /* Return the low n bits of the bit accumulator (n < 16) */
185 #define BITS(n) \
186     ((unsigned)hold & ((1U << (n)) - 1))
187 
188 /* Remove n bits from the bit accumulator */
189 #define DROPBITS(n) \
190     do { \
191         hold >>= (n); \
192         bits -= (unsigned)(n); \
193     } while (0)
194 
195 /* Remove zero to seven bits as needed to go to a byte boundary */
196 #define BYTEBITS() \
197     do { \
198         hold >>= bits & 7; \
199         bits -= bits & 7; \
200     } while (0)
201 
202 /* Assure that some output space is available, by writing out the window
203    if it's full.  If the write fails, return from inflateBack() with a
204    Z_BUF_ERROR. */
205 #define ROOM() \
206     do { \
207         if (left == 0) { \
208             put = state->window; \
209             left = state->wsize; \
210             state->whave = left; \
211             if (out(out_desc, put, left)) { \
212                 ret = Z_BUF_ERROR; \
213                 goto inf_leave; \
214             } \
215         } \
216     } while (0)
217 
218 /*
219    strm provides the memory allocation functions and window buffer on input,
220    and provides information on the unused input on return.  For Z_DATA_ERROR
221    returns, strm will also provide an error message.
222 
223    in() and out() are the call-back input and output functions.  When
224    inflateBack() needs more input, it calls in().  When inflateBack() has
225    filled the window with output, or when it completes with data in the
226    window, it calls out() to write out the data.  The application must not
227    change the provided input until in() is called again or inflateBack()
228    returns.  The application must not change the window/output buffer until
229    inflateBack() returns.
230 
231    in() and out() are called with a descriptor parameter provided in the
232    inflateBack() call.  This parameter can be a structure that provides the
233    information required to do the read or write, as well as accumulated
234    information on the input and output such as totals and check values.
235 
236    in() should return zero on failure.  out() should return non-zero on
237    failure.  If either in() or out() fails, than inflateBack() returns a
238    Z_BUF_ERROR.  strm->next_in can be checked for Z_NULL to see whether it
239    was in() or out() that caused in the error.  Otherwise,  inflateBack()
240    returns Z_STREAM_END on success, Z_DATA_ERROR for an deflate format
241    error, or Z_MEM_ERROR if it could not allocate memory for the state.
242    inflateBack() can also return Z_STREAM_ERROR if the input parameters
243    are not correct, i.e. strm is Z_NULL or the state was not initialized.
244  */
245 int ZEXPORT inflateBack(z_streamp strm, in_func in, void FAR *in_desc,
246     out_func out, void FAR *out_desc)
247 {
248     struct inflate_state FAR *state;
249     z_const unsigned char FAR *next;    /* next input */
250     unsigned char FAR *put;     /* next output */
251     unsigned have, left;        /* available input and output */
252     unsigned long hold;         /* bit buffer */
253     unsigned bits;              /* bits in bit buffer */
254     unsigned copy;              /* number of stored or match bytes to copy */
255     unsigned char FAR *from;    /* where to copy match bytes from */
256     code here;                  /* current decoding table entry */
257     code last;                  /* parent table entry */
258     unsigned len;               /* length to copy for repeats, bits to drop */
259     int ret;                    /* return code */
260     static const unsigned short order[19] = /* permutation of code lengths */
261         {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
262 
263     /* Check that the strm exists and that the state was initialized */
264     if (strm == Z_NULL || strm->state == Z_NULL)
265         return Z_STREAM_ERROR;
266     state = (struct inflate_state FAR *)strm->state;
267 
268     /* Reset the state */
269     strm->msg = Z_NULL;
270     state->mode = TYPE;
271     state->last = 0;
272     state->whave = 0;
273     next = strm->next_in;
274     have = next != Z_NULL ? strm->avail_in : 0;
275     hold = 0;
276     bits = 0;
277     put = state->window;
278     left = state->wsize;
279 
280     /* Inflate until end of block marked as last */
281     for (;;)
282         switch (state->mode) {
283         case TYPE:
284             /* determine and dispatch block type */
285             if (state->last) {
286                 BYTEBITS();
287                 state->mode = DONE;
288                 break;
289             }
290             NEEDBITS(3);
291             state->last = BITS(1);
292             DROPBITS(1);
293             switch (BITS(2)) {
294             case 0:                             /* stored block */
295                 Tracev((stderr, "inflate:     stored block%s\n",
296                         state->last ? " (last)" : ""));
297                 state->mode = STORED;
298                 break;
299             case 1:                             /* fixed block */
300                 fixedtables(state);
301                 Tracev((stderr, "inflate:     fixed codes block%s\n",
302                         state->last ? " (last)" : ""));
303                 state->mode = LEN;              /* decode codes */
304                 break;
305             case 2:                             /* dynamic block */
306                 Tracev((stderr, "inflate:     dynamic codes block%s\n",
307                         state->last ? " (last)" : ""));
308                 state->mode = TABLE;
309                 break;
310             case 3:
311                 strm->msg = (char *)"invalid block type";
312                 state->mode = BAD;
313             }
314             DROPBITS(2);
315             break;
316 
317         case STORED:
318             /* get and verify stored block length */
319             BYTEBITS();                         /* go to byte boundary */
320             NEEDBITS(32);
321             if ((hold & 0xffff) != ((hold >> 16) ^ 0xffff)) {
322                 strm->msg = (char *)"invalid stored block lengths";
323                 state->mode = BAD;
324                 break;
325             }
326             state->length = (unsigned)hold & 0xffff;
327             Tracev((stderr, "inflate:       stored length %u\n",
328                     state->length));
329             INITBITS();
330 
331             /* copy stored block from input to output */
332             while (state->length != 0) {
333                 copy = state->length;
334                 PULL();
335                 ROOM();
336                 if (copy > have) copy = have;
337                 if (copy > left) copy = left;
338                 zmemcpy(put, next, copy);
339                 have -= copy;
340                 next += copy;
341                 left -= copy;
342                 put += copy;
343                 state->length -= copy;
344             }
345             Tracev((stderr, "inflate:       stored end\n"));
346             state->mode = TYPE;
347             break;
348 
349         case TABLE:
350             /* get dynamic table entries descriptor */
351             NEEDBITS(14);
352             state->nlen = BITS(5) + 257;
353             DROPBITS(5);
354             state->ndist = BITS(5) + 1;
355             DROPBITS(5);
356             state->ncode = BITS(4) + 4;
357             DROPBITS(4);
358 #ifndef PKZIP_BUG_WORKAROUND
359             if (state->nlen > 286 || state->ndist > 30) {
360                 strm->msg = (char *)"too many length or distance symbols";
361                 state->mode = BAD;
362                 break;
363             }
364 #endif
365             Tracev((stderr, "inflate:       table sizes ok\n"));
366 
367             /* get code length code lengths (not a typo) */
368             state->have = 0;
369             while (state->have < state->ncode) {
370                 NEEDBITS(3);
371                 state->lens[order[state->have++]] = (unsigned short)BITS(3);
372                 DROPBITS(3);
373             }
374             while (state->have < 19)
375                 state->lens[order[state->have++]] = 0;
376             state->next = state->codes;
377             state->lencode = (code const FAR *)(state->next);
378             state->lenbits = 7;
379             ret = inflate_table(CODES, state->lens, 19, &(state->next),
380                                 &(state->lenbits), state->work);
381             if (ret) {
382                 strm->msg = (char *)"invalid code lengths set";
383                 state->mode = BAD;
384                 break;
385             }
386             Tracev((stderr, "inflate:       code lengths ok\n"));
387 
388             /* get length and distance code code lengths */
389             state->have = 0;
390             while (state->have < state->nlen + state->ndist) {
391                 for (;;) {
392                     here = state->lencode[BITS(state->lenbits)];
393                     if ((unsigned)(here.bits) <= bits) break;
394                     PULLBYTE();
395                 }
396                 if (here.val < 16) {
397                     DROPBITS(here.bits);
398                     state->lens[state->have++] = here.val;
399                 }
400                 else {
401                     if (here.val == 16) {
402                         NEEDBITS(here.bits + 2);
403                         DROPBITS(here.bits);
404                         if (state->have == 0) {
405                             strm->msg = (char *)"invalid bit length repeat";
406                             state->mode = BAD;
407                             break;
408                         }
409                         len = (unsigned)(state->lens[state->have - 1]);
410                         copy = 3 + BITS(2);
411                         DROPBITS(2);
412                     }
413                     else if (here.val == 17) {
414                         NEEDBITS(here.bits + 3);
415                         DROPBITS(here.bits);
416                         len = 0;
417                         copy = 3 + BITS(3);
418                         DROPBITS(3);
419                     }
420                     else {
421                         NEEDBITS(here.bits + 7);
422                         DROPBITS(here.bits);
423                         len = 0;
424                         copy = 11 + BITS(7);
425                         DROPBITS(7);
426                     }
427                     if (state->have + copy > state->nlen + state->ndist) {
428                         strm->msg = (char *)"invalid bit length repeat";
429                         state->mode = BAD;
430                         break;
431                     }
432                     while (copy--)
433                         state->lens[state->have++] = (unsigned short)len;
434                 }
435             }
436 
437             /* handle error breaks in while */
438             if (state->mode == BAD) break;
439 
440             /* check for end-of-block code (better have one) */
441             if (state->lens[256] == 0) {
442                 strm->msg = (char *)"invalid code -- missing end-of-block";
443                 state->mode = BAD;
444                 break;
445             }
446 
447             /* build code tables -- note: do not change the lenbits or distbits
448                values here (9 and 6) without reading the comments in inftrees.h
449                concerning the ENOUGH constants, which depend on those values */
450             state->next = state->codes;
451             state->lencode = (code const FAR *)(state->next);
452             state->lenbits = 9;
453             ret = inflate_table(LENS, state->lens, state->nlen, &(state->next),
454                                 &(state->lenbits), state->work);
455             if (ret) {
456                 strm->msg = (char *)"invalid literal/lengths set";
457                 state->mode = BAD;
458                 break;
459             }
460             state->distcode = (code const FAR *)(state->next);
461             state->distbits = 6;
462             ret = inflate_table(DISTS, state->lens + state->nlen, state->ndist,
463                             &(state->next), &(state->distbits), state->work);
464             if (ret) {
465                 strm->msg = (char *)"invalid distances set";
466                 state->mode = BAD;
467                 break;
468             }
469             Tracev((stderr, "inflate:       codes ok\n"));
470             state->mode = LEN;
471 	    /* FALLTHROUGH */
472         case LEN:
473             /* use inflate_fast() if we have enough input and output */
474             if (have >= 6 && left >= 258) {
475                 RESTORE();
476                 if (state->whave < state->wsize)
477                     state->whave = state->wsize - left;
478                 inflate_fast(strm, state->wsize);
479                 LOAD();
480                 break;
481             }
482 
483             /* get a literal, length, or end-of-block code */
484             for (;;) {
485                 here = state->lencode[BITS(state->lenbits)];
486                 if ((unsigned)(here.bits) <= bits) break;
487                 PULLBYTE();
488             }
489             if (here.op && (here.op & 0xf0) == 0) {
490                 last = here;
491                 for (;;) {
492                     here = state->lencode[last.val +
493                             (BITS(last.bits + last.op) >> last.bits)];
494                     if ((unsigned)(last.bits + here.bits) <= bits) break;
495                     PULLBYTE();
496                 }
497                 DROPBITS(last.bits);
498             }
499             DROPBITS(here.bits);
500             state->length = (unsigned)here.val;
501 
502             /* process literal */
503             if (here.op == 0) {
504                 Tracevv((stderr, here.val >= 0x20 && here.val < 0x7f ?
505                         "inflate:         literal '%c'\n" :
506                         "inflate:         literal 0x%02x\n", here.val));
507                 ROOM();
508                 *put++ = (unsigned char)(state->length);
509                 left--;
510                 state->mode = LEN;
511                 break;
512             }
513 
514             /* process end of block */
515             if (here.op & 32) {
516                 Tracevv((stderr, "inflate:         end of block\n"));
517                 state->mode = TYPE;
518                 break;
519             }
520 
521             /* invalid code */
522             if (here.op & 64) {
523                 strm->msg = (char *)"invalid literal/length code";
524                 state->mode = BAD;
525                 break;
526             }
527 
528             /* length code -- get extra bits, if any */
529             state->extra = (unsigned)(here.op) & 15;
530             if (state->extra != 0) {
531                 NEEDBITS(state->extra);
532                 state->length += BITS(state->extra);
533                 DROPBITS(state->extra);
534             }
535             Tracevv((stderr, "inflate:         length %u\n", state->length));
536 
537             /* get distance code */
538             for (;;) {
539                 here = state->distcode[BITS(state->distbits)];
540                 if ((unsigned)(here.bits) <= bits) break;
541                 PULLBYTE();
542             }
543             if ((here.op & 0xf0) == 0) {
544                 last = here;
545                 for (;;) {
546                     here = state->distcode[last.val +
547                             (BITS(last.bits + last.op) >> last.bits)];
548                     if ((unsigned)(last.bits + here.bits) <= bits) break;
549                     PULLBYTE();
550                 }
551                 DROPBITS(last.bits);
552             }
553             DROPBITS(here.bits);
554             if (here.op & 64) {
555                 strm->msg = (char *)"invalid distance code";
556                 state->mode = BAD;
557                 break;
558             }
559             state->offset = (unsigned)here.val;
560 
561             /* get distance extra bits, if any */
562             state->extra = (unsigned)(here.op) & 15;
563             if (state->extra != 0) {
564                 NEEDBITS(state->extra);
565                 state->offset += BITS(state->extra);
566                 DROPBITS(state->extra);
567             }
568             if (state->offset > state->wsize - (state->whave < state->wsize ?
569                                                 left : 0)) {
570                 strm->msg = (char *)"invalid distance too far back";
571                 state->mode = BAD;
572                 break;
573             }
574             Tracevv((stderr, "inflate:         distance %u\n", state->offset));
575 
576             /* copy match from window to output */
577             do {
578                 ROOM();
579                 copy = state->wsize - state->offset;
580                 if (copy < left) {
581                     from = put + copy;
582                     copy = left - copy;
583                 }
584                 else {
585                     from = put - state->offset;
586                     copy = left;
587                 }
588                 if (copy > state->length) copy = state->length;
589                 state->length -= copy;
590                 left -= copy;
591                 do {
592                     *put++ = *from++;
593                 } while (--copy);
594             } while (state->length != 0);
595             break;
596 
597         case DONE:
598             /* inflate stream terminated properly -- write leftover output */
599             ret = Z_STREAM_END;
600             if (left < state->wsize) {
601                 if (out(out_desc, state->window, state->wsize - left))
602                     ret = Z_BUF_ERROR;
603             }
604             goto inf_leave;
605 
606         case BAD:
607             ret = Z_DATA_ERROR;
608             goto inf_leave;
609 
610         default:                /* can't happen, but makes compilers happy */
611             ret = Z_STREAM_ERROR;
612             goto inf_leave;
613         }
614 
615     /* Return unused input */
616   inf_leave:
617     strm->next_in = next;
618     strm->avail_in = have;
619     return ret;
620 }
621 
622 int ZEXPORT inflateBackEnd(z_streamp strm)
623 {
624     if (strm == Z_NULL || strm->state == Z_NULL || strm->zfree == (free_func)0)
625         return Z_STREAM_ERROR;
626     ZFREE(strm, strm->state);
627     strm->state = Z_NULL;
628     Tracev((stderr, "inflate: end\n"));
629     return Z_OK;
630 }
631