]> git.lizzy.rs Git - plan9front.git/blob - sys/src/libmp/port/gmfield.c
libmp: handle out of memory case in gmfield()
[plan9front.git] / sys / src / libmp / port / gmfield.c
1 #include "os.h"
2 #include <mp.h>
3 #include "dat.h"
4
5 /*
6  * fast reduction for generalized mersenne numbers (GM)
7  * using a series of additions and subtractions.
8  */
9
10 enum {
11         MAXDIG = 1024/Dbits,
12 };
13
14 typedef struct GMfield GMfield;
15 struct GMfield
16 {
17         Mfield; 
18
19         mpint   m2[1];
20
21         int     nadd;
22         int     nsub;
23         int     indx[256];
24 };
25
26 static int
27 gmreduce(Mfield *m, mpint *a, mpint *r)
28 {
29         GMfield *g = (GMfield*)m;
30         mpdigit d0, t[MAXDIG];
31         int i, j, d, *x;
32
33         if(mpmagcmp(a, g->m2) >= 0)
34                 return -1;
35
36         if(a != r)
37                 mpassign(a, r);
38
39         d = g->top;
40         mpbits(r, (d+1)*Dbits*2);
41         memmove(t+d, r->p+d, d*Dbytes);
42
43         r->sign = 1;
44         r->top = d;
45         r->p[d] = 0;
46
47         if(g->nsub > 0)
48                 mpvecdigmuladd(g->p, d, g->nsub, r->p);
49
50         x = g->indx;
51         for(i=0; i<g->nadd; i++){
52                 t[0] = 0;
53                 d0 = t[*x++];
54                 for(j=1; j<d; j++)
55                         t[j] = t[*x++];
56                 t[0] = d0;
57
58                 mpvecadd(r->p, d+1, t, d, r->p);
59         }
60
61         for(i=0; i<g->nsub; i++){
62                 t[0] = 0;
63                 d0 = t[*x++];
64                 for(j=1; j<d; j++)
65                         t[j] = t[*x++];
66                 t[0] = d0;
67
68                 mpvecsub(r->p, d+1, t, d, r->p);
69         }
70
71         mpvecdigmulsub(g->p, d, r->p[d], r->p);
72         r->p[d] = 0;
73
74         mpvecsub(r->p, d+1, g->p, d, r->p+d+1);
75         d0 = r->p[2*d+1];
76         for(j=0; j<d; j++)
77                 r->p[j] = (r->p[j] & d0) | (r->p[j+d+1] & ~d0);
78
79         mpnorm(r);
80
81         return 0;
82 }
83
84 Mfield*
85 gmfield(mpint *N)
86 {
87         int i,j,d, s, *C, *X, *x, *e;
88         mpint *M, *T;
89         GMfield *g;
90
91         d = N->top;
92         if(d <= 2 || d > MAXDIG/2 || (mpsignif(N) % Dbits) != 0)
93                 return nil;
94         g = nil;
95         T = mpnew(0);
96         M = mpcopy(N);
97         C = malloc(sizeof(int)*(d+1));
98         X = malloc(sizeof(int)*(d*d));
99         if(C == nil || X == nil)
100                 goto out;
101
102         for(i=0; i<=d; i++){
103                 if((M->p[i]>>8) != 0 && (~M->p[i]>>8) != 0)
104                         goto out;
105                 j = M->p[i];
106                 C[d - i] = -j;
107                 itomp(j, T);
108                 mpleft(T, i*Dbits, T);
109                 mpsub(M, T, M);
110         }
111         for(j=0; j<d; j++)
112                 X[j] = C[d-j];
113         for(i=1; i<d; i++){
114                 X[d*i] = X[d*(i-1) + d-1]*C[d];
115                 for(j=1; j<d; j++)
116                         X[d*i + j] = X[d*(i-1) + j-1] + X[d*(i-1) + d-1]*C[d-j];
117         }
118         g = mallocz(sizeof(GMfield) + (d+1)*sizeof(mpdigit)*2, 1);
119         if(g == nil)
120                 goto out;
121
122         g->m2->p = (mpdigit*)&g[1];
123         g->m2->size = d*2+1;
124         mpmul(N, N, g->m2);
125         mpassign(N, g);
126         g->reduce = gmreduce;
127         g->flags |= MPfield;
128
129         s = 0;
130         x = g->indx;
131         e = x + nelem(g->indx) - d;
132         for(g->nadd=0; x <= e; x += d, g->nadd++){
133                 s = 0;
134                 for(i=0; i<d; i++){
135                         for(j=0; j<d; j++){
136                                 if(X[d*i+j] > 0 && x[j] == 0){
137                                         X[d*i+j]--;
138                                         x[j] = d+i;
139                                         s = 1;
140                                         break;
141                                 }
142                         }
143                 }
144                 if(s == 0)
145                         break;
146         }
147         for(g->nsub=0; x <= e; x += d, g->nsub++){
148                 s = 0;
149                 for(i=0; i<d; i++){
150                         for(j=0; j<d; j++){
151                                 if(X[d*i+j] < 0 && x[j] == 0){
152                                         X[d*i+j]++;
153                                         x[j] = d+i;
154                                         s = 1;
155                                         break;
156                                 }
157                         }
158                 }
159                 if(s == 0)
160                         break;
161         }
162         if(s != 0){
163                 mpfree(g);
164                 g = nil;
165         }
166 out:
167         free(C);
168         free(X);
169         mpfree(M);
170         mpfree(T);
171         return g;
172 }
173