]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/ndb/dnstcp.c
merge
[plan9front.git] / sys / src / cmd / ndb / dnstcp.c
1 /*
2  * dnstcp - serve dns via tcp
3  */
4 #include <u.h>
5 #include <libc.h>
6 #include <ip.h>
7 #include "dns.h"
8
9 Cfg cfg;
10
11 char    *caller = "";
12 char    *dbfile;
13 int     debug;
14 uchar   ipaddr[IPaddrlen];      /* my ip address */
15 char    *logfile = "dns";
16 int     maxage = 60*60;
17 char    mntpt[Maxpath];
18 int     needrefresh;
19 ulong   now;
20 vlong   nowns;
21 int     testing;
22 int     traceactivity;
23 char    *zonerefreshprogram;
24
25 static int      readmsg(int, uchar*, int);
26 static void     reply(int, DNSmsg*, Request*);
27 static void     dnzone(DNSmsg*, DNSmsg*, Request*);
28 static void     getcaller(char*);
29 static void     refreshmain(char*);
30
31 void
32 usage(void)
33 {
34         fprint(2, "usage: %s [-rR] [-f ndb-file] [-x netmtpt] [conndir]\n", argv0);
35         exits("usage");
36 }
37
38 void
39 main(int argc, char *argv[])
40 {
41         volatile int len, rcode;
42         volatile char tname[32];
43         char *volatile err, *volatile ext = "";
44         volatile uchar buf[64*1024], callip[IPaddrlen];
45         volatile DNSmsg reqmsg, repmsg;
46         volatile Request req;
47
48         alarm(2*60*1000);
49         cfg.cachedb = 1;
50         ARGBEGIN{
51         case 'd':
52                 debug++;
53                 break;
54         case 'f':
55                 dbfile = EARGF(usage());
56                 break;
57         case 'r':
58                 cfg.resolver = 1;
59                 break;
60         case 'R':
61                 norecursion = 1;
62                 break;
63         case 'x':
64                 ext = EARGF(usage());
65                 break;
66         default:
67                 usage();
68                 break;
69         }ARGEND
70
71         if(debug < 2)
72                 debug = 0;
73
74         if(argc > 0)
75                 getcaller(argv[0]);
76
77         cfg.inside = 1;
78         dninit();
79
80         snprint(mntpt, sizeof mntpt, "/net%s", ext);
81         if(myipaddr(ipaddr, mntpt) < 0)
82                 sysfatal("can't read my ip address");
83         dnslog("dnstcp call from %s to %I", caller, ipaddr);
84         memset(callip, 0, sizeof callip);
85         parseip(callip, caller);
86
87         db2cache(1);
88
89         memset(&req, 0, sizeof req);
90         setjmp(req.mret);
91         req.isslave = 0;
92         procsetname("main loop");
93
94         /* loop on requests */
95         for(;; putactivity(0)){
96                 now = time(nil);
97                 memset(&repmsg, 0, sizeof repmsg);
98                 len = readmsg(0, buf, sizeof buf);
99                 if(len <= 0)
100                         break;
101
102                 getactivity(&req, 0);
103                 req.aborttime = timems() + S2MS(15*Min);
104                 rcode = 0;
105                 memset(&reqmsg, 0, sizeof reqmsg);
106                 err = convM2DNS(buf, len, &reqmsg, &rcode);
107                 if(err){
108                         dnslog("server: input error: %s from %s", err, caller);
109                         free(err);
110                         break;
111                 }
112                 if (rcode == 0)
113                         if(reqmsg.qdcount < 1){
114                                 dnslog("server: no questions from %s", caller);
115                                 break;
116                         } else if(reqmsg.flags & Fresp){
117                                 dnslog("server: reply not request from %s",
118                                         caller);
119                                 break;
120                         } else if((reqmsg.flags & Omask) != Oquery){
121                                 dnslog("server: op %d from %s",
122                                         reqmsg.flags & Omask, caller);
123                                 break;
124                         }
125                 if(debug)
126                         dnslog("[%d] %d: serve (%s) %d %s %s",
127                                 getpid(), req.id, caller,
128                                 reqmsg.id, reqmsg.qd->owner->name,
129                                 rrname(reqmsg.qd->type, tname, sizeof tname));
130
131                 /* loop through each question */
132                 while(reqmsg.qd)
133                         if(reqmsg.qd->type == Taxfr)
134                                 dnzone(&reqmsg, &repmsg, &req);
135                         else {
136                                 dnserver(&reqmsg, &repmsg, &req, callip, rcode);
137                                 reply(1, &repmsg, &req);
138                                 rrfreelist(repmsg.qd);
139                                 rrfreelist(repmsg.an);
140                                 rrfreelist(repmsg.ns);
141                                 rrfreelist(repmsg.ar);
142                         }
143                 rrfreelist(reqmsg.qd);          /* qd will be nil */
144                 rrfreelist(reqmsg.an);
145                 rrfreelist(reqmsg.ns);
146                 rrfreelist(reqmsg.ar);
147
148                 if(req.isslave){
149                         putactivity(0);
150                         _exits(0);
151                 }
152         }
153         refreshmain(mntpt);
154 }
155
156 static int
157 readmsg(int fd, uchar *buf, int max)
158 {
159         int n;
160         uchar x[2];
161
162         if(readn(fd, x, 2) != 2)
163                 return -1;
164         n = x[0]<<8 | x[1];
165         if(n > max)
166                 return -1;
167         if(readn(fd, buf, n) != n)
168                 return -1;
169         return n;
170 }
171
172 static void
173 reply(int fd, DNSmsg *rep, Request *req)
174 {
175         int len, rv;
176         char tname[32];
177         uchar buf[64*1024];
178         RR *rp;
179
180         if(debug){
181                 dnslog("%d: reply (%s) %s %s %ux",
182                         req->id, caller,
183                         rep->qd->owner->name,
184                         rrname(rep->qd->type, tname, sizeof tname),
185                         rep->flags);
186                 for(rp = rep->an; rp; rp = rp->next)
187                         dnslog("an %R", rp);
188                 for(rp = rep->ns; rp; rp = rp->next)
189                         dnslog("ns %R", rp);
190                 for(rp = rep->ar; rp; rp = rp->next)
191                         dnslog("ar %R", rp);
192         }
193
194
195         len = convDNS2M(rep, buf+2, sizeof(buf) - 2);
196         buf[0] = len>>8;
197         buf[1] = len;
198         rv = write(fd, buf, len+2);
199         if(rv != len+2){
200                 dnslog("[%d] sending reply: %d instead of %d", getpid(), rv,
201                         len+2);
202                 exits(0);
203         }
204 }
205
206 /*
207  *  Hash table for domain names.  The hash is based only on the
208  *  first element of the domain name.
209  */
210 extern DN       *ht[HTLEN];
211
212 static int
213 numelem(char *name)
214 {
215         int i;
216
217         i = 1;
218         for(; *name; name++)
219                 if(*name == '.')
220                         i++;
221         return i;
222 }
223
224 int
225 inzone(DN *dp, char *name, int namelen, int depth)
226 {
227         int n;
228
229         if(dp->name == nil)
230                 return 0;
231         if(numelem(dp->name) != depth)
232                 return 0;
233         n = strlen(dp->name);
234         if(n < namelen)
235                 return 0;
236         if(strcmp(name, dp->name + n - namelen) != 0)
237                 return 0;
238         if(n > namelen && dp->name[n - namelen - 1] != '.')
239                 return 0;
240         return 1;
241 }
242
243 static void
244 dnzone(DNSmsg *reqp, DNSmsg *repp, Request *req)
245 {
246         DN *dp, *ndp;
247         RR r, *rp;
248         int h, depth, found, nlen;
249
250         memset(repp, 0, sizeof(*repp));
251         repp->id = reqp->id;
252         repp->qd = reqp->qd;
253         reqp->qd = reqp->qd->next;
254         repp->qd->next = 0;
255         repp->flags = Fauth | Fresp | Oquery;
256         if(!norecursion)
257                 repp->flags |= Fcanrec;
258         dp = repp->qd->owner;
259
260         /* send the soa */
261         repp->an = rrlookup(dp, Tsoa, NOneg);
262         reply(1, repp, req);
263         if(repp->an == 0)
264                 goto out;
265         rrfreelist(repp->an);
266         repp->an = nil;
267
268         nlen = strlen(dp->name);
269
270         /* construct a breadth-first search of the name space (hard with a hash) */
271         repp->an = &r;
272         for(depth = numelem(dp->name); ; depth++){
273                 found = 0;
274                 for(h = 0; h < HTLEN; h++)
275                         for(ndp = ht[h]; ndp; ndp = ndp->next)
276                                 if(inzone(ndp, dp->name, nlen, depth)){
277                                         for(rp = ndp->rr; rp; rp = rp->next){
278                                                 /*
279                                                  * there shouldn't be negatives,
280                                                  * but just in case.
281                                                  * don't send any soa's,
282                                                  * ns's are enough.
283                                                  */
284                                                 if (rp->negative ||
285                                                     rp->type == Tsoa)
286                                                         continue;
287                                                 r = *rp;
288                                                 r.next = 0;
289                                                 reply(1, repp, req);
290                                         }
291                                         found = 1;
292                                 }
293                 if(!found)
294                         break;
295         }
296
297         /* resend the soa */
298         repp->an = rrlookup(dp, Tsoa, NOneg);
299         reply(1, repp, req);
300         rrfreelist(repp->an);
301         repp->an = nil;
302 out:
303         rrfree(repp->qd);
304         repp->qd = nil;
305 }
306
307 static void
308 getcaller(char *dir)
309 {
310         int fd, n;
311         static char remote[128];
312
313         snprint(remote, sizeof(remote), "%s/remote", dir);
314         fd = open(remote, OREAD);
315         if(fd < 0)
316                 return;
317         n = read(fd, remote, sizeof remote - 1);
318         close(fd);
319         if(n <= 0)
320                 return;
321         if(remote[n-1] == '\n')
322                 n--;
323         remote[n] = 0;
324         caller = remote;
325 }
326
327 static void
328 refreshmain(char *net)
329 {
330         int fd;
331         char file[128];
332
333         snprint(file, sizeof(file), "%s/dns", net);
334         if(debug)
335                 dnslog("refreshing %s", file);
336         fd = open(file, ORDWR);
337         if(fd < 0)
338                 dnslog("can't refresh %s", file);
339         else {
340                 fprint(fd, "refresh");
341                 close(fd);
342         }
343 }
344
345 /*
346  *  the following varies between dnsdebug and dns
347  */
348 void
349 logreply(int id, uchar *addr, DNSmsg *mp)
350 {
351         RR *rp;
352
353         dnslog("%d: rcvd %I flags:%s%s%s%s%s", id, addr,
354                 mp->flags & Fauth? " auth": "",
355                 mp->flags & Ftrunc? " trunc": "",
356                 mp->flags & Frecurse? " rd": "",
357                 mp->flags & Fcanrec? " ra": "",
358                 (mp->flags & (Fauth|Rmask)) == (Fauth|Rname)? " nx": "");
359         for(rp = mp->qd; rp != nil; rp = rp->next)
360                 dnslog("%d: rcvd %I qd %s", id, addr, rp->owner->name);
361         for(rp = mp->an; rp != nil; rp = rp->next)
362                 dnslog("%d: rcvd %I an %R", id, addr, rp);
363         for(rp = mp->ns; rp != nil; rp = rp->next)
364                 dnslog("%d: rcvd %I ns %R", id, addr, rp);
365         for(rp = mp->ar; rp != nil; rp = rp->next)
366                 dnslog("%d: rcvd %I ar %R", id, addr, rp);
367 }
368
369 void
370 logsend(int id, int subid, uchar *addr, char *sname, char *rname, int type)
371 {
372         char buf[12];
373
374         dnslog("%d.%d: sending to %I/%s %s %s",
375                 id, subid, addr, sname, rname, rrname(type, buf, sizeof buf));
376 }
377
378 RR*
379 getdnsservers(int class)
380 {
381         return dnsservers(class);
382 }