]> git.lizzy.rs Git - plan9front.git/blobdiff - sys/src/cmd/mpc.y
awk: make empty FS unicodely-correct.
[plan9front.git] / sys / src / cmd / mpc.y
index b63fb383deb3c2d52d209a4d04bbdf5ed9c5f120..67b0fec9fe7a5475d28e705d03d99e23045a303c 100644 (file)
@@ -28,6 +28,7 @@ struct Node
        Node*   l;
        Node*   r;
        Sym*    s;
+       mpint*  m;
        int     n;
 };
 
@@ -63,7 +64,6 @@ void  fcom(Node*,Node*,Node*);
 {
        Sym*    sval;
        Node*   node;
-       long    lval;
 }
 
 %type  <node>  name num args expr bool block elif stmnt stmnts
@@ -79,7 +79,7 @@ void  fcom(Node*,Node*,Node*);
 %left  '^'
 %right '('
 
-%token <lval>  MOD IF ELSE WHILE BREAK 
+%token MOD IF ELSE WHILE BREAK 
 %token <sval>  NAME NUM
 
 %%
@@ -238,11 +238,11 @@ expr:
        {
                $$ = new('e', $1, $2);
        }
-|      expr LSH num
+|      expr LSH expr
        {
                $$ = new(LSH, $1, $3);
        }
-|      expr RSH num
+|      expr RSH expr
        {
                $$ = new(RSH, $1, $3);
        }
@@ -390,6 +390,7 @@ new(int c, Node *l, Node *r)
        n->l = l;
        n->r = r;
        n->s = nil;
+       n->m = nil;
        n->n = lineno;
        return n;
 }
@@ -561,7 +562,7 @@ complex(Node *n)
 {
        if(n->c == NAME)
                return 0;
-       if(n->c == NUM && strlen(n->s->n) == 1 && atoi(n->s->n) < 3)
+       if(n->c == NUM && n->m->sign > 0 && mpcmp(n->m, mptwo) <= 0)
                return 0;
        return 1;
 }
@@ -569,38 +570,132 @@ complex(Node *n)
 void
 bcom(Node *n, Node *t);
 
