5 mpint* strtomp(char *buf, char **rptr, int base, mpint *b)
6 char* mptoa(mpint *b, int base, char *buf, int blen)
7 mpint* betomp(uchar *buf, uint blen, mpint *b)
8 int mptobe(mpint *b, uchar *buf, uint blen, uchar **bufp)
9 void mptober(mpint *b, uchar *buf, int blen)
10 mpint* letomp(uchar *buf, uint blen, mpint *b)
11 int mptole(mpint *b, uchar *buf, uint blen, uchar **bufp)
12 void mptolel(mpint *b, uchar *buf, int blen)
14 mpint* uitomp(uint, mpint*)
16 mpint* itomp(int, mpint*)
17 mpint* vtomp(vlong, mpint*)
19 mpint* uvtomp(uvlong, mpint*)
21 mpint* dtomp(double, mpint*)
23 void mpexp(mpint *b, mpint *e, mpint *m, mpint *res)
24 void mpmod(mpint *b, mpint *m, mpint *remainder)
25 void mpmodadd(mpint *b1, mpint *b2, mpint *m, mpint *sum)
26 void mpmodsub(mpint *b1, mpint *b2, mpint *m, mpint *diff)
27 void mpmodmul(mpint *b1, mpint *b2, mpint *m, mpint *prod)
28 void mpsel(int s, mpint *b1, mpint *b2, mpint *res)
29 void mpextendedgcd(mpint *a, mpint *b, mpint *d, mpint *x, mpint *y)
30 void mpinvert(mpint *b, mpint *m, mpint *res)
31 void mpdigdiv(mpdigit *dividend, mpdigit divisor, mpdigit *quotient)
32 void mpvecadd(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *sum)
33 void mpvecsub(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *diff)
34 void mpvecdigmuladd(mpdigit *b, int n, mpdigit m, mpdigit *p)
35 int mpvecdigmulsub(mpdigit *b, int n, mpdigit m, mpdigit *p)
36 void mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen,mpdigit *p)
37 int mpveccmp(mpdigit *a, int alen, mpdigit *b, int blen)
45 typedef struct ldint ldint;
52 ldint _ldzero = {1, (u8int*)"\0"};
53 ldint _ldone = {2, (u8int*)"\1\0"};
54 ldint *ldzero = &_ldzero;
55 ldint *ldone = &_ldone;
58 ldget(ldint *a, int n)
61 if(n >= a->n) return a->b[a->n - 1]&1;
66 ldbits(ldint *a, int n)
68 a->b = realloc(a->b, n);
78 for(i = a->n - 2; i >= 0; i--)
79 if(a->b[i] != a->b[a->n-1])
95 for(i = 0; i < a->n; i++){
100 if(c != a->b[a->n - 1]){
117 a = malloc(sizeof(ldint));
138 for(i = 0; i < a->n; i++)
143 ldrand(int n, ldint *a)
151 for(i = 0; i < n; i++)
152 a->b[i] = rand() & 1;
157 ldtomp(ldint *a, mpint *b)
164 s = a->b[a->n - 1] & 1;
167 memset(b->p, 0, (a->n + Dbits - 1) / Dbits * Dbytes);
168 for(i = 0; i < a->n; i++){
169 c += s ^ a->b[i] & 1;
170 b->p[i / Dbits] |= (mpdigit)(c & 1) << (i & Dbits - 1);
173 b->top = (a->n + Dbits - 1) / Dbits;
179 mptold(mpint *b, ldint *a)
186 for(i = 0; i <= b->top; i++)
187 for(j = 0; j < Dbits; j++)
188 if(Dbits * i + j < n)
189 a->b[Dbits * i + j] = b->p[i] >> j & 1;
192 for(i = 0; i < a->n; i++){
193 c += 1 ^ a->b[i] & 1;
201 itold(int n, ldint *a)
206 a = ldnew(sizeof(n)*8);
208 ldbits(a, sizeof(n)*8);
209 for(i = 0; i < sizeof(n)*8; i++)
210 a->b[i] = n >> i & 1;
216 pow2told(int n, ldint *a)
225 memset(a->b, 0, k+2);
239 a = va_arg(f->args, ldint *);
241 b = calloc(1, d + 1);
242 c = s = a->b[a->n - 1];
243 for(i = 0; i < a->n; i++){
245 b[d - 1 - (i >> 2)] |= (c & 1) << (i & 3);
248 for(i = 0; i < d; i++)
249 b[i] = "0123456789ABCDEF"[b[i]];
251 while(*p == '0' && p[1] != 0) p++;
252 if(a->b[a->n - 1]) fmtrune(f, '-');
253 fmtprint(f, "0x%s", p);
264 a = va_arg(f->args, mpint *);
265 fmtprint(f, "(sign=%d,top=%d,size=%d,", a->sign, a->top, a->size);
267 fmtprint(f, "%ullx", (uvlong)a->p[i]);
268 if(++i == a->top) break;
270 for(j = i+1; j < a->top; j++)
271 if(a->p[i] != a->p[j])
278 for(i=a->top;i<a->size;){
279 fmtprint(f, "%ullx", (uvlong)a->p[i]);
280 if(++i == a->size) break;
282 for(j = i+1; j < a->top; j++)
283 if(a->p[i] != a->p[j])
294 ldcmp(ldint *a, ldint *b)
300 if(a->b[a->n-1] != b->b[b->n-1])
301 return b->b[b->n - 1] - a->b[a->n - 1];
302 for(i = r - 1; --i >= 0; ){
312 ldmagcmp(ldint *a, ldint *b)
327 ldmpeq(ldint *a, mpint *b)
332 for(i = 0; i < b->top * Dbits; i++)
333 if(ldget(a, i) != (b->p[i / Dbits] >> (i & Dbits - 1) & 1))
341 for(i = 0; i < b->top * Dbits; i++){
343 if((c & 1) != (b->p[i / Dbits] >> (i & Dbits - 1) & 1))
362 mpbits(r, n * Dbits);
364 prng((void *) r->p, n * Dbytes);
365 r->sign = 1 - 2 * (rand() & 1);
370 ldadd(ldint *a, ldint *b, ldint *q)
374 r = max(a->n, b->n) + 1;
377 for(i = 0; i < r; i++){
378 c += ldget(a, i) + ldget(b, i);
386 ldmagadd(ldint *a, ldint *b, ldint *q)
388 int i, r, s1, s2, c1, c2, co;
390 r = max(a->n, b->n) + 2;
393 s1 = c1 = a->b[a->n - 1] & 1;
394 s2 = c2 = b->b[b->n - 1] & 1;
395 for(i = 0; i < r; i++){
396 c1 += s1 ^ ldget(a, i) & 1;
397 c2 += s2 ^ ldget(b, i) & 1;
398 co += (c1 & 1) + (c2 & 1);
408 ldmagsub(ldint *a, ldint *b, ldint *q)
410 int i, r, s1, s2, c1, c2, co;
412 r = max(a->n, b->n) + 2;
415 s1 = c1 = a->b[a->n - 1] & 1;
416 s2 = c2 = 1 ^ b->b[b->n - 1] & 1;
417 for(i = 0; i < r; i++){
418 c1 += s1 ^ ldget(a, i) & 1;
419 c2 += s2 ^ ldget(b, i) & 1;
420 co += (c1 & 1) + (c2 & 1);
430 ldsub(ldint *a, ldint *b, ldint *q)
434 r = max(a->n, b->n) + 1;
437 for(i = 0; i < r; i++){
438 c += ldget(a, i) + (1^ldget(b, i));
446 ldmul(ldint *a, ldint *b, ldint *q)
448 int c1, c2, co, s1, s2, so, i, j;
450 c1 = s1 = a->b[a->n - 1] & 1;
451 s2 = b->b[b->n - 1] & 1;
453 ldbits(q, a->n + b->n + 1);
454 memset(q->b, 0, a->n + b->n + 1);
455 for(i = 0; i < a->n; i++){
456 c1 += s1 ^ a->b[i] & 1;
459 for(j = 0; j < b->n; j++){
460 c2 += (s2 ^ b->b[j] & 1) + q->b[i + j];
461 q->b[i + j] = c2 & 1;
465 assert(i + j < q->n);
466 q->b[i + j] = c2 & 1;
473 for(i = 0; i < q->n; i++){
481 lddiv(ldint *a, ldint *b, ldint *q, ldint *r)
483 int n, i, j, c, s, k;
485 n = max(a->n, b->n) + 1;
489 c = s = a->b[a->n-1];
490 for(i = 0; i < n; i++){
491 c += s ^ ldget(a, i);
495 for(i = 0; i < n; i++){
496 for(j = n-1; --j >= 0; )
497 r->b[j + 1] = r->b[j];
498 r->b[0] = q->b[n - 1];
499 for(j = n-1; --j >= 0; )
500 q->b[j + 1] = q->b[j];
501 q->b[0] = !r->b[n - 1];
502 c = s = r->b[n - 1] == b->b[b->n - 1];
503 for(j = 0; j < n; j++){
504 c += r->b[j] + (s ^ ldget(b, j));
509 for(j = n-1; --j >= 0; )
510 q->b[j + 1] = q->b[j];
514 for(j = 0; j < n; j++){
519 c = s = b->b[b->n - 1];
520 for(j = 0; j < n; j++){
521 c += r->b[j] + (s ^ ldget(b, j));
526 c = s = a->b[a->n-1] ^ b->b[b->n-1];
527 for(j = 0; j < n; j++){
532 c = s = a->b[a->n-1];
533 for(j = 0; j < n; j++){
543 lddivq(ldint *a, ldint *b, ldint *q)
547 if(ldmpeq(b, mpzero)){
548 memset(q->b, 0, q->n);
557 mpdivq(mpint *a, mpint *b, mpint *q)
559 if(mpcmp(b, mpzero) == 0){
567 lddivr(ldint *a, ldint *b, ldint *r)
571 if(ldmpeq(b, mpzero)){
572 memset(r->b, 0, r->n);
581 mpdivr(mpint *a, mpint *b, mpint *r)
583 if(mpcmp(b, mpzero) == 0){
591 ldand(ldint *a, ldint *b, ldint *q)
597 for(i = 0; i < r; i++)
598 q->b[i] = ldget(a, i) & ldget(b, i);
603 ldbic(ldint *a, ldint *b, ldint *q)
609 for(i = 0; i < r; i++)
610 q->b[i] = ldget(a, i) & ~ldget(b, i);
615 ldor(ldint *a, ldint *b, ldint *q)
621 for(i = 0; i < r; i++)
622 q->b[i] = ldget(a, i) | ldget(b, i);
627 ldxor(ldint *a, ldint *b, ldint *q)
633 for(i = 0; i < r; i++)
634 q->b[i] = ldget(a, i) ^ ldget(b, i);
639 ldleft(ldint *a, int n, ldint *b)
651 for(i = 0; i < -n; i++)
657 for(i = 0; i < a->n + n; i++){
658 c += a->b[i - n] & 1;
664 memmove(b->b + n, a->b, a->n);
671 ldasr(ldint *a, int n, ldint *b)
679 b->b[0] = a->b[a->n - 1];
683 memmove(b->b, a->b + n, a->n - n);
688 ldtrunc(ldint *a, int n, ldint *b)
693 memmove(b->b, a->b, n);
695 memmove(b->b, a->b, a->n);
696 memset(b->b + a->n, a->b[a->n - 1], n - a->n);
702 ldxtend(ldint *a, int n, ldint *b)
706 memmove(b->b, a->b, n);
708 memmove(b->b, a->b, a->n);
709 memset(b->b + a->n, a->b[a->n - 1], n - a->n);
715 mpnot_(mpint *a, int, mpint *b)
721 ldnot(ldint *a, int, ldint *b)
726 for(i = 0; i < a->n; i++)
727 b->b[i] = a->b[i] ^ 1;
730 enum { NTEST = 2*257 };
732 testgen(int i, ldint *a)
740 typedef struct Test2 Test2;
743 void (*dut)(mpint *, mpint *, mpint *);
744 void (*ref)(ldint *, ldint *, ldint *);
747 typedef struct Test1i Test1i;
750 enum { NONEG = 1 } flags;
751 void (*dut)(mpint *, int, mpint *);
752 void (*ref)(ldint *, int, ldint *);
756 validate(char *name, ldint *ex, mpint *res, char *str)
761 if(res->top == 0 && res->sign < 0){
762 fprint(2, "FAIL: %s: %s: got -0, shouldn't happen\n", name, str);
764 }else if(!ldmpeq(ex, res)){
765 fprint(2, "FAIL: %s: %s: got %#B, expected %L\n", name, str, res, ex);
773 test2(Test2 *t, ldint *a, ldint *b)
788 rv = validate(t->name, c, rc, smprint("%L and %L", a, b));
792 rv = validate(t->name, c, mb, smprint("%L and %L (aliased to result)", a, b));
796 rv = validate(t->name, c, ma, smprint("%L (aliased to result) and %L", a, b));
805 test2x(Test2 *t, ldint *a)
818 rv = validate(t->name, c, rc, smprint("%L and %L (aliased to each other)", a, a));
821 rv = validate(t->name, c, ma, smprint("%L and %L (both aliased to result)", a, a));
838 for(i = 0; i < NTEST; i++){
839 for(j = 0; j < NTEST; j++){
842 ok &= test2(t, a, b);
847 for(i = 1; i <= 4; i++)
848 for(j = 1; j <= 4; j++){
849 ldrand(i * Dbits, a);
850 ldrand(j * Dbits, b);
851 ok &= test2(t, a, b);
856 fprint(2, "%s: passed\n", t->name);
860 "mpadd", mpadd, ldadd,
861 "mpmagadd", mpmagadd, ldmagadd,
862 "mpsub", mpsub, ldsub,
863 "mpmagsub", mpmagsub, ldmagsub,
864 "mpand", mpand, ldand,
866 "mpbic", mpbic, ldbic,
867 "mpxor", mpxor, ldxor,
868 "mpmul", mpmul, ldmul,
869 "mpdiv(q)", mpdivq, lddivq,
870 "mpdiv(r)", mpdivr, lddivr,
878 for(t = tests2; t < tests2 + nelem(tests2); t++)
883 test1i(Test1i *t, ldint *a, int b)
896 rv = validate(t->name, c, rc, smprint("%L and %d", a, b));
899 rv = validate(t->name, c, ma, smprint("%L (aliased to result) and %d", a, b));
915 for(i = 0; i < NTEST; i++)
916 for(j = (t->flags & NONEG) != 0 ? 0 : -128; j <= 128; j++){
918 ok &= test1i(t, a, j);
923 fprint(2, "%s: passed\n", t->name);
928 "mpleft", 0, mpleft, ldleft,
929 "mpasr", 0, mpasr, ldasr,
930 "mptrunc", NONEG, mptrunc, ldtrunc,
931 "mpxtend", NONEG, mpxtend, ldxtend,
932 "mpnot", NONEG, mpnot_, ldnot, /* hack */
940 for(t = tests1i; t < tests1i + nelem(tests1i); t++)
956 for(i = 0; i < NTEST; i++){
958 for(j = 0; j < a->n; j++)
965 fprint(2, "FAIL: mplowbits0: %#B: got %d, expected %d\n", ma, k, j);
968 for(j = a->n - 2; j >= 0; j--)
969 if(a->b[j] != a->b[a->n-1])
971 for(k = j-1; k >= 0; k--)
974 if(a->b[a->n - 1] && k < 0) j++;
979 fprint(2, "FAIL: mpsignif: %#B: got %d, expected %d\n", ma, k, j);
983 if(sigok) fprint(2, "mpsignif: passed\n");
984 if(lowok0) fprint(2, "mplowbits0: passed\n");
1003 for(i = 0; i < NTEST; i++)
1004 for(j = 0; j < NTEST; j++){
1014 fprint(2, "FAIL: mpcmp: %L and %L: got %d, expected %d\n", a, b, k, l);
1020 k = mpmagcmp(ma, mb);
1024 fprint(2, "FAIL: mpmagcmp: %L and %L: got %d, expected %d\n", a, b, k, l);
1032 if(cmpok) fprint(2, "mpcmp: passed\n");
1033 if(magcmpok) fprint(2, "mpmagcmp: passed\n");
1039 fmtinstall('B', mpfmt);
1040 fmtinstall(L'β', mpdetfmt);
1041 fmtinstall('L', ldfmt);