]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/ip/dhcp6d.c
fix filetype detecton by suffix so that multiple dots dont confuse it. (thanks kvik)
[plan9front.git] / sys / src / cmd / ip / dhcp6d.c
1 /* minimal stateless DHCPv6 server for network boot */
2 #include <u.h>
3 #include <libc.h>
4 #include <ip.h>
5 #include <bio.h>
6 #include <ndb.h>
7
8 enum {
9         Eaddrlen = 6,
10
11         SOLICIT = 1,
12         ADVERTISE,
13         REQUEST,
14         CONFIRM,
15         RENEW,
16         REBIND,
17         REPLY,
18         RELEASE,
19         DECLINE,
20         RECONFIGURE,
21         INFOREQ,
22         RELAYFORW,
23         RELAYREPL,
24 };
25
26 typedef struct Req Req;
27 struct Req
28 {
29         int             tra;
30
31         Udphdr          *udp;
32         Ipifc           *ifc;
33
34         uchar           mac[Eaddrlen];
35         uchar           ips[IPaddrlen*8];
36         int             nips;
37
38         Ndb             *db;
39         Ndbtuple        *t;
40
41         struct {
42                 int     t;
43                 uchar   *p;
44                 uchar   *e;
45         } req;
46
47         struct {
48                 int     t;
49                 uchar   *p;
50                 uchar   *e;
51         } resp;
52 };
53
54 typedef struct Otab Otab;
55 struct Otab
56 {
57         int     t;
58         int     (*f)(uchar *, int, Otab*, Req*);
59         char    *q[3];
60         int     done;
61 };
62
63 static Otab otab[];
64 static Ipifc *ipifcs;
65 static ulong starttime;
66 static char *ndbfile;
67 static char *netmtpt = "/net";
68 static int debug;
69
70 static uchar v6loopback[IPaddrlen] = {
71         0, 0, 0, 0,
72         0, 0, 0, 0,
73         0, 0, 0, 0,
74         0, 0, 0, 1
75 };
76
77 /*
78  * open ndbfile as db if not already open.  also check for stale data
79  * and reload as needed.
80  */
81 static Ndb *
82 opendb(void)
83 {
84         static ulong lastcheck;
85         static Ndb *db;
86         ulong now = time(nil);
87
88         /* check no more often than once every minute */
89         if(db == nil) {
90                 db = ndbopen(ndbfile);
91                 if(db != nil)
92                         lastcheck = now;
93         } else if(now >= lastcheck + 60) {
94                 if (ndbchanged(db))
95                         ndbreopen(db);
96                 lastcheck = now;
97         }
98         return db;
99 }
100
101 static Ipifc*
102 findifc(char *net, uchar ip[IPaddrlen])
103 {
104         Ipifc *ifc;
105         Iplifc *lifc;
106
107         ipifcs = readipifc(net, ipifcs, -1);
108         for(ifc = ipifcs; ifc != nil; ifc = ifc->next)
109                 for(lifc = ifc->lifc; lifc != nil; lifc = lifc->next)
110                         if(ipcmp(lifc->ip, ip) == 0)
111                                 return ifc;
112
113         return nil;
114 }
115
116 static Iplifc*
117 localonifc(Ipifc *ifc, uchar ip[IPaddrlen])
118 {
119         Iplifc *lifc;
120         uchar net[IPaddrlen];
121
122         for(lifc = ifc->lifc; lifc != nil; lifc = lifc->next){
123                 maskip(ip, lifc->mask, net);
124                 if(ipcmp(net, lifc->net) == 0)
125                         return lifc;
126         }
127
128         return nil;
129 }
130
131 static int
132 openlisten(char *net)
133 {
134         int fd, cfd;
135         char data[128], devdir[40];
136         Ipifc *ifc;
137         Iplifc *lifc;
138
139         sprint(data, "%s/udp!*!dhcp6s", net);
140         cfd = announce(data, devdir);
141         if(cfd < 0)
142                 sysfatal("can't announce: %r");
143         if(fprint(cfd, "headers") < 0)
144                 sysfatal("can't set header mode: %r");
145
146         ipifcs = readipifc(net, ipifcs, -1);
147         for(ifc = ipifcs; ifc != nil; ifc = ifc->next){
148                 if(strcmp(ifc->dev, "/dev/null") == 0)
149                         continue;
150                 for(lifc = ifc->lifc; lifc != nil; lifc = lifc->next){
151                         if(!ISIPV6LINKLOCAL(lifc->ip))
152                                 continue;
153                         if(fprint(cfd, "addmulti %I ff02::1:2", lifc->ip) < 0)
154                                 fprint(2, "addmulti: %I: %r\n", lifc->ip);
155                 }
156         }
157
158         sprint(data, "%s/data", devdir);
159         fd = open(data, ORDWR);
160         if(fd < 0)
161                 sysfatal("open udp data: %r");
162
163         return fd;
164 }
165
166 static uchar*
167 gettlv(int x, int *plen, uchar *p, uchar *e)
168 {
169         int t;
170         int l;
171
172         if(plen != nil)
173                 *plen = 0;
174         while(p+4 <= e){
175                 t = p[0]<<8 | p[1];
176                 l = p[2]<<8 | p[3];
177                 if(p+4+l > e)
178                         break;
179                 if(t == x){
180                         if(plen != nil)
181                                 *plen = l;
182                         return p+4;
183                 }
184                 p += l+4;
185         }
186         return nil;
187 }
188
189 static int
190 getv6ips(uchar *ip, int n, Ndbtuple *t, char *attr)
191 {
192         int r = 0;
193
194         if(n < IPaddrlen)
195                 return 0;
196         if(*attr == '@')
197                 attr++;
198         for(; t != nil; t = t->entry){
199                 if(strcmp(t->attr, attr) != 0)
200                         continue;
201                 if(parseip(ip, t->val) == -1)
202                         continue;
203                 if(isv4(ip))
204                         continue;
205                 ip += IPaddrlen;
206                 r += IPaddrlen;
207                 if(r >= n)
208                         break;
209         }
210         return r;
211 }
212
213 static int
214 lookupips(uchar *ip, int n, Ndb *db, uchar mac[Eaddrlen])
215 {
216         Ndbtuple *t;
217         Ndbs s;
218         char val[256], *attr;
219         int r;
220
221         /*
222          *  use hardware address to find an ip address
223          */
224         attr = "ether";
225         snprint(val, sizeof val, "%E", mac);
226
227         t = ndbsearch(db, &s, attr, val);
228         r = 0;
229         while(t != nil){
230                 r += getv6ips(ip + r, n - r, t, "ip");
231                 ndbfree(t);
232                 if(r >= n)
233                         break;
234                 t = ndbsnext(&s, attr, val);
235         }
236         return r;
237 }
238
239 static void
240 clearotab(void)
241 {
242         Otab *o;
243
244         for(o = otab; o->t != 0; o++)
245                 o->done = 0;
246 }
247
248 static Otab*
249 findotab(int t)
250 {
251         Otab *o;
252
253         for(o = otab; o->t != 0; o++)
254                 if(o->t == t)
255                         return o;
256         return nil;
257 }
258
259 static int
260 addoption(Req *r, int t)
261 {
262         Otab *o;
263         int n;
264
265         if(r->resp.p+4 > r->resp.e)
266                 return -1;
267         o = findotab(t);
268         if(o == nil || o->f == nil || o->done)
269                 return -1;
270         o->done = 1;
271         n = (*o->f)(r->resp.p+4, r->resp.e - (r->resp.p+4), o, r);
272         if(n < 0 || r->resp.p+4+n > r->resp.e)
273                 return -1;
274         r->resp.p[0] = t>>8, r->resp.p[1] = t;
275         r->resp.p[2] = n>>8, r->resp.p[3] = n;
276         if(debug) fprint(2, "%d(%.*H)\n", t, n, r->resp.p+4);
277         r->resp.p += 4+n;
278         return n;
279 }
280
281 static void
282 usage(void)
283 {
284         fprint(2, "%s [-d]  [-f ndbfile] [-x netmtpt]\n", argv0);
285         exits("usage");
286 }
287
288 void
289 main(int argc, char *argv[])
290 {
291         uchar ibuf[4096], obuf[4096];
292         Req r[1];
293         int fd, n, i;
294
295         fmtinstall('H', encodefmt);
296         fmtinstall('I', eipfmt);
297         fmtinstall('E', eipfmt);
298
299         ARGBEGIN {
300         case 'd':
301                 debug++;
302                 break;
303         case 'f':
304                 ndbfile = EARGF(usage());
305                 break;
306         case 'x':
307                 netmtpt = EARGF(usage());
308                 break;
309         default:
310                 usage();
311         } ARGEND;
312
313         starttime = time(nil) - 946681200UL;
314
315         if(opendb() == nil)
316                 sysfatal("opendb: %r");
317
318         fd = openlisten(netmtpt);
319
320         /* put process in background */
321         if(!debug)
322         switch(rfork(RFNOTEG|RFPROC|RFFDG)) {
323         default:
324                 exits(nil);
325         case -1:
326                 sysfatal("fork: %r");
327         case 0:
328                 break;
329         }
330
331         while((n = read(fd, ibuf, sizeof(ibuf))) > 0){
332                 if(n < Udphdrsize+4)
333                         continue;
334
335                 r->udp = (Udphdr*)ibuf;
336                 if(isv4(r->udp->raddr))
337                         continue;
338                 if((r->ifc = findifc(netmtpt, r->udp->ifcaddr)) == nil)
339                         continue;
340                 if(localonifc(r->ifc, r->udp->raddr) == nil)
341                         continue;
342
343                 memmove(obuf, ibuf, Udphdrsize);
344                 r->req.p = ibuf+Udphdrsize;
345                 r->req.e = ibuf+n;
346                 r->resp.p = obuf+Udphdrsize;
347                 r->resp.e = &obuf[sizeof(obuf)];
348
349                 r->tra = r->req.p[1]<<16 | r->req.p[2]<<8 | r->req.p[3];
350                 r->req.t = r->req.p[0];
351
352                 if(debug)
353                 fprint(2, "%I->%I(%s) typ=%d tra=%x\n",
354                         r->udp->raddr, r->udp->laddr, r->ifc->dev,
355                         r->req.t, r->tra);
356
357                 switch(r->req.t){
358                 default:
359                         continue;
360                 case SOLICIT:
361                         r->resp.t = ADVERTISE;
362                         break;
363                 case REQUEST:
364                 case INFOREQ:
365                         r->resp.t = REPLY;
366                         break;
367                 }
368                 r->resp.p[0] = r->resp.t;
369                 r->resp.p[1] = r->tra>>16;
370                 r->resp.p[2] = r->tra>>8;
371                 r->resp.p[3] = r->tra;
372
373                 r->req.p += 4;
374                 r->resp.p += 4;
375
376                 r->t = nil;
377
378                 clearotab();
379
380                 /* Server Identifier */
381                 if(addoption(r, 2) < 0)
382                         continue;
383
384                 /* Client Identifier */
385                 if(addoption(r, 1) < 0)
386                         continue;
387
388                 if((r->db = opendb()) == nil)
389                         continue;
390                 r->nips = lookupips(r->ips, sizeof(r->ips), r->db, r->mac)/IPaddrlen;
391                 if(debug){
392                         for(i=0; i<r->nips; i++)
393                                 fprint(2, "ip=%I\n", r->ips+i*IPaddrlen);
394                 }
395
396                 addoption(r, 3);
397                 addoption(r, 6);
398
399                 write(fd, obuf, r->resp.p-obuf);
400                 if(debug) fprint(2, "\n");
401         }
402
403         exits(nil);
404 }
405
406 static int
407 oclientid(uchar *w, int n, Otab*, Req *r)
408 {
409         int len;
410         uchar *p;
411
412         if((p = gettlv(1, &len, r->req.p, r->req.e)) == nil)
413                 return -1;
414         if(len < 4+4+Eaddrlen || n < len)
415                 return -1;
416         memmove(r->mac, p+len-Eaddrlen, Eaddrlen);
417         memmove(w, p, len);
418
419         return len;
420 }
421
422 static int
423 oserverid(uchar *w, int n, Otab*, Req *r)
424 {
425         int len;
426         uchar *p;
427
428         if(n < 4+4+Eaddrlen)
429                 return -1;
430         w[0] = 0, w[1] = 1;     /* duid type: link layer address + time*/
431         w[2] = 0, w[3] = 1;     /* hw type: ethernet */
432         w[4] = starttime>>24;
433         w[5] = starttime>>16;
434         w[6] = starttime>>8;
435         w[7] = starttime;
436         myetheraddr(w+8, r->ifc->dev);
437
438         /* check if server id matches from the request */
439         p = gettlv(2, &len, r->req.p, r->req.e);
440         if(p != nil && (len != 4+4+Eaddrlen || memcmp(w, p, 4+4+Eaddrlen) != 0))
441                 return -1;
442
443         return 4+4+Eaddrlen;
444 }
445
446 static int
447 oiana(uchar *w, int n, Otab*, Req *r)
448 {
449         int i, len;
450         uchar *p;
451
452         p = gettlv(3, &len, r->req.p, r->req.e);
453         if(p == nil || len < 3*4)
454                 return -1;
455
456         len = 3*4 + (4+IPaddrlen+2*4)*r->nips;
457         if(n < len)
458                 return -1;
459
460         memmove(w, p, 3*4);
461         w += 3*4;
462
463         for(i = 0; i < r->nips; i++){
464                 w[0] = 0, w[1] = 5;
465                 w[2] = 0, w[3] = IPaddrlen+2*4;
466                 w += 4;
467
468                 memmove(w, r->ips + i*IPaddrlen, IPaddrlen);
469                 w += IPaddrlen;
470
471                 memset(w, 255, 2*4);
472                 w += 2*4;
473         }
474
475         return len;
476 }
477
478 static Ndbtuple*
479 lookup(Req *r, char *av[], int ac)
480 {
481         Ndbtuple *t;
482         char *s;
483
484         if(ac <= 0)
485                 return nil;
486
487         t = nil;
488         if(r->nips > 0){
489                 int i;
490
491                 /* use the target ip's to lookup info if any */
492                 for(i=0; i<r->nips; i++){
493                         s = smprint("%I", &r->ips[i*IPaddrlen]);
494                         t = ndbconcatenate(t, ndbipinfo(r->db, "ip", s, av, ac));
495                         free(s);
496                 }
497         } else {
498                 Iplifc *lifc;
499
500                 /* use the ipv6 networks on the interface */
501                 for(lifc=r->ifc->lifc; lifc!=nil; lifc=lifc->next){
502                         if(isv4(lifc->ip)
503                         || ipcmp(lifc->ip, v6loopback) == 0
504                         || ISIPV6LINKLOCAL(lifc->ip))
505                                 continue;
506                         s = smprint("%I", lifc->net);
507                         t = ndbconcatenate(t, ndbipinfo(r->db, "ip", s, av, ac));
508                         free(s);
509                 }
510         }
511         return t;
512 }
513
514 static int
515 oro(uchar*, int, Otab *o, Req *r)
516 {
517         uchar *p;
518         char *av[100];
519         int i, j, l, ac;
520         Ndbtuple *t;
521
522         p = gettlv(6, &l, r->req.p, r->req.e);
523         if(p == nil || l < 2)
524                 return -1;
525
526         ac = 0;
527         for(i=0; i<l; i+=2){
528                 if((o = findotab(p[i]>>8 | p[i+1])) == nil || o->done)
529                         continue;
530                 for(j=0; j<3 && o->q[j]!=nil && ac<nelem(av); j++)
531                         av[ac++] = o->q[j];
532         }
533
534         r->t = lookup(r, av, ac);
535
536         if(debug){
537                 fprint(2, "ndb(");
538                 for(t = r->t; t != nil; t = t->entry){
539                         fprint(2, "%s=%s ", t->attr, t->val);
540                         if(t->entry != nil && t->entry != t->line)
541                                 fprint(2, "\n");
542                 }
543                 fprint(2, ")\n");
544         }
545
546         /* process the options */
547         for(i=0; i<l; i+=2)
548                 addoption(r, p[i]>>8 | p[i+1]);
549
550         ndbfree(r->t);
551         r->t = nil;
552
553         return -1;
554 }
555
556 static int
557 oservers(uchar *w, int n, Otab *o, Req *r)
558 {
559         return getv6ips(w, n, r->t, o->q[0]);
560 }
561
562 static int
563 odomainlist(uchar *w, int n, Otab *o, Req *q)
564 {
565         char val[256];
566         Ndbtuple *t;
567         int l, r;
568         char *s;
569
570         r = 0;
571         for(t = q->t; t != nil; t = t->entry){
572                 if(strcmp(t->attr, o->q[0]) != 0)
573                         continue;
574                 if(utf2idn(t->val, val, sizeof(val)) <= 0)
575                         continue;
576                 for(s = val; *s != 0; s++){
577                         for(l = 0; *s != 0 && *s != '.'; l++)
578                                 s++;
579                         if(r+1+l > n)
580                                 return -1;
581                         w[r++] = l;
582                         memmove(w+r, s-l, l);
583                         r += l;
584                         if(*s != '.')
585                                 break;
586                 }
587                 if(r >= n)
588                         return -1;
589                 w[r++] = 0;
590         }
591         return r;
592 }
593
594 static int
595 obootfileurl(uchar *w, int n, Otab *, Req *q)
596 {
597         uchar ip[IPaddrlen];
598         Ndbtuple *bootf;
599
600         if((bootf = ndbfindattr(q->t, q->t, "bootf")) == nil)
601                 return -1;
602         if(strstr(bootf->val, "://") != nil)
603                 return snprint((char*)w, n, "%s", bootf->val);
604         else if(getv6ips(ip, sizeof(ip), q->t, "tftp"))
605                 return snprint((char*)w, n, "tftp://[%I]/%s", ip, bootf->val);
606         return -1;
607 }
608
609 static Otab otab[] = {
610         {  1, oclientid, },
611         {  2, oserverid, },
612         {  3, oiana, },
613         {  6, oro, },
614         { 23, oservers, "@dns" },
615         { 24, odomainlist, "dnsdomain" },
616         { 59, obootfileurl, "bootf", "@tftp", },
617         { 0 },
618 };