]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/auth/factotum/rsa.c
factotum: rsa: use mptober() to get right adjusted result instead of mptobe() and...
[plan9front.git] / sys / src / cmd / auth / factotum / rsa.c
1 /*
2  * RSA authentication.
3  *
4  * Old ssh client protocol:
5  *      read public key
6  *              if you don't like it, read another, repeat
7  *      write challenge
8  *      read response
9  *
10  * all numbers are hexadecimal biginits parsable with strtomp.
11  *
12  * Sign (PKCS #1 using hash=sha1 or hash=md5)
13  *      write hash(msg)
14  *      read signature(hash(msg))
15  *
16  * Verify:
17  *      write hash(msg)
18  *      write signature(hash(msg))
19  *      read ok or fail
20  */
21
22 #include "dat.h"
23
24 enum {
25         CHavePub,
26         CHaveResp,
27         VNeedHash,
28         VNeedSig,
29         VHaveResp,
30         SNeedHash,
31         SHaveResp,
32         Maxphase,
33 };
34
35 static char *phasenames[] = {
36 [CHavePub]      "CHavePub",
37 [CHaveResp]     "CHaveResp",
38 [VNeedHash]     "VNeedHash",
39 [VNeedSig]      "VNeedSig",
40 [VHaveResp]     "VHaveResp",
41 [SNeedHash]     "SNeedHash",
42 [SHaveResp]     "SHaveResp",
43 };
44
45 struct State
46 {
47         RSApriv *priv;
48         mpint *resp;
49         int off;
50         Key *key;
51         mpint *digest;
52         int sigresp;
53 };
54
55 static mpint* mkdigest(RSApub *key, char *hashalg, uchar *hash, uint dlen);
56
57 static RSApriv*
58 readrsapriv(Key *k)
59 {
60         char *a;
61         RSApriv *priv;
62
63         priv = rsaprivalloc();
64
65         if((a=_strfindattr(k->attr, "ek"))==nil || (priv->pub.ek=strtomp(a, nil, 16, nil))==nil)
66                 goto Error;
67         if((a=_strfindattr(k->attr, "n"))==nil || (priv->pub.n=strtomp(a, nil, 16, nil))==nil)
68                 goto Error;
69         if(k->privattr == nil)          /* only public half */
70                 return priv;
71         if((a=_strfindattr(k->privattr, "!p"))==nil || (priv->p=strtomp(a, nil, 16, nil))==nil)
72                 goto Error;
73         if((a=_strfindattr(k->privattr, "!q"))==nil || (priv->q=strtomp(a, nil, 16, nil))==nil)
74                 goto Error;
75         if((a=_strfindattr(k->privattr, "!kp"))==nil || (priv->kp=strtomp(a, nil, 16, nil))==nil)
76                 goto Error;
77         if((a=_strfindattr(k->privattr, "!kq"))==nil || (priv->kq=strtomp(a, nil, 16, nil))==nil)
78                 goto Error;
79         if((a=_strfindattr(k->privattr, "!c2"))==nil || (priv->c2=strtomp(a, nil, 16, nil))==nil)
80                 goto Error;
81         if((a=_strfindattr(k->privattr, "!dk"))==nil || (priv->dk=strtomp(a, nil, 16, nil))==nil)
82                 goto Error;
83         return priv;
84
85 Error:
86         rsaprivfree(priv);
87         return nil;
88 }
89
90 static int
91 rsainit(Proto*, Fsstate *fss)
92 {
93         Keyinfo ki;
94         State *s;
95         char *role;
96
97         if((role = _strfindattr(fss->attr, "role")) == nil)
98                 return failure(fss, "rsa role not specified");
99         if(strcmp(role, "client") == 0)
100                 fss->phase = CHavePub;
101         else if(strcmp(role, "sign") == 0)
102                 fss->phase = SNeedHash;
103         else if(strcmp(role, "verify") == 0)
104                 fss->phase = VNeedHash;
105         else
106                 return failure(fss, "rsa role %s unimplemented", role);
107
108         s = emalloc(sizeof *s);
109         fss->phasename = phasenames;
110         fss->maxphase = Maxphase;
111         fss->ps = s;
112
113         switch(fss->phase){
114         case SNeedHash:
115         case VNeedHash:
116                 mkkeyinfo(&ki, fss, nil);
117                 if(findkey(&s->key, &ki, nil) != RpcOk)
118                         return failure(fss, nil);
119                 /* signing needs private key */
120                 if(fss->phase == SNeedHash && s->key->privattr == nil)
121                         return failure(fss,
122                                 "missing private half of key -- cannot sign");
123         }
124         return RpcOk;
125 }
126
127 static int
128 rsaread(Fsstate *fss, void *va, uint *n)
129 {
130         RSApriv *priv;
131         State *s;
132         mpint *m;
133         Keyinfo ki;
134         int len;
135
136         s = fss->ps;
137         switch(fss->phase){
138         default:
139                 return phaseerror(fss, "read");
140         case CHavePub:
141                 if(s->key){
142                         closekey(s->key);
143                         s->key = nil;
144                 }
145                 mkkeyinfo(&ki, fss, nil);
146                 ki.skip = s->off;
147                 ki.noconf = 1;
148                 if(findkey(&s->key, &ki, nil) != RpcOk)
149                         return failure(fss, nil);
150                 s->off++;
151                 priv = s->key->priv;
152                 *n = snprint(va, *n, "%B %B", priv->pub.n, priv->pub.ek);
153                 return RpcOk;
154         case CHaveResp:
155                 *n = snprint(va, *n, "%B", s->resp);
156                 fss->phase = Established;
157                 return RpcOk;
158         case SHaveResp:
159                 priv = s->key->priv;
160                 len = (mpsignif(priv->pub.n)+7)/8;
161                 if(len > *n)
162                         return failure(fss, "signature buffer too short");
163                 *n = len;
164                 m = rsadecrypt(priv, s->digest, nil);
165                 mptober(m, (uchar*)va, len);
166                 mpfree(m);
167                 fss->phase = Established;
168                 return RpcOk;
169         case VHaveResp:
170                 *n = snprint(va, *n, "%s", s->sigresp == 0? "ok":
171                         "signature does not verify");
172                 fss->phase = Established;
173                 return RpcOk;
174         }
175 }
176
177 static int
178 rsawrite(Fsstate *fss, void *va, uint n)
179 {
180         RSApriv *priv;
181         mpint *m, *mm;
182         State *s;
183         char *hash;
184         int dlen;
185
186         s = fss->ps;
187         switch(fss->phase){
188         default:
189                 return phaseerror(fss, "write");
190         case CHavePub:
191                 if(s->key == nil)
192                         return failure(fss, "no current key");
193                 switch(canusekey(fss, s->key)){
194                 case -1:
195                         return RpcConfirm;
196                 case 0:
197                         return failure(fss, "confirmation denied");
198                 case 1:
199                         break;
200                 }
201                 m = strtomp(va, nil, 16, nil);
202                 if(m == nil)
203                         return failure(fss, "invalid challenge value");
204                 m = rsadecrypt(s->key->priv, m, m);
205                 s->resp = m;
206                 fss->phase = CHaveResp;
207                 return RpcOk;
208         case SNeedHash:
209         case VNeedHash:
210                 /* get hash type from key */
211                 hash = _strfindattr(s->key->attr, "hash");
212                 if(hash == nil)
213                         hash = "sha1";
214                 if(strcmp(hash, "sha1") == 0)
215                         dlen = SHA1dlen;
216                 else if(strcmp(hash, "md5") == 0)
217                         dlen = MD5dlen;
218                 else if(strcmp(hash, "sha256") == 0)
219                         dlen = SHA2_256dlen;
220                 else
221                         return failure(fss, "unknown hash function %s", hash);
222                 if(n != dlen)
223                         return failure(fss, "hash length %d should be %d",
224                                 n, dlen);
225                 priv = s->key->priv;
226                 s->digest = mkdigest(&priv->pub, hash, (uchar *)va, n);
227                 if(s->digest == nil)
228                         return failure(fss, nil);
229                 if(fss->phase == VNeedHash)
230                         fss->phase = VNeedSig;
231                 else
232                         fss->phase = SHaveResp;
233                 return RpcOk;
234         case VNeedSig:
235                 priv = s->key->priv;
236                 m = betomp((uchar*)va, n, nil);
237                 mm = rsaencrypt(&priv->pub, m, nil);
238                 s->sigresp = mpcmp(s->digest, mm);
239                 mpfree(m);
240                 mpfree(mm);
241                 fss->phase = VHaveResp;
242                 return RpcOk;
243         }
244 }
245
246 static void
247 rsaclose(Fsstate *fss)
248 {
249         State *s;
250
251         s = fss->ps;
252         if(s->key)
253                 closekey(s->key);
254         if(s->resp)
255                 mpfree(s->resp);
256         if(s->digest)
257                 mpfree(s->digest);
258         free(s);
259 }
260
261 static int
262 rsaaddkey(Key *k, int before)
263 {
264         fmtinstall('B', mpfmt);
265
266         if((k->priv = readrsapriv(k)) == nil){
267                 werrstr("malformed key data");
268                 return -1;
269         }
270         return replacekey(k, before);
271 }
272
273 static void
274 rsaclosekey(Key *k)
275 {
276         rsaprivfree(k->priv);
277 }
278
279 Proto rsa = {
280 .name=  "rsa",
281 .init=          rsainit,
282 .write= rsawrite,
283 .read=  rsaread,
284 .close= rsaclose,
285 .addkey=        rsaaddkey,
286 .closekey=      rsaclosekey,
287 };
288
289 /*
290  * Simple ASN.1 encodings.
291  * Lengths < 128 are encoded as 1-bytes constants,
292  * making our life easy.
293  */
294
295 /*
296  * Hash OIDs
297  *
298  * SHA1 = 1.3.14.3.2.26
299  * MDx = 1.2.840.113549.2.x
300  * SHA256 = 2.16.840.1.101.3.4.2.1
301  */
302 #define O0(a,b) ((a)*40+(b))
303 #define O2(x)   \
304         (((x)>> 7)&0x7F)|0x80, \
305         ((x)&0x7F)
306 #define O3(x)   \
307         (((x)>>14)&0x7F)|0x80, \
308         (((x)>> 7)&0x7F)|0x80, \
309         ((x)&0x7F)
310 uchar oidsha1[] = { O0(1, 3), 14, 3, 2, 26 };
311 uchar oidmd5[] = { O0(1, 2), O2(840), O3(113549), 2, 5 };
312 uchar oidsha256[] = { O0(2, 16), O2(840), 1, 101, 3, 4, 2, 1 };
313 /*
314  *      DigestInfo ::= SEQUENCE {
315  *              digestAlgorithm AlgorithmIdentifier,
316  *              digest OCTET STRING
317  *      }
318  *
319  * except that OpenSSL seems to sign
320  *
321  *      DigestInfo ::= SEQUENCE {
322  *              SEQUENCE{ digestAlgorithm AlgorithmIdentifier, NULL }
323  *              digest OCTET STRING
324  *      }
325  *
326  * instead.  Sigh.
327  */
328 static int
329 mkasn1(uchar *asn1, char *alg, uchar *d, uint dlen)
330 {
331         uchar *obj, *p;
332         uint olen;
333
334         if(strcmp(alg, "sha1") == 0){
335                 obj = oidsha1;
336                 olen = sizeof(oidsha1);
337         }else if(strcmp(alg, "md5") == 0){
338                 obj = oidmd5;
339                 olen = sizeof(oidmd5);
340         }else if(strcmp(alg, "sha256") == 0){
341                 obj = oidsha256;
342                 olen = sizeof(oidsha256);
343         }else{
344                 sysfatal("bad alg in mkasn1");
345                 return -1;
346         }
347
348         p = asn1;
349         *p++ = 0x30;            /* sequence */
350         p++;
351
352         *p++ = 0x30;            /* another sequence */
353         p++;
354
355         *p++ = 0x06;            /* object id */
356         *p++ = olen;
357         memmove(p, obj, olen);
358         p += olen;
359
360         *p++ = 0x05;            /* null */
361         *p++ = 0;
362
363         asn1[3] = p - (asn1+4); /* end of inner sequence */
364
365         *p++ = 0x04;            /* octet string */
366         *p++ = dlen;
367         memmove(p, d, dlen);
368         p += dlen;
369
370         asn1[1] = p - (asn1+2); /* end of outer sequence */
371         return p - asn1;
372 }
373
374 static mpint*
375 mkdigest(RSApub *key, char *hashalg, uchar *hash, uint dlen)
376 {
377         mpint *m;
378         uchar asn1[512], *buf;
379         int len, n, pad;
380
381         /*
382          * Create ASN.1
383          */
384         n = mkasn1(asn1, hashalg, hash, dlen);
385
386         /*
387          * PKCS#1 padding
388          */
389         len = (mpsignif(key->n)+7)/8 - 1;
390         if(len < n+2){
391                 werrstr("rsa key too short");
392                 return nil;
393         }
394         pad = len - (n+2);
395         buf = emalloc(len);
396         buf[0] = 0x01;
397         memset(buf+1, 0xFF, pad);
398         buf[1+pad] = 0x00;
399         memmove(buf+1+pad+1, asn1, n);
400         m = betomp(buf, len, nil);
401         free(buf);
402         return m;
403 }