]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/ssh/agent.c
kernel: keep segment locked for data2txt
[plan9front.git] / sys / src / cmd / ssh / agent.c
1 #include "ssh.h"
2 #include <bio.h>
3
4 typedef struct Key Key;
5 struct Key
6 {
7         mpint *mod;
8         mpint *ek;
9         char *comment;
10 };
11
12 typedef struct Achan Achan;
13 struct Achan
14 {
15         int open;
16         u32int chan;    /* of remote */
17         uchar lbuf[4];
18         uint nlbuf;
19         uint len;
20         uchar *data;
21         int ndata;
22         int needeof;
23         int needclosed;
24 };
25
26 Achan achan[16];
27
28 static char*
29 find(char **f, int nf, char *k)
30 {
31         int i, len;
32
33         len = strlen(k);
34         for(i=1; i<nf; i++)     /* i=1: f[0] is "key" */
35                 if(strncmp(f[i], k, len) == 0 && f[i][len] == '=')
36                         return f[i]+len+1;
37         return nil;
38 }
39
40 static int
41 listkeys(Key **kp)
42 {
43         Biobuf *b;
44         Key *k;
45         int nk;
46         char *p, *f[20];
47         int nf;
48         mpint *mod, *ek;
49         
50         *kp = nil;
51         if((b = Bopen("/mnt/factotum/ctl", OREAD)) == nil)
52                 return -1;
53         
54         k = nil;
55         nk = 0;
56         while((p = Brdline(b, '\n')) != nil){
57                 p[Blinelen(b)-1] = '\0';
58                 nf = tokenize(p, f, nelem(f));
59                 if(nf == 0 || strcmp(f[0], "key") != 0)
60                         continue;
61                 p = find(f, nf, "proto");
62                 if(p == nil || strcmp(p, "rsa") != 0)
63                         continue;
64                 p = find(f, nf, "n");
65                 if(p == nil || (mod = strtomp(p, nil, 16, nil)) == nil)
66                         continue;
67                 p = find(f, nf, "ek");
68                 if(p == nil || (ek = strtomp(p, nil, 16, nil)) == nil){
69                         mpfree(mod);
70                         continue;
71                 }
72                 p = find(f, nf, "comment");
73                 if(p == nil)
74                         p = "";
75                 k = erealloc(k, (nk+1)*sizeof(k[0]));
76                 k[nk].mod = mod;
77                 k[nk].ek = ek;
78                 k[nk].comment = emalloc(strlen(p)+1);
79                 strcpy(k[nk].comment, p);
80                 nk++;
81         }
82         Bterm(b);
83         *kp = k;
84         return nk;      
85 }
86
87
88 static int
89 dorsa(mpint *mod, mpint *exp, mpint *chal, uchar chalbuf[32])
90 {
91         int afd;
92         AuthRpc *rpc;
93         mpint *m;
94         char buf[4096], *p;
95         mpint *decr, *unpad;
96
97         USED(exp);
98
99         snprint(buf, sizeof buf, "proto=rsa service=ssh role=client");
100         if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0){
101                 debug(DBG_AUTH, "open /mnt/factotum/rpc: %r\n");
102                 return -1;
103         }
104         if((rpc = auth_allocrpc(afd)) == nil){
105                 debug(DBG_AUTH, "auth_allocrpc: %r\n");
106                 close(afd);
107                 return -1;
108         }
109         if(auth_rpc(rpc, "start", buf, strlen(buf)) != ARok){
110                 debug(DBG_AUTH, "auth_rpc start failed: %r\n");
111         Die:
112                 auth_freerpc(rpc);
113                 close(afd);
114                 return -1;
115         }
116         m = nil;
117         debug(DBG_AUTH, "trying factotum rsa keys\n");
118         while(auth_rpc(rpc, "read", nil, 0) == ARok){
119                 debug(DBG_AUTH, "try %s\n", (char*)rpc->arg);
120                 m = strtomp(rpc->arg, nil, 16, nil);
121                 if(mpcmp(m, mod) == 0)
122                         break;
123                 mpfree(m);
124                 m = nil;
125         }
126         if(m == nil)
127                 goto Die;
128         mpfree(m);
129         
130         p = mptoa(chal, 16, nil, 0);
131         if(p == nil){
132                 debug(DBG_AUTH, "\tmptoa failed: %r\n");
133                 goto Die;
134         }
135         if(auth_rpc(rpc, "write", p, strlen(p)) != ARok){
136                 debug(DBG_AUTH, "\tauth_rpc write failed: %r\n");
137                 free(p);
138                 goto Die;
139         }
140         free(p);
141         if(auth_rpc(rpc, "read", nil, 0) != ARok){
142                 debug(DBG_AUTH, "\tauth_rpc read failed: %r\n");
143                 goto Die;
144         }
145         decr = strtomp(rpc->arg, nil, 16, nil);
146         if(decr == nil){
147                 debug(DBG_AUTH, "\tdecr %s failed\n", rpc->arg);
148                 goto Die;
149         }
150         debug(DBG_AUTH, "\tdecrypted %B\n", decr);
151         unpad = rsaunpad(decr);
152         if(unpad == nil){
153                 debug(DBG_AUTH, "\tunpad %B failed\n", decr);
154                 mpfree(decr);
155                 goto Die;
156         }
157         debug(DBG_AUTH, "\tunpadded %B\n", unpad);
158         mpfree(decr);
159         mptoberjust(unpad, chalbuf, 32);
160         mpfree(unpad);
161         auth_freerpc(rpc);
162         close(afd);
163         return 0;
164 }
165
166 int
167 startagent(Conn *c)
168 {
169         int ret;
170         Msg *m;
171
172         m = allocmsg(c, SSH_CMSG_AGENT_REQUEST_FORWARDING, 0);
173         sendmsg(m);
174
175         m = recvmsg(c, -1);
176         switch(m->type){
177         case SSH_SMSG_SUCCESS:
178                 debug(DBG_AUTH, "agent allocated\n");
179                 ret = 0;
180                 break;
181         case SSH_SMSG_FAILURE:
182                 debug(DBG_AUTH, "agent failed to allocate\n");
183                 ret = -1;
184                 break;
185         default:
186                 badmsg(m, 0);
187                 ret = -1;
188                 break;
189         }
190         free(m);
191         return ret;
192 }
193
194 void handlefullmsg(Conn*, Achan*);
195
196 void
197 handleagentmsg(Msg *m)
198 {
199         u32int chan, len;
200         int n;
201         Achan *a;
202
203         assert(m->type == SSH_MSG_CHANNEL_DATA);
204
205         debug(DBG_AUTH, "agent data\n");
206         debug(DBG_AUTH, "\t%.*H\n", (int)(m->ep - m->rp), m->rp);
207         chan = getlong(m);
208         len = getlong(m);
209         if(m->rp+len != m->ep)
210                 sysfatal("got bad channel data");
211
212         if(chan >= nelem(achan))
213                 error("bad channel in agent request");
214
215         a = &achan[chan];
216
217         while(m->rp < m->ep){
218                 if(a->nlbuf < 4){
219                         a->lbuf[a->nlbuf++] = getbyte(m);
220                         if(a->nlbuf == 4){
221                                 a->len = (a->lbuf[0]<<24) | (a->lbuf[1]<<16) | (a->lbuf[2]<<8) | a->lbuf[3];
222                                 a->data = erealloc(a->data, a->len);
223                                 a->ndata = 0;
224                         }
225                         continue;
226                 }
227                 if(a->ndata < a->len){
228                         n = a->len - a->ndata;
229                         if(n > m->ep - m->rp)
230                                 n = m->ep - m->rp;
231                         memmove(a->data+a->ndata, getbytes(m, n), n);
232                         a->ndata += n;
233                 }
234                 if(a->ndata == a->len){
235                         handlefullmsg(m->c, a);
236                         a->nlbuf = 0;
237                 }
238         }
239 }
240
241 void
242 handlefullmsg(Conn *c, Achan *a)
243 {
244         int i;
245         u32int chan, len, n, rt;
246         uchar type;
247         Msg *m, mm;
248         Msg *r;
249         Key *k;
250         int nk;
251         mpint *mod, *ek, *chal;
252         uchar sessid[16];
253         uchar chalbuf[32];
254         uchar digest[16];
255         DigestState *s;
256         static int first;
257
258         assert(a->len == a->ndata);
259
260         chan = a->chan;
261         mm.rp = a->data;
262         mm.ep = a->data+a->ndata;
263         mm.c = c;
264         m = &mm;
265
266         type = getbyte(m);
267
268         if(first == 0){
269                 first++;
270                 fmtinstall('H', encodefmt);
271         }
272
273         switch(type){
274         default:
275                 debug(DBG_AUTH, "unknown msg type\n");
276         Failure:
277                 debug(DBG_AUTH, "agent sending failure\n");
278                 r = allocmsg(m->c, SSH_MSG_CHANNEL_DATA, 13);
279                 putlong(r, chan);
280                 putlong(r, 5);
281                 putlong(r, 1);
282                 putbyte(r, SSH_AGENT_FAILURE);
283                 sendmsg(r);
284                 return;
285
286         case SSH_AGENTC_REQUEST_RSA_IDENTITIES:
287                 debug(DBG_AUTH, "agent request identities\n");
288                 nk = listkeys(&k);
289                 if(nk < 0)
290                         goto Failure;
291                 len = 1+4;      /* type, nk */
292                 for(i=0; i<nk; i++){
293                         len += 4;
294                         len += 2+(mpsignif(k[i].ek)+7)/8;
295                         len += 2+(mpsignif(k[i].mod)+7)/8;
296                         len += 4+strlen(k[i].comment);
297                 }
298                 r = allocmsg(m->c, SSH_MSG_CHANNEL_DATA, 12+len);
299                 putlong(r, chan);
300                 putlong(r, len+4);
301                 putlong(r, len);
302                 putbyte(r, SSH_AGENT_RSA_IDENTITIES_ANSWER);
303                 putlong(r, nk);
304                 for(i=0; i<nk; i++){
305                         debug(DBG_AUTH, "\t%B %B %s\n", k[i].ek, k[i].mod, k[i].comment);
306                         putlong(r, mpsignif(k[i].mod));
307                         putmpint(r, k[i].ek);
308                         putmpint(r, k[i].mod);
309                         putstring(r, k[i].comment);
310                         mpfree(k[i].ek);
311                         mpfree(k[i].mod);
312                         free(k[i].comment);
313                 }
314                 free(k);
315                 sendmsg(r);
316                 break;
317
318         case SSH_AGENTC_RSA_CHALLENGE:
319                 n = getlong(m);
320                 USED(n);        /* number of bits in key; who cares? */
321                 ek = getmpint(m);
322                 mod = getmpint(m);
323                 chal = getmpint(m);
324                 memmove(sessid, getbytes(m, 16), 16);
325                 rt = getlong(m);
326                 debug(DBG_AUTH, "agent challenge %B %B %B %ud (%p %p)\n",
327                         ek, mod, chal, rt, m->rp, m->ep);
328                 if(rt != 1 || dorsa(mod, ek, chal, chalbuf) < 0){
329                         mpfree(ek);
330                         mpfree(mod);
331                         mpfree(chal);
332                         goto Failure;
333                 }
334                 s = md5(chalbuf, 32, nil, nil);
335                 md5(sessid, 16, digest, s);
336                 r = allocmsg(m->c, SSH_MSG_CHANNEL_DATA, 12+1+16);
337                 putlong(r, chan);
338                 putlong(r, 4+16+1);
339                 putlong(r, 16+1);
340                 putbyte(r, SSH_AGENT_RSA_RESPONSE);
341                 putbytes(r, digest, 16);
342                 debug(DBG_AUTH, "digest %.16H\n", digest);
343                 sendmsg(r);
344                 mpfree(ek);
345                 mpfree(mod);
346                 mpfree(chal);
347                 return;
348
349         case SSH_AGENTC_ADD_RSA_IDENTITY:
350                 goto Failure;
351 /*
352                 n = getlong(m);
353                 pubmod = getmpint(m);
354                 pubexp = getmpint(m);
355                 privexp = getmpint(m);
356                 pinversemodq = getmpint(m);
357                 p = getmpint(m);
358                 q = getmpint(m);
359                 comment = getstring(m);
360                 add to factotum;
361                 send SSH_AGENT_SUCCESS or SSH_AGENT_FAILURE;
362 */
363
364         case SSH_AGENTC_REMOVE_RSA_IDENTITY:
365                 goto Failure;
366 /*
367                 n = getlong(m);
368                 pubmod = getmpint(m);
369                 pubexp = getmpint(m);
370                 tell factotum to del key
371                 send SSH_AGENT_SUCCESS or SSH_AGENT_FAILURE;
372 */
373         }
374 }
375
376 void
377 handleagentopen(Msg *m)
378 {
379         int i;
380         u32int remote;
381
382         assert(m->type == SSH_SMSG_AGENT_OPEN);
383         remote = getlong(m);
384         debug(DBG_AUTH, "agent open %d\n", remote);
385
386         for(i=0; i<nelem(achan); i++)
387                 if(achan[i].open == 0 && achan[i].needeof == 0 && achan[i].needclosed == 0)
388                         break;
389         if(i == nelem(achan)){
390                 m = allocmsg(m->c, SSH_MSG_CHANNEL_OPEN_FAILURE, 4);
391                 putlong(m, remote);
392                 sendmsg(m);
393                 return;
394         }
395
396         debug(DBG_AUTH, "\tremote %d is local %d\n", remote, i);
397         achan[i].open = 1;
398         achan[i].needeof = 1;
399         achan[i].needclosed = 1;
400         achan[i].nlbuf = 0;
401         achan[i].chan = remote;
402         m = allocmsg(m->c, SSH_MSG_CHANNEL_OPEN_CONFIRMATION, 8);
403         putlong(m, remote);
404         putlong(m, i);
405         sendmsg(m);
406 }
407
408 void
409 handleagentieof(Msg *m)
410 {
411         u32int local;
412
413         assert(m->type == SSH_MSG_CHANNEL_INPUT_EOF);
414         local = getlong(m);
415         debug(DBG_AUTH, "agent close %d\n", local);
416         if(local < nelem(achan)){
417                 debug(DBG_AUTH, "\tlocal %d is remote %d\n", local, achan[local].chan);
418                 achan[local].open = 0;
419 /*
420                 m = allocmsg(m->c, SSH_MSG_CHANNEL_OUTPUT_CLOSED, 4);
421                 putlong(m, achan[local].chan);
422                 sendmsg(m);
423 */
424                 if(achan[local].needeof){
425                         achan[local].needeof = 0;
426                         m = allocmsg(m->c, SSH_MSG_CHANNEL_INPUT_EOF, 4);
427                         putlong(m, achan[local].chan);
428                         sendmsg(m);
429                 }
430         }
431 }
432
433 void
434 handleagentoclose(Msg *m)
435 {
436         u32int local;
437
438         assert(m->type == SSH_MSG_CHANNEL_OUTPUT_CLOSED);
439         local = getlong(m);
440         debug(DBG_AUTH, "agent close %d\n", local);
441         if(local < nelem(achan)){
442                 debug(DBG_AUTH, "\tlocal %d is remote %d\n", local, achan[local].chan);
443                 if(achan[local].needclosed){
444                         achan[local].needclosed = 0;
445                         m = allocmsg(m->c, SSH_MSG_CHANNEL_OUTPUT_CLOSED, 4);
446                         putlong(m, achan[local].chan);
447                         sendmsg(m);
448                 }
449         }
450 }