]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/sshnet.c
sshnet: reduce memory consumption by lowering stack sizes
[plan9front.git] / sys / src / cmd / sshnet.c
1 /*
2  * SSH network file system.
3  * Presents remote TCP stack as /net-style file system.
4  */
5
6 #include <u.h>
7 #include <libc.h>
8 #include <bio.h>
9 #include <ndb.h>
10 #include <thread.h>
11 #include <fcall.h>
12 #include <9p.h>
13
14 typedef struct Client Client;
15 typedef struct Msg Msg;
16
17 enum
18 {
19         Qroot,
20         Qcs,
21         Qtcp,
22         Qclone,
23         Qn,
24         Qctl,
25         Qdata,
26         Qlocal,
27         Qremote,
28         Qstatus,
29 };
30
31 #define PATH(type, n)           ((type)|((n)<<8))
32 #define TYPE(path)              ((int)(path) & 0xFF)
33 #define NUM(path)               ((uint)(path)>>8)
34
35 Channel *ssherrchan;            /* chan(char*) */
36 Channel *sshmsgchan;            /* chan(Msg*) */
37 Channel *fsreqchan;             /* chan(Req*) */
38 Channel *fsreqwaitchan;         /* chan(nil) */
39 Channel *fsclunkchan;           /* chan(Fid*) */
40 Channel *fsclunkwaitchan;       /* chan(nil) */
41 ulong time0;
42
43 enum
44 {
45         Closed,
46         Dialing,
47         Established,
48         Teardown,
49 };
50
51 char *statestr[] = {
52         "Closed",
53         "Dialing",
54         "Established",
55         "Teardown",
56 };
57
58 struct Client
59 {
60         int ref;
61         int state;
62         int num;
63         int servernum;
64         char *connect;
65
66         int sendpkt;
67         int sendwin;
68         int recvwin;
69         int recvacc;
70
71         Req *wq;
72         Req **ewq;
73
74         Req *rq;
75         Req **erq;
76
77         Msg *mq;
78         Msg **emq;
79 };
80
81 enum {
82         MSG_CHANNEL_OPEN = 90,
83         MSG_CHANNEL_OPEN_CONFIRMATION,
84         MSG_CHANNEL_OPEN_FAILURE,
85         MSG_CHANNEL_WINDOW_ADJUST,
86         MSG_CHANNEL_DATA,
87         MSG_CHANNEL_EXTENDED_DATA,
88         MSG_CHANNEL_EOF,
89         MSG_CHANNEL_CLOSE,
90         MSG_CHANNEL_REQUEST,
91         MSG_CHANNEL_SUCCESS,
92         MSG_CHANNEL_FAILURE,
93
94         MaxPacket = 1<<15,
95         WinPackets = 8,
96
97         SESSIONCHAN = 1<<24,
98 };
99
100 struct Msg
101 {
102         Msg     *link;
103
104         uchar   *rp;
105         uchar   *wp;
106         uchar   *ep;
107         uchar   buf[MaxPacket];
108 };
109
110 #define PUT4(p, u) (p)[0] = (u)>>24, (p)[1] = (u)>>16, (p)[2] = (u)>>8, (p)[3] = (u)
111 #define GET4(p) (u32int)(p)[3] | (u32int)(p)[2]<<8 | (u32int)(p)[1]<<16 | (u32int)(p)[0]<<24
112
113 int nclient;
114 Client **client;
115 char *mtpt;
116 int sshfd;
117 int localport;
118 char localip[] = "::";
119
120 int
121 vpack(uchar *p, int n, char *fmt, va_list a)
122 {
123         uchar *p0 = p, *e = p+n;
124         u32int u;
125         void *s;
126         int c;
127
128         for(;;){
129                 switch(c = *fmt++){
130                 case '\0':
131                         return p - p0;
132                 case '_':
133                         if(++p > e) goto err;
134                         break;
135                 case '.':
136                         *va_arg(a, void**) = p;
137                         break;
138                 case 'b':
139                         if(p >= e) goto err;
140                         *p++ = va_arg(a, int);
141                         break;
142                 case '[':
143                 case 's':
144                         s = va_arg(a, void*);
145                         u = va_arg(a, int);
146                         if(c == 's'){
147                                 if(p+4 > e) goto err;
148                                 PUT4(p, u), p += 4;
149                         }
150                         if(u > e-p) goto err;
151                         memmove(p, s, u);
152                         p += u;
153                         break;
154                 case 'u':
155                         u = va_arg(a, int);
156                         if(p+4 > e) goto err;
157                         PUT4(p, u), p += 4;
158                         break;
159                 }
160         }
161 err:
162         return -1;
163 }
164
165 int
166 vunpack(uchar *p, int n, char *fmt, va_list a)
167 {
168         uchar *p0 = p, *e = p+n;
169         u32int u;
170         void *s;
171
172         for(;;){
173                 switch(*fmt++){
174                 case '\0':
175                         return p - p0;
176                 case '_':
177                         if(++p > e) goto err;
178                         break;
179                 case '.':
180                         *va_arg(a, void**) = p;
181                         break;
182                 case 'b':
183                         if(p >= e) goto err;
184                         *va_arg(a, int*) = *p++;
185                         break;
186                 case 's':
187                         if(p+4 > e) goto err;
188                         u = GET4(p), p += 4;
189                         if(u > e-p) goto err;
190                         *va_arg(a, void**) = p;
191                         *va_arg(a, int*) = u;
192                         p += u;
193                         break;
194                 case '[':
195                         s = va_arg(a, void*);
196                         u = va_arg(a, int);
197                         if(u > e-p) goto err;
198                         memmove(s, p, u);
199                         p += u;
200                         break;
201                 case 'u':
202                         if(p+4 > e) goto err;
203                         u = GET4(p);
204                         *va_arg(a, int*) = u;
205                         p += 4;
206                         break;
207                 }
208         }
209 err:
210         return -1;
211 }
212
213 Msg*
214 allocmsg(void)
215 {
216         Msg *m;
217
218         m = emalloc9p(sizeof(Msg));
219         m->link = nil;
220         m->rp = m->wp = m->buf;
221         m->ep = m->rp + sizeof(m->buf);
222         return m;
223 }
224
225 Msg*
226 pack(Msg *m, char *fmt, ...)
227 {
228         va_list a;
229         int n;
230
231         if(m == nil)
232                 m = allocmsg();
233         va_start(a, fmt);
234         n = vpack(m->wp, m->ep - m->wp, fmt, a);
235         if(n < 0)
236                 sysfatal("pack faild");
237         m->wp += n;
238         va_end(a);
239         return m;
240 }
241
242 int
243 unpack(Msg *m, char *fmt, ...)
244 {
245         va_list a;
246         int n;
247
248         va_start(a, fmt);
249         n = vunpack(m->rp, m->wp - m->rp, fmt, a);
250         if(n > 0)
251                 m->rp += n;
252         va_end(a);
253         return n;
254 }
255
256 void
257 sendmsg(Msg *m)
258 {
259         int n;
260
261         if(m == nil)
262                 return;
263         n = m->wp - m->rp;
264         if(n > 0){
265                 if(write(sshfd, m->rp, n) != n)
266                         sysfatal("write to ssh failed: %r");
267         }
268         free(m);
269 }
270
271 int
272 newclient(void)
273 {
274         int i;
275         Client *c;
276
277         for(i=0; i<nclient; i++)
278                 if(client[i]->ref==0 && client[i]->state == Closed)
279                         return i;
280
281         if(nclient%16 == 0)
282                 client = erealloc9p(client, (nclient+16)*sizeof(client[0]));
283
284         c = emalloc9p(sizeof(Client));
285         memset(c, 0, sizeof(*c));
286         c->num = nclient;
287         client[nclient++] = c;
288         return c->num;
289 }
290
291 Client*
292 getclient(int num)
293 {
294         if(num < 0 || num >= nclient)
295                 return nil;
296         return client[num];
297 }
298
299 void
300 adjustwin(Client *c, int len)
301 {
302         c->recvacc += len;
303         if(c->recvacc >= MaxPacket*WinPackets/2 || c->recvwin < MaxPacket){
304                 sendmsg(pack(nil, "buu", MSG_CHANNEL_WINDOW_ADJUST, c->servernum, c->recvacc));
305                 c->recvacc = 0;
306         }
307         c->recvwin += len;
308 }
309
310 void
311 senddata(Client *c, void *data, int len)
312 {
313         sendmsg(pack(nil, "bus", MSG_CHANNEL_DATA, c->servernum, (char*)data, len));
314         c->sendwin -= len;
315 }
316
317 void
318 queuerreq(Client *c, Req *r)
319 {
320         if(c->rq==nil)
321                 c->erq = &c->rq;
322         *c->erq = r;
323         r->aux = nil;
324         c->erq = (Req**)&r->aux;
325 }
326
327 void
328 queuermsg(Client *c, Msg *m)
329 {
330         if(c->mq==nil)
331                 c->emq = &c->mq;
332         *c->emq = m;
333         m->link = nil;
334         c->emq = (Msg**)&m->link;
335 }
336
337 void
338 matchrmsgs(Client *c)
339 {
340         Req *r;
341         Msg *m;
342         int n, rm;
343
344         while(c->rq != nil && c->mq != nil){
345                 r = c->rq;
346                 c->rq = r->aux;
347
348                 rm = 0;
349                 m = c->mq;
350                 n = r->ifcall.count;
351                 if(n >= m->wp - m->rp){
352                         n = m->wp - m->rp;
353                         c->mq = m->link;
354                         rm = 1;
355                 }
356                 memmove(r->ofcall.data, m->rp, n);
357                 if(rm)
358                         free(m);
359                 else
360                         m->rp += n;
361                 r->ofcall.count = n;
362                 respond(r, nil);
363                 adjustwin(c, n);
364         }
365 }
366
367 void
368 queuewreq(Client *c, Req *r)
369 {
370         if(c->wq==nil)
371                 c->ewq = &c->wq;
372         *c->ewq = r;
373         r->aux = nil;
374         c->ewq = (Req**)&r->aux;
375 }
376
377 void
378 procwreqs(Client *c)
379 {
380         Req *r;
381         int n;
382
383         while((r = c->wq) != nil && (n = c->sendwin) > 0){
384                 if(n > c->sendpkt)
385                         n = c->sendpkt;
386                 if(r->ifcall.count > n){
387                         senddata(c, r->ifcall.data, n);
388                         r->ifcall.count -= n;
389                         memmove(r->ifcall.data, (char*)r->ifcall.data + n, r->ifcall.count);
390                         continue;
391                 }
392                 c->wq = (Req*)r->aux;
393                 r->aux = nil;
394                 senddata(c, r->ifcall.data, r->ifcall.count);
395                 r->ofcall.count = r->ifcall.count;
396                 respond(r, nil);
397         }
398 }
399
400 Req*
401 findreq(Client *c, Req *r)
402 {
403         Req **l;
404
405         for(l=&c->rq; *l; l=(Req**)&(*l)->aux){
406                 if(*l == r){
407                         *l = r->aux;
408                         if(*l == nil)
409                                 c->erq = l;
410                         return r;
411                 }
412         }
413         for(l=&c->wq; *l; l=(Req**)&(*l)->aux){
414                 if(*l == r){
415                         *l = r->aux;
416                         if(*l == nil)
417                                 c->ewq = l;
418                         return r;
419                 }
420         }
421         return nil;
422 }
423
424 void
425 dialedclient(Client *c)
426 {
427         Req *r;
428
429         if(r=c->wq){
430                 if(r->aux != nil)
431                         sysfatal("more than one outstanding dial request (BUG)");
432                 if(c->state == Established)
433                         respond(r, nil);
434                 else
435                         respond(r, "connect failed");
436         }
437         c->wq = nil;
438 }
439
440 void
441 teardownclient(Client *c)
442 {
443         c->state = Teardown;
444         sendmsg(pack(nil, "bu", MSG_CHANNEL_EOF, c->servernum));
445 }
446
447 void
448 hangupclient(Client *c)
449 {
450         Req *r, *next;
451         Msg *m, *mnext;
452
453         c->state = Closed;
454         for(m=c->mq; m; m=mnext){
455                 mnext = m->link;
456                 free(m);
457         }
458         c->mq = nil;
459         for(r=c->rq; r; r=next){
460                 next = r->aux;
461                 respond(r, "hangup on network connection");
462         }
463         c->rq = nil;
464         for(r=c->wq; r; r=next){
465                 next = r->aux;
466                 respond(r, "hangup on network connection");
467         }
468         c->wq = nil;
469 }
470
471 void
472 closeclient(Client *c)
473 {
474         Msg *m, *next;
475
476         if(--c->ref)
477                 return;
478
479         if(c->rq != nil || c->wq != nil)
480                 sysfatal("ref count reached zero with requests pending (BUG)");
481
482         for(m=c->mq; m; m=next){
483                 next = m->link;
484                 free(m);
485         }
486         c->mq = nil;
487
488         if(c->state != Closed)
489                 teardownclient(c);
490 }
491
492         
493 void
494 sshreadproc(void*)
495 {
496         Msg *m;
497         int n;
498
499         for(;;){
500                 m = allocmsg();
501                 n = read(sshfd, m->rp, m->ep - m->rp);
502                 if(n <= 0)
503                         sysfatal("eof on ssh connection");
504                 m->wp += n;
505                 sendp(sshmsgchan, m);
506         }
507 }
508
509 typedef struct Tab Tab;
510 struct Tab
511 {
512         char *name;
513         ulong mode;
514 };
515
516 Tab tab[] =
517 {
518         "/",            DMDIR|0555,
519         "cs",           0666,
520         "tcp",          DMDIR|0555,     
521         "clone",        0666,
522         nil,            DMDIR|0555,
523         "ctl",          0666,
524         "data",         0666,
525         "local",        0444,
526         "remote",       0444,
527         "status",       0444,
528 };
529
530 static void
531 fillstat(Dir *d, uvlong path)
532 {
533         Tab *t;
534
535         memset(d, 0, sizeof(*d));
536         d->uid = estrdup9p("ssh");
537         d->gid = estrdup9p("ssh");
538         d->qid.path = path;
539         d->atime = d->mtime = time0;
540         t = &tab[TYPE(path)];
541         if(t->name)
542                 d->name = estrdup9p(t->name);
543         else{
544                 d->name = smprint("%ud", NUM(path));
545                 if(d->name == nil)
546                         sysfatal("out of memory");
547         }
548         d->qid.type = t->mode>>24;
549         d->mode = t->mode;
550 }
551
552 static void
553 fsattach(Req *r)
554 {
555         if(r->ifcall.aname && r->ifcall.aname[0]){
556                 respond(r, "invalid attach specifier");
557                 return;
558         }
559         r->fid->qid.path = PATH(Qroot, 0);
560         r->fid->qid.type = QTDIR;
561         r->fid->qid.vers = 0;
562         r->ofcall.qid = r->fid->qid;
563         respond(r, nil);
564 }
565
566 static void
567 fsstat(Req *r)
568 {
569         fillstat(&r->d, r->fid->qid.path);
570         respond(r, nil);
571 }
572
573 static int
574 rootgen(int i, Dir *d, void*)
575 {
576         i += Qroot+1;
577         if(i <= Qtcp){
578                 fillstat(d, i);
579                 return 0;
580         }
581         return -1;
582 }
583
584 static int
585 tcpgen(int i, Dir *d, void*)
586 {
587         i += Qtcp+1;
588         if(i < Qn){
589                 fillstat(d, i);
590                 return 0;
591         }
592         i -= Qn;
593         if(i < nclient){
594                 fillstat(d, PATH(Qn, i));
595                 return 0;
596         }
597         return -1;
598 }
599
600 static int
601 clientgen(int i, Dir *d, void *aux)
602 {
603         Client *c;
604
605         c = aux;
606         i += Qn+1;
607         if(i <= Qstatus){
608                 fillstat(d, PATH(i, c->num));
609                 return 0;
610         }
611         return -1;
612 }
613
614 static char*
615 fswalk1(Fid *fid, char *name, Qid *qid)
616 {
617         int i, n;
618         char buf[32];
619         ulong path;
620
621         path = fid->qid.path;
622         if(!(fid->qid.type&QTDIR))
623                 return "walk in non-directory";
624
625         if(strcmp(name, "..") == 0){
626                 switch(TYPE(path)){
627                 case Qn:
628                         qid->path = PATH(Qtcp, NUM(path));
629                         qid->type = tab[Qtcp].mode>>24;
630                         return nil;
631                 case Qtcp:
632                         qid->path = PATH(Qroot, 0);
633                         qid->type = tab[Qroot].mode>>24;
634                         return nil;
635                 case Qroot:
636                         return nil;
637                 default:
638                         return "bug in fswalk1";
639                 }
640         }
641
642         i = TYPE(path)+1;
643         for(; i<nelem(tab); i++){
644                 if(i==Qn){
645                         n = atoi(name);
646                         snprint(buf, sizeof buf, "%d", n);
647                         if(n < nclient && strcmp(buf, name) == 0){
648                                 qid->path = PATH(i, n);
649                                 qid->type = tab[i].mode>>24;
650                                 return nil;
651                         }
652                         break;
653                 }
654                 if(strcmp(name, tab[i].name) == 0){
655                         qid->path = PATH(i, NUM(path));
656                         qid->type = tab[i].mode>>24;
657                         return nil;
658                 }
659                 if(tab[i].mode&DMDIR)
660                         break;
661         }
662         return "directory entry not found";
663 }
664
665 typedef struct Cs Cs;
666 struct Cs
667 {
668         char *resp;
669         int isnew;
670 };
671
672 static int
673 ndbfindport(char *p)
674 {
675         char *s, *port;
676         int n;
677         static Ndb *db;
678
679         if(*p == '\0')
680                 return -1;
681
682         n = strtol(p, &s, 0);
683         if(*s == '\0')
684                 return n;
685
686         if(db == nil){
687                 db = ndbopen("/lib/ndb/common");
688                 if(db == nil)
689                         return -1;
690         }
691
692         port = ndbgetvalue(db, nil, "tcp", p, "port", nil);
693         if(port == nil)
694                 return -1;
695         n = atoi(port);
696         free(port);
697
698         return n;
699 }       
700
701 static void
702 csread(Req *r)
703 {
704         Cs *cs;
705
706         cs = r->fid->aux;
707         if(cs->resp==nil){
708                 respond(r, "cs read without write");
709                 return;
710         }
711         if(r->ifcall.offset==0){
712                 if(!cs->isnew){
713                         r->ofcall.count = 0;
714                         respond(r, nil);
715                         return;
716                 }
717                 cs->isnew = 0;
718         }
719         readstr(r, cs->resp);
720         respond(r, nil);
721 }
722
723 static void
724 cswrite(Req *r)
725 {
726         int port, nf;
727         char err[ERRMAX], *f[4], *s, *ns;
728         Cs *cs;
729
730         cs = r->fid->aux;
731         s = emalloc9p(r->ifcall.count+1);
732         memmove(s, r->ifcall.data, r->ifcall.count);
733         s[r->ifcall.count] = '\0';
734
735         nf = getfields(s, f, nelem(f), 0, "!");
736         if(nf != 3){
737                 free(s);
738                 respond(r, "can't translate");
739                 return;
740         }
741         if(strcmp(f[0], "tcp") != 0 && strcmp(f[0], "net") != 0){
742                 free(s);
743                 respond(r, "unknown protocol");
744                 return;
745         }
746         port = ndbfindport(f[2]);
747         if(port < 0){
748                 free(s);
749                 respond(r, "no translation found");
750                 return;
751         }
752
753         ns = smprint("%s/tcp/clone %s!%d", mtpt, f[1], port);
754         if(ns == nil){
755                 free(s);
756                 rerrstr(err, sizeof err);
757                 respond(r, err);
758                 return;
759         }
760         free(s);
761         free(cs->resp);
762         cs->resp = ns;
763         cs->isnew = 1;
764         r->ofcall.count = r->ifcall.count;
765         respond(r, nil);
766 }
767
768 static void
769 ctlread(Req *r, Client *c)
770 {
771         char buf[32];
772
773         sprint(buf, "%d", c->num);
774         readstr(r, buf);
775         respond(r, nil);
776 }
777
778 static void
779 ctlwrite(Req *r, Client *c)
780 {
781         char *f[3], *s;
782         int nf;
783
784         s = emalloc9p(r->ifcall.count+1);
785         r->ofcall.count = r->ifcall.count;
786         memmove(s, r->ifcall.data, r->ifcall.count);
787         s[r->ifcall.count] = '\0';
788
789         nf = tokenize(s, f, 3);
790         if(nf == 0){
791                 free(s);
792                 respond(r, nil);
793                 return;
794         }
795
796         if(strcmp(f[0], "hangup") == 0){
797                 if(c->state != Established)
798                         goto Badarg;
799                 if(nf != 1)
800                         goto Badarg;
801                 teardownclient(c);
802                 respond(r, nil);
803         }else if(strcmp(f[0], "connect") == 0){
804                 if(c->state != Closed)
805                         goto Badarg;
806                 if(nf != 2)
807                         goto Badarg;
808                 free(c->connect);
809                 c->connect = estrdup9p(f[1]);
810                 nf = getfields(f[1], f, nelem(f), 0, "!");
811                 if(nf != 2)
812                         goto Badarg;
813                 c->sendwin = MaxPacket;
814                 c->recvwin = WinPackets * MaxPacket;
815                 c->recvacc = 0;
816                 c->state = Dialing;
817                 queuewreq(c, r);
818
819                 sendmsg(pack(nil, "bsuuususu", MSG_CHANNEL_OPEN,
820                         "direct-tcpip", 12,
821                         c->num, c->recvwin, MaxPacket,
822                         f[0], strlen(f[0]), ndbfindport(f[1]),
823                         localip, strlen(localip), localport));
824         }else{
825         Badarg:
826                 respond(r, "bad or inappropriate tcp control message");
827         }
828         free(s);
829 }
830
831 static void
832 dataread(Req *r, Client *c)
833 {
834         if(c->state != Established){
835                 respond(r, "not connected");
836                 return;
837         }
838         queuerreq(c, r);
839         matchrmsgs(c);
840 }
841
842 static void
843 datawrite(Req *r, Client *c)
844 {
845         if(c->state != Established){
846                 respond(r, "not connected");
847                 return;
848         }
849         if(r->ifcall.count == 0){
850                 r->ofcall.count = r->ifcall.count;
851                 respond(r, nil);
852                 return;
853         }
854         queuewreq(c, r);
855         procwreqs(c);
856 }
857
858 static void
859 localread(Req *r)
860 {
861         char buf[128];
862
863         snprint(buf, sizeof buf, "%s!%d\n", localip, localport);
864         readstr(r, buf);
865         respond(r, nil);
866 }
867
868 static void
869 remoteread(Req *r, Client *c)
870 {
871         char *s;
872         char buf[128];
873
874         s = c->connect;
875         if(s == nil)
876                 s = "::!0";
877         snprint(buf, sizeof buf, "%s\n", s);
878         readstr(r, buf);
879         respond(r, nil);
880 }
881
882 static void
883 statusread(Req *r, Client *c)
884 {
885         char *s;
886
887         s = statestr[c->state];
888         readstr(r, s);
889         respond(r, nil);
890 }
891
892 static void
893 fsread(Req *r)
894 {
895         char e[ERRMAX];
896         ulong path;
897
898         path = r->fid->qid.path;
899         switch(TYPE(path)){
900         default:
901                 snprint(e, sizeof e, "bug in fsread path=%lux", path);
902                 respond(r, e);
903                 break;
904
905         case Qroot:
906                 dirread9p(r, rootgen, nil);
907                 respond(r, nil);
908                 break;
909
910         case Qcs:
911                 csread(r);
912                 break;
913
914         case Qtcp:
915                 dirread9p(r, tcpgen, nil);
916                 respond(r, nil);
917                 break;
918
919         case Qn:
920                 dirread9p(r, clientgen, client[NUM(path)]);
921                 respond(r, nil);
922                 break;
923
924         case Qctl:
925                 ctlread(r, client[NUM(path)]);
926                 break;
927
928         case Qdata:
929                 dataread(r, client[NUM(path)]);
930                 break;
931
932         case Qlocal:
933                 localread(r);
934                 break;
935
936         case Qremote:
937                 remoteread(r, client[NUM(path)]);
938                 break;
939
940         case Qstatus:
941                 statusread(r, client[NUM(path)]);
942                 break;
943         }
944 }
945
946 static void
947 fswrite(Req *r)
948 {
949         ulong path;
950         char e[ERRMAX];
951
952         path = r->fid->qid.path;
953         switch(TYPE(path)){
954         default:
955                 snprint(e, sizeof e, "bug in fswrite path=%lux", path);
956                 respond(r, e);
957                 break;
958
959         case Qcs:
960                 cswrite(r);
961                 break;
962
963         case Qctl:
964                 ctlwrite(r, client[NUM(path)]);
965                 break;
966
967         case Qdata:
968                 datawrite(r, client[NUM(path)]);
969                 break;
970         }
971 }
972
973 static void
974 fsopen(Req *r)
975 {
976         static int need[4] = { 4, 2, 6, 1 };
977         ulong path;
978         int n;
979         Tab *t;
980         Cs *cs;
981
982         /*
983          * lib9p already handles the blatantly obvious.
984          * we just have to enforce the permissions we have set.
985          */
986         path = r->fid->qid.path;
987         t = &tab[TYPE(path)];
988         n = need[r->ifcall.mode&3];
989         if((n&t->mode) != n){
990                 respond(r, "permission denied");
991                 return;
992         }
993
994         switch(TYPE(path)){
995         case Qcs:
996                 cs = emalloc9p(sizeof(Cs));
997                 r->fid->aux = cs;
998                 respond(r, nil);
999                 break;
1000         case Qclone:
1001                 n = newclient();
1002                 path = PATH(Qctl, n);
1003                 r->fid->qid.path = path;
1004                 r->ofcall.qid.path = path;
1005                 if(chatty9p)
1006                         fprint(2, "open clone => path=%lux\n", path);
1007                 t = &tab[Qctl];
1008                 /* fall through */
1009         default:
1010                 if(t-tab >= Qn)
1011                         client[NUM(path)]->ref++;
1012                 respond(r, nil);
1013                 break;
1014         }
1015 }
1016
1017 static void
1018 fsflush(Req *r)
1019 {
1020         int i;
1021
1022         for(i=0; i<nclient; i++)
1023                 if(findreq(client[i], r->oldreq))
1024                         respond(r->oldreq, "interrupted");
1025         respond(r, nil);
1026 }
1027
1028 static void
1029 handlemsg(Msg *m)
1030 {
1031         int chan, win, pkt, n, l;
1032         Client *c;
1033         char *s;
1034
1035         switch(m->rp[0]){
1036         case MSG_CHANNEL_WINDOW_ADJUST:
1037                 if(unpack(m, "_uu", &chan, &n) < 0)
1038                         break;
1039                 c = getclient(chan);
1040                 if(c != nil && c->state==Established){
1041                         c->sendwin += n;
1042                         procwreqs(c);
1043                 }
1044                 break;
1045         case MSG_CHANNEL_DATA:
1046                 if(unpack(m, "_us", &chan, &s, &n) < 0)
1047                         break;
1048                 c = getclient(chan);
1049                 if(c != nil && c->state==Established){
1050                         c->recvwin -= n;
1051                         m->rp = (uchar*)s;
1052                         queuermsg(c, m);
1053                         matchrmsgs(c);
1054                         return;
1055                 }
1056                 break;
1057         case MSG_CHANNEL_EOF:
1058                 if(unpack(m, "_u", &chan) < 0)
1059                         break;
1060                 c = getclient(chan);
1061                 if(c != nil){
1062                         hangupclient(c);
1063                         m->rp = m->wp = m->buf;
1064                         sendmsg(pack(m, "bu", MSG_CHANNEL_CLOSE, c->servernum));
1065                         return;
1066                 }
1067                 break;
1068         case MSG_CHANNEL_CLOSE:
1069                 if(unpack(m, "_u", &chan) < 0)
1070                         break;
1071                 c = getclient(chan);
1072                 if(c != nil)
1073                         hangupclient(c);
1074                 break;
1075         case MSG_CHANNEL_OPEN_CONFIRMATION:
1076                 if(unpack(m, "_uuuu", &chan, &n, &win, &pkt) < 0)
1077                         break;
1078                 if(chan == SESSIONCHAN){
1079                         sendp(ssherrchan, nil);
1080                         break;
1081                 }
1082                 c = getclient(chan);
1083                 if(c == nil || c->state != Dialing)
1084                         break;
1085                 if(pkt <= 0 || pkt > MaxPacket)
1086                         pkt = MaxPacket;
1087                 c->sendpkt = pkt;
1088                 c->sendwin = win;
1089                 c->servernum = n;
1090                 c->state = Established;
1091                 dialedclient(c);
1092                 break;
1093         case MSG_CHANNEL_OPEN_FAILURE:
1094                 if(unpack(m, "_uus", &chan, &n, &s, &l) < 0)
1095                         break;
1096                 if(chan == SESSIONCHAN){
1097                         sendp(ssherrchan, smprint("%.*s", utfnlen(s, l), s));
1098                         break;
1099                 }
1100                 c = getclient(chan);
1101                 if(c == nil || c->state != Dialing)
1102                         break;
1103                 c->servernum = n;
1104                 c->state = Closed;
1105                 dialedclient(c);
1106                 break;
1107         }
1108         free(m);
1109 }
1110
1111 void
1112 fsnetproc(void*)
1113 {
1114         ulong path;
1115         Alt a[4];
1116         Cs *cs;
1117         Fid *fid;
1118         Req *r;
1119         Msg *m;
1120
1121         threadsetname("fsthread");
1122
1123         a[0].op = CHANRCV;
1124         a[0].c = fsclunkchan;
1125         a[0].v = &fid;
1126         a[1].op = CHANRCV;
1127         a[1].c = fsreqchan;
1128         a[1].v = &r;
1129         a[2].op = CHANRCV;
1130         a[2].c = sshmsgchan;
1131         a[2].v = &m;
1132         a[3].op = CHANEND;
1133
1134         for(;;){
1135                 switch(alt(a)){
1136                 case 0:
1137                         path = fid->qid.path;
1138                         switch(TYPE(path)){
1139                         case Qcs:
1140                                 cs = fid->aux;
1141                                 if(cs){
1142                                         free(cs->resp);
1143                                         free(cs);
1144                                 }
1145                                 break;
1146                         }
1147                         if(fid->omode != -1 && TYPE(path) >= Qn)
1148                                 closeclient(client[NUM(path)]);
1149                         sendp(fsclunkwaitchan, nil);
1150                         break;
1151                 case 1:
1152                         switch(r->ifcall.type){
1153                         case Tattach:
1154                                 fsattach(r);
1155                                 break;
1156                         case Topen:
1157                                 fsopen(r);
1158                                 break;
1159                         case Tread:
1160                                 fsread(r);
1161                                 break;
1162                         case Twrite:
1163                                 fswrite(r);
1164                                 break;
1165                         case Tstat:
1166                                 fsstat(r);
1167                                 break;
1168                         case Tflush:
1169                                 fsflush(r);
1170                                 break;
1171                         default:
1172                                 respond(r, "bug in fsthread");
1173                                 break;
1174                         }
1175                         sendp(fsreqwaitchan, 0);
1176                         break;
1177                 case 2:
1178                         handlemsg(m);
1179                         break;
1180                 }
1181         }
1182 }
1183
1184 static void
1185 fssend(Req *r)
1186 {
1187         sendp(fsreqchan, r);
1188         recvp(fsreqwaitchan);   /* avoids need to deal with spurious flushes */
1189 }
1190
1191 static void
1192 fsdestroyfid(Fid *fid)
1193 {
1194         sendp(fsclunkchan, fid);
1195         recvp(fsclunkwaitchan);
1196 }
1197
1198 void
1199 takedown(Srv*)
1200 {
1201         threadexitsall("done");
1202 }
1203
1204 Srv fs = 
1205 {
1206 .attach=                fssend,
1207 .destroyfid=    fsdestroyfid,
1208 .walk1=         fswalk1,
1209 .open=          fssend,
1210 .read=          fssend,
1211 .write=         fssend,
1212 .stat=          fssend,
1213 .flush=         fssend,
1214 .end=           takedown,
1215 };
1216
1217 int pfd[2];
1218 int sshargc;
1219 char **sshargv;
1220
1221 void
1222 startssh(void *)
1223 {
1224         char *f;
1225
1226         close(pfd[0]);
1227         dup(pfd[1], 0);
1228         dup(pfd[1], 1);
1229         close(pfd[1]);
1230         if(strncmp(sshargv[0], "./", 2) != 0)
1231                 f = smprint("/bin/%s", sshargv[0]);
1232         else
1233                 f = sshargv[0];
1234         procexec(nil, f, sshargv);
1235         sysfatal("exec: %r");
1236 }
1237
1238 void
1239 ssh(int argc, char *argv[])
1240 {
1241         Alt a[3];
1242         Waitmsg *w;
1243         char *e;
1244
1245         sshargc = argc + 2;
1246         sshargv = emalloc9p(sizeof(char *) * (sshargc + 1));
1247         sshargv[0] = "ssh";
1248         sshargv[1] = "-X";
1249         memcpy(sshargv + 2, argv, argc * sizeof(char *));
1250
1251         pipe(pfd);
1252         sshfd = pfd[0];
1253         procrfork(startssh, nil, 8*1024, RFFDG|RFNOTEG|RFNAMEG);
1254         close(pfd[1]);
1255
1256         sendmsg(pack(nil, "bsuuu", MSG_CHANNEL_OPEN,
1257                 "session", 7,
1258                 SESSIONCHAN,
1259                 MaxPacket,
1260                 MaxPacket));
1261
1262         a[0].op = CHANRCV;
1263         a[0].c = threadwaitchan();
1264         a[0].v = &w;
1265         a[1].op = CHANRCV;
1266         a[1].c = ssherrchan;
1267         a[1].v = &e;
1268         a[2].op = CHANEND;
1269
1270         switch(alt(a)){
1271         case 0:
1272                 sysfatal("ssh failed: %s", w->msg);
1273         case 1:
1274                 if(e != nil)
1275                         sysfatal("ssh failed: %s", e);
1276         }
1277         chanclose(ssherrchan);
1278 }
1279
1280 void
1281 usage(void)
1282 {
1283         fprint(2, "usage: sshnet [-m mtpt] [ssh options]\n");
1284         exits("usage");
1285 }
1286
1287 void
1288 threadmain(int argc, char **argv)
1289 {
1290         char *service;
1291
1292         fmtinstall('H', encodefmt);
1293
1294         mtpt = "/net";
1295         service = nil;
1296         ARGBEGIN{
1297         case 'D':
1298                 chatty9p++;
1299                 break;
1300         case 'm':
1301                 mtpt = EARGF(usage());
1302                 break;
1303         case 's':
1304                 service = EARGF(usage());
1305                 break;
1306         default:
1307                 usage();
1308         }ARGEND
1309
1310         if(argc == 0)
1311                 usage();
1312
1313         time0 = time(0);
1314         ssherrchan = chancreate(sizeof(char*), 0);
1315         sshmsgchan = chancreate(sizeof(Msg*), 16);
1316         fsreqchan = chancreate(sizeof(Req*), 0);
1317         fsreqwaitchan = chancreate(sizeof(void*), 0);
1318         fsclunkchan = chancreate(sizeof(Fid*), 0);
1319         fsclunkwaitchan = chancreate(sizeof(void*), 0);
1320         procrfork(fsnetproc, nil, 8*1024, RFNAMEG|RFNOTEG);
1321         procrfork(sshreadproc, nil, 8*1024, RFNAMEG|RFNOTEG);
1322
1323         ssh(argc, argv);
1324
1325         threadpostmountsrv(&fs, service, mtpt, MREPL);
1326         exits(0);
1327 }