+Node*
+ccom(Node *f)
+{
+       Node *l, *r;
+
+       if(f == nil)
+               return nil;
+
+       if(f->m != nil)
+               return f;
+       f->m = (void*)~0;
+
+       switch(f->c){
+       case NUM:
+               f->m = strtomp(f->s->n, nil, 0, nil);
+               if(f->m == nil)
+                       diag(f, "bad constant");
+               goto out;
+
+       case LSH:
+       case RSH:
+               break;
+
+       case '+':
+       case '-':
+       case '*':
+       case '/':
+       case '%':
+       case '^':
+               if(modulo == nil || modulo->c == NUM)
+                       break;
+
+               /* wet floor */
+       default:
+               return f;
+       }
+
+       f->l = l = ccom(f->l);
+       f->r = r = ccom(f->r);
+       if(l == nil || r == nil || l->c != NUM || r->c != NUM)
+               return f;
+
+       f->m = mpnew(0);
+       switch(f->c){
+       case LSH:
+       case RSH:
+               if(mpsignif(r->m) > 32)
+                       diag(f, "bad shift");
+               if(f->c == LSH)
+                       mpleft(l->m, mptoi(r->m), f->m);
+               else
+                       mpright(l->m, mptoi(r->m), f->m);
+               goto out;
+
+       case '+':
+               mpadd(l->m, r->m, f->m);
+               break;
+       case '-':
+               mpsub(l->m, r->m, f->m);
+               break;
+       case '*':
+               mpmul(l->m, r->m, f->m);
+               break;
+       case '/':
+               if(modulo != nil){
+                       mpinvert(r->m, modulo->m, f->m);
+                       mpmul(f->m, l->m, f->m);
+               } else {
+                       mpdiv(l->m, r->m, f->m, nil);
+               }
+               break;
+       case '%':
+               mpmod(l->m, r->m, f->m);
+               break;
+       case '^':
+               mpexp(l->m, r->m, modulo != nil ? modulo->m : nil, f->m);
+               goto out;
+       }
+       if(modulo != nil)
+               mpmod(f->m, modulo->m, f->m);
+
+out:
+       f->l = nil;
+       f->r = nil;
+       f->s = nil;
+       f->c = NUM;
+       return f;
+}
+
 Node*
 ecom(Node *f, Node *t)
 {
        Node *l, *r, *t2;
-       mpint *m;
 
        if(f == nil)
                return nil;
 
+       f = ccom(f);
        if(f->c == NUM){
-               m = strtomp(f->s->n, nil, 0, nil);
-               if(m == nil)
-                       diag(f, "bad constant");
-               if(mpcmp(m, mpzero) == 0){
+               if(f->m->sign < 0){
+                       f->m->sign = 1;
+                       t = ecom(f, t);
+                       f->m->sign = -1;
+                       if(isconst(t))
+                               t = ecom(t, alloctmp());
+                       cprint("%N->sign = -1;\n", t);
+                       return t;
+               }
+               if(mpcmp(f->m, mpzero) == 0){
                        f->c = NAME;
                        f->s = sym("mpzero");
                        f->s->f = FSET;
                        return ecom(f, t);
                }
-               if(mpcmp(m, mpone) == 0){
+               if(mpcmp(f->m, mpone) == 0){
                        f->c = NAME;
                        f->s = sym("mpone");
                        f->s->f = FSET;
                        return ecom(f, t);
                }
-               if(mpcmp(m, mptwo) == 0){
+               if(mpcmp(f->m, mptwo) == 0){
                        f->c = NAME;
                        f->s = sym("mptwo");
                        f->s->f = FSET;
                        return ecom(f, t);
                }
-               mpfree(m);
        }
 
        if(f->c == ','){
@@ -645,24 +740,23 @@ ecom(Node *f, Node *t)
 
        switch(f->c){
        case NUM:
-               m = strtomp(f->s->n, nil, 0, nil);
-               if(m == nil)
-                       diag(f, "bad constant");
-               if(mpsignif(m) <= 32)
-                       cprint("uitomp(%udUL, %N);\n", mptoui(m), t);
-               else if(mpsignif(m) <= 64)
-                       cprint("uvtomp(%lludULL, %N);\n", mptouv(m), t);
+               if(mpsignif(f->m) <= 32)
+                       cprint("uitomp(%udUL, %N);\n", mptoui(f->m), t);
+               else if(mpsignif(f->m) <= 64)
+                       cprint("uvtomp(%lludULL, %N);\n", mptouv(f->m), t);
                else
-                       cprint("strtomp(\"%.16B\", nil, 16, %N);\n", m, t);
-               mpfree(m);
+                       cprint("strtomp(\"%.16B\", nil, 16, %N);\n", f->m, t);
                goto out;
        case LSH:
-               l = f->l->c == NAME ? f->l : ecom(f->l, t);
-               cprint("mpleft(%N, %N, %N);\n", l, f->r, t);
-               goto out;
        case RSH:
+               r = ccom(f->r);
+               if(r == nil || r->c != NUM || mpsignif(r->m) > 32)
+                       diag(f, "bad shift");
                l = f->l->c == NAME ? f->l : ecom(f->l, t);
-               cprint("mpright(%N, %N, %N);\n", l, f->r, t);
+               if(f->c == LSH)
+                       cprint("mpleft(%N, %d, %N);\n", l, mptoi(r->m), t);
+               else
+                       cprint("mpright(%N, %d, %N);\n", l, mptoi(r->m), t);
                goto out;
        case '*':
        case '/':
@@ -670,8 +764,10 @@ ecom(Node *f, Node *t)
                r = ecom(f->r, nil);
                break;
        default:
-               l = ecom(f->l, complex(f->l) && !symref(f->r, t->s) ? t : nil);
-               r = ecom(f->r, complex(f->r) && l->s != t->s ? t : nil);
+               l = ccom(f->l);
+               r = ccom(f->r);
+               l = ecom(l, complex(l) && !symref(r, t->s) ? t : nil);
+               r = ecom(r, complex(r) && l->s != t->s ? t : nil);
                break;
        }
 
@@ -975,8 +1071,11 @@ Nfmt(Fmt *f)
                return fmtprint(f, "%N, %N", n->l, n->r);
 
        switch(n->c){
-       case NAME:
        case NUM:
+               if(n->m != nil)
+                       return fmtprint(f, "%B", n->m);
+               /* wet floor */
+       case NAME:
                return fmtprint(f, "%s", n->s->n);
        case EQ:
                return fmtprint(f, "==");