]> git.lizzy.rs Git - plan9front.git/blob - sys/src/libflate/inflate.c
Import sources from 2011-03-30 iso image - lib
[plan9front.git] / sys / src / libflate / inflate.c
1 #include <u.h>
2 #include <libc.h>
3 #include <flate.h>
4
5 enum {
6         HistorySize=    32*1024,
7         BufSize=        4*1024,
8         MaxHuffBits=    17,     /* maximum bits in a encoded code */
9         Nlitlen=        288,    /* number of litlen codes */
10         Noff=           32,     /* number of offset codes */
11         Nclen=          19,     /* number of codelen codes */
12         LenShift=       10,     /* code = len<<LenShift|code */
13         LitlenBits=     7,      /* number of bits in litlen decode table */
14         OffBits=        6,      /* number of bits in offset decode table */
15         ClenBits=       6,      /* number of bits in code len decode table */
16         MaxFlatBits=    LitlenBits,
17         MaxLeaf=        Nlitlen
18 };
19
20 typedef struct Input    Input;
21 typedef struct History  History;
22 typedef struct Huff     Huff;
23
24 struct Input
25 {
26         int     error;          /* first error encountered, or FlateOk */
27         void    *wr;
28         int     (*w)(void*, void*, int);
29         void    *getr;
30         int     (*get)(void*);
31         ulong   sreg;
32         int     nbits;
33 };
34
35 struct History
36 {
37         uchar   his[HistorySize];
38         uchar   *cp;            /* current pointer in history */
39         int     full;           /* his has been filled up at least once */
40 };
41
42 struct Huff
43 {
44         int     maxbits;        /* max bits for any code */
45         int     minbits;        /* min bits to get before looking in flat */
46         int     flatmask;       /* bits used in "flat" fast decoding table */
47         ulong   flat[1<<MaxFlatBits];
48         ulong   maxcode[MaxHuffBits];
49         ulong   last[MaxHuffBits];
50         ulong   decode[MaxLeaf];
51 };
52
53 /* litlen code words 257-285 extra bits */
54 static int litlenextra[Nlitlen-257] =
55 {
56 /* 257 */       0, 0, 0,
57 /* 260 */       0, 0, 0, 0, 0, 1, 1, 1, 1, 2,
58 /* 270 */       2, 2, 2, 3, 3, 3, 3, 4, 4, 4,
59 /* 280 */       4, 5, 5, 5, 5, 0, 0, 0
60 };
61
62 static int litlenbase[Nlitlen-257];
63
64 /* offset code word extra bits */
65 static int offextra[Noff] =
66 {
67         0,  0,  0,  0,  1,  1,  2,  2,  3,  3,
68         4,  4,  5,  5,  6,  6,  7,  7,  8,  8,
69         9,  9,  10, 10, 11, 11, 12, 12, 13, 13,
70         0,  0,
71 };
72 static int offbase[Noff];
73
74 /* order code lengths */
75 static int clenorder[Nclen] =
76 {
77         16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15
78 };
79
80 /* for static huffman tables */
81 static  Huff    litlentab;
82 static  Huff    offtab;
83 static  uchar   revtab[256];
84
85 static int      uncblock(Input *in, History*);
86 static int      fixedblock(Input *in, History*);
87 static int      dynamicblock(Input *in, History*);
88 static int      sregfill(Input *in, int n);
89 static int      sregunget(Input *in);
90 static int      decode(Input*, History*, Huff*, Huff*);
91 static int      hufftab(Huff*, char*, int, int);
92 static int      hdecsym(Input *in, Huff *h, int b);
93
94 int
95 inflateinit(void)
96 {
97         char *len;
98         int i, j, base;
99
100         /* byte reverse table */
101         for(i=0; i<256; i++)
102                 for(j=0; j<8; j++)
103                         if(i & (1<<j))
104                                 revtab[i] |= 0x80 >> j;
105
106         for(i=257,base=3; i<Nlitlen; i++) {
107                 litlenbase[i-257] = base;
108                 base += 1<<litlenextra[i-257];
109         }
110         /* strange table entry in spec... */
111         litlenbase[285-257]--;
112
113         for(i=0,base=1; i<Noff; i++) {
114                 offbase[i] = base;
115                 base += 1<<offextra[i];
116         }
117
118         len = malloc(MaxLeaf);
119         if(len == nil)
120                 return FlateNoMem;
121
122         /* static Litlen bit lengths */
123         for(i=0; i<144; i++)
124                 len[i] = 8;
125         for(i=144; i<256; i++)
126                 len[i] = 9;
127         for(i=256; i<280; i++)
128                 len[i] = 7;
129         for(i=280; i<Nlitlen; i++)
130                 len[i] = 8;
131
132         if(!hufftab(&litlentab, len, Nlitlen, MaxFlatBits))
133                 return FlateInternal;
134
135         /* static Offset bit lengths */
136         for(i=0; i<Noff; i++)
137                 len[i] = 5;
138
139         if(!hufftab(&offtab, len, Noff, MaxFlatBits))
140                 return FlateInternal;
141         free(len);
142
143         return FlateOk;
144 }
145
146 int
147 inflate(void *wr, int (*w)(void*, void*, int), void *getr, int (*get)(void*))
148 {
149         History *his;
150         Input in;
151         int final, type;
152
153         his = malloc(sizeof(History));
154         if(his == nil)
155                 return FlateNoMem;
156         his->cp = his->his;
157         his->full = 0;
158         in.getr = getr;
159         in.get = get;
160         in.wr = wr;
161         in.w = w;
162         in.nbits = 0;
163         in.sreg = 0;
164         in.error = FlateOk;
165
166         do {
167                 if(!sregfill(&in, 3))
168                         goto bad;
169                 final = in.sreg & 0x1;
170                 type = (in.sreg>>1) & 0x3;
171                 in.sreg >>= 3;
172                 in.nbits -= 3;
173                 switch(type) {
174                 default:
175                         in.error = FlateCorrupted;
176                         goto bad;
177                 case 0:
178                         /* uncompressed */
179                         if(!uncblock(&in, his))
180                                 goto bad;
181                         break;
182                 case 1:
183                         /* fixed huffman */
184                         if(!fixedblock(&in, his))
185                                 goto bad;
186                         break;
187                 case 2:
188                         /* dynamic huffman */
189                         if(!dynamicblock(&in, his))
190                                 goto bad;
191                         break;
192                 }
193         } while(!final);
194
195         if(his->cp != his->his && (*w)(wr, his->his, his->cp - his->his) != his->cp - his->his) {
196                 in.error = FlateOutputFail;
197                 goto bad;
198         }
199
200         if(!sregunget(&in))
201                 goto bad;
202
203         free(his);
204         if(in.error != FlateOk)
205                 return FlateInternal;
206         return FlateOk;
207
208 bad:
209         free(his);
210         if(in.error == FlateOk)
211                 return FlateInternal;
212         return in.error;
213 }
214
215 static int
216 uncblock(Input *in, History *his)
217 {
218         int len, nlen, c;
219         uchar *hs, *hp, *he;
220
221         if(!sregunget(in))
222                 return 0;
223         len = (*in->get)(in->getr);
224         len |= (*in->get)(in->getr)<<8;
225         nlen = (*in->get)(in->getr);
226         nlen |= (*in->get)(in->getr)<<8;
227         if(len != (~nlen&0xffff)) {
228                 in->error = FlateCorrupted;
229                 return 0;
230         }
231
232         hp = his->cp;
233         hs = his->his;
234         he = hs + HistorySize;
235
236         while(len > 0) {
237                 c = (*in->get)(in->getr);
238                 if(c < 0)
239                         return 0;
240                 *hp++ = c;
241                 if(hp == he) {
242                         his->full = 1;
243                         if((*in->w)(in->wr, hs, HistorySize) != HistorySize) {
244                                 in->error = FlateOutputFail;
245                                 return 0;
246                         }
247                         hp = hs;
248                 }
249                 len--;
250         }
251
252         his->cp = hp;
253
254         return 1;
255 }
256
257 static int
258 fixedblock(Input *in, History *his)
259 {
260         return decode(in, his, &litlentab, &offtab);
261 }
262
263 static int
264 dynamicblock(Input *in, History *his)
265 {
266         Huff *lentab, *offtab;
267         char *len;
268         int i, j, n, c, nlit, ndist, nclen, res, nb;
269
270         if(!sregfill(in, 14))
271                 return 0;
272         nlit = (in->sreg&0x1f) + 257;
273         ndist = ((in->sreg>>5) & 0x1f) + 1;
274         nclen = ((in->sreg>>10) & 0xf) + 4;
275         in->sreg >>= 14;
276         in->nbits -= 14;
277
278         if(nlit > Nlitlen || ndist > Noff || nlit < 257) {
279                 in->error = FlateCorrupted;
280                 return 0;
281         }
282
283         /* huff table header */
284         len = malloc(Nlitlen+Noff);
285         lentab = malloc(sizeof(Huff));
286         offtab = malloc(sizeof(Huff));
287         if(len == nil || lentab == nil || offtab == nil){
288                 in->error = FlateNoMem;
289                 goto bad;
290         }
291         for(i=0; i < Nclen; i++)
292                 len[i] = 0;
293         for(i=0; i<nclen; i++) {
294                 if(!sregfill(in, 3))
295                         goto bad;
296                 len[clenorder[i]] = in->sreg & 0x7;
297                 in->sreg >>= 3;
298                 in->nbits -= 3;
299         }
300
301         if(!hufftab(lentab, len, Nclen, ClenBits)){
302                 in->error = FlateCorrupted;
303                 goto bad;
304         }
305
306         n = nlit+ndist;
307         for(i=0; i<n;) {
308                 nb = lentab->minbits;
309                 for(;;){
310                         if(in->nbits<nb && !sregfill(in, nb))
311                                 goto bad;
312                         c = lentab->flat[in->sreg & lentab->flatmask];
313                         nb = c & 0xff;
314                         if(nb > in->nbits){
315                                 if(nb != 0xff)
316                                         continue;
317                                 c = hdecsym(in, lentab, c);
318                                 if(c < 0)
319                                         goto bad;
320                         }else{
321                                 c >>= 8;
322                                 in->sreg >>= nb;
323                                 in->nbits -= nb;
324                         }
325                         break;
326                 }
327
328                 if(c < 16) {
329                         j = 1;
330                 } else if(c == 16) {
331                         if(in->nbits<2 && !sregfill(in, 2))
332                                 goto bad;
333                         j = (in->sreg&0x3)+3;
334                         in->sreg >>= 2;
335                         in->nbits -= 2;
336                         if(i == 0) {
337                                 in->error = FlateCorrupted;
338                                 goto bad;
339                         }
340                         c = len[i-1];
341                 } else if(c == 17) {
342                         if(in->nbits<3 && !sregfill(in, 3))
343                                 goto bad;
344                         j = (in->sreg&0x7)+3;
345                         in->sreg >>= 3;
346                         in->nbits -= 3;
347                         c = 0;
348                 } else if(c == 18) {
349                         if(in->nbits<7 && !sregfill(in, 7))
350                                 goto bad;
351                         j = (in->sreg&0x7f)+11;
352                         in->sreg >>= 7;
353                         in->nbits -= 7;
354                         c = 0;
355                 } else {
356                         in->error = FlateCorrupted;
357                         goto bad;
358                 }
359
360                 if(i+j > n) {
361                         in->error = FlateCorrupted;
362                         goto bad;
363                 }
364
365                 while(j) {
366                         len[i] = c;
367                         i++;
368                         j--;
369                 }
370         }
371
372         if(!hufftab(lentab, len, nlit, LitlenBits)
373         || !hufftab(offtab, &len[nlit], ndist, OffBits)){
374                 in->error = FlateCorrupted;
375                 goto bad;
376         }
377
378         res = decode(in, his, lentab, offtab);
379
380         free(len);
381         free(lentab);
382         free(offtab);
383
384         return res;
385
386 bad:
387         free(len);
388         free(lentab);
389         free(offtab);
390         return 0;
391 }
392
393 static int
394 decode(Input *in, History *his, Huff *litlentab, Huff *offtab)
395 {
396         int len, off;
397         uchar *hs, *hp, *hq, *he;
398         int c;
399         int nb;
400
401         hs = his->his;
402         he = hs + HistorySize;
403         hp = his->cp;
404
405         for(;;) {
406                 nb = litlentab->minbits;
407                 for(;;){
408                         if(in->nbits<nb && !sregfill(in, nb))
409                                 return 0;
410                         c = litlentab->flat[in->sreg & litlentab->flatmask];
411                         nb = c & 0xff;
412                         if(nb > in->nbits){
413                                 if(nb != 0xff)
414                                         continue;
415                                 c = hdecsym(in, litlentab, c);
416                                 if(c < 0)
417                                         return 0;
418                         }else{
419                                 c >>= 8;
420                                 in->sreg >>= nb;
421                                 in->nbits -= nb;
422                         }
423                         break;
424                 }
425
426                 if(c < 256) {
427                         /* literal */
428                         *hp++ = c;
429                         if(hp == he) {
430                                 his->full = 1;
431                                 if((*in->w)(in->wr, hs, HistorySize) != HistorySize) {
432                                         in->error = FlateOutputFail;
433                                         return 0;
434                                 }
435                                 hp = hs;
436                         }
437                         continue;
438                 }
439
440                 if(c == 256)
441                         break;
442
443                 if(c > 285) {
444                         in->error = FlateCorrupted;
445                         return 0;
446                 }
447
448                 c -= 257;
449                 nb = litlenextra[c];
450                 if(in->nbits < nb && !sregfill(in, nb))
451                         return 0;
452                 len = litlenbase[c] + (in->sreg & ((1<<nb)-1));
453                 in->sreg >>= nb;
454                 in->nbits -= nb;
455
456                 /* get offset */
457                 nb = offtab->minbits;
458                 for(;;){
459                         if(in->nbits<nb && !sregfill(in, nb))
460                                 return 0;
461                         c = offtab->flat[in->sreg & offtab->flatmask];
462                         nb = c & 0xff;
463                         if(nb > in->nbits){
464                                 if(nb != 0xff)
465                                         continue;
466                                 c = hdecsym(in, offtab, c);
467                                 if(c < 0)
468                                         return 0;
469                         }else{
470                                 c >>= 8;
471                                 in->sreg >>= nb;
472                                 in->nbits -= nb;
473                         }
474                         break;
475                 }
476
477                 if(c > 29) {
478                         in->error = FlateCorrupted;
479                         return 0;
480                 }
481
482                 nb = offextra[c];
483                 if(in->nbits < nb && !sregfill(in, nb))
484                         return 0;
485
486                 off = offbase[c] + (in->sreg & ((1<<nb)-1));
487                 in->sreg >>= nb;
488                 in->nbits -= nb;
489
490                 hq = hp - off;
491                 if(hq < hs) {
492                         if(!his->full) {
493                                 in->error = FlateCorrupted;
494                                 return 0;
495                         }
496                         hq += HistorySize;
497                 }
498
499                 /* slow but correct */
500                 while(len) {
501                         *hp = *hq;
502                         hq++;
503                         hp++;
504                         if(hq >= he)
505                                 hq = hs;
506                         if(hp == he) {
507                                 his->full = 1;
508                                 if((*in->w)(in->wr, hs, HistorySize) != HistorySize) {
509                                         in->error = FlateOutputFail;
510                                         return 0;
511                                 }
512                                 hp = hs;
513                         }
514                         len--;
515                 }
516
517         }
518
519         his->cp = hp;
520
521         return 1;
522 }
523
524 static int
525 revcode(int c, int b)
526 {
527         /* shift encode up so it starts on bit 15 then reverse */
528         c <<= (16-b);
529         c = revtab[c>>8] | (revtab[c&0xff]<<8);
530         return c;
531 }
532
533 /*
534  * construct the huffman decoding arrays and a fast lookup table.
535  * the fast lookup is a table indexed by the next flatbits bits,
536  * which returns the symbol matched and the number of bits consumed,
537  * or the minimum number of bits needed and 0xff if more than flatbits
538  * bits are needed.
539  *
540  * flatbits can be longer than the smallest huffman code,
541  * because shorter codes are assigned smaller lexical prefixes.
542  * this means assuming zeros for the next few bits will give a
543  * conservative answer, in the sense that it will either give the
544  * correct answer, or return the minimum number of bits which
545  * are needed for an answer.
546  */
547 static int
548 hufftab(Huff *h, char *hb, int maxleaf, int flatbits)
549 {
550         ulong bitcount[MaxHuffBits];
551         ulong c, fc, ec, mincode, code, nc[MaxHuffBits];
552         int i, b, minbits, maxbits;
553
554         for(i = 0; i < MaxHuffBits; i++)
555                 bitcount[i] = 0;
556         maxbits = -1;
557         minbits = MaxHuffBits + 1;
558         for(i=0; i < maxleaf; i++){
559                 b = hb[i];
560                 if(b){
561                         bitcount[b]++;
562                         if(b < minbits)
563                                 minbits = b;
564                         if(b > maxbits)
565                                 maxbits = b;
566                 }
567         }
568
569         h->maxbits = maxbits;
570         if(maxbits <= 0){
571                 h->maxbits = 0;
572                 h->minbits = 0;
573                 h->flatmask = 0;
574                 return 1;
575         }
576         code = 0;
577         c = 0;
578         for(b = 0; b <= maxbits; b++){
579                 h->last[b] = c;
580                 c += bitcount[b];
581                 mincode = code << 1;
582                 nc[b] = mincode;
583                 code = mincode + bitcount[b];
584                 if(code > (1 << b))
585                         return 0;
586                 h->maxcode[b] = code - 1;
587                 h->last[b] += code - 1;
588         }
589
590         if(flatbits > maxbits)
591                 flatbits = maxbits;
592         h->flatmask = (1 << flatbits) - 1;
593         if(minbits > flatbits)
594                 minbits = flatbits;
595         h->minbits = minbits;
596
597         b = 1 << flatbits;
598         for(i = 0; i < b; i++)
599                 h->flat[i] = ~0;
600
601         /*
602          * initialize the flat table to include the minimum possible
603          * bit length for each code prefix
604          */
605         for(b = maxbits; b > flatbits; b--){
606                 code = h->maxcode[b];
607                 if(code == -1)
608                         break;
609                 mincode = code + 1 - bitcount[b];
610                 mincode >>= b - flatbits;
611                 code >>= b - flatbits;
612                 for(; mincode <= code; mincode++)
613                         h->flat[revcode(mincode, flatbits)] = (b << 8) | 0xff;
614         }
615
616         for(i = 0; i < maxleaf; i++){
617                 b = hb[i];
618                 if(b <= 0)
619                         continue;
620                 c = nc[b]++;
621                 if(b <= flatbits){
622                         code = (i << 8) | b;
623                         ec = (c + 1) << (flatbits - b);
624                         if(ec > (1<<flatbits))
625                                 return 0;       /* this is actually an internal error */
626                         for(fc = c << (flatbits - b); fc < ec; fc++)
627                                 h->flat[revcode(fc, flatbits)] = code;
628                 }
629                 if(b > minbits){
630                         c = h->last[b] - c;
631                         if(c >= maxleaf)
632                                 return 0;
633                         h->decode[c] = i;
634                 }
635         }
636         return 1;
637 }
638
639 static int
640 hdecsym(Input *in, Huff *h, int nb)
641 {
642         long c;
643
644         if((nb & 0xff) == 0xff)
645                 nb = nb >> 8;
646         else
647                 nb = nb & 0xff;
648         for(; nb <= h->maxbits; nb++){
649                 if(in->nbits<nb && !sregfill(in, nb))
650                         return -1;
651                 c = revtab[in->sreg&0xff]<<8;
652                 c |= revtab[(in->sreg>>8)&0xff];
653                 c >>= (16-nb);
654                 if(c <= h->maxcode[nb]){
655                         in->sreg >>= nb;
656                         in->nbits -= nb;
657                         return h->decode[h->last[nb] - c];
658                 }
659         }
660         in->error = FlateCorrupted;
661         return -1;
662 }
663
664 static int
665 sregfill(Input *in, int n)
666 {
667         int c;
668
669         while(n > in->nbits) {
670                 c = (*in->get)(in->getr);
671                 if(c < 0){
672                         in->error = FlateInputFail;
673                         return 0;
674                 }
675                 in->sreg |= c<<in->nbits;
676                 in->nbits += 8;
677         }
678         return 1;
679 }
680
681 static int
682 sregunget(Input *in)
683 {
684         if(in->nbits >= 8) {
685                 in->error = FlateInternal;
686                 return 0;
687         }
688
689         /* throw other bits on the floor */
690         in->nbits = 0;
691         in->sreg = 0;
692         return 1;
693 }