]> git.lizzy.rs Git - zlib.git/blob - infback.c
zlib 1.2.0
[zlib.git] / infback.c
1 /* infback.c -- inflate using a call-back interface
2  * Copyright (C) 1995-2003 Mark Adler
3  * For conditions of distribution and use, see copyright notice in zlib.h
4  */
5
6 /*
7    This code is largely copied from inflate.c.  Normally either infback.o or
8    inflate.o would be linked into an application--not both.  The interface
9    with inffast.c is retained so that optimized assembler-coded versions of
10    inflate_fast() can be used with either inflate.c or infback.c.
11  */
12
13 #include "zutil.h"
14 #include "inftrees.h"
15 #include "inflate.h"
16 #include "inffast.h"
17
18 /* function prototypes */
19 local void fixedtables OF((struct inflate_state FAR *state));
20
21 /*
22    strm provides memory allocation functions in zalloc and zfree, or
23    Z_NULL to use the library memory allocation functions.
24
25    windowBits is in the range 8..15, and window is a user-supplied
26    window and output buffer that is 2**windowBits bytes.
27  */
28 int ZEXPORT inflateBackInit_(strm, windowBits, window, version, stream_size)
29 z_stream FAR *strm;
30 int windowBits;
31 unsigned char FAR *window;
32 const char *version;
33 int stream_size;
34 {
35     struct inflate_state FAR *state;
36
37     if (version == Z_NULL || version[0] != ZLIB_VERSION[0] ||
38         stream_size != (int)(sizeof(z_stream)))
39         return Z_VERSION_ERROR;
40     if (strm == Z_NULL || window == Z_NULL ||
41         windowBits < 8 || windowBits > 15)
42         return Z_STREAM_ERROR;
43     strm->msg = Z_NULL;                 /* in case we return an error */
44     if (strm->zalloc == Z_NULL) {
45         strm->zalloc = zcalloc;
46         strm->opaque = (voidpf)0;
47     }
48     if (strm->zfree == Z_NULL) strm->zfree = zcfree;
49     state = (struct inflate_state FAR *)ZALLOC(strm, 1,
50                                                sizeof(struct inflate_state));
51     if (state == Z_NULL) return Z_MEM_ERROR;
52     Tracev((stderr, "inflate: allocated\n"));
53     strm->state = (voidpf)state;
54     state->wbits = windowBits;
55     state->wsize = 1U << windowBits;
56     state->window = window;
57     state->write = 0;
58     return Z_OK;
59 }
60
61 /*
62    Return state with length and distance decoding tables and index sizes set to
63    fixed code decoding.  Normally this returns fixed tables from inffixed.h.
64    If BUILDFIXED is defined, then instead this routine builds the tables the
65    first time it's called, and returns those tables the first time and
66    thereafter.  This reduces the size of the code by about 2K bytes, in
67    exchange for a little execution time.  However, BUILDFIXED should not be
68    used for threaded applications, since the rewriting of the tables and virgin
69    may not be thread-safe.
70  */
71 local void fixedtables(state)
72 struct inflate_state FAR *state;
73 {
74 #ifdef BUILDFIXED
75     static int virgin = 1;
76     static code *lenfix, *distfix;
77     static code fixed[544];
78
79     /* build fixed huffman tables if first call (may not be thread safe) */
80     if (virgin) {
81         unsigned sym, bits;
82         static code *next;
83
84         /* literal/length table */
85         sym = 0;
86         while (sym < 144) state->lens[sym++] = 8;
87         while (sym < 256) state->lens[sym++] = 9;
88         while (sym < 280) state->lens[sym++] = 7;
89         while (sym < 288) state->lens[sym++] = 8;
90         next = fixed;
91         lenfix = next;
92         bits = 9;
93         inflate_table(LENS, state->lens, 288, &(next), &(bits), state->work);
94
95         /* distance table */
96         sym = 0;
97         while (sym < 32) state->lens[sym++] = 5;
98         distfix = next;
99         bits = 5;
100         inflate_table(DISTS, state->lens, 32, &(next), &(bits), state->work);
101
102         /* do this just once */
103         virgin = 0;
104     }
105 #else /* !BUILDFIXED */
106 #   include "inffixed.h"
107 #endif /* BUILDFIXED */
108     state->lencode = lenfix;
109     state->lenbits = 9;
110     state->distcode = distfix;
111     state->distbits = 5;
112 }
113
114 /* Macros for inflateBack(): */
115
116 /* Load returned state from inflate_fast() */
117 #define LOAD() \
118     do { \
119         put = strm->next_out; \
120         left = strm->avail_out; \
121         next = strm->next_in; \
122         have = strm->avail_in; \
123         hold = state->hold; \
124         bits = state->bits; \
125     } while (0)
126
127 /* Set state from registers for inflate_fast() */
128 #define RESTORE() \
129     do { \
130         strm->next_out = put; \
131         strm->avail_out = left; \
132         strm->next_in = next; \
133         strm->avail_in = have; \
134         state->hold = hold; \
135         state->bits = bits; \
136     } while (0)
137
138 /* Clear the input bit accumulator */
139 #define INITBITS() \
140     do { \
141         hold = 0; \
142         bits = 0; \
143     } while (0)
144
145 /* Assure that some input is available.  If input is requested, but denied,
146    then return a Z_BUF_ERROR from inflateBack(). */
147 #define PULL() \
148     do { \
149         if (have == 0) { \
150             have = in(in_desc, &next); \
151             if (have == 0) { \
152                 next = Z_NULL; \
153                 ret = Z_BUF_ERROR; \
154                 goto leave; \
155             } \
156         } \
157     } while (0)
158
159 /* Get a byte of input into the bit accumulator, or return from inflateBack()
160    with an error if there is no input available. */
161 #define PULLBYTE() \
162     do { \
163         PULL(); \
164         have--; \
165         hold += (unsigned long)(*next++) << bits; \
166         bits += 8; \
167     } while (0)
168
169 /* Assure that there are at least n bits in the bit accumulator.  If there is
170    not enough available input to do that, then return from inflateBack() with
171    an error. */
172 #define NEEDBITS(n) \
173     do { \
174         while (bits < (unsigned)(n)) \
175             PULLBYTE(); \
176     } while (0)
177
178 /* Return the low n bits of the bit accumulator (n < 16) */
179 #define BITS(n) \
180     ((unsigned)hold & ((1U << (n)) - 1))
181
182 /* Remove n bits from the bit accumulator */
183 #define DROPBITS(n) \
184     do { \
185         hold >>= (n); \
186         bits -= (unsigned)(n); \
187     } while (0)
188
189 /* Remove zero to seven bits as needed to go to a byte boundary */
190 #define BYTEBITS() \
191     do { \
192         hold >>= bits & 7; \
193         bits -= bits & 7; \
194     } while (0)
195
196 /* Assure that some output space is available, by writing out the window
197    if it's full.  If the write fails, return from inflateBack() with a
198    Z_BUF_ERROR. */
199 #define ROOM() \
200     do { \
201         if (left == 0) { \
202             put = state->window; \
203             left = state->wsize; \
204             if (out(out_desc, put, left)) { \
205                 ret = Z_BUF_ERROR; \
206                 goto leave; \
207             } \
208         } \
209     } while (0)
210
211 /*
212    strm provides the memory allocation functions and window buffer on input,
213    and provides information on the unused input on return.  For Z_DATA_ERROR
214    returns, strm will also provide an error message.
215
216    in() and out() are the call-back input and output functions.  When
217    inflateBack() needs more input, it calls in().  When inflateBack() has
218    filled the window with output, or when it completes with data in the
219    window, it called out() to write out the data.  The application must not
220    change the provided input until in() is called again or inflateBack()
221    returns.  The application must not change the window/output buffer until
222    inflateBack() returns.
223
224    in() and out() are called with a descriptor parameter provided in the
225    inflateBack() call.  This parameter can be a structure that provides the
226    information required to do the read or write, as well as accumulated
227    information on the input and output such as totals and check values.
228
229    in() should return zero on failure.  out() should return non-zero on
230    failure.  If either in() or out() fails, than inflateBack() returns a
231    Z_BUF_ERROR.  strm->next_in can be checked for Z_NULL to see whether it
232    was in() or out() that caused in the error.  Otherwise,  inflateBack()
233    returns Z_STREAM_END on success, Z_DATA_ERROR for an deflate format
234    error, or Z_MEM_ERROR if it could not allocate memory for the state.
235    inflateBack() can also return Z_STREAM_ERROR if the input parameters
236    are not correct, i.e. strm is Z_NULL or the state was not initialized.
237  */
238 int ZEXPORT inflateBack(strm, in, in_desc, out, out_desc)
239 z_stream FAR *strm;
240 in_func in;
241 void FAR *in_desc;
242 out_func out;
243 void FAR *out_desc;
244 {
245     struct inflate_state FAR *state;
246     unsigned char *next, *put;  /* next input and output */
247     unsigned have, left;        /* available input and output */
248     unsigned long hold;         /* bit buffer */
249     unsigned bits;              /* bits in bit buffer */
250     unsigned copy;              /* number of stored or match bytes to copy */
251     unsigned char *from;        /* where to copy match bytes from */
252     code this;                  /* current decoding table entry */
253     code last;                  /* parent table entry */
254     unsigned len;               /* length to copy for repeats, bits to drop */
255     int ret;                    /* return code */
256     static const unsigned short order[19] = /* permutation of code lengths */
257         {16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
258
259     /* Check that the strm exists and that the state was initialized */
260     if (strm == Z_NULL || strm->state == Z_NULL)
261         return Z_STREAM_ERROR;
262     state = (struct inflate_state FAR *)strm->state;
263
264     /* Reset the state */
265     strm->msg = Z_NULL;
266     state->mode = TYPE;
267     state->last = 0;
268     next = strm->next_in;
269     have = next != Z_NULL ? strm->avail_in : 0;
270     hold = 0;
271     bits = 0;
272     put = state->window;
273     left = state->wsize;
274
275     /* Inflate until end of block marked as last */
276     for (;;)
277         switch (state->mode) {
278         case TYPE:
279             /* determine and dispatch block type */
280             if (state->last) {
281                 BYTEBITS();
282                 state->mode = DONE;
283                 break;
284             }
285             NEEDBITS(3);
286             state->last = BITS(1);
287             DROPBITS(1);
288             switch (BITS(2)) {
289             case 0:                             /* stored block */
290                 Tracev((stderr, "inflate:     stored block%s\n",
291                         state->last ? " (last)" : ""));
292                 state->mode = STORED;
293                 break;
294             case 1:                             /* fixed block */
295                 fixedtables(state);
296                 Tracev((stderr, "inflate:     fixed codes block%s\n",
297                         state->last ? " (last)" : ""));
298                 state->mode = LEN;              /* decode codes */
299                 break;
300             case 2:                             /* dynamic block */
301                 Tracev((stderr, "inflate:     dynamic codes block%s\n",
302                         state->last ? " (last)" : ""));
303                 state->mode = TABLE;
304                 break;
305             case 3:
306                 strm->msg = (char *)"invalid block type";
307                 state->mode = BAD;
308             }
309             DROPBITS(2);
310             break;
311
312         case STORED:
313             /* get and verify stored block length */
314             BYTEBITS();                         /* go to byte boundary */
315             NEEDBITS(32);
316             if ((hold & 0xffff) != ((hold >> 16) ^ 0xffff)) {
317                 strm->msg = (char *)"invalid stored block lengths";
318                 state->mode = BAD;
319                 break;
320             }
321             state->length = (unsigned)hold & 0xffff;
322             Tracev((stderr, "inflate:       stored length %u\n",
323                     state->length));
324             INITBITS();
325
326             /* copy stored block from input to output */
327             while (state->length != 0) {
328                 copy = state->length;
329                 PULL();
330                 ROOM();
331                 if (copy > have) copy = have;
332                 if (copy > left) copy = left;
333                 zmemcpy(put, next, copy);
334                 have -= copy;
335                 next += copy;
336                 left -= copy;
337                 put += copy;
338                 state->length -= copy;
339             }
340             Tracev((stderr, "inflate:       stored end\n"));
341             state->mode = TYPE;
342             break;
343
344         case TABLE:
345             /* get dynamic table entries descriptor */
346             NEEDBITS(14);
347             state->nlen = BITS(5) + 257;
348             DROPBITS(5);
349             state->ndist = BITS(5) + 1;
350             DROPBITS(5);
351             state->ncode = BITS(4) + 4;
352             DROPBITS(4);
353 #ifndef PKZIP_BUG_WORKAROUND
354             if (state->nlen > 286 || state->ndist > 30) {
355                 strm->msg = (char *)"too many length or distance symbols";
356                 state->mode = BAD;
357                 break;
358             }
359 #endif
360             Tracev((stderr, "inflate:       table sizes ok\n"));
361
362             /* get code length code lengths (not a typo) */
363             state->have = 0;
364             while (state->have < state->ncode) {
365                 NEEDBITS(3);
366                 state->lens[order[state->have++]] = (unsigned short)BITS(3);
367                 DROPBITS(3);
368             }
369             while (state->have < 19)
370                 state->lens[order[state->have++]] = 0;
371             state->next = state->codes;
372             state->lencode = (code const FAR *)(state->next);
373             state->lenbits = 7;
374             ret = inflate_table(CODES, state->lens, 19, &(state->next),
375                                 &(state->lenbits), state->work);
376             if (ret) {
377                 strm->msg = (char *)"invalid code lengths set";
378                 state->mode = BAD;
379                 break;
380             }
381             Tracev((stderr, "inflate:       code lengths ok\n"));
382
383             /* get length and distance code code lengths */
384             state->have = 0;
385             while (state->have < state->nlen + state->ndist) {
386                 for (;;) {
387                     this = state->lencode[BITS(state->lenbits)];
388                     if ((unsigned)(this.bits) <= bits) break;
389                     PULLBYTE();
390                 }
391                 if (this.val < 16) {
392                     NEEDBITS(this.bits);
393                     DROPBITS(this.bits);
394                     state->lens[state->have++] = this.val;
395                 }
396                 else {
397                     if (this.val == 16) {
398                         NEEDBITS(this.bits + 2);
399                         DROPBITS(this.bits);
400                         if (state->have == 0) {
401                             strm->msg = (char *)"invalid bit length repeat";
402                             state->mode = BAD;
403                             break;
404                         }
405                         len = (unsigned)(state->lens[state->have - 1]);
406                         copy = 3 + BITS(2);
407                         DROPBITS(2);
408                     }
409                     else if (this.val == 17) {
410                         NEEDBITS(this.bits + 3);
411                         DROPBITS(this.bits);
412                         len = 0;
413                         copy = 3 + BITS(3);
414                         DROPBITS(3);
415                     }
416                     else {
417                         NEEDBITS(this.bits + 7);
418                         DROPBITS(this.bits);
419                         len = 0;
420                         copy = 11 + BITS(7);
421                         DROPBITS(7);
422                     }
423                     if (state->have + copy > state->nlen + state->ndist) {
424                         strm->msg = (char *)"invalid bit length repeat";
425                         state->mode = BAD;
426                         break;
427                     }
428                     while (copy--)
429                         state->lens[state->have++] = (unsigned short)len;
430                 }
431             }
432
433             /* build code tables */
434             state->next = state->codes;
435             state->lencode = (code const FAR *)(state->next);
436             state->lenbits = 9;
437             ret = inflate_table(LENS, state->lens, state->nlen, &(state->next),
438                                 &(state->lenbits), state->work);
439             if (ret) {
440                 strm->msg = (char *)"invalid literal/lengths set";
441                 state->mode = BAD;
442                 break;
443             }
444             state->distcode = (code const FAR *)(state->next);
445             state->distbits = 6;
446             ret = inflate_table(DISTS, state->lens + state->nlen, state->ndist,
447                             &(state->next), &(state->distbits), state->work);
448             if (ret) {
449                 strm->msg = (char *)"invalid distances set";
450                 state->mode = BAD;
451                 break;
452             }
453             Tracev((stderr, "inflate:       codes ok\n"));
454             state->mode = LEN;
455
456         case LEN:
457             /* use inflate_fast() if we have enough input and output */
458             if (have >= 6 && left >= 258) {
459                 RESTORE();
460                 inflate_fast(strm, state->wsize);
461                 LOAD();
462                 break;
463             }
464
465             /* get a literal, length, or end-of-block code */
466             for (;;) {
467                 this = state->lencode[BITS(state->lenbits)];
468                 if ((unsigned)(this.bits) <= bits) break;
469                 PULLBYTE();
470             }
471             if (this.op && (this.op & 0xf0) == 0) {
472                 last = this;
473                 for (;;) {
474                     this = state->lencode[last.val +
475                             (BITS(last.bits + last.op) >> last.bits)];
476                     if ((unsigned)(last.bits + this.bits) <= bits) break;
477                     PULLBYTE();
478                 }
479                 DROPBITS(last.bits);
480             }
481             DROPBITS(this.bits);
482             state->length = (unsigned)this.val;
483
484             /* process literal */
485             if (this.op == 0) {
486                 Tracevv((stderr, this.val >= 0x20 && this.val < 0x7f ?
487                         "inflate:         literal '%c'\n" :
488                         "inflate:         literal 0x%02x\n", this.val));
489                 ROOM();
490                 *put++ = (unsigned char)(state->length);
491                 left--;
492                 state->mode = LEN;
493                 break;
494             }
495
496             /* process end of block */
497             if (this.op & 32) {
498                 Tracevv((stderr, "inflate:         end of block\n"));
499                 state->mode = TYPE;
500                 break;
501             }
502
503             /* invalid code */
504             if (this.op & 64) {
505                 strm->msg = (char *)"invalid literal/length code";
506                 state->mode = BAD;
507                 break;
508             }
509
510             /* length code -- get extra bits, if any */
511             state->extra = (unsigned)(this.op) & 15;
512             if (state->extra != 0) {
513                 NEEDBITS(state->extra);
514                 state->length += BITS(state->extra);
515                 DROPBITS(state->extra);
516             }
517             Tracevv((stderr, "inflate:         length %u\n", state->length));
518
519             /* get distance code */
520             for (;;) {
521                 this = state->distcode[BITS(state->distbits)];
522                 if ((unsigned)(this.bits) <= bits) break;
523                 PULLBYTE();
524             }
525             if ((this.op & 0xf0) == 0) {
526                 last = this;
527                 for (;;) {
528                     this = state->distcode[last.val +
529                             (BITS(last.bits + last.op) >> last.bits)];
530                     if ((unsigned)(last.bits + this.bits) <= bits) break;
531                     PULLBYTE();
532                 }
533                 DROPBITS(last.bits);
534             }
535             DROPBITS(this.bits);
536             if (this.op & 64) {
537                 strm->msg = (char *)"invalid distance code";
538                 state->mode = BAD;
539                 break;
540             }
541             state->offset = (unsigned)this.val;
542
543             /* get distance extra bits, if any */
544             state->extra = (unsigned)(this.op) & 15;
545             if (state->extra != 0) {
546                 NEEDBITS(state->extra);
547                 state->offset += BITS(state->extra);
548                 DROPBITS(state->extra);
549             }
550             if (state->offset > state->wsize) {
551                 strm->msg = (char *)"invalid distance too far back";
552                 state->mode = BAD;
553                 break;
554             }
555             Tracevv((stderr, "inflate:         distance %u\n", state->offset));
556
557             /* copy match from window to output */
558             do {
559                 ROOM();
560                 copy = state->wsize - state->offset;
561                 if (copy < left) {
562                     from = put + copy;
563                     copy = left - copy;
564                 }
565                 else {
566                     from = put - state->offset;
567                     copy = left;
568                 }
569                 if (copy > state->length) copy = state->length;
570                 state->length -= copy;
571                 left -= copy;
572                 do {
573                     *put++ = *from++;
574                 } while (--copy);
575             } while (state->length != 0);
576             break;
577
578         case DONE:
579             /* inflate stream terminated properly -- write leftover output */
580             ret = Z_STREAM_END;
581             if (left < state->wsize) {
582                 if (out(out_desc, state->window, state->wsize - left))
583                     ret = Z_BUF_ERROR;
584             }
585             goto leave;
586
587         case BAD:
588             ret = Z_DATA_ERROR;
589             goto leave;
590
591         default:                /* can't happen, but makes compilers happy */
592             ret = Z_STREAM_ERROR;
593             goto leave;
594         }
595
596     /* Return unused input */
597   leave:
598     strm->next_in = next;
599     strm->avail_in = have;
600     return ret;
601 }
602
603 int ZEXPORT inflateBackEnd(strm)
604 z_stream FAR *strm;
605 {
606     struct inflate_state FAR *state;
607
608     if (strm == Z_NULL || strm->state == Z_NULL || strm->zfree == Z_NULL)
609         return Z_STREAM_ERROR;
610     state = (struct inflate_state FAR *)strm->state;
611     ZFREE(strm, strm->state);
612     strm->state = Z_NULL;
613     Tracev((stderr, "inflate: end\n"));
614     return Z_OK;
615 }