]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/upas/smtp/mxdial.c
74516a030eec082ae4f3797a6a48ba498a9bd201
[plan9front.git] / sys / src / cmd / upas / smtp / mxdial.c
1 #include "common.h"
2 #include "smtp.h"
3 #include <ndb.h>
4
5 char    *bustedmxs[Maxbustedmx];
6
7 static void
8 expand(DS *ds)
9 {
10         char *s;
11         Ndbtuple *t;
12
13         s = ds->host + 1;
14         t = csipinfo(ds->netdir, "sys", sysname(), &s, 1);
15         if(t != nil){
16                 strecpy(ds->expand, ds->expand+sizeof ds->expand, t->val);
17                 ds->host = ds->expand;
18         }
19         ndbfree(t);
20 }
21
22 /* break up an address to its component parts */
23 void
24 dialstringparse(char *str, DS *ds)
25 {
26         char *p, *p2;
27
28         strecpy(ds->buf, ds->buf + sizeof ds->buf, str);
29         p = strchr(ds->buf, '!');
30         if(p == 0) {
31                 ds->netdir = 0;
32                 ds->proto = "net";
33                 ds->host = ds->buf;
34         } else {
35                 if(*ds->buf != '/'){
36                         ds->netdir = 0;
37                         ds->proto = ds->buf;
38                 } else {
39                         for(p2 = p; *p2 != '/'; p2--)
40                                 ;
41                         *p2++ = 0;
42                         ds->netdir = ds->buf;
43                         ds->proto = p2;
44                 }
45                 *p = 0;
46                 ds->host = p + 1;
47         }
48         ds->service = strchr(ds->host, '!');
49         if(ds->service)
50                 *ds->service++ = 0;
51         if(*ds->host == '$')
52                 expand(ds);
53 }
54
55 void
56 mxtabfree(Mxtab *mx)
57 {
58         free(mx->mx);
59         memset(mx, 0, sizeof *mx);
60 }
61
62 static void
63 mxtabrealloc(Mxtab *mx)
64 {
65         if(mx->nmx < mx->amx)
66                 return;
67         if(mx->amx == 0)
68                 mx->amx = 1;
69         mx->amx <<= 1;
70         mx->mx = realloc(mx->mx, sizeof mx->mx[0] * mx->amx);
71         if(mx->mx == nil)
72                 sysfatal("no memory for mx");
73 }
74
75 static void
76 mxtabadd(Mxtab *mx, char *host, char *ip, char *net, int pref)
77 {
78         int i;
79         Mx *x;
80
81         mxtabrealloc(mx);
82         x = mx->mx;
83         for(i = mx->nmx; i>0 && x[i-1].pref>pref && x[i-1].netdir == net; i--)
84                 x[i] = x[i-1];
85         strecpy(x[i].host, x[i].host + sizeof x[i].host, host);
86         if(ip != nil)
87                 strecpy(x[i].ip, x[i].ip + sizeof x[i].ip, ip);
88         else
89                 x[i].ip[0] = 0;
90         x[i].netdir = net;
91         x[i].pref = pref;
92         x[i].valid = 1;
93         mx->nmx++;
94 }
95
96 static int
97 timeout(void*, char *msg)
98 {
99         if(strstr(msg, "alarm"))
100                 return 1;
101         return 0;
102 }
103
104 static long
105 timedwrite(int fd, void *buf, long len, long ms)
106 {
107         long n, oalarm;
108
109         atnotify(timeout, 1);
110         oalarm = alarm(ms);
111         n = pwrite(fd, buf, len, 0);
112         alarm(oalarm);
113         atnotify(timeout, 0);
114         return n;
115 }
116
117 static int
118 dnslookup(Mxtab *mx, int fd, char *query, char *domain, char *net, int pref0)
119 {
120         int n;
121         char buf[1024], *f[4];
122
123         n = timedwrite(fd, query, strlen(query), 60*1000);
124         if(n < 0){
125                 rerrstr(buf, sizeof buf);
126                 dprint("dns: %s\n", buf);
127                 if(strstr(buf, "dns failure")){
128                         /* if dns fails for the mx lookup, we have to stop */
129                         close(fd);
130                         return -1;
131                 }
132                 return 0;
133         }
134
135         seek(fd, 0, 0);
136         for(;;){
137                 if((n = read(fd, buf, sizeof buf - 1)) < 1)
138                         break;
139                 buf[n] = 0;
140         //      chat("dns: %s\n", buf);
141                 n = tokenize(buf, f, nelem(f));
142                 if(n < 2)
143                         continue;
144                 if(strcmp(f[1], "mx") == 0 && n == 4){
145                         if(strchr(domain, '.') == 0)
146                                 strcpy(domain, f[0]);
147                         mxtabadd(mx, f[3], nil, net, atoi(f[2]));
148                 }
149                 else if (strcmp(f[1], "ip") == 0 && n == 3){
150                         if(strchr(domain, '.') == 0)
151                                 strcpy(domain, f[0]);
152                         mxtabadd(mx, f[0], f[2], net, pref0);
153                 }
154         }
155
156         return 0;
157 }
158
159 static int
160 busted(char *mx)
161 {
162         char **bmp;
163
164         for (bmp = bustedmxs; *bmp != nil; bmp++)
165                 if (strcmp(mx, *bmp) == 0)
166                         return 1;
167         return 0;
168 }
169
170 static void
171 complain(Mxtab *mx, char *domain)
172 {
173         char buf[1024], *e, *p;
174         int i;
175
176         p = buf;
177         e = buf + sizeof buf;
178         for(i = 0; i < mx->nmx; i++)
179                 p = seprint(p, e, "%s ", mx->mx[i].ip);
180         syslog(0, "smtpd.mx", "loopback for %s %s", domain, buf);
181 }
182
183 static int
184 okaymx(Mxtab *mx, char *domain)
185 {
186         int i;
187         Mx *x;
188
189         /* look for malicious dns entries; TODO use badcidr in ../spf/ to catch more than ip4 */
190         for(i = 0; i < mx->nmx; i++){
191                 x = mx->mx + i;
192                 if(x->valid && strcmp(x->ip, "127.0.0.1") == 0){
193                         dprint("illegal: domain %s lists 127.0.0.1 as mail server", domain);
194                         complain(mx, domain);
195                         werrstr("illegal: domain %s lists 127.0.0.1 as mail server", domain);
196                         return -1;
197                 }
198                 if(x->valid && busted(x->host)){
199                         dprint("lookup: skipping busted mx %s\n", x->host);
200                         x->valid = 0;
201                 }
202         }
203         return 0;
204 }
205
206 static int
207 lookup(Mxtab *mx, char *net, char *host, char *domain, char *type)
208 {
209         char dns[128], buf[1024];
210         int fd, i;
211         Mx *x;
212
213         snprint(dns, sizeof dns, "%s/dns", net);
214         fd = open(dns, ORDWR);
215         if(fd == -1)
216                 return -1;
217
218         snprint(buf, sizeof buf, "%s %s", host, type);
219         dprint("sending %s '%s'\n", dns, buf);
220         dnslookup(mx, fd, buf, domain, net, 10000);
221
222         for(i = 0; i < mx->nmx; i++){
223                 x = mx->mx + i;
224                 if(x->ip[0] != 0)
225                         continue;
226                 x->valid = 0;
227
228                 snprint(buf, sizeof buf, "%s %s", x->host, "ip");
229                 dprint("sending %s '%s'\n", dns, buf);
230                 dnslookup(mx, fd, buf, domain, net, x->pref);
231         }
232
233         close(fd);
234
235         if(strcmp(type, "mx") == 0){
236                 if(okaymx(mx, domain) == -1)
237                         return -1;
238                 for(i = 0; i < mx->nmx; i++){
239                         x = mx->mx + i;
240                         dprint("mx list: %s     %d      %s\n", x->host, x->pref, x->ip);
241                 }
242                 dprint("\n");
243         }
244
245         return 0;
246 }
247
248 static int
249 lookcall(Mxtab *mx, DS *d, char *domain, char *type)
250 {
251         char buf[1024];
252         int i;
253         Mx *x;
254
255         if(lookup(mx, d->netdir, d->host, domain, type) == -1){
256                 for(i = 0; i < mx->nmx; i++)
257                         if(mx->mx[i].netdir == d->netdir)
258                                 mx->mx[i].valid = 0;
259                 return -1;
260         }
261
262         for(i = 0; i < mx->nmx; i++){
263                 x = mx->mx + i;
264                 if(x->ip[0] == 0 || x->valid == 0){
265                         x->valid = 0;
266                         continue;
267                 }
268                 snprint(buf, sizeof buf, "%s/%s!%s!%s", d->netdir, d->proto,
269                         x->ip /*x->host*/, d->service);
270                 dprint("mxdial trying %s        [%s]\n", x->host, buf);
271                 atnotify(timeout, 1);
272                 alarm(10*1000);
273                 mx->fd = dial(buf, 0, 0, 0);
274                 alarm(0);
275                 atnotify(timeout, 0);
276                 if(mx->fd >= 0){
277                         mx->pmx = i;
278                         return mx->fd;
279                 }
280                 dprint("        failed %r\n");
281                 x->valid = 0;
282         }
283
284         return -1;
285 }
286
287 int
288 mxdial0(char *addr, char *ddomain, char *gdomain, Mxtab *mx)
289 {
290         int nd, i, j;
291         DS *d;
292         static char *tab[] = {"mx", "ip", };
293
294         dprint("mxdial(%s, %s, %s, mx)\n", addr, ddomain, gdomain);
295         memset(mx, 0, sizeof *mx);
296         addr = netmkaddr(addr, 0, "smtp");
297         d = mx->ds;
298         dialstringparse(addr, d + 0);
299         nd = 1;
300         if(d[0].netdir == nil){
301                 d[1] = d[0];
302                 d[0].netdir = "/net";
303                 d[1].netdir = "/net.alt";
304                 nd = 2;
305         }
306
307         /* search all networks for mx records; then ip records */
308         for(j = 0; j < nelem(tab); j++)
309                 for(i = 0; i < nd; i++)
310                         if(lookcall(mx, d + i, ddomain, tab[j]) != -1)
311                                 return mx->fd;
312
313         /* grotty: try gateway machine by ip only (fixme: try cs lookup) */
314         if(gdomain != nil){
315                 dialstringparse(netmkaddr(gdomain, 0, "smtp"), d + 0);
316                 if(lookcall(mx, d + 0, gdomain, "ip") != -1)
317                         return mx->fd;
318         }
319
320         return -1;
321 }
322
323 int
324 mxdial(char *addr, char *ddomain, char *gdomain, Mx *x)
325 {
326         int fd;
327         Mxtab mx;
328
329         memset(x, 0, sizeof *x);
330         fd = mxdial0(addr, ddomain, gdomain, &mx);
331         if(fd >= 0 && mx.pmx >= 0)
332                 *x = mx.mx[mx.pmx];
333         mxtabfree(&mx);
334         return fd;
335 }