]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/upas/smtp/mxdial.c
upas/smtp: allow remote port to be set to something other than smtp (e.g. ssmtp)
[plan9front.git] / sys / src / cmd / upas / smtp / mxdial.c
1 #include "common.h"
2 #include "smtp.h"
3 #include <ndb.h>
4
5 char    *bustedmxs[Maxbustedmx];
6
7 static void
8 expand(DS *ds)
9 {
10         char *s;
11         Ndbtuple *t;
12
13         s = ds->host + 1;
14         t = csipinfo(ds->netdir, "sys", sysname(), &s, 1);
15         if(t != nil){
16                 strecpy(ds->expand, ds->expand+sizeof ds->expand, t->val);
17                 ds->host = ds->expand;
18         }
19         ndbfree(t);
20 }
21
22 /* break up an address to its component parts */
23 void
24 dialstringparse(char *str, DS *ds)
25 {
26         char *p, *p2;
27
28         strecpy(ds->buf, ds->buf + sizeof ds->buf, str);
29         p = strchr(ds->buf, '!');
30         if(p == 0) {
31                 ds->netdir = 0;
32                 ds->proto = "net";
33                 ds->host = ds->buf;
34         } else {
35                 if(*ds->buf != '/'){
36                         ds->netdir = 0;
37                         ds->proto = ds->buf;
38                 } else {
39                         for(p2 = p; *p2 != '/'; p2--)
40                                 ;
41                         *p2++ = 0;
42                         ds->netdir = ds->buf;
43                         ds->proto = p2;
44                 }
45                 *p = 0;
46                 ds->host = p + 1;
47         }
48         ds->service = strchr(ds->host, '!');
49         if(ds->service)
50                 *ds->service++ = 0;
51         else
52                 ds->service = "smtp";
53         if(*ds->host == '$')
54                 expand(ds);
55 }
56
57 void
58 mxtabfree(Mxtab *mx)
59 {
60         free(mx->mx);
61         memset(mx, 0, sizeof *mx);
62 }
63
64 static void
65 mxtabrealloc(Mxtab *mx)
66 {
67         if(mx->nmx < mx->amx)
68                 return;
69         if(mx->amx == 0)
70                 mx->amx = 1;
71         mx->amx <<= 1;
72         mx->mx = realloc(mx->mx, sizeof mx->mx[0] * mx->amx);
73         if(mx->mx == nil)
74                 sysfatal("no memory for mx");
75 }
76
77 static void
78 mxtabadd(Mxtab *mx, char *host, char *ip, char *net, int pref)
79 {
80         int i;
81         Mx *x;
82
83         mxtabrealloc(mx);
84         x = mx->mx;
85         for(i = mx->nmx; i>0 && x[i-1].pref>pref && x[i-1].netdir == net; i--)
86                 x[i] = x[i-1];
87         strecpy(x[i].host, x[i].host + sizeof x[i].host, host);
88         if(ip != nil)
89                 strecpy(x[i].ip, x[i].ip + sizeof x[i].ip, ip);
90         else
91                 x[i].ip[0] = 0;
92         x[i].netdir = net;
93         x[i].pref = pref;
94         x[i].valid = 1;
95         mx->nmx++;
96 }
97
98 static int
99 timeout(void*, char *msg)
100 {
101         if(strstr(msg, "alarm"))
102                 return 1;
103         return 0;
104 }
105
106 static long
107 timedwrite(int fd, void *buf, long len, long ms)
108 {
109         long n, oalarm;
110
111         atnotify(timeout, 1);
112         oalarm = alarm(ms);
113         n = pwrite(fd, buf, len, 0);
114         alarm(oalarm);
115         atnotify(timeout, 0);
116         return n;
117 }
118
119 static int
120 dnslookup(Mxtab *mx, int fd, char *query, char *domain, char *net, int pref0)
121 {
122         int n;
123         char buf[1024], *f[4];
124
125         n = timedwrite(fd, query, strlen(query), 60*1000);
126         if(n < 0){
127                 rerrstr(buf, sizeof buf);
128                 dprint("dns: %s\n", buf);
129                 if(strstr(buf, "dns failure")){
130                         /* if dns fails for the mx lookup, we have to stop */
131                         close(fd);
132                         return -1;
133                 }
134                 return 0;
135         }
136
137         seek(fd, 0, 0);
138         for(;;){
139                 if((n = read(fd, buf, sizeof buf - 1)) < 1)
140                         break;
141                 buf[n] = 0;
142         //      chat("dns: %s\n", buf);
143                 n = tokenize(buf, f, nelem(f));
144                 if(n < 2)
145                         continue;
146                 if(strcmp(f[1], "mx") == 0 && n == 4){
147                         if(strchr(domain, '.') == 0)
148                                 strcpy(domain, f[0]);
149                         mxtabadd(mx, f[3], nil, net, atoi(f[2]));
150                 }
151                 else if (strcmp(f[1], "ip") == 0 && n == 3){
152                         if(strchr(domain, '.') == 0)
153                                 strcpy(domain, f[0]);
154                         mxtabadd(mx, f[0], f[2], net, pref0);
155                 }
156         }
157
158         return 0;
159 }
160
161 static int
162 busted(char *mx)
163 {
164         char **bmp;
165
166         for (bmp = bustedmxs; *bmp != nil; bmp++)
167                 if (strcmp(mx, *bmp) == 0)
168                         return 1;
169         return 0;
170 }
171
172 static void
173 complain(Mxtab *mx, char *domain)
174 {
175         char buf[1024], *e, *p;
176         int i;
177
178         p = buf;
179         e = buf + sizeof buf;
180         for(i = 0; i < mx->nmx; i++)
181                 p = seprint(p, e, "%s ", mx->mx[i].ip);
182         syslog(0, "smtpd.mx", "loopback for %s %s", domain, buf);
183 }
184
185 static int
186 okaymx(Mxtab *mx, char *domain)
187 {
188         int i;
189         Mx *x;
190
191         /* look for malicious dns entries; TODO use badcidr in ../spf/ to catch more than ip4 */
192         for(i = 0; i < mx->nmx; i++){
193                 x = mx->mx + i;
194                 if(x->valid && strcmp(x->ip, "127.0.0.1") == 0){
195                         dprint("illegal: domain %s lists 127.0.0.1 as mail server", domain);
196                         complain(mx, domain);
197                         werrstr("illegal: domain %s lists 127.0.0.1 as mail server", domain);
198                         return -1;
199                 }
200                 if(x->valid && busted(x->host)){
201                         dprint("lookup: skipping busted mx %s\n", x->host);
202                         x->valid = 0;
203                 }
204         }
205         return 0;
206 }
207
208 static int
209 lookup(Mxtab *mx, char *net, char *host, char *domain, char *type)
210 {
211         char dns[128], buf[1024];
212         int fd, i;
213         Mx *x;
214
215         snprint(dns, sizeof dns, "%s/dns", net);
216         fd = open(dns, ORDWR);
217         if(fd == -1)
218                 return -1;
219
220         snprint(buf, sizeof buf, "%s %s", host, type);
221         dprint("sending %s '%s'\n", dns, buf);
222         dnslookup(mx, fd, buf, domain, net, 10000);
223
224         for(i = 0; i < mx->nmx; i++){
225                 x = mx->mx + i;
226                 if(x->ip[0] != 0)
227                         continue;
228                 x->valid = 0;
229
230                 snprint(buf, sizeof buf, "%s %s", x->host, "ip");
231                 dprint("sending %s '%s'\n", dns, buf);
232                 dnslookup(mx, fd, buf, domain, net, x->pref);
233         }
234
235         close(fd);
236
237         if(strcmp(type, "mx") == 0){
238                 if(okaymx(mx, domain) == -1)
239                         return -1;
240                 for(i = 0; i < mx->nmx; i++){
241                         x = mx->mx + i;
242                         dprint("mx list: %s     %d      %s\n", x->host, x->pref, x->ip);
243                 }
244                 dprint("\n");
245         }
246
247         return 0;
248 }
249
250 static int
251 lookcall(Mxtab *mx, DS *d, char *domain, char *type)
252 {
253         char buf[1024];
254         int i;
255         Mx *x;
256
257         if(lookup(mx, d->netdir, d->host, domain, type) == -1){
258                 for(i = 0; i < mx->nmx; i++)
259                         if(mx->mx[i].netdir == d->netdir)
260                                 mx->mx[i].valid = 0;
261                 return -1;
262         }
263
264         for(i = 0; i < mx->nmx; i++){
265                 x = mx->mx + i;
266                 if(x->ip[0] == 0 || x->valid == 0){
267                         x->valid = 0;
268                         continue;
269                 }
270                 snprint(buf, sizeof buf, "%s/%s!%s!%s", d->netdir, d->proto,
271                         x->ip /*x->host*/, d->service);
272                 dprint("mxdial trying %s        [%s]\n", x->host, buf);
273                 atnotify(timeout, 1);
274                 alarm(10*1000);
275                 mx->fd = dial(buf, 0, 0, 0);
276                 alarm(0);
277                 atnotify(timeout, 0);
278                 if(mx->fd >= 0){
279                         mx->pmx = i;
280                         return mx->fd;
281                 }
282                 dprint("        failed %r\n");
283                 x->valid = 0;
284         }
285
286         return -1;
287 }
288
289 int
290 mxdial0(char *addr, char *ddomain, char *gdomain, Mxtab *mx)
291 {
292         int nd, i, j;
293         DS *d;
294         static char *tab[] = {"mx", "ip", };
295
296         dprint("mxdial(%s, %s, %s, mx)\n", addr, ddomain, gdomain);
297         memset(mx, 0, sizeof *mx);
298         d = mx->ds;
299         dialstringparse(addr, d + 0);
300         nd = 1;
301         if(d[0].netdir == nil){
302                 d[1] = d[0];
303                 d[0].netdir = "/net";
304                 d[1].netdir = "/net.alt";
305                 nd = 2;
306         }
307
308         /* search all networks for mx records; then ip records */
309         for(j = 0; j < nelem(tab); j++)
310                 for(i = 0; i < nd; i++)
311                         if(lookcall(mx, d + i, ddomain, tab[j]) != -1)
312                                 return mx->fd;
313
314         /* grotty: try gateway machine by ip only (fixme: try cs lookup) */
315         if(gdomain != nil){
316                 dialstringparse(netmkaddr(gdomain, 0, "smtp"), d + 0);
317                 if(lookcall(mx, d + 0, gdomain, "ip") != -1)
318                         return mx->fd;
319         }
320
321         return -1;
322 }
323
324 int
325 mxdial(char *addr, char *ddomain, char *gdomain, Mx *x)
326 {
327         int fd;
328         Mxtab mx;
329
330         memset(x, 0, sizeof *x);
331         fd = mxdial0(addr, ddomain, gdomain, &mx);
332         if(fd >= 0 && mx.pmx >= 0)
333                 *x = mx.mx[mx.pmx];
334         mxtabfree(&mx);
335         return fd;
336 }