#include <libsec.h>
mpint*
-dh_new(DHstate *dh, mpint *p, mpint *g)
+dh_new(DHstate *dh, mpint *p, mpint *q, mpint *g)
{
+ mpint *pm1;
+ int n;
+
memset(dh, 0, sizeof(*dh));
- dh->g = mpcopy(g);
+ if(mpcmp(g, mpone) <= 0)
+ return nil;
+
+ n = mpsignif(p);
+ pm1 = mpnew(n);
+ mpsub(p, mpone, pm1);
dh->p = mpcopy(p);
- if(dh->g != nil && dh->p != nil){
- dh->x = mprand(mpsignif(dh->p), genrandom, nil);
- dh->y = mpnew(0);
- if(dh->x != nil && dh->y != nil){
- mpexp(dh->g, dh->x, dh->p, dh->y);
- return dh->y;
- }
+ dh->g = mpcopy(g);
+ dh->q = mpcopy(q != nil ? q : pm1);
+ dh->x = mpnew(mpsignif(dh->q));
+ dh->y = mpnew(n);
+ for(;;){
+ mpnrand(dh->q, genrandom, dh->x);
+ mpexp(dh->g, dh->x, dh->p, dh->y);
+ if(mpcmp(dh->y, mpone) > 0 && mpcmp(dh->y, pm1) < 0)
+ break;
}
- dh_finish(dh, nil);
- return nil;
+ mpfree(pm1);
+
+ return dh->y;
}
mpint*
-dh_finish(DHstate *dh, mpint *pub)
+dh_finish(DHstate *dh, mpint *y)
{
- mpint *k;
+ mpint *k = nil;
+
+ if(y == nil || dh->x == nil || dh->p == nil || dh->q == nil)
+ goto Out;
+
+ /* y > 1 */
+ if(mpcmp(y, mpone) <= 0)
+ goto Out;
+
+ k = mpnew(mpsignif(dh->p));
- k = nil;
- if(pub != nil && dh->x != nil && dh->p != nil){
- if((k = mpnew(0)) != nil)
- mpexp(pub, dh->x, dh->p, k);
+ /* y < p-1 */
+ mpsub(dh->p, mpone, k);
+ if(mpcmp(y, k) >= 0){
+Bad:
+ mpfree(k);
+ k = nil;
+ goto Out;
}
- mpfree(dh->g);
+
+ /* y**q % p == 1 if q < p-1 */
+ if(mpcmp(dh->q, k) < 0){
+ mpexp(y, dh->q, dh->p, k);
+ if(mpcmp(k, mpone) != 0)
+ goto Bad;
+ }
+
+ mpexp(y, dh->x, dh->p, k);
+
+Out:
mpfree(dh->p);
+ mpfree(dh->q);
+ mpfree(dh->g);
mpfree(dh->x);
mpfree(dh->y);
memset(dh, 0, sizeof(*dh));
return k;
}
-