]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/ndb/dnstcp.c
merge
[plan9front.git] / sys / src / cmd / ndb / dnstcp.c
1 /*
2  * dnstcp - serve dns via tcp
3  */
4 #include <u.h>
5 #include <libc.h>
6 #include <ip.h>
7 #include "dns.h"
8
9 Cfg cfg;
10
11 char    *caller = "";
12 char    *dbfile;
13 int     debug;
14 uchar   ipaddr[IPaddrlen];      /* my ip address */
15 char    *logfile = "dns";
16 int     maxage = 60*60;
17 char    mntpt[Maxpath];
18 int     needrefresh;
19 ulong   now;
20 vlong   nowns;
21 int     testing;
22 int     traceactivity;
23 char    *zonerefreshprogram;
24
25 static int      readmsg(int, uchar*, int);
26 static void     reply(int, DNSmsg*, Request*);
27 static void     dnzone(DNSmsg*, DNSmsg*, Request*);
28 static void     getcaller(char*);
29 static void     refreshmain(char*);
30
31 void
32 usage(void)
33 {
34         fprint(2, "usage: %s [-rR] [-f ndb-file] [-x netmtpt] [conndir]\n", argv0);
35         exits("usage");
36 }
37
38 void
39 main(int argc, char *argv[])
40 {
41         volatile int len, rcode;
42         volatile char tname[32];
43         char *volatile err, *volatile ext = "";
44         volatile uchar buf[64*1024], callip[IPaddrlen];
45         volatile DNSmsg reqmsg, repmsg;
46         volatile Request req;
47
48         alarm(2*60*1000);
49         cfg.cachedb = 1;
50         ARGBEGIN{
51         case 'd':
52                 debug++;
53                 break;
54         case 'f':
55                 dbfile = EARGF(usage());
56                 break;
57         case 'r':
58                 cfg.resolver = 1;
59                 break;
60         case 'R':
61                 norecursion = 1;
62                 break;
63         case 'x':
64                 ext = EARGF(usage());
65                 break;
66         default:
67                 usage();
68                 break;
69         }ARGEND
70
71         if(debug < 2)
72                 debug = 0;
73
74         if(argc > 0)
75                 getcaller(argv[0]);
76
77         cfg.inside = 1;
78         dninit();
79
80         snprint(mntpt, sizeof mntpt, "/net%s", ext);
81         if(myipaddr(ipaddr, mntpt) < 0)
82                 sysfatal("can't read my ip address");
83         dnslog("dnstcp call from %s to %I", caller, ipaddr);
84         memset(callip, 0, sizeof callip);
85         parseip(callip, caller);
86
87         db2cache(1);
88
89         memset(&req, 0, sizeof req);
90         setjmp(req.mret);
91         req.isslave = 0;
92         procsetname("main loop");
93
94         /* loop on requests */
95         for(;; putactivity(0)){
96                 now = time(nil);
97                 memset(&repmsg, 0, sizeof repmsg);
98                 len = readmsg(0, buf, sizeof buf);
99                 if(len <= 0)
100                         break;
101
102                 getactivity(&req, 0);
103                 req.aborttime = timems() + S2MS(15*Min);
104                 rcode = 0;
105                 memset(&reqmsg, 0, sizeof reqmsg);
106                 err = convM2DNS(buf, len, &reqmsg, &rcode);
107                 if(err){
108                         dnslog("server: input error: %s from %s", err, caller);
109                         free(err);
110                         break;
111                 }
112                 if (rcode == 0)
113                         if(reqmsg.qdcount < 1){
114                                 dnslog("server: no questions from %s", caller);
115                                 break;
116                         } else if(reqmsg.flags & Fresp){
117                                 dnslog("server: reply not request from %s",
118                                         caller);
119                                 break;
120                         } else if((reqmsg.flags & Omask) != Oquery){
121                                 dnslog("server: op %d from %s",
122                                         reqmsg.flags & Omask, caller);
123                                 break;
124                         }
125
126                 if(reqmsg.qd == nil){
127                         dnslog("server: no question RR from %s", caller);
128                         break;
129                 }
130
131                 if(debug)
132                         dnslog("[%d] %d: serve (%s) %d %s %s",
133                                 getpid(), req.id, caller,
134                                 reqmsg.id, reqmsg.qd->owner->name,
135                                 rrname(reqmsg.qd->type, tname, sizeof tname));
136
137                 /* loop through each question */
138                 while(reqmsg.qd)
139                         if(reqmsg.qd->type == Taxfr)
140                                 dnzone(&reqmsg, &repmsg, &req);
141                         else {
142                                 dnserver(&reqmsg, &repmsg, &req, callip, rcode);
143                                 reply(1, &repmsg, &req);
144                                 rrfreelist(repmsg.qd);
145                                 rrfreelist(repmsg.an);
146                                 rrfreelist(repmsg.ns);
147                                 rrfreelist(repmsg.ar);
148                         }
149                 rrfreelist(reqmsg.qd);          /* qd will be nil */
150                 rrfreelist(reqmsg.an);
151                 rrfreelist(reqmsg.ns);
152                 rrfreelist(reqmsg.ar);
153
154                 if(req.isslave){
155                         putactivity(0);
156                         _exits(0);
157                 }
158         }
159         refreshmain(mntpt);
160 }
161
162 static int
163 readmsg(int fd, uchar *buf, int max)
164 {
165         int n;
166         uchar x[2];
167
168         if(readn(fd, x, 2) != 2)
169                 return -1;
170         n = x[0]<<8 | x[1];
171         if(n > max)
172                 return -1;
173         if(readn(fd, buf, n) != n)
174                 return -1;
175         return n;
176 }
177
178 static void
179 reply(int fd, DNSmsg *rep, Request *req)
180 {
181         int len, rv;
182         char tname[32];
183         uchar buf[64*1024];
184         RR *rp;
185
186         if(debug){
187                 dnslog("%d: reply (%s) %s %s %ux",
188                         req->id, caller,
189                         rep->qd->owner->name,
190                         rrname(rep->qd->type, tname, sizeof tname),
191                         rep->flags);
192                 for(rp = rep->an; rp; rp = rp->next)
193                         dnslog("an %R", rp);
194                 for(rp = rep->ns; rp; rp = rp->next)
195                         dnslog("ns %R", rp);
196                 for(rp = rep->ar; rp; rp = rp->next)
197                         dnslog("ar %R", rp);
198         }
199
200
201         len = convDNS2M(rep, buf+2, sizeof(buf) - 2);
202         buf[0] = len>>8;
203         buf[1] = len;
204         rv = write(fd, buf, len+2);
205         if(rv != len+2){
206                 dnslog("[%d] sending reply: %d instead of %d", getpid(), rv,
207                         len+2);
208                 exits(0);
209         }
210 }
211
212 /*
213  *  Hash table for domain names.  The hash is based only on the
214  *  first element of the domain name.
215  */
216 extern DN       *ht[HTLEN];
217
218 static int
219 numelem(char *name)
220 {
221         int i;
222
223         i = 1;
224         for(; *name; name++)
225                 if(*name == '.')
226                         i++;
227         return i;
228 }
229
230 int
231 inzone(DN *dp, char *name, int namelen, int depth)
232 {
233         int n;
234
235         if(dp->name == nil)
236                 return 0;
237         if(numelem(dp->name) != depth)
238                 return 0;
239         n = strlen(dp->name);
240         if(n < namelen)
241                 return 0;
242         if(cistrcmp(name, dp->name + n - namelen) != 0)
243                 return 0;
244         if(n > namelen && dp->name[n - namelen - 1] != '.')
245                 return 0;
246         return 1;
247 }
248
249 static void
250 dnzone(DNSmsg *reqp, DNSmsg *repp, Request *req)
251 {
252         DN *dp, *ndp;
253         RR r, *rp;
254         int h, depth, found, nlen;
255
256         memset(repp, 0, sizeof(*repp));
257         repp->id = reqp->id;
258         repp->qd = reqp->qd;
259         reqp->qd = reqp->qd->next;
260         repp->qd->next = 0;
261         repp->flags = Fauth | Fresp | Oquery;
262         if(!norecursion)
263                 repp->flags |= Fcanrec;
264         dp = repp->qd->owner;
265
266         /* send the soa */
267         repp->an = rrlookup(dp, Tsoa, NOneg);
268         reply(1, repp, req);
269         if(repp->an == 0)
270                 goto out;
271         rrfreelist(repp->an);
272         repp->an = nil;
273
274         nlen = strlen(dp->name);
275
276         /* construct a breadth-first search of the name space (hard with a hash) */
277         repp->an = &r;
278         for(depth = numelem(dp->name); ; depth++){
279                 found = 0;
280                 for(h = 0; h < HTLEN; h++)
281                         for(ndp = ht[h]; ndp; ndp = ndp->next)
282                                 if(inzone(ndp, dp->name, nlen, depth)){
283                                         for(rp = ndp->rr; rp; rp = rp->next){
284                                                 /*
285                                                  * there shouldn't be negatives,
286                                                  * but just in case.
287                                                  * don't send any soa's,
288                                                  * ns's are enough.
289                                                  */
290                                                 if (rp->negative ||
291                                                     rp->type == Tsoa)
292                                                         continue;
293                                                 r = *rp;
294                                                 r.next = 0;
295                                                 reply(1, repp, req);
296                                         }
297                                         found = 1;
298                                 }
299                 if(!found)
300                         break;
301         }
302
303         /* resend the soa */
304         repp->an = rrlookup(dp, Tsoa, NOneg);
305         reply(1, repp, req);
306         rrfreelist(repp->an);
307         repp->an = nil;
308 out:
309         rrfree(repp->qd);
310         repp->qd = nil;
311 }
312
313 static void
314 getcaller(char *dir)
315 {
316         int fd, n;
317         static char remote[128];
318
319         snprint(remote, sizeof(remote), "%s/remote", dir);
320         fd = open(remote, OREAD);
321         if(fd < 0)
322                 return;
323         n = read(fd, remote, sizeof remote - 1);
324         close(fd);
325         if(n <= 0)
326                 return;
327         if(remote[n-1] == '\n')
328                 n--;
329         remote[n] = 0;
330         caller = remote;
331 }
332
333 static void
334 refreshmain(char *net)
335 {
336         int fd;
337         char file[128];
338
339         snprint(file, sizeof(file), "%s/dns", net);
340         if(debug)
341                 dnslog("refreshing %s", file);
342         fd = open(file, ORDWR);
343         if(fd < 0)
344                 dnslog("can't refresh %s", file);
345         else {
346                 fprint(fd, "refresh");
347                 close(fd);
348         }
349 }
350
351 /*
352  *  the following varies between dnsdebug and dns
353  */
354 void
355 logreply(int id, uchar *addr, DNSmsg *mp)
356 {
357         RR *rp;
358
359         dnslog("%d: rcvd %I flags:%s%s%s%s%s", id, addr,
360                 mp->flags & Fauth? " auth": "",
361                 mp->flags & Ftrunc? " trunc": "",
362                 mp->flags & Frecurse? " rd": "",
363                 mp->flags & Fcanrec? " ra": "",
364                 (mp->flags & (Fauth|Rmask)) == (Fauth|Rname)? " nx": "");
365         for(rp = mp->qd; rp != nil; rp = rp->next)
366                 dnslog("%d: rcvd %I qd %s", id, addr, rp->owner->name);
367         for(rp = mp->an; rp != nil; rp = rp->next)
368                 dnslog("%d: rcvd %I an %R", id, addr, rp);
369         for(rp = mp->ns; rp != nil; rp = rp->next)
370                 dnslog("%d: rcvd %I ns %R", id, addr, rp);
371         for(rp = mp->ar; rp != nil; rp = rp->next)
372                 dnslog("%d: rcvd %I ar %R", id, addr, rp);
373 }
374
375 void
376 logsend(int id, int subid, uchar *addr, char *sname, char *rname, int type)
377 {
378         char buf[12];
379
380         dnslog("%d.%d: sending to %I/%s %s %s",
381                 id, subid, addr, sname, rname, rrname(type, buf, sizeof buf));
382 }
383
384 RR*
385 getdnsservers(int class)
386 {
387         return dnsservers(class);
388 }