]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/sshnet.c
sshnet: pass on open failure error message, simplify
[plan9front.git] / sys / src / cmd / sshnet.c
1 /*
2  * SSH network file system.
3  * Presents remote TCP stack as /net-style file system.
4  */
5
6 #include <u.h>
7 #include <libc.h>
8 #include <bio.h>
9 #include <ndb.h>
10 #include <thread.h>
11 #include <fcall.h>
12 #include <9p.h>
13
14 typedef struct Client Client;
15 typedef struct Msg Msg;
16
17 enum
18 {
19         Qroot,
20         Qcs,
21         Qtcp,
22         Qclone,
23         Qn,
24         Qctl,
25         Qdata,
26         Qlocal,
27         Qremote,
28         Qstatus,
29 };
30
31 #define PATH(type, n)           ((type)|((n)<<8))
32 #define TYPE(path)              ((int)(path) & 0xFF)
33 #define NUM(path)               ((uint)(path)>>8)
34
35 Channel *ssherrchan;            /* chan(char*) */
36 Channel *sshmsgchan;            /* chan(Msg*) */
37 Channel *fsreqchan;             /* chan(Req*) */
38 Channel *fsreqwaitchan;         /* chan(nil) */
39 Channel *fsclunkchan;           /* chan(Fid*) */
40 Channel *fsclunkwaitchan;       /* chan(nil) */
41 ulong time0;
42
43 enum
44 {
45         Closed,
46         Dialing,
47         Established,
48         Teardown,
49         Finished,
50 };
51
52 char *statestr[] = {
53         "Closed",
54         "Dialing",
55         "Established",
56         "Teardown",
57         "Finished",
58 };
59
60 struct Client
61 {
62         int ref;
63         int state;
64         int num;
65         int servernum;
66         char *connect;
67
68         int sendpkt;
69         int sendwin;
70         int recvwin;
71         int recvacc;
72
73         int eof;
74
75         Req *wq;
76         Req **ewq;
77
78         Req *rq;
79         Req **erq;
80
81         Msg *mq;
82         Msg **emq;
83 };
84
85 enum {
86         MSG_CHANNEL_OPEN = 90,
87         MSG_CHANNEL_OPEN_CONFIRMATION,
88         MSG_CHANNEL_OPEN_FAILURE,
89         MSG_CHANNEL_WINDOW_ADJUST,
90         MSG_CHANNEL_DATA,
91         MSG_CHANNEL_EXTENDED_DATA,
92         MSG_CHANNEL_EOF,
93         MSG_CHANNEL_CLOSE,
94         MSG_CHANNEL_REQUEST,
95         MSG_CHANNEL_SUCCESS,
96         MSG_CHANNEL_FAILURE,
97
98         Overhead = 256,
99         MaxPacket = (1<<15)-256,        /* 32K is maxatomic for pipe */
100         WinPackets = 8,
101
102         SESSIONCHAN = 1<<24,
103 };
104
105 struct Msg
106 {
107         Msg     *link;
108
109         uchar   *rp;
110         uchar   *wp;
111         uchar   *ep;
112         uchar   buf[MaxPacket + Overhead];
113 };
114
115 #define PUT4(p, u) (p)[0] = (u)>>24, (p)[1] = (u)>>16, (p)[2] = (u)>>8, (p)[3] = (u)
116 #define GET4(p) (u32int)(p)[3] | (u32int)(p)[2]<<8 | (u32int)(p)[1]<<16 | (u32int)(p)[0]<<24
117
118 int nclient;
119 Client **client;
120 char *mtpt;
121 int sshfd;
122 int localport;
123 char localip[] = "::";
124
125 int
126 vpack(uchar *p, int n, char *fmt, va_list a)
127 {
128         uchar *p0 = p, *e = p+n;
129         u32int u;
130         void *s;
131         int c;
132
133         for(;;){
134                 switch(c = *fmt++){
135                 case '\0':
136                         return p - p0;
137                 case '_':
138                         if(++p > e) goto err;
139                         break;
140                 case '.':
141                         *va_arg(a, void**) = p;
142                         break;
143                 case 'b':
144                         if(p >= e) goto err;
145                         *p++ = va_arg(a, int);
146                         break;
147                 case '[':
148                 case 's':
149                         s = va_arg(a, void*);
150                         u = va_arg(a, int);
151                         if(c == 's'){
152                                 if(p+4 > e) goto err;
153                                 PUT4(p, u), p += 4;
154                         }
155                         if(u > e-p) goto err;
156                         memmove(p, s, u);
157                         p += u;
158                         break;
159                 case 'u':
160                         u = va_arg(a, int);
161                         if(p+4 > e) goto err;
162                         PUT4(p, u), p += 4;
163                         break;
164                 }
165         }
166 err:
167         return -1;
168 }
169
170 int
171 vunpack(uchar *p, int n, char *fmt, va_list a)
172 {
173         uchar *p0 = p, *e = p+n;
174         u32int u;
175         void *s;
176
177         for(;;){
178                 switch(*fmt++){
179                 case '\0':
180                         return p - p0;
181                 case '_':
182                         if(++p > e) goto err;
183                         break;
184                 case '.':
185                         *va_arg(a, void**) = p;
186                         break;
187                 case 'b':
188                         if(p >= e) goto err;
189                         *va_arg(a, int*) = *p++;
190                         break;
191                 case 's':
192                         if(p+4 > e) goto err;
193                         u = GET4(p), p += 4;
194                         if(u > e-p) goto err;
195                         *va_arg(a, void**) = p;
196                         *va_arg(a, int*) = u;
197                         p += u;
198                         break;
199                 case '[':
200                         s = va_arg(a, void*);
201                         u = va_arg(a, int);
202                         if(u > e-p) goto err;
203                         memmove(s, p, u);
204                         p += u;
205                         break;
206                 case 'u':
207                         if(p+4 > e) goto err;
208                         u = GET4(p);
209                         *va_arg(a, int*) = u;
210                         p += 4;
211                         break;
212                 }
213         }
214 err:
215         return -1;
216 }
217
218 Msg*
219 allocmsg(void)
220 {
221         Msg *m;
222
223         m = emalloc9p(sizeof(Msg));
224         m->link = nil;
225         m->rp = m->wp = m->buf;
226         m->ep = m->rp + sizeof(m->buf);
227         return m;
228 }
229
230 Msg*
231 pack(Msg *m, char *fmt, ...)
232 {
233         va_list a;
234         int n;
235
236         if(m == nil)
237                 m = allocmsg();
238         va_start(a, fmt);
239         n = vpack(m->wp, m->ep - m->wp, fmt, a);
240         if(n < 0)
241                 sysfatal("pack faild");
242         m->wp += n;
243         va_end(a);
244         return m;
245 }
246
247 int
248 unpack(Msg *m, char *fmt, ...)
249 {
250         va_list a;
251         int n;
252
253         va_start(a, fmt);
254         n = vunpack(m->rp, m->wp - m->rp, fmt, a);
255         if(n > 0)
256                 m->rp += n;
257         va_end(a);
258         return n;
259 }
260
261 void
262 sendmsg(Msg *m)
263 {
264         int n;
265
266         if(m == nil)
267                 return;
268         n = m->wp - m->rp;
269         if(n > 0){
270                 if(write(sshfd, m->rp, n) != n)
271                         sysfatal("write to ssh failed: %r");
272         }
273         free(m);
274 }
275
276 int
277 newclient(void)
278 {
279         int i;
280         Client *c;
281
282         for(i=0; i<nclient; i++)
283                 if(client[i]->ref==0 && client[i]->state == Closed)
284                         return i;
285
286         if(nclient%16 == 0)
287                 client = erealloc9p(client, (nclient+16)*sizeof(client[0]));
288
289         c = emalloc9p(sizeof(Client));
290         memset(c, 0, sizeof(*c));
291         c->num = nclient;
292         client[nclient++] = c;
293         return c->num;
294 }
295
296 Client*
297 getclient(int num)
298 {
299         if(num < 0 || num >= nclient)
300                 return nil;
301         return client[num];
302 }
303
304 void
305 adjustwin(Client *c, int len)
306 {
307         c->recvacc += len;
308         if(c->recvacc >= MaxPacket*WinPackets/2 || c->recvwin < MaxPacket){
309                 sendmsg(pack(nil, "buu", MSG_CHANNEL_WINDOW_ADJUST, c->servernum, c->recvacc));
310                 c->recvacc = 0;
311         }
312         c->recvwin += len;
313 }
314
315 void
316 senddata(Client *c, void *data, int len)
317 {
318         sendmsg(pack(nil, "bus", MSG_CHANNEL_DATA, c->servernum, (char*)data, len));
319         c->sendwin -= len;
320 }
321
322 void
323 queuerreq(Client *c, Req *r)
324 {
325         if(c->rq==nil)
326                 c->erq = &c->rq;
327         *c->erq = r;
328         r->aux = nil;
329         c->erq = (Req**)&r->aux;
330 }
331
332 void
333 queuermsg(Client *c, Msg *m)
334 {
335         if(c->mq==nil)
336                 c->emq = &c->mq;
337         *c->emq = m;
338         m->link = nil;
339         c->emq = (Msg**)&m->link;
340 }
341
342 void
343 matchrmsgs(Client *c)
344 {
345         Req *r;
346         Msg *m;
347         int n, rm;
348
349         while((r = c->rq) != nil && (m = c->mq) != nil){
350                 c->rq = r->aux;
351                 r->aux = nil;
352                 rm = 0;
353                 n = r->ifcall.count;
354                 if(n >= m->wp - m->rp){
355                         n = m->wp - m->rp;
356                         c->mq = m->link;
357                         rm = 1;
358                 }
359                 memmove(r->ofcall.data, m->rp, n);
360                 if(rm)
361                         free(m);
362                 else
363                         m->rp += n;
364                 r->ofcall.count = n;
365                 respond(r, nil);
366                 adjustwin(c, n);
367         }
368
369         if(c->eof){
370                 while((r = c->rq) != nil){
371                         c->rq = r->aux;
372                         r->aux = nil;
373                         r->ofcall.count = 0;
374                         respond(r, nil);
375                 }
376         }
377 }
378
379 void
380 queuewreq(Client *c, Req *r)
381 {
382         if(c->wq==nil)
383                 c->ewq = &c->wq;
384         *c->ewq = r;
385         r->aux = nil;
386         c->ewq = (Req**)&r->aux;
387 }
388
389 void
390 procwreqs(Client *c)
391 {
392         Req *r;
393         int n;
394
395         while((r = c->wq) != nil && (n = c->sendwin) > 0){
396                 if(n > c->sendpkt)
397                         n = c->sendpkt;
398                 if(r->ifcall.count > n){
399                         senddata(c, r->ifcall.data, n);
400                         r->ifcall.count -= n;
401                         memmove(r->ifcall.data, (char*)r->ifcall.data + n, r->ifcall.count);
402                         continue;
403                 }
404                 c->wq = (Req*)r->aux;
405                 r->aux = nil;
406                 senddata(c, r->ifcall.data, r->ifcall.count);
407                 r->ofcall.count = r->ifcall.count;
408                 respond(r, nil);
409         }
410 }
411
412 Req*
413 findreq(Client *c, Req *r)
414 {
415         Req **l;
416
417         for(l=&c->rq; *l; l=(Req**)&(*l)->aux){
418                 if(*l == r){
419                         *l = r->aux;
420                         if(*l == nil)
421                                 c->erq = l;
422                         return r;
423                 }
424         }
425         for(l=&c->wq; *l; l=(Req**)&(*l)->aux){
426                 if(*l == r){
427                         *l = r->aux;
428                         if(*l == nil)
429                                 c->ewq = l;
430                         return r;
431                 }
432         }
433         return nil;
434 }
435
436 void
437 hangupclient(Client *c, char *err)
438 {
439         Req *r;
440
441         c->eof = 1;
442         c->recvwin = 0;
443         c->sendwin = 0;
444         while((r = c->wq) != nil){
445                 c->wq = r->aux;
446                 r->aux = nil;
447                 respond(r, err);
448         }
449         matchrmsgs(c);
450 }
451
452 void
453 teardownclient(Client *c)
454 {
455         c->state = Teardown;
456         hangupclient(c, "i/o on hungup channel");
457         sendmsg(pack(nil, "bu", MSG_CHANNEL_CLOSE, c->servernum));
458 }
459
460 void
461 closeclient(Client *c)
462 {
463         Msg *m;
464
465         if(--c->ref)
466                 return;
467         switch(c->state){
468         case Established:
469                 teardownclient(c);
470                 break;
471         case Finished:
472                 c->state = Closed;
473                 sendmsg(pack(nil, "bu", MSG_CHANNEL_CLOSE, c->servernum));
474                 break;
475         }
476         while((m = c->mq) != nil){
477                 c->mq = m->link;
478                 free(m);
479         }
480 }
481
482 void
483 sshreadproc(void*)
484 {
485         Msg *m;
486         int n;
487
488         for(;;){
489                 m = allocmsg();
490                 n = read(sshfd, m->rp, m->ep - m->rp);
491                 if(n <= 0)
492                         sysfatal("eof on ssh connection");
493                 m->wp += n;
494                 sendp(sshmsgchan, m);
495         }
496 }
497
498 typedef struct Tab Tab;
499 struct Tab
500 {
501         char *name;
502         ulong mode;
503 };
504
505 Tab tab[] =
506 {
507         "/",            DMDIR|0555,
508         "cs",           0666,
509         "tcp",          DMDIR|0555,     
510         "clone",        0666,
511         nil,            DMDIR|0555,
512         "ctl",          0666,
513         "data",         0666,
514         "local",        0444,
515         "remote",       0444,
516         "status",       0444,
517 };
518
519 static void
520 fillstat(Dir *d, uvlong path)
521 {
522         Tab *t;
523
524         memset(d, 0, sizeof(*d));
525         d->uid = estrdup9p("ssh");
526         d->gid = estrdup9p("ssh");
527         d->qid.path = path;
528         d->atime = d->mtime = time0;
529         t = &tab[TYPE(path)];
530         if(t->name)
531                 d->name = estrdup9p(t->name);
532         else{
533                 d->name = smprint("%ud", NUM(path));
534                 if(d->name == nil)
535                         sysfatal("out of memory");
536         }
537         d->qid.type = t->mode>>24;
538         d->mode = t->mode;
539 }
540
541 static void
542 fsattach(Req *r)
543 {
544         if(r->ifcall.aname && r->ifcall.aname[0]){
545                 respond(r, "invalid attach specifier");
546                 return;
547         }
548         r->fid->qid.path = PATH(Qroot, 0);
549         r->fid->qid.type = QTDIR;
550         r->fid->qid.vers = 0;
551         r->ofcall.qid = r->fid->qid;
552         respond(r, nil);
553 }
554
555 static void
556 fsstat(Req *r)
557 {
558         fillstat(&r->d, r->fid->qid.path);
559         respond(r, nil);
560 }
561
562 static int
563 rootgen(int i, Dir *d, void*)
564 {
565         i += Qroot+1;
566         if(i <= Qtcp){
567                 fillstat(d, i);
568                 return 0;
569         }
570         return -1;
571 }
572
573 static int
574 tcpgen(int i, Dir *d, void*)
575 {
576         i += Qtcp+1;
577         if(i < Qn){
578                 fillstat(d, i);
579                 return 0;
580         }
581         i -= Qn;
582         if(i < nclient){
583                 fillstat(d, PATH(Qn, i));
584                 return 0;
585         }
586         return -1;
587 }
588
589 static int
590 clientgen(int i, Dir *d, void *aux)
591 {
592         Client *c;
593
594         c = aux;
595         i += Qn+1;
596         if(i <= Qstatus){
597                 fillstat(d, PATH(i, c->num));
598                 return 0;
599         }
600         return -1;
601 }
602
603 static char*
604 fswalk1(Fid *fid, char *name, Qid *qid)
605 {
606         int i, n;
607         char buf[32];
608         ulong path;
609
610         path = fid->qid.path;
611         if(!(fid->qid.type&QTDIR))
612                 return "walk in non-directory";
613
614         if(strcmp(name, "..") == 0){
615                 switch(TYPE(path)){
616                 case Qn:
617                         qid->path = PATH(Qtcp, NUM(path));
618                         qid->type = tab[Qtcp].mode>>24;
619                         return nil;
620                 case Qtcp:
621                         qid->path = PATH(Qroot, 0);
622                         qid->type = tab[Qroot].mode>>24;
623                         return nil;
624                 case Qroot:
625                         return nil;
626                 default:
627                         return "bug in fswalk1";
628                 }
629         }
630
631         i = TYPE(path)+1;
632         for(; i<nelem(tab); i++){
633                 if(i==Qn){
634                         n = atoi(name);
635                         snprint(buf, sizeof buf, "%d", n);
636                         if(n < nclient && strcmp(buf, name) == 0){
637                                 qid->path = PATH(i, n);
638                                 qid->type = tab[i].mode>>24;
639                                 return nil;
640                         }
641                         break;
642                 }
643                 if(strcmp(name, tab[i].name) == 0){
644                         qid->path = PATH(i, NUM(path));
645                         qid->type = tab[i].mode>>24;
646                         return nil;
647                 }
648                 if(tab[i].mode&DMDIR)
649                         break;
650         }
651         return "directory entry not found";
652 }
653
654 typedef struct Cs Cs;
655 struct Cs
656 {
657         char *resp;
658         int isnew;
659 };
660
661 static int
662 ndbfindport(char *p)
663 {
664         char *s, *port;
665         int n;
666         static Ndb *db;
667
668         if(*p == '\0')
669                 return -1;
670
671         n = strtol(p, &s, 0);
672         if(*s == '\0')
673                 return n;
674
675         if(db == nil){
676                 db = ndbopen("/lib/ndb/common");
677                 if(db == nil)
678                         return -1;
679         }
680
681         port = ndbgetvalue(db, nil, "tcp", p, "port", nil);
682         if(port == nil)
683                 return -1;
684         n = atoi(port);
685         free(port);
686
687         return n;
688 }       
689
690 static void
691 csread(Req *r)
692 {
693         Cs *cs;
694
695         cs = r->fid->aux;
696         if(cs->resp==nil){
697                 respond(r, "cs read without write");
698                 return;
699         }
700         if(r->ifcall.offset==0){
701                 if(!cs->isnew){
702                         r->ofcall.count = 0;
703                         respond(r, nil);
704                         return;
705                 }
706                 cs->isnew = 0;
707         }
708         readstr(r, cs->resp);
709         respond(r, nil);
710 }
711
712 static void
713 cswrite(Req *r)
714 {
715         int port, nf;
716         char err[ERRMAX], *f[4], *s, *ns;
717         Cs *cs;
718
719         cs = r->fid->aux;
720         s = emalloc9p(r->ifcall.count+1);
721         memmove(s, r->ifcall.data, r->ifcall.count);
722         s[r->ifcall.count] = '\0';
723
724         nf = getfields(s, f, nelem(f), 0, "!");
725         if(nf != 3){
726                 free(s);
727                 respond(r, "can't translate");
728                 return;
729         }
730         if(strcmp(f[0], "tcp") != 0 && strcmp(f[0], "net") != 0){
731                 free(s);
732                 respond(r, "unknown protocol");
733                 return;
734         }
735         port = ndbfindport(f[2]);
736         if(port < 0){
737                 free(s);
738                 respond(r, "no translation found");
739                 return;
740         }
741
742         ns = smprint("%s/tcp/clone %s!%d", mtpt, f[1], port);
743         if(ns == nil){
744                 free(s);
745                 rerrstr(err, sizeof err);
746                 respond(r, err);
747                 return;
748         }
749         free(s);
750         free(cs->resp);
751         cs->resp = ns;
752         cs->isnew = 1;
753         r->ofcall.count = r->ifcall.count;
754         respond(r, nil);
755 }
756
757 static void
758 ctlread(Req *r, Client *c)
759 {
760         char buf[32];
761
762         sprint(buf, "%d", c->num);
763         readstr(r, buf);
764         respond(r, nil);
765 }
766
767 static void
768 ctlwrite(Req *r, Client *c)
769 {
770         char *f[3], *s;
771         int nf;
772
773         s = emalloc9p(r->ifcall.count+1);
774         r->ofcall.count = r->ifcall.count;
775         memmove(s, r->ifcall.data, r->ifcall.count);
776         s[r->ifcall.count] = '\0';
777
778         nf = tokenize(s, f, 3);
779         if(nf == 0){
780                 free(s);
781                 respond(r, nil);
782                 return;
783         }
784
785         if(strcmp(f[0], "hangup") == 0){
786                 if(c->state != Established)
787                         goto Badarg;
788                 if(nf != 1)
789                         goto Badarg;
790                 teardownclient(c);
791                 respond(r, nil);
792         }else if(strcmp(f[0], "connect") == 0){
793                 if(c->state != Closed)
794                         goto Badarg;
795                 if(nf != 2)
796                         goto Badarg;
797                 free(c->connect);
798                 c->connect = estrdup9p(f[1]);
799                 nf = getfields(f[1], f, nelem(f), 0, "!");
800                 if(nf != 2)
801                         goto Badarg;
802                 c->recvwin = WinPackets*MaxPacket;
803                 c->recvacc = 0;
804                 c->state = Dialing;
805                 queuewreq(c, r);
806
807                 sendmsg(pack(nil, "bsuuususu", MSG_CHANNEL_OPEN,
808                         "direct-tcpip", 12,
809                         c->num, c->recvwin, MaxPacket,
810                         f[0], strlen(f[0]), ndbfindport(f[1]),
811                         localip, strlen(localip), localport));
812         }else{
813         Badarg:
814                 respond(r, "bad or inappropriate tcp control message");
815         }
816         free(s);
817 }
818
819 static void
820 dataread(Req *r, Client *c)
821 {
822         if(c->state < Established){
823                 respond(r, "not connected");
824                 return;
825         }
826         queuerreq(c, r);
827         matchrmsgs(c);
828 }
829
830 static void
831 datawrite(Req *r, Client *c)
832 {
833         if(c->state != Established){
834                 respond(r, "not connected");
835                 return;
836         }
837         if(r->ifcall.count == 0){
838                 r->ofcall.count = r->ifcall.count;
839                 respond(r, nil);
840                 return;
841         }
842         queuewreq(c, r);
843         procwreqs(c);
844 }
845
846 static void
847 localread(Req *r)
848 {
849         char buf[128];
850
851         snprint(buf, sizeof buf, "%s!%d\n", localip, localport);
852         readstr(r, buf);
853         respond(r, nil);
854 }
855
856 static void
857 remoteread(Req *r, Client *c)
858 {
859         char *s;
860         char buf[128];
861
862         s = c->connect;
863         if(s == nil)
864                 s = "::!0";
865         snprint(buf, sizeof buf, "%s\n", s);
866         readstr(r, buf);
867         respond(r, nil);
868 }
869
870 static void
871 statusread(Req *r, Client *c)
872 {
873         char *s;
874
875         s = statestr[c->state];
876         readstr(r, s);
877         respond(r, nil);
878 }
879
880 static void
881 fsread(Req *r)
882 {
883         char e[ERRMAX];
884         ulong path;
885
886         path = r->fid->qid.path;
887         switch(TYPE(path)){
888         default:
889                 snprint(e, sizeof e, "bug in fsread path=%lux", path);
890                 respond(r, e);
891                 break;
892
893         case Qroot:
894                 dirread9p(r, rootgen, nil);
895                 respond(r, nil);
896                 break;
897
898         case Qcs:
899                 csread(r);
900                 break;
901
902         case Qtcp:
903                 dirread9p(r, tcpgen, nil);
904                 respond(r, nil);
905                 break;
906
907         case Qn:
908                 dirread9p(r, clientgen, client[NUM(path)]);
909                 respond(r, nil);
910                 break;
911
912         case Qctl:
913                 ctlread(r, client[NUM(path)]);
914                 break;
915
916         case Qdata:
917                 dataread(r, client[NUM(path)]);
918                 break;
919
920         case Qlocal:
921                 localread(r);
922                 break;
923
924         case Qremote:
925                 remoteread(r, client[NUM(path)]);
926                 break;
927
928         case Qstatus:
929                 statusread(r, client[NUM(path)]);
930                 break;
931         }
932 }
933
934 static void
935 fswrite(Req *r)
936 {
937         ulong path;
938         char e[ERRMAX];
939
940         path = r->fid->qid.path;
941         switch(TYPE(path)){
942         default:
943                 snprint(e, sizeof e, "bug in fswrite path=%lux", path);
944                 respond(r, e);
945                 break;
946
947         case Qcs:
948                 cswrite(r);
949                 break;
950
951         case Qctl:
952                 ctlwrite(r, client[NUM(path)]);
953                 break;
954
955         case Qdata:
956                 datawrite(r, client[NUM(path)]);
957                 break;
958         }
959 }
960
961 static void
962 fsopen(Req *r)
963 {
964         static int need[4] = { 4, 2, 6, 1 };
965         ulong path;
966         int n;
967         Tab *t;
968         Cs *cs;
969
970         /*
971          * lib9p already handles the blatantly obvious.
972          * we just have to enforce the permissions we have set.
973          */
974         path = r->fid->qid.path;
975         t = &tab[TYPE(path)];
976         n = need[r->ifcall.mode&3];
977         if((n&t->mode) != n){
978                 respond(r, "permission denied");
979                 return;
980         }
981
982         switch(TYPE(path)){
983         case Qcs:
984                 cs = emalloc9p(sizeof(Cs));
985                 r->fid->aux = cs;
986                 respond(r, nil);
987                 break;
988         case Qclone:
989                 n = newclient();
990                 path = PATH(Qctl, n);
991                 r->fid->qid.path = path;
992                 r->ofcall.qid.path = path;
993                 if(chatty9p)
994                         fprint(2, "open clone => path=%lux\n", path);
995                 t = &tab[Qctl];
996                 /* fall through */
997         default:
998                 if(t-tab >= Qn)
999                         client[NUM(path)]->ref++;
1000                 respond(r, nil);
1001                 break;
1002         }
1003 }
1004
1005 static void
1006 fsflush(Req *r)
1007 {
1008         int i;
1009
1010         for(i=0; i<nclient; i++)
1011                 if(findreq(client[i], r->oldreq))
1012                         respond(r->oldreq, "interrupted");
1013         respond(r, nil);
1014 }
1015
1016 static void
1017 handlemsg(Msg *m)
1018 {
1019         int chan, win, pkt, n;
1020         Client *c;
1021         char *s;
1022
1023         switch(m->rp[0]){
1024         case MSG_CHANNEL_WINDOW_ADJUST:
1025                 if(unpack(m, "_uu", &chan, &n) < 0)
1026                         break;
1027                 c = getclient(chan);
1028                 if(c != nil && c->state == Established){
1029                         c->sendwin += n;
1030                         procwreqs(c);
1031                 }
1032                 break;
1033         case MSG_CHANNEL_DATA:
1034                 if(unpack(m, "_us", &chan, &s, &n) < 0)
1035                         break;
1036                 c = getclient(chan);
1037                 if(c != nil && c->state == Established){
1038                         if(c->recvwin <= 0)
1039                                 break;
1040                         c->recvwin -= n;
1041                         m->rp = (uchar*)s;
1042                         queuermsg(c, m);
1043                         matchrmsgs(c);
1044                         return;
1045                 }
1046                 break;
1047         case MSG_CHANNEL_EOF:
1048                 if(unpack(m, "_u", &chan) < 0)
1049                         break;
1050                 c = getclient(chan);
1051                 if(c != nil && c->state == Established){
1052                         c->eof = 1;
1053                         c->recvwin = 0;
1054                         matchrmsgs(c);
1055                 }
1056                 break;
1057         case MSG_CHANNEL_CLOSE:
1058                 if(unpack(m, "_u", &chan) < 0)
1059                         break;
1060                 c = getclient(chan);
1061                 if(c == nil)
1062                         break;
1063                 switch(c->state){
1064                 case Established:
1065                         c->state = Finished;
1066                         hangupclient(c, "connection closed");
1067                         break;
1068                 case Teardown:
1069                         c->state = Closed;
1070                         break;
1071                 }
1072                 break;
1073         case MSG_CHANNEL_OPEN_CONFIRMATION:
1074                 if(unpack(m, "_uuuu", &chan, &n, &win, &pkt) < 0)
1075                         break;
1076                 if(chan == SESSIONCHAN){
1077                         sendp(ssherrchan, nil);
1078                         break;
1079                 }
1080                 c = getclient(chan);
1081                 if(c == nil || c->state != Dialing)
1082                         break;
1083                 if(pkt <= 0 || pkt > MaxPacket)
1084                         pkt = MaxPacket;
1085                 c->eof = 0;
1086                 c->sendpkt = pkt;
1087                 c->sendwin = win;
1088                 c->servernum = n;
1089                 c->state = Established;
1090                 if(c->wq != nil){
1091                         respond(c->wq, nil);
1092                         c->wq = nil;
1093                 }
1094                 break;
1095         case MSG_CHANNEL_OPEN_FAILURE:
1096                 if(unpack(m, "_u____s", &chan, &s, &n) < 0)
1097                         break;
1098                 s = smprint("%.*s", utfnlen(s, n), s);
1099                 if(chan == SESSIONCHAN){
1100                         sendp(ssherrchan, s);
1101                         break;
1102                 }
1103                 c = getclient(chan);
1104                 if(c == nil || c->state != Dialing){
1105                         free(s);
1106                         break;
1107                 }
1108                 c->state = Closed;
1109                 hangupclient(c, s);
1110                 break;
1111         }
1112         free(m);
1113 }
1114
1115 void
1116 fsnetproc(void*)
1117 {
1118         ulong path;
1119         Alt a[4];
1120         Cs *cs;
1121         Fid *fid;
1122         Req *r;
1123         Msg *m;
1124
1125         threadsetname("fsthread");
1126
1127         a[0].op = CHANRCV;
1128         a[0].c = fsclunkchan;
1129         a[0].v = &fid;
1130         a[1].op = CHANRCV;
1131         a[1].c = fsreqchan;
1132         a[1].v = &r;
1133         a[2].op = CHANRCV;
1134         a[2].c = sshmsgchan;
1135         a[2].v = &m;
1136         a[3].op = CHANEND;
1137
1138         for(;;){
1139                 switch(alt(a)){
1140                 case 0:
1141                         path = fid->qid.path;
1142                         switch(TYPE(path)){
1143                         case Qcs:
1144                                 cs = fid->aux;
1145                                 if(cs){
1146                                         free(cs->resp);
1147                                         free(cs);
1148                                 }
1149                                 break;
1150                         }
1151                         if(fid->omode != -1 && TYPE(path) >= Qn)
1152                                 closeclient(client[NUM(path)]);
1153                         sendp(fsclunkwaitchan, nil);
1154                         break;
1155                 case 1:
1156                         switch(r->ifcall.type){
1157                         case Tattach:
1158                                 fsattach(r);
1159                                 break;
1160                         case Topen:
1161                                 fsopen(r);
1162                                 break;
1163                         case Tread:
1164                                 fsread(r);
1165                                 break;
1166                         case Twrite:
1167                                 fswrite(r);
1168                                 break;
1169                         case Tstat:
1170                                 fsstat(r);
1171                                 break;
1172                         case Tflush:
1173                                 fsflush(r);
1174                                 break;
1175                         default:
1176                                 respond(r, "bug in fsthread");
1177                                 break;
1178                         }
1179                         sendp(fsreqwaitchan, 0);
1180                         break;
1181                 case 2:
1182                         handlemsg(m);
1183                         break;
1184                 }
1185         }
1186 }
1187
1188 static void
1189 fssend(Req *r)
1190 {
1191         sendp(fsreqchan, r);
1192         recvp(fsreqwaitchan);   /* avoids need to deal with spurious flushes */
1193 }
1194
1195 static void
1196 fsdestroyfid(Fid *fid)
1197 {
1198         sendp(fsclunkchan, fid);
1199         recvp(fsclunkwaitchan);
1200 }
1201
1202 void
1203 takedown(Srv*)
1204 {
1205         threadexitsall("done");
1206 }
1207
1208 Srv fs = 
1209 {
1210 .attach=                fssend,
1211 .destroyfid=    fsdestroyfid,
1212 .walk1=         fswalk1,
1213 .open=          fssend,
1214 .read=          fssend,
1215 .write=         fssend,
1216 .stat=          fssend,
1217 .flush=         fssend,
1218 .end=           takedown,
1219 };
1220
1221 int pfd[2];
1222 int sshargc;
1223 char **sshargv;
1224
1225 void
1226 startssh(void *)
1227 {
1228         char *f;
1229
1230         close(pfd[0]);
1231         dup(pfd[1], 0);
1232         dup(pfd[1], 1);
1233         close(pfd[1]);
1234         if(strncmp(sshargv[0], "./", 2) != 0)
1235                 f = smprint("/bin/%s", sshargv[0]);
1236         else
1237                 f = sshargv[0];
1238         procexec(nil, f, sshargv);
1239         sysfatal("exec: %r");
1240 }
1241
1242 void
1243 ssh(int argc, char *argv[])
1244 {
1245         Alt a[3];
1246         Waitmsg *w;
1247         char *e;
1248
1249         sshargc = argc + 2;
1250         sshargv = emalloc9p(sizeof(char *) * (sshargc + 1));
1251         sshargv[0] = "ssh";
1252         sshargv[1] = "-X";
1253         memcpy(sshargv + 2, argv, argc * sizeof(char *));
1254
1255         pipe(pfd);
1256         sshfd = pfd[0];
1257         procrfork(startssh, nil, 8*1024, RFFDG|RFNOTEG|RFNAMEG);
1258         close(pfd[1]);
1259
1260         sendmsg(pack(nil, "bsuuu", MSG_CHANNEL_OPEN,
1261                 "session", 7,
1262                 SESSIONCHAN,
1263                 MaxPacket,
1264                 MaxPacket));
1265
1266         a[0].op = CHANRCV;
1267         a[0].c = threadwaitchan();
1268         a[0].v = &w;
1269         a[1].op = CHANRCV;
1270         a[1].c = ssherrchan;
1271         a[1].v = &e;
1272         a[2].op = CHANEND;
1273
1274         switch(alt(a)){
1275         case 0:
1276                 sysfatal("ssh failed: %s", w->msg);
1277         case 1:
1278                 if(e != nil)
1279                         sysfatal("ssh failed: %s", e);
1280         }
1281         chanclose(ssherrchan);
1282 }
1283
1284 void
1285 usage(void)
1286 {
1287         fprint(2, "usage: sshnet [-m mtpt] [ssh options]\n");
1288         exits("usage");
1289 }
1290
1291 void
1292 threadmain(int argc, char **argv)
1293 {
1294         char *service;
1295
1296         fmtinstall('H', encodefmt);
1297
1298         mtpt = "/net";
1299         service = nil;
1300         ARGBEGIN{
1301         case 'D':
1302                 chatty9p++;
1303                 break;
1304         case 'm':
1305                 mtpt = EARGF(usage());
1306                 break;
1307         case 's':
1308                 service = EARGF(usage());
1309                 break;
1310         default:
1311                 usage();
1312         }ARGEND
1313
1314         if(argc == 0)
1315                 usage();
1316
1317         time0 = time(0);
1318         ssherrchan = chancreate(sizeof(char*), 0);
1319         sshmsgchan = chancreate(sizeof(Msg*), 16);
1320         fsreqchan = chancreate(sizeof(Req*), 0);
1321         fsreqwaitchan = chancreate(sizeof(void*), 0);
1322         fsclunkchan = chancreate(sizeof(Fid*), 0);
1323         fsclunkwaitchan = chancreate(sizeof(void*), 0);
1324         procrfork(fsnetproc, nil, 8*1024, RFNAMEG|RFNOTEG);
1325         procrfork(sshreadproc, nil, 8*1024, RFNAMEG|RFNOTEG);
1326
1327         ssh(argc, argv);
1328
1329         threadpostmountsrv(&fs, service, mtpt, MREPL);
1330         exits(0);
1331 }