]> git.lizzy.rs Git - plan9front.git/blob - sys/src/libsec/port/ecc.c
libsec: make sure Elem is zero initialized so freevalfields() wont cause accidents
[plan9front.git] / sys / src / libsec / port / ecc.c
1 #include "os.h"
2 #include <mp.h>
3 #include <libsec.h>
4 #include <ctype.h>
5
6 void
7 ecassign(ECdomain *, ECpoint *a, ECpoint *b)
8 {
9         b->inf = a->inf;
10         mpassign(a->x, b->x);
11         mpassign(a->y, b->y);
12 }
13
14 void
15 ecadd(ECdomain *dom, ECpoint *a, ECpoint *b, ECpoint *s)
16 {
17         mpint *l, *k, *sx, *sy;
18
19         if(a->inf && b->inf){
20                 s->inf = 1;
21                 return;
22         }
23         if(a->inf){
24                 ecassign(dom, b, s);
25                 return;
26         }
27         if(b->inf){
28                 ecassign(dom, a, s);
29                 return;
30         }
31         if(mpcmp(a->x, b->x) == 0 && (mpcmp(a->y, mpzero) == 0 || mpcmp(a->y, b->y) != 0)){
32                 s->inf = 1;
33                 return;
34         }
35         l = mpnew(0);
36         k = mpnew(0);
37         sx = mpnew(0);
38         sy = mpnew(0);
39         if(mpcmp(a->x, b->x) == 0 && mpcmp(a->y, b->y) == 0){
40                 mpadd(mpone, mptwo, k);
41                 mpmul(a->x, a->x, l);
42                 mpmul(l, k, l);
43                 mpadd(l, dom->a, l);
44                 mpleft(a->y, 1, k);
45                 mpmod(k, dom->p, k);
46                 mpinvert(k, dom->p, k);
47                 mpmul(k, l, l);
48                 mpmod(l, dom->p, l);
49
50                 mpleft(a->x, 1, k);
51                 mpmul(l, l, sx);
52                 mpsub(sx, k, sx);
53                 mpmod(sx, dom->p, sx);
54
55                 mpsub(a->x, sx, sy);
56                 mpmul(l, sy, sy);
57                 mpsub(sy, a->y, sy);
58                 mpmod(sy, dom->p, sy);
59                 mpassign(sx, s->x);
60                 mpassign(sy, s->y);
61                 mpfree(sx);
62                 mpfree(sy);
63                 mpfree(l);
64                 mpfree(k);
65                 return;
66         }
67         mpsub(b->y, a->y, l);
68         mpmod(l, dom->p, l);
69         mpsub(b->x, a->x, k);
70         mpmod(k, dom->p, k);
71         mpinvert(k, dom->p, k);
72         mpmul(k, l, l);
73         mpmod(l, dom->p, l);
74         
75         mpmul(l, l, sx);
76         mpsub(sx, a->x, sx);
77         mpsub(sx, b->x, sx);
78         mpmod(sx, dom->p, sx);
79         
80         mpsub(a->x, sx, sy);
81         mpmul(sy, l, sy);
82         mpsub(sy, a->y, sy);
83         mpmod(sy, dom->p, sy);
84         
85         mpassign(sx, s->x);
86         mpassign(sy, s->y);
87         mpfree(sx);
88         mpfree(sy);
89         mpfree(l);
90         mpfree(k);
91 }
92
93 void
94 ecmul(ECdomain *dom, ECpoint *a, mpint *k, ECpoint *s)
95 {
96         ECpoint ns, na;
97         mpint *l;
98
99         if(a->inf || mpcmp(k, mpzero) == 0){
100                 s->inf = 1;
101                 return;
102         }
103         ns.inf = 1;
104         ns.x = mpnew(0);
105         ns.y = mpnew(0);
106         na.x = mpnew(0);
107         na.y = mpnew(0);
108         ecassign(dom, a, &na);
109         l = mpcopy(k);
110         l->sign = 1;
111         while(mpcmp(l, mpzero) != 0){
112                 if(l->p[0] & 1)
113                         ecadd(dom, &na, &ns, &ns);
114                 ecadd(dom, &na, &na, &na);
115                 mpright(l, 1, l);
116         }
117         if(k->sign < 0){
118                 ns.y->sign = -1;
119                 mpmod(ns.y, dom->p, ns.y);
120         }
121         ecassign(dom, &ns, s);
122         mpfree(ns.x);
123         mpfree(ns.y);
124         mpfree(na.x);
125         mpfree(na.y);
126         mpfree(l);
127 }
128
129 int
130 ecverify(ECdomain *dom, ECpoint *a)
131 {
132         mpint *p, *q;
133         int r;
134
135         if(a->inf)
136                 return 1;
137         
138         p = mpnew(0);
139         q = mpnew(0);
140         mpmul(a->y, a->y, p);
141         mpmod(p, dom->p, p);
142         mpmul(a->x, a->x, q);
143         mpadd(q, dom->a, q);
144         mpmul(a->x, q, q);
145         mpadd(q, dom->b, q);
146         mpmod(q, dom->p, q);
147         r = mpcmp(p, q);
148         mpfree(p);
149         mpfree(q);
150         return r == 0;
151 }
152
153 int
154 ecpubverify(ECdomain *dom, ECpub *a)
155 {
156         ECpoint p;
157         int r;
158
159         if(a->inf)
160                 return 0;
161         if(!ecverify(dom, a))
162                 return 0;
163         p.x = mpnew(0);
164         p.y = mpnew(0);
165         ecmul(dom, a, dom->n, &p);
166         r = p.inf;
167         mpfree(p.x);
168         mpfree(p.y);
169         return r;
170 }
171
172 static void
173 fixnibble(uchar *a)
174 {
175         if(*a >= 'a')
176                 *a -= 'a'-10;
177         else if(*a >= 'A')
178                 *a -= 'A'-10;
179         else
180                 *a -= '0';
181 }
182
183 static int
184 octet(char **s)
185 {
186         uchar c, d;
187         
188         c = *(*s)++;
189         if(!isxdigit(c))
190                 return -1;
191         d = *(*s)++;
192         if(!isxdigit(d))
193                 return -1;
194         fixnibble(&c);
195         fixnibble(&d);
196         return (c << 4) | d;
197 }
198
199 static mpint*
200 halfpt(ECdomain *dom, char *s, char **rptr, mpint *out)
201 {
202         char *buf, *r;
203         int n;
204         mpint *ret;
205         
206         n = ((mpsignif(dom->p)+7)/8)*2;
207         if(strlen(s) < n)
208                 return 0;
209         buf = malloc(n+1);
210         buf[n] = 0;
211         memcpy(buf, s, n);
212         ret = strtomp(buf, &r, 16, out);
213         *rptr = s + (r - buf);
214         free(buf);
215         return ret;
216 }
217
218 static int
219 mpleg(mpint *a, mpint *b)
220 {
221         int r, k;
222         mpint *m, *n, *t;
223         
224         r = 1;
225         m = mpcopy(a);
226         n = mpcopy(b);
227         for(;;){
228                 if(mpcmp(m, n) > 0)
229                         mpmod(m, n, m);
230                 if(mpcmp(m, mpzero) == 0){
231                         r = 0;
232                         break;
233                 }
234                 if(mpcmp(m, mpone) == 0)
235                         break;
236                 k = mplowbits0(m);
237                 if(k > 0){
238                         if(k & 1)
239                                 switch(n->p[0] & 15){
240                                 case 3: case 5: case 11: case 13:
241                                         r = -r;
242                                 }
243                         mpright(m, k, m);
244                 }
245                 if((n->p[0] & 3) == 3 && (m->p[0] & 3) == 3)
246                         r = -r;
247                 t = m;
248                 m = n;
249                 n = t;
250         }
251         mpfree(m);
252         mpfree(n);
253         return r;
254 }
255
256 static int
257 mpsqrt(mpint *n, mpint *p, mpint *r)
258 {
259         mpint *a, *t, *s, *xp, *xq, *yp, *yq, *zp, *zq, *N;
260
261         if(mpleg(n, p) == -1)
262                 return 0;
263         a = mpnew(0);
264         t = mpnew(0);
265         s = mpnew(0);
266         N = mpnew(0);
267         xp = mpnew(0);
268         xq = mpnew(0);
269         yp = mpnew(0);
270         yq = mpnew(0);
271         zp = mpnew(0);
272         zq = mpnew(0);
273         for(;;){
274                 for(;;){
275                         mprand(mpsignif(p), genrandom, a);
276                         if(mpcmp(a, mpzero) > 0 && mpcmp(a, p) < 0)
277                                 break;
278                 }
279                 mpmul(a, a, t);
280                 mpsub(t, n, t);
281                 mpmod(t, p, t);
282                 if(mpleg(t, p) == -1)
283                         break;
284         }
285         mpadd(p, mpone, N);
286         mpright(N, 1, N);
287         mpmul(a, a, t);
288         mpsub(t, n, t);
289         mpassign(a, xp);
290         uitomp(1, xq);
291         uitomp(1, yp);
292         uitomp(0, yq);
293         while(mpcmp(N, mpzero) != 0){
294                 if(N->p[0] & 1){
295                         mpmul(xp, yp, zp);
296                         mpmul(xq, yq, zq);
297                         mpmul(zq, t, zq);
298                         mpadd(zp, zq, zp);
299                         mpmod(zp, p, zp);
300                         mpmul(xp, yq, zq);
301                         mpmul(xq, yp, s);
302                         mpadd(zq, s, zq);
303                         mpmod(zq, p, yq);
304                         mpassign(zp, yp);
305                 }
306                 mpmul(xp, xp, zp);
307                 mpmul(xq, xq, zq);
308                 mpmul(zq, t, zq);
309                 mpadd(zp, zq, zp);
310                 mpmod(zp, p, zp);
311                 mpmul(xp, xq, zq);
312                 mpadd(zq, zq, zq);
313                 mpmod(zq, p, xq);
314                 mpassign(zp, xp);
315                 mpright(N, 1, N);
316         }
317         if(mpcmp(yq, mpzero) != 0)
318                 abort();
319         mpassign(yp, r);
320         mpfree(a);
321         mpfree(t);
322         mpfree(s);
323         mpfree(N);
324         mpfree(xp);
325         mpfree(xq);
326         mpfree(yp);
327         mpfree(yq);
328         mpfree(zp);
329         mpfree(zq);
330         return 1;
331 }
332
333 ECpoint*
334 strtoec(ECdomain *dom, char *s, char **rptr, ECpoint *ret)
335 {
336         int allocd, o;
337         mpint *r;
338
339         allocd = 0;
340         if(ret == nil){
341                 allocd = 1;
342                 ret = mallocz(sizeof(*ret), 1);
343                 if(ret == nil)
344                         return nil;
345                 ret->x = mpnew(0);
346                 ret->y = mpnew(0);
347         }
348         o = 0;
349         switch(octet(&s)){
350         case 0:
351                 ret->inf = 1;
352                 return ret;
353         case 3:
354                 o = 1;
355         case 2:
356                 if(halfpt(dom, s, &s, ret->x) == nil)
357                         goto err;
358                 r = mpnew(0);
359                 mpmul(ret->x, ret->x, r);
360                 mpadd(r, dom->a, r);
361                 mpmul(r, ret->x, r);
362                 mpadd(r, dom->b, r);
363                 if(!mpsqrt(r, dom->p, r)){
364                         mpfree(r);
365                         goto err;
366                 }
367                 if((r->p[0] & 1) != o)
368                         mpsub(dom->p, r, r);
369                 mpassign(r, ret->y);
370                 mpfree(r);
371                 if(!ecverify(dom, ret))
372                         goto err;
373                 return ret;
374         case 4:
375                 if(halfpt(dom, s, &s, ret->x) == nil)
376                         goto err;
377                 if(halfpt(dom, s, &s, ret->y) == nil)
378                         goto err;
379                 if(!ecverify(dom, ret))
380                         goto err;
381                 return ret;
382         }
383 err:
384         if(rptr)
385                 *rptr = s;
386         if(allocd){
387                 mpfree(ret->x);
388                 mpfree(ret->y);
389                 free(ret);
390         }
391         return nil;
392 }
393
394 ECpriv*
395 ecgen(ECdomain *dom, ECpriv *p)
396 {
397         if(p == nil){
398                 p = mallocz(sizeof(*p), 1);
399                 if(p == nil)
400                         return nil;
401                 p->x = mpnew(0);
402                 p->y = mpnew(0);
403                 p->d = mpnew(0);
404         }
405         for(;;){
406                 mprand(mpsignif(dom->n), genrandom, p->d);
407                 if(mpcmp(p->d, mpzero) > 0 && mpcmp(p->d, dom->n) < 0)
408                         break;
409         }
410         ecmul(dom, dom->G, p->d, p);
411         return p;
412 }
413
414 void
415 ecdsasign(ECdomain *dom, ECpriv *priv, uchar *dig, int len, mpint *r, mpint *s)
416 {
417         ECpriv tmp;
418         mpint *E, *t;
419
420         tmp.x = mpnew(0);
421         tmp.y = mpnew(0);
422         tmp.d = mpnew(0);
423         E = betomp(dig, len, nil);
424         t = mpnew(0);
425         if(mpsignif(dom->n) < 8*len)
426                 mpright(E, 8*len - mpsignif(dom->n), E);
427         for(;;){
428                 ecgen(dom, &tmp);
429                 mpmod(tmp.x, dom->n, r);
430                 if(mpcmp(r, mpzero) == 0)
431                         continue;
432                 mpmul(r, priv->d, s);
433                 mpadd(E, s, s);
434                 mpinvert(tmp.d, dom->n, t);
435                 mpmul(s, t, s);
436                 mpmod(s, dom->n, s);
437                 if(mpcmp(s, mpzero) != 0)
438                         break;
439         }
440         mpfree(t);
441         mpfree(E);
442         mpfree(tmp.x);
443         mpfree(tmp.y);
444         mpfree(tmp.d);
445 }
446
447 int
448 ecdsaverify(ECdomain *dom, ECpub *pub, uchar *dig, int len, mpint *r, mpint *s)
449 {
450         mpint *E, *t, *u1, *u2;
451         ECpoint R, S;
452         int ret;
453
454         if(mpcmp(r, mpone) < 0 || mpcmp(s, mpone) < 0 || mpcmp(r, dom->n) >= 0 || mpcmp(r, dom->n) >= 0)
455                 return 0;
456         E = betomp(dig, len, nil);
457         if(mpsignif(dom->n) < 8*len)
458                 mpright(E, 8*len - mpsignif(dom->n), E);
459         t = mpnew(0);
460         u1 = mpnew(0);
461         u2 = mpnew(0);
462         R.x = mpnew(0);
463         R.y = mpnew(0);
464         S.x = mpnew(0);
465         S.y = mpnew(0);
466         mpinvert(s, dom->n, t);
467         mpmul(E, t, u1);
468         mpmod(u1, dom->n, u1);
469         mpmul(r, t, u2);
470         mpmod(u2, dom->n, u2);
471         ecmul(dom, dom->G, u1, &R);
472         ecmul(dom, pub, u2, &S);
473         ecadd(dom, &R, &S, &R);
474         ret = 0;
475         if(!R.inf){
476                 mpmod(R.x, dom->n, t);
477                 ret = mpcmp(r, t) == 0;
478         }
479         mpfree(t);
480         mpfree(u1);
481         mpfree(u2);
482         mpfree(R.x);
483         mpfree(R.y);
484         mpfree(S.x);
485         mpfree(S.y);
486         return ret;
487 }
488
489 static char *code = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz";
490
491 void
492 base58enc(uchar *src, char *dst, int len)
493 {
494         mpint *n, *r, *b;
495         char *sdst, t;
496         
497         sdst = dst;
498         n = betomp(src, len, nil);
499         b = uitomp(58, nil);
500         r = mpnew(0);
501         while(mpcmp(n, mpzero) != 0){
502                 mpdiv(n, b, n, r);
503                 *dst++ = code[mptoui(r)];
504         }
505         for(; *src == 0; src++)
506                 *dst++ = code[0];
507         dst--;
508         while(dst > sdst){
509                 t = *sdst;
510                 *sdst++ = *dst;
511                 *dst-- = t;
512         }
513 }
514
515 int
516 base58dec(char *src, uchar *dst, int len)
517 {
518         mpint *n, *b, *r;
519         char *t;
520         int l;
521         
522         n = mpnew(0);
523         r = mpnew(0);
524         b = uitomp(58, nil);
525         for(; *src; src++){
526                 t = strchr(code, *src);
527                 if(t == nil){
528                         mpfree(n);
529                         mpfree(r);
530                         mpfree(b);
531                         werrstr("invalid base58 char");
532                         return -1;
533                 }
534                 uitomp(t - code, r);
535                 mpmul(n, b, n);
536                 mpadd(n, r, n);
537         }
538         memset(dst, 0, len);
539         l = (mpsignif(n) + 7) / 8;
540         mptobe(n, dst + (len - l), l, nil);
541         mpfree(n);
542         mpfree(r);
543         mpfree(b);
544         return 0;
545 }