]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/ssh.c
e9021a53bc50f2a27ebacb793615933013eb2290
[plan9front.git] / sys / src / cmd / ssh.c
1 #include <u.h>
2 #include <libc.h>
3 #include <mp.h>
4 #include <libsec.h>
5 #include <auth.h>
6 #include <authsrv.h>
7
8 enum {
9         MSG_DISCONNECT = 1,
10         MSG_IGNORE,
11         MSG_UNIMPLEMENTED,
12         MSG_DEBUG,
13         MSG_SERVICE_REQUEST,
14         MSG_SERVICE_ACCEPT,
15
16         MSG_KEXINIT = 20,
17         MSG_NEWKEYS,
18
19         MSG_ECDH_INIT = 30,
20         MSG_ECDH_REPLY,
21
22         MSG_USERAUTH_REQUEST = 50,
23         MSG_USERAUTH_FAILURE,
24         MSG_USERAUTH_SUCCESS,
25         MSG_USERAUTH_BANNER,
26
27         MSG_USERAUTH_PK_OK = 60,
28         MSG_USERAUTH_INFO_REQUEST = 60,
29         MSG_USERAUTH_INFO_RESPONSE = 61,
30
31         MSG_GLOBAL_REQUEST = 80,
32         MSG_REQUEST_SUCCESS,
33         MSG_REQUEST_FAILURE,
34
35         MSG_CHANNEL_OPEN = 90,
36         MSG_CHANNEL_OPEN_CONFIRMATION,
37         MSG_CHANNEL_OPEN_FAILURE,
38         MSG_CHANNEL_WINDOW_ADJUST,
39         MSG_CHANNEL_DATA,
40         MSG_CHANNEL_EXTENDED_DATA,
41         MSG_CHANNEL_EOF,
42         MSG_CHANNEL_CLOSE,
43         MSG_CHANNEL_REQUEST,
44         MSG_CHANNEL_SUCCESS,
45         MSG_CHANNEL_FAILURE,
46 };
47
48
49 enum {
50         Overhead = 256,         // enougth for MSG_CHANNEL_DATA header
51         MaxPacket = 1<<15,
52         WinPackets = 8,         // (1<<15) * 8 = 256K
53 };
54
55 int MaxPwTries = 3; // retry this often for keyboard-interactive
56
57 typedef struct
58 {
59         u32int          seq;
60         u32int          kex;
61         u32int          chan;
62
63         int             win;
64         int             pkt;
65         int             eof;
66
67         Chachastate     cs1;
68         Chachastate     cs2;
69
70         uchar           *r;
71         uchar           *w;
72         uchar           b[Overhead + MaxPacket];
73
74         char            *v;
75         int             pid;
76         Rendez;
77 } Oneway;
78
79 int nsid;
80 uchar sid[256];
81 char thumb[2*SHA2_256dlen+1], *thumbfile;
82
83 int fd, intr, raw, port, mux, debug;
84 char *user, *service, *status, *host, *remote, *cmd;
85
86 Oneway recv, send;
87 void dispatch(void);
88
89 void
90 shutdown(void)
91 {
92         recv.eof = send.eof = 1;
93         if(send.pid > 0)
94                 postnote(PNPROC, send.pid, "shutdown");
95 }
96
97 void
98 catch(void*, char *msg)
99 {
100         if(strcmp(msg, "interrupt") == 0){
101                 intr = 1;
102                 noted(NCONT);
103         }
104         noted(NDFLT);
105 }
106
107 int
108 wasintr(void)
109 {
110         char err[ERRMAX];
111         int r;
112
113         memset(err, 0, sizeof(err));
114         errstr(err, sizeof(err));
115         r = strcmp(err, "interrupted") == 0;
116         errstr(err, sizeof(err));
117         return r;
118 }
119
120 #define PUT4(p, u) (p)[0] = (u)>>24, (p)[1] = (u)>>16, (p)[2] = (u)>>8, (p)[3] = (u)
121 #define GET4(p) (u32int)(p)[3] | (u32int)(p)[2]<<8 | (u32int)(p)[1]<<16 | (u32int)(p)[0]<<24
122
123 int
124 vpack(uchar *p, int n, char *fmt, va_list a)
125 {
126         uchar *p0 = p, *e = p+n;
127         u32int u;
128         mpint *m;
129         void *s;
130         int c;
131
132         for(;;){
133                 switch(c = *fmt++){
134                 case '\0':
135                         return p - p0;
136                 case '_':
137                         if(++p > e) goto err;
138                         break;
139                 case '.':
140                         *va_arg(a, void**) = p;
141                         break;
142                 case 'b':
143                         if(p >= e) goto err;
144                         *p++ = va_arg(a, int);
145                         break;
146                 case 'm':
147                         m = va_arg(a, mpint*);
148                         u = (mpsignif(m)+8)/8;
149                         if(p+4 > e) goto err;
150                         PUT4(p, u), p += 4;
151                         if(u > e-p) goto err;
152                         mptober(m, p, u), p += u;
153                         break;
154                 case '[':
155                 case 's':
156                         s = va_arg(a, void*);
157                         u = va_arg(a, int);
158                         if(c == 's'){
159                                 if(p+4 > e) goto err;
160                                 PUT4(p, u), p += 4;
161                         }
162                         if(u > e-p) goto err;
163                         memmove(p, s, u);
164                         p += u;
165                         break;
166                 case 'u':
167                         u = va_arg(a, int);
168                         if(p+4 > e) goto err;
169                         PUT4(p, u), p += 4;
170                         break;
171                 }
172         }
173 err:
174         return -1;
175 }
176
177 int
178 vunpack(uchar *p, int n, char *fmt, va_list a)
179 {
180         uchar *p0 = p, *e = p+n;
181         u32int u;
182         mpint *m;
183         void *s;
184
185         for(;;){
186                 switch(*fmt++){
187                 case '\0':
188                         return p - p0;
189                 case '_':
190                         if(++p > e) goto err;
191                         break;
192                 case '.':
193                         *va_arg(a, void**) = p;
194                         break;
195                 case 'b':
196                         if(p >= e) goto err;
197                         *va_arg(a, int*) = *p++;
198                         break;
199                 case 'm':
200                         if(p+4 > e) goto err;
201                         u = GET4(p), p += 4;
202                         if(u > e-p) goto err;
203                         m = va_arg(a, mpint*);
204                         betomp(p, u, m), p += u;
205                         break;
206                 case 's':
207                         if(p+4 > e) goto err;
208                         u = GET4(p), p += 4;
209                         if(u > e-p) goto err;
210                         *va_arg(a, void**) = p;
211                         *va_arg(a, int*) = u;
212                         p += u;
213                         break;
214                 case '[':
215                         s = va_arg(a, void*);
216                         u = va_arg(a, int);
217                         if(u > e-p) goto err;
218                         memmove(s, p, u);
219                         p += u;
220                         break;
221                 case 'u':
222                         if(p+4 > e) goto err;
223                         u = GET4(p);
224                         *va_arg(a, int*) = u;
225                         p += 4;
226                         break;
227                 }
228         }
229 err:
230         return -1;
231 }
232
233 int
234 pack(uchar *p, int n, char *fmt, ...)
235 {
236         va_list a;
237         va_start(a, fmt);
238         n = vpack(p, n, fmt, a);
239         va_end(a);
240         return n;
241 }
242 int
243 unpack(uchar *p, int n, char *fmt, ...)
244 {
245         va_list a;
246         va_start(a, fmt);
247         n = vunpack(p, n, fmt, a);
248         va_end(a);
249         return n;
250 }
251
252 void
253 setupcs(Oneway *c, uchar otk[32])
254 {
255         uchar iv[8];
256
257         memset(otk, 0, 32);
258         pack(iv, sizeof(iv), "uu", 0, c->seq);
259         chacha_setiv(&c->cs1, iv);
260         chacha_setiv(&c->cs2, iv);
261         chacha_setblock(&c->cs1, 0);
262         chacha_setblock(&c->cs2, 0);
263         chacha_encrypt(otk, 32, &c->cs2);
264 }
265
266 void
267 sendpkt(char *fmt, ...)
268 {
269         static uchar buf[sizeof(send.b)];
270         int n, pad;
271         va_list a;
272
273         va_start(a, fmt);
274         n = vpack(send.b, sizeof(send.b), fmt, a);
275         va_end(a);
276         if(n < 0) {
277 toobig:         sysfatal("sendpkt: message too big");
278                 return;
279         }
280         send.r = send.b;
281         send.w = send.b+n;
282
283 if(debug > 1)
284         fprint(2, "sendpkt: (%d) %.*H\n", send.r[0], (int)(send.w-send.r), send.r);
285
286         if(nsid){
287                 /* undocumented */
288                 pad = ChachaBsize - ((5+n) % ChachaBsize) + 4;
289         } else {
290                 for(pad=4; (5+n+pad) % 8; pad++)
291                         ;
292         }
293         prng(send.w, pad);
294         n = pack(buf, sizeof(buf)-16, "ub[[", 1+n+pad, pad, send.b, n, send.w, pad);
295         if(n < 0) goto toobig;
296         if(nsid){
297                 uchar otk[32];
298
299                 setupcs(&send, otk);
300                 chacha_encrypt(buf, 4, &send.cs1);
301                 chacha_encrypt(buf+4, n-4, &send.cs2);
302                 poly1305(buf, n, otk, sizeof(otk), buf+n, nil);
303                 n += 16;
304         }
305
306         if(write(fd, buf, n) != n)
307                 sysfatal("write: %r");
308
309         send.seq++;
310 }
311
312 int
313 readall(int fd, uchar *data, int len)
314 {
315         int n, tot;
316
317         for(tot = 0; tot < len; tot += n){
318                 n = read(fd, data+tot, len-tot);
319                 if(n <= 0){
320                         if(n < 0 && wasintr()){
321                                 n = 0;
322                                 continue;
323                         } else if(n == 0)
324                                 werrstr("eof");
325                         break;
326                 }
327         }
328         return tot;
329 }
330
331 int
332 recvpkt(void)
333 {
334         uchar otk[32], tag[16];
335         DigestState *ds = nil;
336         int n;
337
338         if(readall(fd, recv.b, 4) != 4)
339                 sysfatal("read1: %r");
340         if(nsid){
341                 setupcs(&recv, otk);
342                 ds = poly1305(recv.b, 4, otk, sizeof(otk), nil, nil);
343                 chacha_encrypt(recv.b, 4, &recv.cs1);
344                 unpack(recv.b, 4, "u", &n);
345                 n += 16;
346         } else {
347                 unpack(recv.b, 4, "u", &n);
348         }
349         if(n < 8 || n > sizeof(recv.b)){
350 badlen:         sysfatal("bad length %d", n);
351         }
352         if(readall(fd, recv.b, n) != n)
353                 sysfatal("read2: %r");
354         if(nsid){
355                 n -= 16;
356                 if(n < 0) goto badlen;
357                 poly1305(recv.b, n, otk, sizeof(otk), tag, ds);
358                 if(tsmemcmp(tag, recv.b+n, 16) != 0)
359                         sysfatal("bad tag");
360                 chacha_encrypt(recv.b, n, &recv.cs2);
361         }
362         n -= recv.b[0]+1;
363         if(n < 1) goto badlen;
364
365         recv.r = recv.b + 1;
366         recv.w = recv.r + n;
367         recv.seq++;
368
369 if(debug > 1)
370         fprint(2, "recvpkt: (%d) %.*H\n", recv.r[0], (int)(recv.w-recv.r), recv.r);
371
372         return recv.r[0];
373 }
374
375 static char sshrsa[] = "ssh-rsa";
376
377 int
378 rsapub2ssh(RSApub *rsa, uchar *data, int len)
379 {
380         return pack(data, len, "smm", sshrsa, sizeof(sshrsa)-1, rsa->ek, rsa->n);
381 }
382
383 RSApub*
384 ssh2rsapub(uchar *data, int len)
385 {
386         RSApub *pub;
387         char *s;
388         int n;
389
390         pub = rsapuballoc();
391         pub->n = mpnew(0);
392         pub->ek = mpnew(0);
393         if(unpack(data, len, "smm", &s, &n, pub->ek, pub->n) < 0
394         || n != sizeof(sshrsa)-1 || memcmp(s, sshrsa, n) != 0){
395                 rsapubfree(pub);
396                 return nil;
397         }
398         return pub;
399 }
400
401 int
402 rsasig2ssh(RSApub *pub, mpint *S, uchar *data, int len)
403 {
404         int l = (mpsignif(pub->n)+7)/8;
405         if(4+7+4+l > len)
406                 return -1;
407         mptober(S, data+4+7+4, l);
408         return pack(data, len, "ss", sshrsa, sizeof(sshrsa)-1, data+4+7+4, l);
409 }
410
411 mpint*
412 ssh2rsasig(uchar *data, int len)
413 {
414         mpint *m;
415         char *s;
416         int n;
417
418         m = mpnew(0);
419         if(unpack(data, len, "sm", &s, &n, m) < 0
420         || n != sizeof(sshrsa)-1 || memcmp(s, sshrsa, n) != 0){
421                 mpfree(m);
422                 return nil;
423         }
424         return m;
425 }
426
427 mpint*
428 pkcs1digest(uchar *data, int len, RSApub *pub)
429 {
430         uchar digest[SHA1dlen], buf[256];
431
432         sha1(data, len, digest, nil);
433         return pkcs1padbuf(buf, asn1encodedigest(sha1, digest, buf, sizeof(buf)), pub->n, 1);
434 }
435
436 int
437 pkcs1verify(uchar *data, int len, RSApub *pub, mpint *S)
438 {
439         mpint *V;
440         int ret;
441
442         V = pkcs1digest(data, len, pub);
443         ret = V != nil;
444         if(ret){
445                 rsaencrypt(pub, S, S);
446                 ret = mpcmp(V, S) == 0;
447                 mpfree(V);
448         }
449         return ret;
450 }
451
452 DigestState*
453 hashstr(void *data, ulong len, DigestState *ds)
454 {
455         uchar l[4];
456         pack(l, 4, "u", len);
457         return sha2_256((uchar*)data, len, nil, sha2_256(l, 4, nil, ds));
458 }
459
460 void
461 kdf(uchar *k, int nk, uchar *h, char x, uchar *out, int len)
462 {
463         uchar digest[SHA2_256dlen], *out0;
464         DigestState *ds;
465         int n;
466
467         ds = hashstr(k, nk, nil);
468         ds = sha2_256(h, sizeof(digest), nil, ds);
469         ds = sha2_256((uchar*)&x, 1, nil, ds);
470         sha2_256(sid, nsid, digest, ds);
471         for(out0=out;;){
472                 n = len;
473                 if(n > sizeof(digest))
474                         n = sizeof(digest);
475                 memmove(out, digest, n);
476                 len -= n;
477                 if(len == 0)
478                         break;
479                 out += n;
480                 ds = hashstr(k, nk, nil);
481                 ds = sha2_256(h, sizeof(digest), nil, ds);
482                 sha2_256(out0, out-out0, digest, ds);
483         }
484 }
485
486 void
487 kex(int gotkexinit)
488 {
489         static char kexalgs[] = "curve25519-sha256,curve25519-sha256@libssh.org";
490         static char cipheralgs[] = "chacha20-poly1305@openssh.com";
491         static char zipalgs[] = "none";
492         static char macalgs[] = "hmac-sha1";    /* work around for github.com */
493         static char langs[] = "";
494
495         uchar cookie[16], x[32], yc[32], z[32], k[32+1], h[SHA2_256dlen], *ys, *ks, *sig;
496         uchar k12[2*ChachaKeylen];
497         int i, nk, nys, nks, nsig;
498         DigestState *ds;
499         mpint *S, *K;
500         RSApub *pub;
501
502         ds = hashstr(send.v, strlen(send.v), nil);      
503         ds = hashstr(recv.v, strlen(recv.v), ds);
504
505         genrandom(cookie, sizeof(cookie));
506         sendpkt("b[ssssssssssbu", MSG_KEXINIT,
507                 cookie, sizeof(cookie),
508                 kexalgs, sizeof(kexalgs)-1,
509                 sshrsa, sizeof(sshrsa)-1,
510                 cipheralgs, sizeof(cipheralgs)-1,
511                 cipheralgs, sizeof(cipheralgs)-1,
512                 macalgs, sizeof(macalgs)-1,
513                 macalgs, sizeof(macalgs)-1,
514                 zipalgs, sizeof(zipalgs)-1,
515                 zipalgs, sizeof(zipalgs)-1,
516                 langs, sizeof(langs)-1,
517                 langs, sizeof(langs)-1,
518                 0,
519                 0);
520         ds = hashstr(send.r, send.w-send.r, ds);
521
522         if(!gotkexinit){
523         Next0:  switch(recvpkt()){
524                 default:
525                         dispatch();
526                         goto Next0;
527                 case MSG_KEXINIT:
528                         break;
529                 }
530         }
531         ds = hashstr(recv.r, recv.w-recv.r, ds);
532
533         if(debug){
534                 char *tab[] = {
535                         "kexalgs", "hostalgs",
536                         "cipher1", "cipher2",
537                         "mac1", "mac2",
538                         "zip1", "zip2",
539                         "lang1", "lang2",
540                         nil,
541                 }, **t, *s;
542                 uchar *p = recv.r+17;
543                 int n;
544                 for(t=tab; *t != nil; t++){
545                         if(unpack(p, recv.w-p, "s.", &s, &n, &p) < 0)
546                                 break;
547                         fprint(2, "%s: %.*s\n", *t, utfnlen(s, n), s);
548                 }
549         }
550
551         curve25519_dh_new(x, yc);
552         yc[31] &= ~0x80;
553
554         sendpkt("bs", MSG_ECDH_INIT, yc, sizeof(yc));
555 Next1:  switch(recvpkt()){
556         default:
557                 dispatch();
558                 goto Next1;
559         case MSG_KEXINIT:
560                 sysfatal("inception");
561         case MSG_ECDH_REPLY:
562                 if(unpack(recv.r, recv.w-recv.r, "_sss", &ks, &nks, &ys, &nys, &sig, &nsig) < 0)
563                         sysfatal("bad ECDH_REPLY");
564                 break;
565         }
566
567         if(nys != 32)
568                 sysfatal("bad server ECDH ephermal public key length");
569
570         ds = hashstr(ks, nks, ds);
571         ds = hashstr(yc, 32, ds);
572         ds = hashstr(ys, 32, ds);
573
574         if(thumb[0] == 0){
575                 Thumbprint *ok;
576
577                 sha2_256(ks, nks, h, nil);
578                 i = enc64(thumb, sizeof(thumb), h, sizeof(h));
579                 while(i > 0 && thumb[i-1] == '=')
580                         i--;
581                 thumb[i] = '\0';
582
583                 if(debug)
584                         fprint(2, "host fingerprint: %s\n", thumb);
585
586                 ok = initThumbprints(thumbfile, nil, "ssh");
587                 if(ok == nil || !okThumbprint(h, sizeof(h), ok)){
588                         if(ok != nil) werrstr("unknown host");
589                         fprint(2, "%s: %r\n", argv0);
590                         fprint(2, "verify hostkey: %s %.*[\n", sshrsa, nks, ks);
591                         fprint(2, "add thumbprint after verification:\n");
592                         fprint(2, "\techo 'ssh sha256=%s server=%s' >> %q\n", thumb, host, thumbfile);
593                         sysfatal("checking hostkey failed: %r");
594                 }
595                 freeThumbprints(ok);
596         }
597
598         if((pub = ssh2rsapub(ks, nks)) == nil)
599                 sysfatal("bad server public key");
600         if((S = ssh2rsasig(sig, nsig)) == nil)
601                 sysfatal("bad server signature");
602
603         if(!curve25519_dh_finish(x, ys, z))
604                 sysfatal("unlucky shared key");
605
606         K = betomp(z, 32, nil);
607         nk = (mpsignif(K)+8)/8;
608         mptober(K, k, nk);
609         mpfree(K);
610
611         ds = hashstr(k, nk, ds);
612         sha2_256(nil, 0, h, ds);
613         if(!pkcs1verify(h, sizeof(h), pub, S))
614                 sysfatal("server verification failed");
615         mpfree(S);
616         rsapubfree(pub);
617
618         sendpkt("b", MSG_NEWKEYS);
619 Next2:  switch(recvpkt()){
620         default:
621                 dispatch();
622                 goto Next2;
623         case MSG_KEXINIT:
624                 sysfatal("inception");
625         case MSG_NEWKEYS:
626                 break;
627         }
628
629         /* next key exchange */
630         recv.kex = recv.seq + 100000;
631         send.kex = send.seq + 100000;
632
633         if(nsid == 0)
634                 memmove(sid, h, nsid = sizeof(h));
635
636         kdf(k, nk, h, 'C', k12, sizeof(k12));
637         setupChachastate(&send.cs1, k12+1*ChachaKeylen, ChachaKeylen, nil, 64/8, 20);
638         setupChachastate(&send.cs2, k12+0*ChachaKeylen, ChachaKeylen, nil, 64/8, 20);
639
640         kdf(k, nk, h, 'D', k12, sizeof(k12));
641         setupChachastate(&recv.cs1, k12+1*ChachaKeylen, ChachaKeylen, nil, 64/8, 20);
642         setupChachastate(&recv.cs2, k12+0*ChachaKeylen, ChachaKeylen, nil, 64/8, 20);
643 }
644
645 static char *authnext;
646
647 int
648 authok(char *meth)
649 {
650         int ok = authnext == nil || strstr(authnext, meth) != nil;
651 if(debug)
652         fprint(2, "userauth %s %s\n", meth, ok ? "ok" : "skipped");
653         return ok;
654 }
655
656 int
657 authfailure(char *meth)
658 {
659         char *s;
660         int n, partial;
661
662         if(unpack(recv.r, recv.w-recv.r, "_sb", &s, &n, &partial) < 0)
663                 sysfatal("bad auth failure response");
664         free(authnext);
665         authnext = smprint("%.*s", utfnlen(s, n), s);
666 if(debug)
667         fprint(2, "userauth %s failed: partial=%d, next=%s\n", meth, partial, authnext);
668         return partial != 0 || !authok(meth);
669 }
670
671 int
672 noneauth(void)
673 {
674         static char authmeth[] = "none";
675
676         if(!authok(authmeth))
677                 return -1;
678
679         sendpkt("bsss", MSG_USERAUTH_REQUEST,
680                 user, strlen(user),
681                 service, strlen(service),
682                 authmeth, sizeof(authmeth)-1);
683
684 Next0:  switch(recvpkt()){
685         default:
686                 dispatch();
687                 goto Next0;
688         case MSG_USERAUTH_FAILURE:
689                 werrstr("authentication needed");
690                 authfailure(authmeth);
691                 return -1;
692         case MSG_USERAUTH_SUCCESS:
693                 return 0;
694         }
695 }
696
697 int
698 pubkeyauth(void)
699 {
700         static char authmeth[] = "publickey";
701
702         uchar pk[4096], sig[4096];
703         int npk, nsig;
704
705         int afd, n;
706         char *s;
707         mpint *S;
708         AuthRpc *rpc;
709         RSApub *pub;
710
711         if(!authok(authmeth))
712                 return -1;
713
714         if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0)
715                 return -1;
716         if((rpc = auth_allocrpc(afd)) == nil){
717                 close(afd);
718                 return -1;
719         }
720
721         s = "proto=rsa service=ssh role=client";
722         if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){
723                 auth_freerpc(rpc);
724                 close(afd);
725                 return -1;
726         }
727
728         pub = rsapuballoc();
729         pub->n = mpnew(0);
730         pub->ek = mpnew(0);
731
732         while(auth_rpc(rpc, "read", nil, 0) == ARok){
733                 s = rpc->arg;
734                 if(strtomp(s, &s, 16, pub->n) == nil)
735                         break;
736                 if(*s++ != ' ')
737                         continue;
738                 if(strtomp(s, nil, 16, pub->ek) == nil)
739                         continue;
740                 npk = rsapub2ssh(pub, pk, sizeof(pk));
741
742                 sendpkt("bsssbss", MSG_USERAUTH_REQUEST,
743                         user, strlen(user),
744                         service, strlen(service),
745                         authmeth, sizeof(authmeth)-1,
746                         0,
747                         sshrsa, sizeof(sshrsa)-1,
748                         pk, npk);
749 Next1:          switch(recvpkt()){
750                 default:
751                         dispatch();
752                         goto Next1;
753                 case MSG_USERAUTH_FAILURE:
754                         if(authfailure(authmeth))
755                                 goto Failed;
756                         continue;
757                 case MSG_USERAUTH_SUCCESS:
758                 case MSG_USERAUTH_PK_OK:
759                         break;
760                 }
761
762                 /* sign sid and the userauth request */
763                 n = pack(send.b, sizeof(send.b), "sbsssbss",
764                         sid, nsid,
765                         MSG_USERAUTH_REQUEST,
766                         user, strlen(user),
767                         service, strlen(service),
768                         authmeth, sizeof(authmeth)-1,
769                         1,
770                         sshrsa, sizeof(sshrsa)-1,
771                         pk, npk);
772                 S = pkcs1digest(send.b, n, pub);
773                 n = snprint((char*)send.b, sizeof(send.b), "%B", S);
774                 mpfree(S);
775
776                 if(auth_rpc(rpc, "write", (char*)send.b, n) != ARok)
777                         break;
778                 if(auth_rpc(rpc, "read", nil, 0) != ARok)
779                         break;
780
781                 S = strtomp(rpc->arg, nil, 16, nil);
782                 nsig = rsasig2ssh(pub, S, sig, sizeof(sig));
783                 mpfree(S);
784
785                 /* send final userauth request with the signature */
786                 sendpkt("bsssbsss", MSG_USERAUTH_REQUEST,
787                         user, strlen(user),
788                         service, strlen(service),
789                         authmeth, sizeof(authmeth)-1,
790                         1,
791                         sshrsa, sizeof(sshrsa)-1,
792                         pk, npk,
793                         sig, nsig);
794 Next2:          switch(recvpkt()){
795                 default:
796                         dispatch();
797                         goto Next2;
798                 case MSG_USERAUTH_FAILURE:
799                         if(authfailure(authmeth))
800                                 goto Failed;
801                         continue;
802                 case MSG_USERAUTH_SUCCESS:
803                         break;
804                 }
805                 rsapubfree(pub);
806                 auth_freerpc(rpc);
807                 close(afd);
808                 return 0;
809         }
810 Failed:
811         rsapubfree(pub);
812         auth_freerpc(rpc);
813         close(afd);
814         return -1;      
815 }
816
817 int
818 passauth(void)
819 {
820         static char authmeth[] = "password";
821         UserPasswd *up;
822
823         if(!authok(authmeth))
824                 return -1;
825
826         up = auth_getuserpasswd(auth_getkey, "proto=pass service=ssh user=%q server=%q thumb=%q",
827                 user, host, thumb);
828         if(up == nil)
829                 return -1;
830
831         sendpkt("bsssbs", MSG_USERAUTH_REQUEST,
832                 user, strlen(user),
833                 service, strlen(service),
834                 authmeth, sizeof(authmeth)-1,
835                 0,
836                 up->passwd, strlen(up->passwd));
837
838         memset(up->passwd, 0, strlen(up->passwd));
839         free(up);
840
841 Next0:  switch(recvpkt()){
842         default:
843                 dispatch();
844                 goto Next0;
845         case MSG_USERAUTH_FAILURE:
846                 werrstr("wrong password");
847                 authfailure(authmeth);
848                 return -1;
849         case MSG_USERAUTH_SUCCESS:
850                 return 0;
851         }
852 }
853
854 int
855 kbintauth(void)
856 {
857         static char authmeth[] = "keyboard-interactive";
858         int tries;
859
860         char *name, *inst, *s, *a;
861         int fd, i, n, m;
862         int nquest, echo;
863         uchar *ans, *answ;
864         tries = 0;
865
866         if(!authok(authmeth))
867                 return -1;
868
869 Loop:
870         if(++tries > MaxPwTries)
871                 return -1;
872                 
873         sendpkt("bsssss", MSG_USERAUTH_REQUEST,
874                 user, strlen(user),
875                 service, strlen(service),
876                 authmeth, sizeof(authmeth)-1,
877                 "", 0,
878                 "", 0);
879
880 Next0:  switch(recvpkt()){
881         default:
882                 dispatch();
883                 goto Next0;
884         case MSG_USERAUTH_FAILURE:
885                 werrstr("keyboard-interactive failed");
886                 if(authfailure(authmeth))
887                         return -1;
888                 goto Loop;
889         case MSG_USERAUTH_SUCCESS:
890                 return 0;
891         case MSG_USERAUTH_INFO_REQUEST:
892                 break;
893         }
894 Retry:
895         if((fd = open("/dev/cons", OWRITE)) < 0)
896                 return -1;
897
898         if(unpack(recv.r, recv.w-recv.r, "_ss.", &name, &n, &inst, &m, &recv.r) < 0)
899                 sysfatal("bad info request: name, inst");
900
901         while(n > 0 && strchr("\r\n\t ", name[n-1]) != nil)
902                 n--;
903         while(m > 0 && strchr("\r\n\t ", inst[m-1]) != nil)
904                 m--;
905
906         if(n > 0)
907                 fprint(fd, "%.*s\n", utfnlen(name, n), name);
908         if(m > 0)
909                 fprint(fd, "%.*s\n", utfnlen(inst, m), inst);
910
911         /* lang, nprompt */
912         if(unpack(recv.r, recv.w-recv.r, "su.", &s, &n, &nquest, &recv.r) < 0)
913                 sysfatal("bad info request: lang, #quest");
914
915         ans = answ = nil;
916         for(i = 0; i < nquest; i++){
917                 if(unpack(recv.r, recv.w-recv.r, "sb.", &s, &n, &echo, &recv.r) < 0)
918                         sysfatal("bad info request: question [%d]", i);
919
920                 while(n > 0 && strchr("\r\n\t :", s[n-1]) != nil)
921                         n--;
922                 s[n] = '\0';
923
924                 if((a = readcons(s, nil, !echo)) == nil)
925                         sysfatal("readcons: %r");
926
927                 n = answ - ans;
928                 m = strlen(a)+4;
929                 if((s = realloc(ans, n + m)) == nil)
930                         sysfatal("realloc: %r");
931                 ans = (uchar*)s;
932                 answ = ans+n;
933                 answ += pack(answ, m, "s", a, m-4);
934         }
935
936         sendpkt("bu[", MSG_USERAUTH_INFO_RESPONSE, i, ans, answ - ans);
937         free(ans);
938         close(fd);
939
940 Next1:  switch(recvpkt()){
941         default:
942                 dispatch();
943                 goto Next1;
944         case MSG_USERAUTH_INFO_REQUEST:
945                 goto Retry;
946         case MSG_USERAUTH_FAILURE:
947                 werrstr("keyboard-interactive failed");
948                 if(authfailure(authmeth))
949                         return -1;
950                 goto Loop;
951         case MSG_USERAUTH_SUCCESS:
952                 return 0;
953         }
954 }
955
956 void
957 dispatch(void)
958 {
959         char *s;
960         uchar *p;
961         int n, b, c;
962
963         switch(recv.r[0]){
964         case MSG_IGNORE:
965                 return;
966         case MSG_GLOBAL_REQUEST:
967                 if(unpack(recv.r, recv.w-recv.r, "_sb", &s, &n, &b) < 0)
968                         break;
969                 if(debug)
970                         fprint(2, "%s: global request: %.*s\n",
971                                 argv0, utfnlen(s, n), s);
972                 if(b != 0)
973                         sendpkt("b", MSG_REQUEST_FAILURE);
974                 return;
975         case MSG_DISCONNECT:
976                 if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0)
977                         break;
978                 sysfatal("disconnect: (%d) %.*s", c, utfnlen(s, n), s);
979                 return;
980         case MSG_DEBUG:
981                 if(unpack(recv.r, recv.w-recv.r, "__sb", &s, &n, &c) < 0)
982                         break;
983                 if(c != 0 || debug)
984                         fprint(2, "%s: %.*s\n", argv0, utfnlen(s, n), s);
985                 return;
986         case MSG_USERAUTH_BANNER:
987                 if(unpack(recv.r, recv.w-recv.r, "_s", &s, &n) < 0)
988                         break;
989                 if(raw) write(2, s, n);
990                 return;
991         case MSG_KEXINIT:
992                 kex(1);
993                 return;
994         }
995
996         if(mux){
997                 n = recv.w - recv.r;
998                 if(write(1, recv.r, n) != n)
999                         sysfatal("write out: %r");
1000                 return;
1001         }
1002
1003         switch(recv.r[0]){
1004         case MSG_CHANNEL_DATA:
1005                 if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0)
1006                         break;
1007                 if(c != recv.chan)
1008                         break;
1009                 if(write(1, s, n) != n)
1010                         sysfatal("write out: %r");
1011         Winadjust:
1012                 recv.win -= n;
1013                 if(recv.win < recv.pkt){
1014                         n = WinPackets*recv.pkt;
1015                         recv.win += n;
1016                         sendpkt("buu", MSG_CHANNEL_WINDOW_ADJUST, send.chan, n);
1017                 }
1018                 return;
1019         case MSG_CHANNEL_EXTENDED_DATA:
1020                 if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0)
1021                         break;
1022                 if(c != recv.chan)
1023                         break;
1024                 if(b == 1) write(2, s, n);
1025                 goto Winadjust;
1026         case MSG_CHANNEL_WINDOW_ADJUST:
1027                 if(unpack(recv.r, recv.w-recv.r, "_uu", &c, &n) < 0)
1028                         break;
1029                 if(c != recv.chan)
1030                         break;
1031                 send.win += n;
1032                 if(send.win >= send.pkt)
1033                         rwakeup(&send);
1034                 return;
1035         case MSG_CHANNEL_REQUEST:
1036                 if(unpack(recv.r, recv.w-recv.r, "_usb.", &c, &s, &n, &b, &p) < 0)
1037                         break;
1038                 if(c != recv.chan)
1039                         break;
1040                 if(n == 11 && memcmp(s, "exit-signal", n) == 0){
1041                         if(unpack(p, recv.w-p, "s", &s, &n) < 0)
1042                                 break;
1043                         if(n != 0 && status == nil)
1044                                 status = smprint("%.*s", utfnlen(s, n), s);
1045                         c = MSG_CHANNEL_SUCCESS;
1046                 } else if(n == 11 && memcmp(s, "exit-status", n) == 0){
1047                         if(unpack(p, recv.w-p, "u", &n) < 0)
1048                                 break;
1049                         if(n != 0 && status == nil)
1050                                 status = smprint("%d", n);
1051                         c = MSG_CHANNEL_SUCCESS;
1052                 } else {
1053                         if(debug)
1054                                 fprint(2, "%s: channel request: %.*s\n",
1055                                         argv0, utfnlen(s, n), s);
1056                         c = MSG_CHANNEL_FAILURE;
1057                 }
1058                 if(b != 0)
1059                         sendpkt("bu", c, recv.chan);
1060                 return;
1061         case MSG_CHANNEL_EOF:
1062                 recv.eof = 1;
1063                 if(!raw) write(1, "", 0);
1064                 return;
1065         case MSG_CHANNEL_CLOSE:
1066                 shutdown();
1067                 return;
1068         }
1069         sysfatal("got: %.*H", (int)(recv.w - recv.r), recv.r);
1070 }
1071
1072 char*
1073 readline(void)
1074 {
1075         uchar *p;
1076
1077         for(p = send.b; p < &send.b[sizeof(send.b)-1]; p++){
1078                 *p = '\0';
1079                 if(read(fd, p, 1) != 1 || *p == '\n')
1080                         break;
1081         }
1082         while(p >= send.b && (*p == '\n' || *p == '\r'))
1083                 *p-- = '\0';
1084         return (char*)send.b;
1085 }
1086
1087 static struct {
1088         char    *term;
1089         int     xpixels;
1090         int     ypixels;
1091         int     lines;
1092         int     cols;
1093         int     gen;
1094 } tty;
1095
1096 int
1097 getdim(void)
1098 {
1099         char *s;
1100         int g;
1101
1102         if(s = getenv("WINCH")){
1103                 g = atoi(s);
1104                 if(tty.gen == g)
1105                         return 0;
1106                 tty.gen = g;
1107                 free(s);
1108         }
1109         if(s = getenv("XPIXELS")){
1110                 tty.xpixels = atoi(s);
1111                 free(s);
1112         }
1113         if(s = getenv("YPIXELS")){
1114                 tty.ypixels = atoi(s);
1115                 free(s);
1116         }
1117         if(s = getenv("LINES")){
1118                 tty.lines = atoi(s);
1119                 free(s);
1120         }
1121         if(s = getenv("COLS")){
1122                 tty.cols = atoi(s);
1123                 free(s);
1124         }
1125         return 1;
1126 }
1127
1128 void
1129 rawon(void)
1130 {
1131         int ctl;
1132
1133         close(0);
1134         if(open("/dev/cons", OREAD) != 0)
1135                 sysfatal("open: %r");
1136         close(1);
1137         if(open("/dev/cons", OWRITE) != 1)
1138                 sysfatal("open: %r");
1139         dup(1, 2);
1140         if((ctl = open("/dev/consctl", OWRITE)) >= 0){
1141                 write(ctl, "rawon", 5);
1142                 write(ctl, "winchon", 7);       /* vt(1): interrupt note on window change */
1143         }
1144         getdim();
1145 }
1146
1147 #pragma    varargck    type  "k"   char*
1148
1149 kfmt(Fmt *f)
1150 {
1151         char *s, *p;
1152         int n;
1153
1154         s = va_arg(f->args, char*);
1155         n = fmtstrcpy(f, "'");
1156         while((p = strchr(s, '\'')) != nil){
1157                 *p = '\0';
1158                 n += fmtstrcpy(f, s);
1159                 *p = '\'';
1160                 n += fmtstrcpy(f, "'\\''");
1161                 s = p+1;
1162         }
1163         n += fmtstrcpy(f, s);
1164         n += fmtstrcpy(f, "'");
1165         return n;
1166 }
1167
1168 void
1169 usage(void)
1170 {
1171         fprint(2, "usage: %s [-dR] [-t thumbfile] [-T tries] [-u user] [-h] [user@]host [-W remote!port] [cmd args...]\n", argv0);
1172         exits("usage");
1173 }
1174
1175 void
1176 main(int argc, char *argv[])
1177 {
1178         static QLock sl;
1179         int b, n, c;
1180         char *s;
1181
1182         quotefmtinstall();
1183         fmtinstall('B', mpfmt);
1184         fmtinstall('H', encodefmt);
1185         fmtinstall('[', encodefmt);
1186         fmtinstall('k', kfmt);
1187
1188         tty.gen = -1;
1189         tty.term = getenv("TERM");
1190         if(tty.term == nil)
1191                 tty.term = "";
1192         raw = *tty.term != 0;
1193
1194         ARGBEGIN {
1195         case 'd':
1196                 debug++;
1197                 break;
1198         case 'W':
1199                 remote = EARGF(usage());
1200                 s = strrchr(remote, '!');
1201                 if(s == nil)
1202                         s = strrchr(remote, ':');
1203                 if(s == nil)
1204                         usage();
1205                 *s++ = 0;
1206                 port = atoi(s);
1207                 raw = 0;
1208                 break;
1209         case 'R':
1210                 raw = 0;
1211                 break;
1212         case 'r':
1213                 raw = 2; /* bloody */
1214                 break;
1215         case 'u':
1216                 user = EARGF(usage());
1217                 break;
1218         case 'h':
1219                 host = EARGF(usage());
1220                 break;
1221         case 't':
1222                 thumbfile = EARGF(usage());
1223                 break;
1224         case 'T':
1225                 MaxPwTries = strtol(EARGF(usage()), &s, 0);
1226                 if(*s != 0) usage();
1227                 break;
1228         case 'X':
1229                 mux = 1;
1230                 raw = 0;
1231                 break;
1232         default:
1233                 usage();
1234         } ARGEND;
1235
1236         if(host == nil){
1237                 if(argc == 0)
1238                         usage();
1239                 host = *argv++;
1240         }
1241
1242         if(user == nil){
1243                 s = strchr(host, '@');
1244                 if(s != nil){
1245                         *s++ = '\0';
1246                         user = host;
1247                         host = s;
1248                 }
1249         }
1250
1251         for(cmd = nil; *argv != nil; argv++){
1252                 if(cmd == nil){
1253                         cmd = strdup(*argv);
1254                         if(raw == 1)
1255                                 raw = 0;
1256                 }else{
1257                         s = smprint("%s %k", cmd, *argv);
1258                         free(cmd);
1259                         cmd = s;
1260                 }
1261         }
1262
1263         if(remote != nil && cmd != nil)
1264                 usage();
1265
1266         if((fd = dial(netmkaddr(host, nil, "ssh"), nil, nil, nil)) < 0)
1267                 sysfatal("dial: %r");
1268
1269         send.v = "SSH-2.0-(9)";
1270         fprint(fd, "%s\r\n", send.v);
1271         recv.v = readline();
1272         if(debug)
1273                 fprint(2, "server verison: %s\n", recv.v);
1274         if(strncmp("SSH-2.0-", recv.v, 8) != 0)
1275                 sysfatal("bad server version: %s", recv.v);
1276         recv.v = strdup(recv.v);
1277
1278         send.l = recv.l = &sl;
1279
1280         if(user == nil)
1281                 user = getuser();
1282         if(thumbfile == nil)
1283                 thumbfile = smprint("%s/lib/sshthumbs", getenv("home"));
1284
1285         kex(0);
1286
1287         sendpkt("bs", MSG_SERVICE_REQUEST, "ssh-userauth", 12);
1288 Next0:  switch(recvpkt()){
1289         default:
1290                 dispatch();
1291                 goto Next0;
1292         case MSG_SERVICE_ACCEPT:
1293                 break;
1294         }
1295
1296         service = "ssh-connection";
1297         if(noneauth() < 0 && pubkeyauth() < 0 && passauth() < 0 && kbintauth() < 0)
1298                 sysfatal("auth: %r");
1299
1300         recv.pkt = send.pkt = MaxPacket;
1301         recv.win = send.win =  WinPackets*recv.pkt;
1302         recv.chan = send.win = 0;
1303
1304         if(mux)
1305                 goto Mux;
1306
1307         /* open hailing frequencies */
1308         if(remote != nil){
1309                 NetConnInfo *nci = getnetconninfo(nil, fd);
1310                 if(nci == nil)
1311                         sysfatal("can't get netconninfo: %r");
1312                 sendpkt("bsuuususu", MSG_CHANNEL_OPEN,
1313                         "direct-tcpip", 12,
1314                         recv.chan,
1315                         recv.win,
1316                         recv.pkt,
1317                         remote, strlen(remote),
1318                         port,
1319                         nci->laddr, strlen(nci->laddr),
1320                         atoi(nci->lserv));
1321                 free(nci);
1322         } else {
1323                 sendpkt("bsuuu", MSG_CHANNEL_OPEN,
1324                         "session", 7,
1325                         recv.chan,
1326                         recv.win,
1327                         recv.pkt);
1328         }
1329 Next1:  switch(recvpkt()){
1330         default:
1331                 dispatch();
1332                 goto Next1;
1333         case MSG_CHANNEL_OPEN_FAILURE:
1334                 if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0)
1335                         n = strlen(s = "???");
1336                 sysfatal("channel open failure: (%d) %.*s", b, utfnlen(s, n), s);
1337         case MSG_CHANNEL_OPEN_CONFIRMATION:
1338                 break;
1339         }
1340
1341         if(unpack(recv.r, recv.w-recv.r, "_uuuu", &recv.chan, &send.chan, &send.win, &send.pkt) < 0)
1342                 sysfatal("bad channel open confirmation");
1343         if(send.pkt <= 0 || send.pkt > MaxPacket)
1344                 send.pkt = MaxPacket;
1345
1346         if(remote != nil)
1347                 goto Mux;
1348
1349         if(raw) {
1350                 rawon();
1351                 sendpkt("busbsuuuus", MSG_CHANNEL_REQUEST,
1352                         send.chan,
1353                         "pty-req", 7,
1354                         0,
1355                         tty.term, strlen(tty.term),
1356                         tty.cols,
1357                         tty.lines,
1358                         tty.xpixels,
1359                         tty.ypixels,
1360                         "", 0);
1361         }
1362         if(cmd == nil){
1363                 sendpkt("busb", MSG_CHANNEL_REQUEST,
1364                         send.chan,
1365                         "shell", 5,
1366                         0);
1367         } else if(*cmd == '#') {
1368                 sendpkt("busbs", MSG_CHANNEL_REQUEST,
1369                         send.chan,
1370                         "subsystem", 9,
1371                         0,
1372                         cmd+1, strlen(cmd)-1);
1373         } else {
1374                 sendpkt("busbs", MSG_CHANNEL_REQUEST,
1375                         send.chan,
1376                         "exec", 4,
1377                         0,
1378                         cmd, strlen(cmd));
1379         }
1380
1381 Mux:
1382         notify(catch);
1383         atexit(shutdown);
1384
1385         recv.pid = getpid();
1386         n = rfork(RFPROC|RFMEM);
1387         if(n < 0)
1388                 sysfatal("fork: %r");
1389
1390         /* parent reads and dispatches packets */
1391         if(n > 0) {
1392                 send.pid = n;
1393                 while(recv.eof == 0){
1394                         recvpkt();
1395                         qlock(&sl);                                     
1396                         dispatch();
1397                         if((int)(send.kex - send.seq) <= 0 || (int)(recv.kex - recv.seq) <= 0)
1398                                 kex(0);
1399                         qunlock(&sl);
1400                 }
1401                 exits(status);
1402         }
1403
1404         /* child reads input and sends packets */
1405         qlock(&sl);
1406         for(;;){
1407                 static uchar buf[MaxPacket];
1408                 qunlock(&sl);
1409                 n = read(0, buf, send.pkt);
1410                 qlock(&sl);
1411                 if(send.eof)
1412                         break;
1413                 if(n < 0 && wasintr())
1414                         intr = 1;
1415                 if(intr){
1416                         if(!raw) break;
1417                         if(getdim()){
1418                                 sendpkt("busbuuuu", MSG_CHANNEL_REQUEST,
1419                                         send.chan,
1420                                         "window-change", 13,
1421                                         0,
1422                                         tty.cols,
1423                                         tty.lines,
1424                                         tty.xpixels,
1425                                         tty.ypixels);
1426                         }else{
1427                                 sendpkt("busbs", MSG_CHANNEL_REQUEST,
1428                                         send.chan,
1429                                         "signal", 6,
1430                                         0,
1431                                         "INT", 3);
1432                         }
1433                         intr = 0;
1434                         continue;
1435                 }
1436                 if(n <= 0)
1437                         break;
1438                 if(mux){
1439                         sendpkt("[", buf, n);
1440                         continue;
1441                 }
1442                 send.win -= n;
1443                 while(send.win < 0)
1444                         rsleep(&send);
1445                 sendpkt("bus", MSG_CHANNEL_DATA,
1446                         send.chan,
1447                         buf, n);
1448         }
1449         if(send.eof++ == 0 && !mux)
1450                 sendpkt("bu", raw ? MSG_CHANNEL_CLOSE : MSG_CHANNEL_EOF, send.chan);
1451         else if(recv.pid > 0 && mux)
1452                 postnote(PNPROC, recv.pid, "shutdown");
1453         qunlock(&sl);
1454
1455         exits(nil);
1456 }