]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/aan.c
aan: use sync messages as keep alives
[plan9front.git] / sys / src / cmd / aan.c
1 #include <u.h>
2 #include <libc.h>
3 #include <auth.h>
4 #include <fcall.h>
5 #include <thread.h>
6
7 #define NS(x)   ((vlong)x)
8 #define US(x)   (NS(x) * 1000LL)
9 #define MS(x)   (US(x) * 1000LL)
10 #define S(x)    (MS(x) * 1000LL)
11
12 enum {
13         Synctime = S(8),
14         Nbuf = 10,
15         K = 1024,
16         Bufsize = 8 * K,
17         Stacksize = 8 * K,
18         Timer = 0,                              // Alt channels.
19         Unsent = 1,
20         Maxto = 24 * 3600,                      // A full day to reconnect.
21         Hdrsz = 3*4,
22 };
23
24 typedef struct {
25         uchar   nb[4];          // Number of data bytes in this message
26         uchar   msg[4];         // Message number
27         uchar   acked[4];       // Number of messages acked
28 } Hdr;
29
30 typedef struct {
31         Hdr     hdr;
32         uchar   buf[Bufsize];
33 } Buf;
34
35 static Channel  *unsent;
36 static Channel  *unacked;
37 static Channel  *empty;
38 static int      netfd;
39 static int      inmsg;
40 static char     *devdir;
41 static int      debug;
42 static int      done;
43 static char     *dialstring;
44 static int      maxto = Maxto;
45 static char     *Logname = "aan";
46 static int      client;
47 static int      reader = -1;
48 static int      lostsync;
49
50 static Alt a[] = {
51         /*      c       v        op   */
52         {       nil,    nil,    CHANRCV                 },      // timer
53         {       nil,    nil,    CHANRCV                 },      // unsent
54         {       nil,    nil,    CHANEND         },
55 };
56
57 static void             fromnet(void*);
58 static void             fromclient(void*);
59 static int              reconnect(int);
60 static void             synchronize(void);
61 static int              sendcommand(ulong, ulong);
62 static void             showmsg(int, char *, Buf *);
63 static int              writen(int, uchar *, int);
64 static void             dmessage(int, char *, ...);
65 static void             timerproc(void *);
66
67 static void
68 usage(void)
69 {
70         fprint(2, "Usage: %s [-cd] [-m maxto] dialstring|netdir\n", argv0);
71         exits("usage");
72 }
73
74  
75 static int
76 catch(void *, char *s)
77 {
78         if (!strcmp(s, "alarm")) {
79                 syslog(0, Logname, "Timed out while waiting for reconnect, exiting...");
80                 threadexitsall(nil);
81         }
82         return 0;
83 }
84
85 static void*
86 emalloc(int n)
87 {
88         ulong pc;
89         void *v;
90
91         pc = getcallerpc(&n);
92         v = malloc(n);
93         if(v == nil)
94                 sysfatal("Cannot allocate memory; pc=%lux", pc);
95         setmalloctag(v, pc);
96         return v;
97 }
98
99 void
100 threadmain(int argc, char **argv)
101 {
102         vlong synctime;
103         int i, n, failed;
104         Channel *timer;
105         Hdr hdr;
106         Buf *b;
107
108         ARGBEGIN {
109         case 'c':
110                 client++;
111                 break;
112         case 'd':
113                 debug++;
114                 break;
115         case 'm':
116                 maxto = (int)strtol(EARGF(usage()), nil, 0);
117                 break;
118         default:
119                 usage();
120         } ARGEND;
121
122         if (argc != 1)
123                 usage();
124
125         if (!client) {
126                 char *p;
127
128                 devdir = argv[0];
129                 if ((p = strstr(devdir, "/local")) != nil)
130                         *p = '\0';
131         }
132         else
133                 dialstring = argv[0];
134
135         if (debug > 0) {
136                 int fd = open("#c/cons", OWRITE|OCEXEC);        
137                 dup(fd, 2);
138         }
139
140         fmtinstall('F', fcallfmt);
141
142         atnotify(catch, 1);
143
144         /*
145          * Set up initial connection. use short timeout
146          * of 60 seconds so we wont hang arround for too
147          * long if there is some general connection problem
148          * (like NAT).
149          */
150         netfd = reconnect(60);
151
152         unsent = chancreate(sizeof(Buf *), Nbuf);
153         unacked = chancreate(sizeof(Buf *), Nbuf);
154         empty = chancreate(sizeof(Buf *), Nbuf);
155         timer = chancreate(sizeof(uchar *), 1);
156         if(unsent == nil || unacked == nil || empty == nil || timer == nil)
157                 sysfatal("Cannot allocate channels");
158
159         for (i = 0; i < Nbuf; i++)
160                 sendp(empty, emalloc(sizeof(Buf)));
161
162         reader = proccreate(fromnet, nil, Stacksize);
163         if (reader < 0)
164                 sysfatal("Cannot start fromnet; %r");
165
166         if (proccreate(fromclient, nil, Stacksize) < 0)
167                 sysfatal("Cannot start fromclient; %r");
168
169         if (proccreate(timerproc, timer, Stacksize) < 0)
170                 sysfatal("Cannot start timerproc; %r");
171
172         a[Timer].c = timer;
173         a[Unsent].c = unsent;
174         a[Unsent].v = &b;
175
176 Restart:
177         synctime = nsec() + Synctime;
178         failed = 0;
179         lostsync = 0;
180         while (!done) {
181                 if (failed) {
182                         // Wait for the netreader to die.
183                         while (netfd >= 0) {
184                                 dmessage(1, "main; waiting for netreader to die\n");
185                                 threadint(reader);
186                                 sleep(1000);
187                         }
188
189                         // the reader died; reestablish the world.
190                         netfd = reconnect(maxto);
191                         synchronize();
192                         goto Restart;
193                 }
194
195                 switch (alt(a)) {
196                 case Timer:
197                         if (netfd < 0 || nsec() < synctime)
198                                 break;
199
200                         PBIT32(hdr.nb, 0);
201                         PBIT32(hdr.acked, inmsg);
202                         PBIT32(hdr.msg, -1);
203
204                         if (writen(netfd, (uchar *)&hdr, Hdrsz) < 0) {
205                                 dmessage(2, "main; writen failed; %r\n");
206                                 failed = 1;
207                                 continue;
208                         }
209
210                         if(++lostsync > 2){
211                                 dmessage(2, "main; lost sync\n");
212                                 failed = 1;
213                                 continue;
214                         }
215                         synctime = nsec() + Synctime;
216                         break;
217
218                 case Unsent:
219                         sendp(unacked, b);
220
221                         PBIT32(b->hdr.acked, inmsg);
222
223                         if (writen(netfd, (uchar *)&b->hdr, Hdrsz) < 0) {
224                                 dmessage(2, "main; writen failed; %r\n");
225                                 failed = 1;
226                         }
227
228                         n = GBIT32(b->hdr.nb);
229                         if (writen(netfd, b->buf, n) < 0) {
230                                 dmessage(2, "main; writen failed; %r\n");
231                                 failed = 1;
232                         }
233
234                         if (n == 0)
235                                 done = 1;
236                         break;
237                 }
238         }
239         syslog(0, Logname, "exiting...");
240         threadexitsall(nil);
241 }
242
243
244 static void
245 fromclient(void*)
246 {
247         static int outmsg;
248         int n;
249         Buf *b;
250
251         threadsetname("fromclient");
252
253         do {
254                 b = recvp(empty);
255                 n = read(0, b->buf, Bufsize);
256                 if (n <= 0) {
257                         if (n < 0)
258                                 dmessage(2, "fromclient; Cannot read 9P message; %r\n");
259                         else
260                                 dmessage(2, "fromclient; Client terminated\n");
261                         n = 0;
262                 }
263                 PBIT32(b->hdr.nb, n);
264                 PBIT32(b->hdr.msg, outmsg);
265                 showmsg(1, "fromclient", b);
266                 sendp(unsent, b);
267                 outmsg++;
268         } while(n > 0);
269 }
270
271 static void
272 fromnet(void*)
273 {
274         extern void _threadnote(void *, char *);
275         static int lastacked;
276         int n, m, len, acked;
277         Buf *b;
278
279         notify(_threadnote);
280
281         threadsetname("fromnet");
282
283         b = emalloc(sizeof(Buf));
284         while (!done) {
285                 while (netfd < 0) {
286                         if(done)
287                                 return;
288                         dmessage(1, "fromnet; waiting for connection... (inmsg %d)\n", inmsg);
289                         sleep(1000);
290                 }
291
292                 // Read the header.
293                 len = readn(netfd, (uchar *)&b->hdr, Hdrsz);
294                 if (len <= 0) {
295                         if (len < 0)
296                                 dmessage(1, "fromnet; (hdr) network failure; %r\n");
297                         else
298                                 dmessage(1, "fromnet; (hdr) network closed\n");
299                         close(netfd);
300                         netfd = -1;
301                         continue;
302                 }
303                 lostsync = 0;   // reset timeout
304                 n = GBIT32(b->hdr.nb);
305                 m = GBIT32(b->hdr.msg);
306                 acked = GBIT32(b->hdr.acked);
307                 dmessage(2, "fromnet: Got message, size %d, nb %d, msg %d, acked %d, lastacked %d\n",
308                         len, n, m, acked, lastacked);
309
310                 if (n == 0) {
311                         if (m >= 0) {
312                                 dmessage(1, "fromnet; network closed\n");
313                                 break;
314                         }
315                         continue;
316                 }
317
318                 if (n > Bufsize) {
319                         dmessage(1, "fromnet; message too big %d > %d\n", n, Bufsize);
320                         break;
321                 }
322
323                 len = readn(netfd, b->buf, n);
324                 if (len <= 0 || len != n) {
325                         if (len == 0)
326                                 dmessage(1, "fromnet; network closed\n");
327                         else
328                                 dmessage(1, "fromnet; network failure; %r\n");
329                         close(netfd);
330                         netfd = -1;
331                         continue;
332                 }
333
334                 if (m < inmsg) {
335                         dmessage(1, "fromnet; skipping message %d, currently at %d\n", m, inmsg);
336                         continue;
337                 }                       
338
339                 // Process the acked list.
340                 while(lastacked != acked) {
341                         Buf *rb;
342
343                         rb = recvp(unacked);
344                         m = GBIT32(rb->hdr.msg);
345                         if (m != lastacked) {
346                                 dmessage(1, "fromnet; rb %p, msg %d, lastacked %d\n", rb, m, lastacked);
347                                 sysfatal("fromnet; bug");
348                         }
349                         PBIT32(rb->hdr.msg, -1);
350                         sendp(empty, rb);
351                         lastacked++;
352                 } 
353                 inmsg++;
354
355                 showmsg(1, "fromnet", b);
356
357                 if (writen(1, b->buf, len) < 0) 
358                         sysfatal("fromnet; cannot write to client; %r");
359         }
360         done = 1;
361 }
362
363 static int
364 reconnect(int secs)
365 {
366         NetConnInfo *nci;
367         char ldir[40];
368         int lcfd, fd;
369
370         if (dialstring) {
371                 syslog(0, Logname, "dialing %s", dialstring);
372                 alarm(secs*1000);
373                 while ((fd = dial(dialstring, nil, ldir, nil)) < 0) {
374                         char err[32];
375
376                         err[0] = '\0';
377                         errstr(err, sizeof err);
378                         if (strstr(err, "connection refused")) {
379                                 dmessage(1, "reconnect; server died...\n");
380                                 threadexitsall("server died...");
381                         }
382                         dmessage(1, "reconnect: dialed %s; %s\n", dialstring, err);
383                         sleep(1000);
384                 }
385                 alarm(0);
386                 syslog(0, Logname, "reconnected to %s", dialstring);
387         } 
388         else {
389                 syslog(0, Logname, "waiting for connection on %s", devdir);
390                 alarm(secs*1000);
391                 if ((lcfd = listen(devdir, ldir)) < 0) 
392                         sysfatal("reconnect; cannot listen; %r");
393                 if ((fd = accept(lcfd, ldir)) < 0)
394                         sysfatal("reconnect; cannot accept; %r");
395                 alarm(0);
396                 close(lcfd);
397         }
398
399         if(nci = getnetconninfo(ldir, fd)){
400                 syslog(0, Logname, "connected from %s", nci->rsys);
401                 threadsetname(client? "client %s %s" : "server %s %s", ldir, nci->rsys);
402                 freenetconninfo(nci);
403         } else
404                 syslog(0, Logname, "connected");
405
406         return fd;
407 }
408
409 static void
410 synchronize(void)
411 {
412         Channel *tmp;
413         Buf *b;
414         int n;
415
416         // Ignore network errors here.  If we fail during 
417         // synchronization, the next alarm will pick up 
418         // the error.
419
420         tmp = chancreate(sizeof(Buf *), Nbuf);
421         while ((b = nbrecvp(unacked)) != nil) {
422                 n = GBIT32(b->hdr.nb);
423                 writen(netfd, (uchar *)&b->hdr, Hdrsz);
424                 writen(netfd, b->buf, n);
425                 sendp(tmp, b);
426         }
427         chanfree(unacked);
428         unacked = tmp;
429 }
430
431 static void
432 showmsg(int level, char *s, Buf *b)
433 {
434         int n;
435
436         if (b == nil) {
437                 dmessage(level, "%s; b == nil\n", s);
438                 return;
439         }
440         n = GBIT32(b->hdr.nb);
441         dmessage(level, "%s;  (len %d) %X %X %X %X %X %X %X %X %X (%p)\n", s, n, 
442                         b->buf[0], b->buf[1], b->buf[2],
443                         b->buf[3], b->buf[4], b->buf[5],
444                         b->buf[6], b->buf[7], b->buf[8], b);
445 }
446
447 static int
448 writen(int fd, uchar *buf, int nb)
449 {
450         int len = nb;
451
452         while (nb > 0) {
453                 int n;
454
455                 if (fd < 0) 
456                         return -1;
457
458                 if ((n = write(fd, buf, nb)) < 0) {
459                         dmessage(1, "writen; Write failed; %r\n");
460                         return -1;
461                 }
462                 dmessage(2, "writen: wrote %d bytes\n", n);
463
464                 buf += n;
465                 nb -= n;
466         }
467         return len;
468 }
469
470 static void
471 timerproc(void *x)
472 {
473         Channel *timer = x;
474
475         threadsetname("timer");
476
477         while (!done) {
478                 sleep((Synctime / MS(1)) >> 1);
479                 sendp(timer, "timer");
480         }
481 }
482
483 static void
484 dmessage(int level, char *fmt, ...)
485 {
486         va_list arg; 
487
488         if (level > debug) 
489                 return;
490
491         va_start(arg, fmt);
492         vfprint(2, fmt, arg);
493         va_end(arg);
494 }