]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/aan.c
rsa: rename getkey() to getrsakey(), document rsa2csr in rsa(8)
[plan9front.git] / sys / src / cmd / aan.c
1 #include <u.h>
2 #include <libc.h>
3 #include <auth.h>
4 #include <fcall.h>
5 #include <thread.h>
6
7 #define NS(x)   ((vlong)x)
8 #define US(x)   (NS(x) * 1000LL)
9 #define MS(x)   (US(x) * 1000LL)
10 #define S(x)    (MS(x) * 1000LL)
11
12 enum {
13         Synctime = S(8),
14         Nbuf = 10,
15         K = 1024,
16         Bufsize = 8 * K,
17         Stacksize = 8 * K,
18         Timer = 0,                              // Alt channels.
19         Unsent = 1,
20         Maxto = 24 * 3600,                      // A full day to reconnect.
21         Hdrsz = 3*4,
22 };
23
24 typedef struct {
25         uchar   nb[4];          // Number of data bytes in this message
26         uchar   msg[4];         // Message number
27         uchar   acked[4];       // Number of messages acked
28 } Hdr;
29
30 typedef struct {
31         Hdr     hdr;
32         uchar   buf[Bufsize];
33 } Buf;
34
35 static Channel  *unsent;
36 static Channel  *unacked;
37 static Channel  *empty;
38 static int      netfd;
39 static int      inmsg;
40 static char     *devdir;
41 static int      debug;
42 static int      done;
43 static char     *dialstring;
44 static int      maxto = Maxto;
45 static char     *Logname = "aan";
46 static int      client;
47 static int      reader = -1;
48 static int      lostsync;
49
50 static Alt a[] = {
51         /*      c       v        op   */
52         {       nil,    nil,    CHANRCV                 },      // timer
53         {       nil,    nil,    CHANRCV                 },      // unsent
54         {       nil,    nil,    CHANEND         },
55 };
56
57 static void             fromnet(void*);
58 static void             fromclient(void*);
59 static int              reconnect(int);
60 static void             synchronize(void);
61 static int              sendcommand(ulong, ulong);
62 static void             showmsg(int, char *, Buf *);
63 static int              writen(int, uchar *, int);
64 static void             dmessage(int, char *, ...);
65 static void             timerproc(void *);
66
67 static void
68 usage(void)
69 {
70         fprint(2, "Usage: %s [-cd] [-m maxto] dialstring|netdir\n", argv0);
71         exits("usage");
72 }
73
74  
75 static int
76 catch(void *, char *s)
77 {
78         if (!strcmp(s, "alarm")) {
79                 syslog(0, Logname, "Timed out while waiting for reconnect, exiting...");
80                 threadexitsall(nil);
81         }
82         return 0;
83 }
84
85 static void*
86 emalloc(int n)
87 {
88         ulong pc;
89         void *v;
90
91         pc = getcallerpc(&n);
92         v = malloc(n);
93         if(v == nil)
94                 sysfatal("Cannot allocate memory; pc=%lux", pc);
95         setmalloctag(v, pc);
96         return v;
97 }
98
99 void
100 threadmain(int argc, char **argv)
101 {
102         vlong synctime;
103         int i, n, failed;
104         Channel *timer;
105         Hdr hdr;
106         Buf *b;
107
108         ARGBEGIN {
109         case 'c':
110                 client++;
111                 break;
112         case 'd':
113                 debug++;
114                 break;
115         case 'm':
116                 maxto = (int)strtol(EARGF(usage()), nil, 0);
117                 break;
118         default:
119                 usage();
120         } ARGEND;
121
122         if (argc != 1)
123                 usage();
124
125         if (!client) {
126                 char *p;
127
128                 devdir = argv[0];
129                 if ((p = strstr(devdir, "/local")) != nil)
130                         *p = '\0';
131         }
132         else
133                 dialstring = argv[0];
134
135         if (debug > 0) {
136                 int fd = open("#c/cons", OWRITE|OCEXEC);        
137                 dup(fd, 2);
138         }
139
140         fmtinstall('F', fcallfmt);
141
142         atnotify(catch, 1);
143
144         /*
145          * Set up initial connection. use short timeout
146          * of 60 seconds so we wont hang arround for too
147          * long if there is some general connection problem
148          * (like NAT).
149          */
150         netfd = reconnect(60);
151
152         unsent = chancreate(sizeof(Buf *), Nbuf);
153         unacked = chancreate(sizeof(Buf *), Nbuf);
154         empty = chancreate(sizeof(Buf *), Nbuf);
155         timer = chancreate(sizeof(uchar *), 1);
156         if(unsent == nil || unacked == nil || empty == nil || timer == nil)
157                 sysfatal("Cannot allocate channels");
158
159         for (i = 0; i < Nbuf; i++)
160                 sendp(empty, emalloc(sizeof(Buf)));
161
162         reader = proccreate(fromnet, nil, Stacksize);
163         if (reader < 0)
164                 sysfatal("Cannot start fromnet; %r");
165
166         if (proccreate(fromclient, nil, Stacksize) < 0)
167                 sysfatal("Cannot start fromclient; %r");
168
169         if (proccreate(timerproc, timer, Stacksize) < 0)
170                 sysfatal("Cannot start timerproc; %r");
171
172         a[Timer].c = timer;
173         a[Unsent].c = unsent;
174         a[Unsent].v = &b;
175
176 Restart:
177         synctime = nsec() + Synctime;
178         failed = 0;
179         lostsync = 0;
180         while (!done) {
181                 if (netfd < 0 || failed) {
182                         // Wait for the netreader to die.
183                         while (netfd >= 0) {
184                                 dmessage(1, "main; waiting for netreader to die\n");
185                                 threadint(reader);
186                                 sleep(1000);
187                         }
188
189                         // the reader died; reestablish the world.
190                         netfd = reconnect(maxto);
191                         synchronize();
192                         goto Restart;
193                 }
194
195                 switch (alt(a)) {
196                 case Timer:
197                         if (netfd < 0 || nsec() < synctime)
198                                 break;
199
200                         PBIT32(hdr.nb, 0);
201                         PBIT32(hdr.acked, inmsg);
202                         PBIT32(hdr.msg, -1);
203
204                         if (writen(netfd, (uchar *)&hdr, Hdrsz) < 0) {
205                                 dmessage(2, "main; writen failed; %r\n");
206                                 failed = 1;
207                                 continue;
208                         }
209
210                         if(++lostsync > 2){
211                                 syslog(0, Logname, "connection seems hung up...");
212                                 failed = 1;
213                                 continue;
214                         }
215                         synctime = nsec() + Synctime;
216                         break;
217
218                 case Unsent:
219                         sendp(unacked, b);
220
221                         if (netfd < 0)
222                                 break;
223
224                         PBIT32(b->hdr.acked, inmsg);
225
226                         if (writen(netfd, (uchar *)&b->hdr, Hdrsz) < 0) {
227                                 dmessage(2, "main; writen failed; %r\n");
228                                 failed = 1;
229                         }
230
231                         n = GBIT32(b->hdr.nb);
232                         if (writen(netfd, b->buf, n) < 0) {
233                                 dmessage(2, "main; writen failed; %r\n");
234                                 failed = 1;
235                         }
236
237                         if (n == 0)
238                                 done = 1;
239                         break;
240                 }
241         }
242         syslog(0, Logname, "exiting...");
243         threadexitsall(nil);
244 }
245
246
247 static void
248 fromclient(void*)
249 {
250         static int outmsg;
251         int n;
252         Buf *b;
253
254         threadsetname("fromclient");
255
256         do {
257                 b = recvp(empty);
258                 n = read(0, b->buf, Bufsize);
259                 if (n <= 0) {
260                         if (n < 0)
261                                 dmessage(2, "fromclient; Cannot read 9P message; %r\n");
262                         else
263                                 dmessage(2, "fromclient; Client terminated\n");
264                         n = 0;
265                 }
266                 PBIT32(b->hdr.nb, n);
267                 PBIT32(b->hdr.msg, outmsg);
268                 showmsg(1, "fromclient", b);
269                 sendp(unsent, b);
270                 outmsg++;
271         } while(n > 0);
272 }
273
274 static void
275 fromnet(void*)
276 {
277         extern void _threadnote(void *, char *);
278         static int lastacked;
279         int n, m, len, acked;
280         Buf *b;
281
282         notify(_threadnote);
283
284         threadsetname("fromnet");
285
286         b = emalloc(sizeof(Buf));
287         while (!done) {
288                 while (netfd < 0) {
289                         if(done)
290                                 return;
291                         dmessage(1, "fromnet; waiting for connection... (inmsg %d)\n", inmsg);
292                         sleep(1000);
293                 }
294
295                 // Read the header.
296                 len = readn(netfd, (uchar *)&b->hdr, Hdrsz);
297                 if (len <= 0) {
298                         if (len < 0)
299                                 dmessage(1, "fromnet; (hdr) network failure; %r\n");
300                         else
301                                 dmessage(1, "fromnet; (hdr) network closed\n");
302                         close(netfd);
303                         netfd = -1;
304                         continue;
305                 }
306                 lostsync = 0;   // reset timeout
307                 n = GBIT32(b->hdr.nb);
308                 m = GBIT32(b->hdr.msg);
309                 acked = GBIT32(b->hdr.acked);
310                 dmessage(2, "fromnet: Got message, size %d, nb %d, msg %d, acked %d, lastacked %d\n",
311                         len, n, m, acked, lastacked);
312
313                 if (n == 0) {
314                         if (m >= 0) {
315                                 dmessage(1, "fromnet; network closed\n");
316                                 break;
317                         }
318                         continue;
319                 }
320
321                 if (n > Bufsize) {
322                         dmessage(1, "fromnet; message too big %d > %d\n", n, Bufsize);
323                         break;
324                 }
325
326                 len = readn(netfd, b->buf, n);
327                 if (len <= 0 || len != n) {
328                         if (len == 0)
329                                 dmessage(1, "fromnet; network closed\n");
330                         else
331                                 dmessage(1, "fromnet; network failure; %r\n");
332                         close(netfd);
333                         netfd = -1;
334                         continue;
335                 }
336
337                 if (m < inmsg) {
338                         dmessage(1, "fromnet; skipping message %d, currently at %d\n", m, inmsg);
339                         continue;
340                 }                       
341
342                 // Process the acked list.
343                 while(lastacked != acked) {
344                         Buf *rb;
345
346                         rb = recvp(unacked);
347                         m = GBIT32(rb->hdr.msg);
348                         if (m != lastacked) {
349                                 dmessage(1, "fromnet; rb %p, msg %d, lastacked %d\n", rb, m, lastacked);
350                                 sysfatal("fromnet; bug");
351                         }
352                         PBIT32(rb->hdr.msg, -1);
353                         sendp(empty, rb);
354                         lastacked++;
355                 } 
356                 inmsg++;
357
358                 showmsg(1, "fromnet", b);
359
360                 if (writen(1, b->buf, len) < 0) 
361                         sysfatal("fromnet; cannot write to client; %r");
362         }
363         done = 1;
364 }
365
366 static int
367 reconnect(int secs)
368 {
369         NetConnInfo *nci;
370         char ldir[40];
371         int lcfd, fd;
372
373         if (dialstring) {
374                 syslog(0, Logname, "dialing %s", dialstring);
375                 alarm(secs*1000);
376                 while ((fd = dial(dialstring, nil, ldir, nil)) < 0) {
377                         char err[32];
378
379                         err[0] = '\0';
380                         errstr(err, sizeof err);
381                         if (strstr(err, "connection refused")) {
382                                 dmessage(1, "reconnect; server died...\n");
383                                 threadexitsall("server died...");
384                         }
385                         dmessage(1, "reconnect: dialed %s; %s\n", dialstring, err);
386                         sleep(1000);
387                 }
388                 alarm(0);
389                 syslog(0, Logname, "reconnected to %s", dialstring);
390         } 
391         else {
392                 syslog(0, Logname, "waiting for connection on %s", devdir);
393                 alarm(secs*1000);
394                 if ((lcfd = listen(devdir, ldir)) < 0) 
395                         sysfatal("reconnect; cannot listen; %r");
396                 if ((fd = accept(lcfd, ldir)) < 0)
397                         sysfatal("reconnect; cannot accept; %r");
398                 alarm(0);
399                 close(lcfd);
400         }
401
402         if(nci = getnetconninfo(ldir, fd)){
403                 syslog(0, Logname, "connected from %s", nci->rsys);
404                 threadsetname(client? "client %s %s" : "server %s %s", ldir, nci->rsys);
405                 freenetconninfo(nci);
406         } else
407                 syslog(0, Logname, "connected");
408
409         return fd;
410 }
411
412 static void
413 synchronize(void)
414 {
415         Channel *tmp;
416         Buf *b;
417         int n;
418
419         // Ignore network errors here.  If we fail during 
420         // synchronization, the next alarm will pick up 
421         // the error.
422
423         tmp = chancreate(sizeof(Buf *), Nbuf);
424         while ((b = nbrecvp(unacked)) != nil) {
425                 n = GBIT32(b->hdr.nb);
426                 writen(netfd, (uchar *)&b->hdr, Hdrsz);
427                 writen(netfd, b->buf, n);
428                 sendp(tmp, b);
429         }
430         chanfree(unacked);
431         unacked = tmp;
432 }
433
434 static void
435 showmsg(int level, char *s, Buf *b)
436 {
437         int n;
438
439         if (b == nil) {
440                 dmessage(level, "%s; b == nil\n", s);
441                 return;
442         }
443         n = GBIT32(b->hdr.nb);
444         dmessage(level, "%s;  (len %d) %X %X %X %X %X %X %X %X %X (%p)\n", s, n, 
445                         b->buf[0], b->buf[1], b->buf[2],
446                         b->buf[3], b->buf[4], b->buf[5],
447                         b->buf[6], b->buf[7], b->buf[8], b);
448 }
449
450 static int
451 writen(int fd, uchar *buf, int nb)
452 {
453         int len = nb;
454
455         while (nb > 0) {
456                 int n;
457
458                 if (fd < 0) 
459                         return -1;
460
461                 if ((n = write(fd, buf, nb)) < 0) {
462                         dmessage(1, "writen; Write failed; %r\n");
463                         return -1;
464                 }
465                 dmessage(2, "writen: wrote %d bytes\n", n);
466
467                 buf += n;
468                 nb -= n;
469         }
470         return len;
471 }
472
473 static void
474 timerproc(void *x)
475 {
476         Channel *timer = x;
477
478         threadsetname("timer");
479
480         while (!done) {
481                 sleep((Synctime / MS(1)) >> 1);
482                 sendp(timer, "timer");
483         }
484 }
485
486 static void
487 dmessage(int level, char *fmt, ...)
488 {
489         va_list arg; 
490
491         if (level > debug) 
492                 return;
493
494         va_start(arg, fmt);
495         vfprint(2, fmt, arg);
496         va_end(arg);
497 }