]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/aan.c
ip/cifsd: dont return garbage in upper 32 bit of unix extension stat fields
[plan9front.git] / sys / src / cmd / aan.c
1 #include <u.h>
2 #include <libc.h>
3 #include <auth.h>
4 #include <fcall.h>
5 #include <thread.h>
6
7 #define NS(x)   ((vlong)x)
8 #define US(x)   (NS(x) * 1000LL)
9 #define MS(x)   (US(x) * 1000LL)
10 #define S(x)    (MS(x) * 1000LL)
11
12 enum {
13         Synctime = S(8),
14         Nbuf = 10,
15         K = 1024,
16         Bufsize = 8 * K,
17         Stacksize = 8 * K,
18         Timer = 0,                              // Alt channels.
19         Unsent = 1,
20         Maxto = 24 * 3600,                      // A full day to reconnect.
21         Hdrsz = 3*4,
22 };
23
24 typedef struct {
25         uchar   nb[4];          // Number of data bytes in this message
26         uchar   msg[4];         // Message number
27         uchar   acked[4];       // Number of messages acked
28 } Hdr;
29
30 typedef struct {
31         Hdr     hdr;
32         uchar   buf[Bufsize];
33 } Buf;
34
35 static Channel  *unsent;
36 static Channel  *unacked;
37 static Channel  *empty;
38 static int      netfd;
39 static ulong    inmsg;
40 static ulong    outmsg;
41 static char     *devdir;
42 static int      debug;
43 static int      done;
44 static char     *dialstring;
45 static int      maxto = Maxto;
46 static char     *Logname = "aan";
47 static int      client;
48 static int      reader = -1;
49 static int      lostsync;
50
51 static Alt a[] = {
52         /*      c       v        op   */
53         {       nil,    nil,    CHANRCV                 },      // timer
54         {       nil,    nil,    CHANRCV                 },      // unsent
55         {       nil,    nil,    CHANEND         },
56 };
57
58 static void             fromnet(void*);
59 static void             fromclient(void*);
60 static int              reconnect(int);
61 static void             synchronize(void);
62 static int              writen(int, uchar *, int);
63 static void             timerproc(void *);
64
65 static void
66 usage(void)
67 {
68         fprint(2, "Usage: %s [-cd] [-m maxto] dialstring|netdir\n", argv0);
69         exits("usage");
70 }
71
72  
73 static int
74 catch(void *, char *s)
75 {
76         if (!strcmp(s, "alarm")) {
77                 syslog(0, Logname, "Timed out while waiting for reconnect, exiting...");
78                 threadexitsall(nil);
79         }
80         return 0;
81 }
82
83 static void*
84 emalloc(int n)
85 {
86         uintptr pc;
87         void *v;
88
89         pc = getcallerpc(&n);
90         v = malloc(n);
91         if(v == nil)
92                 sysfatal("Cannot allocate memory; pc=%#p", pc);
93         setmalloctag(v, pc);
94         return v;
95 }
96
97 void
98 threadmain(int argc, char **argv)
99 {
100         vlong synctime;
101         int i, n, failed;
102         Channel *timer;
103         Hdr hdr;
104         Buf *b;
105
106         ARGBEGIN {
107         case 'c':
108                 client++;
109                 break;
110         case 'd':
111                 debug++;
112                 break;
113         case 'm':
114                 maxto = (int)strtol(EARGF(usage()), nil, 0);
115                 break;
116         default:
117                 usage();
118         } ARGEND;
119
120         if (argc != 1)
121                 usage();
122
123         if (!client) {
124                 char *p;
125
126                 devdir = argv[0];
127                 if ((p = strstr(devdir, "/local")) != nil)
128                         *p = '\0';
129         }
130         else
131                 dialstring = argv[0];
132
133         if (debug > 0) {
134                 int fd = open("#c/cons", OWRITE|OCEXEC);        
135                 dup(fd, 2);
136         }
137
138         atnotify(catch, 1);
139
140         /*
141          * Set up initial connection. use short timeout
142          * of 60 seconds so we wont hang arround for too
143          * long if there is some general connection problem
144          * (like NAT).
145          */
146         netfd = reconnect(60);
147
148         unsent = chancreate(sizeof(Buf *), Nbuf);
149         unacked = chancreate(sizeof(Buf *), Nbuf);
150         empty = chancreate(sizeof(Buf *), Nbuf);
151         timer = chancreate(sizeof(uchar *), 1);
152         if(unsent == nil || unacked == nil || empty == nil || timer == nil)
153                 sysfatal("Cannot allocate channels");
154
155         for (i = 0; i < Nbuf; i++)
156                 sendp(empty, emalloc(sizeof(Buf)));
157
158         reader = proccreate(fromnet, nil, Stacksize);
159         if (reader < 0)
160                 sysfatal("Cannot start fromnet; %r");
161
162         if (proccreate(fromclient, nil, Stacksize) < 0)
163                 sysfatal("Cannot start fromclient; %r");
164
165         if (proccreate(timerproc, timer, Stacksize) < 0)
166                 sysfatal("Cannot start timerproc; %r");
167
168         a[Timer].c = timer;
169         a[Unsent].c = unsent;
170         a[Unsent].v = &b;
171
172 Restart:
173         synctime = nsec() + Synctime;
174         failed = 0;
175         lostsync = 0;
176         while (!done) {
177                 if (netfd < 0 || failed) {
178                         // Wait for the netreader to die.
179                         while (netfd >= 0) {
180                                 if(debug) fprint(2, "main; waiting for netreader to die\n");
181                                 threadint(reader);
182                                 sleep(1000);
183                         }
184
185                         // the reader died; reestablish the world.
186                         netfd = reconnect(maxto);
187                         synchronize();
188                         goto Restart;
189                 }
190
191                 switch (alt(a)) {
192                 case Timer:
193                         if (netfd < 0 || nsec() < synctime)
194                                 break;
195
196                         PBIT32(hdr.nb, 0);
197                         PBIT32(hdr.acked, inmsg);
198                         PBIT32(hdr.msg, -1);
199
200                         if (writen(netfd, (uchar *)&hdr, Hdrsz) < 0) {
201                                 failed = 1;
202                                 continue;
203                         }
204
205                         if(++lostsync > 2){
206                                 syslog(0, Logname, "connection seems hung up...");
207                                 failed = 1;
208                                 continue;
209                         }
210                         synctime = nsec() + Synctime;
211                         break;
212
213                 case Unsent:
214                         sendp(unacked, b);
215
216                         if (netfd < 0)
217                                 break;
218
219                         PBIT32(b->hdr.acked, inmsg);
220
221                         if (writen(netfd, (uchar *)&b->hdr, Hdrsz) < 0)
222                                 failed = 1;
223                         else {
224                                 n = GBIT32(b->hdr.nb);
225                                 if (writen(netfd, b->buf, n) < 0)
226                                         failed = 1;
227                                 if (n == 0)
228                                         done = 1;
229                         }
230                         break;
231                 }
232         }
233         syslog(0, Logname, "exiting...");
234         threadexitsall(nil);
235 }
236
237
238 static void
239 fromclient(void*)
240 {
241         int n;
242         Buf *b;
243
244         threadsetname("fromclient");
245
246         do {
247                 b = recvp(empty);
248                 n = read(0, b->buf, Bufsize);
249                 if (n < 0)
250                         n = 0;
251                 PBIT32(b->hdr.nb, n);
252                 PBIT32(b->hdr.msg, outmsg);
253                 sendp(unsent, b);
254                 outmsg++;
255         } while(n > 0);
256 }
257
258 static void
259 fromnet(void*)
260 {
261         extern void _threadnote(void *, char *);
262         ulong m, acked, lastacked = 0;
263         int n, len;
264         Buf *b;
265
266         notify(_threadnote);
267
268         threadsetname("fromnet");
269
270         b = emalloc(sizeof(Buf));
271         while (!done) {
272                 while (netfd < 0) {
273                         if(done)
274                                 return;
275                         if(debug) fprint(2, "fromnet; waiting for connection... (inmsg %lud)\n", inmsg);
276                         sleep(1000);
277                 }
278
279                 // Read the header.
280                 len = readn(netfd, (uchar *)&b->hdr, Hdrsz);
281                 if (len <= 0) {
282                         if (debug) {
283                                 if (len < 0)
284                                         fprint(2, "fromnet; (hdr) network failure; %r\n");
285                                 else
286                                         fprint(2, "fromnet; (hdr) network closed\n");
287                         }
288                         close(netfd);
289                         netfd = -1;
290                         continue;
291                 }
292                 lostsync = 0;   // reset timeout
293                 n = GBIT32(b->hdr.nb);
294                 m = GBIT32(b->hdr.msg);
295                 acked = GBIT32(b->hdr.acked);
296                 if (n == 0) {
297                         if (m == (ulong)-1)
298                                 continue;
299                         if(debug) fprint(2, "fromnet; network closed\n");
300                         break;
301                 } else if (n < 0 || n > Bufsize) {
302                         if(debug) fprint(2, "fromnet; message too big %d > %d\n", n, Bufsize);
303                         break;
304                 }
305
306                 len = readn(netfd, b->buf, n);
307                 if (len <= 0 || len != n) {
308                         if (len == 0)
309                                 if(debug) fprint(2, "fromnet; network closed\n");
310                         else
311                                 if(debug) fprint(2, "fromnet; network failure; %r\n");
312                         close(netfd);
313                         netfd = -1;
314                         continue;
315                 }
316
317                 if (m != inmsg) {
318                         if(debug) fprint(2, "fromnet; skipping message %lud, currently at %lud\n", m, inmsg);
319                         continue;
320                 }                       
321                 inmsg++;
322
323                 // Process the acked list.
324                 while((long)(acked - lastacked) > 0) {
325                         Buf *rb;
326
327                         if((rb = recvp(unacked)) == nil)
328                                 break;
329                         m = GBIT32(rb->hdr.msg);
330                         if (m != lastacked) {
331                                 if(debug) fprint(2, "fromnet; rb %p, msg %lud, lastacked %lud\n", rb, m, lastacked);
332                                 sysfatal("fromnet; bug");
333                         }
334                         PBIT32(rb->hdr.msg, -1);
335                         sendp(empty, rb);
336                         lastacked++;
337                 } 
338
339                 if (writen(1, b->buf, len) < 0) 
340                         sysfatal("fromnet; cannot write to client; %r");
341         }
342         done = 1;
343 }
344
345 static int
346 reconnect(int secs)
347 {
348         NetConnInfo *nci;
349         char ldir[40];
350         int lcfd, fd;
351
352         if (dialstring) {
353                 syslog(0, Logname, "dialing %s", dialstring);
354                 alarm(secs*1000);
355                 while ((fd = dial(dialstring, nil, ldir, nil)) < 0) {
356                         char err[32];
357
358                         err[0] = '\0';
359                         errstr(err, sizeof err);
360                         if (strstr(err, "connection refused")) {
361                                 if(debug) fprint(2, "reconnect; server died...\n");
362                                 threadexitsall("server died...");
363                         }
364                         if(debug) fprint(2, "reconnect: dialed %s; %s\n", dialstring, err);
365                         sleep(1000);
366                 }
367                 alarm(0);
368                 syslog(0, Logname, "reconnected to %s", dialstring);
369         } 
370         else {
371                 syslog(0, Logname, "waiting for connection on %s", devdir);
372                 alarm(secs*1000);
373                 if ((lcfd = listen(devdir, ldir)) < 0) 
374                         sysfatal("reconnect; cannot listen; %r");
375                 if ((fd = accept(lcfd, ldir)) < 0)
376                         sysfatal("reconnect; cannot accept; %r");
377                 alarm(0);
378                 close(lcfd);
379         }
380
381         if(nci = getnetconninfo(ldir, fd)){
382                 syslog(0, Logname, "connected from %s", nci->rsys);
383                 threadsetname(client? "client %s %s" : "server %s %s", ldir, nci->rsys);
384                 freenetconninfo(nci);
385         } else
386                 syslog(0, Logname, "connected");
387
388         return fd;
389 }
390
391 static void
392 synchronize(void)
393 {
394         Channel *tmp;
395         Buf *b;
396         int n;
397
398         // Ignore network errors here.  If we fail during 
399         // synchronization, the next alarm will pick up 
400         // the error.
401
402         tmp = chancreate(sizeof(Buf *), Nbuf);
403         while ((b = nbrecvp(unacked)) != nil) {
404                 n = GBIT32(b->hdr.nb);
405                 writen(netfd, (uchar *)&b->hdr, Hdrsz);
406                 writen(netfd, b->buf, n);
407                 sendp(tmp, b);
408         }
409         chanfree(unacked);
410         unacked = tmp;
411 }
412
413 static int
414 writen(int fd, uchar *buf, int nb)
415 {
416         int len = nb;
417
418         while (nb > 0) {
419                 int n;
420
421                 if (fd < 0) 
422                         return -1;
423
424                 if ((n = write(fd, buf, nb)) < 0) {
425                         if(debug) fprint(2, "writen; Write failed; %r\n");
426                         return -1;
427                 }
428
429                 buf += n;
430                 nb -= n;
431         }
432         return len;
433 }
434
435 static void
436 timerproc(void *x)
437 {
438         Channel *timer = x;
439
440         threadsetname("timer");
441
442         while (!done) {
443                 sleep((Synctime / MS(1)) >> 1);
444                 sendp(timer, "timer");
445         }
446 }