]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/ndb/dnstcp.c
ndb/dnstcp: -x specifies the mountmoint
[plan9front.git] / sys / src / cmd / ndb / dnstcp.c
1 /*
2  * dnstcp - serve dns via tcp
3  */
4 #include <u.h>
5 #include <libc.h>
6 #include <ip.h>
7 #include "dns.h"
8
9 Cfg cfg;
10
11 char    *caller = "";
12 char    *dbfile;
13 int     debug;
14 uchar   ipaddr[IPaddrlen];      /* my ip address */
15 char    *logfile = "dns";
16 int     maxage = 60*60;
17 char    mntpt[Maxpath];
18 int     needrefresh;
19 ulong   now;
20 vlong   nowns;
21 int     testing;
22 int     traceactivity;
23 char    *zonerefreshprogram;
24
25 static int      readmsg(int, uchar*, int);
26 static void     reply(int, DNSmsg*, Request*);
27 static void     dnzone(DNSmsg*, DNSmsg*, Request*);
28 static void     getcaller(char*);
29 static void     refreshmain(char*);
30
31 void
32 usage(void)
33 {
34         fprint(2, "usage: %s [-rR] [-f ndb-file] [-x netmtpt] [conndir]\n", argv0);
35         exits("usage");
36 }
37
38 void
39 main(int argc, char *argv[])
40 {
41         volatile int len, rcode;
42         volatile char tname[32];
43         char *volatile err, *volatile ext = "";
44         volatile uchar buf[64*1024], callip[IPaddrlen];
45         volatile DNSmsg reqmsg, repmsg;
46         volatile Request req;
47
48         alarm(2*60*1000);
49         cfg.cachedb = 1;
50         ARGBEGIN{
51         case 'd':
52                 debug++;
53                 break;
54         case 'f':
55                 dbfile = EARGF(usage());
56                 break;
57         case 'r':
58                 cfg.resolver = 1;
59                 break;
60         case 'R':
61                 norecursion = 1;
62                 break;
63         case 'x':
64                 ext = EARGF(usage());
65                 break;
66         default:
67                 usage();
68                 break;
69         }ARGEND
70
71         if(argc > 0)
72                 getcaller(argv[0]);
73
74         cfg.inside = 1;
75         dninit();
76
77         if(*ext == '/')
78                 snprint(mntpt, sizeof mntpt, "%s", ext);
79         else
80                 snprint(mntpt, sizeof mntpt, "/net%s", ext);
81
82         if(myipaddr(ipaddr, mntpt) < 0)
83                 sysfatal("can't read my ip address");
84         dnslog("dnstcp call from %s to %I", caller, ipaddr);
85         memset(callip, 0, sizeof callip);
86         parseip(callip, caller);
87
88         db2cache(1);
89
90         memset(&req, 0, sizeof req);
91         setjmp(req.mret);
92         req.isslave = 0;
93         procsetname("main loop");
94
95         /* loop on requests */
96         for(;; putactivity(0)){
97                 now = time(nil);
98                 memset(&repmsg, 0, sizeof repmsg);
99                 len = readmsg(0, buf, sizeof buf);
100                 if(len <= 0)
101                         break;
102
103                 getactivity(&req, 0);
104                 req.aborttime = timems() + S2MS(15*Min);
105                 rcode = 0;
106                 memset(&reqmsg, 0, sizeof reqmsg);
107                 err = convM2DNS(buf, len, &reqmsg, &rcode);
108                 if(err){
109                         dnslog("server: input error: %s from %s", err, caller);
110                         free(err);
111                         break;
112                 }
113                 if (rcode == 0)
114                         if(reqmsg.qdcount < 1){
115                                 dnslog("server: no questions from %s", caller);
116                                 break;
117                         } else if(reqmsg.flags & Fresp){
118                                 dnslog("server: reply not request from %s",
119                                         caller);
120                                 break;
121                         } else if((reqmsg.flags & Omask) != Oquery){
122                                 dnslog("server: op %d from %s",
123                                         reqmsg.flags & Omask, caller);
124                                 break;
125                         }
126
127                 if(reqmsg.qd == nil){
128                         dnslog("server: no question RR from %s", caller);
129                         break;
130                 }
131
132                 if(debug)
133                         dnslog("[%d] %d: serve (%s) %d %s %s",
134                                 getpid(), req.id, caller,
135                                 reqmsg.id, reqmsg.qd->owner->name,
136                                 rrname(reqmsg.qd->type, tname, sizeof tname));
137
138                 /* loop through each question */
139                 while(reqmsg.qd)
140                         if(reqmsg.qd->type == Taxfr)
141                                 dnzone(&reqmsg, &repmsg, &req);
142                         else {
143                                 dnserver(&reqmsg, &repmsg, &req, callip, rcode);
144                                 reply(1, &repmsg, &req);
145                                 rrfreelist(repmsg.qd);
146                                 rrfreelist(repmsg.an);
147                                 rrfreelist(repmsg.ns);
148                                 rrfreelist(repmsg.ar);
149                         }
150                 rrfreelist(reqmsg.qd);          /* qd will be nil */
151                 rrfreelist(reqmsg.an);
152                 rrfreelist(reqmsg.ns);
153                 rrfreelist(reqmsg.ar);
154
155                 if(req.isslave){
156                         putactivity(0);
157                         _exits(0);
158                 }
159         }
160         refreshmain(mntpt);
161 }
162
163 static int
164 readmsg(int fd, uchar *buf, int max)
165 {
166         int n;
167         uchar x[2];
168
169         if(readn(fd, x, 2) != 2)
170                 return -1;
171         n = x[0]<<8 | x[1];
172         if(n > max)
173                 return -1;
174         if(readn(fd, buf, n) != n)
175                 return -1;
176         return n;
177 }
178
179 static void
180 reply(int fd, DNSmsg *rep, Request *req)
181 {
182         int len, rv;
183         char tname[32];
184         uchar buf[64*1024];
185         RR *rp;
186
187         if(debug){
188                 dnslog("%d: reply (%s) %s %s %ux",
189                         req->id, caller,
190                         rep->qd->owner->name,
191                         rrname(rep->qd->type, tname, sizeof tname),
192                         rep->flags);
193                 for(rp = rep->an; rp; rp = rp->next)
194                         dnslog("an %R", rp);
195                 for(rp = rep->ns; rp; rp = rp->next)
196                         dnslog("ns %R", rp);
197                 for(rp = rep->ar; rp; rp = rp->next)
198                         dnslog("ar %R", rp);
199         }
200
201
202         len = convDNS2M(rep, buf+2, sizeof(buf) - 2);
203         buf[0] = len>>8;
204         buf[1] = len;
205         rv = write(fd, buf, len+2);
206         if(rv != len+2){
207                 dnslog("[%d] sending reply: %d instead of %d", getpid(), rv,
208                         len+2);
209                 exits(0);
210         }
211 }
212
213 /*
214  *  Hash table for domain names.  The hash is based only on the
215  *  first element of the domain name.
216  */
217 extern DN       *ht[HTLEN];
218
219 static int
220 numelem(char *name)
221 {
222         int i;
223
224         i = 1;
225         for(; *name; name++)
226                 if(*name == '.')
227                         i++;
228         return i;
229 }
230
231 int
232 inzone(DN *dp, char *name, int namelen, int depth)
233 {
234         int n;
235
236         if(dp->name == nil)
237                 return 0;
238         if(numelem(dp->name) != depth)
239                 return 0;
240         n = strlen(dp->name);
241         if(n < namelen)
242                 return 0;
243         if(cistrcmp(name, dp->name + n - namelen) != 0)
244                 return 0;
245         if(n > namelen && dp->name[n - namelen - 1] != '.')
246                 return 0;
247         return 1;
248 }
249
250 static void
251 dnzone(DNSmsg *reqp, DNSmsg *repp, Request *req)
252 {
253         DN *dp, *ndp;
254         RR r, *rp;
255         int h, depth, found, nlen;
256
257         memset(repp, 0, sizeof(*repp));
258         repp->id = reqp->id;
259         repp->qd = reqp->qd;
260         reqp->qd = reqp->qd->next;
261         repp->qd->next = 0;
262         repp->flags = Fauth | Fresp | Oquery;
263         if(!norecursion)
264                 repp->flags |= Fcanrec;
265         dp = repp->qd->owner;
266
267         /* send the soa */
268         repp->an = rrlookup(dp, Tsoa, NOneg);
269         reply(1, repp, req);
270         if(repp->an == 0)
271                 goto out;
272         rrfreelist(repp->an);
273         repp->an = nil;
274
275         nlen = strlen(dp->name);
276
277         /* construct a breadth-first search of the name space (hard with a hash) */
278         repp->an = &r;
279         for(depth = numelem(dp->name); ; depth++){
280                 found = 0;
281                 for(h = 0; h < HTLEN; h++)
282                         for(ndp = ht[h]; ndp; ndp = ndp->next)
283                                 if(inzone(ndp, dp->name, nlen, depth)){
284                                         for(rp = ndp->rr; rp; rp = rp->next){
285                                                 /*
286                                                  * there shouldn't be negatives,
287                                                  * but just in case.
288                                                  * don't send any soa's,
289                                                  * ns's are enough.
290                                                  */
291                                                 if (rp->negative ||
292                                                     rp->type == Tsoa)
293                                                         continue;
294                                                 r = *rp;
295                                                 r.next = 0;
296                                                 reply(1, repp, req);
297                                         }
298                                         found = 1;
299                                 }
300                 if(!found)
301                         break;
302         }
303
304         /* resend the soa */
305         repp->an = rrlookup(dp, Tsoa, NOneg);
306         reply(1, repp, req);
307         rrfreelist(repp->an);
308         repp->an = nil;
309 out:
310         rrfree(repp->qd);
311         repp->qd = nil;
312 }
313
314 static void
315 getcaller(char *dir)
316 {
317         int fd, n;
318         static char remote[128];
319
320         snprint(remote, sizeof(remote), "%s/remote", dir);
321         fd = open(remote, OREAD);
322         if(fd < 0)
323                 return;
324         n = read(fd, remote, sizeof remote - 1);
325         close(fd);
326         if(n <= 0)
327                 return;
328         if(remote[n-1] == '\n')
329                 n--;
330         remote[n] = 0;
331         caller = remote;
332 }
333
334 static void
335 refreshmain(char *net)
336 {
337         int fd;
338         char file[128];
339
340         snprint(file, sizeof(file), "%s/dns", net);
341         if(debug)
342                 dnslog("refreshing %s", file);
343         fd = open(file, ORDWR);
344         if(fd < 0)
345                 dnslog("can't refresh %s", file);
346         else {
347                 fprint(fd, "refresh");
348                 close(fd);
349         }
350 }
351
352 /*
353  *  the following varies between dnsdebug and dns
354  */
355 void
356 logreply(int id, uchar *addr, DNSmsg *mp)
357 {
358         RR *rp;
359
360         dnslog("%d: rcvd %I flags:%s%s%s%s%s", id, addr,
361                 mp->flags & Fauth? " auth": "",
362                 mp->flags & Ftrunc? " trunc": "",
363                 mp->flags & Frecurse? " rd": "",
364                 mp->flags & Fcanrec? " ra": "",
365                 (mp->flags & (Fauth|Rmask)) == (Fauth|Rname)? " nx": "");
366         for(rp = mp->qd; rp != nil; rp = rp->next)
367                 dnslog("%d: rcvd %I qd %s", id, addr, rp->owner->name);
368         for(rp = mp->an; rp != nil; rp = rp->next)
369                 dnslog("%d: rcvd %I an %R", id, addr, rp);
370         for(rp = mp->ns; rp != nil; rp = rp->next)
371                 dnslog("%d: rcvd %I ns %R", id, addr, rp);
372         for(rp = mp->ar; rp != nil; rp = rp->next)
373                 dnslog("%d: rcvd %I ar %R", id, addr, rp);
374 }
375
376 void
377 logsend(int id, int subid, uchar *addr, char *sname, char *rname, int type)
378 {
379         char buf[12];
380
381         dnslog("%d.%d: sending to %I/%s %s %s",
382                 id, subid, addr, sname, rname, rrname(type, buf, sizeof buf));
383 }
384
385 RR*
386 getdnsservers(int class)
387 {
388         return dnsservers(class);
389 }