]> git.lizzy.rs Git - plan9front.git/blob - sys/src/9/ip/ipmux.c
audiohda: fix syntax error
[plan9front.git] / sys / src / 9 / ip / ipmux.c
1 /*
2  * IP packet filter
3  */
4 #include "u.h"
5 #include "../port/lib.h"
6 #include "mem.h"
7 #include "dat.h"
8 #include "fns.h"
9 #include "../port/error.h"
10
11 #include "ip.h"
12 #include "ipv6.h"
13
14 typedef struct Ipmuxrock  Ipmuxrock;
15 typedef struct Ipmux      Ipmux;
16
17 enum
18 {
19         Tver,
20         Tproto,
21         Tdata,
22         Tiph,
23         Tdst,
24         Tsrc,
25         Tifc,
26 };
27
28 /*
29  *  a node in the decision tree
30  */
31 struct Ipmux
32 {
33         Ipmux   *yes;
34         Ipmux   *no;
35         uchar   type;           /* type of field(Txxxx) */
36         uchar   len;            /* length in bytes of item to compare */
37         uchar   n;              /* number of items val points to */
38         int     off;            /* offset of comparison */
39         uchar   *val;
40         uchar   *mask;
41         uchar   *e;             /* val+n*len*/
42         int     ref;            /* so we can garbage collect */
43         Conv    *conv;
44 };
45
46 /*
47  *  someplace to hold per conversation data
48  */
49 struct Ipmuxrock
50 {
51         Ipmux   *chain;
52 };
53
54 static int      ipmuxsprint(Ipmux*, int, char*, int);
55 static void     ipmuxkick(void *x);
56 static void     ipmuxfree(Ipmux *f);
57
58 static char*
59 skipwhite(char *p)
60 {
61         while(*p == ' ' || *p == '\t')
62                 p++;
63         return p;
64 }
65
66 static char*
67 follows(char *p, char c)
68 {
69         char *f;
70
71         f = strchr(p, c);
72         if(f == nil)
73                 return nil;
74         *f++ = 0;
75         f = skipwhite(f);
76         if(*f == 0)
77                 return nil;
78         return f;
79 }
80
81 static Ipmux*
82 parseop(char **pp)
83 {
84         char *p = *pp;
85         int type, off, end, len;
86         Ipmux *f;
87
88         p = skipwhite(p);
89         if(strncmp(p, "ver", 3) == 0){
90                 type = Tver;
91                 off = 0;
92                 len = 1;
93                 p += 3;
94         }
95         else if(strncmp(p, "dst", 3) == 0){
96                 type = Tdst;
97                 off = offsetof(Ip6hdr, dst[0]);
98                 len = IPaddrlen;
99                 p += 3;
100         }
101         else if(strncmp(p, "src", 3) == 0){
102                 type = Tsrc;
103                 off = offsetof(Ip6hdr, src[0]);
104                 len = IPaddrlen;
105                 p += 3;
106         }
107         else if(strncmp(p, "ifc", 3) == 0){
108                 type = Tifc;
109                 off = -IPaddrlen;
110                 len = IPaddrlen;
111                 p += 3;
112         }
113         else if(strncmp(p, "proto", 5) == 0){
114                 type = Tproto;
115                 off = offsetof(Ip6hdr, proto);
116                 len = 1;
117                 p += 5;
118         }
119         else if(strncmp(p, "data", 4) == 0 || strncmp(p, "iph", 3) == 0){
120                 if(strncmp(p, "data", 4) == 0) {
121                         type = Tdata;
122                         p += 4;
123                 }
124                 else {
125                         type = Tiph;
126                         p += 3;
127                 }
128                 p = skipwhite(p);
129                 if(*p != '[')
130                         return nil;
131                 p++;
132                 off = strtoul(p, &p, 0);
133                 if(off < 0)
134                         return nil;
135                 p = skipwhite(p);
136                 if(*p != ':')
137                         end = off;
138                 else {
139                         p++;
140                         p = skipwhite(p);
141                         end = strtoul(p, &p, 0);
142                         if(end < off)
143                                 return nil;
144                         p = skipwhite(p);
145                 }
146                 if(*p != ']')
147                         return nil;
148                 p++;
149                 len = end - off + 1;
150         }
151         else
152                 return nil;
153
154         f = smalloc(sizeof(*f));
155         f->type = type;
156         f->len = len;
157         f->off = off;
158         f->val = nil;
159         f->mask = nil;
160         f->n = 1;
161         f->ref = 1;
162         return f;       
163 }
164
165 static int
166 htoi(char x)
167 {
168         if(x >= '0' && x <= '9')
169                 x -= '0';
170         else if(x >= 'a' && x <= 'f')
171                 x -= 'a' - 10;
172         else if(x >= 'A' && x <= 'F')
173                 x -= 'A' - 10;
174         else
175                 x = 0;
176         return x;
177 }
178
179 static int
180 hextoi(char *p)
181 {
182         return (htoi(p[0])<<4) | htoi(p[1]);
183 }
184
185 static void
186 parseval(uchar *v, char *p, int len)
187 {
188         while(*p && len-- > 0){
189                 *v++ = hextoi(p);
190                 p += 2;
191         }
192 }
193
194 static Ipmux*
195 parsemux(char *p)
196 {
197         int n;
198         Ipmux *f;
199         char *val;
200         char *mask;
201         char *vals[20];
202         uchar *v;
203
204         /* parse operand */
205         f = parseop(&p);
206         if(f == nil)
207                 return nil;
208
209         /* find value */
210         val = follows(p, '=');
211         if(val == nil)
212                 goto parseerror;
213
214         /* parse mask */
215         mask = follows(p, '&');
216         if(mask != nil){
217                 switch(f->type){
218                 case Tsrc:
219                 case Tdst:
220                 case Tifc:
221                         f->mask = smalloc(f->len);
222                         parseipmask(f->mask, mask, 0);
223                         break;
224                 case Tdata:
225                 case Tiph:
226                         f->mask = smalloc(f->len);
227                         parseval(f->mask, mask, f->len);
228                         break;
229                 default:
230                         goto parseerror;
231                 }
232         } else if(f->type == Tver){
233                 f->mask = smalloc(f->len);
234                 f->mask[0] = 0xF0;
235         }
236
237         /* parse vals */
238         f->n = getfields(val, vals, nelem(vals), 1, "|");
239         if(f->n == 0)
240                 goto parseerror;
241         f->val = smalloc(f->n*f->len);
242         v = f->val;
243         for(n = 0; n < f->n; n++){
244                 switch(f->type){
245                 case Tver:
246                         if(f->n != 1)
247                                 goto parseerror;
248                         if(strcmp(vals[n], "6") == 0)
249                                 *v = IP_VER6;
250                         else if(strcmp(vals[n], "4") == 0)
251                                 *v = IP_VER4;
252                         else
253                                 goto parseerror;
254                         break;
255                 case Tsrc:
256                 case Tdst:
257                 case Tifc:
258                         if(parseip(v, vals[n]) == -1)
259                                 goto parseerror;
260                         break;
261                 case Tproto:
262                 case Tdata:
263                 case Tiph:
264                         parseval(v, vals[n], f->len);
265                         break;
266                 }
267                 v += f->len;
268         }
269         f->e = f->val + f->n*f->len;
270         return f;
271
272 parseerror:
273         ipmuxfree(f);
274         return nil;
275 }
276
277 /*
278  *  Compare relative ordering of two ipmuxs.  This doesn't compare the
279  *  values, just the fields being looked at.  
280  *
281  *  returns:    <0 if a is a more specific match
282  *               0 if a and b are matching on the same fields
283  *              >0 if b is a more specific match
284  */
285 static int
286 ipmuxcmp(Ipmux *a, Ipmux *b)
287 {
288         int n;
289
290         /* compare types, lesser ones are more important */
291         n = a->type - b->type;
292         if(n != 0)
293                 return n;
294
295         /* compare offsets, call earlier ones more specific */
296         n = a->off - b->off;
297         if(n != 0)
298                 return n;
299
300         /* compare match lengths, longer ones are more specific */
301         n = b->len - a->len;
302         if(n != 0)
303                 return n;
304
305         /*
306          *  if we get here we have two entries matching
307          *  the same bytes of the record.  Now check
308          *  the mask for equality.  Longer masks are
309          *  more specific.
310          */
311         if(a->mask != nil && b->mask == nil)
312                 return -1;
313         if(a->mask == nil && b->mask != nil)
314                 return 1;
315         if(a->mask != nil && b->mask != nil){
316                 n = memcmp(b->mask, a->mask, a->len);
317                 if(n != 0)
318                         return n;
319         }
320         return 0;
321 }
322
323 /*
324  *  Compare the values of two ipmuxs.  We're assuming that ipmuxcmp
325  *  returned 0 comparing them.
326  */
327 static int
328 ipmuxvalcmp(Ipmux *a, Ipmux *b)
329 {
330         int n;
331
332         n = b->len*b->n - a->len*a->n;
333         if(n != 0)
334                 return n;
335         return memcmp(a->val, b->val, a->len*a->n);
336
337
338 /*
339  *  add onto an existing ipmux chain in the canonical comparison
340  *  order
341  */
342 static void
343 ipmuxchain(Ipmux **l, Ipmux *f)
344 {
345         for(; *l; l = &(*l)->yes)
346                 if(ipmuxcmp(f, *l) < 0)
347                         break;
348         f->yes = *l;
349         *l = f;
350 }
351
352 /*
353  *  copy a tree
354  */
355 static Ipmux*
356 ipmuxcopy(Ipmux *f)
357 {
358         Ipmux *nf;
359
360         if(f == nil)
361                 return nil;
362         nf = smalloc(sizeof *nf);
363         *nf = *f;
364         nf->no = ipmuxcopy(f->no);
365         nf->yes = ipmuxcopy(f->yes);
366         if(f->mask != nil){
367                 nf->mask = smalloc(f->len);
368                 memmove(nf->mask, f->mask, f->len);
369         }
370         nf->val = smalloc(f->n*f->len);
371         nf->e = nf->val + f->len*f->n;
372         memmove(nf->val, f->val, f->n*f->len);
373         return nf;
374 }
375
376 static void
377 ipmuxfree(Ipmux *f)
378 {
379         if(f == nil)
380                 return;
381         free(f->val);
382         free(f->mask);
383         free(f);
384 }
385
386 static void
387 ipmuxtreefree(Ipmux *f)
388 {
389         if(f == nil)
390                 return;
391         ipmuxfree(f->no);
392         ipmuxfree(f->yes);
393         ipmuxfree(f);
394 }
395
396 /*
397  *  merge two trees
398  */
399 static Ipmux*
400 ipmuxmerge(Ipmux *a, Ipmux *b)
401 {
402         int n;
403         Ipmux *f;
404
405         if(a == nil)
406                 return b;
407         if(b == nil)
408                 return a;
409         n = ipmuxcmp(a, b);
410         if(n < 0){
411                 f = ipmuxcopy(b);
412                 a->yes = ipmuxmerge(a->yes, b);
413                 a->no = ipmuxmerge(a->no, f);
414                 return a;
415         }
416         if(n > 0){
417                 f = ipmuxcopy(a);
418                 b->yes = ipmuxmerge(b->yes, a);
419                 b->no = ipmuxmerge(b->no, f);
420                 return b;
421         }
422         if(ipmuxvalcmp(a, b) == 0){
423                 a->yes = ipmuxmerge(a->yes, b->yes);
424                 a->no = ipmuxmerge(a->no, b->no);
425                 a->ref++;
426                 ipmuxfree(b);
427                 return a;
428         }
429         a->no = ipmuxmerge(a->no, b);
430         return a;
431 }
432
433 /*
434  *  remove a chain from a demux tree.  This is like merging accept that
435  *  we remove instead of insert.
436  */
437 static int
438 ipmuxremove(Ipmux **l, Ipmux *f)
439 {
440         int n, rv;
441         Ipmux *ft;
442
443         if(f == nil)
444                 return 0;               /* we've removed it all */
445         if(*l == nil)
446                 return -1;
447
448         ft = *l;
449         n = ipmuxcmp(ft, f);
450         if(n < 0){
451                 /* *l is maching an earlier field, descend both paths */
452                 rv = ipmuxremove(&ft->yes, f);
453                 rv += ipmuxremove(&ft->no, f);
454                 return rv;
455         }
456         if(n > 0){
457                 /* f represents an earlier field than *l, this should be impossible */
458                 return -1;
459         }
460
461         /* if we get here f and *l are comparing the same fields */
462         if(ipmuxvalcmp(ft, f) != 0){
463                 /* different values mean mutually exclusive */
464                 return ipmuxremove(&ft->no, f);
465         }
466
467         ipmuxremove(&ft->no, f->no);
468
469         /* we found a match */
470         if(--(ft->ref) == 0){
471                 /*
472                  *  a dead node implies the whole yes side is also dead.
473                  *  since our chain is constrained to be on that side,
474                  *  we're done.
475                  */
476                 ipmuxtreefree(ft->yes);
477                 *l = ft->no;
478                 ipmuxfree(ft);
479                 return 0;
480         }
481
482         /*
483          *  free the rest of the chain.  it is constrained to match the
484          *  yes side.
485          */
486         return ipmuxremove(&ft->yes, f->yes);
487 }
488
489 /*
490  * convert to ipv4 filter
491  */
492 static Ipmux*
493 ipmuxconv4(Ipmux *f)
494 {
495         int i, n;
496
497         if(f == nil)
498                 return nil;
499
500         switch(f->type){
501         case Tproto:
502                 f->off = offsetof(Ip4hdr, proto);
503                 break;
504         case Tdst:
505                 f->off = offsetof(Ip4hdr, dst[0]);
506                 if(0){
507         case Tsrc:
508                 f->off = offsetof(Ip4hdr, src[0]);
509                 }
510                 if(f->len != IPaddrlen)
511                         break;
512                 n = 0;
513                 for(i = 0; i < f->n; i++){
514                         if(isv4(f->val + i*IPaddrlen)){
515                                 memmove(f->val + n*IPv4addrlen, f->val + i*IPaddrlen + IPv4off, IPv4addrlen);
516                                 n++;
517                         }
518                 }
519                 if(n == 0){
520                         ipmuxtreefree(f);
521                         return nil;
522                 }
523                 f->n = n;
524                 f->len = IPv4addrlen;
525                 if(f->mask != nil)
526                         memmove(f->mask, f->mask+IPv4off, IPv4addrlen);
527         }
528         f->e = f->val + f->n*f->len;
529
530         f->yes = ipmuxconv4(f->yes);
531         f->no = ipmuxconv4(f->no);
532
533         return f;
534 }
535
536 /*
537  *  connection request is a semi separated list of filters
538  *  e.g. ver=4;proto=17;data[0:4]=11aa22bb;ifc=135.104.9.2&255.255.255.0
539  *
540  *  there's no protection against overlapping specs.
541  */
542 static char*
543 ipmuxconnect(Conv *c, char **argv, int argc)
544 {
545         int i, n;
546         char *field[10];
547         Ipmux *mux, *chain;
548         Ipmuxrock *r;
549         Fs *f;
550
551         f = c->p->f;
552
553         if(argc != 2)
554                 return Ebadarg;
555
556         n = getfields(argv[1], field, nelem(field), 1, ";");
557         if(n <= 0)
558                 return Ebadarg;
559
560         chain = nil;
561         mux = nil;
562         for(i = 0; i < n; i++){
563                 mux = parsemux(field[i]);
564                 if(mux == nil){
565                         ipmuxtreefree(chain);
566                         return Ebadarg;
567                 }
568                 ipmuxchain(&chain, mux);
569         }
570         if(chain == nil)
571                 return Ebadarg;
572         mux->conv = c;
573
574         if(chain->type != Tver) {
575                 char ver6[] = "ver=6";
576                 mux = parsemux(ver6);
577                 mux->yes = chain;
578                 mux->no = ipmuxcopy(chain);
579                 chain = mux;
580         }
581         if(*chain->val == IP_VER4)
582                 chain->yes = ipmuxconv4(chain->yes);
583         else
584                 chain->no = ipmuxconv4(chain->no);
585
586         /* save a copy of the chain so we can later remove it */
587         mux = ipmuxcopy(chain);
588         r = (Ipmuxrock*)(c->ptcl);
589         r->chain = chain;
590
591         /* add the chain to the protocol demultiplexor tree */
592         wlock(f);
593         f->ipmux->priv = ipmuxmerge(f->ipmux->priv, mux);
594         wunlock(f);
595
596         Fsconnected(c, nil);
597         return nil;
598 }
599
600 static int
601 ipmuxstate(Conv *c, char *state, int n)
602 {
603         Ipmuxrock *r;
604         
605         r = (Ipmuxrock*)(c->ptcl);
606         return ipmuxsprint(r->chain, 0, state, n);
607 }
608
609 static void
610 ipmuxcreate(Conv *c)
611 {
612         Ipmuxrock *r;
613
614         c->rq = qopen(64*1024, Qmsg, 0, c);
615         c->wq = qopen(64*1024, Qkick, ipmuxkick, c);
616         r = (Ipmuxrock*)(c->ptcl);
617         r->chain = nil;
618 }
619
620 static char*
621 ipmuxannounce(Conv*, char**, int)
622 {
623         return "ipmux does not support announce";
624 }
625
626 static void
627 ipmuxclose(Conv *c)
628 {
629         Ipmuxrock *r;
630         Fs *f = c->p->f;
631
632         r = (Ipmuxrock*)(c->ptcl);
633
634         qclose(c->rq);
635         qclose(c->wq);
636         qclose(c->eq);
637         ipmove(c->laddr, IPnoaddr);
638         ipmove(c->raddr, IPnoaddr);
639         c->lport = 0;
640         c->rport = 0;
641
642         wlock(f);
643         ipmuxremove(&(c->p->priv), r->chain);
644         wunlock(f);
645         ipmuxtreefree(r->chain);
646         r->chain = nil;
647 }
648
649 /*
650  *  takes a fully formed ip packet and just passes it down
651  *  the stack
652  */
653 static void
654 ipmuxkick(void *x)
655 {
656         Conv *c = x;
657         Block *bp;
658
659         bp = qget(c->wq);
660         if(bp != nil) {
661                 Ip4hdr *ih4 = (Ip4hdr*)(bp->rp);
662
663                 if((ih4->vihl & 0xF0) != IP_VER6)
664                         ipoput4(c->p->f, bp, 0, ih4->ttl, ih4->tos, nil);
665                 else
666                         ipoput6(c->p->f, bp, 0, ((Ip6hdr*)ih4)->ttl, 0, nil);
667         }
668 }
669
670 static int
671 maskmemcmp(uchar *m, uchar *v, uchar *c, int n)
672 {
673         int i;
674
675         if(m == nil)
676                 return memcmp(v, c, n) != 0;
677
678         for(i = 0; i < n; i++)
679                 if((v[i] & m[i]) != c[i])
680                         return 1;
681         return 0;
682 }
683
684 static void
685 ipmuxiput(Proto *p, Ipifc *ifc, Block *bp)
686 {
687         Fs *f = p->f;
688         Conv *c;
689         Iplifc *lifc;
690         Ipmux *mux;
691         uchar *v;
692         Ip4hdr *ip4;
693         Ip6hdr *ip6;
694         int off, hl;
695
696         ip4 = (Ip4hdr*)bp->rp;
697         if((ip4->vihl & 0xF0) == IP_VER4) {
698                 hl = (ip4->vihl&0x0F)<<2;
699                 ip6 = nil;
700         } else {
701                 hl = IP6HDR;
702                 ip6 = (Ip6hdr*)ip4;
703         }
704
705         if(p->priv == nil)
706                 goto nomatch;
707
708         c = nil;
709         lifc = nil;
710
711         /* run the filter */
712         rlock(f);
713         mux = f->ipmux->priv;
714         while(mux != nil){
715                 switch(mux->type){
716                 case Tifc:
717                         if(mux->len != IPaddrlen)
718                                 goto no;
719                         for(lifc = ifc->lifc; lifc != nil; lifc = lifc->next)
720                                 for(v = mux->val; v < mux->e; v += IPaddrlen)
721                                         if(maskmemcmp(mux->mask, lifc->local, v, IPaddrlen) == 0)
722                                                 goto yes;
723                         goto no;
724                 case Tdata:
725                         off = hl;
726                         break;
727                 default:
728                         off = 0;
729                         break;
730                 }
731                 off += mux->off;
732                 if(off < 0 || off + mux->len > BLEN(bp))
733                         goto no;
734                 for(v = mux->val; v < mux->e; v += mux->len)
735                         if(maskmemcmp(mux->mask, bp->rp + off, v, mux->len) == 0)
736                                 goto yes;
737 no:
738                 mux = mux->no;
739                 continue;
740 yes:
741                 if(mux->conv != nil)
742                         c = mux->conv;
743                 mux = mux->yes;
744         }
745         runlock(f);
746
747         if(c != nil){
748                 /* tack on interface address */
749                 bp = padblock(bp, IPaddrlen);
750                 if(lifc == nil)
751                         lifc = ifc->lifc;
752                 ipmove(bp->rp, lifc != nil ? lifc->local : IPnoaddr);
753                 qpass(c->rq, concatblock(bp));
754                 return;
755         }
756
757 nomatch:
758         /* doesn't match any filter, hand it to the specific protocol handler */
759         if(ip6 != nil)
760                 p = f->t2p[ip6->proto];
761         else
762                 p = f->t2p[ip4->proto];
763         if(p != nil && p->rcv != nil){
764                 (*p->rcv)(p, ifc, bp);
765                 return;
766         }
767         freeblist(bp);
768 }
769
770 static int
771 ipmuxsprint(Ipmux *mux, int level, char *buf, int len)
772 {
773         int i, j, n;
774         uchar *v;
775
776         n = 0;
777         for(i = 0; i < level; i++)
778                 n += snprint(buf+n, len-n, " ");
779         if(mux == nil){
780                 n += snprint(buf+n, len-n, "\n");
781                 return n;
782         }
783         n += snprint(buf+n, len-n, "%s[%d:%d]", 
784                 mux->type == Tdata ? "data": "iph",
785                 mux->off, mux->off+mux->len-1);
786         if(mux->mask != nil){
787                 n += snprint(buf+n, len-n, "&");
788                 for(i = 0; i < mux->len; i++)
789                         n += snprint(buf+n, len - n, "%2.2ux", mux->mask[i]);
790         }
791         n += snprint(buf+n, len-n, "=");
792         v = mux->val;
793         for(j = 0; j < mux->n; j++){
794                 for(i = 0; i < mux->len; i++)
795                         n += snprint(buf+n, len - n, "%2.2ux", *v++);
796                 n += snprint(buf+n, len-n, "|");
797         }
798         n += snprint(buf+n, len-n, "\n");
799         level++;
800         n += ipmuxsprint(mux->no, level, buf+n, len-n);
801         n += ipmuxsprint(mux->yes, level, buf+n, len-n);
802         return n;
803 }
804
805 static int
806 ipmuxstats(Proto *p, char *buf, int len)
807 {
808         int n;
809         Fs *f = p->f;
810
811         rlock(f);
812         n = ipmuxsprint(p->priv, 0, buf, len);
813         runlock(f);
814
815         return n;
816 }
817
818 void
819 ipmuxinit(Fs *f)
820 {
821         Proto *ipmux;
822
823         ipmux = smalloc(sizeof(Proto));
824         ipmux->priv = nil;
825         ipmux->name = "ipmux";
826         ipmux->connect = ipmuxconnect;
827         ipmux->announce = ipmuxannounce;
828         ipmux->state = ipmuxstate;
829         ipmux->create = ipmuxcreate;
830         ipmux->close = ipmuxclose;
831         ipmux->rcv = ipmuxiput;
832         ipmux->ctl = nil;
833         ipmux->advise = nil;
834         ipmux->stats = ipmuxstats;
835         ipmux->ipproto = -1;
836         ipmux->nc = 64;
837         ipmux->ptclsize = sizeof(Ipmuxrock);
838
839         f->ipmux = ipmux;                       /* hack for Fsrcvpcol */
840
841         Fsproto(f, ipmux);
842 }