]> git.lizzy.rs Git - plan9front.git/blobdiff - sys/src/cmd/sshnet.c
rio, kbdfs: increase read buffer for high latency kbdfs support
[plan9front.git] / sys / src / cmd / sshnet.c
index 815803b6c62c060ff91f04b9338b7d44187fb300..ecb730abcff52effaf5961568e2e5c53b36f439c 100755 (executable)
@@ -26,6 +26,7 @@ enum
        Qlocal,
        Qremote,
        Qstatus,
+       Qlisten,
 };
 
 #define PATH(type, n)          ((type)|((n)<<8))
@@ -44,15 +45,19 @@ enum
 {
        Closed,
        Dialing,
+       Listen,
        Established,
        Teardown,
+       Finished,
 };
 
 char *statestr[] = {
        "Closed",
        "Dialing",
+       "Listen",
        "Established",
        "Teardown",
+       "Finished",
 };
 
 struct Client
@@ -61,13 +66,18 @@ struct Client
        int state;
        int num;
        int servernum;
-       char *connect;
+
+       int rport, lport;
+       char *rhost;
+       char *lhost;
 
        int sendpkt;
        int sendwin;
        int recvwin;
        int recvacc;
 
+       int eof;
+
        Req *wq;
        Req **ewq;
 
@@ -79,6 +89,8 @@ struct Client
 };
 
 enum {
+       MSG_GLOBAL_REQUEST = 80,
+
        MSG_CHANNEL_OPEN = 90,
        MSG_CHANNEL_OPEN_CONFIRMATION,
        MSG_CHANNEL_OPEN_FAILURE,
@@ -91,7 +103,8 @@ enum {
        MSG_CHANNEL_SUCCESS,
        MSG_CHANNEL_FAILURE,
 
-       MaxPacket = 1<<15,
+       Overhead = 256,
+       MaxPacket = (1<<15)-256,        /* 32K is maxatomic for pipe */
        WinPackets = 8,
 
        SESSIONCHAN = 1<<24,
@@ -104,7 +117,7 @@ struct Msg
        uchar   *rp;
        uchar   *wp;
        uchar   *ep;
-       uchar   buf[MaxPacket];
+       uchar   buf[MaxPacket + Overhead];
 };
 
 #define PUT4(p, u) (p)[0] = (u)>>24, (p)[1] = (u)>>16, (p)[2] = (u)>>8, (p)[3] = (u)
@@ -114,8 +127,6 @@ int nclient;
 Client **client;
 char *mtpt;
 int sshfd;
-int localport;
-char localip[] = "::";
 
 int
 vpack(uchar *p, int n, char *fmt, va_list a)
@@ -296,6 +307,31 @@ getclient(int num)
        return client[num];
 }
 
