]> git.lizzy.rs Git - plan9front.git/blobdiff - sys/src/cmd/ssh.c
cc: fix void cast crash
[plan9front.git] / sys / src / cmd / ssh.c
index c0877f289c18100a7b8ce239bd6edc7a1e3d13d3..239a4f542c2e5d914b5086346311f3029cabaed3 100644 (file)
@@ -52,9 +52,7 @@ enum {
        WinPackets = 8,         // (1<<15) * 8 = 256K
 };
 
-enum {
-       MaxPwTries = 3 // retry this often for keyboard-interactive
-};
+int MaxPwTries = 3; // retry this often for keyboard-interactive
 
 typedef struct
 {
@@ -82,8 +80,8 @@ int nsid;
 uchar sid[256];
 char thumb[2*SHA2_256dlen+1], *thumbfile;
 
-int fd, intr, raw, debug;
-char *user, *service, *status, *host, *cmd;
+int fd, intr, raw, port, mux, debug;
+char *user, *service, *status, *host, *remote, *cmd;
 
 Oneway recv, send;
 void dispatch(void);
@@ -99,7 +97,7 @@ shutdown(void)
 void
 catch(void*, char *msg)
 {
-       if(strstr(msg, "interrupt") != nil){
+       if(strcmp(msg, "interrupt") == 0){
                intr = 1;
                noted(NCONT);
        }
@@ -112,11 +110,9 @@ wasintr(void)
        char err[ERRMAX];
        int r;
 
-       if(intr)
-               return 1;
        memset(err, 0, sizeof(err));
        errstr(err, sizeof(err));
-       r = strstr(err, "interrupt") != nil;
+       r = strcmp(err, "interrupted") == 0;
        errstr(err, sizeof(err));
        return r;
 }
@@ -428,11 +424,6 @@ ssh2rsasig(uchar *data, int len)
        return m;
 }
 
-/* libsec */
-extern mpint* pkcs1padbuf(uchar *buf, int len, mpint *modulus, int blocktype);
-extern int asn1encodedigest(DigestState* (*fun)(uchar*, ulong, uchar*, DigestState*),
-       uchar *digest, uchar *buf, int len);
-
 mpint*
 pkcs1digest(uchar *data, int len, RSApub *pub)
 {
@@ -498,7 +489,7 @@ kex(int gotkexinit)
        static char kexalgs[] = "curve25519-sha256,curve25519-sha256@libssh.org";
        static char cipheralgs[] = "chacha20-poly1305@openssh.com";
        static char zipalgs[] = "none";
-       static char macalgs[] = "";
+       static char macalgs[] = "hmac-sha1";    /* work around for github.com */
        static char langs[] = "";
 
        uchar cookie[16], x[32], yc[32], z[32], k[32+1], h[SHA2_256dlen], *ys, *ks, *sig;
@@ -553,7 +544,7 @@ kex(int gotkexinit)
                for(t=tab; *t != nil; t++){
                        if(unpack(p, recv.w-p, "s.", &s, &n, &p) < 0)
                                break;
-                       fprint(2, "%s: %.*s\n", *t, n, s);
+                       fprint(2, "%s: %.*s\n", *t, utfnlen(s, n), s);
                }
        }
 
@@ -670,7 +661,7 @@ authfailure(char *meth)
        if(unpack(recv.r, recv.w-recv.r, "_sb", &s, &n, &partial) < 0)
                sysfatal("bad auth failure response");
        free(authnext);
-       authnext = smprint("%.*s", n, s);
+       authnext = smprint("%.*s", utfnlen(s, n), s);
 if(debug)
        fprint(2, "userauth %s failed: partial=%d, next=%s\n", meth, partial, authnext);
        return partial != 0 || !authok(meth);
@@ -912,9 +903,9 @@ Retry:
                m--;
 
        if(n > 0)
-               fprint(fd, "%.*s\n", n, name);
+               fprint(fd, "%.*s\n", utfnlen(name, n), name);
        if(m > 0)
-               fprint(fd, "%.*s\n", m, inst);
+               fprint(fd, "%.*s\n", utfnlen(inst, m), inst);
 
        /* lang, nprompt */
        if(unpack(recv.r, recv.w-recv.r, "su.", &s, &n, &nquest, &recv.r) < 0)
@@ -970,23 +961,45 @@ dispatch(void)
 
        switch(recv.r[0]){
        case MSG_IGNORE:
+               return;
        case MSG_GLOBAL_REQUEST:
+               if(unpack(recv.r, recv.w-recv.r, "_sb", &s, &n, &b) < 0)
+                       break;
+               if(debug)
+                       fprint(2, "%s: global request: %.*s\n",
+                               argv0, utfnlen(s, n), s);
+               if(b != 0)
+                       sendpkt("b", MSG_REQUEST_FAILURE);
                return;
        case MSG_DISCONNECT:
                if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0)
                        break;
-               sysfatal("disconnect: (%d) %.*s", c, n, s);
+               sysfatal("disconnect: (%d) %.*s", c, utfnlen(s, n), s);
                return;
        case MSG_DEBUG:
                if(unpack(recv.r, recv.w-recv.r, "__sb", &s, &n, &c) < 0)
                        break;
-               if(c != 0 || debug) fprint(2, "%s: %.*s\n", argv0, n, s);
+               if(c != 0 || debug)
+                       fprint(2, "%s: %.*s\n", argv0, utfnlen(s, n), s);
                return;
        case MSG_USERAUTH_BANNER:
                if(unpack(recv.r, recv.w-recv.r, "_s", &s, &n) < 0)
                        break;
                if(raw) write(2, s, n);
                return;
+       case MSG_KEXINIT:
+               kex(1);
+               return;
+       }
+
+       if(mux){
+               n = recv.w - recv.r;
+               if(write(1, recv.r, n) != n)
+                       sysfatal("write out: %r");
+               return;
+       }
+
+       switch(recv.r[0]){
        case MSG_CHANNEL_DATA:
                if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0)
                        break;
