]> git.lizzy.rs Git - plan9front.git/blobdiff - sys/src/libsec/port/dh.c
libsec: mpconv -> mpfmt
[plan9front.git] / sys / src / libsec / port / dh.c
index 70f6a864c37856dc319589a4a2d4389620164996..c43595e02590d1758243e624c7e49df7ec7bdd70 100644 (file)
@@ -3,38 +3,72 @@
 #include <libsec.h>
 
 mpint*
-dh_new(DHstate *dh, mpint *p, mpint *g)
+dh_new(DHstate *dh, mpint *p, mpint *q, mpint *g)
 {
+       mpint *pm1;
+       int n;
+
        memset(dh, 0, sizeof(*dh));
-       dh->g = mpcopy(g);
+       if(mpcmp(g, mpone) <= 0)
+               return nil;
+
+       n = mpsignif(p);
+       pm1 = mpnew(n);
+       mpsub(p, mpone, pm1);
        dh->p = mpcopy(p);
-       if(dh->g != nil && dh->p != nil){
-               dh->x = mprand(mpsignif(dh->p), genrandom, nil);
-               dh->y = mpnew(0);
-               if(dh->x != nil && dh->y != nil){
-                       mpexp(dh->g, dh->x, dh->p, dh->y);
-                       return dh->y;
-               }
+       dh->g = mpcopy(g);
+       dh->q = mpcopy(q != nil ? q : pm1);
+       dh->x = mpnew(mpsignif(dh->q));
+       dh->y = mpnew(n);
+       for(;;){
+               mpnrand(dh->q, genrandom, dh->x);
+               mpexp(dh->g, dh->x, dh->p, dh->y);
+               if(mpcmp(dh->y, mpone) > 0 && mpcmp(dh->y, pm1) < 0)
+                       break;
        }
-       dh_finish(dh, nil);
-       return nil;
+       mpfree(pm1);
+
+       return dh->y;
 }
 
 mpint*
-dh_finish(DHstate *dh, mpint *pub)
+dh_finish(DHstate *dh, mpint *y)
 {
-       mpint *k;
+       mpint *k = nil;
+
+       if(y == nil || dh->x == nil || dh->p == nil || dh->q == nil)
+               goto Out;
+
+       /* y > 1 */
+       if(mpcmp(y, mpone) <= 0)
+               goto Out;
+
+       k = mpnew(mpsignif(dh->p));
 
-       k = nil;
-       if(pub != nil && dh->x != nil && dh->p != nil){
-               if((k = mpnew(0)) != nil)
-                       mpexp(pub, dh->x, dh->p, k);
+       /* y < p-1 */
+       mpsub(dh->p, mpone, k);
+       if(mpcmp(y, k) >= 0){
+Bad:
+               mpfree(k);
+               k = nil;
+               goto Out;
        }
-       mpfree(dh->g);
+
+       /* y**q % p == 1 if q < p-1 */
+       if(mpcmp(dh->q, k) < 0){
+               mpexp(y, dh->q, dh->p, k);
+               if(mpcmp(k, mpone) != 0)
+                       goto Bad;
+       }
+
+       mpexp(y, dh->x, dh->p, k);
+
+Out:
        mpfree(dh->p);
+       mpfree(dh->q);
+       mpfree(dh->g);
        mpfree(dh->x);
        mpfree(dh->y);
        memset(dh, 0, sizeof(*dh));
        return k;
 }
-