]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/ssh.c
ssh: actually handle flow control and channel id's
[plan9front.git] / sys / src / cmd / ssh.c
1 #include <u.h>
2 #include <libc.h>
3 #include <mp.h>
4 #include <libsec.h>
5 #include <auth.h>
6 #include <authsrv.h>
7
8 enum {
9         MSG_DISCONNECT = 1,
10         MSG_IGNORE,
11         MSG_UNIMPLEMENTED,
12         MSG_DEBUG,
13         MSG_SERVICE_REQUEST,
14         MSG_SERVICE_ACCEPT,
15
16         MSG_KEXINIT = 20,
17         MSG_NEWKEYS,
18
19         MSG_ECDH_INIT = 30,
20         MSG_ECDH_REPLY,
21
22         MSG_USERAUTH_REQUEST = 50,
23         MSG_USERAUTH_FAILURE,
24         MSG_USERAUTH_SUCCESS,
25         MSG_USERAUTH_BANNER,
26
27         MSG_USERAUTH_PK_OK = 60,
28         MSG_USERAUTH_INFO_REQUEST = 60,
29         MSG_USERAUTH_INFO_RESPONSE = 61,
30
31         MSG_GLOBAL_REQUEST = 80,
32         MSG_REQUEST_SUCCESS,
33         MSG_REQUEST_FAILURE,
34
35         MSG_CHANNEL_OPEN = 90,
36         MSG_CHANNEL_OPEN_CONFIRMATION,
37         MSG_CHANNEL_OPEN_FAILURE,
38         MSG_CHANNEL_WINDOW_ADJUST,
39         MSG_CHANNEL_DATA,
40         MSG_CHANNEL_EXTENDED_DATA,
41         MSG_CHANNEL_EOF,
42         MSG_CHANNEL_CLOSE,
43         MSG_CHANNEL_REQUEST,
44         MSG_CHANNEL_SUCCESS,
45         MSG_CHANNEL_FAILURE,
46 };
47
48
49 enum {
50         Overhead = 256,         // enougth for MSG_CHANNEL_DATA header
51         MaxPacket = 1<<15,
52         WinPackets = 8,         // (1<<15) * 8 = 256K
53 };
54
55 typedef struct
56 {
57         u32int          seq;
58         u32int          kex;
59         u32int          chan;
60
61         int             win;
62         int             pkt;
63         int             eof;
64
65         Chachastate     cs1;
66         Chachastate     cs2;
67
68         uchar           *r;
69         uchar           *w;
70         uchar           b[Overhead + MaxPacket];
71
72         char            *v;
73         int             pid;
74         Rendez;
75 } Oneway;
76
77 int nsid;
78 uchar sid[256];
79 char thumb[2*SHA2_256dlen+1];
80
81 int fd, intr, raw, debug;
82 char *user, *service, *status, *host, *cmd;
83
84 Oneway recv, send;
85 void dispatch(void);
86
87 void
88 shutdown(void)
89 {
90         recv.eof = send.eof = 1;
91         if(send.pid > 0)
92                 postnote(PNPROC, send.pid, "shutdown");
93 }
94
95 void
96 catch(void*, char *msg)
97 {
98         if(strstr(msg, "interrupt") != nil){
99                 intr = 1;
100                 noted(NCONT);
101         }
102         noted(NDFLT);
103 }
104
105 int
106 wasintr(void)
107 {
108         char err[ERRMAX];
109         int r;
110
111         if(intr)
112                 return 1;
113         memset(err, 0, sizeof(err));
114         errstr(err, sizeof(err));
115         r = strstr(err, "interrupt") != nil;
116         errstr(err, sizeof(err));
117         return r;
118 }
119
120 #define PUT4(p, u) (p)[0] = (u)>>24, (p)[1] = (u)>>16, (p)[2] = (u)>>8, (p)[3] = (u)
121 #define GET4(p) (u32int)(p)[3] | (u32int)(p)[2]<<8 | (u32int)(p)[1]<<16 | (u32int)(p)[0]<<24
122
123 int
124 vpack(uchar *p, int n, char *fmt, va_list a)
125 {
126         uchar *p0 = p, *e = p+n;
127         u32int u;
128         mpint *m;
129         void *s;
130         int c;
131
132         for(;;){
133                 switch(c = *fmt++){
134                 case '\0':
135                         return p - p0;
136                 case '_':
137                         if(++p > e) goto err;
138                         break;
139                 case '.':
140                         *va_arg(a, void**) = p;
141                         break;
142                 case 'b':
143                         if(p >= e) goto err;
144                         *p++ = va_arg(a, int);
145                         break;
146                 case 'm':
147                         m = va_arg(a, mpint*);
148                         u = (mpsignif(m)+8)/8;
149                         if(p+4 > e) goto err;
150                         PUT4(p, u), p += 4;
151                         if(u > e-p) goto err;
152                         mptober(m, p, u), p += u;
153                         break;
154                 case '[':
155                 case 's':
156                         s = va_arg(a, void*);
157                         u = va_arg(a, int);
158                         if(c == 's'){
159                                 if(p+4 > e) goto err;
160                                 PUT4(p, u), p += 4;
161                         }
162                         if(u > e-p) goto err;
163                         memmove(p, s, u);
164                         p += u;
165                         break;
166                 case 'u':
167                         u = va_arg(a, int);
168                         if(p+4 > e) goto err;
169                         PUT4(p, u), p += 4;
170                         break;
171                 }
172         }
173 err:
174         return -1;
175 }
176
177 int
178 vunpack(uchar *p, int n, char *fmt, va_list a)
179 {
180         uchar *p0 = p, *e = p+n;
181         u32int u;
182         mpint *m;
183         void *s;
184
185         for(;;){
186                 switch(*fmt++){
187                 case '\0':
188                         return p - p0;
189                 case '_':
190                         if(++p > e) goto err;
191                         break;
192                 case '.':
193                         *va_arg(a, void**) = p;
194                         break;
195                 case 'b':
196                         if(p >= e) goto err;
197                         *va_arg(a, int*) = *p++;
198                         break;
199                 case 'm':
200                         if(p+4 > e) goto err;
201                         u = GET4(p), p += 4;
202                         if(u > e-p) goto err;
203                         m = va_arg(a, mpint*);
204                         betomp(p, u, m), p += u;
205                         break;
206                 case 's':
207                         if(p+4 > e) goto err;
208                         u = GET4(p), p += 4;
209                         if(u > e-p) goto err;
210                         *va_arg(a, void**) = p;
211                         *va_arg(a, int*) = u;
212                         p += u;
213                         break;
214                 case '[':
215                         s = va_arg(a, void*);
216                         u = va_arg(a, int);
217                         if(u > e-p) goto err;
218                         memmove(s, p, u);
219                         p += u;
220                         break;
221                 case 'u':
222                         if(p+4 > e) goto err;
223                         u = GET4(p);
224                         *va_arg(a, int*) = u;
225                         p += 4;
226                         break;
227                 }
228         }
229 err:
230         return -1;
231 }
232
233 int
234 pack(uchar *p, int n, char *fmt, ...)
235 {
236         va_list a;
237         va_start(a, fmt);
238         n = vpack(p, n, fmt, a);
239         va_end(a);
240         return n;
241 }
242 int
243 unpack(uchar *p, int n, char *fmt, ...)
244 {
245         va_list a;
246         va_start(a, fmt);
247         n = vunpack(p, n, fmt, a);
248         va_end(a);
249         return n;
250 }
251
252 void
253 setupcs(Oneway *c, uchar otk[32])
254 {
255         uchar iv[8];
256
257         memset(otk, 0, 32);
258         pack(iv, sizeof(iv), "uu", 0, c->seq);
259         chacha_setiv(&c->cs1, iv);
260         chacha_setiv(&c->cs2, iv);
261         chacha_setblock(&c->cs1, 0);
262         chacha_setblock(&c->cs2, 0);
263         chacha_encrypt(otk, 32, &c->cs2);
264 }
265
266 void
267 sendpkt(char *fmt, ...)
268 {
269         static uchar buf[sizeof(send.b)];
270         int n, pad;
271         va_list a;
272
273         va_start(a, fmt);
274         n = vpack(send.b, sizeof(send.b), fmt, a);
275         va_end(a);
276         if(n < 0) {
277 toobig:         sysfatal("sendpkt: message too big");
278                 return;
279         }
280         send.r = send.b;
281         send.w = send.b+n;
282
283 if(debug > 1)
284         fprint(2, "sendpkt: (%d) %.*H\n", send.r[0], (int)(send.w-send.r), send.r);
285
286         if(nsid){
287                 /* undocumented */
288                 pad = ChachaBsize - ((5+n) % ChachaBsize) + 4;
289         } else {
290                 for(pad=4; (5+n+pad) % 8; pad++)
291                         ;
292         }
293         prng(send.w, pad);
294         n = pack(buf, sizeof(buf)-16, "ub[[", 1+n+pad, pad, send.b, n, send.w, pad);
295         if(n < 0) goto toobig;
296         if(nsid){
297                 uchar otk[32];
298
299                 setupcs(&send, otk);
300                 chacha_encrypt(buf, 4, &send.cs1);
301                 chacha_encrypt(buf+4, n-4, &send.cs2);
302                 poly1305(buf, n, otk, sizeof(otk), buf+n, nil);
303                 n += 16;
304         }
305
306         if(write(fd, buf, n) != n)
307                 sysfatal("write: %r");
308
309         send.seq++;
310 }
311
312 int
313 readall(int fd, uchar *data, int len)
314 {
315         int n, tot;
316
317         for(tot = 0; tot < len; tot += n){
318                 n = read(fd, data+tot, len-tot);
319                 if(n <= 0){
320                         if(n < 0 && wasintr()){
321                                 n = 0;
322                                 continue;
323                         } else if(n == 0)
324                                 werrstr("eof");
325                         break;
326                 }
327         }
328         return tot;
329 }
330
331 int
332 recvpkt(void)
333 {
334         uchar otk[32], tag[16];
335         DigestState *ds = nil;
336         int n;
337
338         if(readall(fd, recv.b, 4) != 4)
339                 sysfatal("read1: %r");
340         if(nsid){
341                 setupcs(&recv, otk);
342                 ds = poly1305(recv.b, 4, otk, sizeof(otk), nil, nil);
343                 chacha_encrypt(recv.b, 4, &recv.cs1);
344                 unpack(recv.b, 4, "u", &n);
345                 n += 16;
346         } else {
347                 unpack(recv.b, 4, "u", &n);
348         }
349         if(n < 8 || n > sizeof(recv.b)){
350 badlen:         sysfatal("bad length %d", n);
351         }
352         if(readall(fd, recv.b, n) != n)
353                 sysfatal("read2: %r");
354         if(nsid){
355                 n -= 16;
356                 if(n < 0) goto badlen;
357                 poly1305(recv.b, n, otk, sizeof(otk), tag, ds);
358                 if(tsmemcmp(tag, recv.b+n, 16) != 0)
359                         sysfatal("bad tag");
360                 chacha_encrypt(recv.b, n, &recv.cs2);
361         }
362         n -= recv.b[0]+1;
363         if(n < 1) goto badlen;
364
365         recv.r = recv.b + 1;
366         recv.w = recv.r + n;
367         recv.seq++;
368
369 if(debug > 1)
370         fprint(2, "recvpkt: (%d) %.*H\n", recv.r[0], (int)(recv.w-recv.r), recv.r);
371
372         return recv.r[0];
373 }
374
375 static char sshrsa[] = "ssh-rsa";
376
377 int
378 rsapub2ssh(RSApub *rsa, uchar *data, int len)
379 {
380         return pack(data, len, "smm", sshrsa, sizeof(sshrsa)-1, rsa->ek, rsa->n);
381 }
382
383 RSApub*
384 ssh2rsapub(uchar *data, int len)
385 {
386         RSApub *pub;
387         char *s;
388         int n;
389
390         pub = rsapuballoc();
391         pub->n = mpnew(0);
392         pub->ek = mpnew(0);
393         if(unpack(data, len, "smm", &s, &n, pub->ek, pub->n) < 0
394         || n != sizeof(sshrsa)-1 || memcmp(s, sshrsa, n) != 0){
395                 rsapubfree(pub);
396                 return nil;
397         }
398         return pub;
399 }
400
401 int
402 rsasig2ssh(RSApub *pub, mpint *S, uchar *data, int len)
403 {
404         int l = (mpsignif(pub->n)+7)/8;
405         if(4+7+4+l > len)
406                 return -1;
407         mptober(S, data+4+7+4, l);
408         return pack(data, len, "ss", sshrsa, sizeof(sshrsa)-1, data+4+7+4, l);
409 }
410
411 mpint*
412 ssh2rsasig(uchar *data, int len)
413 {
414         mpint *m;
415         char *s;
416         int n;
417
418         m = mpnew(0);
419         if(unpack(data, len, "sm", &s, &n, m) < 0
420         || n != sizeof(sshrsa)-1 || memcmp(s, sshrsa, n) != 0){
421                 mpfree(m);
422                 return nil;
423         }
424         return m;
425 }
426
427 /* libsec */
428 extern mpint* pkcs1padbuf(uchar *buf, int len, mpint *modulus, int blocktype);
429 extern int asn1encodedigest(DigestState* (*fun)(uchar*, ulong, uchar*, DigestState*),
430         uchar *digest, uchar *buf, int len);
431
432 mpint*
433 pkcs1digest(uchar *data, int len, RSApub *pub)
434 {
435         uchar digest[SHA1dlen], buf[256];
436
437         sha1(data, len, digest, nil);
438         return pkcs1padbuf(buf, asn1encodedigest(sha1, digest, buf, sizeof(buf)), pub->n, 1);
439 }
440
441 int
442 pkcs1verify(uchar *data, int len, RSApub *pub, mpint *S)
443 {
444         mpint *V;
445         int ret;
446
447         V = pkcs1digest(data, len, pub);
448         ret = V != nil;
449         if(ret){
450                 rsaencrypt(pub, S, S);
451                 ret = mpcmp(V, S) == 0;
452                 mpfree(V);
453         }
454         return ret;
455 }
456
457 DigestState*
458 hashstr(void *data, ulong len, DigestState *ds)
459 {
460         uchar l[4];
461         pack(l, 4, "u", len);
462         return sha2_256((uchar*)data, len, nil, sha2_256(l, 4, nil, ds));
463 }
464
465 void
466 kdf(uchar *k, int nk, uchar *h, char x, uchar *out, int len)
467 {
468         uchar digest[SHA2_256dlen], *out0;
469         DigestState *ds;
470         int n;
471
472         ds = hashstr(k, nk, nil);
473         ds = sha2_256(h, sizeof(digest), nil, ds);
474         ds = sha2_256((uchar*)&x, 1, nil, ds);
475         sha2_256(sid, nsid, digest, ds);
476         for(out0=out;;){
477                 n = len;
478                 if(n > sizeof(digest))
479                         n = sizeof(digest);
480                 memmove(out, digest, n);
481                 len -= n;
482                 if(len == 0)
483                         break;
484                 out += n;
485                 ds = hashstr(k, nk, nil);
486                 ds = sha2_256(h, sizeof(digest), nil, ds);
487                 sha2_256(out0, out-out0, digest, ds);
488         }
489 }
490
491 void
492 kex(int gotkexinit)
493 {
494         static char kexalgs[] = "curve25519-sha256,curve25519-sha256@libssh.org";
495         static char hostkeyalgs[] = "ssh-rsa";
496         static char cipheralgs[] = "chacha20-poly1305@openssh.com";
497         static char zipalgs[] = "none";
498         static char macalgs[] = "";
499         static char langs[] = "";
500
501         uchar cookie[16], x[32], yc[32], z[32], k[32+1], h[SHA2_256dlen], *ys, *ks, *sig;
502         uchar k12[2*ChachaKeylen];
503         int i, nk, nys, nks, nsig;
504         DigestState *ds;
505         mpint *S, *K;
506         RSApub *pub;
507
508         ds = hashstr(send.v, strlen(send.v), nil);      
509         ds = hashstr(recv.v, strlen(recv.v), ds);
510
511         genrandom(cookie, sizeof(cookie));
512         sendpkt("b[ssssssssssbu", MSG_KEXINIT,
513                 cookie, sizeof(cookie),
514                 kexalgs, sizeof(kexalgs)-1,
515                 hostkeyalgs, sizeof(hostkeyalgs)-1,
516                 cipheralgs, sizeof(cipheralgs)-1,
517                 cipheralgs, sizeof(cipheralgs)-1,
518                 macalgs, sizeof(macalgs)-1,
519                 macalgs, sizeof(macalgs)-1,
520                 zipalgs, sizeof(zipalgs)-1,
521                 zipalgs, sizeof(zipalgs)-1,
522                 langs, sizeof(langs)-1,
523                 langs, sizeof(langs)-1,
524                 0,
525                 0);
526         ds = hashstr(send.r, send.w-send.r, ds);
527
528         if(!gotkexinit){
529         Next0:  switch(recvpkt()){
530                 default:
531                         dispatch();
532                         goto Next0;
533                 case MSG_KEXINIT:
534                         break;
535                 }
536         }
537         ds = hashstr(recv.r, recv.w-recv.r, ds);
538
539         if(debug){
540                 char *tab[] = {
541                         "kexalgs", "hostalgs",
542                         "cipher1", "cipher2",
543                         "mac1", "mac2",
544                         "zip1", "zip2",
545                         "lang1", "lang2",
546                         nil,
547                 }, **t, *s;
548                 uchar *p = recv.r+17;
549                 int n;
550                 for(t=tab; *t != nil; t++){
551                         if(unpack(p, recv.w-p, "s.", &s, &n, &p) < 0)
552                                 break;
553                         fprint(2, "%s: %.*s\n", *t, n, s);
554                 }
555         }
556
557         curve25519_dh_new(x, yc);
558         yc[31] &= ~0x80;
559
560         sendpkt("bs", MSG_ECDH_INIT, yc, sizeof(yc));
561 Next1:  switch(recvpkt()){
562         default:
563                 dispatch();
564                 goto Next1;
565         case MSG_KEXINIT:
566                 sysfatal("inception");
567         case MSG_ECDH_REPLY:
568                 if(unpack(recv.r, recv.w-recv.r, "_sss", &ks, &nks, &ys, &nys, &sig, &nsig) < 0)
569                         sysfatal("bad ECDH_REPLY");
570                 break;
571         }
572
573         if(nys != 32)
574                 sysfatal("bad server ECDH ephermal public key length");
575
576         ds = hashstr(ks, nks, ds);
577         ds = hashstr(yc, 32, ds);
578         ds = hashstr(ys, 32, ds);
579
580         sha2_256(ks, nks, h, nil);
581         i = snprint(thumb, sizeof(thumb), "%.*[", sizeof(h), h);
582         while(i > 0 && thumb[i-1] == '=')
583                 thumb[--i] = '\0';
584
585 if(debug)
586         fprint(2, "host fingerprint: %s\n", thumb);
587
588         if((pub = ssh2rsapub(ks, nks)) == nil)
589                 sysfatal("bad server public key");
590         if((S = ssh2rsasig(sig, nsig)) == nil)
591                 sysfatal("bad server signature");
592
593         curve25519_dh_finish(x, ys, z);
594
595         K = betomp(z, 32, nil);
596         nk = (mpsignif(K)+8)/8;
597         mptober(K, k, nk);
598         mpfree(K);
599
600         ds = hashstr(k, nk, ds);
601         sha2_256(nil, 0, h, ds);
602         if(!pkcs1verify(h, sizeof(h), pub, S))
603                 sysfatal("server verification failed");
604         mpfree(S);
605         rsapubfree(pub);
606
607         sendpkt("b", MSG_NEWKEYS);
608 Next2:  switch(recvpkt()){
609         default:
610                 dispatch();
611                 goto Next2;
612         case MSG_KEXINIT:
613                 sysfatal("inception");
614         case MSG_NEWKEYS:
615                 break;
616         }
617
618         /* next key exchange */
619         recv.kex = recv.seq + 100000;
620         send.kex = send.seq + 100000;
621
622         if(nsid == 0)
623                 memmove(sid, h, nsid = sizeof(h));
624
625         kdf(k, nk, h, 'C', k12, sizeof(k12));
626         setupChachastate(&send.cs1, k12+1*ChachaKeylen, ChachaKeylen, nil, 64/8, 20);
627         setupChachastate(&send.cs2, k12+0*ChachaKeylen, ChachaKeylen, nil, 64/8, 20);
628
629         kdf(k, nk, h, 'D', k12, sizeof(k12));
630         setupChachastate(&recv.cs1, k12+1*ChachaKeylen, ChachaKeylen, nil, 64/8, 20);
631         setupChachastate(&recv.cs2, k12+0*ChachaKeylen, ChachaKeylen, nil, 64/8, 20);
632 }
633
634 static char *authnext;
635
636 int
637 authok(char *meth)
638 {
639         int ok = authnext == nil || strstr(authnext, meth) != nil;
640 if(debug)
641         fprint(2, "userauth %s %s\n", meth, ok ? "ok" : "skipped");
642         return ok;
643 }
644
645 int
646 authfailure(char *meth)
647 {
648         char *s;
649         int n, partial;
650
651         if(unpack(recv.r, recv.w-recv.r, "_sb", &s, &n, &partial) < 0)
652                 sysfatal("bad auth failure response");
653         free(authnext);
654         authnext = smprint("%.*s", n, s);
655 if(debug)
656         fprint(2, "userauth %s failed: partial=%d, next=%s\n", meth, partial, authnext);
657         return partial != 0 || !authok(meth);
658 }
659
660 int
661 pubkeyauth(void)
662 {
663         static char authmeth[] = "publickey";
664
665         uchar pk[4096], sig[4096];
666         int npk, nsig;
667
668         int afd, n;
669         char *s;
670         mpint *S;
671         AuthRpc *rpc;
672         RSApub *pub;
673
674         if(!authok(authmeth))
675                 return -1;
676
677         if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0)
678                 return -1;
679         if((rpc = auth_allocrpc(afd)) == nil){
680                 close(afd);
681                 return -1;
682         }
683
684         s = "proto=rsa service=ssh role=client";
685         if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){
686                 auth_freerpc(rpc);
687                 close(afd);
688                 return -1;
689         }
690
691         pub = rsapuballoc();
692         pub->n = mpnew(0);
693         pub->ek = mpnew(0);
694
695         while(auth_rpc(rpc, "read", nil, 0) == ARok){
696                 s = rpc->arg;
697                 if(strtomp(s, &s, 16, pub->n) == nil)
698                         break;
699                 if(*s++ != ' ')
700                         continue;
701                 if(strtomp(s, nil, 16, pub->ek) == nil)
702                         continue;
703                 npk = rsapub2ssh(pub, pk, sizeof(pk));
704
705                 sendpkt("bsssbss", MSG_USERAUTH_REQUEST,
706                         user, strlen(user),
707                         service, strlen(service),
708                         authmeth, sizeof(authmeth)-1,
709                         0,
710                         sshrsa, sizeof(sshrsa)-1,
711                         pk, npk);
712 Next1:          switch(recvpkt()){
713                 default:
714                         dispatch();
715                         goto Next1;
716                 case MSG_USERAUTH_FAILURE:
717                         if(authfailure(authmeth))
718                                 goto Failed;
719                         continue;
720                 case MSG_USERAUTH_SUCCESS:
721                 case MSG_USERAUTH_PK_OK:
722                         break;
723                 }
724
725                 /* sign sid and the userauth request */
726                 n = pack(send.b, sizeof(send.b), "sbsssbss",
727                         sid, nsid,
728                         MSG_USERAUTH_REQUEST,
729                         user, strlen(user),
730                         service, strlen(service),
731                         authmeth, sizeof(authmeth)-1,
732                         1,
733                         sshrsa, sizeof(sshrsa)-1,
734                         pk, npk);
735                 S = pkcs1digest(send.b, n, pub);
736                 n = snprint((char*)send.b, sizeof(send.b), "%B", S);
737                 mpfree(S);
738
739                 if(auth_rpc(rpc, "write", (char*)send.b, n) != ARok)
740                         break;
741                 if(auth_rpc(rpc, "read", nil, 0) != ARok)
742                         break;
743
744                 S = strtomp(rpc->arg, nil, 16, nil);
745                 nsig = rsasig2ssh(pub, S, sig, sizeof(sig));
746                 mpfree(S);
747
748                 /* send final userauth request with the signature */
749                 sendpkt("bsssbsss", MSG_USERAUTH_REQUEST,
750                         user, strlen(user),
751                         service, strlen(service),
752                         authmeth, sizeof(authmeth)-1,
753                         1,
754                         sshrsa, sizeof(sshrsa)-1,
755                         pk, npk,
756                         sig, nsig);
757 Next2:          switch(recvpkt()){
758                 default:
759                         dispatch();
760                         goto Next2;
761                 case MSG_USERAUTH_FAILURE:
762                         if(authfailure(authmeth))
763                                 goto Failed;
764                         continue;
765                 case MSG_USERAUTH_SUCCESS:
766                         break;
767                 }
768                 rsapubfree(pub);
769                 auth_freerpc(rpc);
770                 close(afd);
771                 return 0;
772         }
773 Failed:
774         rsapubfree(pub);
775         auth_freerpc(rpc);
776         close(afd);
777         return -1;      
778 }
779
780 int
781 passauth(void)
782 {
783         static char authmeth[] = "password";
784         UserPasswd *up;
785
786         if(!authok(authmeth))
787                 return -1;
788
789         up = auth_getuserpasswd(auth_getkey, "proto=pass servive=ssh user=%q server=%q thumb=%q",
790                 user, host, thumb);
791         if(up == nil)
792                 return -1;
793
794         sendpkt("bsssbs", MSG_USERAUTH_REQUEST,
795                 user, strlen(user),
796                 service, strlen(service),
797                 authmeth, sizeof(authmeth)-1,
798                 0,
799                 up->passwd, strlen(up->passwd));
800
801         memset(up->passwd, 0, strlen(up->passwd));
802         free(up);
803
804 Next0:  switch(recvpkt()){
805         default:
806                 dispatch();
807                 goto Next0;
808         case MSG_USERAUTH_FAILURE:
809                 werrstr("wrong password");
810                 authfailure(authmeth);
811                 return -1;
812         case MSG_USERAUTH_SUCCESS:
813                 return 0;
814         }
815 }
816
817 int
818 kbintauth(void)
819 {
820         static char authmeth[] = "keyboard-interactive";
821
822         char *name, *inst, *s, *a;
823         int fd, i, n, m;
824         int nquest, echo;
825         uchar *ans, *answ;
826
827         if(!authok(authmeth))
828                 return -1;
829
830         sendpkt("bsssss", MSG_USERAUTH_REQUEST,
831                 user, strlen(user),
832                 service, strlen(service),
833                 authmeth, sizeof(authmeth)-1,
834                 "", 0,
835                 "", 0);
836
837 Next0:  switch(recvpkt()){
838         default:
839                 dispatch();
840                 goto Next0;
841         case MSG_USERAUTH_FAILURE:
842                 authfailure(authmeth);
843                 return -1;
844         case MSG_USERAUTH_SUCCESS:
845                 return 0;
846         case MSG_USERAUTH_INFO_REQUEST:
847                 break;
848         }
849 Retry:
850         if((fd = open("/dev/cons", OWRITE)) < 0)
851                 return -1;
852
853         if(unpack(recv.r, recv.w-recv.r, "_ss.", &name, &n, &inst, &m, &recv.r) < 0)
854                 sysfatal("bad info request: name, inst");
855
856         while(n > 0 && strchr("\r\n\t ", name[n-1]) != nil)
857                 n--;
858         while(m > 0 && strchr("\r\n\t ", inst[m-1]) != nil)
859                 m--;
860
861         if(n > 0)
862                 fprint(fd, "%.*s\n", n, name);
863         if(m > 0)
864                 fprint(fd, "%.*s\n", m, inst);
865
866         /* lang, nprompt */
867         if(unpack(recv.r, recv.w-recv.r, "su.", &s, &n, &nquest, &recv.r) < 0)
868                 sysfatal("bad info request: lang, #quest");
869
870         ans = answ = nil;
871         for(i = 0; i < nquest; i++){
872                 if(unpack(recv.r, recv.w-recv.r, "sb.", &s, &n, &echo, &recv.r) < 0)
873                         sysfatal("bad info request: question [%d]", i);
874
875                 while(n > 0 && strchr("\r\n\t :", s[n-1]) != nil)
876                         n--;
877                 s[n] = '\0';
878
879                 if((a = readcons(s, nil, !echo)) == nil)
880                         sysfatal("readcons: %r");
881
882                 n = answ - ans;
883                 m = strlen(a)+4;
884                 if((s = realloc(ans, n + m)) == nil)
885                         sysfatal("realloc: %r");
886                 ans = (uchar*)s;
887                 answ = ans+n;
888                 answ += pack(answ, m, "s", a, m-4);
889         }
890
891         sendpkt("bu[", MSG_USERAUTH_INFO_RESPONSE, i, ans, answ - ans);
892         free(ans);
893         close(fd);
894
895 Next1:  switch(recvpkt()){
896         default:
897                 dispatch();
898                 goto Next1;
899         case MSG_USERAUTH_INFO_REQUEST:
900                 goto Retry;
901         case MSG_USERAUTH_FAILURE:
902                 authfailure(authmeth);
903                 return -1;
904         case MSG_USERAUTH_SUCCESS:
905                 return 0;
906         }
907 }
908
909 void
910 dispatch(void)
911 {
912         char *s;
913         uchar *p;
914         int n, b, c;
915
916         switch(recv.r[0]){
917         case MSG_IGNORE:
918         case MSG_GLOBAL_REQUEST:
919                 return;
920         case MSG_DISCONNECT:
921                 if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0)
922                         break;
923                 sysfatal("disconnect: (%d) %.*s", c, n, s);
924                 return;
925         case MSG_DEBUG:
926                 if(unpack(recv.r, recv.w-recv.r, "__sb", &s, &n, &c) < 0)
927                         break;
928                 if(c != 0 || debug) fprint(2, "%s: %.*s\n", argv0, n, s);
929                 return;
930         case MSG_USERAUTH_BANNER:
931                 if(unpack(recv.r, recv.w-recv.r, "_s", &s, &n) < 0)
932                         break;
933                 if(raw) write(2, s, n);
934                 return;
935         case MSG_CHANNEL_DATA:
936                 if(unpack(recv.r, recv.w-recv.r, "_us", &c, &s, &n) < 0)
937                         break;
938                 if(c != recv.chan)
939                         break;
940                 if(write(1, s, n) != n)
941                         sysfatal("write out: %r");
942         Winadjust:
943                 recv.win -= n;
944                 if(recv.win < recv.pkt){
945                         n = WinPackets*recv.pkt;
946                         recv.win += n;
947                         sendpkt("buu", MSG_CHANNEL_WINDOW_ADJUST, send.chan, n);
948                 }
949                 return;
950         case MSG_CHANNEL_EXTENDED_DATA:
951                 if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0)
952                         break;
953                 if(c != recv.chan)
954                         break;
955                 if(b == 1) write(2, s, n);
956                 goto Winadjust;
957         case MSG_CHANNEL_WINDOW_ADJUST:
958                 if(unpack(recv.r, recv.w-recv.r, "_uu", &c, &n) < 0)
959                         break;
960                 if(c != recv.chan)
961                         break;
962                 send.win += n;
963                 if(send.win >= send.pkt)
964                         rwakeup(&send);
965                 return;
966         case MSG_CHANNEL_REQUEST:
967                 if(unpack(recv.r, recv.w-recv.r, "_usb.", &c, &s, &n, &b, &p) < 0)
968                         break;
969                 if(c != recv.chan)
970                         break;
971                 if(n == 11 && memcmp(s, "exit-signal", n) == 0){
972                         if(unpack(p, recv.w-p, "s", &s, &n) < 0)
973                                 break;
974                         if(n != 0 && status == nil)
975                                 status = smprint("%.*s", n, s);
976                 } else if(n == 11 && memcmp(s, "exit-status", n) == 0){
977                         if(unpack(p, recv.w-p, "u", &n) < 0)
978                                 break;
979                         if(n != 0 && status == nil)
980                                 status = smprint("%d", n);
981                 } else if(debug) {
982                         fprint(2, "%s: channel request: %.*s\n", argv0, n, s);
983                 }
984                 return;
985         case MSG_CHANNEL_EOF:
986                 recv.eof = 1;
987                 if(!raw) write(1, "", 0);
988                 return;
989         case MSG_CHANNEL_CLOSE:
990                 shutdown();
991                 return;
992         case MSG_KEXINIT:
993                 kex(1);
994                 return;
995         }
996         sysfatal("got: %.*H", (int)(recv.w - recv.r), recv.r);
997 }
998
999 char*
1000 readline(void)
1001 {
1002         uchar *p;
1003
1004         for(p = send.b; p < &send.b[sizeof(send.b)-1]; p++){
1005                 *p = '\0';
1006                 if(read(fd, p, 1) != 1 || *p == '\n')
1007                         break;
1008         }
1009         while(p >= send.b && (*p == '\n' || *p == '\r'))
1010                 *p-- = '\0';
1011         return (char*)send.b;
1012 }
1013
1014 static struct {
1015         char    *term;
1016         int     xpixels;
1017         int     ypixels;
1018         int     lines;
1019         int     cols;
1020 } tty = {
1021         "dumb",
1022         0,
1023         0,
1024         0,
1025         0,
1026 };
1027
1028 void
1029 rawon(void)
1030 {
1031         int ctl;
1032         char *s;
1033
1034         close(0);
1035         if(open("/dev/cons", OREAD) != 0)
1036                 sysfatal("open: %r");
1037         close(1);
1038         if(open("/dev/cons", OWRITE) != 1)
1039                 sysfatal("open: %r");
1040         dup(1, 2);
1041         if((ctl = open("/dev/consctl", OWRITE)) >= 0)
1042                 write(ctl, "rawon", 5);
1043         if(s = getenv("TERM")){
1044                 tty.term = s;
1045                 if(s = getenv("XPIXELS")){
1046                         tty.xpixels = atoi(s);
1047                         free(s);
1048                 }
1049                 if(s = getenv("YPIXELS")){
1050                         tty.ypixels = atoi(s);
1051                         free(s);
1052                 }
1053                 if(s = getenv("LINES")){
1054                         tty.lines = atoi(s);
1055                         free(s);
1056                 }
1057                 if(s = getenv("COLS")){
1058                         tty.cols = atoi(s);
1059                         free(s);
1060                 }
1061         }
1062 }
1063
1064 void
1065 usage(void)
1066 {
1067         fprint(2, "usage: %s [-dR] [-u user] [user@]host [cmd]\n", argv0);
1068         exits("usage");
1069 }
1070
1071 void
1072 main(int argc, char *argv[])
1073 {
1074         static QLock sl;
1075         int b, n, c;
1076         char *s;
1077
1078         quotefmtinstall();
1079         fmtinstall('B', mpfmt);
1080         fmtinstall('H', encodefmt);
1081         fmtinstall('[', encodefmt);
1082
1083         s = getenv("TERM");
1084         raw = s != nil && strcmp(s, "dumb") != 0;
1085         free(s);
1086
1087         ARGBEGIN {
1088         case 'd':
1089                 debug++;
1090                 break;
1091         case 'R':
1092                 raw = 0;
1093                 break;
1094         case 'u':
1095                 user = EARGF(usage());
1096                 break;
1097         } ARGEND;
1098
1099         if(argc == 0)
1100                 usage();
1101
1102         host = *argv++;
1103         if(user == nil){
1104                 s = strchr(host, '@');
1105                 if(s != nil){
1106                         *s++ = '\0';
1107                         user = host;
1108                         host = s;
1109                 }
1110         }
1111         for(cmd = nil; *argv != nil; argv++){
1112                 if(cmd == nil)
1113                         cmd = strdup(*argv);
1114                 else {
1115                         s = smprint("%s %q", cmd, *argv);
1116                         free(cmd);
1117                         cmd = s;
1118                 }
1119         }
1120         if(cmd != nil)
1121                 raw = 0;
1122
1123         if((fd = dial(netmkaddr(host, nil, "ssh"), nil, nil, nil)) < 0)
1124                 sysfatal("dial: %r");
1125
1126         send.v = "SSH-2.0-(9)";
1127         fprint(fd, "%s\r\n", send.v);
1128         recv.v = readline();
1129         if(debug)
1130                 fprint(2, "server verison: %s\n", recv.v);
1131         if(strncmp("SSH-2.0-", recv.v, 8) != 0)
1132                 sysfatal("bad server version: %s", recv.v);
1133         recv.v = strdup(recv.v);
1134
1135         send.l = recv.l = &sl;
1136
1137         kex(0);
1138
1139         if(user == nil)
1140                 user = getuser();
1141         service = "ssh-connection";
1142
1143         sendpkt("bs", MSG_SERVICE_REQUEST, "ssh-userauth", 12);
1144 Next0:  switch(recvpkt()){
1145         default:
1146                 dispatch();
1147                 goto Next0;
1148         case MSG_SERVICE_ACCEPT:
1149                 break;
1150         }
1151
1152         if(pubkeyauth() < 0 && passauth() < 0 && kbintauth() < 0)
1153                 sysfatal("auth: %r");
1154
1155         recv.pkt = MaxPacket;
1156         recv.win = WinPackets*recv.pkt;
1157         recv.chan = 0;
1158
1159         /* open hailing frequencies */
1160         sendpkt("bsuuu", MSG_CHANNEL_OPEN,
1161                 "session", 7,
1162                 recv.chan,
1163                 recv.win,
1164                 recv.pkt);
1165
1166 Next1:  switch(recvpkt()){
1167         default:
1168                 dispatch();
1169                 goto Next1;
1170         case MSG_CHANNEL_OPEN_FAILURE:
1171                 if(unpack(recv.r, recv.w-recv.r, "_uus", &c, &b, &s, &n) < 0)
1172                         n = strlen(s = "???");
1173                 sysfatal("channel open failure: (%d) %.*s", b, n, s);
1174         case MSG_CHANNEL_OPEN_CONFIRMATION:
1175                 break;
1176         }
1177
1178         if(unpack(recv.r, recv.w-recv.r, "_uuuu", &recv.chan, &send.chan, &send.win, &send.pkt) < 0)
1179                 sysfatal("bad channel open confirmation");
1180         if(send.pkt <= 0 || send.pkt > MaxPacket)
1181                 send.pkt = MaxPacket;
1182
1183         notify(catch);
1184         atexit(shutdown);
1185
1186         recv.pid = getpid();
1187         n = rfork(RFPROC|RFMEM);
1188         if(n < 0)
1189                 sysfatal("fork: %r");
1190
1191         /* parent reads and dispatches packets */
1192         if(n > 0) {
1193                 send.pid = n;
1194                 while((send.eof|recv.eof) == 0){
1195                         recvpkt();
1196                         qlock(&sl);                                     
1197                         dispatch();
1198                         if((int)(send.kex - send.seq) <= 0 || (int)(recv.kex - recv.seq) <= 0)
1199                                 kex(0);
1200                         qunlock(&sl);
1201                 }
1202                 exits(status);
1203         }
1204
1205         /* child reads input and sends packets */
1206         qlock(&sl);
1207         if(raw) {
1208                 rawon();
1209                 sendpkt("busbsuuuus", MSG_CHANNEL_REQUEST,
1210                         send.chan,
1211                         "pty-req", 7,
1212                         0,
1213                         tty.term, strlen(tty.term),
1214                         tty.cols,
1215                         tty.lines,
1216                         tty.xpixels,
1217                         tty.ypixels,
1218                         "", 0);
1219         }
1220         if(cmd == nil){
1221                 sendpkt("busb", MSG_CHANNEL_REQUEST,
1222                         send.chan,
1223                         "shell", 5,
1224                         0);
1225         } else {
1226                 sendpkt("busbs", MSG_CHANNEL_REQUEST,
1227                         send.chan,
1228                         "exec", 4,
1229                         0,
1230                         cmd, strlen(cmd));
1231         }
1232         for(;;){
1233                 static uchar buf[MaxPacket];
1234                 qunlock(&sl);
1235                 n = read(0, buf, send.pkt);
1236                 qlock(&sl);
1237                 if(send.eof)
1238                         break;
1239                 if(n < 0 && wasintr()){
1240                         if(!raw) break;
1241                         sendpkt("busbs", MSG_CHANNEL_REQUEST,
1242                                 send.chan,
1243                                 "signal", 6,
1244                                 0,
1245                                 "INT", 3);
1246                         intr = 0;
1247                         continue;
1248                 }
1249                 if(n <= 0)
1250                         break;
1251                 send.win -= n;
1252                 while(send.win < 0)
1253                         rsleep(&send);
1254                 sendpkt("bus", MSG_CHANNEL_DATA,
1255                         send.chan,
1256                         buf, n);
1257         }
1258         if(send.eof++ == 0)
1259                 sendpkt("bu", raw ? MSG_CHANNEL_CLOSE : MSG_CHANNEL_EOF, send.chan);
1260         qunlock(&sl);
1261
1262         exits(nil);
1263 }