@@ -1027,15 +1040,22 @@ dispatch(void)
                        if(unpack(p, recv.w-p, "s", &s, &n) < 0)
                                break;
                        if(n != 0 && status == nil)
-                               status = smprint("%.*s", n, s);
+                               status = smprint("%.*s", utfnlen(s, n), s);
+                       c = MSG_CHANNEL_SUCCESS;
                } else if(n == 11 && memcmp(s, "exit-status", n) == 0){
                        if(unpack(p, recv.w-p, "u", &n) < 0)
                                break;
                        if(n != 0 && status == nil)
                                status = smprint("%d", n);
-               } else if(debug) {
-                       fprint(2, "%s: channel request: %.*s\n", argv0, n, s);
+                       c = MSG_CHANNEL_SUCCESS;
+               } else {
+                       if(debug)
+                               fprint(2, "%s: channel request: %.*s\n",
+                                       argv0, utfnlen(s, n), s);
+                       c = MSG_CHANNEL_FAILURE;
                }
+               if(b != 0)
+                       sendpkt("bu", c, recv.chan);
                return;
        case MSG_CHANNEL_EOF:
                recv.eof = 1;
@@ -1044,9 +1064,6 @@ dispatch(void)
        case MSG_CHANNEL_CLOSE:
                shutdown();
                return;
-       case MSG_KEXINIT:
-               kex(1);
-               return;
        }
        sysfatal("got: %.*H", (int)(recv.w - recv.r), recv.r);
 }
