]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/9nfs/server.c
dc: fix off by one in stack overflow check (thanks BurnZeZ)
[plan9front.git] / sys / src / cmd / 9nfs / server.c
1 #include "all.h"
2 #include <ndb.h>
3
4 static int      alarmflag;
5
6 static int      Iconv(Fmt*);
7 static void     cachereply(Rpccall*, void*, int);
8 static int      replycache(int, Rpccall*, long (*)(int, void*, long));
9 static void     udpserver(int, Progmap*);
10 static void     tcpserver(int, Progmap*);
11 static void     getendpoints(Udphdr*, char*);
12 static long     readtcp(int, void*, long);
13 static long     writetcp(int, void*, long);
14 static int      servemsg(int, long (*)(int, void*, long), long (*)(int, void*, long),
15                 int, Progmap*);
16 void    (*rpcalarm)(void);
17 int     rpcdebug;
18 int     rejectall;
19 int     p9debug;
20
21 int     nocache;
22
23 uchar   buf[9000];
24 uchar   rbuf[9000];
25 uchar   resultbuf[9000];
26
27 static int tcp;
28
29 char *commonopts = "[-9CDrtv]";                 /* for usage() messages */
30
31 /*
32  * this recognises common, nominally rcp-related options.
33  * they may not take arguments.
34  */
35 int
36 argopt(int c)
37 {
38         switch(c){
39         case '9':
40                 ++p9debug;
41                 return 0;
42         case 'C':
43                 ++nocache;
44                 return 0;
45         case 'D':
46                 ++rpcdebug;
47                 return 0;
48         case 'r':
49                 ++rejectall;
50                 return 0;
51         case 't':
52                 tcp = 1;
53                 return 0;
54         case 'v':
55                 ++chatty;
56                 return 0;
57         default:
58                 return -1;
59         }
60 }
61
62 /*
63  * all option parsing is now done in (*pg->init)(), which can call back
64  * here to argopt for common options.
65  */
66 void
67 server(int argc, char **argv, int myport, Progmap *progmap)
68 {
69         Progmap *pg;
70
71         fmtinstall('I', Iconv);
72         fmtinstall('F', fcallfmt);
73         fmtinstall('D', dirfmt);
74
75         switch(rfork(RFNOWAIT|RFENVG|RFNAMEG|RFNOTEG|RFFDG|RFPROC)){
76         case -1:
77                 panic("fork");
78         default:
79                 _exits(0);
80         case 0:
81                 break;
82         }
83
84         switch(rfork(RFMEM|RFPROC)){
85         case 0:
86                 for(;;){
87                         sleep(30*1000);
88                         alarmflag = 1;
89                 }
90         case -1:
91                 sysfatal("rfork: %r");
92         }
93
94         for(pg=progmap; pg->init; pg++)
95                 (*pg->init)(argc, argv);
96         if(tcp)
97                 tcpserver(myport, progmap);
98         else
99                 udpserver(myport, progmap);
100 }
101
102 static void
103 udpserver(int myport, Progmap *progmap)
104 {
105         char service[128];
106         char data[128];
107         char devdir[40];
108         int ctlfd, datafd;
109
110         snprint(service, sizeof service, "udp!*!%d", myport);
111         ctlfd = announce(service, devdir);
112         if(ctlfd < 0)
113                 panic("can't announce %s: %r\n", service);
114         if(fprint(ctlfd, "headers") < 0)
115                 panic("can't set header mode: %r\n");
116
117         snprint(data, sizeof data, "%s/data", devdir);
118         datafd = open(data, ORDWR);
119         if(datafd < 0)
120                 panic("can't open udp data: %r\n");
121         close(ctlfd);
122
123         chatsrv(0);
124         clog("%s: listening to port %d\n", argv0, myport);
125         while (servemsg(datafd, read, write, myport, progmap) >= 0)
126                 continue;
127         exits(0);
128 }
129
130 static void
131 tcpserver(int myport, Progmap *progmap)
132 {
133         char adir[40];
134         char ldir[40];
135         char ds[40];
136         int actl, lctl, data;
137
138         snprint(ds, sizeof ds, "tcp!*!%d", myport);
139         chatsrv(0);
140         actl = -1;
141         for(;;){
142                 if(actl < 0){
143                         actl = announce(ds, adir);
144                         if(actl < 0){
145                                 clog("%s: listening to tcp port %d\n",
146                                         argv0, myport);
147                                 clog("announcing: %r");
148                                 break;
149                         }
150                 }
151                 lctl = listen(adir, ldir);
152                 if(lctl < 0){
153                         close(actl);
154                         actl = -1;
155                         continue;
156                 }
157                 switch(fork()){
158                 case -1:
159                         clog("%s!%d: %r\n", argv0, myport);
160                         /* fall through */
161                 default:
162                         close(lctl);
163                         continue;
164                 case 0:
165                         close(actl);
166                         data = accept(lctl, ldir);
167                         close(lctl);
168                         if(data < 0)
169                                 exits(0);
170
171                         /* pretend it's udp; fill in Udphdr */
172                         getendpoints((Udphdr*)buf, ldir);
173
174                         while (servemsg(data, readtcp, writetcp, myport,
175                             progmap) >= 0)
176                                 continue;
177                         close(data);
178                         exits(0);
179                 }
180         }
181         exits(0);
182 }
183
184 static int
185 servemsg(int fd, long (*readmsg)(int, void*, long), long (*writemsg)(int, void*, long),
186                 int myport, Progmap * progmap)
187 {
188         int i, n, nreply;
189         Rpccall rcall, rreply;
190         int vlo, vhi;
191         Progmap *pg;
192         Procmap *pp;
193         char errbuf[ERRMAX];
194
195         if(alarmflag){
196                 alarmflag = 0;
197                 if(rpcalarm)
198                         (*rpcalarm)();
199         }
200         n = (*readmsg)(fd, buf, sizeof buf);
201         if(n < 0){
202                 errstr(errbuf, sizeof errbuf);
203                 if(strcmp(errbuf, "interrupted") == 0)
204                         return 0;
205                 clog("port %d: error: %s\n", myport, errbuf);
206                 return -1;
207         }
208         if(n == 0){
209                 clog("port %d: EOF\n", myport);
210                 return -1;
211         }
212         if(rpcdebug == 1)
213                 fprint(2, "%s: rpc from %d.%d.%d.%d/%d\n",
214                         argv0, buf[12], buf[13], buf[14], buf[15],
215                         (buf[32]<<8)|buf[33]);
216         i = rpcM2S(buf, &rcall, n);
217         if(i != 0){
218                 clog("udp port %d: message format error %d\n",
219                         myport, i);
220                 return 0;
221         }
222         if(rpcdebug > 1)
223                 rpcprint(2, &rcall);
224         if(rcall.mtype != CALL)
225                 return 0;
226         if(replycache(fd, &rcall, writemsg))
227                 return 0;
228         nreply = 0;
229         rreply.host = rcall.host;
230         rreply.port = rcall.port;
231         rreply.lhost = rcall.lhost;
232         rreply.lport = rcall.lport;
233         rreply.xid = rcall.xid;
234         rreply.mtype = REPLY;
235         if(rcall.rpcvers != 2){
236                 rreply.stat = MSG_DENIED;
237                 rreply.rstat = RPC_MISMATCH;
238                 rreply.rlow = 2;
239                 rreply.rhigh = 2;
240                 goto send_reply;
241         }
242         if(rejectall){
243                 rreply.stat = MSG_DENIED;
244                 rreply.rstat = AUTH_ERROR;
245                 rreply.authstat = AUTH_TOOWEAK;
246                 goto send_reply;
247         }
248         i = n - (((uchar *)rcall.args) - buf);
249         if(rpcdebug > 1)
250                 fprint(2, "arg size = %d\n", i);
251         rreply.stat = MSG_ACCEPTED;
252         rreply.averf.flavor = 0;
253         rreply.averf.count = 0;
254         rreply.results = resultbuf;
255         vlo = 0x7fffffff;
256         vhi = -1;
257         for(pg=progmap; pg->pmap; pg++){
258                 if(pg->progno != rcall.prog)
259                         continue;
260                 if(pg->vers == rcall.vers)
261                         break;
262                 if(pg->vers < vlo)
263                         vlo = pg->vers;
264                 if(pg->vers > vhi)
265                         vhi = pg->vers;
266         }
267         if(pg->pmap == 0){
268                 if(vhi < 0)
269                         rreply.astat = PROG_UNAVAIL;
270                 else{
271                         rreply.astat = PROG_MISMATCH;
272                         rreply.plow = vlo;
273                         rreply.phigh = vhi;
274                 }
275                 goto send_reply;
276         }
277         for(pp = pg->pmap; pp->procp; pp++)
278                 if(rcall.proc == pp->procno){
279                         if(rpcdebug > 1)
280                                 fprint(2, "process %d\n", pp->procno);
281                         rreply.astat = SUCCESS;
282                         nreply = (*pp->procp)(i, &rcall, &rreply);
283                         goto send_reply;
284                 }
285         rreply.astat = PROC_UNAVAIL;
286 send_reply:
287         if(nreply >= 0){
288                 i = rpcS2M(&rreply, nreply, rbuf);
289                 if(rpcdebug > 1)
290                         rpcprint(2, &rreply);
291                 (*writemsg)(fd, rbuf, i);
292                 cachereply(&rreply, rbuf, i);
293         }
294         return 0;
295 }
296
297 static void
298 getendpoint(char *dir, char *file, uchar *addr, uchar *port)
299 {
300         int fd, n;
301         char buf[128];
302         char *sys, *serv;
303
304         sys = serv = 0;
305
306         snprint(buf, sizeof buf, "%s/%s", dir, file);
307         fd = open(buf, OREAD);
308         if(fd >= 0){
309                 n = read(fd, buf, sizeof(buf)-1);
310                 if(n>0){
311                         buf[n-1] = 0;
312                         serv = strchr(buf, '!');
313                         if(serv){
314                                 *serv++ = 0;
315                                 serv = strdup(serv);
316                         }
317                         sys = strdup(buf);
318                 }
319                 close(fd);
320         }
321         if(serv == 0)
322                 serv = strdup("unknown");
323         if(sys == 0)
324                 sys = strdup("unknown");
325         parseip(addr, sys);
326         n = atoi(serv);
327         hnputs(port, n);
328 }
329
330 /* set Udphdr values from protocol dir local & remote files */
331 static void
332 getendpoints(Udphdr *ep, char *dir)
333 {
334         getendpoint(dir, "local", ep->laddr, ep->lport);
335         getendpoint(dir, "remote", ep->raddr, ep->rport);
336 }
337
338 static long
339 readtcp(int fd, void *vbuf, long blen)
340 {
341         uchar mk[4];
342         int n, m, sofar;
343         ulong done;
344         char *buf;
345
346         buf = vbuf;
347         buf += Udphdrsize;
348         blen -= Udphdrsize;
349
350         done = 0;
351         for(sofar = 0; !done; sofar += n){
352                 m = readn(fd, mk, 4);
353                 if(m < 4)
354                         return 0;
355                 done = (mk[0]<<24)|(mk[1]<<16)|(mk[2]<<8)|mk[3];
356                 m = done & 0x7fffffff;
357                 done &= 0x80000000;
358                 if(m > blen-sofar)
359                         return -1;
360                 n = readn(fd, buf+sofar, m);
361                 if(m != n)
362                         return 0;
363         }
364         return sofar + Udphdrsize;
365 }
366
367 static long
368 writetcp(int fd, void *vbuf, long len)
369 {
370         char *buf;
371
372         buf = vbuf;
373         buf += Udphdrsize;
374         len -= Udphdrsize;
375
376         buf -= 4;
377         buf[0] = 0x80 | (len>>24);
378         buf[1] = len>>16;
379         buf[2] = len>>8;
380         buf[3] = len;
381         len += 4;
382         return write(fd, buf, len);
383 }
384 /*
385  *long
386  *niwrite(int fd, void *buf, long count)
387  *{
388  *      char errbuf[ERRLEN];
389  *      long n;
390  *
391  *      for(;;){
392  *              n = write(fd, buf, count);
393  *              if(n < 0){
394  *                      errstr(errbuf);
395  *                      if(strcmp(errbuf, "interrupted") == 0)
396  *                              continue;
397  *                      clog("niwrite error: %s\n", errbuf);
398  *                      werrstr(errbuf);
399  *              }
400  *              break;
401  *      }
402  *      return n;
403  *}
404  */
405 long
406 niwrite(int fd, void *buf, long n)
407 {
408 //      int savalarm;
409
410 //      savalarm = alarm(0);
411         n = write(fd, buf, n);
412 //      if(savalarm > 0)
413 //              alarm(savalarm);
414         return n;
415 }
416
417 typedef struct Namecache        Namecache;
418 struct Namecache {
419         char dom[256];
420         ulong ipaddr;
421         Namecache *next;
422 };
423
424 Namecache *dnscache;
425
426 static Namecache*
427 domlookupl(void *name, int len)
428 {
429         Namecache *n, **ln;
430
431         if(len >= sizeof(n->dom))
432                 return nil;
433
434         for(ln=&dnscache, n=*ln; n; ln=&(*ln)->next, n=*ln) {
435                 if(strncmp(n->dom, name, len) == 0 && n->dom[len] == 0) {
436                         *ln = n->next;
437                         n->next = dnscache;
438                         dnscache = n;
439                         return n;
440                 }
441         }
442         return nil;
443 }
444
445 static Namecache*
446 iplookup(ulong ip)
447 {
448         Namecache *n, **ln;
449
450         for(ln=&dnscache, n=*ln; n; ln=&(*ln)->next, n=*ln) {
451                 if(n->ipaddr == ip) {
452                         *ln = n->next;
453                         n->next = dnscache;
454                         dnscache = n;
455                         return n;
456                 }
457         }
458         return nil;
459 }
460
461 static Namecache*
462 addcacheentry(void *name, int len, ulong ip)
463 {
464         Namecache *n;
465
466         if(len >= sizeof(n->dom))
467                 return nil;
468
469         n = malloc(sizeof(*n));
470         if(n == nil)
471                 return nil;
472         strncpy(n->dom, name, len);
473         n->dom[len] = 0;
474         n->ipaddr = ip;
475         n->next = dnscache;
476         dnscache = n;
477         return nil;
478 }
479
480 int
481 getdnsdom(ulong ip, char *name, int len)
482 {
483         char buf[128];
484         Namecache *nc;
485         char *p;
486
487         if(nc=iplookup(ip)) {
488                 strncpy(name, nc->dom, len);
489                 name[len-1] = 0;
490                 return 0;
491         }
492         clog("getdnsdom: %I\n", ip);
493         snprint(buf, sizeof buf, "%I", ip);
494         p = csgetvalue("/net", "ip", buf, "dom", nil);
495         if(p == nil)
496                 return -1;
497         strncpy(name, p, len-1);
498         name[len] = 0;
499         free(p);
500         addcacheentry(name, strlen(name), ip);
501         return 0;
502 }
503
504 int
505 getdom(ulong ip, char *dom, int len)
506 {
507         int i;
508         static char *prefix[] = { "", "gate-", "fddi-", "u-", 0 };
509         char **pr;
510
511         if(getdnsdom(ip, dom, len)<0)
512                 return -1;
513
514         for(pr=prefix; *pr; pr++){
515                 i = strlen(*pr);
516                 if(strncmp(dom, *pr, i) == 0) {
517                         memmove(dom, dom+i, len-i);
518                         break;
519                 }
520         }
521         return 0;
522 }
523
524 #define MAXCACHE        64
525
526 static Rpccache *head, *tail;
527 static int      ncache;
528
529 static void
530 cachereply(Rpccall *rp, void *buf, int len)
531 {
532         Rpccache *cp;
533
534         if(nocache)
535                 return;
536
537         if(ncache >= MAXCACHE){
538                 if(rpcdebug)
539                         fprint(2, "%s: drop  %I/%ld, xid %uld, len %d\n",
540                                 argv0, tail->host,
541                                 tail->port, tail->xid, tail->n);
542                 tail = tail->prev;
543                 free(tail->next);
544                 tail->next = 0;
545                 --ncache;
546         }
547         cp = malloc(sizeof(Rpccache)+len-4);
548         if(cp == 0){
549                 clog("cachereply: malloc %d failed\n", len);
550                 return;
551         }
552         ++ncache;
553         cp->prev = 0;
554         cp->next = head;
555         if(head)
556                 head->prev = cp;
557         else
558                 tail = cp;
559         head = cp;
560         cp->host = rp->host;
561         cp->port = rp->port;
562         cp->xid = rp->xid;
563         cp->n = len;
564         memmove(cp->data, buf, len);
565         if(rpcdebug)
566                 fprint(2, "%s: cache %I/%ld, xid %uld, len %d\n",
567                         argv0, cp->host, cp->port, cp->xid, cp->n);
568 }
569
570 static int
571 replycache(int fd, Rpccall *rp, long (*writemsg)(int, void*, long))
572 {
573         Rpccache *cp;
574
575         for(cp=head; cp; cp=cp->next)
576                 if(cp->host == rp->host &&
577                    cp->port == rp->port &&
578                    cp->xid == rp->xid)
579                         break;
580         if(cp == 0)
581                 return 0;
582         if(cp->prev){   /* move to front */
583                 cp->prev->next = cp->next;
584                 if(cp->next)
585                         cp->next->prev = cp->prev;
586                 else
587                         tail = cp->prev;
588                 cp->prev = 0;
589                 cp->next = head;
590                 head->prev = cp;
591                 head = cp;
592         }
593         (*writemsg)(fd, cp->data, cp->n);
594         if(rpcdebug)
595                 fprint(2, "%s: reply %I/%ld, xid %uld, len %d\n",
596                         argv0, cp->host, cp->port, cp->xid, cp->n);
597         return 1;
598 }
599
600 static int
601 Iconv(Fmt *f)
602 {
603         char buf[16];
604         ulong h;
605
606         h = va_arg(f->args, ulong);
607         snprint(buf, sizeof buf, "%ld.%ld.%ld.%ld",
608                 (h>>24)&0xff, (h>>16)&0xff,
609                 (h>>8)&0xff, h&0xff);
610         return fmtstrcpy(f, buf);
611 }