]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/forp/cvt.c
/sys/src/cmd/ndb/dns.h:
[plan9front.git] / sys / src / cmd / forp / cvt.c
1 #include <u.h>
2 #include <libc.h>
3 #include <mp.h>
4 #include <sat.h>
5 #include "dat.h"
6 #include "fns.h"
7
8 SATSolve *sat;
9 int satvar = 3; /* 1 = false, 2 = true */
10 #define SVAR(n, i) ((n)->vars[(i) < (n)->size ? (i) : (n)->size - 1])
11 int nassertvar;
12 int *assertvar;
13
14 static int
15 max(int a, int b)
16 {
17         return a < b ? b : a;
18 }
19
20 static int
21 min(int a, int b)
22 {
23         return a > b ? b : a;
24 }
25
26 static void
27 symsat(Node *n)
28 {
29         Symbol *s;
30         int i;
31         
32         s = n->sym;
33         assert(s->type == SYMBITS);
34         n->size = s->size + ((s->flags & SYMFSIGNED) == 0);
35         n->vars = emalloc(sizeof(int) * n->size);
36         for(i = 0; i < s->size; i++){
37                 if(s->vars[i] == 0)
38                         s->vars[i] = satvar++;
39                 n->vars[i] = s->vars[i];
40         }
41         if((s->flags & SYMFSIGNED) == 0)
42                 n->vars[i] = 1;
43 }
44
45 static void
46 numsat(Node *n)
47 {
48         mpint *m;
49         int i, sz, j;
50         
51         m = n->num;
52         assert(m != nil);
53         assert(m->sign > 0);
54         sz = mpsignif(m) + 1;
55         n->size = sz;
56         n->vars = emalloc(sizeof(int) * sz);
57         for(i = 0; i < m->top; i++){
58                 for(j = 0; j < Dbits; j++)
59                         if(i * Dbits + j < sz-1)
60                                 n->vars[i * Dbits + j] = 1 + ((m->p[i] >> j & 1) != 0);
61         }
62         n->vars[sz - 1] = 1;
63 }
64
65 static void
66 nodevars(Node *n, int nv)
67 {
68         int i;
69
70         n->size = nv;
71         n->vars = emalloc(sizeof(int) * nv);
72         for(i = 0; i < nv; i++)
73                 n->vars[i] = 1;
74 }
75
76 static void
77 assign(Node *t, Node *n)
78 {
79         Symbol *s;
80         int i;
81         
82         s = t->sym;
83         for(i = 0; i < s->size; i++){
84                 if(i < n->size)
85                         s->vars[i] = n->vars[i];
86                 else
87                         s->vars[i] = n->vars[n->size - 1];
88         }
89 }
90
91 static void
92 opeq(Node *r, Node *n1, Node *n2, int neq)
93 {
94         int i, m, a, b, *t;
95
96         nodevars(r, 2);
97         m = max(n1->size, n2->size);
98         t = malloc(m * sizeof(int));
99         for(i = 0; i < m; i++){
100                 a = SVAR(n1, i);
101                 b = SVAR(n2, i);
102                 t[i] = satlogicv(sat, neq ? 6 : 9, a, b, 0);
103         }
104         if(neq)
105                 r->vars[0] = sator1(sat, t, m);
106         else
107                 r->vars[0] = satand1(sat, t, m);
108         free(t);
109 }
110
111 static void
112 oplogic(Node *r, Node *n1, Node *n2, int op)
113 {
114         int m, i, a, b, *t;
115         
116         m = max(n1->size, n2->size);
117         nodevars(r, m);
118         t = r->vars;
119         for(i = 0; i < m; i++){
120                 a = SVAR(n1, i);
121                 b = SVAR(n2, i);
122                 switch(op){
123                 case OPOR:
124                         t[i] = satorv(sat, a, b, 0);
125                         break;
126                 case OPAND:
127                         t[i] = satandv(sat, a, b, 0);
128                         break;
129                 case OPXOR:
130                         t[i] = satlogicv(sat, 6, a, b, 0);
131                         break;
132                 default: abort();
133                 }
134         }
135 }
136
137 static int
138 tologic(Node *n)
139 {
140         int i;
141
142         for(i = 1; i < n->size; i++)
143                 if(n->vars[i] != 1)
144                         break;
145         if(i == n->size)
146                 return n->vars[0];
147         return sator1(sat, n->vars, n->size);
148 }
149
150 static void
151 opllogic(Node *rn, Node *n1, Node *n2, int op)
152 {
153         int a, b;
154         
155         a = tologic(n1);
156         b = tologic(n2);
157         nodevars(rn, 2);
158         switch(op){
159         case OPLAND:
160                 rn->vars[0] = satandv(sat, a, b, 0);
161                 break;
162         case OPLOR:
163                 rn->vars[0] = satorv(sat, a, b, 0);
164                 break;
165         case OPIMP:
166                 rn->vars[0] = satorv(sat, -a, b, 0);
167                 break;
168         case OPEQV:
169                 rn->vars[0] = satlogicv(sat, 9, a, b, 0);
170                 break;
171         default:
172                 abort();
173         }
174 }
175
176 static void
177 opcom(Node *r, Node *n1)
178 {
179         int i;
180         
181         nodevars(r, n1->size);
182         for(i = 0; i < n1->size; i++)
183                 r->vars[i] = -n1->vars[i];
184 }
185
186 static void
187 opneg(Node *r, Node *n1)
188 {
189         int i, c;
190         
191         nodevars(r, n1->size);
192         c = 2;
193         for(i = 0; i < n1->size; i++){
194                 r->vars[i] = satlogicv(sat, 9, n1->vars[i], c, 0);
195                 if(i < n1->size - 1)
196                         c = satandv(sat, -n1->vars[i], c, 0);
197         }
198 }
199
200 static void
201 opnot(Node *r, Node *n1)
202 {
203         nodevars(r, 2);
204         r->vars[0] = -tologic(n1);
205 }
206
207 static void
208 opadd(Node *rn, Node *n1, Node *n2, int sub)
209 {
210         int i, m, c, a, b;
211         
212         m = max(n1->size, n2->size) + 1;
213         nodevars(rn, m);
214         c = 1 + sub;
215         sub = 1 - 2 * sub;
216         for(i = 0; i < m; i++){
217                 a = SVAR(n1, i);
218                 b = SVAR(n2, i) * sub;
219                 rn->vars[i] = satlogicv(sat, 0x96, c, a, b, 0);
220                 c = satlogicv(sat, 0xe8, c, a, b, 0);
221         }
222 }
223
224 static void
225 oplt(Node *rn, Node *n1, Node *n2, int le)
226 {
227         int i, m, a, b, t, *r;
228         
229         nodevars(rn, 2);
230         m = max(n1->size, n2->size);
231         r = emalloc(sizeof(int) * (m + le));
232         t = 2;
233         for(i = m; --i >= 0; ){
234                 if(i == m - 1){
235                         a = SVAR(n2, i);
236                         b = SVAR(n1, i);
237                 }else{
238                         a = SVAR(n1, i);
239                         b = SVAR(n2, i);
240                 }
241                 r[i] = satandv(sat, -a, b, t, 0);
242                 t = satlogicv(sat, 0x90, a, b, t, 0);
243         }
244         if(le)
245                 r[m] = t;
246         rn->vars[0] = sator1(sat, r, m + le);
247 }
248
249 static void
250 opidx(Node *rn, Node *n1, Node *n2, Node *n3)
251 {
252         int i, j, k, s;
253         
254         k = mptoi(n2->num);
255         if(n3 == nil) j = k;
256         else j = mptoi(n3->num);
257         if(j > k){
258                 nodevars(rn, 1);
259                 return;
260         }
261         s = k - j + 1;
262         nodevars(rn, s + 1);
263         for(i = 0; i < s; i++)
264                 rn->vars[i] = SVAR(n1, j + i);
265 }
266
267 static void
268 oprsh(Node *rn, Node *n1, Node *n2)
269 {
270         int i, j, a, b, q;
271
272         nodevars(rn, n1->size);
273         memcpy(rn->vars, n1->vars, sizeof(int) * n1->size);
274         for(i = 0; i < n2->size; i++){
275                 if(n2->vars[i] == 1) continue;
276                 if(n2->vars[i] == 2){
277                         for(j = 0; j < n1->size; j++)
278                                 rn->vars[j] = SVAR(rn, j + (1<<i));
279                         continue;
280                 }
281                 for(j = 0; j < n1->size; j++){
282                         a = rn->vars[j];
283                         b = SVAR(rn, j + (1<<i));
284                         q = n2->vars[i];
285                         rn->vars[j] = satlogicv(sat, 0xca, a, b, q, 0);
286                 }
287         }
288 }
289
290 static void
291 oplsh(Node *rn, Node *n1, Node *n2, uint sz)
292 {
293         int i, j, a, b, q;
294         u32int m;
295         
296         m = 0;
297         for(i = n2->size; --i >= 0; )
298                 m = m << 1 | n2->vars[i] != m;
299         m += n1->size;
300         if(m > sz) m = sz;
301         nodevars(rn, m);
302         for(i = 0; i < m; i++)
303                 rn->vars[i] = SVAR(n1, i);
304         for(i = 0; i < n2->size; i++){
305                 if(n2->vars[i] == 1) continue;
306                 if(n2->vars[i] == 2){
307                         for(j = m; --j >= 0; )
308                                 rn->vars[j] = j >= 1<<i ? rn->vars[j - (1<<i)] : 1;
309                         continue;
310                 }
311                 for(j = m; --j >= 0; ){
312                         a = rn->vars[j];
313                         b = j >= 1<<i ? rn->vars[j - (1<<i)] : 1;
314                         q = n2->vars[i];
315                         rn->vars[j] = satlogicv(sat, 0xca, a, b, q, 0);
316                 }
317         }       
318 }
319
320 static void
321 optern(Node *rn, Node *n1, Node *n2, Node *n3, uint sz)
322 {
323         uint m;
324         int i, a, b, q;
325         
326         m = n2->size;
327         if(n3->size > m) m = n3->size;
328         if(m > sz) m = sz;
329         nodevars(rn, m);
330         q = tologic(n1);
331         for(i = 0; i < m; i++){
332                 a = SVAR(n3, i);
333                 b = SVAR(n2, i);
334                 rn->vars[i] = satlogicv(sat, 0xca, a, b, q, 0);
335         }
336 }
337
338 static int *
339 opmul(int *n1v, int s1, int *n2v, int s2)
340 {
341         int i, k, t, s;
342         int *r, *q0, *q1, *z, nq0, nq1, nq;
343
344         s1--;
345         s2--;
346         r = emalloc(sizeof(int) * (s1 + s2 + 2));
347         nq = 2 * (min(s1, s2) + 2);
348         q0 = emalloc(nq * sizeof(int));
349         q1 = emalloc(nq * sizeof(int));
350         nq0 = nq1 = 0;
351         for(k = 0; k <= s1 + s2 + 1; k++){
352                 if(k == s1 || k == s1 + s2 + 1){ assert(nq0 < nq); q0[nq0++] = 2; }
353                 if(k == s2){ assert(nq0 < nq); q0[nq0++] = 2; }
354                 for(i = max(0, k - s2); i <= k && i <= s1; i++){
355                         assert(nq0 < nq);
356                         t = satandv(sat, n1v[i], n2v[k - i], 0);
357                         q0[nq0++] = i == s1 ^ k-i == s2 ? -t : t;
358                 }
359                 assert(nq0 > 0);
360                 while(nq0 > 1){
361                         if(nq0 == 2){
362                                 t = satlogicv(sat, 0x6, q0[0], q0[1], 0);
363                                 s = satandv(sat, q0[0], q0[1], 0);
364                                 q0[0] = t;
365                                 assert(nq1 < nq);
366                                 q1[nq1++] = s;
367                                 break;
368                         }
369                         t = satlogicv(sat, 0x96, q0[nq0-3], q0[nq0-2], q0[nq0-1], 0);
370                         s = satlogicv(sat, 0xe8, q0[nq0-3], q0[nq0-2], q0[nq0-1], 0);
371                         q0[nq0-3] = t;
372                         nq0 -= 2;
373                         assert(nq1 < nq);
374                         q1[nq1++] = s;
375                 }
376                 r[k] = q0[0];
377                 z=q0, q0=q1, q1=z;
378                 nq0 = nq1;
379                 nq1 = 0;
380         }
381         free(q0);
382         free(q1);
383         return r;
384 }
385
386 static void
387 opabs(Node *q, Node *n)
388 {
389         int i;
390         int s, c;
391
392         nodevars(q, n->size + 1);
393         s = n->vars[n->size - 1];
394         c = s;
395         for(i = 0; i < n->size; i++){
396                 q->vars[i] = satlogicv(sat, 0x96, n->vars[i], s, c, 0);
397                 c = satandv(sat, -n->vars[i], c, 0);
398         }
399 }
400
401 static void
402 opdiv(Node *q, Node *r, Node *n1, Node *n2)
403 {
404         Node *s;
405         int i, s1, sr,zr;
406         
407         if(q == nil) q = node(ASTTEMP);
408         if(r == nil) r = node(ASTTEMP);
409         nodevars(q, n1->size);
410         nodevars(r, n2->size);
411         for(i = 0; i < n1->size; i++)
412                 q->vars[i] = satvar++;
413         for(i = 0; i < n2->size; i++)
414                 r->vars[i] = satvar++;
415         s = node(ASTBIN, OPEQ, node(ASTBIN, OPADD, node(ASTBIN, OPMUL, q, n2), r), n1); convert(s, -1); assume(s);
416         s = node(ASTBIN, OPLT, node(ASTUN, OPABS, r), node(ASTUN, OPABS, n2)); convert(s, -1); assume(s);
417         s1 = n1->vars[n1->size - 1];
418         sr = r->vars[r->size - 1];
419         zr = -sator1(sat, r->vars, r->size);
420         sataddv(sat, zr, sr, -s1, 0);
421         sataddv(sat, zr, -sr, s1, 0);
422 }
423
424 void
425 convert(Node *n, uint sz)
426 {
427         if(n->size > 0) return;
428         switch(n->type){
429         case ASTTEMP:
430                 assert(n->size > 0);
431                 break;
432         case ASTSYM:
433                 symsat(n);
434                 break;
435         case ASTNUM:
436                 numsat(n);
437                 break;
438         case ASTBIN:
439                 if(n->op == OPASS){
440                         if(n->n1 == nil || n->n1->type != ASTSYM)
441                                 error(n, "convert: '%ε' invalid lval", n->n1);
442                         convert(n->n2, n->n1->sym->size);
443                         assert(n->n2->size > 0);
444                         assign(n->n1, n->n2);
445                         break;
446                 }
447                 switch(n->op){
448                 case OPAND: case OPOR: case OPXOR:
449                 case OPADD: case OPSUB: case OPLSH:
450                 case OPCOMMA:
451                         convert(n->n1, sz);
452                         convert(n->n2, sz);
453                         break;
454                 default:
455                         convert(n->n1, -1);
456                         convert(n->n2, -1);
457                 }
458                 assert(n->n1->size > 0);
459                 assert(n->n2->size > 0);
460                 switch(n->op){
461                 case OPCOMMA: n->size = n->n2->size; n->vars = n->n2->vars; break;
462                 case OPEQ: opeq(n, n->n1, n->n2, 0); break;
463                 case OPNEQ: opeq(n, n->n1, n->n2, 1); break;
464                 case OPLT: oplt(n, n->n1, n->n2, 0); break;
465                 case OPLE: oplt(n, n->n1, n->n2, 1); break;
466                 case OPGT: oplt(n, n->n2, n->n1, 0); break;
467                 case OPGE: oplt(n, n->n2, n->n1, 1); break;
468                 case OPXOR: case OPAND: case OPOR: oplogic(n, n->n1, n->n2, n->op); break;
469                 case OPLAND: case OPLOR: case OPIMP: case OPEQV: opllogic(n, n->n1, n->n2, n->op); break;
470                 case OPADD: opadd(n, n->n1, n->n2, 0); break;
471                 case OPSUB: opadd(n, n->n1, n->n2, 1); break;
472                 case OPLSH: oplsh(n, n->n1, n->n2, sz); break;
473                 case OPRSH: oprsh(n, n->n1, n->n2); break;
474                 case OPMUL: n->vars = opmul(n->n1->vars, n->n1->size, n->n2->vars, n->n2->size); n->size = n->n1->size + n->n2->size; break;
475                 case OPDIV: opdiv(n, nil, n->n1, n->n2); break;
476                 case OPMOD: opdiv(nil, n, n->n1, n->n2); break;
477                 default:
478                         error(n, "convert: unimplemented op %O", n->op);
479                 }
480                 break;
481         case ASTUN:
482                 convert(n->n1, sz);
483                 switch(n->op){
484                 case OPCOM: opcom(n, n->n1); break;
485                 case OPNEG: opneg(n, n->n1); break;
486                 case OPNOT: opnot(n, n->n1); break;
487                 case OPABS: opabs(n, n->n1); break;
488                 default:
489                         error(n, "convert: unimplemented op %O", n->op);
490                 }
491                 break;
492         case ASTIDX:
493                 if(n->n2->type != ASTNUM || n->n3 != nil && n->n3->type != ASTNUM)
494                         error(n, "non-constant in indexing expression");
495                 convert(n->n1, n->n3 != nil ? mptoi(n->n3->num) - mptoi(n->n2->num) + 1 : 1);
496                 opidx(n, n->n1, n->n2, n->n3);
497                 break;
498         case ASTTERN:
499                 convert(n->n1, -1);
500                 convert(n->n2, sz);
501                 convert(n->n3, sz);
502                 optern(n, n->n1, n->n2, n->n3, sz);
503                 break;
504         default:
505                 error(n, "convert: unimplemented %α", n->type);
506         }
507 }
508
509 void
510 assume(Node *n)
511 {
512         assert(n->size > 0);
513         satadd1(sat, n->vars, n->size);
514 }
515
516 void
517 obviously(Node *n)
518 {
519         assertvar = realloc(assertvar, sizeof(int) * (nassertvar + 1));
520         assert(assertvar != nil);
521         assertvar[nassertvar++] = -tologic(n);
522 }
523
524 void
525 cvtinit(void)
526 {
527         sat = sataddv(nil, -1, 0);
528         sataddv(sat, 2, 0);
529 }