10 icast(int sign, int size, Node *n)
14 t = type(TYPINT, sign, size);
15 return node(OCAST, t, n);
19 the type checker checks types.
20 the result is an expression that is correct if evaluated with 64-bit operands all the way.
21 to maintain c-like semantics, this means adding casts all over the place, which will get optimised later.
23 note we use kencc, NOT ansi c, semantics for unsigned.
31 switch(/*nodetype*/n->type){
34 case SYMNONE: error("undeclared '%s'", n->sym->name); break;
35 case SYMVAR: n->typ = n->sym->typ; break;
36 default: sysfatal("typecheck: unknown symbol type %d", n->sym->type);
40 if((vlong)n->num >= -0x80000000LL && (vlong)n->num <= 0x7fffffffLL)
41 n->typ = type(TYPINT, 4, 1);
43 n->typ = type(TYPINT, 8, 1);
46 n->typ = type(TYPSTRING);
49 n->n1 = typecheck(n->n1);
50 n->n2 = typecheck(n->n2);
51 if(n->n1->typ == nil || n->n2->typ == nil)
53 if(n->n1->typ->type != TYPINT){
54 error("%τ not allowed in operation", n->n1->typ);
57 if(n->n2->typ->type != TYPINT){
58 error("%τ not allowed in operation", n->n2->typ);
61 s1 = n->n1->typ->size;
62 s2 = n->n2->typ->size;
63 sign = n->n1->typ->sign && n->n2->typ->sign;
74 n->typ = type(TYPINT, 8, sign);
76 n->n1 = icast(8, sign, n->n1);
77 n->n2 = icast(8, sign, n->n2);
80 n->n1 = icast(4, sign, n->n1);
81 n->n2 = icast(4, sign, n->n2);
82 return icast(4, sign, n);
88 n->typ = type(TYPINT, 4, sign);
90 n->n1 = icast(8, sign, n->n1);
91 n->n2 = icast(8, sign, n->n2);
94 n->n1 = icast(4, sign, n->n1);
95 n->n2 = icast(4, sign, n->n2);
100 n->typ = type(TYPINT, 4, sign);
104 if(n->n1->typ->size <= 4)
105 n->n1 = icast(4, n->n1->typ->sign, n->n1);
107 return icast(n->typ->size, n->typ->sign, n);
109 sysfatal("typecheck: unknown op %d", n->op);
113 n->n1 = typecheck(n->n1);
114 if(n->n1->typ == nil)
116 if(n->typ->type == TYPINT && n->n1->typ->type == TYPINT){
117 }else if(n->typ == n->n1->typ){
118 }else if(n->typ->type == TYPSTRING && n->n1->typ->type == TYPINT){
120 error("can't cast from %τ to %τ", n->n1->typ, n->typ);
123 n->n1 = typecheck(n->n1);
124 if(n->n1->typ == nil)
126 if(n->n1->typ->type != TYPINT){
127 error("%τ not allowed in operation", n->n1->typ);
130 n->typ = type(TYPINT, 4, 1);
133 n->n1 = typecheck(n->n1);
134 n->n2 = typecheck(n->n2);
135 n->n3 = typecheck(n->n3);
136 if(n->n1->typ == nil || n->n2->typ == nil || n->n3->typ == nil)
138 if(n->n1->typ->type != TYPINT){
139 error("%τ not allowed in operation", n->n1->typ);
142 if(n->n2->typ->type == TYPINT || n->n3->typ->type == TYPINT){
143 sign = n->n2->typ->sign && n->n3->typ->sign;
144 s1 = n->n2->typ->size;
145 s2 = n->n3->typ->size;
146 if(s1 > 4 || s2 > 4){
147 n->n2 = icast(8, sign, n->n2);
148 n->n3 = icast(8, sign, n->n3);
149 n->typ = type(TYPINT, 8, sign);
152 n->n2 = icast(4, sign, n->n2);
153 n->n3 = icast(4, sign, n->n3);
154 n->typ = type(TYPINT, 4, sign);
157 }else if(n->n2->typ == n->n3->typ){
160 error("don't know how to do ternary with %τ and %τ", n->n2->typ, n->n3->typ);
163 default: sysfatal("typecheck: unknown node type %α", n->type);
169 evalop(int op, int sign, vlong v1, vlong v2)
172 case OPADD: return v1 + v2; break;
173 case OPSUB: return v1 - v2; break;
174 case OPMUL: return v1 * v2; break;
175 case OPDIV: if(v2 == 0) sysfatal("division by zero"); return sign ? v1 / v2 : (uvlong)v1 / (uvlong)v2; break;
176 case OPMOD: if(v2 == 0) sysfatal("division by zero"); return sign ? v1 % v2 : (uvlong)v1 % (uvlong)v2; break;
177 case OPAND: return v1 & v2; break;
178 case OPOR: return v1 | v2; break;
179 case OPXOR: return v1 ^ v2; break;
180 case OPXNOR: return ~(v1 ^ v2); break;
197 return (u64int)v1 >> v2;
200 case OPEQ: return v1 == v2; break;
201 case OPNE: return v1 != v2; break;
202 case OPLT: return v1 < v2; break;
203 case OPLE: return v1 <= v2; break;
204 case OPLAND: return v1 && v2; break;
205 case OPLOR: return v1 || v2; break;
207 sysfatal("cfold: unknown op %.2x", op);
214 addtype(Type *t, Node *n)
225 switch(/*nodetype*/n->type){
231 n->n1 = cfold(n->n1);
232 n->n2 = cfold(n->n2);
233 if(n->n1->type != ONUM || n->n2->type != ONUM)
235 return addtype(n->typ, node(ONUM, evalop(n->op, n->typ->sign, n->n1->num, n->n2->num)));
237 n->n1 = cfold(n->n1);
238 if(n->n1->type == ONUM)
239 return addtype(n->typ, node(ONUM, !n->n1->num));
242 n->n1 = cfold(n->n1);
243 n->n2 = cfold(n->n2);
244 n->n3 = cfold(n->n3);
245 if(n->n1->type == ONUM)
246 return n->n1->num ? n->n2 : n->n3;
249 n->n1 = cfold(n->n1);
250 if(n->n1->type != ONUM || n->typ->type != TYPINT)
252 switch(n->typ->size << 4 | n->typ->sign){
253 case 0x10: return addtype(n->typ, node(ONUM, (vlong)(u8int)n->n1->num));
254 case 0x11: return addtype(n->typ, node(ONUM, (vlong)(s8int)n->n1->num));
255 case 0x20: return addtype(n->typ, node(ONUM, (vlong)(u16int)n->n1->num));
256 case 0x21: return addtype(n->typ, node(ONUM, (vlong)(s16int)n->n1->num));
257 case 0x40: return addtype(n->typ, node(ONUM, (vlong)(u32int)n->n1->num));
258 case 0x41: return addtype(n->typ, node(ONUM, (vlong)(s32int)n->n1->num));
259 case 0x80: return addtype(n->typ, node(ONUM, n->n1->num));
260 case 0x81: return addtype(n->typ, node(ONUM, n->n1->num));
265 fprint(2, "cfold: unknown type %α\n", n->type);
270 /* calculate the minimum record size for each node of the expression */
274 switch(/*nodetype*/n->type){
280 switch(n->sym->type){
288 n->recsize = n->typ->size;
292 default: sysfatal("calcrecsize: unknown symbol type %d", n->sym->type); return nil;
296 n->n1 = calcrecsize(n->n1);
297 n->n2 = calcrecsize(n->n2);
298 n->recsize = min(n->typ->size, n->n1->recsize + n->n2->recsize);
301 n->n1 = calcrecsize(n->n1);
302 n->recsize = min(n->typ->size, n->n1->recsize);
305 n->n1 = calcrecsize(n->n1);
306 if(n->typ->type == TYPSTRING)
307 n->recsize = n->typ->size;
309 n->recsize = min(n->typ->size, n->n1->recsize);
312 n->n1 = calcrecsize(n->n1);
313 n->n2 = calcrecsize(n->n2);
314 n->n3 = calcrecsize(n->n3);
315 n->recsize = min(n->typ->size, n->n1->recsize + n->n2->recsize + n->n3->recsize);
318 default: sysfatal("calcrecsize: unknown type %α", n->type); return nil;
323 /* insert ORECORD nodes to mark the subexpression that we will pass to the kernel */
329 if(n->typ->size == n->recsize)
330 return addtype(n->typ, node(ORECORD, n));
331 switch(/*nodetype*/n->type){
337 n->n1 = insrecord(n->n1);
338 n->n2 = insrecord(n->n2);
342 n->n1 = insrecord(n->n1);
345 n->n1 = insrecord(n->n1);
346 n->n2 = insrecord(n->n2);
347 n->n3 = insrecord(n->n3);
350 default: sysfatal("insrecord: unknown type %α", n->type); return nil;
356 delete useless casts.
357 going down we determine the number of bits (m) needed to be correct at each stage.
358 going back up we determine the number of bits (n->databits) which can be either 0 or 1.
359 all other bits are either zero (n->upper == UPZX) or sign-extended (n->upper == UPSX).
360 note that by number of bits we always mean a consecutive block starting from the LSB.
362 we can delete a cast if it either affects only bits not needed (according to m) or
363 if it's a no-op (according to databits, upper).
366 elidecasts(Node *n, int m)
368 switch(/*nodetype*/n->type){
372 n->databits = n->typ->size * 8;
373 n->upper = n->typ->sign ? UPSX : UPZX;
376 /* TODO: make less pessimistic */
380 switch(/*oper*/n->op){
383 n->n1 = elidecasts(n->n1, m);
384 n->n2 = elidecasts(n->n2, m);
385 n->databits = min(64, max(n->n1->databits, n->n2->databits) + 1);
386 n->upper = n->n1->upper | n->n2->upper;
389 n->n1 = elidecasts(n->n1, m);
390 n->n2 = elidecasts(n->n2, m);
391 n->databits = min(64, n->n1->databits + n->n2->databits);
392 n->upper = n->n1->upper | n->n2->upper;
398 n->n1 = elidecasts(n->n1, m);
399 n->n2 = elidecasts(n->n2, m);
400 if(n->op == OPAND && (n->n1->upper == UPZX || n->n2->upper == UPZX)){
402 if(n->n1->upper == UPZX && n->n2->upper == UPZX)
403 n->databits = min(n->n1->databits, n->n2->databits);
404 else if(n->n1->upper == UPZX)
405 n->databits = n->n1->databits;
407 n->databits = n->n2->databits;
409 n->databits = max(n->n1->databits, n->n2->databits);
410 n->upper = n->n1->upper | n->n2->upper;
414 n->n1 = elidecasts(n->n1, m);
415 n->n2 = elidecasts(n->n2, 64);
416 if(n->n2->type == ONUM && n->n2->num >= 0 && n->n1->databits + (uvlong)n->n2->num <= 64)
417 n->databits = n->n1->databits + n->n2->num;
420 n->upper = n->n1->upper;
423 n->n1 = elidecasts(n->n1, 64);
424 n->n2 = elidecasts(n->n2, 64);
425 if(n->n1->upper == n->typ->sign){
426 n->databits = n->n1->databits;
427 n->upper = n->n1->upper;
439 n->n1 = elidecasts(n->n1, 64);
440 n->n2 = elidecasts(n->n2, 64);
447 n->n1 = elidecasts(n->n1, 64);
448 n->n2 = elidecasts(n->n2, 64);
455 n->n1 = elidecasts(n->n1, 64);
460 switch(n->typ->type){
462 n->n1 = elidecasts(n->n1, min(n->typ->size * 8, m));
463 if(n->n1->databits < n->typ->size * 8 && n->n1->upper == n->typ->sign){
464 n->databits = n->n1->databits;
465 n->upper = n->n1->upper;
467 n->databits = n->typ->size * 8;
468 n->upper = n->typ->sign ? UPSX : UPZX;
470 if(n->typ->size * 8 >= m) return n->n1;
471 if(n->typ->size * 8 >= n->n1->databits && n->typ->sign == n->n1->upper) return n->n1;
472 if(n->typ->size * 8 > n->n1->databits && n->typ->sign && !n->n1->upper) return n->n1;
475 n->n1 = elidecasts(n->n1, 64);
478 sysfatal("elidecasts: don't know how to cast %τ to %τ", n->n1->typ, n->typ);
482 n->n1 = elidecasts(n->n1, min(n->typ->size * 8, m));
483 if(n->n1->databits < n->typ->size * 8 && n->n1->upper == n->typ->sign){
484 n->databits = n->n1->databits;
485 n->upper = n->n1->upper;
487 n->databits = n->typ->size * 8;
488 n->upper = n->typ->sign ? UPSX : UPZX;
492 n->n1 = elidecasts(n->n1, 64);
493 n->n2 = elidecasts(n->n2, m);
494 n->n3 = elidecasts(n->n3, m);
495 if(n->n2->upper == n->n3->upper){
496 n->databits = max(n->n2->databits, n->n3->databits);
497 n->upper = n->n2->upper;
499 if(n->n3->upper == UPSX)
500 n->databits = max(min(64, n->n2->databits + 1), n->n3->databits);
502 n->databits = max(min(64, n->n3->databits + 1), n->n2->databits);
506 default: sysfatal("elidecasts: unknown type %α", n->type);
508 // print("need %d got %d%c %ε\n", n->needbits, n->databits, "ZS"[n->upper], n);
514 exprcheck(Node *n, int pred)
516 if(dflag) print("start %ε\n", n);
519 if(dflag) print("typecheck %ε\n", n);
521 if(dflag) print("cfold %ε\n", n);
523 n = insrecord(calcrecsize(n));
524 if(dflag) print("insrecord %ε\n", n);
526 n = elidecasts(n, 64);
527 if(dflag) print("elidecasts %ε\n", n);