]> git.lizzy.rs Git - plan9front.git/blob - sys/src/libmp/ntest.c
libmp: add new tests
[plan9front.git] / sys / src / libmp / ntest.c
1 #include <u.h>
2 #include <libc.h>
3 #include <mp.h>
4 #include <libsec.h>
5
6 typedef struct ldint ldint;
7
8 struct ldint {
9         int n;
10         u8int *b;
11 };
12
13 ldint _ldzero = {1, (u8int*)"\0"};
14 ldint _ldone = {2, (u8int*)"\1\0"};
15 ldint *ldzero = &_ldzero;
16 ldint *ldone = &_ldone;
17
18 static int
19 ldget(ldint *a, int n)
20 {
21         if(n < 0) return 0;
22         if(n >= a->n) return a->b[a->n - 1]&1;
23         return a->b[n]&1;
24 }
25
26 static void
27 ldbits(ldint *a, int n)
28 {
29         a->b = realloc(a->b, n);
30         a->n = n;
31 }
32
33 static ldint *
34 ldnorm(ldint *a)
35 {
36         int i;
37
38         if(a->n > 0){
39                 for(i = a->n - 2; i >= 0; i--)
40                         if(a->b[i] != a->b[a->n-1])
41                                 break;
42                 ldbits(a, i + 2);
43         }
44         return a;
45 }
46
47 static void
48 ldneg(ldint *a)
49 {
50         int i, c;
51         
52         c = 1;
53         for(i = 0; i < a->n; i++){
54                 c += 1 ^ a->b[i] & 1;
55                 a->b[i] = c & 1;
56                 c >>= 1;
57         }
58 }
59
60 static int
61 max(int a, int b)
62 {
63         return a>b? a : b;
64 }
65
66 static ldint *
67 ldnew(int n)
68 {
69         ldint *a;
70         
71         a = malloc(sizeof(ldint));
72         a->b = malloc(n);
73         a->n = n;
74         return a;
75 }
76
77 static void
78 ldfree(ldint *a)
79 {
80         if(a == nil) return;
81         free(a->b);
82         free(a);
83 }
84
85 static void
86 ldsanity(ldint *a)
87 {
88         int i;
89         
90         for(i = 0; i < a->n; i++)
91                 assert(a->b[i] < 2);
92 }
93
94 static ldint *
95 ldrand(int n, ldint *a)
96 {
97         int i;
98         
99         if(a == nil)
100                 a = ldnew(n);
101         else
102                 ldbits(a, n);
103         for(i = 0; i < n; i++)
104                 a->b[i] = rand() & 1;
105         return a;
106 }
107
108 static mpint *
109 ldtomp(ldint *a, mpint *b)
110 {
111         int s, c, i;
112
113         if(b == nil)
114                 b = mpnew(0);
115         mpbits(b, a->n);
116         s = a->b[a->n - 1] & 1;
117         b->sign = 1 - 2 * s;
118         c = s;
119         memset(b->p, 0, (a->n + Dbits - 1) / Dbits * Dbytes);
120         for(i = 0; i < a->n; i++){
121                 c += s ^ a->b[i] & 1;
122                 b->p[i / Dbits] |= (mpdigit)(c & 1) << (i & Dbits - 1);
123                 c >>= 1;
124         }
125         b->top = (a->n + Dbits - 1) / Dbits;
126         mpnorm(b);
127         return b;
128 }
129
130 static void
131 mptold(mpint *b, ldint *a)
132 {
133         int i, j, n, c;
134
135         n = mpsignif(b) + 1;
136         ldbits(a, n);
137         memset(a->b, 0, n);
138         for(i = 0; i <= b->top; i++)
139                 for(j = 0; j < Dbits; j++)
140                         if(Dbits * i + j < n)
141                                 a->b[Dbits * i + j] = b->p[i] >> j & 1;
142         if(b->sign < 0){
143                 c = 1;
144                 for(i = 0; i < a->n; i++){
145                         c += 1 ^ a->b[i] & 1;
146                         a->b[i] = c & 1;
147                         c >>= 1;
148                 }
149         }       
150 }
151
152 static ldint *
153 itold(int n, ldint *a)
154 {
155         int i;
156
157         if(a == nil)
158                 a = ldnew(sizeof(n)*8);
159         else
160                 ldbits(a, sizeof(n)*8);
161         for(i = 0; i < sizeof(n)*8; i++)
162                 a->b[i] = n >> i & 1;
163         ldnorm(a);
164         return a;
165 }
166
167 static ldint *
168 pow2told(int n, ldint *a)
169 {
170         int k;
171         
172         k = abs(n);
173         if(a == nil)
174                 a = ldnew(k+2);
175         else
176                 ldbits(a, k+2);
177         memset(a->b, 0, k+2);
178         a->b[k] = 1;
179         if(n < 0) ldneg(a);
180         return a;
181 }
182
183 static int
184 ldfmt(Fmt *f)
185 {
186         ldint *a;
187         char *b, *p;
188         int i, d, s, c;
189         
190         a = va_arg(f->args, ldint *);
191         d = (a->n + 3) / 4;
192         b = calloc(1, d + 1);
193         c = s = a->b[a->n - 1];
194         for(i = 0; i < a->n; i++){
195                 c += s^ldget(a, i);
196                 b[d - 1 - (i >> 2)] |= (c & 1) << (i & 3);
197                 c >>= 1;
198         }
199         for(i = 0; i < d; i++)
200                 b[i] = "0123456789ABCDEF"[b[i]];
201         p = b;
202         while(*p == '0' && p[1] != 0) p++;
203         if(a->b[a->n - 1]) fmtrune(f, '-');
204         fmtprint(f, "0x%s", p);
205         free(b);
206         return 0;
207 }
208
209 static int
210 ldmpeq(ldint *a, mpint *b)
211 {
212         int i, c;
213
214         if(b->sign > 0){
215                 for(i = 0; i < b->top * Dbits; i++)
216                         if(ldget(a, i) != (b->p[i / Dbits] >> (i & Dbits - 1) & 1))
217                                 return 0;
218                 for(; i < a->n; i++)
219                         if(a->b[i] != 0)
220                                 return 0;
221                 return 1;
222         }else{
223                 c = 1;
224                 for(i = 0; i < b->top * Dbits; i++){
225                         c += !ldget(a, i);
226                         if((c & 1) != (b->p[i / Dbits] >> (i & Dbits - 1) & 1))
227                                 return 0;
228                         c >>= 1;
229                 }
230                 for(; i < a->n; i++)
231                         if(a->b[i] != 1)
232                                 return 0;
233                 return 1;
234         }
235 }
236
237 static mpint *
238 mptarget(void)
239 {
240         mpint *r;
241         int i, n;
242         
243         r = mpnew(0);
244         n = nrand(16);
245         mpbits(r, n * Dbits);
246         r->top = n;
247         prng((void *) r->p, n * Dbytes);
248         r->sign = 1 - 2 * (rand() & 1);
249         return r;
250 }
251
252 static void
253 ldadd(ldint *a, ldint *b, ldint *q)
254 {
255         int r, i, x, c;
256         
257         r = max(a->n, b->n) + 1;
258         ldbits(q, r);
259         c = 0;
260         for(i = 0; i < r; i++){
261                 c += ldget(a, i) + ldget(b, i);
262                 q->b[i] = c & 1;
263                 c >>= 1;
264         }
265         ldnorm(q);
266 }
267
268 static void
269 ldsub(ldint *a, ldint *b, ldint *q)
270 {
271         int r, i, x, c;
272         
273         r = max(a->n, b->n) + 1;
274         ldbits(q, r);
275         c = 1;
276         for(i = 0; i < r; i++){
277                 c += ldget(a, i) + (1^ldget(b, i));
278                 q->b[i] = c & 1;
279                 c >>= 1;
280         }
281         ldnorm(q);
282 }
283
284 static void
285 ldmul(ldint *a, ldint *b, ldint *q)
286 {
287         int c1, c2, co, s1, s2, so, i, j;
288         
289         c1 = s1 = a->b[a->n - 1] & 1;
290         s2 = b->b[b->n - 1] & 1;
291         so = s1 ^ s2;
292         ldbits(q, a->n + b->n + 1);
293         memset(q->b, 0, a->n + b->n + 1);
294         for(i = 0; i < a->n; i++){
295                 c1 += s1 ^ a->b[i] & 1;
296                 if((c1 & 1) != 0){
297                         c2 = s2;
298                         for(j = 0; j < b->n; j++){
299                                 c2 += (s2 ^ b->b[j] & 1) + q->b[i + j];
300                                 q->b[i + j] = c2 & 1;
301                                 c2 >>= 1;
302                         }
303                         for(; c2 > 0; j++){
304                                 assert(i + j < q->n);
305                                 q->b[i + j] = c2 & 1;
306                                 c2 >>= 1;
307                         }
308                 }
309                 c1 >>= 1;
310         }
311         co = so;
312         for(i = 0; i < q->n; i++){
313                 co += so ^ q->b[i];
314                 q->b[i] = co & 1;
315                 co >>= 1;
316         }
317 }
318
319 static void
320 lddiv(ldint *a, ldint *b, ldint *q, ldint *r)
321 {
322         int n, i, j, c, s, k;
323         
324         n = max(a->n, b->n) + 1;
325         ldbits(q, n);
326         ldbits(r, n);
327         memset(r->b, 0, n);
328         c = s = a->b[a->n-1];
329         for(i = 0; i < n; i++){
330                 c += s ^ ldget(a, i);
331                 q->b[i] = c & 1;
332                 c >>= 1;
333         }
334         for(i = 0; i < n; i++){
335                 for(j = n-1; --j >= 0; )
336                         r->b[j + 1] = r->b[j];
337                 r->b[0] = q->b[n - 1];
338                 for(j = n-1; --j >= 0; )
339                         q->b[j + 1] = q->b[j];
340                 q->b[0] = !r->b[n - 1];
341                 c = s = r->b[n - 1] == b->b[b->n - 1];
342                 for(j = 0; j < n; j++){
343                         c += r->b[j] + (s ^ ldget(b, j));
344                         r->b[j] = c & 1;
345                         c >>= 1;
346                 }
347         }
348         for(j = n-1; --j >= 0; )
349                 q->b[j + 1] = q->b[j];
350         q->b[0] = 1;
351         if(r->b[r->n - 1]){
352                 c = 0;
353                 for(j = 0; j < n; j++){
354                         c += 1 + q->b[j];
355                         q->b[j] = c & 1;
356                         c >>= 1;
357                 }
358                 c = s = b->b[b->n - 1];
359                 for(j = 0; j < n; j++){
360                         c += r->b[j] + (s ^ ldget(b, j));
361                         r->b[j] = c & 1;
362                         c >>= 1;
363                 }
364         }
365         c = s = a->b[a->n-1] ^ b->b[b->n-1];
366         for(j = 0; j < n; j++){
367                 c += s ^ q->b[j];
368                 q->b[j] = c & 1;
369                 c >>= 1;
370         }
371         c = s = a->b[a->n-1];
372         for(j = 0; j < n; j++){
373                 c += s ^ r->b[j];
374                 r->b[j] = c & 1;
375                 c >>= 1;
376         }
377         ldnorm(q);
378         ldnorm(r);
379 }
380
381 static void
382 lddivq(ldint *a, ldint *b, ldint *q)
383 {
384         ldint *r;
385         
386         if(ldmpeq(b, mpzero)){
387                 memset(q->b, 0, q->n);
388                 return;
389         }
390         r = ldnew(0);
391         lddiv(a, b, q, r);
392         ldfree(r);
393 }
394
395 static void
396 mpdivq(mpint *a, mpint *b, mpint *q)
397 {
398         if(mpcmp(b, mpzero) == 0){
399                 mpassign(mpzero, q);
400                 return;
401         }
402         mpdiv(a, b, q, nil);
403 }
404
405 static void
406 lddivr(ldint *a, ldint *b, ldint *r)
407 {
408         ldint *q;
409         
410         if(ldmpeq(b, mpzero)){
411                 memset(r->b, 0, r->n);
412                 return;
413         }
414         q = ldnew(0);
415         lddiv(a, b, q, r);
416         ldfree(q);
417 }
418
419 static void
420 mpdivr(mpint *a, mpint *b, mpint *r)
421 {
422         if(mpcmp(b, mpzero) == 0){
423                 mpassign(mpzero, r);
424                 return;
425         }
426         mpdiv(a, b, nil, r);
427 }
428
429 static void
430 ldand(ldint *a, ldint *b, ldint *q)
431 {
432         int r, i, x, c;
433         
434         r = max(a->n, b->n);
435         ldbits(q, r);
436         for(i = 0; i < r; i++)
437                 q->b[i] = ldget(a, i) & ldget(b, i);
438         ldnorm(q);
439 }
440
441 static void
442 ldbic(ldint *a, ldint *b, ldint *q)
443 {
444         int r, i, x, c;
445         
446         r = max(a->n, b->n);
447         ldbits(q, r);
448         for(i = 0; i < r; i++)
449                 q->b[i] = ldget(a, i) & ~ldget(b, i);
450         ldnorm(q);
451 }
452
453 static void
454 ldor(ldint *a, ldint *b, ldint *q)
455 {
456         int r, i, x, c;
457         
458         r = max(a->n, b->n);
459         ldbits(q, r);
460         for(i = 0; i < r; i++)
461                 q->b[i] = ldget(a, i) | ldget(b, i);
462         ldnorm(q);
463 }
464
465 static void
466 ldxor(ldint *a, ldint *b, ldint *q)
467 {
468         int r, i, x, c;
469         
470         r = max(a->n, b->n);
471         ldbits(q, r);
472         for(i = 0; i < r; i++)
473                 q->b[i] = ldget(a, i) ^ ldget(b, i);
474         ldnorm(q);
475 }
476
477 typedef struct Test2 Test2;
478 struct Test2 {
479         char *name;
480         void (*dut)(mpint *, mpint *, mpint *);
481         void (*ref)(ldint *, ldint *, ldint *);
482 };
483
484 int
485 validate(char *name, ldint *ex, mpint *res, char *str)
486 {
487         int rv;
488
489         rv = 1;
490         if(res->top == 0 && res->sign < 0){
491                 fprint(2, "FAIL: %s: %s: got -0, shouldn't happen\n", name, str);
492                 rv =0;
493         }else if(!ldmpeq(ex, res)){
494                 fprint(2, "FAIL: %s: %s: got %#B, expected %L\n", name, str, res, ex);
495                 rv = 0;
496         }
497         free(str);
498         return rv;
499 }
500
501 int
502 test2(Test2 *t, ldint *a, ldint *b)
503 {
504         ldint *c;
505         mpint *ma, *mb, *rc;
506         int rv;
507         
508         c = ldnew(0);
509         t->ref(a, b, c);
510         ldsanity(a);
511         ldsanity(b);
512         ldsanity(c);
513         ma = ldtomp(a, nil);
514         mb = ldtomp(b, nil);
515         rc = mptarget();
516         t->dut(ma, mb, rc);
517         rv = validate(t->name, c, rc, smprint("%L and %L", a, b));
518         ldtomp(a, ma);
519         ldtomp(b, mb);
520         t->dut(ma, mb, mb);
521         rv = validate(t->name, c, mb, smprint("%L and %L (aliased to result)", a, b));
522         ldtomp(a, ma);
523         ldtomp(b, mb);
524         t->dut(ma, mb, ma);
525         rv = validate(t->name, c, ma, smprint("%L (aliased to result) and %L", a, b));
526         ldfree(c);
527         mpfree(rc);
528         mpfree(ma);
529         mpfree(mb);
530         return rv;
531 }
532
533 int
534 test2x(Test2 *t, ldint *a)
535 {
536         ldint *c;
537         mpint *ma, *rc;
538         int rv;
539         
540         c = ldnew(0);
541         t->ref(a, a, c);
542         ldsanity(a);
543         ldsanity(c);
544         ma = ldtomp(a, nil);
545         rc = mptarget();
546         t->dut(ma, ma, rc);
547         rv = validate(t->name, c, rc, smprint("%L and %L (aliased to each other)", a, a));
548         ldtomp(a, ma);
549         t->dut(ma, ma, ma);
550         rv = validate(t->name, c, ma, smprint("%L and %L (both aliased to result)", a, a));
551         ldfree(c);
552         mpfree(rc);
553         mpfree(ma);
554         return rv;
555 }
556
557 void
558 run2(Test2 *t)
559 {
560         int i, j, ok;
561         ldint *a, *b, *c;
562         
563         a = ldnew(32);
564         b = ldnew(32);
565         c = ldnew(32);
566         ok = 1;
567         for(i = -128; i <= 128; i++)
568                 for(j = -128; j <= 128; j++){
569                         itold(i, a);
570                         itold(j, b);
571                         ok &= test2(t, a, b);
572                         pow2told(i, a);
573                         itold(j, b);
574                         ok &= test2(t, a, b);
575                         ok &= test2(t, b, a);
576                         pow2told(i, a);
577                         pow2told(j, b);
578                         ok &= test2(t, a, b);                   
579                 }
580         for(i = 1; i <= 4; i++)
581                 for(j = 1; j <= 4; j++){
582                         ldrand(i * Dbits, a);
583                         ldrand(j * Dbits, b);
584                         ok &= test2(t, a, b);
585                 }
586         for(i = -128; i <= 128; i++){
587                 itold(i, a);
588                 ok &= test2x(t, a);
589                 pow2told(i, a);
590                 ok &= test2x(t, a);
591         }
592         ldfree(a);
593         ldfree(b);
594         if(ok)
595                 fprint(2, "%s: passed\n", t->name);
596 }
597
598 Test2 tests2[] = {
599         "mpdiv(q)", mpdivq, lddivq,
600         "mpdiv(r)", mpdivr, lddivr,
601         "mpmul", mpmul, ldmul,
602         "mpadd", mpadd, ldadd,
603         "mpsub", mpsub, ldsub,
604         "mpand", mpand, ldand,
605         "mpor", mpor, ldor,
606         "mpbic", mpbic, ldbic,
607         "mpxor", mpxor, ldxor,
608 };
609
610 void
611 all2(void)
612 {
613         Test2 *t;
614         
615         for(t = tests2; t < tests2 + nelem(tests2); t++)
616                 run2(t);
617 }
618
619 void
620 main()
621 {
622         fmtinstall('B', mpfmt);
623         fmtinstall('L', ldfmt);
624         all2();
625 }