]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/aan.c
9bootfat: rename open() to fileinit and make it static as its really a internal funct...
[plan9front.git] / sys / src / cmd / aan.c
1 #include <u.h>
2 #include <libc.h>
3 #include <auth.h>
4 #include <fcall.h>
5 #include <thread.h>
6
7 #define NS(x)           ((vlong)x)
8 #define US(x)           (NS(x) * 1000LL)
9 #define MS(x)           (US(x) * 1000LL)
10 #define S(x)            (MS(x) * 1000LL)
11
12 #define LOGNAME "aan"
13
14 enum {
15         Synctime = S(8),
16         Nbuf = 10,
17         K = 1024,
18         Bufsize = 8 * K,
19         Stacksize = 8 * K,
20         Timer = 0,                                      // Alt channels.
21         Unsent = 1,
22         Maxto = 24 * 3600,                      // A full day to reconnect.
23 };
24
25 typedef struct Endpoints Endpoints;
26 struct Endpoints {
27         char    *lsys;
28         char    *lserv;
29         char    *rsys;
30         char    *rserv;
31 };
32
33 typedef struct {
34         ulong           nb;             // Number of data bytes in this message
35         ulong           msg;            // Message number
36         ulong           acked;  // Number of messages acked
37 } Hdr;
38
39 typedef struct t_Buf {
40         Hdr                     hdr;
41         uchar           buf[Bufsize];
42 } Buf;
43
44 static char     *progname;
45 static Channel  *unsent;
46 static Channel  *unacked;
47 static Channel  *empty;
48 static int              netfd;
49 static int              inmsg;
50 static char     *devdir;
51 static int              debug;
52 static int              done;
53 static char     *dialstring;
54 static int              maxto = Maxto;
55 static char     *Logname = LOGNAME;
56 static int              client;
57
58 static Alt a[] = {
59         /*      c       v        op   */
60         {       nil,    nil,    CHANRCV                 },      // timer
61         {       nil,    nil,    CHANRCV                 },      // unsent
62         {       nil,    nil,    CHANEND         },
63 };
64
65 static void             fromnet(void*);
66 static void             fromclient(void*);
67 static void             reconnect(void);
68 static void             synchronize(void);
69 static int              sendcommand(ulong, ulong);
70 static void             showmsg(int, char *, Buf *);
71 static int              writen(int, uchar *, int);
72 static int              getport(char *);
73 static void             dmessage(int, char *, ...);
74 static void             timerproc(void *);
75 static Endpoints *getendpoints(char *);
76 static void             freeendpoints(Endpoints *);
77
78 static void
79 usage(void)
80 {
81         fprint(2, "Usage: %s [-cd] [-m maxto] dialstring|netdir\n", progname);
82         exits("usage");
83 }
84
85 static int
86 catch(void *, char *s)
87 {
88         if (!strcmp(s, "alarm")) {
89                 syslog(0, Logname, "Timed out while waiting for client on %s, exiting...",
90                            devdir);
91                 threadexitsall(nil);
92         }
93         return 0;
94 }
95
96 void
97 threadmain(int argc, char **argv)
98 {
99         int i, failed;
100         Buf *b;
101         Channel *timer;
102         vlong synctime;
103
104         progname = argv[0];
105         ARGBEGIN {
106         case 'c':
107                 client++;
108                 break;
109         case 'd':
110                 debug++;
111                 break;
112         case 'm':
113                 maxto = (int)strtol(EARGF(usage()), (char **)nil, 0);
114                 break;
115         default:
116                 usage();
117         } ARGEND;
118
119         if (argc != 1)
120                 usage();
121
122         if (!client) {
123                 char *p;
124
125                 devdir = argv[0];
126                 if ((p = strstr(devdir, "/local")) != nil)
127                         *p = '\0';
128         }
129         else
130                 dialstring = argv[0];
131
132         if (debug > 0) {
133                 int fd = open("#c/cons", OWRITE|OCEXEC);        
134                 dup(fd, 2);
135         }
136
137         fmtinstall('F', fcallfmt);
138
139         atnotify(catch, 1);
140
141         unsent = chancreate(sizeof(Buf *), Nbuf);
142         unacked = chancreate(sizeof(Buf *), Nbuf);
143         empty = chancreate(sizeof(Buf *), Nbuf);
144         timer = chancreate(sizeof(uchar *), 1);
145
146         for (i = 0; i != Nbuf; i++) {
147                 Buf *b = malloc(sizeof(Buf));
148                 sendp(empty, b);
149         }
150
151         netfd = -1;
152
153         if (proccreate(fromnet, nil, Stacksize) < 0)
154                 sysfatal("%s; Cannot start fromnet; %r", progname);
155
156         reconnect();            // Set up the initial connection.
157         synchronize();
158
159         if (proccreate(fromclient, nil, Stacksize) < 0)
160                 sysfatal("%s; Cannot start fromclient; %r", progname);
161
162         if (proccreate(timerproc, timer, Stacksize) < 0)
163                 sysfatal("%s; Cannot start timerproc; %r", progname);
164
165         a[Timer].c = timer;
166         a[Unsent].c = unsent;
167         a[Unsent].v = &b;
168
169         synctime = nsec() + Synctime;
170         failed = 0;
171         while (!done) {
172                 vlong now;
173                 int delta;
174
175                 if (failed) {
176                         // Wait for the netreader to die.
177                         while (netfd >= 0) {
178                                 dmessage(1, "main; waiting for netreader to die\n");
179                                 sleep(1000);
180                         }
181
182                         // the reader died; reestablish the world.
183                         reconnect();
184                         synchronize();
185                         failed = 0;
186                 }
187
188                 now = nsec();
189                 delta = (synctime - nsec()) / MS(1);
190
191                 if (delta <= 0) {
192                         Hdr hdr;
193
194                         hdr.nb = 0;
195                         hdr.acked = inmsg;
196                         hdr.msg = -1;
197
198                         if (writen(netfd, (uchar *)&hdr, sizeof(Hdr)) < 0) {
199                                 dmessage(2, "main; writen failed; %r\n");
200                                 failed = 1;
201                                 continue;
202                         }
203                         synctime = nsec() + Synctime;
204                         assert(synctime > now);
205                 }
206
207                 switch (alt(a)) {
208                 case Timer:
209                         break;
210
211                 case Unsent:
212                         sendp(unacked, b);
213
214                         b->hdr.acked = inmsg;
215
216                         if (writen(netfd, (uchar *)&b->hdr, sizeof(Hdr)) < 0) {
217                                 dmessage(2, "main; writen failed; %r\n");
218                                 failed = 1;
219                         }
220
221                         if (writen(netfd, b->buf, b->hdr.nb) < 0) {
222                                 dmessage(2, "main; writen failed; %r\n");
223                                 failed = 1;
224                         }
225
226                         if (b->hdr.nb == 0)
227                                 done = 1;
228                         break;
229                 }
230         }
231         syslog(0, Logname, "exiting...");
232         threadexitsall(nil);
233 }
234
235
236 static void
237 fromclient(void*)
238 {
239         static int outmsg;
240
241         for (;;) {
242                 Buf *b;
243
244                 b = recvp(empty);       
245                 if ((int)(b->hdr.nb = read(0, b->buf, Bufsize)) <= 0) {
246                         if ((int)b->hdr.nb < 0)
247                                 dmessage(2, "fromclient; Cannot read 9P message; %r\n");
248                         else
249                                 dmessage(2, "fromclient; Client terminated\n");
250                         b->hdr.nb = 0;
251                 }
252                 b->hdr.msg = outmsg++;
253
254                 showmsg(1, "fromclient", b);
255                 sendp(unsent, b);
256                 
257                 if (b->hdr.nb == 0)
258                         break;
259         }
260 }
261
262 static void
263 fromnet(void*)
264 {
265         static int lastacked;
266         Buf *b;
267
268         b = (Buf *)malloc(sizeof(Buf));
269         assert(b);
270
271         while (!done) {
272                 int len, acked, i;
273
274                 while (netfd < 0) {
275                         dmessage(1, "fromnet; waiting for connection... (inmsg %d)\n", 
276                                           inmsg);
277                         sleep(1000);
278                 }
279
280                 // Read the header.
281                 if ((len = readn(netfd, &b->hdr, sizeof(Hdr))) <= 0) {
282                         if (len < 0)
283                                 dmessage(1, "fromnet; (hdr) network failure; %r\n");
284                         else
285                                 dmessage(1, "fromnet; (hdr) network closed\n");
286                         close(netfd);
287                         netfd = -1;
288                         continue;
289                 }
290                 dmessage(2, "fromnet: Got message, size %d, nb %d, msg %d\n", len,
291                                 b->hdr.nb, b->hdr.msg);
292
293                 if (b->hdr.nb == 0) {
294                         if  ((long)b->hdr.msg >= 0) {
295                                 dmessage(1, "fromnet; network closed\n");
296                                 break;
297                         }
298                         continue;
299                 }
300         
301                 if ((len = readn(netfd, b->buf, b->hdr.nb)) <= 0 || len != b->hdr.nb) {
302                         if (len == 0)
303                                 dmessage(1, "fromnet; network closed\n");
304                         else
305                                 dmessage(1, "fromnet; network failure; %r\n");
306                         close(netfd);
307                         netfd = -1;
308                         continue;
309                 }
310
311                 if (b->hdr.msg < inmsg) {
312                         dmessage(1, "fromnet; skipping message %d, currently at %d\n",
313                                          b->hdr.msg, inmsg);
314                         continue;
315                 }                       
316
317                 // Process the acked list.
318                 acked = b->hdr.acked - lastacked;
319                 for (i = 0; i != acked; i++) {
320                         Buf *rb;
321
322                         rb = recvp(unacked);
323                         if (rb->hdr.msg != lastacked + i) {
324                                 dmessage(1, "rb %p, msg %d, lastacked %d, i %d\n",
325                                                 rb, rb? rb->hdr.msg: -2, lastacked, i);
326                                 assert(0);
327                         }
328                         rb->hdr.msg = -1;
329                         sendp(empty, rb);
330                 } 
331                 lastacked = b->hdr.acked;
332
333                 inmsg++;
334
335                 showmsg(1, "fromnet", b);
336
337                 if (writen(1, b->buf, len) < 0) 
338                         sysfatal("fromnet; cannot write to client; %r");
339         }
340         done = 1;
341 }
342
343 static void
344 reconnect(void)
345 {
346         char ldir[40];
347         int lcfd, fd;
348
349         if (dialstring) {
350                 syslog(0, Logname, "dialing %s", dialstring);
351                 while ((fd = dial(dialstring, nil, nil, nil)) < 0) {
352                         char err[32];
353
354                         err[0] = '\0';
355                         errstr(err, sizeof err);
356                         if (strstr(err, "connection refused")) {
357                                 dmessage(1, "reconnect; server died...\n");
358                                 threadexitsall("server died...");
359                         }
360                         dmessage(1, "reconnect: dialed %s; %s\n", dialstring, err);
361                         sleep(1000);
362                 }
363                 syslog(0, Logname, "reconnected to %s", dialstring);
364         } 
365         else {
366                 Endpoints *ep;
367
368                 syslog(0, Logname, "waiting for connection on %s", devdir);
369                 alarm(maxto * 1000);
370                 if ((lcfd = listen(devdir, ldir)) < 0) 
371                         sysfatal("reconnect; cannot listen; %r");
372         
373                 if ((fd = accept(lcfd, ldir)) < 0)
374                         sysfatal("reconnect; cannot accept; %r");
375                 alarm(0);
376                 close(lcfd);
377                 
378                 ep = getendpoints(ldir);
379                 dmessage(1, "rsys '%s'\n", ep->rsys);
380                 syslog(0, Logname, "connected from %s", ep->rsys);
381                 freeendpoints(ep);
382         }
383         
384         netfd = fd;             // Wakes up the netreader.
385 }
386
387 static void
388 synchronize(void)
389 {
390         Channel *tmp;
391         Buf *b;
392
393         // Ignore network errors here.  If we fail during 
394         // synchronization, the next alarm will pick up 
395         // the error.
396
397         tmp = chancreate(sizeof(Buf *), Nbuf);
398         while ((b = nbrecvp(unacked)) != nil) {
399                 writen(netfd, (uchar *)b, sizeof(Hdr) + b->hdr.nb);
400                 sendp(tmp, b);
401         }
402         chanfree(unacked);
403         unacked = tmp;
404 }
405
406 static void
407 showmsg(int level, char *s, Buf *b)
408 {
409         if (b == nil) {
410                 dmessage(level, "%s; b == nil\n", s);
411                 return;
412         }
413
414         dmessage(level, 
415                         "%s;  (len %d) %X %X %X %X %X %X %X %X %X (%p)\n", s, 
416                         b->hdr.nb, 
417                         b->buf[0], b->buf[1], b->buf[2],
418                         b->buf[3], b->buf[4], b->buf[5],
419                         b->buf[6], b->buf[7], b->buf[8], b);
420 }
421
422 static int
423 writen(int fd, uchar *buf, int nb)
424 {
425         int len = nb;
426
427         while (nb > 0) {
428                 int n;
429
430                 if (fd < 0) 
431                         return -1;
432
433                 if ((n = write(fd, buf, nb)) < 0) {
434                         dmessage(1, "writen; Write failed; %r\n");
435                         return -1;
436                 }
437                 dmessage(2, "writen: wrote %d bytes\n", n);
438
439                 buf += n;
440                 nb -= n;
441         }
442         return len;
443 }
444
445 static void
446 timerproc(void *x)
447 {
448         Channel *timer = x;
449         while (!done) {
450                 sleep((Synctime / MS(1)) >> 1);
451                 sendp(timer, "timer");
452         }
453 }
454
455 static void
456 dmessage(int level, char *fmt, ...)
457 {
458         va_list arg; 
459
460         if (level > debug) 
461                 return;
462
463         va_start(arg, fmt);
464         vfprint(2, fmt, arg);
465         va_end(arg);
466 }
467
468 static void
469 getendpoint(char *dir, char *file, char **sysp, char **servp)
470 {
471         int fd, n;
472         char buf[128];
473         char *sys, *serv;
474
475         sys = serv = 0;
476
477         snprint(buf, sizeof buf, "%s/%s", dir, file);
478         fd = open(buf, OREAD);
479         if(fd >= 0){
480                 n = read(fd, buf, sizeof(buf)-1);
481                 if(n>0){
482                         buf[n-1] = 0;
483                         serv = strchr(buf, '!');
484                         if(serv){
485                                 *serv++ = 0;
486                                 serv = strdup(serv);
487                         }
488                         sys = strdup(buf);
489                 }
490                 close(fd);
491         }
492         if(serv == 0)
493                 serv = strdup("unknown");
494         if(sys == 0)
495                 sys = strdup("unknown");
496         *servp = serv;
497         *sysp = sys;
498 }
499
500 static Endpoints *
501 getendpoints(char *dir)
502 {
503         Endpoints *ep;
504
505         ep = malloc(sizeof(*ep));
506         getendpoint(dir, "local", &ep->lsys, &ep->lserv);
507         getendpoint(dir, "remote", &ep->rsys, &ep->rserv);
508         return ep;
509 }
510
511 static void
512 freeendpoints(Endpoints *ep)
513 {
514         free(ep->lsys);
515         free(ep->rsys);
516         free(ep->lserv);
517         free(ep->rserv);
518         free(ep);
519 }
520