]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/ndb/dnstcp.c
ircrc: freenode -> oftc
[plan9front.git] / sys / src / cmd / ndb / dnstcp.c
1 /*
2  * dnstcp - serve dns via tcp
3  */
4 #include <u.h>
5 #include <libc.h>
6 #include <bio.h>
7 #include <ndb.h>
8 #include <ip.h>
9 #include "dns.h"
10
11 Cfg cfg;
12
13 char    *caller = "";
14 char    *dbfile;
15 int     anyone;
16 int     debug;
17 char    *logfile = "dns";
18 int     maxage = 60*60;
19 char    mntpt[Maxpath];
20 int     needrefresh;
21 ulong   now;
22 vlong   nowns;
23 int     traceactivity;
24 char    *zonerefreshprogram;
25
26 static int      readmsg(int, uchar*, int);
27 static void     reply(int, DNSmsg*, Request*);
28 static void     dnzone(DNSmsg*, DNSmsg*, Request*, uchar*);
29 static void     getcaller(char*);
30 static void     refreshmain(char*);
31
32 void
33 usage(void)
34 {
35         fprint(2, "usage: %s [-adrR] [-f ndbfile] [-x netmtpt] [conndir]\n", argv0);
36         exits("usage");
37 }
38
39 void
40 main(int argc, char *argv[])
41 {
42         volatile int len, rcode;
43         volatile char tname[32];
44         char *volatile err, *volatile ext = "";
45         volatile uchar buf[64*1024], callip[IPaddrlen];
46         volatile DNSmsg reqmsg, repmsg;
47         volatile Request req;
48
49         cfg.cachedb = 1;
50         ARGBEGIN{
51         case 'a':
52                 anyone++;
53                 break;
54         case 'd':
55                 debug++;
56                 break;
57         case 'f':
58                 dbfile = EARGF(usage());
59                 break;
60         case 'r':
61                 cfg.resolver = 1;
62                 break;
63         case 'R':
64                 norecursion = 1;
65                 break;
66         case 'x':
67                 ext = EARGF(usage());
68                 break;
69         default:
70                 usage();
71                 break;
72         }ARGEND
73
74         if(argc > 0)
75                 getcaller(argv[0]);
76
77         cfg.inside = 1;
78         dninit();
79
80         if(*ext == '/')
81                 snprint(mntpt, sizeof mntpt, "%s", ext);
82         else
83                 snprint(mntpt, sizeof mntpt, "/net%s", ext);
84
85         dnslog("dnstcp call from %s", caller);
86         memset(callip, 0, sizeof callip);
87         parseip(callip, caller);
88
89         srand(truerand());
90         db2cache(1);
91
92         memset(&req, 0, sizeof req);
93         setjmp(req.mret);
94         req.isslave = 0;
95         procsetname("main loop");
96
97         alarm(10*1000);
98
99         /* loop on requests */
100         for(;; putactivity(0)){
101                 now = time(nil);
102                 memset(&repmsg, 0, sizeof repmsg);
103                 len = readmsg(0, buf, sizeof buf);
104                 if(len <= 0)
105                         break;
106
107                 getactivity(&req, 0);
108                 req.aborttime = timems() + S2MS(15*Min);
109                 rcode = 0;
110                 memset(&reqmsg, 0, sizeof reqmsg);
111                 err = convM2DNS(buf, len, &reqmsg, &rcode);
112                 if(err){
113                         dnslog("server: input error: %s from %s", err, caller);
114                         free(err);
115                         break;
116                 }
117                 if (rcode == 0)
118                         if(reqmsg.qdcount < 1){
119                                 dnslog("server: no questions from %s", caller);
120                                 break;
121                         } else if(reqmsg.flags & Fresp){
122                                 dnslog("server: reply not request from %s",
123                                         caller);
124                                 break;
125                         } else if((reqmsg.flags & Omask) != Oquery){
126                                 dnslog("server: op %d from %s",
127                                         reqmsg.flags & Omask, caller);
128                                 break;
129                         }
130
131                 if(reqmsg.qd == nil){
132                         dnslog("server: no question RR from %s", caller);
133                         break;
134                 }
135
136                 if(debug)
137                         dnslog("[%d] %d: serve (%s) %d %s %s",
138                                 getpid(), req.id, caller,
139                                 reqmsg.id, reqmsg.qd->owner->name,
140                                 rrname(reqmsg.qd->type, tname, sizeof tname));
141
142                 /* loop through each question */
143                 while(reqmsg.qd)
144                         if(reqmsg.qd->type == Taxfr)
145                                 dnzone(&reqmsg, &repmsg, &req, callip);
146                         else {
147                                 dnserver(&reqmsg, &repmsg, &req, callip, rcode);
148                                 reply(1, &repmsg, &req);
149                                 rrfreelist(repmsg.qd);
150                                 rrfreelist(repmsg.an);
151                                 rrfreelist(repmsg.ns);
152                                 rrfreelist(repmsg.ar);
153                         }
154                 rrfreelist(reqmsg.qd);          /* qd will be nil */
155                 rrfreelist(reqmsg.an);
156                 rrfreelist(reqmsg.ns);
157                 rrfreelist(reqmsg.ar);
158
159                 if(req.isslave){
160                         putactivity(0);
161                         _exits(0);
162                 }
163         }
164         refreshmain(mntpt);
165 }
166
167 static int
168 readmsg(int fd, uchar *buf, int max)
169 {
170         int n;
171         uchar x[2];
172
173         if(readn(fd, x, 2) != 2)
174                 return -1;
175         n = x[0]<<8 | x[1];
176         if(n > max)
177                 return -1;
178         if(readn(fd, buf, n) != n)
179                 return -1;
180         return n;
181 }
182
183 static void
184 reply(int fd, DNSmsg *rep, Request *req)
185 {
186         int len, rv;
187         char tname[32];
188         uchar buf[64*1024];
189         RR *rp;
190
191         if(debug){
192                 dnslog("%d: reply (%s) %s %s %ux",
193                         req->id, caller,
194                         rep->qd->owner->name,
195                         rrname(rep->qd->type, tname, sizeof tname),
196                         rep->flags);
197                 for(rp = rep->an; rp; rp = rp->next)
198                         dnslog("an %R", rp);
199                 for(rp = rep->ns; rp; rp = rp->next)
200                         dnslog("ns %R", rp);
201                 for(rp = rep->ar; rp; rp = rp->next)
202                         dnslog("ar %R", rp);
203         }
204
205
206         len = convDNS2M(rep, buf+2, sizeof(buf) - 2);
207         buf[0] = len>>8;
208         buf[1] = len;
209         rv = write(fd, buf, len+2);
210         if(rv != len+2){
211                 dnslog("[%d] sending reply: %d instead of %d", getpid(), rv,
212                         len+2);
213                 exits(0);
214         }
215 }
216
217 /*
218  *  Hash table for domain names.  The hash is based only on the
219  *  first element of the domain name.
220  */
221 extern DN       *ht[HTLEN];
222
223 static int
224 numelem(char *name)
225 {
226         int i;
227
228         i = 1;
229         for(; *name; name++)
230                 if(*name == '.')
231                         i++;
232         return i;
233 }
234
235 int
236 inzone(DN *dp, char *name, int namelen, int depth)
237 {
238         int n;
239
240         if(dp->name == nil)
241                 return 0;
242         if(numelem(dp->name) != depth)
243                 return 0;
244         n = strlen(dp->name);
245         if(n < namelen)
246                 return 0;
247         if(cistrcmp(name, dp->name + n - namelen) != 0)
248                 return 0;
249         if(n > namelen && dp->name[n - namelen - 1] != '.')
250                 return 0;
251         return 1;
252 }
253
254 static Server*
255 findserver(uchar *srcip, Server *servers, Request *req)
256 {
257         uchar ip[IPaddrlen];
258         RR *list, *rp;
259         int tmp;
260
261         for(; servers != nil; servers = servers->next){
262                 if(strcmp(ipattr(servers->name), "ip") == 0){
263                         if(parseip(ip, servers->name) == -1)
264                                 continue;
265                         if(ipcmp(srcip, ip) == 0)
266                                 return servers;
267                         continue;
268                 }
269
270                 tmp = cfg.resolver;
271                 cfg.resolver = 1;
272                 list = dnresolve(servers->name, Cin, isv4(srcip)? Ta: Taaaa,
273                         req, nil, 0, Recurse, 0, nil);
274                 cfg.resolver = tmp;
275
276                 for(rp = list; rp != nil; rp = rp->next){
277                         if(parseip(ip, rp->ip->name) == -1)
278                                 continue;
279                         if(ipcmp(srcip, ip) == 0)
280                                 break;
281                 }
282                 rrfreelist(list);
283                 if(rp != nil)
284                         return servers;
285         }
286         return nil;
287 }
288
289 static void
290 dnzone(DNSmsg *reqp, DNSmsg *repp, Request *req, uchar *srcip)
291 {
292         DN *dp, *ndp;
293         RR r, *rp;
294         int h, depth, found, nlen;
295
296         memset(repp, 0, sizeof(*repp));
297         repp->id = reqp->id;
298         repp->qd = reqp->qd;
299         reqp->qd = reqp->qd->next;
300         repp->qd->next = 0;
301         repp->flags = Fauth | Fresp | Oquery;
302         if(!norecursion)
303                 repp->flags |= Fcanrec;
304         dp = repp->qd->owner;
305
306         /* send the soa */
307         repp->an = rrlookup(dp, Tsoa, NOneg);
308         if(repp->an != nil && !anyone && !myip(srcip)
309         && findserver(srcip, repp->an->soa->slaves, req) == nil){
310                 dnslog("dnstcp: %I axfr %s - not a dnsslave", srcip, dp->name);
311                 rrfreelist(repp->an);
312                 repp->an = nil;
313         }
314         reply(1, repp, req);
315         if(repp->an == nil)
316                 goto out;
317         rrfreelist(repp->an);
318         repp->an = nil;
319
320         nlen = strlen(dp->name);
321
322         /* construct a breadth-first search of the name space (hard with a hash) */
323         repp->an = &r;
324         for(depth = numelem(dp->name); ; depth++){
325                 found = 0;
326                 for(h = 0; h < HTLEN; h++)
327                         for(ndp = ht[h]; ndp; ndp = ndp->next)
328                                 if(inzone(ndp, dp->name, nlen, depth)){
329                                         for(rp = ndp->rr; rp; rp = rp->next){
330                                                 /*
331                                                  * there shouldn't be negatives,
332                                                  * but just in case.
333                                                  * don't send any soa's,
334                                                  * ns's are enough.
335                                                  */
336                                                 if (rp->negative ||
337                                                     rp->type == Tsoa)
338                                                         continue;
339                                                 r = *rp;
340                                                 r.next = 0;
341                                                 reply(1, repp, req);
342                                         }
343                                         found = 1;
344                                 }
345                 if(!found)
346                         break;
347         }
348
349         /* resend the soa */
350         repp->an = rrlookup(dp, Tsoa, NOneg);
351         reply(1, repp, req);
352         rrfreelist(repp->an);
353         repp->an = nil;
354 out:
355         rrfree(repp->qd);
356         repp->qd = nil;
357 }
358
359 static void
360 getcaller(char *dir)
361 {
362         int fd, n;
363         static char remote[128];
364
365         snprint(remote, sizeof(remote), "%s/remote", dir);
366         fd = open(remote, OREAD);
367         if(fd < 0)
368                 return;
369         n = read(fd, remote, sizeof remote - 1);
370         close(fd);
371         if(n <= 0)
372                 return;
373         if(remote[n-1] == '\n')
374                 n--;
375         remote[n] = 0;
376         caller = remote;
377 }
378
379 static void
380 refreshmain(char *net)
381 {
382         int fd;
383         char file[128];
384
385         snprint(file, sizeof(file), "%s/dns", net);
386         if(debug)
387                 dnslog("refreshing %s", file);
388         fd = open(file, ORDWR);
389         if(fd < 0)
390                 dnslog("can't refresh %s", file);
391         else {
392                 fprint(fd, "refresh");
393                 close(fd);
394         }
395 }
396
397 /*
398  *  the following varies between dnsdebug and dns
399  */
400 void
401 logreply(int id, uchar *addr, DNSmsg *mp)
402 {
403         RR *rp;
404
405         dnslog("%d: rcvd %I flags:%s%s%s%s%s", id, addr,
406                 mp->flags & Fauth? " auth": "",
407                 mp->flags & Ftrunc? " trunc": "",
408                 mp->flags & Frecurse? " rd": "",
409                 mp->flags & Fcanrec? " ra": "",
410                 (mp->flags & (Fauth|Rmask)) == (Fauth|Rname)? " nx": "");
411         for(rp = mp->qd; rp != nil; rp = rp->next)
412                 dnslog("%d: rcvd %I qd %s", id, addr, rp->owner->name);
413         for(rp = mp->an; rp != nil; rp = rp->next)
414                 dnslog("%d: rcvd %I an %R", id, addr, rp);
415         for(rp = mp->ns; rp != nil; rp = rp->next)
416                 dnslog("%d: rcvd %I ns %R", id, addr, rp);
417         for(rp = mp->ar; rp != nil; rp = rp->next)
418                 dnslog("%d: rcvd %I ar %R", id, addr, rp);
419 }
420
421 void
422 logsend(int id, int subid, uchar *addr, char *sname, char *rname, int type)
423 {
424         char buf[12];
425
426         dnslog("%d.%d: sending to %I/%s %s %s",
427                 id, subid, addr, sname, rname, rrname(type, buf, sizeof buf));
428 }
429
430 RR*
431 getdnsservers(int class)
432 {
433         return dnsservers(class);
434 }