]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/sshnet.c
merge
[plan9front.git] / sys / src / cmd / sshnet.c
1 /*
2  * SSH network file system.
3  * Presents remote TCP stack as /net-style file system.
4  */
5
6 #include <u.h>
7 #include <libc.h>
8 #include <bio.h>
9 #include <ndb.h>
10 #include <thread.h>
11 #include <fcall.h>
12 #include <9p.h>
13
14 typedef struct Client Client;
15 typedef struct Msg Msg;
16
17 enum
18 {
19         Qroot,
20         Qcs,
21         Qtcp,
22         Qclone,
23         Qn,
24         Qctl,
25         Qdata,
26         Qlocal,
27         Qremote,
28         Qstatus,
29         Qlisten,
30 };
31
32 #define PATH(type, n)           ((type)|((n)<<8))
33 #define TYPE(path)              ((int)(path) & 0xFF)
34 #define NUM(path)               ((uint)(path)>>8)
35
36 int sessionopen = 0;
37 Channel *sshmsgchan;            /* chan(Msg*) */
38 Channel *fsreqchan;             /* chan(Req*) */
39 Channel *fsreqwaitchan;         /* chan(nil) */
40 Channel *fsclunkchan;           /* chan(Fid*) */
41 Channel *fsclunkwaitchan;       /* chan(nil) */
42 ulong time0;
43
44 enum
45 {
46         Closed,
47         Dialing,
48         Listen,
49         Established,
50         Teardown,
51         Finished,
52 };
53
54 char *statestr[] = {
55         "Closed",
56         "Dialing",
57         "Listen",
58         "Established",
59         "Teardown",
60         "Finished",
61 };
62
63 struct Client
64 {
65         int ref;
66         int state;
67         int num;
68         int servernum;
69
70         int rport, lport;
71         char *rhost;
72         char *lhost;
73
74         int sendpkt;
75         int sendwin;
76         int recvwin;
77         int recvacc;
78
79         int eof;
80
81         Req *wq;
82         Req **ewq;
83
84         Req *rq;
85         Req **erq;
86
87         Msg *mq;
88         Msg **emq;
89 };
90
91 enum {
92         MSG_GLOBAL_REQUEST = 80,
93
94         MSG_CHANNEL_OPEN = 90,
95         MSG_CHANNEL_OPEN_CONFIRMATION,
96         MSG_CHANNEL_OPEN_FAILURE,
97         MSG_CHANNEL_WINDOW_ADJUST,
98         MSG_CHANNEL_DATA,
99         MSG_CHANNEL_EXTENDED_DATA,
100         MSG_CHANNEL_EOF,
101         MSG_CHANNEL_CLOSE,
102         MSG_CHANNEL_REQUEST,
103         MSG_CHANNEL_SUCCESS,
104         MSG_CHANNEL_FAILURE,
105
106         Overhead = 256,
107         MaxPacket = (1<<15)-256,        /* 32K is maxatomic for pipe */
108         WinPackets = 8,
109
110         SESSIONCHAN = 1<<24,
111 };
112
113 struct Msg
114 {
115         Msg     *link;
116
117         uchar   *rp;
118         uchar   *wp;
119         uchar   *ep;
120         uchar   buf[MaxPacket + Overhead];
121 };
122
123 #define PUT4(p, u) (p)[0] = (u)>>24, (p)[1] = (u)>>16, (p)[2] = (u)>>8, (p)[3] = (u)
124 #define GET4(p) (u32int)(p)[3] | (u32int)(p)[2]<<8 | (u32int)(p)[1]<<16 | (u32int)(p)[0]<<24
125
126 int nclient;
127 Client **client;
128 char *mtpt, *service;
129 int sshfd;
130
131 int
132 vpack(uchar *p, int n, char *fmt, va_list a)
133 {
134         uchar *p0 = p, *e = p+n;
135         u32int u;
136         void *s;
137         int c;
138
139         for(;;){
140                 switch(c = *fmt++){
141                 case '\0':
142                         return p - p0;
143                 case '_':
144                         if(++p > e) goto err;
145                         break;
146                 case '.':
147                         *va_arg(a, void**) = p;
148                         break;
149                 case 'b':
150                         if(p >= e) goto err;
151                         *p++ = va_arg(a, int);
152                         break;
153                 case '[':
154                 case 's':
155                         s = va_arg(a, void*);
156                         u = va_arg(a, int);
157                         if(c == 's'){
158                                 if(p+4 > e) goto err;
159                                 PUT4(p, u), p += 4;
160                         }
161                         if(u > e-p) goto err;
162                         memmove(p, s, u);
163                         p += u;
164                         break;
165                 case 'u':
166                         u = va_arg(a, int);
167                         if(p+4 > e) goto err;
168                         PUT4(p, u), p += 4;
169                         break;
170                 }
171         }
172 err:
173         return -1;
174 }
175
176 int
177 vunpack(uchar *p, int n, char *fmt, va_list a)
178 {
179         uchar *p0 = p, *e = p+n;
180         u32int u;
181         void *s;
182
183         for(;;){
184                 switch(*fmt++){
185                 case '\0':
186                         return p - p0;
187                 case '_':
188                         if(++p > e) goto err;
189                         break;
190                 case '.':
191                         *va_arg(a, void**) = p;
192                         break;
193                 case 'b':
194                         if(p >= e) goto err;
195                         *va_arg(a, int*) = *p++;
196                         break;
197                 case 's':
198                         if(p+4 > e) goto err;
199                         u = GET4(p), p += 4;
200                         if(u > e-p) goto err;
201                         *va_arg(a, void**) = p;
202                         *va_arg(a, int*) = u;
203                         p += u;
204                         break;
205                 case '[':
206                         s = va_arg(a, void*);
207                         u = va_arg(a, int);
208                         if(u > e-p) goto err;
209                         memmove(s, p, u);
210                         p += u;
211                         break;
212                 case 'u':
213                         if(p+4 > e) goto err;
214                         u = GET4(p);
215                         *va_arg(a, int*) = u;
216                         p += 4;
217                         break;
218                 }
219         }
220 err:
221         return -1;
222 }
223
224 Msg*
225 allocmsg(void)
226 {
227         Msg *m;
228
229         m = emalloc9p(sizeof(Msg));
230         m->link = nil;
231         m->rp = m->wp = m->buf;
232         m->ep = m->rp + sizeof(m->buf);
233         return m;
234 }
235
236 Msg*
237 pack(Msg *m, char *fmt, ...)
238 {
239         va_list a;
240         int n;
241
242         if(m == nil)
243                 m = allocmsg();
244         va_start(a, fmt);
245         n = vpack(m->wp, m->ep - m->wp, fmt, a);
246         if(n < 0)
247                 sysfatal("pack faild");
248         m->wp += n;
249         va_end(a);
250         return m;
251 }
252
253 int
254 unpack(Msg *m, char *fmt, ...)
255 {
256         va_list a;
257         int n;
258
259         va_start(a, fmt);
260         n = vunpack(m->rp, m->wp - m->rp, fmt, a);
261         if(n > 0)
262                 m->rp += n;
263         va_end(a);
264         return n;
265 }
266
267 void
268 sendmsg(Msg *m)
269 {
270         int n;
271
272         if(m == nil)
273                 return;
274         n = m->wp - m->rp;
275         if(n > 0){
276                 if(write(sshfd, m->rp, n) != n)
277                         sysfatal("write to ssh failed: %r");
278         }
279         free(m);
280 }
281
282 int
283 newclient(void)
284 {
285         int i;
286         Client *c;
287
288         for(i=0; i<nclient; i++)
289                 if(client[i]->ref==0 && client[i]->state == Closed)
290                         return i;
291
292         if(nclient%16 == 0)
293                 client = erealloc9p(client, (nclient+16)*sizeof(client[0]));
294
295         c = emalloc9p(sizeof(Client));
296         memset(c, 0, sizeof(*c));
297         c->num = nclient;
298         client[nclient++] = c;
299         return c->num;
300 }
301
302 Client*
303 getclient(int num)
304 {
305         if(num < 0 || num >= nclient)
306                 return nil;
307         return client[num];
308 }
309
310 Client*
311 acceptclient(char *lhost, int lport, char *rhost, int rport)
312 {
313         Client *c, *nc;
314         int i;
315
316         for(i = 0; i < nclient; i++){
317                 c = client[i];
318                 if(c->state == Listen && c->lport == lport && c->wq != nil){
319                         nc = client[newclient()];
320                         nc->wq = c->wq;
321                         c->wq = nc->wq->aux;
322                         nc->wq->aux = nil;
323                         free(nc->lhost);
324                         nc->lhost = lhost;
325                         nc->lport = lport;
326                         free(nc->rhost);
327                         nc->rhost = rhost;
328                         nc->rport = rport;
329                         return nc;
330                 }
331         }
332         return nil;
333 }
334
335 void
336 adjustwin(Client *c, int len)
337 {
338         c->recvacc += len;
339         if(c->recvacc >= MaxPacket*WinPackets/2 || c->recvwin < MaxPacket){
340                 sendmsg(pack(nil, "buu", MSG_CHANNEL_WINDOW_ADJUST, c->servernum, c->recvacc));
341                 c->recvacc = 0;
342         }
343         c->recvwin += len;
344 }
345
346 void
347 senddata(Client *c, void *data, int len)
348 {
349         sendmsg(pack(nil, "bus", MSG_CHANNEL_DATA, c->servernum, (char*)data, len));
350         c->sendwin -= len;
351 }
352
353 void
354 queuerreq(Client *c, Req *r)
355 {
356         if(c->rq==nil)
357                 c->erq = &c->rq;
358         *c->erq = r;
359         r->aux = nil;
360         c->erq = (Req**)&r->aux;
361 }
362
363 void
364 queuermsg(Client *c, Msg *m)
365 {
366         if(c->mq==nil)
367                 c->emq = &c->mq;
368         *c->emq = m;
369         m->link = nil;
370         c->emq = (Msg**)&m->link;
371 }
372
373 void
374 matchrmsgs(Client *c)
375 {
376         Req *r;
377         Msg *m;
378         int n, rm;
379
380         while((r = c->rq) != nil && (m = c->mq) != nil){
381                 c->rq = r->aux;
382                 r->aux = nil;
383                 rm = 0;
384                 n = r->ifcall.count;
385                 if(n >= m->wp - m->rp){
386                         n = m->wp - m->rp;
387                         c->mq = m->link;
388                         rm = 1;
389                 }
390                 memmove(r->ofcall.data, m->rp, n);
391                 if(rm)
392                         free(m);
393                 else
394                         m->rp += n;
395                 r->ofcall.count = n;
396                 respond(r, nil);
397                 adjustwin(c, n);
398         }
399
400         if(c->eof){
401                 while((r = c->rq) != nil){
402                         c->rq = r->aux;
403                         r->aux = nil;
404                         r->ofcall.count = 0;
405                         respond(r, nil);
406                 }
407         }
408 }
409
410 void
411 queuewreq(Client *c, Req *r)
412 {
413         if(c->wq==nil)
414                 c->ewq = &c->wq;
415         *c->ewq = r;
416         r->aux = nil;
417         c->ewq = (Req**)&r->aux;
418 }
419
420 void
421 procwreqs(Client *c)
422 {
423         Req *r;
424         int n;
425
426         while((r = c->wq) != nil && (n = c->sendwin) > 0){
427                 if(n > c->sendpkt)
428                         n = c->sendpkt;
429                 if(r->ifcall.count > n){
430                         senddata(c, r->ifcall.data, n);
431                         r->ifcall.count -= n;
432                         memmove(r->ifcall.data, (char*)r->ifcall.data + n, r->ifcall.count);
433                         continue;
434                 }
435                 c->wq = (Req*)r->aux;
436                 r->aux = nil;
437                 senddata(c, r->ifcall.data, r->ifcall.count);
438                 r->ofcall.count = r->ifcall.count;
439                 respond(r, nil);
440         }
441 }
442
443 Req*
444 findreq(Client *c, Req *r)
445 {
446         Req **l;
447
448         for(l=&c->rq; *l; l=(Req**)&(*l)->aux){
449                 if(*l == r){
450                         *l = r->aux;
451                         if(*l == nil)
452                                 c->erq = l;
453                         return r;
454                 }
455         }
456         for(l=&c->wq; *l; l=(Req**)&(*l)->aux){
457                 if(*l == r){
458                         *l = r->aux;
459                         if(*l == nil)
460                                 c->ewq = l;
461                         return r;
462                 }
463         }
464         return nil;
465 }
466
467 void
468 hangupclient(Client *c, char *err)
469 {
470         Req *r;
471
472         c->eof = 1;
473         c->recvwin = 0;
474         c->sendwin = 0;
475         while((r = c->wq) != nil){
476                 c->wq = r->aux;
477                 r->aux = nil;
478                 respond(r, err);
479         }
480         matchrmsgs(c);
481 }
482
483 void
484 teardownclient(Client *c)
485 {
486         c->state = Teardown;
487         hangupclient(c, "i/o on hungup channel");
488         sendmsg(pack(nil, "bu", MSG_CHANNEL_CLOSE, c->servernum));
489 }
490
491 void
492 closeclient(Client *c)
493 {
494         Msg *m;
495
496         if(--c->ref)
497                 return;
498         switch(c->state){
499         case Established:
500                 teardownclient(c);
501                 break;
502         case Finished:
503                 c->state = Closed;
504                 sendmsg(pack(nil, "bu", MSG_CHANNEL_CLOSE, c->servernum));
505                 break;
506         case Listen:
507                 c->state = Closed;
508                 sendmsg(pack(nil, "bsbsu", MSG_GLOBAL_REQUEST,
509                         "cancel-tcpip-forward", 20,
510                         0,
511                         c->lhost, strlen(c->lhost),
512                         c->lport));
513                 break;
514         }
515         while((m = c->mq) != nil){
516                 c->mq = m->link;
517                 free(m);
518         }
519 }
520
521 void
522 sshreadproc(void*)
523 {
524         Msg *m;
525         int n;
526
527         for(;;){
528                 m = allocmsg();
529                 n = read(sshfd, m->rp, m->ep - m->rp);
530                 if(n <= 0)
531                         sysfatal("eof on ssh connection");
532                 m->wp += n;
533                 sendp(sshmsgchan, m);
534         }
535 }
536
537 typedef struct Tab Tab;
538 struct Tab
539 {
540         char *name;
541         ulong mode;
542 };
543
544 Tab tab[] =
545 {
546         "/",            DMDIR|0555,
547         "cs",           0666,
548         "tcp",          DMDIR|0555,     
549         "clone",        0666,
550         nil,            DMDIR|0555,
551         "ctl",          0666,
552         "data",         0666,
553         "local",        0444,
554         "remote",       0444,
555         "status",       0444,
556         "listen",       0666,
557 };
558
559 static void
560 fillstat(Dir *d, uvlong path)
561 {
562         Tab *t;
563
564         memset(d, 0, sizeof(*d));
565         d->uid = estrdup9p("ssh");
566         d->gid = estrdup9p("ssh");
567         d->qid.path = path;
568         d->atime = d->mtime = time0;
569         t = &tab[TYPE(path)];
570         if(t->name)
571                 d->name = estrdup9p(t->name);
572         else{
573                 d->name = smprint("%ud", NUM(path));
574                 if(d->name == nil)
575                         sysfatal("out of memory");
576         }
577         d->qid.type = t->mode>>24;
578         d->mode = t->mode;
579 }
580
581 static void
582 fsattach(Req *r)
583 {
584         if(r->ifcall.aname && r->ifcall.aname[0]){
585                 respond(r, "invalid attach specifier");
586                 return;
587         }
588         r->fid->qid.path = PATH(Qroot, 0);
589         r->fid->qid.type = QTDIR;
590         r->fid->qid.vers = 0;
591         r->ofcall.qid = r->fid->qid;
592         respond(r, nil);
593 }
594
595 static void
596 fsstat(Req *r)
597 {
598         fillstat(&r->d, r->fid->qid.path);
599         respond(r, nil);
600 }
601
602 static int
603 rootgen(int i, Dir *d, void*)
604 {
605         i += Qroot+1;
606         if(i <= Qtcp){
607                 fillstat(d, i);
608                 return 0;
609         }
610         return -1;
611 }
612
613 static int
614 tcpgen(int i, Dir *d, void*)
615 {
616         i += Qtcp+1;
617         if(i < Qn){
618                 fillstat(d, i);
619                 return 0;
620         }
621         i -= Qn;
622         if(i < nclient){
623                 fillstat(d, PATH(Qn, i));
624                 return 0;
625         }
626         return -1;
627 }
628
629 static int
630 clientgen(int i, Dir *d, void *aux)
631 {
632         Client *c;
633
634         c = aux;
635         i += Qn+1;
636         if(i <= Qstatus){
637                 fillstat(d, PATH(i, c->num));
638                 return 0;
639         }
640         return -1;
641 }
642
643 static char*
644 fswalk1(Fid *fid, char *name, Qid *qid)
645 {
646         int i, n;
647         char buf[32];
648         ulong path;
649
650         path = fid->qid.path;
651         if(!(fid->qid.type&QTDIR))
652                 return "walk in non-directory";
653
654         if(strcmp(name, "..") == 0){
655                 switch(TYPE(path)){
656                 case Qn:
657                         qid->path = PATH(Qtcp, NUM(path));
658                         qid->type = tab[Qtcp].mode>>24;
659                         return nil;
660                 case Qtcp:
661                         qid->path = PATH(Qroot, 0);
662                         qid->type = tab[Qroot].mode>>24;
663                         return nil;
664                 case Qroot:
665                         return nil;
666                 default:
667                         return "bug in fswalk1";
668                 }
669         }
670
671         i = TYPE(path)+1;
672         for(; i<nelem(tab); i++){
673                 if(i==Qn){
674                         n = atoi(name);
675                         snprint(buf, sizeof buf, "%d", n);
676                         if(n < nclient && strcmp(buf, name) == 0){
677                                 qid->path = PATH(i, n);
678                                 qid->type = tab[i].mode>>24;
679                                 return nil;
680                         }
681                         break;
682                 }
683                 if(strcmp(name, tab[i].name) == 0){
684                         qid->path = PATH(i, NUM(path));
685                         qid->type = tab[i].mode>>24;
686                         return nil;
687                 }
688                 if(tab[i].mode&DMDIR)
689                         break;
690         }
691         return "directory entry not found";
692 }
693
694 typedef struct Cs Cs;
695 struct Cs
696 {
697         char *resp;
698         int isnew;
699 };
700
701 static int
702 ndbfindport(char *p)
703 {
704         char *s, *port;
705         int n;
706         static Ndb *db;
707
708         if(*p == '\0')
709                 return -1;
710
711         n = strtol(p, &s, 0);
712         if(*s == '\0')
713                 return n;
714
715         if(db == nil){
716                 db = ndbopen("/lib/ndb/common");
717                 if(db == nil)
718                         return -1;
719         }
720
721         port = ndbgetvalue(db, nil, "tcp", p, "port", nil);
722         if(port == nil)
723                 return -1;
724         n = atoi(port);
725         free(port);
726
727         return n;
728 }       
729
730 static void
731 csread(Req *r)
732 {
733         Cs *cs;
734
735         cs = r->fid->aux;
736         if(cs->resp==nil){
737                 respond(r, "cs read without write");
738                 return;
739         }
740         if(r->ifcall.offset==0){
741                 if(!cs->isnew){
742                         r->ofcall.count = 0;
743                         respond(r, nil);
744                         return;
745                 }
746                 cs->isnew = 0;
747         }
748         readstr(r, cs->resp);
749         respond(r, nil);
750 }
751
752 static void
753 cswrite(Req *r)
754 {
755         int port, nf;
756         char err[ERRMAX], *f[4], *s, *ns;
757         Cs *cs;
758
759         cs = r->fid->aux;
760         s = emalloc9p(r->ifcall.count+1);
761         memmove(s, r->ifcall.data, r->ifcall.count);
762         s[r->ifcall.count] = '\0';
763
764         nf = getfields(s, f, nelem(f), 0, "!");
765         if(nf != 3){
766                 free(s);
767                 respond(r, "can't translate");
768                 return;
769         }
770         if(strcmp(f[0], "tcp") != 0 && strcmp(f[0], "net") != 0){
771                 free(s);
772                 respond(r, "unknown protocol");
773                 return;
774         }
775         port = ndbfindport(f[2]);
776         if(port < 0){
777                 free(s);
778                 respond(r, "no translation found");
779                 return;
780         }
781
782         ns = smprint("%s/tcp/clone %s!%d", mtpt, f[1], port);
783         if(ns == nil){
784                 free(s);
785                 rerrstr(err, sizeof err);
786                 respond(r, err);
787                 return;
788         }
789         free(s);
790         free(cs->resp);
791         cs->resp = ns;
792         cs->isnew = 1;
793         r->ofcall.count = r->ifcall.count;
794         respond(r, nil);
795 }
796
797 static void
798 ctlread(Req *r, Client *c)
799 {
800         char buf[32];
801
802         sprint(buf, "%d", c->num);
803         readstr(r, buf);
804         respond(r, nil);
805 }
806
807 static void
808 ctlwrite(Req *r, Client *c)
809 {
810         char *f[3], *s;
811         int nf, port;
812
813         s = emalloc9p(r->ifcall.count+1);
814         r->ofcall.count = r->ifcall.count;
815         memmove(s, r->ifcall.data, r->ifcall.count);
816         s[r->ifcall.count] = '\0';
817
818         nf = tokenize(s, f, 3);
819         if(nf == 0){
820                 free(s);
821                 respond(r, nil);
822                 return;
823         }
824
825         if(strcmp(f[0], "hangup") == 0){
826                 if(c->state != Established)
827                         goto Badarg;
828                 if(nf != 1)
829                         goto Badarg;
830                 teardownclient(c);
831                 respond(r, nil);
832         }else if(strcmp(f[0], "connect") == 0){
833                 if(nf != 2 || c->state != Closed)
834                         goto Badarg;
835                 if(getfields(f[1], f, nelem(f), 0, "!") != 2)
836                         goto Badarg;
837                 if((port = ndbfindport(f[1])) < 0)
838                         goto Badarg;
839                 free(c->lhost);
840                 c->lhost = estrdup9p("::");
841                 c->lport = 0;
842                 free(c->rhost);
843                 c->rhost = estrdup9p(f[0]);
844                 c->rport = port;
845                 c->recvwin = WinPackets*MaxPacket;
846                 c->recvacc = 0;
847                 c->state = Dialing;
848                 queuewreq(c, r);
849
850                 sendmsg(pack(nil, "bsuuususu", MSG_CHANNEL_OPEN,
851                         "direct-tcpip", 12,
852                         c->num, c->recvwin, MaxPacket,
853                         c->rhost, strlen(c->rhost), c->rport,
854                         c->lhost, strlen(c->lhost), c->lport));
855         }else if(strcmp(f[0], "announce") == 0){
856                 if(nf != 2 || c->state != Closed)
857                         goto Badarg;
858                 if(getfields(f[1], f, nelem(f), 0, "!") != 2)
859                         goto Badarg;
860                 if((port = ndbfindport(f[1])) < 0)
861                         goto Badarg;
862                 if(strcmp(f[0], "*") == 0)
863                         f[0] = "";
864                 free(c->lhost);
865                 c->lhost = estrdup9p(f[0]);
866                 c->lport = port;
867                 free(c->rhost);
868                 c->rhost = estrdup9p("::");
869                 c->rport = 0;
870                 c->state = Listen;
871                 sendmsg(pack(nil, "bsbsu", MSG_GLOBAL_REQUEST,
872                         "tcpip-forward", 13, 0,
873                         c->lhost, strlen(c->lhost), c->lport));
874                 respond(r, nil);
875         }else{
876         Badarg:
877                 respond(r, "bad or inappropriate tcp control message");
878         }
879         free(s);
880 }
881
882 static void
883 dataread(Req *r, Client *c)
884 {
885         if(c->state < Established){
886                 respond(r, "not connected");
887                 return;
888         }
889         queuerreq(c, r);
890         matchrmsgs(c);
891 }
892
893 static void
894 datawrite(Req *r, Client *c)
895 {
896         if(c->state != Established){
897                 respond(r, "not connected");
898                 return;
899         }
900         if(r->ifcall.count == 0){
901                 r->ofcall.count = r->ifcall.count;
902                 respond(r, nil);
903                 return;
904         }
905         queuewreq(c, r);
906         procwreqs(c);
907 }
908
909 static void
910 localread(Req *r, Client *c)
911 {
912         char buf[128], *s;
913
914         s = c->lhost;
915         if(s == nil)
916                 s = "::";
917         else if(*s == 0)
918                 s = "*";
919         snprint(buf, sizeof buf, "%s!%d\n", s, c->lport);
920         readstr(r, buf);
921         respond(r, nil);
922 }
923
924 static void
925 remoteread(Req *r, Client *c)
926 {
927         char buf[128], *s;
928
929         s = c->rhost;
930         if(s == nil)
931                 s = "::";
932         snprint(buf, sizeof buf, "%s!%d\n", s, c->rport);
933         readstr(r, buf);
934         respond(r, nil);
935 }
936
937 static void
938 statusread(Req *r, Client *c)
939 {
940         char *s;
941
942         s = statestr[c->state];
943         readstr(r, s);
944         respond(r, nil);
945 }
946
947 static void
948 fsread(Req *r)
949 {
950         char e[ERRMAX];
951         ulong path;
952
953         path = r->fid->qid.path;
954         switch(TYPE(path)){
955         default:
956                 snprint(e, sizeof e, "bug in fsread path=%lux", path);
957                 respond(r, e);
958                 break;
959
960         case Qroot:
961                 dirread9p(r, rootgen, nil);
962                 respond(r, nil);
963                 break;
964
965         case Qcs:
966                 csread(r);
967                 break;
968
969         case Qtcp:
970                 dirread9p(r, tcpgen, nil);
971                 respond(r, nil);
972                 break;
973
974         case Qn:
975                 dirread9p(r, clientgen, client[NUM(path)]);
976                 respond(r, nil);
977                 break;
978
979         case Qctl:
980                 ctlread(r, client[NUM(path)]);
981                 break;
982
983         case Qdata:
984                 dataread(r, client[NUM(path)]);
985                 break;
986
987         case Qlocal:
988                 localread(r, client[NUM(path)]);
989                 break;
990
991         case Qremote:
992                 remoteread(r, client[NUM(path)]);
993                 break;
994
995         case Qstatus:
996                 statusread(r, client[NUM(path)]);
997                 break;
998         }
999 }
1000
1001 static void
1002 fswrite(Req *r)
1003 {
1004         ulong path;
1005         char e[ERRMAX];
1006
1007         path = r->fid->qid.path;
1008         switch(TYPE(path)){
1009         default:
1010                 snprint(e, sizeof e, "bug in fswrite path=%lux", path);
1011                 respond(r, e);
1012                 break;
1013
1014         case Qcs:
1015                 cswrite(r);
1016                 break;
1017
1018         case Qctl:
1019                 ctlwrite(r, client[NUM(path)]);
1020                 break;
1021
1022         case Qdata:
1023                 datawrite(r, client[NUM(path)]);
1024                 break;
1025         }
1026 }
1027
1028 static void
1029 fsopen(Req *r)
1030 {
1031         static int need[4] = { 4, 2, 6, 1 };
1032         ulong path;
1033         int n;
1034         Tab *t;
1035         Cs *cs;
1036
1037         /*
1038          * lib9p already handles the blatantly obvious.
1039          * we just have to enforce the permissions we have set.
1040          */
1041         path = r->fid->qid.path;
1042         t = &tab[TYPE(path)];
1043         n = need[r->ifcall.mode&3];
1044         if((n&t->mode) != n){
1045                 respond(r, "permission denied");
1046                 return;
1047         }
1048
1049         switch(TYPE(path)){
1050         case Qcs:
1051                 cs = emalloc9p(sizeof(Cs));
1052                 r->fid->aux = cs;
1053                 respond(r, nil);
1054                 break;
1055         case Qlisten:
1056                 if(client[NUM(path)]->state != Listen){
1057                         respond(r, "no address set");
1058                         break;
1059                 }
1060                 queuewreq(client[NUM(path)], r);
1061                 break;
1062         case Qclone:
1063                 n = newclient();
1064                 path = PATH(Qctl, n);
1065                 r->fid->qid.path = path;
1066                 r->ofcall.qid.path = path;
1067                 if(chatty9p)
1068                         fprint(2, "open clone => path=%lux\n", path);
1069                 t = &tab[Qctl];
1070                 /* fall through */
1071         default:
1072                 if(t-tab >= Qn)
1073                         client[NUM(path)]->ref++;
1074                 respond(r, nil);
1075                 break;
1076         }
1077 }
1078
1079 static void
1080 fsflush(Req *r)
1081 {
1082         int i;
1083
1084         for(i=0; i<nclient; i++)
1085                 if(findreq(client[i], r->oldreq))
1086                         respond(r->oldreq, "interrupted");
1087         respond(r, nil);
1088 }
1089
1090 static void
1091 handlemsg(Msg *m)
1092 {
1093         int chan, win, pkt, lport, rport, n, ln, rn;
1094         char *s, *lhost, *rhost;
1095         Client *c;
1096
1097         switch(m->rp[0]){
1098         case MSG_CHANNEL_WINDOW_ADJUST:
1099                 if(unpack(m, "_uu", &chan, &n) < 0)
1100                         break;
1101                 c = getclient(chan);
1102                 if(c != nil && c->state == Established){
1103                         c->sendwin += n;
1104                         procwreqs(c);
1105                 }
1106                 break;
1107         case MSG_CHANNEL_DATA:
1108                 if(unpack(m, "_us", &chan, &s, &n) < 0)
1109                         break;
1110                 c = getclient(chan);
1111                 if(c != nil && c->state == Established){
1112                         if(c->recvwin <= 0)
1113                                 break;
1114                         c->recvwin -= n;
1115                         m->rp = (uchar*)s;
1116                         queuermsg(c, m);
1117                         matchrmsgs(c);
1118                         return;
1119                 }
1120                 break;
1121         case MSG_CHANNEL_EOF:
1122                 if(unpack(m, "_u", &chan) < 0)
1123                         break;
1124                 c = getclient(chan);
1125                 if(c != nil && c->state == Established){
1126                         c->eof = 1;
1127                         c->recvwin = 0;
1128                         matchrmsgs(c);
1129                 }
1130                 break;
1131         case MSG_CHANNEL_CLOSE:
1132                 if(unpack(m, "_u", &chan) < 0)
1133                         break;
1134                 c = getclient(chan);
1135                 if(c == nil)
1136                         break;
1137                 switch(c->state){
1138                 case Established:
1139                         c->state = Finished;
1140                         hangupclient(c, "connection closed");
1141                         break;
1142                 case Teardown:
1143                         c->state = Closed;
1144                         break;
1145                 }
1146                 break;
1147         case MSG_CHANNEL_OPEN_CONFIRMATION:
1148                 if(unpack(m, "_uuuu", &chan, &n, &win, &pkt) < 0)
1149                         break;
1150                 if(chan == SESSIONCHAN){
1151                         sessionopen++;
1152                         break;
1153                 }
1154                 c = getclient(chan);
1155                 if(c == nil || c->state != Dialing)
1156                         break;
1157                 if(pkt <= 0 || pkt > MaxPacket)
1158                         pkt = MaxPacket;
1159                 c->eof = 0;
1160                 c->sendpkt = pkt;
1161                 c->sendwin = win;
1162                 c->servernum = n;
1163                 if(c->wq == nil){
1164                         teardownclient(c);
1165                         break;
1166                 }
1167                 respond(c->wq, nil);
1168                 c->wq = nil;
1169                 c->state = Established;
1170                 break;
1171         case MSG_CHANNEL_OPEN_FAILURE:
1172                 if(unpack(m, "_u____s", &chan, &s, &n) < 0)
1173                         break;
1174                 s = smprint("%.*s", utfnlen(s, n), s);
1175                 if(chan == SESSIONCHAN){
1176                         sysfatal("ssh failed: %s", s);
1177                         break;
1178                 }
1179                 c = getclient(chan);
1180                 if(c != nil && c->state == Dialing){
1181                         c->state = Closed;
1182                         hangupclient(c, s);
1183                 }
1184                 free(s);
1185                 break;
1186         case MSG_CHANNEL_OPEN:
1187                 if(unpack(m, "_suuususu", &s, &n, &chan,
1188                         &win, &pkt,
1189                         &lhost, &ln, &lport,
1190                         &rhost, &rn, &rport) < 0)
1191                         break;
1192                 if(n != 15 || strncmp(s, "forwarded-tcpip", 15) != 0){
1193                         n = 3, s = "unknown open type";
1194                 Reject:
1195                         sendmsg(pack(nil, "buus", MSG_CHANNEL_OPEN_FAILURE,
1196                                 chan, n, s, strlen(s)));
1197                         break;
1198                 }
1199                 lhost = smprint("%.*s", utfnlen(lhost, ln), lhost);
1200                 rhost = smprint("%.*s", utfnlen(rhost, rn), rhost);
1201                 c = acceptclient(lhost, lport, rhost, rport);
1202                 if(c == nil){
1203                         free(lhost);
1204                         free(rhost);
1205                         n = 2, s = "connection refused";
1206                         goto Reject;
1207                 }
1208                 c->servernum = chan;
1209                 c->recvwin = WinPackets*MaxPacket;
1210                 c->recvacc = 0;
1211                 c->eof = 0;
1212                 c->sendpkt = pkt;
1213                 c->sendwin = win;
1214                 c->state = Established;
1215
1216                 sendmsg(pack(nil, "buuuu", MSG_CHANNEL_OPEN_CONFIRMATION,
1217                         c->servernum, c->num, c->recvwin, MaxPacket));
1218
1219                 c->ref++;
1220                 c->wq->fid->qid.path = PATH(Qctl, c->num);
1221                 c->wq->ofcall.qid.path = c->wq->fid->qid.path;
1222                 respond(c->wq, nil);
1223                 c->wq = nil;
1224                 break;
1225         }
1226         free(m);
1227 }
1228
1229 void
1230 fsnetproc(void*)
1231 {
1232         ulong path;
1233         Alt a[4];
1234         Cs *cs;
1235         Fid *fid;
1236         Req *r;
1237         Msg *m;
1238
1239         threadsetname("fsthread");
1240
1241         a[0].op = CHANRCV;
1242         a[0].c = fsclunkchan;
1243         a[0].v = &fid;
1244         a[1].op = CHANRCV;
1245         a[1].c = fsreqchan;
1246         a[1].v = &r;
1247         a[2].op = CHANRCV;
1248         a[2].c = sshmsgchan;
1249         a[2].v = &m;
1250         a[3].op = CHANEND;
1251
1252         for(;;){
1253                 switch(alt(a)){
1254                 case 0:
1255                         path = fid->qid.path;
1256                         switch(TYPE(path)){
1257                         case Qcs:
1258                                 cs = fid->aux;
1259                                 if(cs){
1260                                         free(cs->resp);
1261                                         free(cs);
1262                                 }
1263                                 break;
1264                         }
1265                         if(fid->omode != -1 && TYPE(path) >= Qn)
1266                                 closeclient(client[NUM(path)]);
1267                         sendp(fsclunkwaitchan, nil);
1268                         break;
1269                 case 1:
1270                         switch(r->ifcall.type){
1271                         case Tattach:
1272                                 fsattach(r);
1273                                 break;
1274                         case Topen:
1275                                 fsopen(r);
1276                                 break;
1277                         case Tread:
1278                                 fsread(r);
1279                                 break;
1280                         case Twrite:
1281                                 fswrite(r);
1282                                 break;
1283                         case Tstat:
1284                                 fsstat(r);
1285                                 break;
1286                         case Tflush:
1287                                 fsflush(r);
1288                                 break;
1289                         default:
1290                                 respond(r, "bug in fsthread");
1291                                 break;
1292                         }
1293                         sendp(fsreqwaitchan, 0);
1294                         break;
1295                 case 2:
1296                         handlemsg(m);
1297                         break;
1298                 }
1299         }
1300 }
1301
1302 static void
1303 fssend(Req *r)
1304 {
1305         sendp(fsreqchan, r);
1306         recvp(fsreqwaitchan);   /* avoids need to deal with spurious flushes */
1307 }
1308
1309 static void
1310 fsdestroyfid(Fid *fid)
1311 {
1312         sendp(fsclunkchan, fid);
1313         recvp(fsclunkwaitchan);
1314 }
1315
1316 static void
1317 startup(Srv*)
1318 {
1319         proccreate(fsnetproc, nil, 8*1024);
1320 }
1321
1322 void
1323 takedown(Srv*)
1324 {
1325         threadexitsall("done");
1326 }
1327
1328 Srv fs = 
1329 {
1330 .attach=        fssend,
1331 .destroyfid=    fsdestroyfid,
1332 .walk1=         fswalk1,
1333 .open=          fssend,
1334 .read=          fssend,
1335 .write=         fssend,
1336 .stat=          fssend,
1337 .flush=         fssend,
1338 .start=         startup,
1339 .end=           takedown,
1340 };
1341
1342 int pfd[2];
1343 int sshargc;
1344 char **sshargv;
1345
1346 void
1347 startssh(void *)
1348 {
1349         char *f;
1350
1351         close(pfd[0]);
1352         dup(pfd[1], 0);
1353         dup(pfd[1], 1);
1354         close(pfd[1]);
1355         if(strncmp(sshargv[0], "./", 2) != 0)
1356                 f = smprint("/bin/%s", sshargv[0]);
1357         else
1358                 f = sshargv[0];
1359         procexec(nil, f, sshargv);
1360         sysfatal("exec: %r");
1361 }
1362
1363 void
1364 ssh(int argc, char *argv[])
1365 {
1366         Alt a[4];
1367         Waitmsg *w;
1368         Msg *m;
1369
1370         sshargc = argc + 2;
1371         sshargv = emalloc9p(sizeof(char *) * (sshargc + 1));
1372         sshargv[0] = "ssh";
1373         sshargv[1] = "-X";
1374         memcpy(sshargv + 2, argv, argc * sizeof(char *));
1375
1376         if(pipe(pfd) < 0)
1377                 sysfatal("pipe: %r");
1378         sshfd = pfd[0];
1379         procrfork(startssh, nil, 8*1024, RFFDG|RFNOTEG|RFNAMEG);
1380         close(pfd[1]);
1381
1382         procrfork(sshreadproc, nil, 8*1024, RFFDG|RFNOTEG|RFNAMEG);
1383
1384         sendmsg(pack(nil, "bsuuu", MSG_CHANNEL_OPEN,
1385                 "session", 7,
1386                 SESSIONCHAN,
1387                 MaxPacket,
1388                 MaxPacket));
1389
1390         a[0].op = CHANRCV;
1391         a[0].c = threadwaitchan();
1392         a[0].v = &w;
1393         a[1].op = CHANRCV;
1394         a[1].c = sshmsgchan;
1395         a[1].v = &m;
1396         a[2].op = CHANEND;
1397
1398         while(!sessionopen){
1399                 switch(alt(a)){
1400                 case 0:
1401                         sysfatal("ssh failed: %s", w->msg);
1402                 case 1:
1403                         handlemsg(m);
1404                 }
1405         }
1406 }
1407
1408 void
1409 usage(void)
1410 {
1411         fprint(2, "usage: sshnet [-m mtpt] [ssh options]\n");
1412         exits("usage");
1413 }
1414
1415 void
1416 threadmain(int argc, char **argv)
1417 {
1418         fmtinstall('H', encodefmt);
1419
1420         mtpt = "/net";
1421         service = nil;
1422         ARGBEGIN{
1423         case 'D':
1424                 chatty9p++;
1425                 break;
1426         case 'm':
1427                 mtpt = EARGF(usage());
1428                 break;
1429         case 's':
1430                 service = EARGF(usage());
1431                 break;
1432         default:
1433                 usage();
1434         }ARGEND
1435
1436         if(argc == 0)
1437                 usage();
1438
1439         time0 = time(0);
1440         sshmsgchan = chancreate(sizeof(Msg*), 16);
1441         fsreqchan = chancreate(sizeof(Req*), 0);
1442         fsreqwaitchan = chancreate(sizeof(void*), 0);
1443         fsclunkchan = chancreate(sizeof(Fid*), 0);
1444         fsclunkwaitchan = chancreate(sizeof(void*), 0);
1445
1446         ssh(argc, argv);
1447
1448         threadpostmountsrv(&fs, service, mtpt, MREPL);
1449
1450         threadexits(nil);
1451 }