]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/ip/pppoe.c
Import sources from 2011-03-30 iso image
[plan9front.git] / sys / src / cmd / ip / pppoe.c
1 /*
2  * User-level PPP over Ethernet (PPPoE) client.
3  * See RFC 2516
4  */
5
6 #include <u.h>
7 #include <libc.h>
8 #include <ip.h>
9
10 void dumppkt(uchar*);
11 uchar *findtag(uchar*, int, int*, int);
12 void hexdump(uchar*, int);
13 int malformed(uchar*, int, int);
14 int pppoe(char*);
15 void execppp(int);
16
17 int alarmed;
18 int debug;
19 int sessid;
20 char *keyspec;
21 int primary;
22 char *pppnetmtpt;
23 char *acname;
24 char *pppname = "/bin/ip/ppp";
25 char *srvname = "";
26 char *wantac;
27 uchar *cookie;
28 int cookielen;
29 uchar etherdst[6];
30 int mtu = 1492;
31
32 void
33 usage(void)
34 {
35         fprint(2, "usage: pppoe [-Pd] [-A acname] [-S srvname] [-k keyspec] [-m mtu] [-x pppnet] [ether0]\n");
36         exits("usage");
37 }
38
39 int
40 catchalarm(void *a, char *msg)
41 {
42         USED(a);
43
44         if(strstr(msg, "alarm")){
45                 alarmed = 1;
46                 return 1;
47         }
48         if(debug)
49                 fprint(2, "note rcved: %s\n", msg);
50         return 0;
51 }
52
53 void
54 main(int argc, char **argv)
55 {
56         int fd;
57         char *dev;
58
59         ARGBEGIN{
60         case 'A':
61                 wantac = EARGF(usage());
62                 break;
63         case 'P':
64                 primary = 1;
65                 break;
66         case 'S':
67                 srvname = EARGF(usage());
68                 break;
69         case 'd':
70                 debug++;
71                 break;
72         case 'm':
73                 mtu = atoi(EARGF(usage()));
74                 break;
75         case 'k':
76                 keyspec = EARGF(usage());
77                 break;
78         case 'x':
79                 pppnetmtpt = EARGF(usage());
80                 break;
81         default:
82                 usage();
83         }ARGEND
84
85         switch(argc){
86         default:
87                 usage();
88         case 0:
89                 dev = "ether0";
90                 break;
91         case 1:
92                 dev = argv[0];
93                 break;
94         }
95
96         fmtinstall('E', eipfmt);
97
98         atnotify(catchalarm, 1);
99         fd = pppoe(dev);
100         execppp(fd);
101 }
102
103 typedef struct Etherhdr Etherhdr;
104 struct Etherhdr {
105         uchar dst[6];
106         uchar src[6];
107         uchar type[2];
108 };
109
110 enum {
111         EtherHdrSz = 6+6+2,
112         EtherMintu = 60,
113
114         EtherPppoeDiscovery = 0x8863,
115         EtherPppoeSession = 0x8864,
116 };
117
118 uchar etherbcast[6] = {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF};
119
120 int
121 etherhdr(uchar *pkt, uchar *dst, int type)
122 {
123         Etherhdr *eh;
124
125         eh = (Etherhdr*)pkt;
126         memmove(eh->dst, dst, sizeof(eh->dst));
127         hnputs(eh->type, type);
128         return EtherHdrSz;
129 }
130
131 typedef struct Pppoehdr Pppoehdr;
132 struct Pppoehdr {
133         uchar verstype;
134         uchar code;
135         uchar sessid[2];
136         uchar length[2];        /* of payload */
137 };
138
139 enum {
140         PppoeHdrSz = 1+1+2+2,
141         Hdr = EtherHdrSz+PppoeHdrSz,
142 };
143
144 enum {
145         VersType = 0x11,
146
147         /* Discovery codes */
148         CodeDiscInit = 0x09,    /* discovery init */
149         CodeDiscOffer = 0x07,   /* discovery offer */
150         CodeDiscReq = 0x19,     /* discovery request */
151         CodeDiscSess = 0x65,    /* session confirmation */
152
153         /* Session codes */
154         CodeSession = 0x00,
155 };
156
157 int
158 pppoehdr(uchar *pkt, int code, int sessid)
159 {
160         Pppoehdr *ph;
161
162         ph = (Pppoehdr*)pkt;
163         ph->verstype = VersType;
164         ph->code = code;
165         hnputs(ph->sessid, sessid);
166         return PppoeHdrSz;
167 }
168
169 typedef struct Taghdr Taghdr;
170 struct Taghdr {
171         uchar type[2];
172         uchar length[2];        /* of value */
173 };
174
175 enum {
176         TagEnd = 0x0000,                /* end of tag list */
177         TagSrvName = 0x0101,    /* service name */
178         TagAcName = 0x0102,     /* access concentrator name */
179         TagHostUniq = 0x0103,   /* nonce */
180         TagAcCookie = 0x0104,   /* a.c. cookie */
181         TagVendSpec = 0x0105,   /* vendor specific */
182         TagRelaySessId = 0x0110,        /* relay session id */
183         TagSrvNameErr = 0x0201, /* service name error (ascii) */
184         TagAcSysErr = 0x0202,   /* a.c. system error */
185 };
186
187 int
188 tag(uchar *pkt, int type, void *value, int nvalue)
189 {
190         Taghdr *h;
191
192         h = (Taghdr*)pkt;
193         hnputs(h->type, type);
194         hnputs(h->length, nvalue);
195         memmove(pkt+4, value, nvalue);
196         return 4+nvalue;
197 }
198
199 /* PPPoE Active Discovery Initiation */
200 int
201 padi(uchar *pkt)
202 {
203         int sz, tagoff;
204         uchar *length;
205
206         sz = 0;
207         sz += etherhdr(pkt+sz, etherbcast, EtherPppoeDiscovery);
208         sz += pppoehdr(pkt+sz, CodeDiscInit, 0x0000);
209         length = pkt+sz-2;
210         tagoff = sz;
211         sz += tag(pkt+sz, TagSrvName, srvname, strlen(srvname));
212         hnputs(length, sz-tagoff);
213         return sz;
214 }
215
216 /* PPPoE Active Discovery Request */
217 int
218 padr(uchar *pkt)
219 {
220         int sz, tagoff;
221         uchar *length;
222
223         sz = 0;
224         sz += etherhdr(pkt+sz, etherdst, EtherPppoeDiscovery);
225         sz += pppoehdr(pkt+sz, CodeDiscReq, 0x0000);
226         length = pkt+sz-2;
227         tagoff = sz;
228         sz += tag(pkt+sz, TagSrvName, srvname, strlen(srvname));
229         sz += tag(pkt+sz, TagAcName, acname, strlen(acname));
230         if(cookie)
231                 sz += tag(pkt+sz, TagAcCookie, cookie, cookielen);
232         hnputs(length, sz-tagoff);
233         return sz;
234 }
235
236 void
237 ewrite(int fd, void *buf, int nbuf)
238 {
239         char e[ERRMAX], path[64];
240
241         if(write(fd, buf, nbuf) != nbuf){
242                 rerrstr(e, sizeof e);
243                 strcpy(path, "unknown");
244                 fd2path(fd, path, sizeof path);
245                 sysfatal("write %d to %s: %s", nbuf, path, e);
246         }
247 }
248
249 void*
250 emalloc(long n)
251 {
252         void *v;
253
254         v = malloc(n);
255         if(v == nil)
256                 sysfatal("out of memory");
257         return v;
258 }
259
260 int
261 aread(int timeout, int fd, void *buf, int nbuf)
262 {
263         int n;
264
265         alarmed = 0;
266         alarm(timeout);
267         n = read(fd, buf, nbuf);
268         alarm(0);
269         if(alarmed)
270                 return -1;
271         if(n < 0)
272                 sysfatal("read: %r");
273         if(n == 0)
274                 sysfatal("short read");
275         return n;
276 }
277
278 int
279 pktread(int timeout, int fd, void *buf, int nbuf, int (*want)(uchar*))
280 {
281         int n, t2;
282         n = -1;
283         for(t2=timeout; t2<16000; t2*=2){
284                 while((n = aread(t2, fd, buf, nbuf)) > 0){
285                         if(malformed(buf, n, EtherPppoeDiscovery)){
286                                 if(debug)
287                                         fprint(2, "dropping pkt: %r\n");
288                                 continue;
289                         }
290                         if(debug)
291                                 dumppkt(buf);
292                         if(!want(buf)){
293                                 if(debug)
294                                         fprint(2, "dropping unwanted pkt: %r\n");
295                                 continue;
296                         }
297                         break;
298                 }
299                 if(n > 0)
300                         break;
301         }
302         return n;
303 }
304
305 int
306 bad(char *reason)
307 {
308         werrstr(reason);
309         return 0;
310 }
311
312 void*
313 copy(uchar *s, int len)
314 {
315         uchar *v;
316
317         v = emalloc(len+1);
318         memmove(v, s, len);
319         v[len] = '\0';
320         return v;
321 }
322
323 void
324 clearstate(void)
325 {
326         sessid = -1;
327         free(acname);
328         acname = nil;
329         free(cookie);
330         cookie = nil;
331 }
332
333 int
334 wantoffer(uchar *pkt)
335 {
336         int i, len;
337         uchar *s;
338         Etherhdr *eh;
339         Pppoehdr *ph;
340
341         eh = (Etherhdr*)pkt;
342         ph = (Pppoehdr*)(pkt+EtherHdrSz);
343
344         if(ph->code != CodeDiscOffer)
345                 return bad("not an offer");
346         if(nhgets(ph->sessid) != 0x0000)
347                 return bad("bad session id");
348
349         for(i=0;; i++){
350                 if((s = findtag(pkt, TagSrvName, &len, i)) == nil)
351                         return bad("no matching service name");
352                 if(len == strlen(srvname) && memcmp(s, srvname, len) == 0)
353                         break;
354         }
355
356         if((s = findtag(pkt, TagAcName, &len, 0)) == nil)
357                 return bad("no ac name");
358         acname = copy(s, len);
359         if(wantac && strcmp(acname, wantac) != 0){
360                 free(acname);
361                 return bad("wrong ac name");
362         }
363
364         if(s = findtag(pkt, TagAcCookie, &len, 0)){
365                 cookie = copy(s, len);
366                 cookielen = len;
367         }
368         memmove(etherdst, eh->src, sizeof etherdst);
369         return 1;
370 }
371
372 int
373 wantsession(uchar *pkt)
374 {
375         int len;
376         uchar *s;
377         Pppoehdr *ph;
378
379         ph = (Pppoehdr*)(pkt+EtherHdrSz);
380
381         if(ph->code != CodeDiscSess)
382                 return bad("not a session confirmation");
383         if(nhgets(ph->sessid) == 0x0000)
384                 return bad("bad session id");
385         if(findtag(pkt, TagSrvName, &len, 0) == nil)
386                 return bad("no service name");
387         if(findtag(pkt, TagSrvNameErr, &len, 0))
388                 return bad("service name error");
389         if(findtag(pkt, TagAcSysErr, &len, 0))
390                 return bad("ac system error");
391
392         /*
393          * rsc said: ``if there is no -S option given, the current code
394          * waits for an offer with service name == "".
395          * that's silly.  it should take the first one it gets.''
396          */
397         if(srvname[0] != '\0') {
398                 if((s = findtag(pkt, TagSrvName, &len, 0)) == nil)
399                         return bad("no matching service name");
400                 if(len != strlen(srvname) || memcmp(s, srvname, len) != 0)
401                         return bad("no matching service name");
402         }
403         sessid = nhgets(ph->sessid);
404         return 1;
405 }
406
407 int
408 pppoe(char *ether)
409 {
410         char buf[64];
411         uchar pkt[1520];
412         int dfd, p[2], n, sfd, sz, timeout;
413         Pppoehdr *ph;
414
415         ph = (Pppoehdr*)(pkt+EtherHdrSz);
416         snprint(buf, sizeof buf, "%s!%d", ether, EtherPppoeDiscovery);
417         if((dfd = dial(buf, nil, nil, nil)) < 0)
418                 sysfatal("dial %s: %r", buf);
419
420         snprint(buf, sizeof buf, "%s!%d", ether, EtherPppoeSession);
421         if((sfd = dial(buf, nil, nil, nil)) < 0)
422                 sysfatal("dial %s: %r", buf);
423
424         for(timeout=250; timeout<16000; timeout*=2){
425                 clearstate();
426                 memset(pkt, 0, sizeof pkt);
427                 sz = padi(pkt);
428                 if(debug)
429                         dumppkt(pkt);
430                 if(sz < EtherMintu)
431                         sz = EtherMintu;
432                 ewrite(dfd, pkt, sz);
433
434                 if(pktread(timeout, dfd, pkt, sizeof pkt, wantoffer) < 0)
435                         continue;
436
437                 memset(pkt, 0, sizeof pkt);
438                 sz = padr(pkt);
439                 if(debug)
440                         dumppkt(pkt);
441                 if(sz < EtherMintu)
442                         sz = EtherMintu;
443                 ewrite(dfd, pkt, sz);
444
445                 if(pktread(timeout, dfd, pkt, sizeof pkt, wantsession) < 0)
446                         continue;
447
448                 break;
449         }
450         if(sessid < 0)
451                 sysfatal("could not establish session");
452
453         rfork(RFNOTEG);
454         if(pipe(p) < 0)
455                 sysfatal("pipe: %r");
456
457         switch(fork()){
458         case -1:
459                 sysfatal("fork: %r");
460         default:
461                 break;
462         case 0:
463                 close(p[1]);
464                 while((n = read(p[0], pkt+Hdr, sizeof pkt-Hdr)) > 0){
465                         etherhdr(pkt, etherdst, EtherPppoeSession);
466                         pppoehdr(pkt+EtherHdrSz, 0x00, sessid);
467                         hnputs(pkt+Hdr-2, n);
468                         sz = Hdr+n;
469                         if(debug > 1){
470                                 dumppkt(pkt);
471                                 hexdump(pkt, sz);
472                         }
473                         if(sz < EtherMintu)
474                                 sz = EtherMintu;
475                         if(write(sfd, pkt, sz) < 0){
476                                 if(debug)
477                                         fprint(2, "write to ether failed: %r");
478                                 _exits(nil);
479                         }
480                 }
481                 _exits(nil);
482         }
483
484         switch(fork()){
485         case -1:
486                 sysfatal("fork: %r");
487         default:
488                 break;
489         case 0:
490                 close(p[1]);
491                 while((n = read(sfd, pkt, sizeof pkt)) > 0){
492                         if(malformed(pkt, n, EtherPppoeSession)
493                         || ph->code != 0x00 || nhgets(ph->sessid) != sessid){
494                                 if(debug)
495                                         fprint(2, "malformed session pkt: %r\n");
496                                 if(debug)
497                                         dumppkt(pkt);
498                                 continue;
499                         }
500                         if(write(p[0], pkt+Hdr, nhgets(ph->length)) < 0){
501                                 if(debug)
502                                         fprint(2, "write to ppp failed: %r\n");
503                                 _exits(nil);
504                         }
505                 }
506                 _exits(nil);
507         }
508         close(p[0]);
509         return p[1];
510 }
511
512 void
513 execppp(int fd)
514 {
515         char *argv[16];
516         int argc;
517         char smtu[10];
518
519         argc = 0;
520         argv[argc++] = pppname;
521         snprint(smtu, sizeof(smtu), "-m%d", mtu);
522         argv[argc++] = smtu;
523         argv[argc++] = "-F";
524         if(debug)
525                 argv[argc++] = "-d";
526         if(primary)
527                 argv[argc++] = "-P";
528         if(pppnetmtpt){
529                 argv[argc++] = "-x";
530                 argv[argc++] = pppnetmtpt;
531         }
532         if(keyspec){
533                 argv[argc++] = "-k";
534                 argv[argc++] = keyspec;
535         }
536         argv[argc] = nil;
537
538         dup(fd, 0);
539         dup(fd, 1);
540         exec(pppname, argv);
541         sysfatal("exec: %r");
542 }
543
544 uchar*
545 findtag(uchar *pkt, int tagtype, int *plen, int skip)
546 {
547         int len, sz, totlen;
548         uchar *tagdat, *v;
549         Etherhdr *eh;
550         Pppoehdr *ph;
551         Taghdr *t;
552
553         eh = (Etherhdr*)pkt;
554         ph = (Pppoehdr*)(pkt+EtherHdrSz);
555         tagdat = pkt+Hdr;
556
557         if(nhgets(eh->type) != EtherPppoeDiscovery)
558                 return nil;
559         totlen = nhgets(ph->length);
560
561         sz = 0;
562         while(sz+4 <= totlen){
563                 t = (Taghdr*)(tagdat+sz);
564                 v = tagdat+sz+4;
565                 len = nhgets(t->length);
566                 if(sz+4+len > totlen)
567                         break;
568                 if(nhgets(t->type) == tagtype && skip-- == 0){
569                         *plen = len;
570                         return v;
571                 }
572                 sz += 2+2+len;
573         }
574         return nil;     
575 }
576
577 void
578 dumptags(uchar *tagdat, int ntagdat)
579 {
580         int i,len, sz;
581         uchar *v;
582         Taghdr *t;
583
584         sz = 0;
585         while(sz+4 <= ntagdat){
586                 t = (Taghdr*)(tagdat+sz);
587                 v = tagdat+sz+2+2;
588                 len = nhgets(t->length);
589                 if(sz+4+len > ntagdat)
590                         break;
591                 fprint(2, "\t0x%x %d: ", nhgets(t->type), len);
592                 switch(nhgets(t->type)){
593                 case TagEnd:
594                         fprint(2, "end of tag list\n");
595                         break;
596                 case TagSrvName:
597                         fprint(2, "service '%.*s'\n", len, (char*)v);
598                         break;
599                 case TagAcName:
600                         fprint(2, "ac '%.*s'\n", len, (char*)v);
601                         break;
602                 case TagHostUniq:
603                         fprint(2, "nonce ");
604                 Hex:
605                         for(i=0; i<len; i++)
606                                 fprint(2, "%.2ux", v[i]);
607                         fprint(2, "\n");
608                         break;
609                 case TagAcCookie:
610                         fprint(2, "ac cookie ");
611                         goto Hex;
612                 case TagVendSpec:
613                         fprint(2, "vend spec ");
614                         goto Hex;
615                 case TagRelaySessId:
616                         fprint(2, "relay ");
617                         goto Hex;
618                 case TagSrvNameErr:
619                         fprint(2, "srverr '%.*s'\n", len, (char*)v);
620                         break;
621                 case TagAcSysErr:
622                         fprint(2, "syserr '%.*s'\n", len, (char*)v);
623                         break;
624                 }
625                 sz += 2+2+len;
626         }
627         if(sz != ntagdat)
628                 fprint(2, "warning: only dumped %d of %d bytes\n", sz, ntagdat);
629 }
630
631 void
632 dumppkt(uchar *pkt)
633 {
634         int et;
635         Etherhdr *eh;
636         Pppoehdr *ph;
637
638         eh = (Etherhdr*)pkt;
639         ph = (Pppoehdr*)(pkt+EtherHdrSz);
640         et = nhgets(eh->type);
641
642         fprint(2, "%E -> %E type 0x%x\n", 
643                 eh->src, eh->dst, et);
644         switch(et){
645         case EtherPppoeDiscovery:
646         case EtherPppoeSession:
647                 fprint(2, "\tvers %d type %d code 0x%x sessid 0x%x length %d\n",
648                         ph->verstype>>4, ph->verstype&15,
649                         ph->code, nhgets(ph->sessid), nhgets(ph->length));
650                 if(et == EtherPppoeDiscovery)
651                         dumptags(pkt+Hdr, nhgets(ph->length));
652         }
653 }
654
655 int
656 malformed(uchar *pkt, int n, int wantet)
657 {
658         int et;
659         Etherhdr *eh;
660         Pppoehdr *ph;
661
662         eh = (Etherhdr*)pkt;
663         ph = (Pppoehdr*)(pkt+EtherHdrSz);
664
665         if(n < Hdr || n < Hdr+nhgets(ph->length)){
666                 werrstr("packet too short %d != %d", n, Hdr+nhgets(ph->length));
667                 return 1;
668         }
669
670         et = nhgets(eh->type);
671         if(et != wantet){
672                 werrstr("wrong ethernet packet type 0x%x != 0x%x", et, wantet);
673                 return 1;
674         }
675
676         return 0;
677 }
678
679 void
680 hexdump(uchar *a, int na)
681 {
682         int i;
683         char buf[80];
684
685         buf[0] = '\0';
686         for(i=0; i<na; i++){
687                 sprint(buf+strlen(buf), " %.2ux", a[i]);
688                 if(i%16 == 7)
689                         sprint(buf+strlen(buf), " --");
690                 if(i%16==15){
691                         sprint(buf+strlen(buf), "\n");
692                         write(2, buf, strlen(buf));
693                         buf[0] = 0;
694                 }
695         }
696         if(i%16){
697                 sprint(buf+strlen(buf), "\n");
698                 write(2, buf, strlen(buf));
699         }
700 }