]> git.lizzy.rs Git - plan9front.git/blob - sys/src/libflate/inflate.c
amd64: FP: always use enough to fit AVX state and align to 64 bytes
[plan9front.git] / sys / src / libflate / inflate.c
1 #include <u.h>
2 #include <libc.h>
3 #include <flate.h>
4
5 enum {
6         HistorySize=    32*1024,
7         BufSize=        4*1024,
8         MaxHuffBits=    17,     /* maximum bits in a encoded code */
9         Nlitlen=        288,    /* number of litlen codes */
10         Noff=           32,     /* number of offset codes */
11         Nclen=          19,     /* number of codelen codes */
12         LenShift=       10,     /* code = len<<LenShift|code */
13         LitlenBits=     7,      /* number of bits in litlen decode table */
14         OffBits=        6,      /* number of bits in offset decode table */
15         ClenBits=       6,      /* number of bits in code len decode table */
16         MaxFlatBits=    LitlenBits,
17         MaxLeaf=        Nlitlen
18 };
19
20 typedef struct Input    Input;
21 typedef struct History  History;
22 typedef struct Huff     Huff;
23
24 struct Input
25 {
26         int     error;          /* first error encountered, or FlateOk */
27         void    *wr;
28         int     (*w)(void*, void*, int);
29         void    *getr;
30         int     (*get)(void*);
31         ulong   sreg;
32         int     nbits;
33 };
34
35 struct History
36 {
37         uchar   his[HistorySize];
38         uchar   *cp;            /* current pointer in history */
39         int     full;           /* his has been filled up at least once */
40 };
41
42 struct Huff
43 {
44         int     maxbits;        /* max bits for any code */
45         int     minbits;        /* min bits to get before looking in flat */
46         int     flatmask;       /* bits used in "flat" fast decoding table */
47         ulong   flat[1<<MaxFlatBits];
48         ulong   maxcode[MaxHuffBits];
49         ulong   last[MaxHuffBits];
50         ulong   decode[MaxLeaf];
51         int     maxleaf;
52 };
53
54 /* litlen code words 257-285 extra bits */
55 static int litlenextra[Nlitlen-257] =
56 {
57 /* 257 */       0, 0, 0,
58 /* 260 */       0, 0, 0, 0, 0, 1, 1, 1, 1, 2,
59 /* 270 */       2, 2, 2, 3, 3, 3, 3, 4, 4, 4,
60 /* 280 */       4, 5, 5, 5, 5, 0, 0, 0
61 };
62
63 static int litlenbase[Nlitlen-257];
64
65 /* offset code word extra bits */
66 static int offextra[Noff] =
67 {
68         0,  0,  0,  0,  1,  1,  2,  2,  3,  3,
69         4,  4,  5,  5,  6,  6,  7,  7,  8,  8,
70         9,  9,  10, 10, 11, 11, 12, 12, 13, 13,
71         0,  0,
72 };
73 static int offbase[Noff];
74
75 /* order code lengths */
76 static int clenorder[Nclen] =
77 {
78         16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15
79 };
80
81 /* for static huffman tables */
82 static  Huff    litlentab;
83 static  Huff    offtab;
84 static  uchar   revtab[256];
85
86 static int      uncblock(Input *in, History*);
87 static int      fixedblock(Input *in, History*);
88 static int      dynamicblock(Input *in, History*);
89 static int      sregfill(Input *in, int n);
90 static int      sregunget(Input *in);
91 static int      decode(Input*, History*, Huff*, Huff*);
92 static int      hufftab(Huff*, char*, int, int);
93 static int      hdecsym(Input *in, Huff *h, int b);
94
95 int
96 inflateinit(void)
97 {
98         char *len;
99         int i, j, base;
100
101         /* byte reverse table */
102         for(i=0; i<256; i++)
103                 for(j=0; j<8; j++)
104                         if(i & (1<<j))
105                                 revtab[i] |= 0x80 >> j;
106
107         for(i=257,base=3; i<Nlitlen; i++) {
108                 litlenbase[i-257] = base;
109                 base += 1<<litlenextra[i-257];
110         }
111         /* strange table entry in spec... */
112         litlenbase[285-257]--;
113
114         for(i=0,base=1; i<Noff; i++) {
115                 offbase[i] = base;
116                 base += 1<<offextra[i];
117         }
118
119         len = malloc(MaxLeaf);
120         if(len == nil)
121                 return FlateNoMem;
122
123         /* static Litlen bit lengths */
124         for(i=0; i<144; i++)
125                 len[i] = 8;
126         for(i=144; i<256; i++)
127                 len[i] = 9;
128         for(i=256; i<280; i++)
129                 len[i] = 7;
130         for(i=280; i<Nlitlen; i++)
131                 len[i] = 8;
132
133         if(!hufftab(&litlentab, len, Nlitlen, MaxFlatBits))
134                 return FlateInternal;
135
136         /* static Offset bit lengths */
137         for(i=0; i<Noff; i++)
138                 len[i] = 5;
139
140         if(!hufftab(&offtab, len, Noff, MaxFlatBits))
141                 return FlateInternal;
142         free(len);
143
144         return FlateOk;
145 }
146
147 int
148 inflate(void *wr, int (*w)(void*, void*, int), void *getr, int (*get)(void*))
149 {
150         History *his;
151         Input in;
152         int final, type;
153
154         his = malloc(sizeof(History));
155         if(his == nil)
156                 return FlateNoMem;
157         his->cp = his->his;
158         his->full = 0;
159         in.getr = getr;
160         in.get = get;
161         in.wr = wr;
162         in.w = w;
163         in.nbits = 0;
164         in.sreg = 0;
165         in.error = FlateOk;
166
167         do {
168                 if(!sregfill(&in, 3))
169                         goto bad;
170                 final = in.sreg & 0x1;
171                 type = (in.sreg>>1) & 0x3;
172                 in.sreg >>= 3;
173                 in.nbits -= 3;
174                 switch(type) {
175                 default:
176                         in.error = FlateCorrupted;
177                         goto bad;
178                 case 0:
179                         /* uncompressed */
180                         if(!uncblock(&in, his))
181                                 goto bad;
182                         break;
183                 case 1:
184                         /* fixed huffman */
185                         if(!fixedblock(&in, his))
186                                 goto bad;
187                         break;
188                 case 2:
189                         /* dynamic huffman */
190                         if(!dynamicblock(&in, his))
191                                 goto bad;
192                         break;
193                 }
194         } while(!final);
195
196         if(his->cp != his->his && (*w)(wr, his->his, his->cp - his->his) != his->cp - his->his) {
197                 in.error = FlateOutputFail;
198                 goto bad;
199         }
200
201         if(!sregunget(&in))
202                 goto bad;
203
204         free(his);
205         if(in.error != FlateOk)
206                 return FlateInternal;
207         return FlateOk;
208
209 bad:
210         free(his);
211         if(in.error == FlateOk)
212                 return FlateInternal;
213         return in.error;
214 }
215
216 static int
217 uncblock(Input *in, History *his)
218 {
219         int len, nlen, c;
220         uchar *hs, *hp, *he;
221
222         if(!sregunget(in))
223                 return 0;
224         len = (*in->get)(in->getr);
225         len |= (*in->get)(in->getr)<<8;
226         nlen = (*in->get)(in->getr);
227         nlen |= (*in->get)(in->getr)<<8;
228         if(len != (~nlen&0xffff)) {
229                 in->error = FlateCorrupted;
230                 return 0;
231         }
232
233         hp = his->cp;
234         hs = his->his;
235         he = hs + HistorySize;
236
237         while(len > 0) {
238                 c = (*in->get)(in->getr);
239                 if(c < 0)
240                         return 0;
241                 *hp++ = c;
242                 if(hp == he) {
243                         his->full = 1;
244                         if((*in->w)(in->wr, hs, HistorySize) != HistorySize) {
245                                 in->error = FlateOutputFail;
246                                 return 0;
247                         }
248                         hp = hs;
249                 }
250                 len--;
251         }
252
253         his->cp = hp;
254
255         return 1;
256 }
257
258 static int
259 fixedblock(Input *in, History *his)
260 {
261         return decode(in, his, &litlentab, &offtab);
262 }
263
264 static int
265 dynamicblock(Input *in, History *his)
266 {
267         Huff *lentab, *offtab;
268         char *len;
269         int i, j, n, c, nlit, ndist, nclen, res, nb;
270
271         if(!sregfill(in, 14))
272                 return 0;
273         nlit = (in->sreg&0x1f) + 257;
274         ndist = ((in->sreg>>5) & 0x1f) + 1;
275         nclen = ((in->sreg>>10) & 0xf) + 4;
276         in->sreg >>= 14;
277         in->nbits -= 14;
278
279         if(nlit > Nlitlen || ndist > Noff || nlit < 257) {
280                 in->error = FlateCorrupted;
281                 return 0;
282         }
283
284         /* huff table header */
285         len = malloc(Nlitlen+Noff);
286         lentab = malloc(sizeof(Huff));
287         offtab = malloc(sizeof(Huff));
288         if(len == nil || lentab == nil || offtab == nil){
289                 in->error = FlateNoMem;
290                 goto bad;
291         }
292         for(i=0; i < Nclen; i++)
293                 len[i] = 0;
294         for(i=0; i<nclen; i++) {
295                 if(!sregfill(in, 3))
296                         goto bad;
297                 len[clenorder[i]] = in->sreg & 0x7;
298                 in->sreg >>= 3;
299                 in->nbits -= 3;
300         }
301
302         if(!hufftab(lentab, len, Nclen, ClenBits)){
303                 in->error = FlateCorrupted;
304                 goto bad;
305         }
306
307         n = nlit+ndist;
308         for(i=0; i<n;) {
309                 nb = lentab->minbits;
310                 for(;;){
311                         if(in->nbits<nb && !sregfill(in, nb))
312                                 goto bad;
313                         c = lentab->flat[in->sreg & lentab->flatmask];
314                         nb = c & 0xff;
315                         if(nb > in->nbits){
316                                 if(nb != 0xff)
317                                         continue;
318                                 c = hdecsym(in, lentab, c);
319                                 if(c < 0)
320                                         goto bad;
321                         }else{
322                                 c >>= 8;
323                                 in->sreg >>= nb;
324                                 in->nbits -= nb;
325                         }
326                         break;
327                 }
328
329                 if(c < 16) {
330                         j = 1;
331                 } else if(c == 16) {
332                         if(in->nbits<2 && !sregfill(in, 2))
333                                 goto bad;
334                         j = (in->sreg&0x3)+3;
335                         in->sreg >>= 2;
336                         in->nbits -= 2;
337                         if(i == 0) {
338                                 in->error = FlateCorrupted;
339                                 goto bad;
340                         }
341                         c = len[i-1];
342                 } else if(c == 17) {
343                         if(in->nbits<3 && !sregfill(in, 3))
344                                 goto bad;
345                         j = (in->sreg&0x7)+3;
346                         in->sreg >>= 3;
347                         in->nbits -= 3;
348                         c = 0;
349                 } else if(c == 18) {
350                         if(in->nbits<7 && !sregfill(in, 7))
351                                 goto bad;
352                         j = (in->sreg&0x7f)+11;
353                         in->sreg >>= 7;
354                         in->nbits -= 7;
355                         c = 0;
356                 } else {
357                         in->error = FlateCorrupted;
358                         goto bad;
359                 }
360
361                 if(i+j > n) {
362                         in->error = FlateCorrupted;
363                         goto bad;
364                 }
365
366                 while(j) {
367                         len[i] = c;
368                         i++;
369                         j--;
370                 }
371         }
372
373         if(!hufftab(lentab, len, nlit, LitlenBits)
374         || !hufftab(offtab, &len[nlit], ndist, OffBits)){
375                 in->error = FlateCorrupted;
376                 goto bad;
377         }
378
379         res = decode(in, his, lentab, offtab);
380
381         free(len);
382         free(lentab);
383         free(offtab);
384
385         return res;
386
387 bad:
388         free(len);
389         free(lentab);
390         free(offtab);
391         return 0;
392 }
393
394 static int
395 decode(Input *in, History *his, Huff *litlentab, Huff *offtab)
396 {
397         int len, off;
398         uchar *hs, *hp, *hq, *he;
399         int c;
400         int nb;
401
402         hs = his->his;
403         he = hs + HistorySize;
404         hp = his->cp;
405
406         for(;;) {
407                 nb = litlentab->minbits;
408                 for(;;){
409                         if(in->nbits<nb && !sregfill(in, nb))
410                                 return 0;
411                         c = litlentab->flat[in->sreg & litlentab->flatmask];
412                         nb = c & 0xff;
413                         if(nb > in->nbits){
414                                 if(nb != 0xff)
415                                         continue;
416                                 c = hdecsym(in, litlentab, c);
417                                 if(c < 0)
418                                         return 0;
419                         }else{
420                                 c >>= 8;
421                                 in->sreg >>= nb;
422                                 in->nbits -= nb;
423                         }
424                         break;
425                 }
426
427                 if(c < 256) {
428                         /* literal */
429                         *hp++ = c;
430                         if(hp == he) {
431                                 his->full = 1;
432                                 if((*in->w)(in->wr, hs, HistorySize) != HistorySize) {
433                                         in->error = FlateOutputFail;
434                                         return 0;
435                                 }
436                                 hp = hs;
437                         }
438                         continue;
439                 }
440
441                 if(c == 256)
442                         break;
443
444                 if(c > 285) {
445                         in->error = FlateCorrupted;
446                         return 0;
447                 }
448
449                 c -= 257;
450                 nb = litlenextra[c];
451                 if(in->nbits < nb && !sregfill(in, nb))
452                         return 0;
453                 len = litlenbase[c] + (in->sreg & ((1<<nb)-1));
454                 in->sreg >>= nb;
455                 in->nbits -= nb;
456
457                 /* get offset */
458                 nb = offtab->minbits;
459                 for(;;){
460                         if(in->nbits<nb && !sregfill(in, nb))
461                                 return 0;
462                         c = offtab->flat[in->sreg & offtab->flatmask];
463                         nb = c & 0xff;
464                         if(nb > in->nbits){
465                                 if(nb != 0xff)
466                                         continue;
467                                 c = hdecsym(in, offtab, c);
468                                 if(c < 0)
469                                         return 0;
470                         }else{
471                                 c >>= 8;
472                                 in->sreg >>= nb;
473                                 in->nbits -= nb;
474                         }
475                         break;
476                 }
477
478                 if(c > 29) {
479                         in->error = FlateCorrupted;
480                         return 0;
481                 }
482
483                 nb = offextra[c];
484                 if(in->nbits < nb && !sregfill(in, nb))
485                         return 0;
486
487                 off = offbase[c] + (in->sreg & ((1<<nb)-1));
488                 in->sreg >>= nb;
489                 in->nbits -= nb;
490
491                 hq = hp - off;
492                 if(hq < hs) {
493                         if(!his->full) {
494                                 in->error = FlateCorrupted;
495                                 return 0;
496                         }
497                         hq += HistorySize;
498                 }
499
500                 /* slow but correct */
501                 while(len) {
502                         *hp = *hq;
503                         hq++;
504                         hp++;
505                         if(hq >= he)
506                                 hq = hs;
507                         if(hp == he) {
508                                 his->full = 1;
509                                 if((*in->w)(in->wr, hs, HistorySize) != HistorySize) {
510                                         in->error = FlateOutputFail;
511                                         return 0;
512                                 }
513                                 hp = hs;
514                         }
515                         len--;
516                 }
517
518         }
519
520         his->cp = hp;
521
522         return 1;
523 }
524
525 static int
526 revcode(int c, int b)
527 {
528         /* shift encode up so it starts on bit 15 then reverse */
529         c <<= (16-b);
530         c = revtab[c>>8] | (revtab[c&0xff]<<8);
531         return c;
532 }
533
534 /*
535  * construct the huffman decoding arrays and a fast lookup table.
536  * the fast lookup is a table indexed by the next flatbits bits,
537  * which returns the symbol matched and the number of bits consumed,
538  * or the minimum number of bits needed and 0xff if more than flatbits
539  * bits are needed.
540  *
541  * flatbits can be longer than the smallest huffman code,
542  * because shorter codes are assigned smaller lexical prefixes.
543  * this means assuming zeros for the next few bits will give a
544  * conservative answer, in the sense that it will either give the
545  * correct answer, or return the minimum number of bits which
546  * are needed for an answer.
547  */
548 static int
549 hufftab(Huff *h, char *hb, int maxleaf, int flatbits)
550 {
551         ulong bitcount[MaxHuffBits];
552         ulong c, fc, ec, mincode, code, nc[MaxHuffBits];
553         int i, b, minbits, maxbits;
554
555         for(i = 0; i < MaxHuffBits; i++)
556                 bitcount[i] = 0;
557         maxbits = -1;
558         minbits = MaxHuffBits + 1;
559         for(i=0; i < maxleaf; i++){
560                 b = hb[i];
561                 if(b){
562                         bitcount[b]++;
563                         if(b < minbits)
564                                 minbits = b;
565                         if(b > maxbits)
566                                 maxbits = b;
567                 }
568         }
569         if(maxbits <= 0){
570                 h->maxbits = 0;
571                 h->minbits = 0;
572                 h->flatmask = 0;
573                 h->maxleaf = 0;
574                 return 1;
575         }
576         h->maxbits = maxbits;
577         if(maxbits >= MaxHuffBits || minbits <= 0)
578                 return 0;
579         code = 0;
580         c = 0;
581         for(b = 0; b <= maxbits; b++){
582                 h->last[b] = c;
583                 c += bitcount[b];
584                 mincode = code << 1;
585                 nc[b] = mincode;
586                 code = mincode + bitcount[b];
587                 if(code > (1 << b))
588                         return 0;
589                 h->maxcode[b] = code - 1;
590                 h->last[b] += code - 1;
591         }
592
593         if(flatbits > maxbits)
594                 flatbits = maxbits;
595         h->flatmask = (1 << flatbits) - 1;
596         if(minbits > flatbits)
597                 minbits = flatbits;
598         h->minbits = minbits;
599
600         b = 1 << flatbits;
601         for(i = 0; i < b; i++)
602                 h->flat[i] = ~0;
603
604         /*
605          * initialize the flat table to include the minimum possible
606          * bit length for each code prefix
607          */
608         for(b = maxbits; b > flatbits; b--){
609                 code = h->maxcode[b];
610                 if(code == -1)
611                         break;
612                 mincode = code + 1 - bitcount[b];
613                 mincode >>= b - flatbits;
614                 code >>= b - flatbits;
615                 for(; mincode <= code; mincode++)
616                         h->flat[revcode(mincode, flatbits)] = (b << 8) | 0xff;
617         }
618
619         h->maxleaf = maxleaf;
620         for(i = 0; i < maxleaf; i++){
621                 b = hb[i];
622                 if(b <= 0)
623                         continue;
624                 c = nc[b]++;
625                 if(b <= flatbits){
626                         code = (i << 8) | b;
627                         ec = (c + 1) << (flatbits - b);
628                         if(ec > (1<<flatbits))
629                                 return 0;       /* this is actually an internal error */
630                         for(fc = c << (flatbits - b); fc < ec; fc++)
631                                 h->flat[revcode(fc, flatbits)] = code;
632                 }
633                 if(b > minbits){
634                         c = h->last[b] - c;
635                         if(c >= maxleaf)
636                                 return 0;
637                         h->decode[c] = i;
638                 }
639         }
640         return 1;
641 }
642
643 static int
644 hdecsym(Input *in, Huff *h, int nb)
645 {
646         ulong c;
647
648         if((nb & 0xff) == 0xff)
649                 nb = nb >> 8;
650         else
651                 nb = nb & 0xff;
652         for(; nb <= h->maxbits; nb++){
653                 if(in->nbits<nb && !sregfill(in, nb))
654                         return -1;
655                 c = revtab[in->sreg&0xff]<<8;
656                 c |= revtab[(in->sreg>>8)&0xff];
657                 c >>= (16-nb);
658                 if(c <= h->maxcode[nb]){
659                         c = h->last[nb] - c;
660                         if(c >= h->maxleaf)
661                                 break;
662                         in->sreg >>= nb;
663                         in->nbits -= nb;
664                         return h->decode[c];
665                 }
666         }
667         in->error = FlateCorrupted;
668         return -1;
669 }
670
671 static int
672 sregfill(Input *in, int n)
673 {
674         int c;
675
676         while(n > in->nbits) {
677                 c = (*in->get)(in->getr);
678                 if(c < 0){
679                         in->error = FlateInputFail;
680                         return 0;
681                 }
682                 in->sreg |= c<<in->nbits;
683                 in->nbits += 8;
684         }
685         return 1;
686 }
687
688 static int
689 sregunget(Input *in)
690 {
691         if(in->nbits >= 8) {
692                 in->error = FlateInternal;
693                 return 0;
694         }
695
696         /* throw other bits on the floor */
697         in->nbits = 0;
698         in->sreg = 0;
699         return 1;
700 }