]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/sshnet.c
sshnet: bring back sshnet using ssh(1) mux mode
[plan9front.git] / sys / src / cmd / sshnet.c
1 /*
2  * SSH network file system.
3  * Presents remote TCP stack as /net-style file system.
4  */
5
6 #include <u.h>
7 #include <libc.h>
8 #include <bio.h>
9 #include <ndb.h>
10 #include <thread.h>
11 #include <fcall.h>
12 #include <9p.h>
13
14 typedef struct Client Client;
15 typedef struct Msg Msg;
16
17 enum
18 {
19         Qroot,
20         Qcs,
21         Qtcp,
22         Qclone,
23         Qn,
24         Qctl,
25         Qdata,
26         Qlocal,
27         Qremote,
28         Qstatus,
29 };
30
31 #define PATH(type, n)           ((type)|((n)<<8))
32 #define TYPE(path)              ((int)(path) & 0xFF)
33 #define NUM(path)               ((uint)(path)>>8)
34
35 Channel *sshmsgchan;            /* chan(Msg*) */
36 Channel *fsreqchan;             /* chan(Req*) */
37 Channel *fsreqwaitchan;         /* chan(nil) */
38 Channel *fsclunkchan;           /* chan(Fid*) */
39 Channel *fsclunkwaitchan;       /* chan(nil) */
40 ulong time0;
41
42 enum
43 {
44         Closed,
45         Dialing,
46         Established,
47         Teardown,
48 };
49
50 char *statestr[] = {
51         "Closed",
52         "Dialing",
53         "Established",
54         "Teardown",
55 };
56
57 struct Client
58 {
59         int ref;
60         int state;
61         int num;
62         int servernum;
63         char *connect;
64
65         int sendpkt;
66         int sendwin;
67         int recvwin;
68         int recvacc;
69
70         Req *wq;
71         Req **ewq;
72
73         Req *rq;
74         Req **erq;
75
76         Msg *mq;
77         Msg **emq;
78 };
79
80 enum {
81         MSG_CHANNEL_OPEN = 90,
82         MSG_CHANNEL_OPEN_CONFIRMATION,
83         MSG_CHANNEL_OPEN_FAILURE,
84         MSG_CHANNEL_WINDOW_ADJUST,
85         MSG_CHANNEL_DATA,
86         MSG_CHANNEL_EXTENDED_DATA,
87         MSG_CHANNEL_EOF,
88         MSG_CHANNEL_CLOSE,
89         MSG_CHANNEL_REQUEST,
90         MSG_CHANNEL_SUCCESS,
91         MSG_CHANNEL_FAILURE,
92
93         MaxPacket = 1<<15,
94         WinPackets = 8,
95 };
96
97 struct Msg
98 {
99         Msg     *link;
100
101         uchar   *rp;
102         uchar   *wp;
103         uchar   *ep;
104         uchar   buf[MaxPacket];
105 };
106
107 #define PUT4(p, u) (p)[0] = (u)>>24, (p)[1] = (u)>>16, (p)[2] = (u)>>8, (p)[3] = (u)
108 #define GET4(p) (u32int)(p)[3] | (u32int)(p)[2]<<8 | (u32int)(p)[1]<<16 | (u32int)(p)[0]<<24
109
110 int nclient;
111 Client **client;
112 char *mtpt;
113 int sshfd;
114 int localport;
115 char localip[] = "::";
116
117 int
118 vpack(uchar *p, int n, char *fmt, va_list a)
119 {
120         uchar *p0 = p, *e = p+n;
121         u32int u;
122         void *s;
123         int c;
124
125         for(;;){
126                 switch(c = *fmt++){
127                 case '\0':
128                         return p - p0;
129                 case '_':
130                         if(++p > e) goto err;
131                         break;
132                 case '.':
133                         *va_arg(a, void**) = p;
134                         break;
135                 case 'b':
136                         if(p >= e) goto err;
137                         *p++ = va_arg(a, int);
138                         break;
139                 case '[':
140                 case 's':
141                         s = va_arg(a, void*);
142                         u = va_arg(a, int);
143                         if(c == 's'){
144                                 if(p+4 > e) goto err;
145                                 PUT4(p, u), p += 4;
146                         }
147                         if(u > e-p) goto err;
148                         memmove(p, s, u);
149                         p += u;
150                         break;
151                 case 'u':
152                         u = va_arg(a, int);
153                         if(p+4 > e) goto err;
154                         PUT4(p, u), p += 4;
155                         break;
156                 }
157         }
158 err:
159         return -1;
160 }
161
162 int
163 vunpack(uchar *p, int n, char *fmt, va_list a)
164 {
165         uchar *p0 = p, *e = p+n;
166         u32int u;
167         void *s;
168
169         for(;;){
170                 switch(*fmt++){
171                 case '\0':
172                         return p - p0;
173                 case '_':
174                         if(++p > e) goto err;
175                         break;
176                 case '.':
177                         *va_arg(a, void**) = p;
178                         break;
179                 case 'b':
180                         if(p >= e) goto err;
181                         *va_arg(a, int*) = *p++;
182                         break;
183                 case 's':
184                         if(p+4 > e) goto err;
185                         u = GET4(p), p += 4;
186                         if(u > e-p) goto err;
187                         *va_arg(a, void**) = p;
188                         *va_arg(a, int*) = u;
189                         p += u;
190                         break;
191                 case '[':
192                         s = va_arg(a, void*);
193                         u = va_arg(a, int);
194                         if(u > e-p) goto err;
195                         memmove(s, p, u);
196                         p += u;
197                         break;
198                 case 'u':
199                         if(p+4 > e) goto err;
200                         u = GET4(p);
201                         *va_arg(a, int*) = u;
202                         p += 4;
203                         break;
204                 }
205         }
206 err:
207         return -1;
208 }
209
210 Msg*
211 allocmsg(void)
212 {
213         Msg *m;
214
215         m = emalloc9p(sizeof(Msg));
216         m->link = nil;
217         m->rp = m->wp = m->buf;
218         m->ep = m->rp + sizeof(m->buf);
219         return m;
220 }
221
222 Msg*
223 pack(Msg *m, char *fmt, ...)
224 {
225         va_list a;
226         int n;
227
228         if(m == nil)
229                 m = allocmsg();
230         va_start(a, fmt);
231         n = vpack(m->wp, m->ep - m->wp, fmt, a);
232         if(n < 0)
233                 sysfatal("pack faild");
234         m->wp += n;
235         va_end(a);
236         return m;
237 }
238
239 int
240 unpack(Msg *m, char *fmt, ...)
241 {
242         va_list a;
243         int n;
244
245         va_start(a, fmt);
246         n = vunpack(m->rp, m->wp - m->rp, fmt, a);
247         if(n > 0)
248                 m->rp += n;
249         va_end(a);
250         return n;
251 }
252
253 void
254 sendmsg(Msg *m)
255 {
256         int n;
257
258         if(m == nil)
259                 return;
260         n = m->wp - m->rp;
261         if(n > 0){
262                 if(write(sshfd, m->rp, n) != n)
263                         sysfatal("write to ssh failed: %r");
264         }
265         free(m);
266 }
267
268 int
269 newclient(void)
270 {
271         int i;
272         Client *c;
273
274         for(i=0; i<nclient; i++)
275                 if(client[i]->ref==0 && client[i]->state == Closed)
276                         return i;
277
278         if(nclient%16 == 0)
279                 client = erealloc9p(client, (nclient+16)*sizeof(client[0]));
280
281         c = emalloc9p(sizeof(Client));
282         memset(c, 0, sizeof(*c));
283         c->num = nclient;
284         client[nclient++] = c;
285         return c->num;
286 }
287
288 Client*
289 getclient(int num)
290 {
291         if(num < 0 || num >= nclient)
292                 return nil;
293         return client[num];
294 }
295
296 void
297 adjustwin(Client *c, int len)
298 {
299         c->recvacc += len;
300         if(c->recvacc >= MaxPacket*WinPackets/2 || c->recvwin < MaxPacket){
301                 sendmsg(pack(nil, "buu", MSG_CHANNEL_WINDOW_ADJUST, c->servernum, c->recvacc));
302                 c->recvacc = 0;
303         }
304         c->recvwin += len;
305 }
306
307 void
308 senddata(Client *c, void *data, int len)
309 {
310         sendmsg(pack(nil, "bus", MSG_CHANNEL_DATA, c->servernum, (char*)data, len));
311         c->sendwin -= len;
312 }
313
314 void
315 queuerreq(Client *c, Req *r)
316 {
317         if(c->rq==nil)
318                 c->erq = &c->rq;
319         *c->erq = r;
320         r->aux = nil;
321         c->erq = (Req**)&r->aux;
322 }
323
324 void
325 queuermsg(Client *c, Msg *m)
326 {
327         if(c->mq==nil)
328                 c->emq = &c->mq;
329         *c->emq = m;
330         m->link = nil;
331         c->emq = (Msg**)&m->link;
332 }
333
334 void
335 matchrmsgs(Client *c)
336 {
337         Req *r;
338         Msg *m;
339         int n, rm;
340
341         while(c->rq != nil && c->mq != nil){
342                 r = c->rq;
343                 c->rq = r->aux;
344
345                 rm = 0;
346                 m = c->mq;
347                 n = r->ifcall.count;
348                 if(n >= m->wp - m->rp){
349                         n = m->wp - m->rp;
350                         c->mq = m->link;
351                         rm = 1;
352                 }
353                 memmove(r->ofcall.data, m->rp, n);
354                 if(rm)
355                         free(m);
356                 else
357                         m->rp += n;
358                 r->ofcall.count = n;
359                 respond(r, nil);
360                 adjustwin(c, n);
361         }
362 }
363
364 void
365 queuewreq(Client *c, Req *r)
366 {
367         if(c->wq==nil)
368                 c->ewq = &c->wq;
369         *c->ewq = r;
370         r->aux = nil;
371         c->ewq = (Req**)&r->aux;
372 }
373
374 void
375 procwreqs(Client *c)
376 {
377         Req *r;
378         int n;
379
380         while((r = c->wq) != nil && (n = c->sendwin) > 0){
381                 if(n > c->sendpkt)
382                         n = c->sendpkt;
383                 if(r->ifcall.count > n){
384                         senddata(c, r->ifcall.data, n);
385                         r->ifcall.count -= n;
386                         memmove(r->ifcall.data, (char*)r->ifcall.data + n, r->ifcall.count);
387                         continue;
388                 }
389                 c->wq = (Req*)r->aux;
390                 r->aux = nil;
391                 senddata(c, r->ifcall.data, r->ifcall.count);
392                 r->ofcall.count = r->ifcall.count;
393                 respond(r, nil);
394         }
395 }
396
397 Req*
398 findreq(Client *c, Req *r)
399 {
400         Req **l;
401
402         for(l=&c->rq; *l; l=(Req**)&(*l)->aux){
403                 if(*l == r){
404                         *l = r->aux;
405                         if(*l == nil)
406                                 c->erq = l;
407                         return r;
408                 }
409         }
410         for(l=&c->wq; *l; l=(Req**)&(*l)->aux){
411                 if(*l == r){
412                         *l = r->aux;
413                         if(*l == nil)
414                                 c->ewq = l;
415                         return r;
416                 }
417         }
418         return nil;
419 }
420
421 void
422 dialedclient(Client *c)
423 {
424         Req *r;
425
426         if(r=c->wq){
427                 if(r->aux != nil)
428                         sysfatal("more than one outstanding dial request (BUG)");
429                 if(c->state == Established)
430                         respond(r, nil);
431                 else
432                         respond(r, "connect failed");
433         }
434         c->wq = nil;
435 }
436
437 void
438 teardownclient(Client *c)
439 {
440         c->state = Teardown;
441         sendmsg(pack(nil, "bu", MSG_CHANNEL_EOF, c->servernum));
442 }
443
444 void
445 hangupclient(Client *c)
446 {
447         Req *r, *next;
448         Msg *m, *mnext;
449
450         c->state = Closed;
451         for(m=c->mq; m; m=mnext){
452                 mnext = m->link;
453                 free(m);
454         }
455         c->mq = nil;
456         for(r=c->rq; r; r=next){
457                 next = r->aux;
458                 respond(r, "hangup on network connection");
459         }
460         c->rq = nil;
461         for(r=c->wq; r; r=next){
462                 next = r->aux;
463                 respond(r, "hangup on network connection");
464         }
465         c->wq = nil;
466 }
467
468 void
469 closeclient(Client *c)
470 {
471         Msg *m, *next;
472
473         if(--c->ref)
474                 return;
475
476         if(c->rq != nil || c->wq != nil)
477                 sysfatal("ref count reached zero with requests pending (BUG)");
478
479         for(m=c->mq; m; m=next){
480                 next = m->link;
481                 free(m);
482         }
483         c->mq = nil;
484
485         if(c->state != Closed)
486                 teardownclient(c);
487 }
488
489         
490 void
491 sshreadproc(void*)
492 {
493         Msg *m;
494         int n;
495
496         for(;;){
497                 m = allocmsg();
498                 n = read(sshfd, m->rp, m->ep - m->rp);
499                 if(n <= 0)
500                         sysfatal("eof on ssh connection");
501                 m->wp += n;
502                 sendp(sshmsgchan, m);
503         }
504 }
505
506 typedef struct Tab Tab;
507 struct Tab
508 {
509         char *name;
510         ulong mode;
511 };
512
513 Tab tab[] =
514 {
515         "/",            DMDIR|0555,
516         "cs",           0666,
517         "tcp",          DMDIR|0555,     
518         "clone",        0666,
519         nil,            DMDIR|0555,
520         "ctl",          0666,
521         "data",         0666,
522         "local",        0444,
523         "remote",       0444,
524         "status",       0444,
525 };
526
527 static void
528 fillstat(Dir *d, uvlong path)
529 {
530         Tab *t;
531
532         memset(d, 0, sizeof(*d));
533         d->uid = estrdup9p("ssh");
534         d->gid = estrdup9p("ssh");
535         d->qid.path = path;
536         d->atime = d->mtime = time0;
537         t = &tab[TYPE(path)];
538         if(t->name)
539                 d->name = estrdup9p(t->name);
540         else{
541                 d->name = smprint("%ud", NUM(path));
542                 if(d->name == nil)
543                         sysfatal("out of memory");
544         }
545         d->qid.type = t->mode>>24;
546         d->mode = t->mode;
547 }
548
549 static void
550 fsattach(Req *r)
551 {
552         if(r->ifcall.aname && r->ifcall.aname[0]){
553                 respond(r, "invalid attach specifier");
554                 return;
555         }
556         r->fid->qid.path = PATH(Qroot, 0);
557         r->fid->qid.type = QTDIR;
558         r->fid->qid.vers = 0;
559         r->ofcall.qid = r->fid->qid;
560         respond(r, nil);
561 }
562
563 static void
564 fsstat(Req *r)
565 {
566         fillstat(&r->d, r->fid->qid.path);
567         respond(r, nil);
568 }
569
570 static int
571 rootgen(int i, Dir *d, void*)
572 {
573         i += Qroot+1;
574         if(i <= Qtcp){
575                 fillstat(d, i);
576                 return 0;
577         }
578         return -1;
579 }
580
581 static int
582 tcpgen(int i, Dir *d, void*)
583 {
584         i += Qtcp+1;
585         if(i < Qn){
586                 fillstat(d, i);
587                 return 0;
588         }
589         i -= Qn;
590         if(i < nclient){
591                 fillstat(d, PATH(Qn, i));
592                 return 0;
593         }
594         return -1;
595 }
596
597 static int
598 clientgen(int i, Dir *d, void *aux)
599 {
600         Client *c;
601
602         c = aux;
603         i += Qn+1;
604         if(i <= Qstatus){
605                 fillstat(d, PATH(i, c->num));
606                 return 0;
607         }
608         return -1;
609 }
610
611 static char*
612 fswalk1(Fid *fid, char *name, Qid *qid)
613 {
614         int i, n;
615         char buf[32];
616         ulong path;
617
618         path = fid->qid.path;
619         if(!(fid->qid.type&QTDIR))
620                 return "walk in non-directory";
621
622         if(strcmp(name, "..") == 0){
623                 switch(TYPE(path)){
624                 case Qn:
625                         qid->path = PATH(Qtcp, NUM(path));
626                         qid->type = tab[Qtcp].mode>>24;
627                         return nil;
628                 case Qtcp:
629                         qid->path = PATH(Qroot, 0);
630                         qid->type = tab[Qroot].mode>>24;
631                         return nil;
632                 case Qroot:
633                         return nil;
634                 default:
635                         return "bug in fswalk1";
636                 }
637         }
638
639         i = TYPE(path)+1;
640         for(; i<nelem(tab); i++){
641                 if(i==Qn){
642                         n = atoi(name);
643                         snprint(buf, sizeof buf, "%d", n);
644                         if(n < nclient && strcmp(buf, name) == 0){
645                                 qid->path = PATH(i, n);
646                                 qid->type = tab[i].mode>>24;
647                                 return nil;
648                         }
649                         break;
650                 }
651                 if(strcmp(name, tab[i].name) == 0){
652                         qid->path = PATH(i, NUM(path));
653                         qid->type = tab[i].mode>>24;
654                         return nil;
655                 }
656                 if(tab[i].mode&DMDIR)
657                         break;
658         }
659         return "directory entry not found";
660 }
661
662 typedef struct Cs Cs;
663 struct Cs
664 {
665         char *resp;
666         int isnew;
667 };
668
669 static int
670 ndbfindport(char *p)
671 {
672         char *s, *port;
673         int n;
674         static Ndb *db;
675
676         if(*p == '\0')
677                 return -1;
678
679         n = strtol(p, &s, 0);
680         if(*s == '\0')
681                 return n;
682
683         if(db == nil){
684                 db = ndbopen("/lib/ndb/common");
685                 if(db == nil)
686                         return -1;
687         }
688
689         port = ndbgetvalue(db, nil, "tcp", p, "port", nil);
690         if(port == nil)
691                 return -1;
692         n = atoi(port);
693         free(port);
694
695         return n;
696 }       
697
698 static void
699 csread(Req *r)
700 {
701         Cs *cs;
702
703         cs = r->fid->aux;
704         if(cs->resp==nil){
705                 respond(r, "cs read without write");
706                 return;
707         }
708         if(r->ifcall.offset==0){
709                 if(!cs->isnew){
710                         r->ofcall.count = 0;
711                         respond(r, nil);
712                         return;
713                 }
714                 cs->isnew = 0;
715         }
716         readstr(r, cs->resp);
717         respond(r, nil);
718 }
719
720 static void
721 cswrite(Req *r)
722 {
723         int port, nf;
724         char err[ERRMAX], *f[4], *s, *ns;
725         Cs *cs;
726
727         cs = r->fid->aux;
728         s = emalloc9p(r->ifcall.count+1);
729         memmove(s, r->ifcall.data, r->ifcall.count);
730         s[r->ifcall.count] = '\0';
731
732         nf = getfields(s, f, nelem(f), 0, "!");
733         if(nf != 3){
734                 free(s);
735                 respond(r, "can't translate");
736                 return;
737         }
738         if(strcmp(f[0], "tcp") != 0 && strcmp(f[0], "net") != 0){
739                 free(s);
740                 respond(r, "unknown protocol");
741                 return;
742         }
743         port = ndbfindport(f[2]);
744         if(port <= 0){
745                 free(s);
746                 respond(r, "no translation found");
747                 return;
748         }
749
750         ns = smprint("%s/tcp/clone %s!%d", mtpt, f[1], port);
751         if(ns == nil){
752                 free(s);
753                 rerrstr(err, sizeof err);
754                 respond(r, err);
755                 return;
756         }
757         free(s);
758         free(cs->resp);
759         cs->resp = ns;
760         cs->isnew = 1;
761         r->ofcall.count = r->ifcall.count;
762         respond(r, nil);
763 }
764
765 static void
766 ctlread(Req *r, Client *c)
767 {
768         char buf[32];
769
770         sprint(buf, "%d", c->num);
771         readstr(r, buf);
772         respond(r, nil);
773 }
774
775 static void
776 ctlwrite(Req *r, Client *c)
777 {
778         char *f[3], *s;
779         int nf;
780
781         s = emalloc9p(r->ifcall.count+1);
782         memmove(s, r->ifcall.data, r->ifcall.count);
783         s[r->ifcall.count] = '\0';
784
785         nf = tokenize(s, f, 3);
786         if(nf == 0){
787                 free(s);
788                 r->ofcall.count = r->ifcall.count;
789                 respond(r, nil);
790                 return;
791         }
792
793         if(strcmp(f[0], "hangup") == 0){
794                 if(c->state != Established)
795                         goto Badarg;
796                 if(nf != 1)
797                         goto Badarg;
798                 teardownclient(c);
799                 r->ofcall.count = r->ifcall.count;
800                 respond(r, nil);
801         }else if(strcmp(f[0], "connect") == 0){
802                 if(c->state != Closed)
803                         goto Badarg;
804                 if(nf != 2)
805                         goto Badarg;
806                 c->connect = estrdup9p(f[1]);
807                 nf = getfields(f[1], f, nelem(f), 0, "!");
808                 if(nf != 2){
809                         free(c->connect);
810                         c->connect = nil;
811                         goto Badarg;
812                 }
813                 c->sendwin = MaxPacket;
814                 c->recvwin = WinPackets * MaxPacket;
815                 c->recvacc = 0;
816                 c->state = Dialing;
817                 queuewreq(c, r);
818
819                 sendmsg(pack(nil, "bsuuususu", MSG_CHANNEL_OPEN,
820                         "direct-tcpip", 12,
821                         c->num, c->recvwin, MaxPacket,
822                         f[0], strlen(f[0]), ndbfindport(f[1]),
823                         localip, strlen(localip), localport));
824         }else{
825         Badarg:
826                 respond(r, "bad or inappropriate tcp control message");
827         }
828         free(s);
829 }
830
831 static void
832 dataread(Req *r, Client *c)
833 {
834         if(c->state != Established){
835                 respond(r, "not connected");
836                 return;
837         }
838         queuerreq(c, r);
839         matchrmsgs(c);
840 }
841
842 static void
843 datawrite(Req *r, Client *c)
844 {
845         if(c->state != Established){
846                 respond(r, "not connected");
847                 return;
848         }
849         if(r->ifcall.count == 0){
850                 r->ofcall.count = r->ifcall.count;
851                 respond(r, nil);
852                 return;
853         }
854         queuewreq(c, r);
855         procwreqs(c);
856 }
857
858 static void
859 localread(Req *r)
860 {
861         char buf[128];
862
863         snprint(buf, sizeof buf, "%s!%d\n", localip, localport);
864         readstr(r, buf);
865         respond(r, nil);
866 }
867
868 static void
869 remoteread(Req *r, Client *c)
870 {
871         char *s;
872         char buf[128];
873
874         s = c->connect;
875         if(s == nil)
876                 s = "::!0";
877         snprint(buf, sizeof buf, "%s\n", s);
878         readstr(r, buf);
879         respond(r, nil);
880 }
881
882 static void
883 statusread(Req *r, Client *c)
884 {
885         char *s;
886
887         s = statestr[c->state];
888         readstr(r, s);
889         respond(r, nil);
890 }
891
892 static void
893 fsread(Req *r)
894 {
895         char e[ERRMAX];
896         ulong path;
897
898         path = r->fid->qid.path;
899         switch(TYPE(path)){
900         default:
901                 snprint(e, sizeof e, "bug in fsread path=%lux", path);
902                 respond(r, e);
903                 break;
904
905         case Qroot:
906                 dirread9p(r, rootgen, nil);
907                 respond(r, nil);
908                 break;
909
910         case Qcs:
911                 csread(r);
912                 break;
913
914         case Qtcp:
915                 dirread9p(r, tcpgen, nil);
916                 respond(r, nil);
917                 break;
918
919         case Qn:
920                 dirread9p(r, clientgen, client[NUM(path)]);
921                 respond(r, nil);
922                 break;
923
924         case Qctl:
925                 ctlread(r, client[NUM(path)]);
926                 break;
927
928         case Qdata:
929                 dataread(r, client[NUM(path)]);
930                 break;
931
932         case Qlocal:
933                 localread(r);
934                 break;
935
936         case Qremote:
937                 remoteread(r, client[NUM(path)]);
938                 break;
939
940         case Qstatus:
941                 statusread(r, client[NUM(path)]);
942                 break;
943         }
944 }
945
946 static void
947 fswrite(Req *r)
948 {
949         ulong path;
950         char e[ERRMAX];
951
952         path = r->fid->qid.path;
953         switch(TYPE(path)){
954         default:
955                 snprint(e, sizeof e, "bug in fswrite path=%lux", path);
956                 respond(r, e);
957                 break;
958
959         case Qcs:
960                 cswrite(r);
961                 break;
962
963         case Qctl:
964                 ctlwrite(r, client[NUM(path)]);
965                 break;
966
967         case Qdata:
968                 datawrite(r, client[NUM(path)]);
969                 break;
970         }
971 }
972
973 static void
974 fsopen(Req *r)
975 {
976         static int need[4] = { 4, 2, 6, 1 };
977         ulong path;
978         int n;
979         Tab *t;
980         Cs *cs;
981
982         /*
983          * lib9p already handles the blatantly obvious.
984          * we just have to enforce the permissions we have set.
985          */
986         path = r->fid->qid.path;
987         t = &tab[TYPE(path)];
988         n = need[r->ifcall.mode&3];
989         if((n&t->mode) != n){
990                 respond(r, "permission denied");
991                 return;
992         }
993
994         switch(TYPE(path)){
995         case Qcs:
996                 cs = emalloc9p(sizeof(Cs));
997                 r->fid->aux = cs;
998                 respond(r, nil);
999                 break;
1000         case Qclone:
1001                 n = newclient();
1002                 path = PATH(Qctl, n);
1003                 r->fid->qid.path = path;
1004                 r->ofcall.qid.path = path;
1005                 if(chatty9p)
1006                         fprint(2, "open clone => path=%lux\n", path);
1007                 t = &tab[Qctl];
1008                 /* fall through */
1009         default:
1010                 if(t-tab >= Qn)
1011                         client[NUM(path)]->ref++;
1012                 respond(r, nil);
1013                 break;
1014         }
1015 }
1016
1017 static void
1018 fsflush(Req *r)
1019 {
1020         int i;
1021
1022         for(i=0; i<nclient; i++)
1023                 if(findreq(client[i], r->oldreq))
1024                         respond(r->oldreq, "interrupted");
1025         respond(r, nil);
1026 }
1027
1028 static void
1029 handlemsg(Msg *m)
1030 {
1031         int chan, win, pkt, n;
1032         Client *c;
1033         char *s;
1034
1035         switch(m->rp[0]){
1036         case MSG_CHANNEL_WINDOW_ADJUST:
1037                 if(unpack(m, "_uu", &chan, &n) < 0)
1038                         break;
1039                 c = getclient(chan);
1040                 if(c != nil && c->state==Established){
1041                         c->sendwin += n;
1042                         procwreqs(c);
1043                 }
1044                 break;
1045         case MSG_CHANNEL_DATA:
1046                 if(unpack(m, "_us", &chan, &s, &n) < 0)
1047                         break;
1048                 c = getclient(chan);
1049                 if(c != nil && c->state==Established){
1050                         c->recvwin -= n;
1051                         m->rp = (uchar*)s;
1052                         queuermsg(c, m);
1053                         matchrmsgs(c);
1054                         return;
1055                 }
1056                 break;
1057         case MSG_CHANNEL_EOF:
1058                 if(unpack(m, "_u", &chan) < 0)
1059                         break;
1060                 c = getclient(chan);
1061                 if(c != nil){
1062                         hangupclient(c);
1063                         m->rp = m->wp = m->buf;
1064                         sendmsg(pack(m, "bu", MSG_CHANNEL_CLOSE, c->servernum));
1065                         return;
1066                 }
1067                 break;
1068         case MSG_CHANNEL_CLOSE:
1069                 if(unpack(m, "_u", &chan) < 0)
1070                         break;
1071                 c = getclient(chan);
1072                 if(c != nil)
1073                         hangupclient(c);
1074                 break;
1075         case MSG_CHANNEL_OPEN_CONFIRMATION:
1076                 if(unpack(m, "_uuuu", &chan, &n, &win, &pkt) < 0)
1077                         break;
1078                 c = getclient(chan);
1079                 if(c == nil || c->state != Dialing)
1080                         break;
1081                 if(pkt <= 0 || pkt > MaxPacket)
1082                         pkt = MaxPacket;
1083                 c->sendpkt = pkt;
1084                 c->sendwin = win;
1085                 c->servernum = n;
1086                 c->state = Established;
1087                 dialedclient(c);
1088                 break;
1089         case MSG_CHANNEL_OPEN_FAILURE:
1090                 if(unpack(m, "_uu", &chan, &n) < 0)
1091                         break;
1092                 c = getclient(chan);
1093                 if(c == nil || c->state != Dialing)
1094                         break;
1095                 c->servernum = n;
1096                 c->state = Closed;
1097                 dialedclient(c);
1098                 break;
1099         }
1100         free(m);
1101 }
1102
1103 void
1104 fsnetproc(void*)
1105 {
1106         ulong path;
1107         Alt a[4];
1108         Cs *cs;
1109         Fid *fid;
1110         Req *r;
1111         Msg *m;
1112
1113         threadsetname("fsthread");
1114
1115         a[0].op = CHANRCV;
1116         a[0].c = fsclunkchan;
1117         a[0].v = &fid;
1118         a[1].op = CHANRCV;
1119         a[1].c = fsreqchan;
1120         a[1].v = &r;
1121         a[2].op = CHANRCV;
1122         a[2].c = sshmsgchan;
1123         a[2].v = &m;
1124         a[3].op = CHANEND;
1125
1126         for(;;){
1127                 switch(alt(a)){
1128                 case 0:
1129                         path = fid->qid.path;
1130                         switch(TYPE(path)){
1131                         case Qcs:
1132                                 cs = fid->aux;
1133                                 if(cs){
1134                                         free(cs->resp);
1135                                         free(cs);
1136                                 }
1137                                 break;
1138                         }
1139                         if(fid->omode != -1 && TYPE(path) >= Qn)
1140                                 closeclient(client[NUM(path)]);
1141                         sendp(fsclunkwaitchan, nil);
1142                         break;
1143                 case 1:
1144                         switch(r->ifcall.type){
1145                         case Tattach:
1146                                 fsattach(r);
1147                                 break;
1148                         case Topen:
1149                                 fsopen(r);
1150                                 break;
1151                         case Tread:
1152                                 fsread(r);
1153                                 break;
1154                         case Twrite:
1155                                 fswrite(r);
1156                                 break;
1157                         case Tstat:
1158                                 fsstat(r);
1159                                 break;
1160                         case Tflush:
1161                                 fsflush(r);
1162                                 break;
1163                         default:
1164                                 respond(r, "bug in fsthread");
1165                                 break;
1166                         }
1167                         sendp(fsreqwaitchan, 0);
1168                         break;
1169                 case 2:
1170                         handlemsg(m);
1171                         break;
1172                 }
1173         }
1174 }
1175
1176 static void
1177 fssend(Req *r)
1178 {
1179         sendp(fsreqchan, r);
1180         recvp(fsreqwaitchan);   /* avoids need to deal with spurious flushes */
1181 }
1182
1183 static void
1184 fsdestroyfid(Fid *fid)
1185 {
1186         sendp(fsclunkchan, fid);
1187         recvp(fsclunkwaitchan);
1188 }
1189
1190 void
1191 takedown(Srv*)
1192 {
1193         threadexitsall("done");
1194 }
1195
1196 Srv fs = 
1197 {
1198 .attach=                fssend,
1199 .destroyfid=    fsdestroyfid,
1200 .walk1=         fswalk1,
1201 .open=          fssend,
1202 .read=          fssend,
1203 .write=         fssend,
1204 .stat=          fssend,
1205 .flush=         fssend,
1206 .end=           takedown,
1207 };
1208
1209 int pfd[2];
1210 int sshargc;
1211 char **sshargv;
1212
1213 void
1214 startssh(void *)
1215 {
1216         char *f;
1217
1218         close(pfd[0]);
1219         dup(pfd[1], 0);
1220         dup(pfd[1], 1);
1221         close(pfd[1]);
1222         if(strncmp(sshargv[0], "./", 2) != 0)
1223                 f = smprint("/bin/%s", sshargv[0]);
1224         else
1225                 f = sshargv[0];
1226         procexec(nil, f, sshargv);
1227         sysfatal("exec: %r");
1228 }
1229
1230 void
1231 usage(void)
1232 {
1233         fprint(2, "usage: sshnet [-m mtpt] [ssh options]\n");
1234         exits("usage");
1235 }
1236
1237 void
1238 threadmain(int argc, char **argv)
1239 {
1240         char *service;
1241
1242         fmtinstall('H', encodefmt);
1243
1244         mtpt = "/net";
1245         service = nil;
1246         ARGBEGIN{
1247         case 'D':
1248                 chatty9p++;
1249                 break;
1250         case 'm':
1251                 mtpt = EARGF(usage());
1252                 break;
1253         case 's':
1254                 service = EARGF(usage());
1255                 break;
1256         default:
1257                 usage();
1258         }ARGEND
1259
1260         if(argc == 0)
1261                 usage();
1262         
1263         sshargc = argc + 2;
1264         sshargv = emalloc9p(sizeof(char *) * (sshargc + 1));
1265         sshargv[0] = "ssh";
1266         sshargv[1] = "-X";
1267         memcpy(sshargv + 2, argv, argc * sizeof(char *));
1268
1269         pipe(pfd);
1270         sshfd = pfd[0];
1271         procrfork(startssh, nil, mainstacksize, RFFDG|RFNOTEG|RFNAMEG);
1272         close(pfd[1]);
1273
1274         time0 = time(0);
1275         sshmsgchan = chancreate(sizeof(Msg*), 16);
1276         fsreqchan = chancreate(sizeof(Req*), 0);
1277         fsreqwaitchan = chancreate(sizeof(void*), 0);
1278         fsclunkchan = chancreate(sizeof(Fid*), 0);
1279         fsclunkwaitchan = chancreate(sizeof(void*), 0);
1280
1281         procrfork(sshreadproc, nil, mainstacksize, RFNAMEG|RFNOTEG);
1282         procrfork(fsnetproc, nil, mainstacksize, RFNAMEG|RFNOTEG);
1283
1284         threadpostmountsrv(&fs, service, mtpt, MREPL);
1285         exits(0);
1286 }