]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/mpc.y
libaml: fix gc bug, need to amltake()/amldrop() temporary buffer
[plan9front.git] / sys / src / cmd / mpc.y
1 %{
2
3 #include        <u.h>
4 #include        <libc.h>
5 #include        <bio.h>
6 #include        <mp.h>
7
8 typedef struct Sym Sym;
9 typedef struct Node Node;
10
11 enum {
12         FSET    = 1,
13         FUSE    = 2,
14         FARG    = 4,
15         FLOC    = 8,
16 };
17
18 struct Sym
19 {
20         Sym*    l;
21         int     f;
22         char    n[];
23 };
24
25 struct Node
26 {
27         int     c;
28         Node*   l;
29         Node*   r;
30         Sym*    s;
31         mpint*  m;
32         int     n;
33 };
34
35 #pragma varargck type "N" Node*
36
37 int     ntmp;
38 Node    *ftmps, *atmps;
39 Node    *modulo;
40
41 Node*   new(int, Node*, Node*);
42 Sym*    sym(char*);
43
44 Biobuf  bin;
45 int     goteof;
46 int     lineno;
47 int     clevel;
48 char*   filename;
49
50 int     getch(void);
51 void    ungetc(void);
52 void    yyerror(char*);
53 int     yyparse(void);
54 void    diag(Node*, char*, ...);
55 void    com(Node*);
56 void    fcom(Node*,Node*,Node*);
57
58 #pragma varargck argpos cprint 1
59 #pragma varargck argpos diag 2
60
61 %}
62
63 %union
64 {
65         Sym*    sval;
66         Node*   node;
67 }
68
69 %type   <node>  name num args expr bool block elif stmnt stmnts
70
71 %left   '{' '}' ';'
72 %right  '=' ','
73 %right  '?' ':'
74 %left   EQ NEQ '<' '>'
75 %left   LSH RSH
76 %left   '+' '-'
77 %left   '/' '%'
78 %left   '*'
79 %left   '^'
80 %right  '('
81
82 %token  MOD IF ELSE WHILE BREAK 
83 %token  <sval>  NAME NUM
84
85 %%
86
87 prog:
88         prog func
89 |       func
90
91 func:
92         name args stmnt
93         {
94                 fcom($1, $2, $3);
95         }
96
97 args:
98         '(' expr ')'
99         {
100                 $$ = $2;
101         }
102 |       '(' ')'
103         {
104                 $$ = nil;
105         }
106
107 name:
108         NAME
109         {
110                 $$ = new(NAME,nil,nil);
111                 $$->s = $1;
112         }
113 num:
114         NUM
115         {
116                 $$ = new(NUM,nil,nil);
117                 $$->s = $1;
118         }
119
120 elif:
121         ELSE IF '(' bool ')' stmnt
122         {
123                 $$ = new('?', $4, new(':', $6, nil));
124         }
125 |       ELSE IF '(' bool ')' stmnt elif
126         {
127                 $$ = new('?', $4, new(':', $6, $7));
128         }
129 |       ELSE stmnt
130         {
131                 $$ = $2;
132         }
133
134 sem:
135         sem ';'
136 |       ';'
137
138 stmnt:
139         expr '=' expr sem
140         {
141                 $$ = new('=', $1, $3);
142         }
143 |       MOD args stmnt
144         {
145                 $$ = new('m', $2, $3);
146         }
147 |       IF '(' bool ')' stmnt
148         {
149                 $$ = new('?', $3, new(':', $5, nil));
150         }
151 |       IF '(' bool ')' stmnt elif
152         {
153                 $$ = new('?', $3, new(':', $5, $6));
154         }
155 |       WHILE '(' bool ')' stmnt
156         {
157                 $$ = new('@', new('?', $3, new(':', $5, new('b', nil, nil))), nil);
158         }
159 |       BREAK sem
160         {
161                 $$ = new('b', nil, nil);
162         }
163 |       expr sem
164         {
165                 if($1->c == NAME)
166                         $$ = new('e', $1, nil);
167                 else
168                         $$ = $1;
169         }
170 |       block
171
172 block:
173         '{' stmnts '}'
174         {
175                 $$ = $2;
176         }
177
178 stmnts:
179         stmnts stmnt
180         {
181                 $$ = new('\n', $1, $2);
182         }
183 |       stmnt
184
185 expr:
186         '(' expr ')'
187         {
188                 $$ = $2;
189         }
190 |       name
191         {
192                 $$ = $1;
193         }
194 |       num
195         {
196                 $$ = $1;
197         }
198 |       '-' expr
199         {
200                 $$ = new(NUM, nil, nil);
201                 $$->s = sym("0");
202                 $$->s->f = 0;
203                 $$ = new('-', $$, $2);
204         }
205 |       expr ',' expr
206         {
207                 $$ = new(',', $1, $3);
208         }
209 |       expr '^' expr
210         {
211                 $$ = new('^', $1, $3);
212         }
213 |       expr '*' expr
214         {
215                 $$ = new('*', $1, $3);
216         }
217 |       expr '/' expr
218         {
219                 $$ = new('/', $1, $3);
220         }
221 |       expr '%' expr
222         {
223                 $$ = new('%', $1, $3);
224         }
225 |       expr '+' expr
226         {
227                 $$ = new('+', $1, $3);
228         }
229 |       expr '-' expr
230         {
231                 $$ = new('-', $1, $3);
232         }
233 |       bool '?' expr ':' expr
234         {
235                 $$ = new('?', $1, new(':', $3, $5));
236         }
237 |       name args
238         {
239                 $$ = new('e', $1, $2);
240         }
241 |       expr LSH expr
242         {
243                 $$ = new(LSH, $1, $3);
244         }
245 |       expr RSH expr
246         {
247                 $$ = new(RSH, $1, $3);
248         }
249
250 bool:
251         '(' bool ')'
252         {
253                 $$ = $2;
254         }
255 |       '!' bool
256         {
257                 $$ = new('!', $2, nil);
258         }
259 |       expr EQ expr
260         {
261                 $$ = new(EQ, $1, $3);
262         }
263 |       expr NEQ expr
264         {
265                 $$ = new('!', new(EQ, $1, $3), nil);
266         }
267 |       expr '>' expr
268         {
269                 $$ = new('>', $1, $3);
270         }
271 |       expr '<' expr
272         {
273                 $$ = new('<', $1, $3);
274         }
275
276 %%
277
278 int
279 yylex(void)
280 {
281         static char buf[200];
282         char *p;
283         int c;
284
285 Loop:
286         c = getch();
287         switch(c){
288         case -1:
289                 return -1;
290         case ' ':
291         case '\t':
292         case '\n':
293                 goto Loop;
294         case '#':
295                 while((c = getch()) > 0)
296                         if(c == '\n')
297                                 break;
298                 goto Loop;
299         }
300
301         switch(c){
302         case '?': case ':':
303         case '+': case '-':
304         case '*': case '^':
305         case '/': case '%':
306         case '{': case '}':
307         case '(': case ')':
308         case ',': case ';':
309                 return c;
310         case '<':
311                 if(getch() == '<') return LSH;
312                 ungetc();
313                 return '<';
314         case '>': 
315                 if(getch() == '>') return RSH;
316                 ungetc();
317                 return '>';
318         case '=':
319                 if(getch() == '=') return EQ;
320                 ungetc();
321                 return '=';
322         case '!':
323                 if(getch() == '=') return NEQ;
324                 ungetc();
325                 return '!';
326         }
327
328         ungetc();
329         p = buf;
330         for(;;){
331                 c = getch();
332                 if((c >= Runeself)
333                 || (c == '_')
334                 || (c >= 'a' && c <= 'z')
335                 || (c >= 'A' && c <= 'Z')
336                 || (c >= '0' && c <= '9')){
337                         *p++ = c;
338                         continue;
339                 }
340                 ungetc();
341                 break;
342         }
343         *p = '\0';
344
345         if(strcmp(buf, "mod") == 0)
346                 return MOD;
347         if(strcmp(buf, "if") == 0)
348                 return IF;
349         if(strcmp(buf, "else") == 0)
350                 return ELSE;
351         if(strcmp(buf, "while") == 0)
352                 return WHILE;
353         if(strcmp(buf, "break") == 0)
354                 return BREAK;
355
356         yylval.sval = sym(buf);
357         yylval.sval->f = 0;
358         return (buf[0] >= '0' && buf[0] <= '9') ? NUM : NAME;
359 }
360
361
362 int
363 getch(void)
364 {
365         int c;
366
367         c = Bgetc(&bin);
368         if(c == Beof){
369                 goteof = 1;
370                 return -1;
371         }
372         if(c == '\n')
373                 lineno++;
374         return c;
375 }
376
377 void
378 ungetc(void)
379 {
380         Bungetc(&bin);
381 }
382
383 Node*
384 new(int c, Node *l, Node *r)
385 {
386         Node *n;
387
388         n = malloc(sizeof(Node));
389         n->c = c;
390         n->l = l;
391         n->r = r;
392         n->s = nil;
393         n->m = nil;
394         n->n = lineno;
395         return n;
396 }
397
398 Sym*
399 sym(char *n)
400 {
401         static Sym *tab[128];
402         Sym *s;
403         ulong h, t;
404         int i;
405
406         h = 0;
407         for(i=0; n[i] != '\0'; i++){
408                 t = h & 0xf8000000;
409                 h <<= 5;
410                 h ^= t>>27;
411                 h ^= (ulong)n[i];
412         }
413         h %= nelem(tab);
414         for(s = tab[h]; s != nil; s = s->l)
415                 if(strcmp(s->n, n) == 0)
416                         return s;
417         s = malloc(sizeof(Sym)+i+1);
418         memmove(s->n, n, i+1);
419         s->f = 0;
420         s->l = tab[h];
421         tab[h] = s;
422         return s;
423 }
424
425 void
426 yyerror(char *s)
427 {
428         fprint(2, "%s:%d: %s\n", filename, lineno, s);
429         exits(s);
430 }
431 void
432 cprint(char *fmt, ...)
433 {
434         static char buf[1024], tabs[] = "\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t";
435         char *p, *x;
436         va_list a;
437
438         va_start(a, fmt);
439         vsnprint(buf, sizeof(buf), fmt, a);
440         va_end(a);
441
442         p = buf;
443         while((x = strchr(p, '\n')) != nil){
444                 x++;
445                 write(1, p, x-p);
446                 p = &tabs[sizeof(tabs)-1 - clevel];
447                 if(*p != '\0')
448                         write(1, p, strlen(p));
449                 p = x;
450         }
451         if(*p != '\0')
452                 write(1, p, strlen(p));
453 }
454
455 Node*
456 alloctmp(void)
457 {
458         Node *t;
459
460         t = ftmps;
461         if(t != nil)
462                 ftmps = t->l;
463         else {
464                 char n[16];
465
466                 snprint(n, sizeof(n), "tmp%d", ++ntmp);
467                 t = new(NAME, nil, nil);
468                 t->s = sym(n);
469
470                 cprint("mpint *");
471         }
472         cprint("%N = mpnew(0);\n", t);
473         t->s->f &= ~(FSET|FUSE);
474         t->l = atmps;
475         atmps = t;
476         return t;
477 }
478
479 int
480 isconst(Node *n)
481 {
482         if(n->c == NUM)
483                 return 1;
484         if(n->c == NAME){
485                 return  n->s == sym("mpzero") ||
486                         n->s == sym("mpone") ||
487                         n->s == sym("mptwo");
488         }
489         return 0;
490 }
491
492 int
493 istmp(Node *n)
494 {
495         Node *l;
496
497         if(n->c == NAME){
498                 for(l = atmps; l != nil; l = l->l){
499                         if(l->s == n->s)
500                                 return 1;
501                 }
502         }
503         return 0;
504 }
505
506
507 void
508 freetmp(Node *t)
509 {
510         Node **ll, *l;
511
512         if(t == nil)
513                 return;
514         if(t->c == ','){
515                 freetmp(t->l);
516                 freetmp(t->r);
517                 return;
518         }
519         if(t->c != NAME)
520                 return;
521
522         ll = &atmps;
523         for(l = atmps; l != nil; l = l->l){
524                 if(l == t){
525                         cprint("mpfree(%N);\n", t);
526                         *ll = t->l;
527                         t->l = ftmps;
528                         ftmps = t;
529                         return;
530                 }
531                 ll = &l->l;
532         }
533 }
534
535 int
536 symref(Node *n, Sym *s)
537 {
538         if(n == nil)
539                 return 0;
540         if(n->c == NAME && n->s == s)
541                 return 1;
542         return symref(n->l, s) || symref(n->r, s);
543 }
544
545 void
546 nodeset(Node *n)
547 {
548         if(n == nil)
549                 return;
550         if(n->c == NAME){
551                 n->s->f |= FSET;
552                 return;
553         }
554         if(n->c == ','){
555                 nodeset(n->l);
556                 nodeset(n->r);
557         }
558 }
559
560 int
561 complex(Node *n)
562 {
563         if(n->c == NAME)
564                 return 0;
565         if(n->c == NUM && n->m->sign > 0 && mpcmp(n->m, mptwo) <= 0)
566                 return 0;
567         return 1;
568 }
569
570 void
571 bcom(Node *n, Node *t);
572
573 Node*
574 ccom(Node *f)
575 {
576         Node *l, *r;
577
578         if(f == nil)
579                 return nil;
580
581         if(f->m != nil)
582                 return f;
583         f->m = (void*)~0;
584
585         switch(f->c){
586         case NUM:
587                 f->m = strtomp(f->s->n, nil, 0, nil);
588                 if(f->m == nil)
589                         diag(f, "bad constant");
590                 goto out;
591
592         case LSH:
593         case RSH:
594                 break;
595
596         case '+':
597         case '-':
598         case '*':
599         case '/':
600         case '%':
601         case '^':
602                 if(modulo == nil || modulo->c == NUM)
603                         break;
604
605                 /* wet floor */
606         default:
607                 return f;
608         }
609
610         f->l = l = ccom(f->l);
611         f->r = r = ccom(f->r);
612         if(l == nil || r == nil || l->c != NUM || r->c != NUM)
613                 return f;
614
615         f->m = mpnew(0);
616         switch(f->c){
617         case LSH:
618         case RSH:
619                 if(mpsignif(r->m) > 32)
620                         diag(f, "bad shift");
621                 if(f->c == LSH)
622                         mpleft(l->m, mptoi(r->m), f->m);
623                 else
624                         mpright(l->m, mptoi(r->m), f->m);
625                 goto out;
626
627         case '+':
628                 mpadd(l->m, r->m, f->m);
629                 break;
630         case '-':
631                 mpsub(l->m, r->m, f->m);
632                 break;
633         case '*':
634                 mpmul(l->m, r->m, f->m);
635                 break;
636         case '/':
637                 if(modulo != nil){
638                         mpinvert(r->m, modulo->m, f->m);
639                         mpmul(f->m, l->m, f->m);
640                 } else {
641                         mpdiv(l->m, r->m, f->m, nil);
642                 }
643                 break;
644         case '%':
645                 mpmod(l->m, r->m, f->m);
646                 break;
647         case '^':
648                 mpexp(l->m, r->m, modulo != nil ? modulo->m : nil, f->m);
649                 goto out;
650         }
651         if(modulo != nil)
652                 mpmod(f->m, modulo->m, f->m);
653
654 out:
655         f->l = nil;
656         f->r = nil;
657         f->s = nil;
658         f->c = NUM;
659         return f;
660 }
661
662 Node*
663 ecom(Node *f, Node *t)
664 {
665         Node *l, *r, *t2;
666
667         if(f == nil)
668                 return nil;
669
670         f = ccom(f);
671         if(f->c == NUM){
672                 if(f->m->sign < 0){
673                         f->m->sign = 1;
674                         t = ecom(f, t);
675                         f->m->sign = -1;
676                         if(isconst(t))
677                                 t = ecom(t, alloctmp());
678                         cprint("%N->sign = -1;\n", t);
679                         return t;
680                 }
681                 if(mpcmp(f->m, mpzero) == 0){
682                         f->c = NAME;
683                         f->s = sym("mpzero");
684                         f->s->f = FSET;
685                         return ecom(f, t);
686                 }
687                 if(mpcmp(f->m, mpone) == 0){
688                         f->c = NAME;
689                         f->s = sym("mpone");
690                         f->s->f = FSET;
691                         return ecom(f, t);
692                 }
693                 if(mpcmp(f->m, mptwo) == 0){
694                         f->c = NAME;
695                         f->s = sym("mptwo");
696                         f->s->f = FSET;
697                         return ecom(f, t);
698                 }
699         }
700
701         if(f->c == ','){
702                 if(t != nil)
703                         diag(f, "cannot assign list to %N", t);
704                 f->l = ecom(f->l, nil);
705                 f->r = ecom(f->r, nil);
706                 return f;
707         }
708
709         l = r = nil;
710         if(f->c == NAME){
711                 if((f->s->f & FSET) == 0)
712                         diag(f, "name used but not set");
713                 f->s->f |= FUSE;
714                 if(t == nil)
715                         return f;
716                 if(f->s != t->s)
717                         cprint("mpassign(%N, %N);\n", f, t);
718                 goto out;
719         }
720
721         if(t == nil)
722                 t = alloctmp();
723
724         if(f->c == '?'){
725                 bcom(f, t);
726                 goto out;
727         }
728
729         if(f->c == 'e'){
730                 r = ecom(f->r, nil);
731                 if(r == nil)
732                         cprint("%N(%N);\n", f->l, t);
733                 else
734                         cprint("%N(%N, %N);\n", f->l, r, t);
735                 goto out;
736         }
737
738         if(t->c != NAME)
739                 diag(f, "destination %N not a name", t);
740
741         switch(f->c){
742         case NUM:
743                 if(mpsignif(f->m) <= 32)
744                         cprint("uitomp(%udUL, %N);\n", mptoui(f->m), t);
745                 else if(mpsignif(f->m) <= 64)
746                         cprint("uvtomp(%lludULL, %N);\n", mptouv(f->m), t);
747                 else
748                         cprint("strtomp(\"%.16B\", nil, 16, %N);\n", f->m, t);
749                 goto out;
750         case LSH:
751         case RSH:
752                 r = ccom(f->r);
753                 if(r == nil || r->c != NUM || mpsignif(r->m) > 32)
754                         diag(f, "bad shift");
755                 l = f->l->c == NAME ? f->l : ecom(f->l, t);
756                 if(f->c == LSH)
757                         cprint("mpleft(%N, %d, %N);\n", l, mptoi(r->m), t);
758                 else
759                         cprint("mpright(%N, %d, %N);\n", l, mptoi(r->m), t);
760                 goto out;
761         case '*':
762         case '/':
763                 l = ecom(f->l, nil);
764                 r = ecom(f->r, nil);
765                 break;
766         default:
767                 l = ccom(f->l);
768                 r = ccom(f->r);
769                 l = ecom(l, complex(l) && !symref(r, t->s) ? t : nil);
770                 r = ecom(r, complex(r) && l->s != t->s ? t : nil);
771                 break;
772         }
773
774
775         if(modulo != nil){
776                 switch(f->c){
777                 case '+':
778                         cprint("mpmodadd(%N, %N, %N, %N);\n", l, r, modulo, t);
779                         goto out;
780                 case '-':
781                         cprint("mpmodsub(%N, %N, %N, %N);\n", l, r, modulo, t);
782                         goto out;
783                 case '*':
784                 Modmul:
785                         if(l->s == sym("mptwo") || r->s == sym("mptwo"))
786                                 cprint("mpmodadd(%N, %N, %N, %N); // 2*%N\n",
787                                         r->s == sym("mptwo") ? l : r,
788                                         r->s == sym("mptwo") ? l : r,
789                                         modulo, t,
790                                         r);
791                         else
792                                 cprint("mpmodmul(%N, %N, %N, %N);\n", l, r, modulo, t);
793                         goto out;
794                 case '/':
795                         if(l->s == sym("mpone")){
796                                 cprint("mpinvert(%N, %N, %N);\n", r, modulo, t);
797                                 goto out;
798                         }
799                         t2 = alloctmp();
800                         cprint("mpinvert(%N, %N, %N);\n", r, modulo, t2);
801                         cprint("mpmodmul(%N, %N, %N, %N);\n", l, t2, modulo, t);
802                         freetmp(t2);
803                         goto out;
804                 case '^':
805                         if(r->s == sym("mptwo")){
806                                 r = l;
807                                 goto Modmul;
808                         }
809                         cprint("mpexp(%N, %N, %N, %N);\n", l, r, modulo, t);
810                         goto out;
811                 }
812         }
813
814         switch(f->c){
815         case '+':
816                 cprint("mpadd(%N, %N, %N);\n", l, r, t);
817                 goto out;
818         case '-':
819                 if(l->s == sym("mpzero")){
820                         r = ecom(r, t);
821                         cprint("%N->sign = -%N->sign;\n", t, t);
822                 } else
823                         cprint("mpsub(%N, %N, %N);\n", l, r, t);
824                 goto out;
825         case '*':
826         Mul:
827                 if(l->s == sym("mptwo") || r->s == sym("mptwo"))
828                         cprint("mpleft(%N, 1, %N);\n", r->s == sym("mptwo") ? l : r, t);
829                 else
830                         cprint("mpmul(%N, %N, %N);\n", l, r, t);
831                 goto out;
832         case '/':
833                 cprint("mpdiv(%N, %N, %N, %N);\n", l, r, t, nil);
834                 goto out;
835         case '%':
836                 cprint("mpmod(%N, %N, %N);\n", l, r, t);
837                 goto out;
838         case '^':
839                 if(r->s == sym("mptwo")){
840                         r = l;
841                         goto Mul;
842                 }
843                 cprint("mpexp(%N, %N, nil, %N);\n", l, r, t);
844                 goto out;
845         default:
846                 diag(f, "unknown operation");
847         }
848
849 out:
850         if(l != t)
851                 freetmp(l);
852         if(r != t)
853                 freetmp(r);
854         nodeset(t);
855         return t;
856 }
857
858 void
859 bcom(Node *n, Node *t)
860 {
861         Node *f, *l, *r;
862         int neg = 0;
863
864         l = r = nil;
865         f = n->l;
866 Loop:
867         switch(f->c){
868         case '!':
869                 neg = !neg;
870                 f = f->l;
871                 goto Loop;
872         case '>':
873         case '<':
874         case EQ:
875                 l = ecom(f->l, nil);
876                 r = ecom(f->r, nil);
877                 if(t != nil) {
878                         Node *b1, *b2;
879
880                         b1 = ecom(n->r->l, nil);
881                         b2 = ecom(n->r->r, nil);
882                         cprint("mpsel(");
883
884                         if(l->s == r->s)
885                                 cprint("0");
886                         else {
887                                 if(f->c == '>')
888                                         cprint("-");
889                                 cprint("mpcmp(%N, %N)", l, r);
890                         }
891                         if(f->c == EQ)
892                                 neg = !neg;
893                         else
894                                 cprint(" >> (sizeof(int)*8-1)");
895
896                         cprint(", %N, %N, %N);\n", neg ? b2 : b1, neg ? b1 : b2, t);
897                         freetmp(b1);
898                         freetmp(b2);
899                 } else {
900                         cprint("if(");
901
902                         if(l->s == r->s)
903                                 cprint("0");
904                         else
905                                 cprint("mpcmp(%N, %N)", l, r);
906                         if(f->c == EQ)
907                                 cprint(neg ? " != 0" : " == 0");
908                         else if(f->c == '>')
909                                 cprint(neg ? " <= 0" : " > 0");
910                         else
911                                 cprint(neg ? " >= 0" : " < 0");
912
913                         cprint(")");
914                         com(n->r);
915                 }
916                 break;
917         default:
918                 diag(n, "saw %N in boolean expression", f);
919         }
920         freetmp(l);
921         freetmp(r);
922 }
923
924 void
925 com(Node *n)
926 {
927         Node *l, *r;
928
929 Loop:
930         if(n != nil)
931         switch(n->c){
932         case '\n':
933                 com(n->l);
934                 n = n->r;
935                 goto Loop;
936         case '?':
937                 bcom(n, nil);
938                 break;
939         case 'b':
940                 for(l = atmps; l != nil; l = l->l)
941                         cprint("mpfree(%N);\n", l);
942                 cprint("break;\n");
943                 break;
944         case '@':
945                 cprint("for(;;)");
946         case ':':
947                 clevel++;
948                 cprint("{\n");
949                 l = ftmps;
950                 r = atmps;
951                 if(n->c == '@')
952                         atmps = nil;
953                 ftmps = nil;
954                 com(n->l);
955                 if(n->r != nil){
956                         cprint("}else{\n");
957                         ftmps = nil;
958                         com(n->r);
959                 }
960                 ftmps = l;
961                 atmps = r;
962                 clevel--;
963                 cprint("}\n");
964                 break;
965         case 'm':
966                 l = modulo;
967                 modulo = ecom(n->l, nil);
968                 com(n->r);
969                 freetmp(modulo);
970                 modulo = l;
971                 break;
972         case 'e':
973                 if(n->r == nil)
974                         cprint("%N();\n", n->l);
975                 else {
976                         r = ecom(n->r, nil);
977                         cprint("%N(%N);\n", n->l, r);
978                         freetmp(r);
979                 }
980                 break;
981         case '=':
982                 ecom(n->r, n->l);
983                 break;
984         }
985 }
986
987 Node*
988 flocs(Node *n, Node *r)
989 {
990 Loop:
991         if(n != nil)
992         switch(n->c){
993         default:
994                 r = flocs(n->l, r);
995                 r = flocs(n->r, r);
996                 n = n->r;
997                 goto Loop;
998         case '=':
999                 n = n->l;
1000                 if(n == nil)
1001                         diag(n, "lhs is nil");
1002                 while(n->c == ','){
1003                         n->c = '=';
1004                         r = flocs(n, r);
1005                         n->c = ',';
1006                         n = n->r;
1007                         if(n == nil)
1008                                 return r;
1009                 }
1010                 if(n->c == NAME && (n->s->f & (FARG|FLOC)) == 0){
1011                         n->s->f = FLOC;
1012                         return new(',', n, r);
1013                 }
1014                 break;
1015         }
1016         return r;
1017 }
1018
1019 void
1020 fcom(Node *f, Node *a, Node *b)
1021 {
1022         Node *a0, *l0, *l;
1023
1024         ntmp = 0;
1025         ftmps = atmps = modulo = nil;
1026         clevel = 1;
1027         cprint("void %N(", f);
1028         a0 = a;
1029         while(a != nil){
1030                 if(a != a0)
1031                         cprint(", ");
1032                 l = a->c == NAME ? a : a->l;
1033                 l->s->f = FARG|FSET;
1034                 cprint("mpint *%N", l);
1035                 a = a->r;
1036         }
1037         cprint("){\n");
1038         l0 = flocs(b, nil);
1039         for(a = l0; a != nil; a = a->r)
1040                 cprint("mpint *%N = mpnew(0);\n", a->l);
1041         com(b);
1042         for(a = l0; a != nil; a = a->r)
1043                 cprint("mpfree(%N);\n", a->l);
1044         clevel = 0;
1045         cprint("}\n");
1046 }
1047
1048 void
1049 diag(Node *n, char *fmt, ...)
1050 {
1051         static char buf[1024];
1052         va_list a;
1053         
1054         va_start(a, fmt);
1055         vsnprint(buf, sizeof(buf), fmt, a);
1056         va_end(a);
1057
1058         fprint(2, "%s:%d: for %N; %s\n", filename, n->n, n, buf);
1059         exits("error");
1060 }
1061
1062 int
1063 Nfmt(Fmt *f)
1064 {
1065         Node *n = va_arg(f->args, Node*);
1066
1067         if(n == nil)
1068                 return fmtprint(f, "nil");
1069
1070         if(n->c == ',')
1071                 return fmtprint(f, "%N, %N", n->l, n->r);
1072
1073         switch(n->c){
1074         case NUM:
1075                 if(n->m != nil)
1076                         return fmtprint(f, "%B", n->m);
1077                 /* wet floor */
1078         case NAME:
1079                 return fmtprint(f, "%s", n->s->n);
1080         case EQ:
1081                 return fmtprint(f, "==");
1082         case IF:
1083                 return fmtprint(f, "if");
1084         case ELSE:
1085                 return fmtprint(f, "else");
1086         case MOD:
1087                 return fmtprint(f, "mod");
1088         default:
1089                 return fmtprint(f, "%c", (char)n->c);
1090         }
1091 }
1092
1093 void
1094 parse(int fd, char *file)
1095 {
1096         Binit(&bin, fd, OREAD);
1097         filename = file;
1098         clevel = 0;
1099         lineno = 1;
1100         goteof = 0;
1101         while(!goteof)
1102                 yyparse();
1103         Bterm(&bin);
1104 }
1105
1106 void
1107 usage(void)
1108 {
1109         fprint(2, "%s [file ...]\n", argv0);
1110         exits("usage");
1111 }
1112
1113 void
1114 main(int argc, char *argv[])
1115 {
1116         fmtinstall('N', Nfmt);
1117         fmtinstall('B', mpfmt);
1118
1119         ARGBEGIN {
1120         default:
1121                 usage();
1122         } ARGEND;
1123
1124         if(argc == 0){
1125                 parse(0, "<stdin>");
1126                 exits(nil);
1127         }
1128         while(*argv != nil){
1129                 int fd;
1130
1131                 if((fd = open(*argv, OREAD)) < 0){
1132                         fprint(2, "%s: %r\n", *argv);
1133                         exits("error");
1134                 }
1135                 parse(fd, *argv);
1136                 close(fd);
1137                 argv++;
1138         }
1139         exits(nil);
1140 }