xref: /freebsd/sys/contrib/zlib/infback.c (revision 1dfcff294e44d4b45813288ef4095c36abb22f0e)
1 /* infback.c -- inflate using a call-back interface
2  * Copyright (C) 1995-2022 Mark Adler
3  * For conditions of distribution and use, see copyright notice in zlib.h
4  */
5 
6 /*
7    This code is largely copied from inflate.c.  Normally either infback.o or
8    inflate.o would be linked into an application--not both.  The interface
9    with inffast.c is retained so that optimized assembler-coded versions of
10    inflate_fast() can be used with either inflate.c or infback.c.
11  */
12 
13 #include "zutil.h"
14 #include "inftrees.h"
15 #include "inflate.h"
16 #include "inffast.h"
17 
18 /* function prototypes */
19 local void fixedtables OF((struct inflate_state FAR *state));
20 
21 /*
22    strm provides memory allocation functions in zalloc and zfree, or
23    Z_NULL to use the library memory allocation functions.
24 
25    windowBits is in the range 8..15, and window is a user-supplied
26    window and output buffer that is 2**windowBits bytes.
27  */
28 int ZEXPORT inflateBackInit_(strm, windowBits, window, version, stream_size)
29 z_streamp strm;
30 int windowBits;
31 unsigned char FAR *window;
32 const char *version;
33 int stream_size;
34 {
35     struct inflate_state FAR *state;
36 
37     if (version == Z_NULL || version[0] != ZLIB_VERSION[0] ||
38         stream_size != (int)(sizeof(z_stream)))
39         return Z_VERSION_ERROR;
40     if (strm == Z_NULL || window == Z_NULL ||
41         windowBits < 8 || windowBits > 15)
42         return Z_STREAM_ERROR;
43     strm->msg = Z_NULL;                 /* in case we return an error */
44     if (strm->zalloc == (alloc_func)0) {
45 #if defined(Z_SOLO) && !defined(_KERNEL)
46         return Z_STREAM_ERROR;
47 #else
48         strm->zalloc = zcalloc;
49         strm->opaque = (voidpf)0;
50 #endif
51     }
52     if (strm->zfree == (free_func)0)
53 #if defined(Z_SOLO) && !defined(_KERNEL)
54         return Z_STREAM_ERROR;
55 #else
56     strm->zfree = zcfree;
57 #endif
58     state = (struct inflate_state FAR *)ZALLOC(strm, 1,
59                                                sizeof(struct inflate_state));
60     if (state == Z_NULL) return Z_MEM_ERROR;
61     Tracev((stderr, "inflate: allocated\n"));
62     strm->state = (struct internal_state FAR *)state;
63     state->dmax = 32768U;
64     state->wbits = (uInt)windowBits;
65     state->wsize = 1U << windowBits;
66     state->window = window;
67     state->wnext = 0;
68     state->whave = 0;
69     return Z_OK;
70 }
71 
72 /*
73    Return state with length and distance decoding tables and index sizes set to
74    fixed code decoding.  Normally this returns fixed tables from inffixed.h.
75    If BUILDFIXED is defined, then instead this routine builds the tables the
76    first time it's called, and returns those tables the first time and
77    thereafter.  This reduces the size of the code by about 2K bytes, in
78    exchange for a little execution time.  However, BUILDFIXED should not be
79    used for threaded applications, since the rewriting of the tables and virgin
80    may not be thread-safe.
81  */
82 local void fixedtables(state)
83 struct inflate_state FAR *state;
84 {
85 #ifdef BUILDFIXED
86     static int virgin = 1;
87     static code *lenfix, *distfix;
88     static code fixed[544];
89 
90     /* build fixed huffman tables if first call (may not be thread safe) */
91     if (virgin) {
92         unsigned sym, bits;
93         static code *next;
94 
95         /* literal/length table */
96         sym = 0;
97         while (sym < 144) state->lens[sym++] = 8;
98         while (sym < 256) state->lens[sym++] = 9;
99         while (sym < 280) state->lens[sym++] = 7;
100         while (sym < 288) state->lens[sym++] = 8;
101         next = fixed;
102         lenfix = next;
103         bits = 9;
104         inflate_table(LENS, state->lens, 288, &(next), &(bits), state->work);
105 
106         /* distance table */
107         sym = 0;
108         while (sym < 32) state->lens[sym++] = 5;
109         distfix = next;
110         bits = 5;
111         inflate_table(DISTS, state->lens, 32, &(next), &(bits), state->work);
112 
113         /* do this just once */
114         virgin = 0;
115     }
116 #else /* !BUILDFIXED */
117 #   include "inffixed.h"
118 #endif /* BUILDFIXED */
119     state->lencode = lenfix;
120     state->lenbits = 9;
121     state->distcode = distfix;
122     state->distbits = 5;
123 }
124 
125 /* Macros for inflateBack(): */
126 
127 /* Load returned state from inflate_fast() */
128 #define LOAD() \
129     do { \
130         put = strm->next_out; \
131         left = strm->avail_out; \
132         next = strm->next_in; \
133         have = strm->avail_in; \
134         hold = state->hold; \
135         bits = state->bits; \
136     } while (0)
137 
138 /* Set state from registers for inflate_fast() */
139 #define RESTORE() \
140     do { \
141         strm->next_out = put; \
142         strm->avail_out = left; \
143         strm->next_in = next; \
144         strm->avail_in = have; \
145         state->hold = hold; \
146         state->bits = bits; \
147     } while (0)
148 
149 /* Clear the input bit accumulator */
150 #define INITBITS() \
151     do { \
152         hold = 0; \
153         bits = 0; \
154     } while (0)
155 
156 /* Assure that some input is available.  If input is requested, but denied,
157    then return a Z_BUF_ERROR from inflateBack(). */
158 #define PULL() \
159     do { \
160         if (have == 0) { \
161             have = in(in_desc, &next); \
162             if (have == 0) { \
163                 next = Z_NULL; \
164                 ret = Z_BUF_ERROR; \
165                 goto inf_leave; \
166             } \
167         } \
168     } while (0)
169 
170 /* Get a byte of input into the bit accumulator, or return from inflateBack()
171    with an error if there is no input available. */
172 #define PULLBYTE() \
173     do { \
174         PULL(); \
175         have--; \
176         hold += (unsigned long)(*next++) << bits; \
177         bits += 8; \
178     } while (0)
179 
180 /* Assure that there are at least n bits in the bit accumulator.  If there is
181    not enough available input to do that, then return from inflateBack() with
182    an error. */
183 #define NEEDBITS(n) \
184     do { \
185         while (bits < (unsigned)(n)) \
186             PULLBYTE(); \
187     } while (0)
188 
189 /* Return the low n bits of the bit accumulator (n < 16) */
190 #define BITS(n) \
191     ((unsigned)hold & ((1U << (n)) - 1))
192 
193 /* Remove n bits from the bit accumulator */
194 #define DROPBITS(n) \
195     do { \
196         hold >>= (n); \
197         bits -= (unsigned)(n); \
198     } while (0)
199 
200 /* Remove zero to seven bits as needed to go to a byte boundary */
201 #define BYTEBITS() \
202     do { \
203         hold >>= bits & 7; \
204         bits -= bits & 7; \
205     } while (0)
206 
207 /* Assure that some output space is available, by writing out the window
208    if it's full.  If the write fails, return from inflateBack() with a
209    Z_BUF_ERROR. */
210 #define ROOM() \
211     do { \
212         if (left == 0) { \
213             put = state->window; \
214             left = state->wsize; \
215             state->whave = left; \
216             if (out(out_desc, put, left)) { \
217                 ret = Z_BUF_ERROR; \
218                 goto inf_leave; \
219             } \
220         } \
221     } while (0)
222 
223 /*
224    strm provides the memory allocation functions and window buffer on input,
225    and provides information on the unused input on return.  For Z_DATA_ERROR
226    returns, strm will also provide an error message.
227 
228    in() and out() are the call-back input and output functions.  When
229    inflateBack() needs more input, it calls in().  When inflateBack() has
230    filled the window with output, or when it completes with data in the
231    window, it calls out() to write out the data.  The application must not
232    change the provided input until in() is called again or inflateBack()
233    returns.  The application must not change the window/output buffer until
234    inflateBack() returns.
235 
236    in() and out() are called with a descriptor parameter provided in the
237    inflateBack() call.  This parameter can be a structure that provides the
238    information required to do the read or write, as well as accumulated
239    information on the input and output such as totals and check values.
240 
241    in() should return zero on failure.  out() should return non-zero on
242    failure.  If either in() or out() fails, than inflateBack() returns a
243    Z_BUF_ERROR.  strm->next_in can be checked for Z_NULL to see whether it
244    was in() or out() that caused in the error.  Otherwise,  inflateBack()
245    returns Z_STREAM_END on success, Z_DATA_ERROR for an deflate format
246    error, or Z_MEM_ERROR if it could not allocate memory for the state.
247    inflateBack() can also return Z_STREAM_ERROR if the input parameters
248    are not correct, i.e. strm is Z_NULL or the state was not initialized.
249  */
250 int ZEXPORT inflateBack(strm, in, in_desc, out, out_desc)
251 z_streamp strm;
252 in_func in;
253 void FAR *in_desc;
254 out_func out;
255 void FAR *out_desc;
256 {
257     struct inflate_state FAR *state;
258     z_const unsigned char FAR *next;    /* next input */
259     unsigned char FAR *put;     /* next output */
260     unsigned have, left;        /* available input and output */
261     unsigned long hold;         /* bit buffer */
262     unsigned bits;              /* bits in bit buffer */
263     unsigned copy;              /* number of stored or match bytes to copy */
264     unsigned char FAR *from;    /* where to copy match bytes from */
265     code here;                  /* current decoding table entry */
266     code last;                  /* parent table entry */
267     unsigned len;               /* length to copy for repeats, bits to drop */
268     int ret;                    /* return code */
269     static const unsigned short order[19] = /* permutation of code lengths */
270         {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
271 
272     /* Check that the strm exists and that the state was initialized */
273     if (strm == Z_NULL || strm->state == Z_NULL)
274         return Z_STREAM_ERROR;
275     state = (struct inflate_state FAR *)strm->state;
276 
277     /* Reset the state */
278     strm->msg = Z_NULL;
279     state->mode = TYPE;
280     state->last = 0;
281     state->whave = 0;
282     next = strm->next_in;
283     have = next != Z_NULL ? strm->avail_in : 0;
284     hold = 0;
285     bits = 0;
286     put = state->window;
287     left = state->wsize;
288 
289     /* Inflate until end of block marked as last */
290     for (;;)
291         switch (state->mode) {
292         case TYPE:
293             /* determine and dispatch block type */
294             if (state->last) {
295                 BYTEBITS();
296                 state->mode = DONE;
297                 break;
298             }
299             NEEDBITS(3);
300             state->last = BITS(1);
301             DROPBITS(1);
302             switch (BITS(2)) {
303             case 0:                             /* stored block */
304                 Tracev((stderr, "inflate:     stored block%s\n",
305                         state->last ? " (last)" : ""));
306                 state->mode = STORED;
307                 break;
308             case 1:                             /* fixed block */
309                 fixedtables(state);
310                 Tracev((stderr, "inflate:     fixed codes block%s\n",
311                         state->last ? " (last)" : ""));
312                 state->mode = LEN;              /* decode codes */
313                 break;
314             case 2:                             /* dynamic block */
315                 Tracev((stderr, "inflate:     dynamic codes block%s\n",
316                         state->last ? " (last)" : ""));
317                 state->mode = TABLE;
318                 break;
319             case 3:
320                 strm->msg = (char *)"invalid block type";
321                 state->mode = BAD;
322             }
323             DROPBITS(2);
324             break;
325 
326         case STORED:
327             /* get and verify stored block length */
328             BYTEBITS();                         /* go to byte boundary */
329             NEEDBITS(32);
330             if ((hold & 0xffff) != ((hold >> 16) ^ 0xffff)) {
331                 strm->msg = (char *)"invalid stored block lengths";
332                 state->mode = BAD;
333                 break;
334             }
335             state->length = (unsigned)hold & 0xffff;
336             Tracev((stderr, "inflate:       stored length %u\n",
337                     state->length));
338             INITBITS();
339 
340             /* copy stored block from input to output */
341             while (state->length != 0) {
342                 copy = state->length;
343                 PULL();
344                 ROOM();
345                 if (copy > have) copy = have;
346                 if (copy > left) copy = left;
347                 zmemcpy(put, next, copy);
348                 have -= copy;
349                 next += copy;
350                 left -= copy;
351                 put += copy;
352                 state->length -= copy;
353             }
354             Tracev((stderr, "inflate:       stored end\n"));
355             state->mode = TYPE;
356             break;
357 
358         case TABLE:
359             /* get dynamic table entries descriptor */
360             NEEDBITS(14);
361             state->nlen = BITS(5) + 257;
362             DROPBITS(5);
363             state->ndist = BITS(5) + 1;
364             DROPBITS(5);
365             state->ncode = BITS(4) + 4;
366             DROPBITS(4);
367 #ifndef PKZIP_BUG_WORKAROUND
368             if (state->nlen > 286 || state->ndist > 30) {
369                 strm->msg = (char *)"too many length or distance symbols";
370                 state->mode = BAD;
371                 break;
372             }
373 #endif
374             Tracev((stderr, "inflate:       table sizes ok\n"));
375 
376             /* get code length code lengths (not a typo) */
377             state->have = 0;
378             while (state->have < state->ncode) {
379                 NEEDBITS(3);
380                 state->lens[order[state->have++]] = (unsigned short)BITS(3);
381                 DROPBITS(3);
382             }
383             while (state->have < 19)
384                 state->lens[order[state->have++]] = 0;
385             state->next = state->codes;
386             state->lencode = (code const FAR *)(state->next);
387             state->lenbits = 7;
388             ret = inflate_table(CODES, state->lens, 19, &(state->next),
389                                 &(state->lenbits), state->work);
390             if (ret) {
391                 strm->msg = (char *)"invalid code lengths set";
392                 state->mode = BAD;
393                 break;
394             }
395             Tracev((stderr, "inflate:       code lengths ok\n"));
396 
397             /* get length and distance code code lengths */
398             state->have = 0;
399             while (state->have < state->nlen + state->ndist) {
400                 for (;;) {
401                     here = state->lencode[BITS(state->lenbits)];
402                     if ((unsigned)(here.bits) <= bits) break;
403                     PULLBYTE();
404                 }
405                 if (here.val < 16) {
406                     DROPBITS(here.bits);
407                     state->lens[state->have++] = here.val;
408                 }
409                 else {
410                     if (here.val == 16) {
411                         NEEDBITS(here.bits + 2);
412                         DROPBITS(here.bits);
413                         if (state->have == 0) {
414                             strm->msg = (char *)"invalid bit length repeat";
415                             state->mode = BAD;
416                             break;
417                         }
418                         len = (unsigned)(state->lens[state->have - 1]);
419                         copy = 3 + BITS(2);
420                         DROPBITS(2);
421                     }
422                     else if (here.val == 17) {
423                         NEEDBITS(here.bits + 3);
424                         DROPBITS(here.bits);
425                         len = 0;
426                         copy = 3 + BITS(3);
427                         DROPBITS(3);
428                     }
429                     else {
430                         NEEDBITS(here.bits + 7);
431                         DROPBITS(here.bits);
432                         len = 0;
433                         copy = 11 + BITS(7);
434                         DROPBITS(7);
435                     }
436                     if (state->have + copy > state->nlen + state->ndist) {
437                         strm->msg = (char *)"invalid bit length repeat";
438                         state->mode = BAD;
439                         break;
440                     }
441                     while (copy--)
442                         state->lens[state->have++] = (unsigned short)len;
443                 }
444             }
445 
446             /* handle error breaks in while */
447             if (state->mode == BAD) break;
448 
449             /* check for end-of-block code (better have one) */
450             if (state->lens[256] == 0) {
451                 strm->msg = (char *)"invalid code -- missing end-of-block";
452                 state->mode = BAD;
453                 break;
454             }
455 
456             /* build code tables -- note: do not change the lenbits or distbits
457                values here (9 and 6) without reading the comments in inftrees.h
458                concerning the ENOUGH constants, which depend on those values */
459             state->next = state->codes;
460             state->lencode = (code const FAR *)(state->next);
461             state->lenbits = 9;
462             ret = inflate_table(LENS, state->lens, state->nlen, &(state->next),
463                                 &(state->lenbits), state->work);
464             if (ret) {
465                 strm->msg = (char *)"invalid literal/lengths set";
466                 state->mode = BAD;
467                 break;
468             }
469             state->distcode = (code const FAR *)(state->next);
470             state->distbits = 6;
471             ret = inflate_table(DISTS, state->lens + state->nlen, state->ndist,
472                             &(state->next), &(state->distbits), state->work);
473             if (ret) {
474                 strm->msg = (char *)"invalid distances set";
475                 state->mode = BAD;
476                 break;
477             }
478             Tracev((stderr, "inflate:       codes ok\n"));
479             state->mode = LEN;
480                 /* fallthrough */
481 
482         case LEN:
483             /* use inflate_fast() if we have enough input and output */
484             if (have >= 6 && left >= 258) {
485                 RESTORE();
486                 if (state->whave < state->wsize)
487                     state->whave = state->wsize - left;
488                 inflate_fast(strm, state->wsize);
489                 LOAD();
490                 break;
491             }
492 
493             /* get a literal, length, or end-of-block code */
494             for (;;) {
495                 here = state->lencode[BITS(state->lenbits)];
496                 if ((unsigned)(here.bits) <= bits) break;
497                 PULLBYTE();
498             }
499             if (here.op && (here.op & 0xf0) == 0) {
500                 last = here;
501                 for (;;) {
502                     here = state->lencode[last.val +
503                             (BITS(last.bits + last.op) >> last.bits)];
504                     if ((unsigned)(last.bits + here.bits) <= bits) break;
505                     PULLBYTE();
506                 }
507                 DROPBITS(last.bits);
508             }
509             DROPBITS(here.bits);
510             state->length = (unsigned)here.val;
511 
512             /* process literal */
513             if (here.op == 0) {
514                 Tracevv((stderr, here.val >= 0x20 && here.val < 0x7f ?
515                         "inflate:         literal '%c'\n" :
516                         "inflate:         literal 0x%02x\n", here.val));
517                 ROOM();
518                 *put++ = (unsigned char)(state->length);
519                 left--;
520                 state->mode = LEN;
521                 break;
522             }
523 
524             /* process end of block */
525             if (here.op & 32) {
526                 Tracevv((stderr, "inflate:         end of block\n"));
527                 state->mode = TYPE;
528                 break;
529             }
530 
531             /* invalid code */
532             if (here.op & 64) {
533                 strm->msg = (char *)"invalid literal/length code";
534                 state->mode = BAD;
535                 break;
536             }
537 
538             /* length code -- get extra bits, if any */
539             state->extra = (unsigned)(here.op) & 15;
540             if (state->extra != 0) {
541                 NEEDBITS(state->extra);
542                 state->length += BITS(state->extra);
543                 DROPBITS(state->extra);
544             }
545             Tracevv((stderr, "inflate:         length %u\n", state->length));
546 
547             /* get distance code */
548             for (;;) {
549                 here = state->distcode[BITS(state->distbits)];
550                 if ((unsigned)(here.bits) <= bits) break;
551                 PULLBYTE();
552             }
553             if ((here.op & 0xf0) == 0) {
554                 last = here;
555                 for (;;) {
556                     here = state->distcode[last.val +
557                             (BITS(last.bits + last.op) >> last.bits)];
558                     if ((unsigned)(last.bits + here.bits) <= bits) break;
559                     PULLBYTE();
560                 }
561                 DROPBITS(last.bits);
562             }
563             DROPBITS(here.bits);
564             if (here.op & 64) {
565                 strm->msg = (char *)"invalid distance code";
566                 state->mode = BAD;
567                 break;
568             }
569             state->offset = (unsigned)here.val;
570 
571             /* get distance extra bits, if any */
572             state->extra = (unsigned)(here.op) & 15;
573             if (state->extra != 0) {
574                 NEEDBITS(state->extra);
575                 state->offset += BITS(state->extra);
576                 DROPBITS(state->extra);
577             }
578             if (state->offset > state->wsize - (state->whave < state->wsize ?
579                                                 left : 0)) {
580                 strm->msg = (char *)"invalid distance too far back";
581                 state->mode = BAD;
582                 break;
583             }
584             Tracevv((stderr, "inflate:         distance %u\n", state->offset));
585 
586             /* copy match from window to output */
587             do {
588                 ROOM();
589                 copy = state->wsize - state->offset;
590                 if (copy < left) {
591                     from = put + copy;
592                     copy = left - copy;
593                 }
594                 else {
595                     from = put - state->offset;
596                     copy = left;
597                 }
598                 if (copy > state->length) copy = state->length;
599                 state->length -= copy;
600                 left -= copy;
601                 do {
602                     *put++ = *from++;
603                 } while (--copy);
604             } while (state->length != 0);
605             break;
606 
607         case DONE:
608             /* inflate stream terminated properly -- write leftover output */
609             ret = Z_STREAM_END;
610             if (left < state->wsize) {
611                 if (out(out_desc, state->window, state->wsize - left))
612                     ret = Z_BUF_ERROR;
613             }
614             goto inf_leave;
615 
616         case BAD:
617             ret = Z_DATA_ERROR;
618             goto inf_leave;
619 
620         default:                /* can't happen, but makes compilers happy */
621             ret = Z_STREAM_ERROR;
622             goto inf_leave;
623         }
624 
625     /* Return unused input */
626   inf_leave:
627     strm->next_in = next;
628     strm->avail_in = have;
629     return ret;
630 }
631 
632 int ZEXPORT inflateBackEnd(strm)
633 z_streamp strm;
634 {
635     if (strm == Z_NULL || strm->state == Z_NULL || strm->zfree == (free_func)0)
636         return Z_STREAM_ERROR;
637     ZFREE(strm, strm->state);
638     strm->state = Z_NULL;
639     Tracev((stderr, "inflate: end\n"));
640     return Z_OK;
641 }
642