xref: /freebsd/sys/contrib/zlib/infback.c (revision 7ef62cebc2f965b0f640263e179276928885e33d)
1 /* infback.c -- inflate using a call-back interface
2  * Copyright (C) 1995-2022 Mark Adler
3  * For conditions of distribution and use, see copyright notice in zlib.h
4  */
5 
6 /*
7    This code is largely copied from inflate.c.  Normally either infback.o or
8    inflate.o would be linked into an application--not both.  The interface
9    with inffast.c is retained so that optimized assembler-coded versions of
10    inflate_fast() can be used with either inflate.c or infback.c.
11  */
12 
13 #include "zutil.h"
14 #include "inftrees.h"
15 #include "inflate.h"
16 #include "inffast.h"
17 
18 /* function prototypes */
19 local void fixedtables OF((struct inflate_state FAR *state));
20 
21 /*
22    strm provides memory allocation functions in zalloc and zfree, or
23    Z_NULL to use the library memory allocation functions.
24 
25    windowBits is in the range 8..15, and window is a user-supplied
26    window and output buffer that is 2**windowBits bytes.
27  */
28 int ZEXPORT inflateBackInit_(strm, windowBits, window, version, stream_size)
29 z_streamp strm;
30 int windowBits;
31 unsigned char FAR *window;
32 const char *version;
33 int stream_size;
34 {
35     struct inflate_state FAR *state;
36 
37     if (version == Z_NULL || version[0] != ZLIB_VERSION[0] ||
38         stream_size != (int)(sizeof(z_stream)))
39         return Z_VERSION_ERROR;
40     if (strm == Z_NULL || window == Z_NULL ||
41         windowBits < 8 || windowBits > 15)
42         return Z_STREAM_ERROR;
43     strm->msg = Z_NULL;                 /* in case we return an error */
44     if (strm->zalloc == (alloc_func)0) {
45 #if defined(Z_SOLO) && !defined(_KERNEL)
46         return Z_STREAM_ERROR;
47 #else
48         strm->zalloc = zcalloc;
49         strm->opaque = (voidpf)0;
50 #endif
51     }
52     if (strm->zfree == (free_func)0)
53 #if defined(Z_SOLO) && !defined(_KERNEL)
54         return Z_STREAM_ERROR;
55 #else
56     strm->zfree = zcfree;
57 #endif
58     state = (struct inflate_state FAR *)ZALLOC(strm, 1,
59                                                sizeof(struct inflate_state));
60     if (state == Z_NULL) return Z_MEM_ERROR;
61     Tracev((stderr, "inflate: allocated\n"));
62     strm->state = (struct internal_state FAR *)state;
63     state->dmax = 32768U;
64     state->wbits = (uInt)windowBits;
65     state->wsize = 1U << windowBits;
66     state->window = window;
67     state->wnext = 0;
68     state->whave = 0;
69     state->sane = 1;
70     return Z_OK;
71 }
72 
73 /*
74    Return state with length and distance decoding tables and index sizes set to
75    fixed code decoding.  Normally this returns fixed tables from inffixed.h.
76    If BUILDFIXED is defined, then instead this routine builds the tables the
77    first time it's called, and returns those tables the first time and
78    thereafter.  This reduces the size of the code by about 2K bytes, in
79    exchange for a little execution time.  However, BUILDFIXED should not be
80    used for threaded applications, since the rewriting of the tables and virgin
81    may not be thread-safe.
82  */
83 local void fixedtables(state)
84 struct inflate_state FAR *state;
85 {
86 #ifdef BUILDFIXED
87     static int virgin = 1;
88     static code *lenfix, *distfix;
89     static code fixed[544];
90 
91     /* build fixed huffman tables if first call (may not be thread safe) */
92     if (virgin) {
93         unsigned sym, bits;
94         static code *next;
95 
96         /* literal/length table */
97         sym = 0;
98         while (sym < 144) state->lens[sym++] = 8;
99         while (sym < 256) state->lens[sym++] = 9;
100         while (sym < 280) state->lens[sym++] = 7;
101         while (sym < 288) state->lens[sym++] = 8;
102         next = fixed;
103         lenfix = next;
104         bits = 9;
105         inflate_table(LENS, state->lens, 288, &(next), &(bits), state->work);
106 
107         /* distance table */
108         sym = 0;
109         while (sym < 32) state->lens[sym++] = 5;
110         distfix = next;
111         bits = 5;
112         inflate_table(DISTS, state->lens, 32, &(next), &(bits), state->work);
113 
114         /* do this just once */
115         virgin = 0;
116     }
117 #else /* !BUILDFIXED */
118 #   include "inffixed.h"
119 #endif /* BUILDFIXED */
120     state->lencode = lenfix;
121     state->lenbits = 9;
122     state->distcode = distfix;
123     state->distbits = 5;
124 }
125 
126 /* Macros for inflateBack(): */
127 
128 /* Load returned state from inflate_fast() */
129 #define LOAD() \
130     do { \
131         put = strm->next_out; \
132         left = strm->avail_out; \
133         next = strm->next_in; \
134         have = strm->avail_in; \
135         hold = state->hold; \
136         bits = state->bits; \
137     } while (0)
138 
139 /* Set state from registers for inflate_fast() */
140 #define RESTORE() \
141     do { \
142         strm->next_out = put; \
143         strm->avail_out = left; \
144         strm->next_in = next; \
145         strm->avail_in = have; \
146         state->hold = hold; \
147         state->bits = bits; \
148     } while (0)
149 
150 /* Clear the input bit accumulator */
151 #define INITBITS() \
152     do { \
153         hold = 0; \
154         bits = 0; \
155     } while (0)
156 
157 /* Assure that some input is available.  If input is requested, but denied,
158    then return a Z_BUF_ERROR from inflateBack(). */
159 #define PULL() \
160     do { \
161         if (have == 0) { \
162             have = in(in_desc, &next); \
163             if (have == 0) { \
164                 next = Z_NULL; \
165                 ret = Z_BUF_ERROR; \
166                 goto inf_leave; \
167             } \
168         } \
169     } while (0)
170 
171 /* Get a byte of input into the bit accumulator, or return from inflateBack()
172    with an error if there is no input available. */
173 #define PULLBYTE() \
174     do { \
175         PULL(); \
176         have--; \
177         hold += (unsigned long)(*next++) << bits; \
178         bits += 8; \
179     } while (0)
180 
181 /* Assure that there are at least n bits in the bit accumulator.  If there is
182    not enough available input to do that, then return from inflateBack() with
183    an error. */
184 #define NEEDBITS(n) \
185     do { \
186         while (bits < (unsigned)(n)) \
187             PULLBYTE(); \
188     } while (0)
189 
190 /* Return the low n bits of the bit accumulator (n < 16) */
191 #define BITS(n) \
192     ((unsigned)hold & ((1U << (n)) - 1))
193 
194 /* Remove n bits from the bit accumulator */
195 #define DROPBITS(n) \
196     do { \
197         hold >>= (n); \
198         bits -= (unsigned)(n); \
199     } while (0)
200 
201 /* Remove zero to seven bits as needed to go to a byte boundary */
202 #define BYTEBITS() \
203     do { \
204         hold >>= bits & 7; \
205         bits -= bits & 7; \
206     } while (0)
207 
208 /* Assure that some output space is available, by writing out the window
209    if it's full.  If the write fails, return from inflateBack() with a
210    Z_BUF_ERROR. */
211 #define ROOM() \
212     do { \
213         if (left == 0) { \
214             put = state->window; \
215             left = state->wsize; \
216             state->whave = left; \
217             if (out(out_desc, put, left)) { \
218                 ret = Z_BUF_ERROR; \
219                 goto inf_leave; \
220             } \
221         } \
222     } while (0)
223 
224 /*
225    strm provides the memory allocation functions and window buffer on input,
226    and provides information on the unused input on return.  For Z_DATA_ERROR
227    returns, strm will also provide an error message.
228 
229    in() and out() are the call-back input and output functions.  When
230    inflateBack() needs more input, it calls in().  When inflateBack() has
231    filled the window with output, or when it completes with data in the
232    window, it calls out() to write out the data.  The application must not
233    change the provided input until in() is called again or inflateBack()
234    returns.  The application must not change the window/output buffer until
235    inflateBack() returns.
236 
237    in() and out() are called with a descriptor parameter provided in the
238    inflateBack() call.  This parameter can be a structure that provides the
239    information required to do the read or write, as well as accumulated
240    information on the input and output such as totals and check values.
241 
242    in() should return zero on failure.  out() should return non-zero on
243    failure.  If either in() or out() fails, than inflateBack() returns a
244    Z_BUF_ERROR.  strm->next_in can be checked for Z_NULL to see whether it
245    was in() or out() that caused in the error.  Otherwise,  inflateBack()
246    returns Z_STREAM_END on success, Z_DATA_ERROR for an deflate format
247    error, or Z_MEM_ERROR if it could not allocate memory for the state.
248    inflateBack() can also return Z_STREAM_ERROR if the input parameters
249    are not correct, i.e. strm is Z_NULL or the state was not initialized.
250  */
251 int ZEXPORT inflateBack(strm, in, in_desc, out, out_desc)
252 z_streamp strm;
253 in_func in;
254 void FAR *in_desc;
255 out_func out;
256 void FAR *out_desc;
257 {
258     struct inflate_state FAR *state;
259     z_const unsigned char FAR *next;    /* next input */
260     unsigned char FAR *put;     /* next output */
261     unsigned have, left;        /* available input and output */
262     unsigned long hold;         /* bit buffer */
263     unsigned bits;              /* bits in bit buffer */
264     unsigned copy;              /* number of stored or match bytes to copy */
265     unsigned char FAR *from;    /* where to copy match bytes from */
266     code here;                  /* current decoding table entry */
267     code last;                  /* parent table entry */
268     unsigned len;               /* length to copy for repeats, bits to drop */
269     int ret;                    /* return code */
270     static const unsigned short order[19] = /* permutation of code lengths */
271         {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
272 
273     /* Check that the strm exists and that the state was initialized */
274     if (strm == Z_NULL || strm->state == Z_NULL)
275         return Z_STREAM_ERROR;
276     state = (struct inflate_state FAR *)strm->state;
277 
278     /* Reset the state */
279     strm->msg = Z_NULL;
280     state->mode = TYPE;
281     state->last = 0;
282     state->whave = 0;
283     next = strm->next_in;
284     have = next != Z_NULL ? strm->avail_in : 0;
285     hold = 0;
286     bits = 0;
287     put = state->window;
288     left = state->wsize;
289 
290     /* Inflate until end of block marked as last */
291     for (;;)
292         switch (state->mode) {
293         case TYPE:
294             /* determine and dispatch block type */
295             if (state->last) {
296                 BYTEBITS();
297                 state->mode = DONE;
298                 break;
299             }
300             NEEDBITS(3);
301             state->last = BITS(1);
302             DROPBITS(1);
303             switch (BITS(2)) {
304             case 0:                             /* stored block */
305                 Tracev((stderr, "inflate:     stored block%s\n",
306                         state->last ? " (last)" : ""));
307                 state->mode = STORED;
308                 break;
309             case 1:                             /* fixed block */
310                 fixedtables(state);
311                 Tracev((stderr, "inflate:     fixed codes block%s\n",
312                         state->last ? " (last)" : ""));
313                 state->mode = LEN;              /* decode codes */
314                 break;
315             case 2:                             /* dynamic block */
316                 Tracev((stderr, "inflate:     dynamic codes block%s\n",
317                         state->last ? " (last)" : ""));
318                 state->mode = TABLE;
319                 break;
320             case 3:
321                 strm->msg = (char *)"invalid block type";
322                 state->mode = BAD;
323             }
324             DROPBITS(2);
325             break;
326 
327         case STORED:
328             /* get and verify stored block length */
329             BYTEBITS();                         /* go to byte boundary */
330             NEEDBITS(32);
331             if ((hold & 0xffff) != ((hold >> 16) ^ 0xffff)) {
332                 strm->msg = (char *)"invalid stored block lengths";
333                 state->mode = BAD;
334                 break;
335             }
336             state->length = (unsigned)hold & 0xffff;
337             Tracev((stderr, "inflate:       stored length %u\n",
338                     state->length));
339             INITBITS();
340 
341             /* copy stored block from input to output */
342             while (state->length != 0) {
343                 copy = state->length;
344                 PULL();
345                 ROOM();
346                 if (copy > have) copy = have;
347                 if (copy > left) copy = left;
348                 zmemcpy(put, next, copy);
349                 have -= copy;
350                 next += copy;
351                 left -= copy;
352                 put += copy;
353                 state->length -= copy;
354             }
355             Tracev((stderr, "inflate:       stored end\n"));
356             state->mode = TYPE;
357             break;
358 
359         case TABLE:
360             /* get dynamic table entries descriptor */
361             NEEDBITS(14);
362             state->nlen = BITS(5) + 257;
363             DROPBITS(5);
364             state->ndist = BITS(5) + 1;
365             DROPBITS(5);
366             state->ncode = BITS(4) + 4;
367             DROPBITS(4);
368 #ifndef PKZIP_BUG_WORKAROUND
369             if (state->nlen > 286 || state->ndist > 30) {
370                 strm->msg = (char *)"too many length or distance symbols";
371                 state->mode = BAD;
372                 break;
373             }
374 #endif
375             Tracev((stderr, "inflate:       table sizes ok\n"));
376 
377             /* get code length code lengths (not a typo) */
378             state->have = 0;
379             while (state->have < state->ncode) {
380                 NEEDBITS(3);
381                 state->lens[order[state->have++]] = (unsigned short)BITS(3);
382                 DROPBITS(3);
383             }
384             while (state->have < 19)
385                 state->lens[order[state->have++]] = 0;
386             state->next = state->codes;
387             state->lencode = (code const FAR *)(state->next);
388             state->lenbits = 7;
389             ret = inflate_table(CODES, state->lens, 19, &(state->next),
390                                 &(state->lenbits), state->work);
391             if (ret) {
392                 strm->msg = (char *)"invalid code lengths set";
393                 state->mode = BAD;
394                 break;
395             }
396             Tracev((stderr, "inflate:       code lengths ok\n"));
397 
398             /* get length and distance code code lengths */
399             state->have = 0;
400             while (state->have < state->nlen + state->ndist) {
401                 for (;;) {
402                     here = state->lencode[BITS(state->lenbits)];
403                     if ((unsigned)(here.bits) <= bits) break;
404                     PULLBYTE();
405                 }
406                 if (here.val < 16) {
407                     DROPBITS(here.bits);
408                     state->lens[state->have++] = here.val;
409                 }
410                 else {
411                     if (here.val == 16) {
412                         NEEDBITS(here.bits + 2);
413                         DROPBITS(here.bits);
414                         if (state->have == 0) {
415                             strm->msg = (char *)"invalid bit length repeat";
416                             state->mode = BAD;
417                             break;
418                         }
419                         len = (unsigned)(state->lens[state->have - 1]);
420                         copy = 3 + BITS(2);
421                         DROPBITS(2);
422                     }
423                     else if (here.val == 17) {
424                         NEEDBITS(here.bits + 3);
425                         DROPBITS(here.bits);
426                         len = 0;
427                         copy = 3 + BITS(3);
428                         DROPBITS(3);
429                     }
430                     else {
431                         NEEDBITS(here.bits + 7);
432                         DROPBITS(here.bits);
433                         len = 0;
434                         copy = 11 + BITS(7);
435                         DROPBITS(7);
436                     }
437                     if (state->have + copy > state->nlen + state->ndist) {
438                         strm->msg = (char *)"invalid bit length repeat";
439                         state->mode = BAD;
440                         break;
441                     }
442                     while (copy--)
443                         state->lens[state->have++] = (unsigned short)len;
444                 }
445             }
446 
447             /* handle error breaks in while */
448             if (state->mode == BAD) break;
449 
450             /* check for end-of-block code (better have one) */
451             if (state->lens[256] == 0) {
452                 strm->msg = (char *)"invalid code -- missing end-of-block";
453                 state->mode = BAD;
454                 break;
455             }
456 
457             /* build code tables -- note: do not change the lenbits or distbits
458                values here (9 and 6) without reading the comments in inftrees.h
459                concerning the ENOUGH constants, which depend on those values */
460             state->next = state->codes;
461             state->lencode = (code const FAR *)(state->next);
462             state->lenbits = 9;
463             ret = inflate_table(LENS, state->lens, state->nlen, &(state->next),
464                                 &(state->lenbits), state->work);
465             if (ret) {
466                 strm->msg = (char *)"invalid literal/lengths set";
467                 state->mode = BAD;
468                 break;
469             }
470             state->distcode = (code const FAR *)(state->next);
471             state->distbits = 6;
472             ret = inflate_table(DISTS, state->lens + state->nlen, state->ndist,
473                             &(state->next), &(state->distbits), state->work);
474             if (ret) {
475                 strm->msg = (char *)"invalid distances set";
476                 state->mode = BAD;
477                 break;
478             }
479             Tracev((stderr, "inflate:       codes ok\n"));
480             state->mode = LEN;
481                 /* fallthrough */
482 
483         case LEN:
484             /* use inflate_fast() if we have enough input and output */
485             if (have >= 6 && left >= 258) {
486                 RESTORE();
487                 if (state->whave < state->wsize)
488                     state->whave = state->wsize - left;
489                 inflate_fast(strm, state->wsize);
490                 LOAD();
491                 break;
492             }
493 
494             /* get a literal, length, or end-of-block code */
495             for (;;) {
496                 here = state->lencode[BITS(state->lenbits)];
497                 if ((unsigned)(here.bits) <= bits) break;
498                 PULLBYTE();
499             }
500             if (here.op && (here.op & 0xf0) == 0) {
501                 last = here;
502                 for (;;) {
503                     here = state->lencode[last.val +
504                             (BITS(last.bits + last.op) >> last.bits)];
505                     if ((unsigned)(last.bits + here.bits) <= bits) break;
506                     PULLBYTE();
507                 }
508                 DROPBITS(last.bits);
509             }
510             DROPBITS(here.bits);
511             state->length = (unsigned)here.val;
512 
513             /* process literal */
514             if (here.op == 0) {
515                 Tracevv((stderr, here.val >= 0x20 && here.val < 0x7f ?
516                         "inflate:         literal '%c'\n" :
517                         "inflate:         literal 0x%02x\n", here.val));
518                 ROOM();
519                 *put++ = (unsigned char)(state->length);
520                 left--;
521                 state->mode = LEN;
522                 break;
523             }
524 
525             /* process end of block */
526             if (here.op & 32) {
527                 Tracevv((stderr, "inflate:         end of block\n"));
528                 state->mode = TYPE;
529                 break;
530             }
531 
532             /* invalid code */
533             if (here.op & 64) {
534                 strm->msg = (char *)"invalid literal/length code";
535                 state->mode = BAD;
536                 break;
537             }
538 
539             /* length code -- get extra bits, if any */
540             state->extra = (unsigned)(here.op) & 15;
541             if (state->extra != 0) {
542                 NEEDBITS(state->extra);
543                 state->length += BITS(state->extra);
544                 DROPBITS(state->extra);
545             }
546             Tracevv((stderr, "inflate:         length %u\n", state->length));
547 
548             /* get distance code */
549             for (;;) {
550                 here = state->distcode[BITS(state->distbits)];
551                 if ((unsigned)(here.bits) <= bits) break;
552                 PULLBYTE();
553             }
554             if ((here.op & 0xf0) == 0) {
555                 last = here;
556                 for (;;) {
557                     here = state->distcode[last.val +
558                             (BITS(last.bits + last.op) >> last.bits)];
559                     if ((unsigned)(last.bits + here.bits) <= bits) break;
560                     PULLBYTE();
561                 }
562                 DROPBITS(last.bits);
563             }
564             DROPBITS(here.bits);
565             if (here.op & 64) {
566                 strm->msg = (char *)"invalid distance code";
567                 state->mode = BAD;
568                 break;
569             }
570             state->offset = (unsigned)here.val;
571 
572             /* get distance extra bits, if any */
573             state->extra = (unsigned)(here.op) & 15;
574             if (state->extra != 0) {
575                 NEEDBITS(state->extra);
576                 state->offset += BITS(state->extra);
577                 DROPBITS(state->extra);
578             }
579             if (state->offset > state->wsize - (state->whave < state->wsize ?
580                                                 left : 0)) {
581                 strm->msg = (char *)"invalid distance too far back";
582                 state->mode = BAD;
583                 break;
584             }
585             Tracevv((stderr, "inflate:         distance %u\n", state->offset));
586 
587             /* copy match from window to output */
588             do {
589                 ROOM();
590                 copy = state->wsize - state->offset;
591                 if (copy < left) {
592                     from = put + copy;
593                     copy = left - copy;
594                 }
595                 else {
596                     from = put - state->offset;
597                     copy = left;
598                 }
599                 if (copy > state->length) copy = state->length;
600                 state->length -= copy;
601                 left -= copy;
602                 do {
603                     *put++ = *from++;
604                 } while (--copy);
605             } while (state->length != 0);
606             break;
607 
608         case DONE:
609             /* inflate stream terminated properly */
610             ret = Z_STREAM_END;
611             goto inf_leave;
612 
613         case BAD:
614             ret = Z_DATA_ERROR;
615             goto inf_leave;
616 
617         default:
618             /* can't happen, but makes compilers happy */
619             ret = Z_STREAM_ERROR;
620             goto inf_leave;
621         }
622 
623     /* Write leftover output and return unused input */
624   inf_leave:
625     if (left < state->wsize) {
626         if (out(out_desc, state->window, state->wsize - left) &&
627             ret == Z_STREAM_END)
628             ret = Z_BUF_ERROR;
629     }
630     strm->next_in = next;
631     strm->avail_in = have;
632     return ret;
633 }
634 
635 int ZEXPORT inflateBackEnd(strm)
636 z_streamp strm;
637 {
638     if (strm == Z_NULL || strm->state == Z_NULL || strm->zfree == (free_func)0)
639         return Z_STREAM_ERROR;
640     ZFREE(strm, strm->state);
641     strm->state = Z_NULL;
642     Tracev((stderr, "inflate: end\n"));
643     return Z_OK;
644 }
645