@@ -1072,19 +1089,35 @@ static struct {
        int     ypixels;
        int     lines;
        int     cols;
-} tty = {
-       "dumb",
-       0,
-       0,
-       0,
-       0,
-};
+} tty;
+
+void
+getdim(void)
+{
+       char *s;
+
+       if(s = getenv("XPIXELS")){
+               tty.xpixels = atoi(s);
+               free(s);
+       }
+       if(s = getenv("YPIXELS")){
+               tty.ypixels = atoi(s);
+               free(s);
+       }
+       if(s = getenv("LINES")){
+               tty.lines = atoi(s);
+               free(s);
+       }
+       if(s = getenv("COLS")){
+               tty.cols = atoi(s);
+               free(s);
+       }
+}
 
 void
 rawon(void)
 {
        int ctl;
-       char *s;
 
        close(0);
        if(open("/dev/cons", OREAD) != 0)
@@ -1093,33 +1126,38 @@ rawon(void)
        if(open("/dev/cons", OWRITE) != 1)
                sysfatal("open: %r");
        dup(1, 2);
-       if((ctl = open("/dev/consctl", OWRITE)) >= 0)
+       if((ctl = open("/dev/consctl", OWRITE)) >= 0){
                write(ctl, "rawon", 5);
-       if(s = getenv("TERM")){
-               tty.term = s;
-               if(s = getenv("XPIXELS")){
-                       tty.xpixels = atoi(s);
-                       free(s);
-               }
-               if(s = getenv("YPIXELS")){
-                       tty.ypixels = atoi(s);
-                       free(s);
-               }
-               if(s = getenv("LINES")){
-                       tty.lines = atoi(s);
-                       free(s);
-               }
-               if(s = getenv("COLS")){
-                       tty.cols = atoi(s);
-                       free(s);
-               }
+               write(ctl, "winchon", 7);       /* vt(1): interrupt note on window change */
        }
+       getdim();
+}
+
+#pragma           varargck    type  "k"   char*
+
+kfmt(Fmt *f)
+{
+       char *s, *p;
+       int n;
+
+       s = va_arg(f->args, char*);
+       n = fmtstrcpy(f, "'");
+       while((p = strchr(s, '\'')) != nil){
+               *p = '\0';
+               n += fmtstrcpy(f, s);
+               *p = '\'';
+               n += fmtstrcpy(f, "'\\''");
+               s = p+1;
+       }
+       n += fmtstrcpy(f, s);
+       n += fmtstrcpy(f, "'");
+       return n;
 }
 
 void
 usage(void)
 {
-       fprint(2, "usage: %s [-dR] [-t thumbfile] [-u user] [user@]host [cmd]\n", argv0);
+       fprint(2, "usage: %s [-dR] [-t thumbfile] [-T tries] [-u user] [-h] [user@]host [-W remote!port] [cmd args...]\n", argv0);
        exits("usage");
 }
 
@@ -1134,30 +1172,61 @@ main(int argc, char *argv[])
        fmtinstall('B', mpfmt);
        fmtinstall('H', encodefmt);
        fmtinstall('[', encodefmt);
+       fmtinstall('k', kfmt);
 
-       s = getenv("TERM");
-       raw = s != nil && strcmp(s, "dumb") != 0;
-       free(s);
+       tty.term = getenv("TERM");
+       if(tty.term == nil)
+               tty.term = "";
+       raw = *tty.term != 0;
 
        ARGBEGIN {
        case 'd':
                debug++;
                break;
+       case 'W':
+               remote = EARGF(usage());
+               s = strrchr(remote, '!');
+               if(s == nil)
+                       s = strrchr(remote, ':');
+               if(s == nil)
+                       usage();
+               *s++ = 0;
+               port = atoi(s);
+               raw = 0;
+               break;
        case 'R':
                raw = 0;
                break;
+       case 'r':
+               raw = 2; /* bloody */
+               break;
        case 'u':
                user = EARGF(usage());
                break;
+       case 'h':
+               host = EARGF(usage());
+               break;
        case 't':
                thumbfile = EARGF(usage());
                break;
+       case 'T':
+               MaxPwTries = strtol(EARGF(usage()), &s, 0);
+               if(*s != 0) usage();
+               break;
+       case 'X':
+               mux = 1;
+               raw = 0;
+               break;
+       default:
+               usage();
        } ARGEND;
 
-       if(argc == 0)
-               usage();
+       if(host == nil){
+               if(argc == 0)
+                       usage();
+               host = *argv++;
+       }
 
-       host = *argv++;
        if(user == nil){
                s = strchr(host, '@');
                if(s != nil){
@@ -1166,17 +1235,21 @@ main(int argc, char *argv[])
                        host = s;
                }
        }
+
        for(cmd = nil; *argv != nil; argv++){
-               if(cmd == nil)
+               if(cmd == nil){
                        cmd = strdup(*argv);
-               else {
-                       s = smprint("%s %q", cmd, *argv);
+                       if(raw == 1)
+                               raw = 0;
+               }else{
+                       s = smprint("%s %k", cmd, *argv);
                        free(cmd);
                        cmd = s;
                }
        }
-       if(cmd != nil)
-               raw = 0;
+
+       if(remote != nil && cmd != nil)
+               usage();
 
        if((fd = dial(netmkaddr(host, nil, "ssh"), nil, nil, nil)) < 0)
                sysfatal("dial: %r");
@@ -1199,9 +1272,6 @@ main(int argc, char *argv[])
 
        kex(0);
 
-
-       service = "ssh-connection";
-
        sendpkt("bs", MSG_SERVICE_REQUEST, "ssh-userauth", 12);
 Next0: switch(recvpkt()){
        default:
@@ -1211,20 +1281,39 @@ Next0:  switch(recvpkt()){
                break;
        }
 
+       service = "ssh-connection";
        if(noneauth() < 0 && pubkeyauth() < 0 && passauth() < 0 && kbintauth() < 0)
                sysfatal("auth: %r");
 
-       recv.pkt = MaxPacket;
-       recv.win = WinPackets*recv.pkt;
-       recv.chan = 0;
+       recv.pkt = send.pkt = MaxPacket;
+       recv.win = send.win =  WinPackets*recv.pkt;
+       recv.chan = send.win = 0;
 
-       /* open hailing frequencies */
-       sendpkt("bsuuu", MSG_CHANNEL_OPEN,
-               "session", 7,
-               recv.chan,
-               recv.win,
-               recv.pkt);
+       if(mux)
+               goto Mux;
 
+       /* open hailing frequencies */
+       if(remote != nil){
+               NetConnInfo *nci = getnetconninfo(nil, fd);
+               if(nci == nil)
+                       sysfatal("can't get netconninfo: %r");
+               sendpkt("bsuuususu", MSG_CHANNEL_OPEN,
+                       "direct-tcpip", 12,
+                       recv.chan,
+                       recv.win,
+                       recv.pkt,
+                       remote, strlen(remote),
+                       port,
+                       nci->laddr, strlen(nci->laddr),
+                       atoi(nci->lserv));
+               free(nci);
+       } else {
+               sendpkt("bsuuu", MSG_CHANNEL_OPEN,
+                       "session", 7,
+                       recv.chan,
+                       recv.win,
+                       recv.pkt);
+       }
 Next1: switch(recvpkt()){
        default:
                dispatch();
@@ -1232,7 +1321,7 @@ Next1:    switch(recvpkt()){
        case MSG_CHANNEL_OPEN_FAILURE:
                if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0)
                        n = strlen(s = "???");
-               sysfatal("channel open failure: (%d) %.*s", b, n, s);
+               sysfatal("channel open failure: (%d) %.*s", b, utfnlen(s, n), s);
        case MSG_CHANNEL_OPEN_CONFIRMATION:
                break;
        }
@@ -1242,30 +1331,9 @@ Next1:   switch(recvpkt()){
        if(send.pkt <= 0 || send.pkt > MaxPacket)
                send.pkt = MaxPacket;
 
-       notify(catch);
-       atexit(shutdown);
-
-       recv.pid = getpid();
-       n = rfork(RFPROC|RFMEM);
-       if(n < 0)
-               sysfatal("fork: %r");
-
-       /* parent reads and dispatches packets */
-       if(n > 0) {
-               send.pid = n;
-               while((send.eof|recv.eof) == 0){
-                       recvpkt();
-                       qlock(&sl);                                     
-                       dispatch();
-                       if((int)(send.kex - send.seq) <= 0 || (int)(recv.kex - recv.seq) <= 0)
-                               kex(0);
-                       qunlock(&sl);
-               }
-               exits(status);
-       }
+       if(remote != nil)
+               goto Mux;
 
-       /* child reads input and sends packets */
-       qlock(&sl);
        if(raw) {
                rawon();
                sendpkt("busbsuuuus", MSG_CHANNEL_REQUEST,
@@ -1284,6 +1352,12 @@ Next1:   switch(recvpkt()){
                        send.chan,
                        "shell", 5,
                        0);
+       } else if(*cmd == '#') {
+               sendpkt("busbs", MSG_CHANNEL_REQUEST,
+                       send.chan,
+                       "subsystem", 9,
+                       0,
+                       cmd+1, strlen(cmd)-1);
        } else {
                sendpkt("busbs", MSG_CHANNEL_REQUEST,
                        send.chan,
@@ -1291,6 +1365,32 @@ Next1:   switch(recvpkt()){
                        0,
                        cmd, strlen(cmd));
        }
+
+Mux:
+       notify(catch);
+       atexit(shutdown);
+
+       recv.pid = getpid();
+       n = rfork(RFPROC|RFMEM);
+       if(n < 0)
+               sysfatal("fork: %r");
+
+       /* parent reads and dispatches packets */
+       if(n > 0) {
+               send.pid = n;
+               while(recv.eof == 0){
+                       recvpkt();
+                       qlock(&sl);                                     
+                       dispatch();
+                       if((int)(send.kex - send.seq) <= 0 || (int)(recv.kex - recv.seq) <= 0)
+                               kex(0);
+                       qunlock(&sl);
+               }
+               exits(status);
+       }
+
+       /* child reads input and sends packets */
+       qlock(&sl);
        for(;;){
                static uchar buf[MaxPacket];
                qunlock(&sl);
@@ -1298,8 +1398,19 @@ Next1:   switch(recvpkt()){
                qlock(&sl);
                if(send.eof)
                        break;
-               if(n < 0 && wasintr()){
+               if(n < 0 && wasintr())
+                       intr = 1;
+               if(intr){
                        if(!raw) break;
+                       getdim();
+                       sendpkt("busbuuuu", MSG_CHANNEL_REQUEST,
+                               send.chan,
+                               "window-change", 13,
+                               0,
+                               tty.cols,
+                               tty.lines,
+                               tty.xpixels,
+                               tty.ypixels);
                        sendpkt("busbs", MSG_CHANNEL_REQUEST,
                                send.chan,
                                "signal", 6,
@@ -1310,6 +1421,10 @@ Next1:   switch(recvpkt()){
                }
                if(n <= 0)
                        break;
+               if(mux){
+                       sendpkt("[", buf, n);
+                       continue;
+               }
                send.win -= n;
                while(send.win < 0)
                        rsleep(&send);
@@ -1317,8 +1432,10 @@ Next1:   switch(recvpkt()){
                        send.chan,
                        buf, n);
        }
-       if(send.eof++ == 0)
+       if(send.eof++ == 0 && !mux)
                sendpkt("bu", raw ? MSG_CHANNEL_CLOSE : MSG_CHANNEL_EOF, send.chan);
+       else if(recv.pid > 0 && mux)
+               postnote(PNPROC, recv.pid, "shutdown");
        qunlock(&sl);
 
        exits(nil);