]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/ndb/dnstcp.c
merge
[plan9front.git] / sys / src / cmd / ndb / dnstcp.c
1 /*
2  * dnstcp - serve dns via tcp
3  */
4 #include <u.h>
5 #include <libc.h>
6 #include <ip.h>
7 #include "dns.h"
8
9 Cfg cfg;
10
11 char    *caller = "";
12 char    *dbfile;
13 int     debug;
14 char    *logfile = "dns";
15 int     maxage = 60*60;
16 char    mntpt[Maxpath];
17 int     needrefresh;
18 ulong   now;
19 vlong   nowns;
20 int     testing;
21 int     traceactivity;
22 char    *zonerefreshprogram;
23
24 static int      readmsg(int, uchar*, int);
25 static void     reply(int, DNSmsg*, Request*);
26 static void     dnzone(DNSmsg*, DNSmsg*, Request*);
27 static void     getcaller(char*);
28 static void     refreshmain(char*);
29
30 void
31 usage(void)
32 {
33         fprint(2, "usage: %s [-rR] [-f ndb-file] [-x netmtpt] [conndir]\n", argv0);
34         exits("usage");
35 }
36
37 void
38 main(int argc, char *argv[])
39 {
40         volatile int len, rcode;
41         volatile char tname[32];
42         char *volatile err, *volatile ext = "";
43         volatile uchar buf[64*1024], callip[IPaddrlen];
44         volatile DNSmsg reqmsg, repmsg;
45         volatile Request req;
46
47         alarm(2*60*1000);
48         cfg.cachedb = 1;
49         ARGBEGIN{
50         case 'd':
51                 debug++;
52                 break;
53         case 'f':
54                 dbfile = EARGF(usage());
55                 break;
56         case 'r':
57                 cfg.resolver = 1;
58                 break;
59         case 'R':
60                 norecursion = 1;
61                 break;
62         case 'x':
63                 ext = EARGF(usage());
64                 break;
65         default:
66                 usage();
67                 break;
68         }ARGEND
69
70         if(argc > 0)
71                 getcaller(argv[0]);
72
73         cfg.inside = 1;
74         dninit();
75
76         if(*ext == '/')
77                 snprint(mntpt, sizeof mntpt, "%s", ext);
78         else
79                 snprint(mntpt, sizeof mntpt, "/net%s", ext);
80
81         dnslog("dnstcp call from %s", caller);
82         memset(callip, 0, sizeof callip);
83         parseip(callip, caller);
84
85         db2cache(1);
86
87         memset(&req, 0, sizeof req);
88         setjmp(req.mret);
89         req.isslave = 0;
90         procsetname("main loop");
91
92         /* loop on requests */
93         for(;; putactivity(0)){
94                 now = time(nil);
95                 memset(&repmsg, 0, sizeof repmsg);
96                 len = readmsg(0, buf, sizeof buf);
97                 if(len <= 0)
98                         break;
99
100                 getactivity(&req, 0);
101                 req.aborttime = timems() + S2MS(15*Min);
102                 rcode = 0;
103                 memset(&reqmsg, 0, sizeof reqmsg);
104                 err = convM2DNS(buf, len, &reqmsg, &rcode);
105                 if(err){
106                         dnslog("server: input error: %s from %s", err, caller);
107                         free(err);
108                         break;
109                 }
110                 if (rcode == 0)
111                         if(reqmsg.qdcount < 1){
112                                 dnslog("server: no questions from %s", caller);
113                                 break;
114                         } else if(reqmsg.flags & Fresp){
115                                 dnslog("server: reply not request from %s",
116                                         caller);
117                                 break;
118                         } else if((reqmsg.flags & Omask) != Oquery){
119                                 dnslog("server: op %d from %s",
120                                         reqmsg.flags & Omask, caller);
121                                 break;
122                         }
123
124                 if(reqmsg.qd == nil){
125                         dnslog("server: no question RR from %s", caller);
126                         break;
127                 }
128
129                 if(debug)
130                         dnslog("[%d] %d: serve (%s) %d %s %s",
131                                 getpid(), req.id, caller,
132                                 reqmsg.id, reqmsg.qd->owner->name,
133                                 rrname(reqmsg.qd->type, tname, sizeof tname));
134
135                 /* loop through each question */
136                 while(reqmsg.qd)
137                         if(reqmsg.qd->type == Taxfr)
138                                 dnzone(&reqmsg, &repmsg, &req);
139                         else {
140                                 dnserver(&reqmsg, &repmsg, &req, callip, rcode);
141                                 reply(1, &repmsg, &req);
142                                 rrfreelist(repmsg.qd);
143                                 rrfreelist(repmsg.an);
144                                 rrfreelist(repmsg.ns);
145                                 rrfreelist(repmsg.ar);
146                         }
147                 rrfreelist(reqmsg.qd);          /* qd will be nil */
148                 rrfreelist(reqmsg.an);
149                 rrfreelist(reqmsg.ns);
150                 rrfreelist(reqmsg.ar);
151
152                 if(req.isslave){
153                         putactivity(0);
154                         _exits(0);
155                 }
156         }
157         refreshmain(mntpt);
158 }
159
160 static int
161 readmsg(int fd, uchar *buf, int max)
162 {
163         int n;
164         uchar x[2];
165
166         if(readn(fd, x, 2) != 2)
167                 return -1;
168         n = x[0]<<8 | x[1];
169         if(n > max)
170                 return -1;
171         if(readn(fd, buf, n) != n)
172                 return -1;
173         return n;
174 }
175
176 static void
177 reply(int fd, DNSmsg *rep, Request *req)
178 {
179         int len, rv;
180         char tname[32];
181         uchar buf[64*1024];
182         RR *rp;
183
184         if(debug){
185                 dnslog("%d: reply (%s) %s %s %ux",
186                         req->id, caller,
187                         rep->qd->owner->name,
188                         rrname(rep->qd->type, tname, sizeof tname),
189                         rep->flags);
190                 for(rp = rep->an; rp; rp = rp->next)
191                         dnslog("an %R", rp);
192                 for(rp = rep->ns; rp; rp = rp->next)
193                         dnslog("ns %R", rp);
194                 for(rp = rep->ar; rp; rp = rp->next)
195                         dnslog("ar %R", rp);
196         }
197
198
199         len = convDNS2M(rep, buf+2, sizeof(buf) - 2);
200         buf[0] = len>>8;
201         buf[1] = len;
202         rv = write(fd, buf, len+2);
203         if(rv != len+2){
204                 dnslog("[%d] sending reply: %d instead of %d", getpid(), rv,
205                         len+2);
206                 exits(0);
207         }
208 }
209
210 /*
211  *  Hash table for domain names.  The hash is based only on the
212  *  first element of the domain name.
213  */
214 extern DN       *ht[HTLEN];
215
216 static int
217 numelem(char *name)
218 {
219         int i;
220
221         i = 1;
222         for(; *name; name++)
223                 if(*name == '.')
224                         i++;
225         return i;
226 }
227
228 int
229 inzone(DN *dp, char *name, int namelen, int depth)
230 {
231         int n;
232
233         if(dp->name == nil)
234                 return 0;
235         if(numelem(dp->name) != depth)
236                 return 0;
237         n = strlen(dp->name);
238         if(n < namelen)
239                 return 0;
240         if(cistrcmp(name, dp->name + n - namelen) != 0)
241                 return 0;
242         if(n > namelen && dp->name[n - namelen - 1] != '.')
243                 return 0;
244         return 1;
245 }
246
247 static void
248 dnzone(DNSmsg *reqp, DNSmsg *repp, Request *req)
249 {
250         DN *dp, *ndp;
251         RR r, *rp;
252         int h, depth, found, nlen;
253
254         memset(repp, 0, sizeof(*repp));
255         repp->id = reqp->id;
256         repp->qd = reqp->qd;
257         reqp->qd = reqp->qd->next;
258         repp->qd->next = 0;
259         repp->flags = Fauth | Fresp | Oquery;
260         if(!norecursion)
261                 repp->flags |= Fcanrec;
262         dp = repp->qd->owner;
263
264         /* send the soa */
265         repp->an = rrlookup(dp, Tsoa, NOneg);
266         reply(1, repp, req);
267         if(repp->an == 0)
268                 goto out;
269         rrfreelist(repp->an);
270         repp->an = nil;
271
272         nlen = strlen(dp->name);
273
274         /* construct a breadth-first search of the name space (hard with a hash) */
275         repp->an = &r;
276         for(depth = numelem(dp->name); ; depth++){
277                 found = 0;
278                 for(h = 0; h < HTLEN; h++)
279                         for(ndp = ht[h]; ndp; ndp = ndp->next)
280                                 if(inzone(ndp, dp->name, nlen, depth)){
281                                         for(rp = ndp->rr; rp; rp = rp->next){
282                                                 /*
283                                                  * there shouldn't be negatives,
284                                                  * but just in case.
285                                                  * don't send any soa's,
286                                                  * ns's are enough.
287                                                  */
288                                                 if (rp->negative ||
289                                                     rp->type == Tsoa)
290                                                         continue;
291                                                 r = *rp;
292                                                 r.next = 0;
293                                                 reply(1, repp, req);
294                                         }
295                                         found = 1;
296                                 }
297                 if(!found)
298                         break;
299         }
300
301         /* resend the soa */
302         repp->an = rrlookup(dp, Tsoa, NOneg);
303         reply(1, repp, req);
304         rrfreelist(repp->an);
305         repp->an = nil;
306 out:
307         rrfree(repp->qd);
308         repp->qd = nil;
309 }
310
311 static void
312 getcaller(char *dir)
313 {
314         int fd, n;
315         static char remote[128];
316
317         snprint(remote, sizeof(remote), "%s/remote", dir);
318         fd = open(remote, OREAD);
319         if(fd < 0)
320                 return;
321         n = read(fd, remote, sizeof remote - 1);
322         close(fd);
323         if(n <= 0)
324                 return;
325         if(remote[n-1] == '\n')
326                 n--;
327         remote[n] = 0;
328         caller = remote;
329 }
330
331 static void
332 refreshmain(char *net)
333 {
334         int fd;
335         char file[128];
336
337         snprint(file, sizeof(file), "%s/dns", net);
338         if(debug)
339                 dnslog("refreshing %s", file);
340         fd = open(file, ORDWR);
341         if(fd < 0)
342                 dnslog("can't refresh %s", file);
343         else {
344                 fprint(fd, "refresh");
345                 close(fd);
346         }
347 }
348
349 /*
350  *  the following varies between dnsdebug and dns
351  */
352 void
353 logreply(int id, uchar *addr, DNSmsg *mp)
354 {
355         RR *rp;
356
357         dnslog("%d: rcvd %I flags:%s%s%s%s%s", id, addr,
358                 mp->flags & Fauth? " auth": "",
359                 mp->flags & Ftrunc? " trunc": "",
360                 mp->flags & Frecurse? " rd": "",
361                 mp->flags & Fcanrec? " ra": "",
362                 (mp->flags & (Fauth|Rmask)) == (Fauth|Rname)? " nx": "");
363         for(rp = mp->qd; rp != nil; rp = rp->next)
364                 dnslog("%d: rcvd %I qd %s", id, addr, rp->owner->name);
365         for(rp = mp->an; rp != nil; rp = rp->next)
366                 dnslog("%d: rcvd %I an %R", id, addr, rp);
367         for(rp = mp->ns; rp != nil; rp = rp->next)
368                 dnslog("%d: rcvd %I ns %R", id, addr, rp);
369         for(rp = mp->ar; rp != nil; rp = rp->next)
370                 dnslog("%d: rcvd %I ar %R", id, addr, rp);
371 }
372
373 void
374 logsend(int id, int subid, uchar *addr, char *sname, char *rname, int type)
375 {
376         char buf[12];
377
378         dnslog("%d.%d: sending to %I/%s %s %s",
379                 id, subid, addr, sname, rname, rrname(type, buf, sizeof buf));
380 }
381
382 RR*
383 getdnsservers(int class)
384 {
385         return dnsservers(class);
386 }