]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/aan.c
aan: didn't ask about sendcommand
[plan9front.git] / sys / src / cmd / aan.c
1 #include <u.h>
2 #include <libc.h>
3 #include <auth.h>
4 #include <fcall.h>
5 #include <thread.h>
6
7 #define NS(x)   ((vlong)x)
8 #define US(x)   (NS(x) * 1000LL)
9 #define MS(x)   (US(x) * 1000LL)
10 #define S(x)    (MS(x) * 1000LL)
11
12 enum {
13         Synctime = S(8),
14         Nbuf = 10,
15         K = 1024,
16         Bufsize = 8 * K,
17         Stacksize = 8 * K,
18         Timer = 0,                              // Alt channels.
19         Unsent = 1,
20         Maxto = 24 * 3600,                      // A full day to reconnect.
21         Hdrsz = 3*4,
22 };
23
24 typedef struct {
25         uchar   nb[4];          // Number of data bytes in this message
26         uchar   msg[4];         // Message number
27         uchar   acked[4];       // Number of messages acked
28 } Hdr;
29
30 typedef struct {
31         Hdr     hdr;
32         uchar   buf[Bufsize];
33 } Buf;
34
35 static Channel  *unsent;
36 static Channel  *unacked;
37 static Channel  *empty;
38 static int      netfd;
39 static int      inmsg;
40 static char     *devdir;
41 static int      debug;
42 static int      done;
43 static char     *dialstring;
44 static int      maxto = Maxto;
45 static char     *Logname = "aan";
46 static int      client;
47 static int      reader = -1;
48 static int      lostsync;
49
50 static Alt a[] = {
51         /*      c       v        op   */
52         {       nil,    nil,    CHANRCV                 },      // timer
53         {       nil,    nil,    CHANRCV                 },      // unsent
54         {       nil,    nil,    CHANEND         },
55 };
56
57 static void             fromnet(void*);
58 static void             fromclient(void*);
59 static int              reconnect(int);
60 static void             synchronize(void);
61 static void             showmsg(int, char *, Buf *);
62 static int              writen(int, uchar *, int);
63 static void             dmessage(int, char *, ...);
64 static void             timerproc(void *);
65
66 static void
67 usage(void)
68 {
69         fprint(2, "Usage: %s [-cd] [-m maxto] dialstring|netdir\n", argv0);
70         exits("usage");
71 }
72
73  
74 static int
75 catch(void *, char *s)
76 {
77         if (!strcmp(s, "alarm")) {
78                 syslog(0, Logname, "Timed out while waiting for reconnect, exiting...");
79                 threadexitsall(nil);
80         }
81         return 0;
82 }
83
84 static void*
85 emalloc(int n)
86 {
87         ulong pc;
88         void *v;
89
90         pc = getcallerpc(&n);
91         v = malloc(n);
92         if(v == nil)
93                 sysfatal("Cannot allocate memory; pc=%lux", pc);
94         setmalloctag(v, pc);
95         return v;
96 }
97
98 void
99 threadmain(int argc, char **argv)
100 {
101         vlong synctime;
102         int i, n, failed;
103         Channel *timer;
104         Hdr hdr;
105         Buf *b;
106
107         ARGBEGIN {
108         case 'c':
109                 client++;
110                 break;
111         case 'd':
112                 debug++;
113                 break;
114         case 'm':
115                 maxto = (int)strtol(EARGF(usage()), nil, 0);
116                 break;
117         default:
118                 usage();
119         } ARGEND;
120
121         if (argc != 1)
122                 usage();
123
124         if (!client) {
125                 char *p;
126
127                 devdir = argv[0];
128                 if ((p = strstr(devdir, "/local")) != nil)
129                         *p = '\0';
130         }
131         else
132                 dialstring = argv[0];
133
134         if (debug > 0) {
135                 int fd = open("#c/cons", OWRITE|OCEXEC);        
136                 dup(fd, 2);
137         }
138
139         fmtinstall('F', fcallfmt);
140
141         atnotify(catch, 1);
142
143         /*
144          * Set up initial connection. use short timeout
145          * of 60 seconds so we wont hang arround for too
146          * long if there is some general connection problem
147          * (like NAT).
148          */
149         netfd = reconnect(60);
150
151         unsent = chancreate(sizeof(Buf *), Nbuf);
152         unacked = chancreate(sizeof(Buf *), Nbuf);
153         empty = chancreate(sizeof(Buf *), Nbuf);
154         timer = chancreate(sizeof(uchar *), 1);
155         if(unsent == nil || unacked == nil || empty == nil || timer == nil)
156                 sysfatal("Cannot allocate channels");
157
158         for (i = 0; i < Nbuf; i++)
159                 sendp(empty, emalloc(sizeof(Buf)));
160
161         reader = proccreate(fromnet, nil, Stacksize);
162         if (reader < 0)
163                 sysfatal("Cannot start fromnet; %r");
164
165         if (proccreate(fromclient, nil, Stacksize) < 0)
166                 sysfatal("Cannot start fromclient; %r");
167
168         if (proccreate(timerproc, timer, Stacksize) < 0)
169                 sysfatal("Cannot start timerproc; %r");
170
171         a[Timer].c = timer;
172         a[Unsent].c = unsent;
173         a[Unsent].v = &b;
174
175 Restart:
176         synctime = nsec() + Synctime;
177         failed = 0;
178         lostsync = 0;
179         while (!done) {
180                 if (netfd < 0 || failed) {
181                         // Wait for the netreader to die.
182                         while (netfd >= 0) {
183                                 dmessage(1, "main; waiting for netreader to die\n");
184                                 threadint(reader);
185                                 sleep(1000);
186                         }
187
188                         // the reader died; reestablish the world.
189                         netfd = reconnect(maxto);
190                         synchronize();
191                         goto Restart;
192                 }
193
194                 switch (alt(a)) {
195                 case Timer:
196                         if (netfd < 0 || nsec() < synctime)
197                                 break;
198
199                         PBIT32(hdr.nb, 0);
200                         PBIT32(hdr.acked, inmsg);
201                         PBIT32(hdr.msg, -1);
202
203                         if (writen(netfd, (uchar *)&hdr, Hdrsz) < 0) {
204                                 dmessage(2, "main; writen failed; %r\n");
205                                 failed = 1;
206                                 continue;
207                         }
208
209                         if(++lostsync > 2){
210                                 syslog(0, Logname, "connection seems hung up...");
211                                 failed = 1;
212                                 continue;
213                         }
214                         synctime = nsec() + Synctime;
215                         break;
216
217                 case Unsent:
218                         sendp(unacked, b);
219
220                         if (netfd < 0)
221                                 break;
222
223                         PBIT32(b->hdr.acked, inmsg);
224
225                         if (writen(netfd, (uchar *)&b->hdr, Hdrsz) < 0) {
226                                 dmessage(2, "main; writen failed; %r\n");
227                                 failed = 1;
228                         }
229
230                         n = GBIT32(b->hdr.nb);
231                         if (writen(netfd, b->buf, n) < 0) {
232                                 dmessage(2, "main; writen failed; %r\n");
233                                 failed = 1;
234                         }
235
236                         if (n == 0)
237                                 done = 1;
238                         break;
239                 }
240         }
241         syslog(0, Logname, "exiting...");
242         threadexitsall(nil);
243 }
244
245
246 static void
247 fromclient(void*)
248 {
249         static int outmsg;
250         int n;
251         Buf *b;
252
253         threadsetname("fromclient");
254
255         do {
256                 b = recvp(empty);
257                 n = read(0, b->buf, Bufsize);
258                 if (n <= 0) {
259                         if (n < 0)
260                                 dmessage(2, "fromclient; Cannot read 9P message; %r\n");
261                         else
262                                 dmessage(2, "fromclient; Client terminated\n");
263                         n = 0;
264                 }
265                 PBIT32(b->hdr.nb, n);
266                 PBIT32(b->hdr.msg, outmsg);
267                 showmsg(1, "fromclient", b);
268                 sendp(unsent, b);
269                 outmsg++;
270         } while(n > 0);
271 }
272
273 static void
274 fromnet(void*)
275 {
276         extern void _threadnote(void *, char *);
277         static int lastacked;
278         int n, m, len, acked;
279         Buf *b;
280
281         notify(_threadnote);
282
283         threadsetname("fromnet");
284
285         b = emalloc(sizeof(Buf));
286         while (!done) {
287                 while (netfd < 0) {
288                         if(done)
289                                 return;
290                         dmessage(1, "fromnet; waiting for connection... (inmsg %d)\n", inmsg);
291                         sleep(1000);
292                 }
293
294                 // Read the header.
295                 len = readn(netfd, (uchar *)&b->hdr, Hdrsz);
296                 if (len <= 0) {
297                         if (len < 0)
298                                 dmessage(1, "fromnet; (hdr) network failure; %r\n");
299                         else
300                                 dmessage(1, "fromnet; (hdr) network closed\n");
301                         close(netfd);
302                         netfd = -1;
303                         continue;
304                 }
305                 lostsync = 0;   // reset timeout
306                 n = GBIT32(b->hdr.nb);
307                 m = GBIT32(b->hdr.msg);
308                 acked = GBIT32(b->hdr.acked);
309                 dmessage(2, "fromnet: Got message, size %d, nb %d, msg %d, acked %d, lastacked %d\n",
310                         len, n, m, acked, lastacked);
311
312                 if (n == 0) {
313                         if (m >= 0) {
314                                 dmessage(1, "fromnet; network closed\n");
315                                 break;
316                         }
317                         continue;
318                 }
319
320                 if (n > Bufsize) {
321                         dmessage(1, "fromnet; message too big %d > %d\n", n, Bufsize);
322                         break;
323                 }
324
325                 len = readn(netfd, b->buf, n);
326                 if (len <= 0 || len != n) {
327                         if (len == 0)
328                                 dmessage(1, "fromnet; network closed\n");
329                         else
330                                 dmessage(1, "fromnet; network failure; %r\n");
331                         close(netfd);
332                         netfd = -1;
333                         continue;
334                 }
335
336                 if (m < inmsg) {
337                         dmessage(1, "fromnet; skipping message %d, currently at %d\n", m, inmsg);
338                         continue;
339                 }                       
340
341                 // Process the acked list.
342                 while(lastacked != acked) {
343                         Buf *rb;
344
345                         rb = recvp(unacked);
346                         m = GBIT32(rb->hdr.msg);
347                         if (m != lastacked) {
348                                 dmessage(1, "fromnet; rb %p, msg %d, lastacked %d\n", rb, m, lastacked);
349                                 sysfatal("fromnet; bug");
350                         }
351                         PBIT32(rb->hdr.msg, -1);
352                         sendp(empty, rb);
353                         lastacked++;
354                 } 
355                 inmsg++;
356
357                 showmsg(1, "fromnet", b);
358
359                 if (writen(1, b->buf, len) < 0) 
360                         sysfatal("fromnet; cannot write to client; %r");
361         }
362         done = 1;
363 }
364
365 static int
366 reconnect(int secs)
367 {
368         NetConnInfo *nci;
369         char ldir[40];
370         int lcfd, fd;
371
372         if (dialstring) {
373                 syslog(0, Logname, "dialing %s", dialstring);
374                 alarm(secs*1000);
375                 while ((fd = dial(dialstring, nil, ldir, nil)) < 0) {
376                         char err[32];
377
378                         err[0] = '\0';
379                         errstr(err, sizeof err);
380                         if (strstr(err, "connection refused")) {
381                                 dmessage(1, "reconnect; server died...\n");
382                                 threadexitsall("server died...");
383                         }
384                         dmessage(1, "reconnect: dialed %s; %s\n", dialstring, err);
385                         sleep(1000);
386                 }
387                 alarm(0);
388                 syslog(0, Logname, "reconnected to %s", dialstring);
389         } 
390         else {
391                 syslog(0, Logname, "waiting for connection on %s", devdir);
392                 alarm(secs*1000);
393                 if ((lcfd = listen(devdir, ldir)) < 0) 
394                         sysfatal("reconnect; cannot listen; %r");
395                 if ((fd = accept(lcfd, ldir)) < 0)
396                         sysfatal("reconnect; cannot accept; %r");
397                 alarm(0);
398                 close(lcfd);
399         }
400
401         if(nci = getnetconninfo(ldir, fd)){
402                 syslog(0, Logname, "connected from %s", nci->rsys);
403                 threadsetname(client? "client %s %s" : "server %s %s", ldir, nci->rsys);
404                 freenetconninfo(nci);
405         } else
406                 syslog(0, Logname, "connected");
407
408         return fd;
409 }
410
411 static void
412 synchronize(void)
413 {
414         Channel *tmp;
415         Buf *b;
416         int n;
417
418         // Ignore network errors here.  If we fail during 
419         // synchronization, the next alarm will pick up 
420         // the error.
421
422         tmp = chancreate(sizeof(Buf *), Nbuf);
423         while ((b = nbrecvp(unacked)) != nil) {
424                 n = GBIT32(b->hdr.nb);
425                 writen(netfd, (uchar *)&b->hdr, Hdrsz);
426                 writen(netfd, b->buf, n);
427                 sendp(tmp, b);
428         }
429         chanfree(unacked);
430         unacked = tmp;
431 }
432
433 static void
434 showmsg(int level, char *s, Buf *b)
435 {
436         int n;
437
438         if (b == nil) {
439                 dmessage(level, "%s; b == nil\n", s);
440                 return;
441         }
442         n = GBIT32(b->hdr.nb);
443         dmessage(level, "%s;  (len %d) %X %X %X %X %X %X %X %X %X (%p)\n", s, n, 
444                         b->buf[0], b->buf[1], b->buf[2],
445                         b->buf[3], b->buf[4], b->buf[5],
446                         b->buf[6], b->buf[7], b->buf[8], b);
447 }
448
449 static int
450 writen(int fd, uchar *buf, int nb)
451 {
452         int len = nb;
453
454         while (nb > 0) {
455                 int n;
456
457                 if (fd < 0) 
458                         return -1;
459
460                 if ((n = write(fd, buf, nb)) < 0) {
461                         dmessage(1, "writen; Write failed; %r\n");
462                         return -1;
463                 }
464                 dmessage(2, "writen: wrote %d bytes\n", n);
465
466                 buf += n;
467                 nb -= n;
468         }
469         return len;
470 }
471
472 static void
473 timerproc(void *x)
474 {
475         Channel *timer = x;
476
477         threadsetname("timer");
478
479         while (!done) {
480                 sleep((Synctime / MS(1)) >> 1);
481                 sendp(timer, "timer");
482         }
483 }
484
485 static void
486 dmessage(int level, char *fmt, ...)
487 {
488         va_list arg; 
489
490         if (level > debug) 
491                 return;
492
493         va_start(arg, fmt);
494         vfprint(2, fmt, arg);
495         va_end(arg);
496 }