+Client*
+acceptclient(char *lhost, int lport, char *rhost, int rport)
+{
+       Client *c, *nc;
+       int i;
+
+       for(i = 0; i < nclient; i++){
+               c = client[i];
+               if(c->state == Listen && c->lport == lport && c->wq != nil){
+                       nc = client[newclient()];
+                       nc->wq = c->wq;
+                       c->wq = nc->wq->aux;
+                       nc->wq->aux = nil;
+                       free(nc->lhost);
+                       nc->lhost = lhost;
+                       nc->lport = lport;
+                       free(nc->rhost);
+                       nc->rhost = rhost;
+                       nc->rport = rport;
+                       return nc;
+               }
+       }
+       return nil;
+}
+
 void
 adjustwin(Client *c, int len)
 {
@@ -341,12 +377,10 @@ matchrmsgs(Client *c)
        Msg *m;
        int n, rm;
 
-       while(c->rq != nil && c->mq != nil){
-               r = c->rq;
+       while((r = c->rq) != nil && (m = c->mq) != nil){
                c->rq = r->aux;
-
+               r->aux = nil;
                rm = 0;
-               m = c->mq;
                n = r->ifcall.count;
                if(n >= m->wp - m->rp){
                        n = m->wp - m->rp;
@@ -362,6 +396,15 @@ matchrmsgs(Client *c)
                respond(r, nil);
                adjustwin(c, n);
        }
+
+       if(c->eof){
+               while((r = c->rq) != nil){
+                       c->rq = r->aux;
+                       r->aux = nil;
+                       r->ofcall.count = 0;
+                       respond(r, nil);
+               }
+       }
 }
 
 void
@@ -422,74 +465,59 @@ findreq(Client *c, Req *r)
 }
 
 void
-dialedclient(Client *c)
+hangupclient(Client *c, char *err)
 {
        Req *r;
 
-       if(r=c->wq){
-               if(r->aux != nil)
-                       sysfatal("more than one outstanding dial request (BUG)");
-               if(c->state == Established)
-                       respond(r, nil);
-               else
-                       respond(r, "connect failed");
+       c->eof = 1;
+       c->recvwin = 0;
+       c->sendwin = 0;
+       while((r = c->wq) != nil){
+               c->wq = r->aux;
+               r->aux = nil;
+               respond(r, err);
        }
-       c->wq = nil;
+       matchrmsgs(c);
 }
 
 void
 teardownclient(Client *c)
 {
        c->state = Teardown;
-       sendmsg(pack(nil, "bu", MSG_CHANNEL_EOF, c->servernum));
-}
-
-void
-hangupclient(Client *c)
-{
-       Req *r, *next;
-       Msg *m, *mnext;
-
-       c->state = Closed;
-       for(m=c->mq; m; m=mnext){
-               mnext = m->link;
-               free(m);
-       }
-       c->mq = nil;
-       for(r=c->rq; r; r=next){
-               next = r->aux;
-               respond(r, "hangup on network connection");
-       }
-       c->rq = nil;
-       for(r=c->wq; r; r=next){
-               next = r->aux;
-               respond(r, "hangup on network connection");
-       }
-       c->wq = nil;
+       hangupclient(c, "i/o on hungup channel");
+       sendmsg(pack(nil, "bu", MSG_CHANNEL_CLOSE, c->servernum));
 }
 
 void
 closeclient(Client *c)
 {
-       Msg *m, *next;
+       Msg *m;
 
        if(--c->ref)
                return;
-
-       if(c->rq != nil || c->wq != nil)
-               sysfatal("ref count reached zero with requests pending (BUG)");
-
-       for(m=c->mq; m; m=next){
-               next = m->link;
+       switch(c->state){
+       case Established:
+               teardownclient(c);
+               break;
+       case Finished:
+               c->state = Closed;
+               sendmsg(pack(nil, "bu", MSG_CHANNEL_CLOSE, c->servernum));
+               break;
+       case Listen:
+               c->state = Closed;
+               sendmsg(pack(nil, "bsbsu", MSG_GLOBAL_REQUEST,
+                       "cancel-tcpip-forward", 20,
+                       0,
+                       c->lhost, strlen(c->lhost),
+                       c->lport));
+               break;
+       }
+       while((m = c->mq) != nil){
+               c->mq = m->link;
                free(m);
        }
-       c->mq = nil;
-
-       if(c->state != Closed)
-               teardownclient(c);
 }
 
-       
 void
 sshreadproc(void*)
 {
@@ -525,6 +553,7 @@ Tab tab[] =
        "local",        0444,
        "remote",       0444,
        "status",       0444,
+       "listen",       0666,
 };
 
 static void
@@ -779,7 +808,7 @@ static void
 ctlwrite(Req *r, Client *c)
 {
        char *f[3], *s;
-       int nf;
+       int nf, port;
 
        s = emalloc9p(r->ifcall.count+1);
        r->ofcall.count = r->ifcall.count;
@@ -801,19 +830,19 @@ ctlwrite(Req *r, Client *c)
                teardownclient(c);
                respond(r, nil);
        }else if(strcmp(f[0], "connect") == 0){
-               if(c->state != Closed)
+               if(nf != 2 || c->state != Closed)
                        goto Badarg;
-               if(nf != 2)
+               if(getfields(f[1], f, nelem(f), 0, "!") != 2)
                        goto Badarg;
-               c->connect = estrdup9p(f[1]);
-               nf = getfields(f[1], f, nelem(f), 0, "!");
-               if(nf != 2){
-                       free(c->connect);
-                       c->connect = nil;
+               if((port = ndbfindport(f[1])) < 0)
                        goto Badarg;
-               }
-               c->sendwin = MaxPacket;
-               c->recvwin = WinPackets * MaxPacket;
+               free(c->lhost);
+               c->lhost = estrdup9p("::");
+               c->lport = 0;
+               free(c->rhost);
+               c->rhost = estrdup9p(f[0]);
+               c->rport = port;
+               c->recvwin = WinPackets*MaxPacket;
                c->recvacc = 0;
                c->state = Dialing;
                queuewreq(c, r);
@@ -821,8 +850,28 @@ ctlwrite(Req *r, Client *c)
                sendmsg(pack(nil, "bsuuususu", MSG_CHANNEL_OPEN,
                        "direct-tcpip", 12,
                        c->num, c->recvwin, MaxPacket,
-                       f[0], strlen(f[0]), ndbfindport(f[1]),
-                       localip, strlen(localip), localport));
+                       c->rhost, strlen(c->rhost), c->rport,
+                       c->lhost, strlen(c->lhost), c->lport));
+       }else if(strcmp(f[0], "announce") == 0){
+               if(nf != 2 || c->state != Closed)
+                       goto Badarg;
+               if(getfields(f[1], f, nelem(f), 0, "!") != 2)
+                       goto Badarg;
+               if((port = ndbfindport(f[1])) < 0)
+                       goto Badarg;
+               if(strcmp(f[0], "*") == 0)
+                       f[0] = "";
+               free(c->lhost);
+               c->lhost = estrdup9p(f[0]);
+               c->lport = port;
+               free(c->rhost);
+               c->rhost = estrdup9p("::");
+               c->rport = 0;
+               c->state = Listen;
+               sendmsg(pack(nil, "bsbsu", MSG_GLOBAL_REQUEST,
+                       "tcpip-forward", 13, 0,
+                       c->lhost, strlen(c->lhost), c->lport));
+               respond(r, nil);
        }else{
        Badarg:
                respond(r, "bad or inappropriate tcp control message");
@@ -833,7 +882,7 @@ ctlwrite(Req *r, Client *c)
 static void
 dataread(Req *r, Client *c)
 {
-       if(c->state != Established){
+       if(c->state < Established){
                respond(r, "not connected");
                return;
        }
@@ -858,11 +907,16 @@ datawrite(Req *r, Client *c)
 }
 
 static void
-localread(Req *r)
+localread(Req *r, Client *c)
 {
-       char buf[128];
+       char buf[128], *s;
 
-       snprint(buf, sizeof buf, "%s!%d\n", localip, localport);
+       s = c->lhost;
+       if(s == nil)
+               s = "::";
+       else if(*s == 0)
+               s = "*";
+       snprint(buf, sizeof buf, "%s!%d\n", s, c->lport);
        readstr(r, buf);
        respond(r, nil);
 }
@@ -870,13 +924,12 @@ localread(Req *r)
 static void
 remoteread(Req *r, Client *c)
 {
-       char *s;
-       char buf[128];
+       char buf[128], *s;
 
-       s = c->connect;
+       s = c->rhost;
        if(s == nil)
-               s = "::!0";
-       snprint(buf, sizeof buf, "%s\n", s);
+               s = "::";
+       snprint(buf, sizeof buf, "%s!%d\n", s, c->rport);
        readstr(r, buf);
        respond(r, nil);
 }
@@ -932,7 +985,7 @@ fsread(Req *r)
                break;
 
        case Qlocal:
-               localread(r);
+               localread(r, client[NUM(path)]);
                break;
 
        case Qremote:
@@ -999,6 +1052,13 @@ fsopen(Req *r)
                r->fid->aux = cs;
                respond(r, nil);
                break;
+       case Qlisten:
+               if(client[NUM(path)]->state != Listen){
+                       respond(r, "no address set");
+                       break;
+               }
+               queuewreq(client[NUM(path)], r);
+               break;
        case Qclone:
                n = newclient();
                path = PATH(Qctl, n);
@@ -1030,16 +1090,16 @@ fsflush(Req *r)
 static void
 handlemsg(Msg *m)
 {
-       int chan, win, pkt, n, l;
+       int chan, win, pkt, lport, rport, n, ln, rn;
+       char *s, *lhost, *rhost;
        Client *c;
-       char *s;
 
        switch(m->rp[0]){
        case MSG_CHANNEL_WINDOW_ADJUST:
                if(unpack(m, "_uu", &chan, &n) < 0)
                        break;
                c = getclient(chan);
-               if(c != nil && c->state==Established){
+               if(c != nil && c->state == Established){
                        c->sendwin += n;
                        procwreqs(c);
                }
@@ -1048,7 +1108,9 @@ handlemsg(Msg *m)
                if(unpack(m, "_us", &chan, &s, &n) < 0)
                        break;
                c = getclient(chan);
-               if(c != nil && c->state==Established){
+               if(c != nil && c->state == Established){
+                       if(c->recvwin <= 0)
+                               break;
                        c->recvwin -= n;
                        m->rp = (uchar*)s;
                        queuermsg(c, m);
@@ -1060,19 +1122,27 @@ handlemsg(Msg *m)
                if(unpack(m, "_u", &chan) < 0)
                        break;
                c = getclient(chan);
-               if(c != nil){
-                       hangupclient(c);
-                       m->rp = m->wp = m->buf;
-                       sendmsg(pack(m, "bu", MSG_CHANNEL_CLOSE, c->servernum));
-                       return;
+               if(c != nil && c->state == Established){
+                       c->eof = 1;
+                       c->recvwin = 0;
+                       matchrmsgs(c);
                }
                break;
        case MSG_CHANNEL_CLOSE:
                if(unpack(m, "_u", &chan) < 0)
                        break;
                c = getclient(chan);
-               if(c != nil)
-                       hangupclient(c);
+               if(c == nil)
+                       break;
+               switch(c->state){
+               case Established:
+                       c->state = Finished;
+                       hangupclient(c, "connection closed");
+                       break;
+               case Teardown:
+                       c->state = Closed;
+                       break;
+               }
                break;
        case MSG_CHANNEL_OPEN_CONFIRMATION:
                if(unpack(m, "_uuuu", &chan, &n, &win, &pkt) < 0)
@@ -1086,25 +1156,71 @@ handlemsg(Msg *m)
                        break;
                if(pkt <= 0 || pkt > MaxPacket)
                        pkt = MaxPacket;
+               c->eof = 0;
                c->sendpkt = pkt;
                c->sendwin = win;
                c->servernum = n;
+               if(c->wq == nil){
+                       teardownclient(c);
+                       break;
+               }
+               respond(c->wq, nil);
+               c->wq = nil;
                c->state = Established;
-               dialedclient(c);
                break;
        case MSG_CHANNEL_OPEN_FAILURE:
-               if(unpack(m, "_uus", &chan, &n, &s, &l) < 0)
+               if(unpack(m, "_u____s", &chan, &s, &n) < 0)
                        break;
+               s = smprint("%.*s", utfnlen(s, n), s);
                if(chan == SESSIONCHAN){
-                       sendp(ssherrchan, smprint("%.*s", utfnlen(s, l), s));
+                       sendp(ssherrchan, s);
                        break;
                }
                c = getclient(chan);
-               if(c == nil || c->state != Dialing)
+               if(c != nil && c->state == Dialing){
+                       c->state = Closed;
+                       hangupclient(c, s);
+               }
+               free(s);
+               break;
+       case MSG_CHANNEL_OPEN:
+               if(unpack(m, "_suuususu", &s, &n, &chan,
+                       &win, &pkt,
+                       &lhost, &ln, &lport,
+                       &rhost, &rn, &rport) < 0)
                        break;
-               c->servernum = n;
-               c->state = Closed;
-               dialedclient(c);
+               if(n != 15 || strncmp(s, "forwarded-tcpip", 15) != 0){
+                       n = 3, s = "unknown open type";
+               Reject:
+                       sendmsg(pack(nil, "buus", MSG_CHANNEL_OPEN_FAILURE,
+                               chan, n, s, strlen(s)));
+                       break;
+               }
+               lhost = smprint("%.*s", utfnlen(lhost, ln), lhost);
+               rhost = smprint("%.*s", utfnlen(rhost, rn), rhost);
+               c = acceptclient(lhost, lport, rhost, rport);
+               if(c == nil){
+                       free(lhost);
+                       free(rhost);
+                       n = 2, s = "connection refused";
+                       goto Reject;
+               }
+               c->servernum = chan;
+               c->recvwin = WinPackets*MaxPacket;
+               c->recvacc = 0;
+               c->eof = 0;
+               c->sendpkt = pkt;
+               c->sendwin = win;
+               c->state = Established;
+
+               sendmsg(pack(nil, "buuuu", MSG_CHANNEL_OPEN_CONFIRMATION,
+                       c->servernum, c->num, c->recvwin, MaxPacket));
+
+               c->ref++;
+               c->wq->fid->qid.path = PATH(Qctl, c->num);
+               c->wq->ofcall.qid.path = c->wq->fid->qid.path;
+               respond(c->wq, nil);
+               c->wq = nil;
                break;
        }
        free(m);
@@ -1252,7 +1368,7 @@ ssh(int argc, char *argv[])
 
        pipe(pfd);
        sshfd = pfd[0];
-       procrfork(startssh, nil, mainstacksize, RFFDG|RFNOTEG|RFNAMEG);
+       procrfork(startssh, nil, 8*1024, RFFDG|RFNOTEG|RFNAMEG);
        close(pfd[1]);
 
        sendmsg(pack(nil, "bsuuu", MSG_CHANNEL_OPEN,
@@ -1319,8 +1435,8 @@ threadmain(int argc, char **argv)
        fsreqwaitchan = chancreate(sizeof(void*), 0);
        fsclunkchan = chancreate(sizeof(Fid*), 0);
        fsclunkwaitchan = chancreate(sizeof(void*), 0);
-       procrfork(fsnetproc, nil, mainstacksize, RFNAMEG|RFNOTEG);
-       procrfork(sshreadproc, nil, mainstacksize, RFNAMEG|RFNOTEG);
+       procrfork(fsnetproc, nil, 8*1024, RFNAMEG|RFNOTEG);
+       procrfork(sshreadproc, nil, 8*1024, RFNAMEG|RFNOTEG);
 
        ssh(argc, argv);