]> git.lizzy.rs Git - plan9front.git/blob - sys/src/libsec/port/ecc.c
libsec: move zero check to curve25519_dh_finish()
[plan9front.git] / sys / src / libsec / port / ecc.c
1 #include "os.h"
2 #include <mp.h>
3 #include <libsec.h>
4 #include <ctype.h>
5
6 extern void jacobian_affine(mpint *p,
7         mpint *X, mpint *Y, mpint *Z);
8 extern void jacobian_dbl(mpint *p, mpint *a,
9         mpint *X1, mpint *Y1, mpint *Z1,
10         mpint *X3, mpint *Y3, mpint *Z3);
11 extern void jacobian_add(mpint *p, mpint *a,
12         mpint *X1, mpint *Y1, mpint *Z1,
13         mpint *X2, mpint *Y2, mpint *Z2,
14         mpint *X3, mpint *Y3, mpint *Z3);
15
16 void
17 ecassign(ECdomain *dom, ECpoint *a, ECpoint *b)
18 {
19         if((b->inf = a->inf) != 0)
20                 return;
21         mpassign(a->x, b->x);
22         mpassign(a->y, b->y);
23         if(b->z != nil){
24                 mpassign(a->z != nil ? a->z : mpone, b->z);
25                 return;
26         }
27         if(a->z != nil){
28                 b->z = mpcopy(a->z);
29                 jacobian_affine(dom->p, b->x, b->y, b->z);
30                 mpfree(b->z);
31                 b->z = nil;
32         }
33 }
34
35 void
36 ecadd(ECdomain *dom, ECpoint *a, ECpoint *b, ECpoint *s)
37 {
38         if(a->inf && b->inf){
39                 s->inf = 1;
40                 return;
41         }
42         if(a->inf){
43                 ecassign(dom, b, s);
44                 return;
45         }
46         if(b->inf){
47                 ecassign(dom, a, s);
48                 return;
49         }
50
51         if(s->z == nil){
52                 s->z = mpcopy(mpone);
53                 ecadd(dom, a, b, s);
54                 if(!s->inf)
55                         jacobian_affine(dom->p, s->x, s->y, s->z);
56                 mpfree(s->z);
57                 s->z = nil;
58                 return;
59         }
60
61         if(a == b)
62                 jacobian_dbl(dom->p, dom->a,
63                         a->x, a->y, a->z != nil ? a->z : mpone,
64                         s->x, s->y, s->z);
65         else
66                 jacobian_add(dom->p, dom->a,
67                         a->x, a->y, a->z != nil ? a->z : mpone,
68                         b->x, b->y, b->z != nil ? b->z : mpone,
69                         s->x, s->y, s->z);
70         s->inf = mpcmp(s->z, mpzero) == 0;
71 }
72
73 void
74 ecmul(ECdomain *dom, ECpoint *a, mpint *k, ECpoint *s)
75 {
76         ECpoint ns, na;
77         mpint *l;
78
79         if(a->inf || mpcmp(k, mpzero) == 0){
80                 s->inf = 1;
81                 return;
82         }
83         ns.inf = 1;
84         ns.x = mpnew(0);
85         ns.y = mpnew(0);
86         ns.z = mpnew(0);
87         na.x = mpnew(0);
88         na.y = mpnew(0);
89         na.z = mpnew(0);
90         ecassign(dom, a, &na);
91         l = mpcopy(k);
92         l->sign = 1;
93         while(mpcmp(l, mpzero) != 0){
94                 if(l->p[0] & 1)
95                         ecadd(dom, &na, &ns, &ns);
96                 ecadd(dom, &na, &na, &na);
97                 mpright(l, 1, l);
98         }
99         if(k->sign < 0 && !ns.inf){
100                 ns.y->sign = -1;
101                 mpmod(ns.y, dom->p, ns.y);
102         }
103         ecassign(dom, &ns, s);
104         mpfree(ns.x);
105         mpfree(ns.y);
106         mpfree(ns.z);
107         mpfree(na.x);
108         mpfree(na.y);
109         mpfree(na.z);
110         mpfree(l);
111 }
112
113 int
114 ecverify(ECdomain *dom, ECpoint *a)
115 {
116         mpint *p, *q;
117         int r;
118
119         if(a->inf)
120                 return 1;
121
122         assert(a->z == nil);    /* need affine coordinates */
123         p = mpnew(0);
124         q = mpnew(0);
125         mpmodmul(a->y, a->y, dom->p, p);
126         mpmodmul(a->x, a->x, dom->p, q);
127         mpmodadd(q, dom->a, dom->p, q);
128         mpmodmul(q, a->x, dom->p, q);
129         mpmodadd(q, dom->b, dom->p, q);
130         r = mpcmp(p, q);
131         mpfree(p);
132         mpfree(q);
133         return r == 0;
134 }
135
136 int
137 ecpubverify(ECdomain *dom, ECpub *a)
138 {
139         ECpoint p;
140         int r;
141
142         if(a->inf)
143                 return 0;
144         if(!ecverify(dom, a))
145                 return 0;
146         p.x = mpnew(0);
147         p.y = mpnew(0);
148         p.z = mpnew(0);
149         ecmul(dom, a, dom->n, &p);
150         r = p.inf;
151         mpfree(p.x);
152         mpfree(p.y);
153         mpfree(p.z);
154         return r;
155 }
156
157 static void
158 fixnibble(uchar *a)
159 {
160         if(*a >= 'a')
161                 *a -= 'a'-10;
162         else if(*a >= 'A')
163                 *a -= 'A'-10;
164         else
165                 *a -= '0';
166 }
167
168 static int
169 octet(char **s)
170 {
171         uchar c, d;
172         
173         c = *(*s)++;
174         if(!isxdigit(c))
175                 return -1;
176         d = *(*s)++;
177         if(!isxdigit(d))
178                 return -1;
179         fixnibble(&c);
180         fixnibble(&d);
181         return (c << 4) | d;
182 }
183
184 static mpint*
185 halfpt(ECdomain *dom, char *s, char **rptr, mpint *out)
186 {
187         char *buf, *r;
188         int n;
189         mpint *ret;
190         
191         n = ((mpsignif(dom->p)+7)/8)*2;
192         if(strlen(s) < n)
193                 return 0;
194         buf = malloc(n+1);
195         buf[n] = 0;
196         memcpy(buf, s, n);
197         ret = strtomp(buf, &r, 16, out);
198         *rptr = s + (r - buf);
199         free(buf);
200         return ret;
201 }
202
203 static int
204 mpleg(mpint *a, mpint *b)
205 {
206         int r, k;
207         mpint *m, *n, *t;
208         
209         r = 1;
210         m = mpcopy(a);
211         n = mpcopy(b);
212         for(;;){
213                 if(mpcmp(m, n) > 0)
214                         mpmod(m, n, m);
215                 if(mpcmp(m, mpzero) == 0){
216                         r = 0;
217                         break;
218                 }
219                 if(mpcmp(m, mpone) == 0)
220                         break;
221                 k = mplowbits0(m);
222                 if(k > 0){
223                         if(k & 1)
224                                 switch(n->p[0] & 15){
225                                 case 3: case 5: case 11: case 13:
226                                         r = -r;
227                                 }
228                         mpright(m, k, m);
229                 }
230                 if((n->p[0] & 3) == 3 && (m->p[0] & 3) == 3)
231                         r = -r;
232                 t = m;
233                 m = n;
234                 n = t;
235         }
236         mpfree(m);
237         mpfree(n);
238         return r;
239 }
240
241 static int
242 mpsqrt(mpint *n, mpint *p, mpint *r)
243 {
244         mpint *a, *t, *s, *xp, *xq, *yp, *yq, *zp, *zq, *N;
245
246         if(mpleg(n, p) == -1)
247                 return 0;
248         a = mpnew(0);
249         t = mpnew(0);
250         s = mpnew(0);
251         N = mpnew(0);
252         xp = mpnew(0);
253         xq = mpnew(0);
254         yp = mpnew(0);
255         yq = mpnew(0);
256         zp = mpnew(0);
257         zq = mpnew(0);
258         for(;;){
259                 for(;;){
260                         mpnrand(p, genrandom, a);
261                         if(mpcmp(a, mpzero) > 0)
262                                 break;
263                 }
264                 mpmul(a, a, t);
265                 mpsub(t, n, t);
266                 mpmod(t, p, t);
267                 if(mpleg(t, p) == -1)
268                         break;
269         }
270         mpadd(p, mpone, N);
271         mpright(N, 1, N);
272         mpmul(a, a, t);
273         mpsub(t, n, t);
274         mpassign(a, xp);
275         uitomp(1, xq);
276         uitomp(1, yp);
277         uitomp(0, yq);
278         while(mpcmp(N, mpzero) != 0){
279                 if(N->p[0] & 1){
280                         mpmul(xp, yp, zp);
281                         mpmul(xq, yq, zq);
282                         mpmul(zq, t, zq);
283                         mpadd(zp, zq, zp);
284                         mpmod(zp, p, zp);
285                         mpmul(xp, yq, zq);
286                         mpmul(xq, yp, s);
287                         mpadd(zq, s, zq);
288                         mpmod(zq, p, yq);
289                         mpassign(zp, yp);
290                 }
291                 mpmul(xp, xp, zp);
292                 mpmul(xq, xq, zq);
293                 mpmul(zq, t, zq);
294                 mpadd(zp, zq, zp);
295                 mpmod(zp, p, zp);
296                 mpmul(xp, xq, zq);
297                 mpadd(zq, zq, zq);
298                 mpmod(zq, p, xq);
299                 mpassign(zp, xp);
300                 mpright(N, 1, N);
301         }
302         if(mpcmp(yq, mpzero) != 0)
303                 abort();
304         mpassign(yp, r);
305         mpfree(a);
306         mpfree(t);
307         mpfree(s);
308         mpfree(N);
309         mpfree(xp);
310         mpfree(xq);
311         mpfree(yp);
312         mpfree(yq);
313         mpfree(zp);
314         mpfree(zq);
315         return 1;
316 }
317
318 ECpoint*
319 strtoec(ECdomain *dom, char *s, char **rptr, ECpoint *ret)
320 {
321         int allocd, o;
322         mpint *r;
323
324         allocd = 0;
325         if(ret == nil){
326                 allocd = 1;
327                 ret = mallocz(sizeof(*ret), 1);
328                 if(ret == nil)
329                         return nil;
330                 ret->x = mpnew(0);
331                 ret->y = mpnew(0);
332         }
333         ret->inf = 0;
334         o = 0;
335         switch(octet(&s)){
336         case 0:
337                 ret->inf = 1;
338                 break;
339         case 3:
340                 o = 1;
341         case 2:
342                 if(halfpt(dom, s, &s, ret->x) == nil)
343                         goto err;
344                 r = mpnew(0);
345                 mpmul(ret->x, ret->x, r);
346                 mpadd(r, dom->a, r);
347                 mpmul(r, ret->x, r);
348                 mpadd(r, dom->b, r);
349                 if(!mpsqrt(r, dom->p, r)){
350                         mpfree(r);
351                         goto err;
352                 }
353                 if((r->p[0] & 1) != o)
354                         mpsub(dom->p, r, r);
355                 mpassign(r, ret->y);
356                 mpfree(r);
357                 if(!ecverify(dom, ret))
358                         goto err;
359                 break;
360         case 4:
361                 if(halfpt(dom, s, &s, ret->x) == nil)
362                         goto err;
363                 if(halfpt(dom, s, &s, ret->y) == nil)
364                         goto err;
365                 if(!ecverify(dom, ret))
366                         goto err;
367                 break;
368         }
369         if(ret->z != nil && !ret->inf)
370                 mpassign(mpone, ret->z);
371         return ret;
372
373 err:
374         if(rptr)
375                 *rptr = s;
376         if(allocd){
377                 mpfree(ret->x);
378                 mpfree(ret->y);
379                 free(ret);
380         }
381         return nil;
382 }
383
384 ECpriv*
385 ecgen(ECdomain *dom, ECpriv *p)
386 {
387         if(p == nil){
388                 p = mallocz(sizeof(*p), 1);
389                 if(p == nil)
390                         return nil;
391                 p->x = mpnew(0);
392                 p->y = mpnew(0);
393                 p->d = mpnew(0);
394         }
395         for(;;){
396                 mpnrand(dom->n, genrandom, p->d);
397                 if(mpcmp(p->d, mpzero) > 0)
398                         break;
399         }
400         ecmul(dom, &dom->G, p->d, p);
401         return p;
402 }
403
404 void
405 ecdsasign(ECdomain *dom, ECpriv *priv, uchar *dig, int len, mpint *r, mpint *s)
406 {
407         ECpriv tmp;
408         mpint *E, *t;
409
410         tmp.x = mpnew(0);
411         tmp.y = mpnew(0);
412         tmp.z = nil;
413         tmp.d = mpnew(0);
414         E = betomp(dig, len, nil);
415         t = mpnew(0);
416         if(mpsignif(dom->n) < 8*len)
417                 mpright(E, 8*len - mpsignif(dom->n), E);
418         for(;;){
419                 ecgen(dom, &tmp);
420                 mpmod(tmp.x, dom->n, r);
421                 if(mpcmp(r, mpzero) == 0)
422                         continue;
423                 mpmul(r, priv->d, s);
424                 mpadd(E, s, s);
425                 mpinvert(tmp.d, dom->n, t);
426                 mpmodmul(s, t, dom->n, s);
427                 if(mpcmp(s, mpzero) != 0)
428                         break;
429         }
430         mpfree(t);
431         mpfree(E);
432         mpfree(tmp.x);
433         mpfree(tmp.y);
434         mpfree(tmp.d);
435 }
436
437 int
438 ecdsaverify(ECdomain *dom, ECpub *pub, uchar *dig, int len, mpint *r, mpint *s)
439 {
440         mpint *E, *t, *u1, *u2;
441         ECpoint R, S;
442         int ret;
443
444         if(mpcmp(r, mpone) < 0 || mpcmp(s, mpone) < 0 || mpcmp(r, dom->n) >= 0 || mpcmp(r, dom->n) >= 0)
445                 return 0;
446         E = betomp(dig, len, nil);
447         if(mpsignif(dom->n) < 8*len)
448                 mpright(E, 8*len - mpsignif(dom->n), E);
449         t = mpnew(0);
450         u1 = mpnew(0);
451         u2 = mpnew(0);
452         R.x = mpnew(0);
453         R.y = mpnew(0);
454         R.z = mpnew(0);
455         S.x = mpnew(0);
456         S.y = mpnew(0);
457         S.z = mpnew(0);
458         mpinvert(s, dom->n, t);
459         mpmodmul(E, t, dom->n, u1);
460         mpmodmul(r, t, dom->n, u2);
461         ecmul(dom, &dom->G, u1, &R);
462         ecmul(dom, pub, u2, &S);
463         ecadd(dom, &R, &S, &R);
464         ret = 0;
465         if(!R.inf){
466                 jacobian_affine(dom->p, R.x, R.y, R.z);
467                 mpmod(R.x, dom->n, t);
468                 ret = mpcmp(r, t) == 0;
469         }
470         mpfree(E);
471         mpfree(t);
472         mpfree(u1);
473         mpfree(u2);
474         mpfree(R.x);
475         mpfree(R.y);
476         mpfree(R.z);
477         mpfree(S.x);
478         mpfree(S.y);
479         mpfree(S.z);
480         return ret;
481 }
482
483 static char *code = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz";
484
485 void
486 base58enc(uchar *src, char *dst, int len)
487 {
488         mpint *n, *r, *b;
489         char *sdst, t;
490         
491         sdst = dst;
492         n = betomp(src, len, nil);
493         b = uitomp(58, nil);
494         r = mpnew(0);
495         while(mpcmp(n, mpzero) != 0){
496                 mpdiv(n, b, n, r);
497                 *dst++ = code[mptoui(r)];
498         }
499         for(; *src == 0; src++)
500                 *dst++ = code[0];
501         *dst-- = 0;
502         while(dst > sdst){
503                 t = *sdst;
504                 *sdst++ = *dst;
505                 *dst-- = t;
506         }
507 }
508
509 int
510 base58dec(char *src, uchar *dst, int len)
511 {
512         mpint *n, *b, *r;
513         char *t;
514         
515         n = mpnew(0);
516         r = mpnew(0);
517         b = uitomp(58, nil);
518         for(; *src; src++){
519                 t = strchr(code, *src);
520                 if(t == nil){
521                         mpfree(n);
522                         mpfree(r);
523                         mpfree(b);
524                         werrstr("invalid base58 char");
525                         return -1;
526                 }
527                 uitomp(t - code, r);
528                 mpmul(n, b, n);
529                 mpadd(n, r, n);
530         }
531         mptober(n, dst, len);
532         mpfree(n);
533         mpfree(r);
534         mpfree(b);
535         return 0;
536 }
537
538 void
539 ecdominit(ECdomain *dom, void (*init)(mpint *p, mpint *a, mpint *b, mpint *x, mpint *y, mpint *n, mpint *h))
540 {
541         memset(dom, 0, sizeof(*dom));
542         dom->p = mpnew(0);
543         dom->a = mpnew(0);
544         dom->b = mpnew(0);
545         dom->G.x = mpnew(0);
546         dom->G.y = mpnew(0);
547         dom->n = mpnew(0);
548         dom->h = mpnew(0);
549         if(init){
550                 (*init)(dom->p, dom->a, dom->b, dom->G.x, dom->G.y, dom->n, dom->h);
551                 dom->p = mpfield(dom->p);
552         }
553 }
554
555 void
556 ecdomfree(ECdomain *dom)
557 {
558         mpfree(dom->p);
559         mpfree(dom->a);
560         mpfree(dom->b);
561         mpfree(dom->G.x);
562         mpfree(dom->G.y);
563         mpfree(dom->n);
564         mpfree(dom->h);
565         memset(dom, 0, sizeof(*dom));
566 }
567
568 int
569 ecencodepub(ECdomain *dom, ECpub *pub, uchar *data, int len)
570 {
571         int n;
572
573         n = (mpsignif(dom->p)+7)/8;
574         if(len < 1 + 2*n)
575                 return 0;
576         len = 1 + 2*n;
577         data[0] = 0x04;
578         mptober(pub->x, data+1, n);
579         mptober(pub->y, data+1+n, n);
580         return len;
581 }
582
583 ECpub*
584 ecdecodepub(ECdomain *dom, uchar *data, int len)
585 {
586         ECpub *pub;
587         int n;
588
589         n = (mpsignif(dom->p)+7)/8;
590         if(len != 1 + 2*n || data[0] != 0x04)
591                 return nil;
592         pub = mallocz(sizeof(*pub), 1);
593         if(pub == nil)
594                 return nil;
595         pub->x = betomp(data+1, n, nil);
596         pub->y = betomp(data+1+n, n, nil);
597         if(!ecpubverify(dom, pub)){
598                 ecpubfree(pub);
599                 pub = nil;
600         }
601         return pub;
602 }
603
604 void
605 ecpubfree(ECpub *p)
606 {
607         if(p == nil)
608                 return;
609         mpfree(p->x);
610         mpfree(p->y);
611         free(p);
612 }