]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/ip/pppoe.c
6cd3f257021ef5cb0070bcbd48f8e18579d27da4
[plan9front.git] / sys / src / cmd / ip / pppoe.c
1 /*
2  * User-level PPP over Ethernet (PPPoE) client.
3  * See RFC 2516
4  */
5
6 #include <u.h>
7 #include <libc.h>
8 #include <ip.h>
9
10 void dumppkt(uchar*);
11 uchar *findtag(uchar*, int, int*, int);
12 void hexdump(uchar*, int);
13 int malformed(uchar*, int, int);
14 int pppoe(char*);
15 void execppp(int);
16
17 int alarmed;
18 int debug;
19 int sessid;
20 char *keyspec;
21 int primary;
22 char *pppnetmtpt;
23 char *acname;
24 char *pppname = "/bin/ip/ppp";
25 char *srvname = "";
26 char *wantac;
27 uchar *cookie;
28 int cookielen;
29 uchar etherdst[6];
30 int mtu = 1492;
31
32 void
33 usage(void)
34 {
35         fprint(2, "usage: pppoe [-Pd] [-A acname] [-S srvname] [-k keyspec] [-m mtu] [-x pppnet] [ether0]\n");
36         exits("usage");
37 }
38
39 int
40 catchalarm(void *a, char *msg)
41 {
42         USED(a);
43
44         if(strstr(msg, "alarm")){
45                 alarmed = 1;
46                 return 1;
47         }
48         if(debug)
49                 fprint(2, "note rcved: %s\n", msg);
50         return 0;
51 }
52
53 void
54 main(int argc, char **argv)
55 {
56         int fd;
57         char *dev;
58
59         ARGBEGIN{
60         case 'A':
61                 wantac = EARGF(usage());
62                 break;
63         case 'P':
64                 primary = 1;
65                 break;
66         case 'S':
67                 srvname = EARGF(usage());
68                 break;
69         case 'd':
70                 debug++;
71                 break;
72         case 'm':
73                 mtu = atoi(EARGF(usage()));
74                 break;
75         case 'k':
76                 keyspec = EARGF(usage());
77                 break;
78         case 'x':
79                 pppnetmtpt = EARGF(usage());
80                 break;
81         default:
82                 usage();
83         }ARGEND
84
85         switch(argc){
86         default:
87                 usage();
88         case 0:
89                 dev = "ether0";
90                 break;
91         case 1:
92                 dev = argv[0];
93                 break;
94         }
95
96         fmtinstall('E', eipfmt);
97
98         atnotify(catchalarm, 1);
99         fd = pppoe(dev);
100         execppp(fd);
101 }
102
103 typedef struct Etherhdr Etherhdr;
104 struct Etherhdr {
105         uchar dst[6];
106         uchar src[6];
107         uchar type[2];
108 };
109
110 enum {
111         EtherHdrSz = 6+6+2,
112         EtherMintu = 60,
113
114         EtherPppoeDiscovery = 0x8863,
115         EtherPppoeSession = 0x8864,
116 };
117
118 uchar etherbcast[6] = {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF};
119
120 int
121 etherhdr(uchar *pkt, uchar *dst, int type)
122 {
123         Etherhdr *eh;
124
125         eh = (Etherhdr*)pkt;
126         memmove(eh->dst, dst, sizeof(eh->dst));
127         hnputs(eh->type, type);
128         return EtherHdrSz;
129 }
130
131 typedef struct Pppoehdr Pppoehdr;
132 struct Pppoehdr {
133         uchar verstype;
134         uchar code;
135         uchar sessid[2];
136         uchar length[2];        /* of payload */
137 };
138
139 enum {
140         PppoeHdrSz = 1+1+2+2,
141         Hdr = EtherHdrSz+PppoeHdrSz,
142 };
143
144 enum {
145         VersType = 0x11,
146
147         /* Discovery codes */
148         CodeDiscInit = 0x09,    /* discovery init */
149         CodeDiscOffer = 0x07,   /* discovery offer */
150         CodeDiscReq = 0x19,     /* discovery request */
151         CodeDiscSess = 0x65,    /* session confirmation */
152
153         /* Session codes */
154         CodeSession = 0x00,
155 };
156
157 int
158 pppoehdr(uchar *pkt, int code, int sessid)
159 {
160         Pppoehdr *ph;
161
162         ph = (Pppoehdr*)pkt;
163         ph->verstype = VersType;
164         ph->code = code;
165         hnputs(ph->sessid, sessid);
166         return PppoeHdrSz;
167 }
168
169 typedef struct Taghdr Taghdr;
170 struct Taghdr {
171         uchar type[2];
172         uchar length[2];        /* of value */
173 };
174
175 enum {
176         TagEnd = 0x0000,                /* end of tag list */
177         TagSrvName = 0x0101,    /* service name */
178         TagAcName = 0x0102,     /* access concentrator name */
179         TagHostUniq = 0x0103,   /* nonce */
180         TagAcCookie = 0x0104,   /* a.c. cookie */
181         TagVendSpec = 0x0105,   /* vendor specific */
182         TagRelaySessId = 0x0110,        /* relay session id */
183         TagSrvNameErr = 0x0201, /* service name error (ascii) */
184         TagAcSysErr = 0x0202,   /* a.c. system error */
185 };
186
187 int
188 tag(uchar *pkt, int type, void *value, int nvalue)
189 {
190         Taghdr *h;
191
192         h = (Taghdr*)pkt;
193         hnputs(h->type, type);
194         hnputs(h->length, nvalue);
195         memmove(pkt+4, value, nvalue);
196         return 4+nvalue;
197 }
198
199 /* PPPoE Active Discovery Initiation */
200 int
201 padi(uchar *pkt)
202 {
203         int sz, tagoff;
204         uchar *length;
205
206         sz = 0;
207         sz += etherhdr(pkt+sz, etherbcast, EtherPppoeDiscovery);
208         sz += pppoehdr(pkt+sz, CodeDiscInit, 0x0000);
209         length = pkt+sz-2;
210         tagoff = sz;
211         sz += tag(pkt+sz, TagSrvName, srvname, strlen(srvname));
212         hnputs(length, sz-tagoff);
213         return sz;
214 }
215
216 /* PPPoE Active Discovery Request */
217 int
218 padr(uchar *pkt)
219 {
220         int sz, tagoff;
221         uchar *length;
222
223         sz = 0;
224         sz += etherhdr(pkt+sz, etherdst, EtherPppoeDiscovery);
225         sz += pppoehdr(pkt+sz, CodeDiscReq, 0x0000);
226         length = pkt+sz-2;
227         tagoff = sz;
228         sz += tag(pkt+sz, TagSrvName, srvname, strlen(srvname));
229         sz += tag(pkt+sz, TagAcName, acname, strlen(acname));
230         if(cookie)
231                 sz += tag(pkt+sz, TagAcCookie, cookie, cookielen);
232         hnputs(length, sz-tagoff);
233         return sz;
234 }
235
236 void
237 ewrite(int fd, void *buf, int nbuf)
238 {
239         char e[ERRMAX], path[64];
240
241         if(write(fd, buf, nbuf) != nbuf){
242                 rerrstr(e, sizeof e);
243                 strcpy(path, "unknown");
244                 fd2path(fd, path, sizeof path);
245                 sysfatal("write %d to %s: %s", nbuf, path, e);
246         }
247 }
248
249 void*
250 emalloc(long n)
251 {
252         void *v;
253
254         v = malloc(n);
255         if(v == nil)
256                 sysfatal("out of memory");
257         return v;
258 }
259
260 int
261 aread(int timeout, int fd, void *buf, int nbuf)
262 {
263         int n;
264
265         alarmed = 0;
266         alarm(timeout);
267         n = read(fd, buf, nbuf);
268         alarm(0);
269         if(alarmed)
270                 return -1;
271         if(n < 0)
272                 sysfatal("read: %r");
273         if(n == 0)
274                 sysfatal("short read");
275         return n;
276 }
277
278 int
279 pktread(int timeout, int fd, void *buf, int nbuf, int (*want)(uchar*))
280 {
281         int n, t2;
282         n = -1;
283         for(t2=timeout; t2<16000; t2*=2){
284                 while((n = aread(t2, fd, buf, nbuf)) > 0){
285                         if(malformed(buf, n, EtherPppoeDiscovery)){
286                                 if(debug)
287                                         fprint(2, "dropping pkt: %r\n");
288                                 continue;
289                         }
290                         if(debug)
291                                 dumppkt(buf);
292                         if(!want(buf)){
293                                 if(debug)
294                                         fprint(2, "dropping unwanted pkt: %r\n");
295                                 continue;
296                         }
297                         break;
298                 }
299                 if(n > 0)
300                         break;
301         }
302         return n;
303 }
304
305 int
306 bad(char *reason)
307 {
308         werrstr(reason);
309         return 0;
310 }
311
312 void*
313 copy(uchar *s, int len)
314 {
315         uchar *v;
316
317         v = emalloc(len+1);
318         memmove(v, s, len);
319         v[len] = '\0';
320         return v;
321 }
322
323 void
324 clearstate(void)
325 {
326         sessid = -1;
327         free(acname);
328         acname = nil;
329         free(cookie);
330         cookie = nil;
331 }
332
333 int
334 wantoffer(uchar *pkt)
335 {
336         int i, len;
337         uchar *s;
338         Etherhdr *eh;
339         Pppoehdr *ph;
340
341         eh = (Etherhdr*)pkt;
342         ph = (Pppoehdr*)(pkt+EtherHdrSz);
343
344         if(ph->code != CodeDiscOffer)
345                 return bad("not an offer");
346         if(nhgets(ph->sessid) != 0x0000)
347                 return bad("bad session id");
348
349         for(i=0;; i++){
350                 if((s = findtag(pkt, TagSrvName, &len, i)) == nil)
351                         return bad("no matching service name");
352                 if(len == strlen(srvname) && memcmp(s, srvname, len) == 0)
353                         break;
354         }
355
356         if((s = findtag(pkt, TagAcName, &len, 0)) == nil)
357                 return bad("no ac name");
358         acname = copy(s, len);
359         if(wantac && strcmp(acname, wantac) != 0){
360                 free(acname);
361                 acname = nil;
362                 return bad("wrong ac name");
363         }
364
365         if(s = findtag(pkt, TagAcCookie, &len, 0)){
366                 cookie = copy(s, len);
367                 cookielen = len;
368         }
369         memmove(etherdst, eh->src, sizeof etherdst);
370         return 1;
371 }
372
373 int
374 wantsession(uchar *pkt)
375 {
376         int len;
377         uchar *s;
378         Pppoehdr *ph;
379
380         ph = (Pppoehdr*)(pkt+EtherHdrSz);
381
382         if(ph->code != CodeDiscSess)
383                 return bad("not a session confirmation");
384         if(nhgets(ph->sessid) == 0x0000)
385                 return bad("bad session id");
386         if(findtag(pkt, TagSrvName, &len, 0) == nil)
387                 return bad("no service name");
388         if(findtag(pkt, TagSrvNameErr, &len, 0))
389                 return bad("service name error");
390         if(findtag(pkt, TagAcSysErr, &len, 0))
391                 return bad("ac system error");
392
393         /*
394          * rsc said: ``if there is no -S option given, the current code
395          * waits for an offer with service name == "".
396          * that's silly.  it should take the first one it gets.''
397          */
398         if(srvname[0] != '\0') {
399                 if((s = findtag(pkt, TagSrvName, &len, 0)) == nil)
400                         return bad("no matching service name");
401                 if(len != strlen(srvname) || memcmp(s, srvname, len) != 0)
402                         return bad("no matching service name");
403         }
404         sessid = nhgets(ph->sessid);
405         return 1;
406 }
407
408 int
409 pppoe(char *ether)
410 {
411         char buf[64];
412         uchar pkt[1520];
413         int dfd, p[2], n, sfd, sz, timeout;
414         Pppoehdr *ph;
415
416         ph = (Pppoehdr*)(pkt+EtherHdrSz);
417         snprint(buf, sizeof buf, "%s!%d", ether, EtherPppoeDiscovery);
418         if((dfd = dial(buf, nil, nil, nil)) < 0)
419                 sysfatal("dial %s: %r", buf);
420
421         snprint(buf, sizeof buf, "%s!%d", ether, EtherPppoeSession);
422         if((sfd = dial(buf, nil, nil, nil)) < 0)
423                 sysfatal("dial %s: %r", buf);
424
425         for(timeout=250; timeout<16000; timeout*=2){
426                 clearstate();
427                 memset(pkt, 0, sizeof pkt);
428                 sz = padi(pkt);
429                 if(debug)
430                         dumppkt(pkt);
431                 if(sz < EtherMintu)
432                         sz = EtherMintu;
433                 ewrite(dfd, pkt, sz);
434
435                 if(pktread(timeout, dfd, pkt, sizeof pkt, wantoffer) < 0)
436                         continue;
437
438                 memset(pkt, 0, sizeof pkt);
439                 sz = padr(pkt);
440                 if(debug)
441                         dumppkt(pkt);
442                 if(sz < EtherMintu)
443                         sz = EtherMintu;
444                 ewrite(dfd, pkt, sz);
445
446                 if(pktread(timeout, dfd, pkt, sizeof pkt, wantsession) < 0)
447                         continue;
448
449                 break;
450         }
451         if(sessid < 0)
452                 sysfatal("could not establish session");
453
454         rfork(RFNOTEG);
455         if(pipe(p) < 0)
456                 sysfatal("pipe: %r");
457
458         switch(fork()){
459         case -1:
460                 sysfatal("fork: %r");
461         default:
462                 break;
463         case 0:
464                 close(p[1]);
465                 while((n = read(p[0], pkt+Hdr, sizeof pkt-Hdr)) > 0){
466                         etherhdr(pkt, etherdst, EtherPppoeSession);
467                         pppoehdr(pkt+EtherHdrSz, 0x00, sessid);
468                         hnputs(pkt+Hdr-2, n);
469                         sz = Hdr+n;
470                         if(debug > 1){
471                                 dumppkt(pkt);
472                                 hexdump(pkt, sz);
473                         }
474                         if(sz < EtherMintu)
475                                 sz = EtherMintu;
476                         if(write(sfd, pkt, sz) < 0){
477                                 if(debug)
478                                         fprint(2, "write to ether failed: %r");
479                                 _exits(nil);
480                         }
481                 }
482                 _exits(nil);
483         }
484
485         switch(fork()){
486         case -1:
487                 sysfatal("fork: %r");
488         default:
489                 break;
490         case 0:
491                 close(p[1]);
492                 while((n = read(sfd, pkt, sizeof pkt)) > 0){
493                         if(malformed(pkt, n, EtherPppoeSession)
494                         || ph->code != 0x00 || nhgets(ph->sessid) != sessid){
495                                 if(debug)
496                                         fprint(2, "malformed session pkt: %r\n");
497                                 if(debug)
498                                         dumppkt(pkt);
499                                 continue;
500                         }
501                         if(write(p[0], pkt+Hdr, nhgets(ph->length)) < 0){
502                                 if(debug)
503                                         fprint(2, "write to ppp failed: %r\n");
504                                 _exits(nil);
505                         }
506                 }
507                 _exits(nil);
508         }
509         close(p[0]);
510         return p[1];
511 }
512
513 void
514 execppp(int fd)
515 {
516         char *argv[16];
517         int argc;
518         char smtu[10];
519
520         argc = 0;
521         argv[argc++] = pppname;
522         snprint(smtu, sizeof(smtu), "-m%d", mtu);
523         argv[argc++] = smtu;
524         argv[argc++] = "-F";
525         if(debug)
526                 argv[argc++] = "-d";
527         if(primary)
528                 argv[argc++] = "-P";
529         if(pppnetmtpt){
530                 argv[argc++] = "-x";
531                 argv[argc++] = pppnetmtpt;
532         }
533         if(keyspec){
534                 argv[argc++] = "-k";
535                 argv[argc++] = keyspec;
536         }
537         argv[argc] = nil;
538
539         dup(fd, 0);
540         dup(fd, 1);
541         exec(pppname, argv);
542         sysfatal("exec: %r");
543 }
544
545 uchar*
546 findtag(uchar *pkt, int tagtype, int *plen, int skip)
547 {
548         int len, sz, totlen;
549         uchar *tagdat, *v;
550         Etherhdr *eh;
551         Pppoehdr *ph;
552         Taghdr *t;
553
554         eh = (Etherhdr*)pkt;
555         ph = (Pppoehdr*)(pkt+EtherHdrSz);
556         tagdat = pkt+Hdr;
557
558         if(nhgets(eh->type) != EtherPppoeDiscovery)
559                 return nil;
560         totlen = nhgets(ph->length);
561
562         sz = 0;
563         while(sz+4 <= totlen){
564                 t = (Taghdr*)(tagdat+sz);
565                 v = tagdat+sz+4;
566                 len = nhgets(t->length);
567                 if(sz+4+len > totlen)
568                         break;
569                 if(nhgets(t->type) == tagtype && skip-- == 0){
570                         *plen = len;
571                         return v;
572                 }
573                 sz += 2+2+len;
574         }
575         return nil;     
576 }
577
578 void
579 dumptags(uchar *tagdat, int ntagdat)
580 {
581         int i,len, sz;
582         uchar *v;
583         Taghdr *t;
584
585         sz = 0;
586         while(sz+4 <= ntagdat){
587                 t = (Taghdr*)(tagdat+sz);
588                 v = tagdat+sz+2+2;
589                 len = nhgets(t->length);
590                 if(sz+4+len > ntagdat)
591                         break;
592                 fprint(2, "\t0x%x %d: ", nhgets(t->type), len);
593                 switch(nhgets(t->type)){
594                 case TagEnd:
595                         fprint(2, "end of tag list\n");
596                         break;
597                 case TagSrvName:
598                         fprint(2, "service '%.*s'\n", len, (char*)v);
599                         break;
600                 case TagAcName:
601                         fprint(2, "ac '%.*s'\n", len, (char*)v);
602                         break;
603                 case TagHostUniq:
604                         fprint(2, "nonce ");
605                 Hex:
606                         for(i=0; i<len; i++)
607                                 fprint(2, "%.2ux", v[i]);
608                         fprint(2, "\n");
609                         break;
610                 case TagAcCookie:
611                         fprint(2, "ac cookie ");
612                         goto Hex;
613                 case TagVendSpec:
614                         fprint(2, "vend spec ");
615                         goto Hex;
616                 case TagRelaySessId:
617                         fprint(2, "relay ");
618                         goto Hex;
619                 case TagSrvNameErr:
620                         fprint(2, "srverr '%.*s'\n", len, (char*)v);
621                         break;
622                 case TagAcSysErr:
623                         fprint(2, "syserr '%.*s'\n", len, (char*)v);
624                         break;
625                 }
626                 sz += 2+2+len;
627         }
628         if(sz != ntagdat)
629                 fprint(2, "warning: only dumped %d of %d bytes\n", sz, ntagdat);
630 }
631
632 void
633 dumppkt(uchar *pkt)
634 {
635         int et;
636         Etherhdr *eh;
637         Pppoehdr *ph;
638
639         eh = (Etherhdr*)pkt;
640         ph = (Pppoehdr*)(pkt+EtherHdrSz);
641         et = nhgets(eh->type);
642
643         fprint(2, "%E -> %E type 0x%x\n", 
644                 eh->src, eh->dst, et);
645         switch(et){
646         case EtherPppoeDiscovery:
647         case EtherPppoeSession:
648                 fprint(2, "\tvers %d type %d code 0x%x sessid 0x%x length %d\n",
649                         ph->verstype>>4, ph->verstype&15,
650                         ph->code, nhgets(ph->sessid), nhgets(ph->length));
651                 if(et == EtherPppoeDiscovery)
652                         dumptags(pkt+Hdr, nhgets(ph->length));
653         }
654 }
655
656 int
657 malformed(uchar *pkt, int n, int wantet)
658 {
659         int et;
660         Etherhdr *eh;
661         Pppoehdr *ph;
662
663         eh = (Etherhdr*)pkt;
664         ph = (Pppoehdr*)(pkt+EtherHdrSz);
665
666         if(n < Hdr || n < Hdr+nhgets(ph->length)){
667                 werrstr("packet too short %d != %d", n, Hdr+nhgets(ph->length));
668                 return 1;
669         }
670
671         et = nhgets(eh->type);
672         if(et != wantet){
673                 werrstr("wrong ethernet packet type 0x%x != 0x%x", et, wantet);
674                 return 1;
675         }
676
677         return 0;
678 }
679
680 void
681 hexdump(uchar *a, int na)
682 {
683         int i;
684         char buf[80];
685
686         buf[0] = '\0';
687         for(i=0; i<na; i++){
688                 sprint(buf+strlen(buf), " %.2ux", a[i]);
689                 if(i%16 == 7)
690                         sprint(buf+strlen(buf), " --");
691                 if(i%16==15){
692                         sprint(buf+strlen(buf), "\n");
693                         write(2, buf, strlen(buf));
694                         buf[0] = 0;
695                 }
696         }
697         if(i%16){
698                 sprint(buf+strlen(buf), "\n");
699                 write(2, buf, strlen(buf));
700         }
701 }