]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/ssh.c
sshfs: usage
[plan9front.git] / sys / src / cmd / ssh.c
1 #include <u.h>
2 #include <libc.h>
3 #include <mp.h>
4 #include <libsec.h>
5 #include <auth.h>
6 #include <authsrv.h>
7
8 enum {
9         MSG_DISCONNECT = 1,
10         MSG_IGNORE,
11         MSG_UNIMPLEMENTED,
12         MSG_DEBUG,
13         MSG_SERVICE_REQUEST,
14         MSG_SERVICE_ACCEPT,
15
16         MSG_KEXINIT = 20,
17         MSG_NEWKEYS,
18
19         MSG_ECDH_INIT = 30,
20         MSG_ECDH_REPLY,
21
22         MSG_USERAUTH_REQUEST = 50,
23         MSG_USERAUTH_FAILURE,
24         MSG_USERAUTH_SUCCESS,
25         MSG_USERAUTH_BANNER,
26
27         MSG_USERAUTH_PK_OK = 60,
28         MSG_USERAUTH_INFO_REQUEST = 60,
29         MSG_USERAUTH_INFO_RESPONSE = 61,
30
31         MSG_GLOBAL_REQUEST = 80,
32         MSG_REQUEST_SUCCESS,
33         MSG_REQUEST_FAILURE,
34
35         MSG_CHANNEL_OPEN = 90,
36         MSG_CHANNEL_OPEN_CONFIRMATION,
37         MSG_CHANNEL_OPEN_FAILURE,
38         MSG_CHANNEL_WINDOW_ADJUST,
39         MSG_CHANNEL_DATA,
40         MSG_CHANNEL_EXTENDED_DATA,
41         MSG_CHANNEL_EOF,
42         MSG_CHANNEL_CLOSE,
43         MSG_CHANNEL_REQUEST,
44         MSG_CHANNEL_SUCCESS,
45         MSG_CHANNEL_FAILURE,
46 };
47
48
49 enum {
50         Overhead = 256,         // enougth for MSG_CHANNEL_DATA header
51         MaxPacket = 1<<15,
52         WinPackets = 8,         // (1<<15) * 8 = 256K
53 };
54
55 int MaxPwTries = 3; // retry this often for keyboard-interactive
56
57 typedef struct
58 {
59         u32int          seq;
60         u32int          kex;
61         u32int          chan;
62
63         int             win;
64         int             pkt;
65         int             eof;
66
67         Chachastate     cs1;
68         Chachastate     cs2;
69
70         uchar           *r;
71         uchar           *w;
72         uchar           b[Overhead + MaxPacket];
73
74         char            *v;
75         int             pid;
76         Rendez;
77 } Oneway;
78
79 int nsid;
80 uchar sid[256];
81 char thumb[2*SHA2_256dlen+1], *thumbfile;
82
83 int fd, intr, raw, debug;
84 char *user, *service, *status, *host, *cmd;
85
86 Oneway recv, send;
87 void dispatch(void);
88
89 void
90 shutdown(void)
91 {
92         recv.eof = send.eof = 1;
93         if(send.pid > 0)
94                 postnote(PNPROC, send.pid, "shutdown");
95 }
96
97 void
98 catch(void*, char *msg)
99 {
100         if(strcmp(msg, "interrupt") == 0){
101                 intr = 1;
102                 noted(NCONT);
103         }
104         noted(NDFLT);
105 }
106
107 int
108 wasintr(void)
109 {
110         char err[ERRMAX];
111         int r;
112
113         if(intr)
114                 return 1;
115         memset(err, 0, sizeof(err));
116         errstr(err, sizeof(err));
117         r = strcmp(err, "interrupted") == 0;
118         errstr(err, sizeof(err));
119         return r;
120 }
121
122 #define PUT4(p, u) (p)[0] = (u)>>24, (p)[1] = (u)>>16, (p)[2] = (u)>>8, (p)[3] = (u)
123 #define GET4(p) (u32int)(p)[3] | (u32int)(p)[2]<<8 | (u32int)(p)[1]<<16 | (u32int)(p)[0]<<24
124
125 int
126 vpack(uchar *p, int n, char *fmt, va_list a)
127 {
128         uchar *p0 = p, *e = p+n;
129         u32int u;
130         mpint *m;
131         void *s;
132         int c;
133
134         for(;;){
135                 switch(c = *fmt++){
136                 case '\0':
137                         return p - p0;
138                 case '_':
139                         if(++p > e) goto err;
140                         break;
141                 case '.':
142                         *va_arg(a, void**) = p;
143                         break;
144                 case 'b':
145                         if(p >= e) goto err;
146                         *p++ = va_arg(a, int);
147                         break;
148                 case 'm':
149                         m = va_arg(a, mpint*);
150                         u = (mpsignif(m)+8)/8;
151                         if(p+4 > e) goto err;
152                         PUT4(p, u), p += 4;
153                         if(u > e-p) goto err;
154                         mptober(m, p, u), p += u;
155                         break;
156                 case '[':
157                 case 's':
158                         s = va_arg(a, void*);
159                         u = va_arg(a, int);
160                         if(c == 's'){
161                                 if(p+4 > e) goto err;
162                                 PUT4(p, u), p += 4;
163                         }
164                         if(u > e-p) goto err;
165                         memmove(p, s, u);
166                         p += u;
167                         break;
168                 case 'u':
169                         u = va_arg(a, int);
170                         if(p+4 > e) goto err;
171                         PUT4(p, u), p += 4;
172                         break;
173                 }
174         }
175 err:
176         return -1;
177 }
178
179 int
180 vunpack(uchar *p, int n, char *fmt, va_list a)
181 {
182         uchar *p0 = p, *e = p+n;
183         u32int u;
184         mpint *m;
185         void *s;
186
187         for(;;){
188                 switch(*fmt++){
189                 case '\0':
190                         return p - p0;
191                 case '_':
192                         if(++p > e) goto err;
193                         break;
194                 case '.':
195                         *va_arg(a, void**) = p;
196                         break;
197                 case 'b':
198                         if(p >= e) goto err;
199                         *va_arg(a, int*) = *p++;
200                         break;
201                 case 'm':
202                         if(p+4 > e) goto err;
203                         u = GET4(p), p += 4;
204                         if(u > e-p) goto err;
205                         m = va_arg(a, mpint*);
206                         betomp(p, u, m), p += u;
207                         break;
208                 case 's':
209                         if(p+4 > e) goto err;
210                         u = GET4(p), p += 4;
211                         if(u > e-p) goto err;
212                         *va_arg(a, void**) = p;
213                         *va_arg(a, int*) = u;
214                         p += u;
215                         break;
216                 case '[':
217                         s = va_arg(a, void*);
218                         u = va_arg(a, int);
219                         if(u > e-p) goto err;
220                         memmove(s, p, u);
221                         p += u;
222                         break;
223                 case 'u':
224                         if(p+4 > e) goto err;
225                         u = GET4(p);
226                         *va_arg(a, int*) = u;
227                         p += 4;
228                         break;
229                 }
230         }
231 err:
232         return -1;
233 }
234
235 int
236 pack(uchar *p, int n, char *fmt, ...)
237 {
238         va_list a;
239         va_start(a, fmt);
240         n = vpack(p, n, fmt, a);
241         va_end(a);
242         return n;
243 }
244 int
245 unpack(uchar *p, int n, char *fmt, ...)
246 {
247         va_list a;
248         va_start(a, fmt);
249         n = vunpack(p, n, fmt, a);
250         va_end(a);
251         return n;
252 }
253
254 void
255 setupcs(Oneway *c, uchar otk[32])
256 {
257         uchar iv[8];
258
259         memset(otk, 0, 32);
260         pack(iv, sizeof(iv), "uu", 0, c->seq);
261         chacha_setiv(&c->cs1, iv);
262         chacha_setiv(&c->cs2, iv);
263         chacha_setblock(&c->cs1, 0);
264         chacha_setblock(&c->cs2, 0);
265         chacha_encrypt(otk, 32, &c->cs2);
266 }
267
268 void
269 sendpkt(char *fmt, ...)
270 {
271         static uchar buf[sizeof(send.b)];
272         int n, pad;
273         va_list a;
274
275         va_start(a, fmt);
276         n = vpack(send.b, sizeof(send.b), fmt, a);
277         va_end(a);
278         if(n < 0) {
279 toobig:         sysfatal("sendpkt: message too big");
280                 return;
281         }
282         send.r = send.b;
283         send.w = send.b+n;
284
285 if(debug > 1)
286         fprint(2, "sendpkt: (%d) %.*H\n", send.r[0], (int)(send.w-send.r), send.r);
287
288         if(nsid){
289                 /* undocumented */
290                 pad = ChachaBsize - ((5+n) % ChachaBsize) + 4;
291         } else {
292                 for(pad=4; (5+n+pad) % 8; pad++)
293                         ;
294         }
295         prng(send.w, pad);
296         n = pack(buf, sizeof(buf)-16, "ub[[", 1+n+pad, pad, send.b, n, send.w, pad);
297         if(n < 0) goto toobig;
298         if(nsid){
299                 uchar otk[32];
300
301                 setupcs(&send, otk);
302                 chacha_encrypt(buf, 4, &send.cs1);
303                 chacha_encrypt(buf+4, n-4, &send.cs2);
304                 poly1305(buf, n, otk, sizeof(otk), buf+n, nil);
305                 n += 16;
306         }
307
308         if(write(fd, buf, n) != n)
309                 sysfatal("write: %r");
310
311         send.seq++;
312 }
313
314 int
315 readall(int fd, uchar *data, int len)
316 {
317         int n, tot;
318
319         for(tot = 0; tot < len; tot += n){
320                 n = read(fd, data+tot, len-tot);
321                 if(n <= 0){
322                         if(n < 0 && wasintr()){
323                                 n = 0;
324                                 continue;
325                         } else if(n == 0)
326                                 werrstr("eof");
327                         break;
328                 }
329         }
330         return tot;
331 }
332
333 int
334 recvpkt(void)
335 {
336         uchar otk[32], tag[16];
337         DigestState *ds = nil;
338         int n;
339
340         if(readall(fd, recv.b, 4) != 4)
341                 sysfatal("read1: %r");
342         if(nsid){
343                 setupcs(&recv, otk);
344                 ds = poly1305(recv.b, 4, otk, sizeof(otk), nil, nil);
345                 chacha_encrypt(recv.b, 4, &recv.cs1);
346                 unpack(recv.b, 4, "u", &n);
347                 n += 16;
348         } else {
349                 unpack(recv.b, 4, "u", &n);
350         }
351         if(n < 8 || n > sizeof(recv.b)){
352 badlen:         sysfatal("bad length %d", n);
353         }
354         if(readall(fd, recv.b, n) != n)
355                 sysfatal("read2: %r");
356         if(nsid){
357                 n -= 16;
358                 if(n < 0) goto badlen;
359                 poly1305(recv.b, n, otk, sizeof(otk), tag, ds);
360                 if(tsmemcmp(tag, recv.b+n, 16) != 0)
361                         sysfatal("bad tag");
362                 chacha_encrypt(recv.b, n, &recv.cs2);
363         }
364         n -= recv.b[0]+1;
365         if(n < 1) goto badlen;
366
367         recv.r = recv.b + 1;
368         recv.w = recv.r + n;
369         recv.seq++;
370
371 if(debug > 1)
372         fprint(2, "recvpkt: (%d) %.*H\n", recv.r[0], (int)(recv.w-recv.r), recv.r);
373
374         return recv.r[0];
375 }
376
377 static char sshrsa[] = "ssh-rsa";
378
379 int
380 rsapub2ssh(RSApub *rsa, uchar *data, int len)
381 {
382         return pack(data, len, "smm", sshrsa, sizeof(sshrsa)-1, rsa->ek, rsa->n);
383 }
384
385 RSApub*
386 ssh2rsapub(uchar *data, int len)
387 {
388         RSApub *pub;
389         char *s;
390         int n;
391
392         pub = rsapuballoc();
393         pub->n = mpnew(0);
394         pub->ek = mpnew(0);
395         if(unpack(data, len, "smm", &s, &n, pub->ek, pub->n) < 0
396         || n != sizeof(sshrsa)-1 || memcmp(s, sshrsa, n) != 0){
397                 rsapubfree(pub);
398                 return nil;
399         }
400         return pub;
401 }
402
403 int
404 rsasig2ssh(RSApub *pub, mpint *S, uchar *data, int len)
405 {
406         int l = (mpsignif(pub->n)+7)/8;
407         if(4+7+4+l > len)
408                 return -1;
409         mptober(S, data+4+7+4, l);
410         return pack(data, len, "ss", sshrsa, sizeof(sshrsa)-1, data+4+7+4, l);
411 }
412
413 mpint*
414 ssh2rsasig(uchar *data, int len)
415 {
416         mpint *m;
417         char *s;
418         int n;
419
420         m = mpnew(0);
421         if(unpack(data, len, "sm", &s, &n, m) < 0
422         || n != sizeof(sshrsa)-1 || memcmp(s, sshrsa, n) != 0){
423                 mpfree(m);
424                 return nil;
425         }
426         return m;
427 }
428
429 mpint*
430 pkcs1digest(uchar *data, int len, RSApub *pub)
431 {
432         uchar digest[SHA1dlen], buf[256];
433
434         sha1(data, len, digest, nil);
435         return pkcs1padbuf(buf, asn1encodedigest(sha1, digest, buf, sizeof(buf)), pub->n, 1);
436 }
437
438 int
439 pkcs1verify(uchar *data, int len, RSApub *pub, mpint *S)
440 {
441         mpint *V;
442         int ret;
443
444         V = pkcs1digest(data, len, pub);
445         ret = V != nil;
446         if(ret){
447                 rsaencrypt(pub, S, S);
448                 ret = mpcmp(V, S) == 0;
449                 mpfree(V);
450         }
451         return ret;
452 }
453
454 DigestState*
455 hashstr(void *data, ulong len, DigestState *ds)
456 {
457         uchar l[4];
458         pack(l, 4, "u", len);
459         return sha2_256((uchar*)data, len, nil, sha2_256(l, 4, nil, ds));
460 }
461
462 void
463 kdf(uchar *k, int nk, uchar *h, char x, uchar *out, int len)
464 {
465         uchar digest[SHA2_256dlen], *out0;
466         DigestState *ds;
467         int n;
468
469         ds = hashstr(k, nk, nil);
470         ds = sha2_256(h, sizeof(digest), nil, ds);
471         ds = sha2_256((uchar*)&x, 1, nil, ds);
472         sha2_256(sid, nsid, digest, ds);
473         for(out0=out;;){
474                 n = len;
475                 if(n > sizeof(digest))
476                         n = sizeof(digest);
477                 memmove(out, digest, n);
478                 len -= n;
479                 if(len == 0)
480                         break;
481                 out += n;
482                 ds = hashstr(k, nk, nil);
483                 ds = sha2_256(h, sizeof(digest), nil, ds);
484                 sha2_256(out0, out-out0, digest, ds);
485         }
486 }
487
488 void
489 kex(int gotkexinit)
490 {
491         static char kexalgs[] = "curve25519-sha256,curve25519-sha256@libssh.org";
492         static char cipheralgs[] = "chacha20-poly1305@openssh.com";
493         static char zipalgs[] = "none";
494         static char macalgs[] = "";
495         static char langs[] = "";
496
497         uchar cookie[16], x[32], yc[32], z[32], k[32+1], h[SHA2_256dlen], *ys, *ks, *sig;
498         uchar k12[2*ChachaKeylen];
499         int i, nk, nys, nks, nsig;
500         DigestState *ds;
501         mpint *S, *K;
502         RSApub *pub;
503
504         ds = hashstr(send.v, strlen(send.v), nil);      
505         ds = hashstr(recv.v, strlen(recv.v), ds);
506
507         genrandom(cookie, sizeof(cookie));
508         sendpkt("b[ssssssssssbu", MSG_KEXINIT,
509                 cookie, sizeof(cookie),
510                 kexalgs, sizeof(kexalgs)-1,
511                 sshrsa, sizeof(sshrsa)-1,
512                 cipheralgs, sizeof(cipheralgs)-1,
513                 cipheralgs, sizeof(cipheralgs)-1,
514                 macalgs, sizeof(macalgs)-1,
515                 macalgs, sizeof(macalgs)-1,
516                 zipalgs, sizeof(zipalgs)-1,
517                 zipalgs, sizeof(zipalgs)-1,
518                 langs, sizeof(langs)-1,
519                 langs, sizeof(langs)-1,
520                 0,
521                 0);
522         ds = hashstr(send.r, send.w-send.r, ds);
523
524         if(!gotkexinit){
525         Next0:  switch(recvpkt()){
526                 default:
527                         dispatch();
528                         goto Next0;
529                 case MSG_KEXINIT:
530                         break;
531                 }
532         }
533         ds = hashstr(recv.r, recv.w-recv.r, ds);
534
535         if(debug){
536                 char *tab[] = {
537                         "kexalgs", "hostalgs",
538                         "cipher1", "cipher2",
539                         "mac1", "mac2",
540                         "zip1", "zip2",
541                         "lang1", "lang2",
542                         nil,
543                 }, **t, *s;
544                 uchar *p = recv.r+17;
545                 int n;
546                 for(t=tab; *t != nil; t++){
547                         if(unpack(p, recv.w-p, "s.", &s, &n, &p) < 0)
548                                 break;
549                         fprint(2, "%s: %.*s\n", *t, n, s);
550                 }
551         }
552
553         curve25519_dh_new(x, yc);
554         yc[31] &= ~0x80;
555
556         sendpkt("bs", MSG_ECDH_INIT, yc, sizeof(yc));
557 Next1:  switch(recvpkt()){
558         default:
559                 dispatch();
560                 goto Next1;
561         case MSG_KEXINIT:
562                 sysfatal("inception");
563         case MSG_ECDH_REPLY:
564                 if(unpack(recv.r, recv.w-recv.r, "_sss", &ks, &nks, &ys, &nys, &sig, &nsig) < 0)
565                         sysfatal("bad ECDH_REPLY");
566                 break;
567         }
568
569         if(nys != 32)
570                 sysfatal("bad server ECDH ephermal public key length");
571
572         ds = hashstr(ks, nks, ds);
573         ds = hashstr(yc, 32, ds);
574         ds = hashstr(ys, 32, ds);
575
576         if(thumb[0] == 0){
577                 Thumbprint *ok;
578
579                 sha2_256(ks, nks, h, nil);
580                 i = enc64(thumb, sizeof(thumb), h, sizeof(h));
581                 while(i > 0 && thumb[i-1] == '=')
582                         i--;
583                 thumb[i] = '\0';
584
585                 if(debug)
586                         fprint(2, "host fingerprint: %s\n", thumb);
587
588                 ok = initThumbprints(thumbfile, nil, "ssh");
589                 if(ok == nil || !okThumbprint(h, sizeof(h), ok)){
590                         if(ok != nil) werrstr("unknown host");
591                         fprint(2, "%s: %r\n", argv0);
592                         fprint(2, "verify hostkey: %s %.*[\n", sshrsa, nks, ks);
593                         fprint(2, "add thumbprint after verification:\n");
594                         fprint(2, "\techo 'ssh sha256=%s server=%s' >> %q\n", thumb, host, thumbfile);
595                         sysfatal("checking hostkey failed: %r");
596                 }
597                 freeThumbprints(ok);
598         }
599
600         if((pub = ssh2rsapub(ks, nks)) == nil)
601                 sysfatal("bad server public key");
602         if((S = ssh2rsasig(sig, nsig)) == nil)
603                 sysfatal("bad server signature");
604
605         curve25519_dh_finish(x, ys, z);
606
607         K = betomp(z, 32, nil);
608         nk = (mpsignif(K)+8)/8;
609         mptober(K, k, nk);
610         mpfree(K);
611
612         ds = hashstr(k, nk, ds);
613         sha2_256(nil, 0, h, ds);
614         if(!pkcs1verify(h, sizeof(h), pub, S))
615                 sysfatal("server verification failed");
616         mpfree(S);
617         rsapubfree(pub);
618
619         sendpkt("b", MSG_NEWKEYS);
620 Next2:  switch(recvpkt()){
621         default:
622                 dispatch();
623                 goto Next2;
624         case MSG_KEXINIT:
625                 sysfatal("inception");
626         case MSG_NEWKEYS:
627                 break;
628         }
629
630         /* next key exchange */
631         recv.kex = recv.seq + 100000;
632         send.kex = send.seq + 100000;
633
634         if(nsid == 0)
635                 memmove(sid, h, nsid = sizeof(h));
636
637         kdf(k, nk, h, 'C', k12, sizeof(k12));
638         setupChachastate(&send.cs1, k12+1*ChachaKeylen, ChachaKeylen, nil, 64/8, 20);
639         setupChachastate(&send.cs2, k12+0*ChachaKeylen, ChachaKeylen, nil, 64/8, 20);
640
641         kdf(k, nk, h, 'D', k12, sizeof(k12));
642         setupChachastate(&recv.cs1, k12+1*ChachaKeylen, ChachaKeylen, nil, 64/8, 20);
643         setupChachastate(&recv.cs2, k12+0*ChachaKeylen, ChachaKeylen, nil, 64/8, 20);
644 }
645
646 static char *authnext;
647
648 int
649 authok(char *meth)
650 {
651         int ok = authnext == nil || strstr(authnext, meth) != nil;
652 if(debug)
653         fprint(2, "userauth %s %s\n", meth, ok ? "ok" : "skipped");
654         return ok;
655 }
656
657 int
658 authfailure(char *meth)
659 {
660         char *s;
661         int n, partial;
662
663         if(unpack(recv.r, recv.w-recv.r, "_sb", &s, &n, &partial) < 0)
664                 sysfatal("bad auth failure response");
665         free(authnext);
666         authnext = smprint("%.*s", n, s);
667 if(debug)
668         fprint(2, "userauth %s failed: partial=%d, next=%s\n", meth, partial, authnext);
669         return partial != 0 || !authok(meth);
670 }
671
672 int
673 noneauth(void)
674 {
675         static char authmeth[] = "none";
676
677         if(!authok(authmeth))
678                 return -1;
679
680         sendpkt("bsss", MSG_USERAUTH_REQUEST,
681                 user, strlen(user),
682                 service, strlen(service),
683                 authmeth, sizeof(authmeth)-1);
684
685 Next0:  switch(recvpkt()){
686         default:
687                 dispatch();
688                 goto Next0;
689         case MSG_USERAUTH_FAILURE:
690                 werrstr("authentication needed");
691                 authfailure(authmeth);
692                 return -1;
693         case MSG_USERAUTH_SUCCESS:
694                 return 0;
695         }
696 }
697
698 int
699 pubkeyauth(void)
700 {
701         static char authmeth[] = "publickey";
702
703         uchar pk[4096], sig[4096];
704         int npk, nsig;
705
706         int afd, n;
707         char *s;
708         mpint *S;
709         AuthRpc *rpc;
710         RSApub *pub;
711
712         if(!authok(authmeth))
713                 return -1;
714
715         if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0)
716                 return -1;
717         if((rpc = auth_allocrpc(afd)) == nil){
718                 close(afd);
719                 return -1;
720         }
721
722         s = "proto=rsa service=ssh role=client";
723         if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){
724                 auth_freerpc(rpc);
725                 close(afd);
726                 return -1;
727         }
728
729         pub = rsapuballoc();
730         pub->n = mpnew(0);
731         pub->ek = mpnew(0);
732
733         while(auth_rpc(rpc, "read", nil, 0) == ARok){
734                 s = rpc->arg;
735                 if(strtomp(s, &s, 16, pub->n) == nil)
736                         break;
737                 if(*s++ != ' ')
738                         continue;
739                 if(strtomp(s, nil, 16, pub->ek) == nil)
740                         continue;
741                 npk = rsapub2ssh(pub, pk, sizeof(pk));
742
743                 sendpkt("bsssbss", MSG_USERAUTH_REQUEST,
744                         user, strlen(user),
745                         service, strlen(service),
746                         authmeth, sizeof(authmeth)-1,
747                         0,
748                         sshrsa, sizeof(sshrsa)-1,
749                         pk, npk);
750 Next1:          switch(recvpkt()){
751                 default:
752                         dispatch();
753                         goto Next1;
754                 case MSG_USERAUTH_FAILURE:
755                         if(authfailure(authmeth))
756                                 goto Failed;
757                         continue;
758                 case MSG_USERAUTH_SUCCESS:
759                 case MSG_USERAUTH_PK_OK:
760                         break;
761                 }
762
763                 /* sign sid and the userauth request */
764                 n = pack(send.b, sizeof(send.b), "sbsssbss",
765                         sid, nsid,
766                         MSG_USERAUTH_REQUEST,
767                         user, strlen(user),
768                         service, strlen(service),
769                         authmeth, sizeof(authmeth)-1,
770                         1,
771                         sshrsa, sizeof(sshrsa)-1,
772                         pk, npk);
773                 S = pkcs1digest(send.b, n, pub);
774                 n = snprint((char*)send.b, sizeof(send.b), "%B", S);
775                 mpfree(S);
776
777                 if(auth_rpc(rpc, "write", (char*)send.b, n) != ARok)
778                         break;
779                 if(auth_rpc(rpc, "read", nil, 0) != ARok)
780                         break;
781
782                 S = strtomp(rpc->arg, nil, 16, nil);
783                 nsig = rsasig2ssh(pub, S, sig, sizeof(sig));
784                 mpfree(S);
785
786                 /* send final userauth request with the signature */
787                 sendpkt("bsssbsss", MSG_USERAUTH_REQUEST,
788                         user, strlen(user),
789                         service, strlen(service),
790                         authmeth, sizeof(authmeth)-1,
791                         1,
792                         sshrsa, sizeof(sshrsa)-1,
793                         pk, npk,
794                         sig, nsig);
795 Next2:          switch(recvpkt()){
796                 default:
797                         dispatch();
798                         goto Next2;
799                 case MSG_USERAUTH_FAILURE:
800                         if(authfailure(authmeth))
801                                 goto Failed;
802                         continue;
803                 case MSG_USERAUTH_SUCCESS:
804                         break;
805                 }
806                 rsapubfree(pub);
807                 auth_freerpc(rpc);
808                 close(afd);
809                 return 0;
810         }
811 Failed:
812         rsapubfree(pub);
813         auth_freerpc(rpc);
814         close(afd);
815         return -1;      
816 }
817
818 int
819 passauth(void)
820 {
821         static char authmeth[] = "password";
822         UserPasswd *up;
823
824         if(!authok(authmeth))
825                 return -1;
826
827         up = auth_getuserpasswd(auth_getkey, "proto=pass service=ssh user=%q server=%q thumb=%q",
828                 user, host, thumb);
829         if(up == nil)
830                 return -1;
831
832         sendpkt("bsssbs", MSG_USERAUTH_REQUEST,
833                 user, strlen(user),
834                 service, strlen(service),
835                 authmeth, sizeof(authmeth)-1,
836                 0,
837                 up->passwd, strlen(up->passwd));
838
839         memset(up->passwd, 0, strlen(up->passwd));
840         free(up);
841
842 Next0:  switch(recvpkt()){
843         default:
844                 dispatch();
845                 goto Next0;
846         case MSG_USERAUTH_FAILURE:
847                 werrstr("wrong password");
848                 authfailure(authmeth);
849                 return -1;
850         case MSG_USERAUTH_SUCCESS:
851                 return 0;
852         }
853 }
854
855 int
856 kbintauth(void)
857 {
858         static char authmeth[] = "keyboard-interactive";
859         int tries;
860
861         char *name, *inst, *s, *a;
862         int fd, i, n, m;
863         int nquest, echo;
864         uchar *ans, *answ;
865         tries = 0;
866
867         if(!authok(authmeth))
868                 return -1;
869
870 Loop:
871         if(++tries > MaxPwTries)
872                 return -1;
873                 
874         sendpkt("bsssss", MSG_USERAUTH_REQUEST,
875                 user, strlen(user),
876                 service, strlen(service),
877                 authmeth, sizeof(authmeth)-1,
878                 "", 0,
879                 "", 0);
880
881 Next0:  switch(recvpkt()){
882         default:
883                 dispatch();
884                 goto Next0;
885         case MSG_USERAUTH_FAILURE:
886                 werrstr("keyboard-interactive failed");
887                 if(authfailure(authmeth))
888                         return -1;
889                 goto Loop;
890         case MSG_USERAUTH_SUCCESS:
891                 return 0;
892         case MSG_USERAUTH_INFO_REQUEST:
893                 break;
894         }
895 Retry:
896         if((fd = open("/dev/cons", OWRITE)) < 0)
897                 return -1;
898
899         if(unpack(recv.r, recv.w-recv.r, "_ss.", &name, &n, &inst, &m, &recv.r) < 0)
900                 sysfatal("bad info request: name, inst");
901
902         while(n > 0 && strchr("\r\n\t ", name[n-1]) != nil)
903                 n--;
904         while(m > 0 && strchr("\r\n\t ", inst[m-1]) != nil)
905                 m--;
906
907         if(n > 0)
908                 fprint(fd, "%.*s\n", n, name);
909         if(m > 0)
910                 fprint(fd, "%.*s\n", m, inst);
911
912         /* lang, nprompt */
913         if(unpack(recv.r, recv.w-recv.r, "su.", &s, &n, &nquest, &recv.r) < 0)
914                 sysfatal("bad info request: lang, #quest");
915
916         ans = answ = nil;
917         for(i = 0; i < nquest; i++){
918                 if(unpack(recv.r, recv.w-recv.r, "sb.", &s, &n, &echo, &recv.r) < 0)
919                         sysfatal("bad info request: question [%d]", i);
920
921                 while(n > 0 && strchr("\r\n\t :", s[n-1]) != nil)
922                         n--;
923                 s[n] = '\0';
924
925                 if((a = readcons(s, nil, !echo)) == nil)
926                         sysfatal("readcons: %r");
927
928                 n = answ - ans;
929                 m = strlen(a)+4;
930                 if((s = realloc(ans, n + m)) == nil)
931                         sysfatal("realloc: %r");
932                 ans = (uchar*)s;
933                 answ = ans+n;
934                 answ += pack(answ, m, "s", a, m-4);
935         }
936
937         sendpkt("bu[", MSG_USERAUTH_INFO_RESPONSE, i, ans, answ - ans);
938         free(ans);
939         close(fd);
940
941 Next1:  switch(recvpkt()){
942         default:
943                 dispatch();
944                 goto Next1;
945         case MSG_USERAUTH_INFO_REQUEST:
946                 goto Retry;
947         case MSG_USERAUTH_FAILURE:
948                 werrstr("keyboard-interactive failed");
949                 if(authfailure(authmeth))
950                         return -1;
951                 goto Loop;
952         case MSG_USERAUTH_SUCCESS:
953                 return 0;
954         }
955 }
956
957 void
958 dispatch(void)
959 {
960         char *s;
961         uchar *p;
962         int n, b, c;
963
964         switch(recv.r[0]){
965         case MSG_IGNORE:
966         case MSG_GLOBAL_REQUEST:
967                 return;
968         case MSG_DISCONNECT:
969                 if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0)
970                         break;
971                 sysfatal("disconnect: (%d) %.*s", c, n, s);
972                 return;
973         case MSG_DEBUG:
974                 if(unpack(recv.r, recv.w-recv.r, "__sb", &s, &n, &c) < 0)
975                         break;
976                 if(c != 0 || debug) fprint(2, "%s: %.*s\n", argv0, n, s);
977                 return;
978         case MSG_USERAUTH_BANNER:
979                 if(unpack(recv.r, recv.w-recv.r, "_s", &s, &n) < 0)
980                         break;
981                 if(raw) write(2, s, n);
982                 return;
983         case MSG_CHANNEL_DATA:
984                 if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0)
985                         break;
986                 if(c != recv.chan)
987                         break;
988                 if(write(1, s, n) != n)
989                         sysfatal("write out: %r");
990         Winadjust:
991                 recv.win -= n;
992                 if(recv.win < recv.pkt){
993                         n = WinPackets*recv.pkt;
994                         recv.win += n;
995                         sendpkt("buu", MSG_CHANNEL_WINDOW_ADJUST, send.chan, n);
996                 }
997                 return;
998         case MSG_CHANNEL_EXTENDED_DATA:
999                 if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0)
1000                         break;
1001                 if(c != recv.chan)
1002                         break;
1003                 if(b == 1) write(2, s, n);
1004                 goto Winadjust;
1005         case MSG_CHANNEL_WINDOW_ADJUST:
1006                 if(unpack(recv.r, recv.w-recv.r, "_uu", &c, &n) < 0)
1007                         break;
1008                 if(c != recv.chan)
1009                         break;
1010                 send.win += n;
1011                 if(send.win >= send.pkt)
1012                         rwakeup(&send);
1013                 return;
1014         case MSG_CHANNEL_REQUEST:
1015                 if(unpack(recv.r, recv.w-recv.r, "_usb.", &c, &s, &n, &b, &p) < 0)
1016                         break;
1017                 if(c != recv.chan)
1018                         break;
1019                 if(n == 11 && memcmp(s, "exit-signal", n) == 0){
1020                         if(unpack(p, recv.w-p, "s", &s, &n) < 0)
1021                                 break;
1022                         if(n != 0 && status == nil)
1023                                 status = smprint("%.*s", n, s);
1024                 } else if(n == 11 && memcmp(s, "exit-status", n) == 0){
1025                         if(unpack(p, recv.w-p, "u", &n) < 0)
1026                                 break;
1027                         if(n != 0 && status == nil)
1028                                 status = smprint("%d", n);
1029                 } else if(debug) {
1030                         fprint(2, "%s: channel request: %.*s\n", argv0, n, s);
1031                 }
1032                 return;
1033         case MSG_CHANNEL_EOF:
1034                 recv.eof = 1;
1035                 if(!raw) write(1, "", 0);
1036                 return;
1037         case MSG_CHANNEL_CLOSE:
1038                 shutdown();
1039                 return;
1040         case MSG_KEXINIT:
1041                 kex(1);
1042                 return;
1043         }
1044         sysfatal("got: %.*H", (int)(recv.w - recv.r), recv.r);
1045 }
1046
1047 char*
1048 readline(void)
1049 {
1050         uchar *p;
1051
1052         for(p = send.b; p < &send.b[sizeof(send.b)-1]; p++){
1053                 *p = '\0';
1054                 if(read(fd, p, 1) != 1 || *p == '\n')
1055                         break;
1056         }
1057         while(p >= send.b && (*p == '\n' || *p == '\r'))
1058                 *p-- = '\0';
1059         return (char*)send.b;
1060 }
1061
1062 static struct {
1063         char    *term;
1064         int     xpixels;
1065         int     ypixels;
1066         int     lines;
1067         int     cols;
1068 } tty;
1069
1070 void
1071 getdim(void)
1072 {
1073         char *s;
1074
1075         if(s = getenv("XPIXELS")){
1076                 tty.xpixels = atoi(s);
1077                 free(s);
1078         }
1079         if(s = getenv("YPIXELS")){
1080                 tty.ypixels = atoi(s);
1081                 free(s);
1082         }
1083         if(s = getenv("LINES")){
1084                 tty.lines = atoi(s);
1085                 free(s);
1086         }
1087         if(s = getenv("COLS")){
1088                 tty.cols = atoi(s);
1089                 free(s);
1090         }
1091 }
1092
1093 void
1094 rawon(void)
1095 {
1096         int ctl;
1097
1098         close(0);
1099         if(open("/dev/cons", OREAD) != 0)
1100                 sysfatal("open: %r");
1101         close(1);
1102         if(open("/dev/cons", OWRITE) != 1)
1103                 sysfatal("open: %r");
1104         dup(1, 2);
1105         if((ctl = open("/dev/consctl", OWRITE)) >= 0){
1106                 write(ctl, "rawon", 5);
1107                 write(ctl, "winchon", 7);       /* vt(1): interrupt note on window change */
1108         }
1109         getdim();
1110 }
1111
1112 #pragma    varargck    type  "k"   char*
1113
1114 kfmt(Fmt *f)
1115 {
1116         char *s, *p;
1117         int n;
1118
1119         s = va_arg(f->args, char*);
1120         n = fmtstrcpy(f, "'");
1121         while((p = strchr(s, '\'')) != nil){
1122                 *p = '\0';
1123                 n += fmtstrcpy(f, s);
1124                 *p = '\'';
1125                 n += fmtstrcpy(f, "'\\''");
1126                 s = p+1;
1127         }
1128         n += fmtstrcpy(f, s);
1129         n += fmtstrcpy(f, "'");
1130         return n;
1131 }
1132
1133 void
1134 usage(void)
1135 {
1136         fprint(2, "usage: %s [-dR] [-t thumbfile] [-T tries] [-u user] [-h] [user@]host [cmd args...]\n", argv0);
1137         exits("usage");
1138 }
1139
1140 void
1141 main(int argc, char *argv[])
1142 {
1143         static QLock sl;
1144         int b, n, c;
1145         char *s;
1146
1147         quotefmtinstall();
1148         fmtinstall('B', mpfmt);
1149         fmtinstall('H', encodefmt);
1150         fmtinstall('[', encodefmt);
1151         fmtinstall('k', kfmt);
1152
1153         tty.term = getenv("TERM");
1154         raw = tty.term != nil && *tty.term != 0;
1155
1156         ARGBEGIN {
1157         case 'd':
1158                 debug++;
1159                 break;
1160         case 'R':
1161                 raw = 0;
1162                 break;
1163         case 'u':
1164                 user = EARGF(usage());
1165                 break;
1166         case 'h':
1167                 host = EARGF(usage());
1168                 break;
1169         case 't':
1170                 thumbfile = EARGF(usage());
1171                 break;
1172         case 'T':
1173                 MaxPwTries = strtol(EARGF(usage()), &s, 0);
1174                 if(*s != 0) usage();
1175                 break;
1176         } ARGEND;
1177
1178         if(host == nil){
1179                 if(argc == 0)
1180                         usage();
1181                 host = *argv++;
1182         }
1183
1184         if(user == nil){
1185                 s = strchr(host, '@');
1186                 if(s != nil){
1187                         *s++ = '\0';
1188                         user = host;
1189                         host = s;
1190                 }
1191         }
1192
1193         for(cmd = nil; *argv != nil; argv++){
1194                 if(cmd == nil){
1195                         cmd = strdup(*argv);
1196                         raw = 0;
1197                 }else {
1198                         s = smprint("%s %k", cmd, *argv);
1199                         free(cmd);
1200                         cmd = s;
1201                 }
1202         }
1203
1204         if((fd = dial(netmkaddr(host, nil, "ssh"), nil, nil, nil)) < 0)
1205                 sysfatal("dial: %r");
1206
1207         send.v = "SSH-2.0-(9)";
1208         fprint(fd, "%s\r\n", send.v);
1209         recv.v = readline();
1210         if(debug)
1211                 fprint(2, "server verison: %s\n", recv.v);
1212         if(strncmp("SSH-2.0-", recv.v, 8) != 0)
1213                 sysfatal("bad server version: %s", recv.v);
1214         recv.v = strdup(recv.v);
1215
1216         send.l = recv.l = &sl;
1217
1218         if(user == nil)
1219                 user = getuser();
1220         if(thumbfile == nil)
1221                 thumbfile = smprint("%s/lib/sshthumbs", getenv("home"));
1222
1223         kex(0);
1224
1225         sendpkt("bs", MSG_SERVICE_REQUEST, "ssh-userauth", 12);
1226 Next0:  switch(recvpkt()){
1227         default:
1228                 dispatch();
1229                 goto Next0;
1230         case MSG_SERVICE_ACCEPT:
1231                 break;
1232         }
1233
1234         service = "ssh-connection";
1235         if(noneauth() < 0 && pubkeyauth() < 0 && passauth() < 0 && kbintauth() < 0)
1236                 sysfatal("auth: %r");
1237
1238         recv.pkt = MaxPacket;
1239         recv.win = WinPackets*recv.pkt;
1240         recv.chan = 0;
1241
1242         /* open hailing frequencies */
1243         sendpkt("bsuuu", MSG_CHANNEL_OPEN,
1244                 "session", 7,
1245                 recv.chan,
1246                 recv.win,
1247                 recv.pkt);
1248
1249 Next1:  switch(recvpkt()){
1250         default:
1251                 dispatch();
1252                 goto Next1;
1253         case MSG_CHANNEL_OPEN_FAILURE:
1254                 if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0)
1255                         n = strlen(s = "???");
1256                 sysfatal("channel open failure: (%d) %.*s", b, n, s);
1257         case MSG_CHANNEL_OPEN_CONFIRMATION:
1258                 break;
1259         }
1260
1261         if(unpack(recv.r, recv.w-recv.r, "_uuuu", &recv.chan, &send.chan, &send.win, &send.pkt) < 0)
1262                 sysfatal("bad channel open confirmation");
1263         if(send.pkt <= 0 || send.pkt > MaxPacket)
1264                 send.pkt = MaxPacket;
1265
1266         notify(catch);
1267         atexit(shutdown);
1268
1269         recv.pid = getpid();
1270         n = rfork(RFPROC|RFMEM);
1271         if(n < 0)
1272                 sysfatal("fork: %r");
1273
1274         /* parent reads and dispatches packets */
1275         if(n > 0) {
1276                 send.pid = n;
1277                 while((send.eof|recv.eof) == 0){
1278                         recvpkt();
1279                         qlock(&sl);                                     
1280                         dispatch();
1281                         if((int)(send.kex - send.seq) <= 0 || (int)(recv.kex - recv.seq) <= 0)
1282                                 kex(0);
1283                         qunlock(&sl);
1284                 }
1285                 exits(status);
1286         }
1287
1288         /* child reads input and sends packets */
1289         qlock(&sl);
1290         if(raw) {
1291                 rawon();
1292                 sendpkt("busbsuuuus", MSG_CHANNEL_REQUEST,
1293                         send.chan,
1294                         "pty-req", 7,
1295                         0,
1296                         tty.term, strlen(tty.term),
1297                         tty.cols,
1298                         tty.lines,
1299                         tty.xpixels,
1300                         tty.ypixels,
1301                         "", 0);
1302         }
1303         if(cmd == nil){
1304                 sendpkt("busb", MSG_CHANNEL_REQUEST,
1305                         send.chan,
1306                         "shell", 5,
1307                         0);
1308         } else if(*cmd == '#') {
1309                 sendpkt("busbs", MSG_CHANNEL_REQUEST,
1310                         send.chan,
1311                         "subsystem", 9,
1312                         0,
1313                         cmd+1, strlen(cmd)-1);
1314         } else {
1315                 sendpkt("busbs", MSG_CHANNEL_REQUEST,
1316                         send.chan,
1317                         "exec", 4,
1318                         0,
1319                         cmd, strlen(cmd));
1320         }
1321         for(;;){
1322                 static uchar buf[MaxPacket];
1323                 qunlock(&sl);
1324                 n = read(0, buf, send.pkt);
1325                 qlock(&sl);
1326                 if(send.eof)
1327                         break;
1328                 if(n < 0 && wasintr()){
1329                         if(!raw) break;
1330                         if(intr){
1331                                 getdim();
1332                                 sendpkt("busbuuuu", MSG_CHANNEL_REQUEST,
1333                                         send.chan,
1334                                         "window-change", 13,
1335                                         0,
1336                                         tty.cols,
1337                                         tty.lines,
1338                                         tty.xpixels,
1339                                         tty.ypixels);
1340                                 sendpkt("busbs", MSG_CHANNEL_REQUEST,
1341                                         send.chan,
1342                                         "signal", 6,
1343                                         0,
1344                                         "INT", 3);
1345                                 intr = 0;
1346                         }
1347                         continue;
1348                 }
1349                 if(n <= 0)
1350                         break;
1351                 send.win -= n;
1352                 while(send.win < 0)
1353                         rsleep(&send);
1354                 sendpkt("bus", MSG_CHANNEL_DATA,
1355                         send.chan,
1356                         buf, n);
1357         }
1358         if(send.eof++ == 0)
1359                 sendpkt("bu", raw ? MSG_CHANNEL_CLOSE : MSG_CHANNEL_EOF, send.chan);
1360         qunlock(&sl);
1361
1362         exits(nil);
1363 }