]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/dtracy/type.c
aux/realemu: run cpuproc in same fd group as fileserver
[plan9front.git] / sys / src / cmd / dtracy / type.c
1 #include <u.h>
2 #include <libc.h>
3 #include <ctype.h>
4 #include <dtracy.h>
5 #include <bio.h>
6 #include "dat.h"
7 #include "fns.h"
8
9 Node *
10 icast(int sign, int size, Node *n)
11 {
12         Type *t;
13         
14         t = type(TYPINT, sign, size);
15         return node(OCAST, t, n);
16 }
17
18 /*
19         the type checker checks types.
20         the result is an expression that is correct if evaluated with 64-bit operands all the way.
21         to maintain c-like semantics, this means adding casts all over the place, which will get optimised later.
22         
23         note we use kencc, NOT ansi c, semantics for unsigned.
24 */
25
26 Node *
27 typecheck(Node *n)
28 {
29         int s1, s2, sign;
30
31         switch(/*nodetype*/n->type){
32         case OSYM:
33                 switch(n->sym->type){
34                 case SYMNONE: error("undeclared '%s'", n->sym->name); break;
35                 case SYMVAR: n->typ = n->sym->typ; break;
36                 default: sysfatal("typecheck: unknown symbol type %d", n->sym->type);
37                 }
38                 break;
39         case ONUM:
40                 if((vlong)n->num >= -0x80000000LL && (vlong)n->num <= 0x7fffffffLL)
41                         n->typ = type(TYPINT, 4, 1);
42                 else
43                         n->typ = type(TYPINT, 8, 1);
44                 break;
45         case OSTR:
46                 n->typ = type(TYPSTRING);
47                 break;
48         case OBIN:
49                 n->n1 = typecheck(n->n1);
50                 n->n2 = typecheck(n->n2);
51                 if(n->n1->typ == nil || n->n2->typ == nil)
52                         break;
53                 if(n->n1->typ->type != TYPINT){
54                         error("%τ not allowed in operation", n->n1->typ);
55                         break;
56                 }
57                 if(n->n2->typ->type != TYPINT){
58                         error("%τ not allowed in operation", n->n2->typ);
59                         break;
60                 }
61                 s1 = n->n1->typ->size;
62                 s2 = n->n2->typ->size;
63                 sign = n->n1->typ->sign && n->n2->typ->sign;
64                 switch(n->op){
65                 case OPADD:
66                 case OPSUB:
67                 case OPMUL:
68                 case OPDIV:
69                 case OPMOD:
70                 case OPAND:
71                 case OPOR:
72                 case OPXOR:
73                 case OPXNOR:
74                         n->typ = type(TYPINT, 8, sign);
75                         if(s1 > 4 || s2 > 4){
76                                 n->n1 = icast(8, sign, n->n1);
77                                 n->n2 = icast(8, sign, n->n2);
78                                 return n;
79                         }else{
80                                 n->n1 = icast(4, sign, n->n1);
81                                 n->n2 = icast(4, sign, n->n2);
82                                 return icast(4, sign, n);
83                         }
84                 case OPEQ:
85                 case OPNE:
86                 case OPLT:
87                 case OPLE:
88                         n->typ = type(TYPINT, 4, sign);
89                         if(s1 > 4 || s2 > 4){
90                                 n->n1 = icast(8, sign, n->n1);
91                                 n->n2 = icast(8, sign, n->n2);
92                                 return n;
93                         }else{
94                                 n->n1 = icast(4, sign, n->n1);
95                                 n->n2 = icast(4, sign, n->n2);
96                                 return n;
97                         }
98                 case OPLAND:
99                 case OPLOR:
100                         n->typ = type(TYPINT, 4, sign);
101                         return n;
102                 case OPLSH:
103                 case OPRSH:
104                         if(n->n1->typ->size <= 4)
105                                 n->n1 = icast(4, n->n1->typ->sign, n->n1);
106                         n->typ = n->n1->typ;
107                         return icast(n->typ->size, n->typ->sign, n);
108                 default:
109                         sysfatal("typecheck: unknown op %d", n->op);
110                 }
111                 break;
112         case OCAST:
113                 n->n1 = typecheck(n->n1);
114                 if(n->n1->typ == nil)
115                         break;
116                 if(n->typ->type == TYPINT && n->n1->typ->type == TYPINT){
117                 }else if(n->typ == n->n1->typ){
118                 }else if(n->typ->type == TYPSTRING && n->n1->typ->type == TYPINT){
119                 }else
120                         error("can't cast from %τ to %τ", n->n1->typ, n->typ);
121                 break;
122         case OLNOT:
123                 n->n1 = typecheck(n->n1);
124                 if(n->n1->typ == nil)
125                         break;
126                 if(n->n1->typ->type != TYPINT){
127                         error("%τ not allowed in operation", n->n1->typ);
128                         break;
129                 }
130                 n->typ = type(TYPINT, 4, 1);
131                 break;
132         case OTERN:
133                 n->n1 = typecheck(n->n1);
134                 n->n2 = typecheck(n->n2);
135                 n->n3 = typecheck(n->n3);
136                 if(n->n1->typ == nil || n->n2->typ == nil || n->n3->typ == nil)
137                         break;
138                 if(n->n1->typ->type != TYPINT){
139                         error("%τ not allowed in operation", n->n1->typ);
140                         break;
141                 }
142                 if(n->n2->typ->type == TYPINT || n->n3->typ->type == TYPINT){
143                         sign = n->n2->typ->sign && n->n3->typ->sign;
144                         s1 = n->n2->typ->size;
145                         s2 = n->n3->typ->size;
146                         if(s1 > 4 || s2 > 4){
147                                 n->n2 = icast(8, sign, n->n2);
148                                 n->n3 = icast(8, sign, n->n3);
149                                 n->typ = type(TYPINT, 8, sign);
150                                 return n;
151                         }else{
152                                 n->n2 = icast(4, sign, n->n2);
153                                 n->n3 = icast(4, sign, n->n3);
154                                 n->typ = type(TYPINT, 4, sign);
155                                 return n;
156                         }
157                 }else if(n->n2->typ == n->n3->typ){
158                         n->typ = n->n2->typ;
159                 }else
160                         error("don't know how to do ternary with %τ and %τ", n->n2->typ, n->n3->typ);
161                 break;
162         case ORECORD:
163         default: sysfatal("typecheck: unknown node type %α", n->type);
164         }
165         return n;
166 }
167
168 vlong
169 evalop(int op, int sign, vlong v1, vlong v2)
170 {
171         switch(/*oper*/op){
172         case OPADD: return v1 + v2; break;
173         case OPSUB: return v1 - v2; break;
174         case OPMUL: return v1 * v2; break;
175         case OPDIV: if(v2 == 0) sysfatal("division by zero"); return sign ? v1 / v2 : (uvlong)v1 / (uvlong)v2; break;
176         case OPMOD: if(v2 == 0) sysfatal("division by zero"); return sign ? v1 % v2 : (uvlong)v1 % (uvlong)v2; break;
177         case OPAND: return v1 & v2; break;
178         case OPOR: return v1 | v2; break;
179         case OPXOR: return v1 ^ v2; break;
180         case OPXNOR: return ~(v1 ^ v2); break;
181         case OPLSH:
182                 if((u64int)v2 >= 64)
183                         return 0;
184                 else
185                         return v1 << v2;
186                 break;
187         case OPRSH:
188                 if(sign){
189                         if((u64int)v2 >= 64)
190                                 return v1 >> 63;
191                         else
192                                 return v1 >> v2;
193                 }else{
194                         if((u64int)v2 >= 64)
195                                 return 0;
196                         else
197                                 return (u64int)v1 >> v2;
198                 }
199                 break;
200         case OPEQ: return v1 == v2; break;
201         case OPNE: return v1 != v2; break;
202         case OPLT: return v1 < v2; break;
203         case OPLE: return v1 <= v2; break;
204         case OPLAND: return v1 && v2; break;
205         case OPLOR: return v1 || v2; break;
206         default:
207                 sysfatal("cfold: unknown op %.2x", op);
208                 return 0;
209         }
210
211 }
212
213 Node *
214 addtype(Type *t, Node *n)
215 {
216         n->typ = t;
217         return n;
218 }
219
220 /* fold constants */
221
222 static Node *
223 cfold(Node *n)
224 {
225         switch(/*nodetype*/n->type){
226         case ONUM:
227         case OSYM:
228         case OSTR:
229                 return n;
230         case OBIN:
231                 n->n1 = cfold(n->n1);
232                 n->n2 = cfold(n->n2);
233                 if(n->n1->type != ONUM || n->n2->type != ONUM)
234                         return n;
235                 return addtype(n->typ, node(ONUM, evalop(n->op, n->typ->sign, n->n1->num, n->n2->num)));
236         case OLNOT:
237                 n->n1 = cfold(n->n1);
238                 if(n->n1->type == ONUM)
239                         return addtype(n->typ, node(ONUM, !n->n1->num));
240                 return n;
241         case OTERN:
242                 n->n1 = cfold(n->n1);
243                 n->n2 = cfold(n->n2);
244                 n->n3 = cfold(n->n3);
245                 if(n->n1->type == ONUM)
246                         return n->n1->num ? n->n2 : n->n3;
247                 return n;
248         case OCAST:
249                 n->n1 = cfold(n->n1);
250                 if(n->n1->type != ONUM || n->typ->type != TYPINT)
251                         return n;
252                 switch(n->typ->size << 4 | n->typ->sign){
253                 case 0x10: return addtype(n->typ, node(ONUM, (vlong)(u8int)n->n1->num));
254                 case 0x11: return addtype(n->typ, node(ONUM, (vlong)(s8int)n->n1->num));
255                 case 0x20: return addtype(n->typ, node(ONUM, (vlong)(u16int)n->n1->num));
256                 case 0x21: return addtype(n->typ, node(ONUM, (vlong)(s16int)n->n1->num));
257                 case 0x40: return addtype(n->typ, node(ONUM, (vlong)(u32int)n->n1->num));
258                 case 0x41: return addtype(n->typ, node(ONUM, (vlong)(s32int)n->n1->num));
259                 case 0x80: return addtype(n->typ, node(ONUM, n->n1->num));
260                 case 0x81: return addtype(n->typ, node(ONUM, n->n1->num));
261                 }
262                 return n;
263         case ORECORD:
264         default:
265                 fprint(2, "cfold: unknown type %α\n", n->type);
266                 return n;
267         }
268 }
269
270 /* calculate the minimum record size for each node of the expression */
271 static Node *
272 calcrecsize(Node *n)
273 {
274         switch(/*nodetype*/n->type){
275         case ONUM:
276         case OSTR:
277                 n->recsize = 0;
278                 break;
279         case OSYM:
280                 switch(n->sym->type){
281                 case SYMVAR:
282                         switch(n->sym->idx){
283                         case DTV_TIME:
284                         case DTV_PROBE:
285                                 n->recsize = 0;
286                                 break;
287                         default:
288                                 n->recsize = n->typ->size;
289                                 break;
290                         }
291                         break;
292                 default: sysfatal("calcrecsize: unknown symbol type %d", n->sym->type); return nil;
293                 }
294                 break;
295         case OBIN:
296                 n->n1 = calcrecsize(n->n1);
297                 n->n2 = calcrecsize(n->n2);
298                 n->recsize = min(n->typ->size, n->n1->recsize + n->n2->recsize);
299                 break;
300         case OLNOT:
301                 n->n1 = calcrecsize(n->n1);
302                 n->recsize = min(n->typ->size, n->n1->recsize);
303                 break;
304         case OCAST:
305                 n->n1 = calcrecsize(n->n1);
306                 if(n->typ->type == TYPSTRING)
307                         n->recsize = n->typ->size;
308                 else
309                         n->recsize = min(n->typ->size, n->n1->recsize);
310                 break;
311         case OTERN:
312                 n->n1 = calcrecsize(n->n1);
313                 n->n2 = calcrecsize(n->n2);
314                 n->n3 = calcrecsize(n->n3);
315                 n->recsize = min(n->typ->size, n->n1->recsize + n->n2->recsize + n->n3->recsize);
316                 break;
317         case ORECORD:
318         default: sysfatal("calcrecsize: unknown type %α", n->type); return nil;
319         }
320         return n;
321 }
322
323 /* insert ORECORD nodes to mark the subexpression that we will pass to the kernel */
324 static Node *
325 insrecord(Node *n)
326 {
327         if(n->recsize == 0)
328                 return n;
329         if(n->typ->size == n->recsize)
330                 return addtype(n->typ, node(ORECORD, n));
331         switch(/*nodetype*/n->type){
332         case ONUM:
333         case OSTR:
334         case OSYM:
335                 break;
336         case OBIN:
337                 n->n1 = insrecord(n->n1);
338                 n->n2 = insrecord(n->n2);
339                 break;
340         case OLNOT:
341         case OCAST:
342                 n->n1 = insrecord(n->n1);
343                 break;
344         case OTERN:
345                 n->n1 = insrecord(n->n1);
346                 n->n2 = insrecord(n->n2);
347                 n->n3 = insrecord(n->n3);
348                 break;
349         case ORECORD:
350         default: sysfatal("insrecord: unknown type %α", n->type); return nil;
351         }
352         return n;
353 }
354
355 /*
356         delete useless casts.
357         going down we determine the number of bits (m) needed to be correct at each stage.
358         going back up we determine the number of bits (n->databits) which can be either 0 or 1.
359         all other bits are either zero (n->upper == UPZX) or sign-extended (n->upper == UPSX).
360         note that by number of bits we always mean a consecutive block starting from the LSB.
361         
362         we can delete a cast if it either affects only bits not needed (according to m) or
363         if it's a no-op (according to databits, upper).
364 */
365 static Node *
366 elidecasts(Node *n, int m)
367 {
368         switch(/*nodetype*/n->type){
369         case OSTR:
370                 return n;
371         case ONUM:
372                 n->databits = n->typ->size * 8;
373                 n->upper = n->typ->sign ? UPSX : UPZX;
374                 break;
375         case OSYM:
376                 /* TODO: make less pessimistic */
377                 n->databits = 64;
378                 break;
379         case OBIN:
380                 switch(/*oper*/n->op){
381                 case OPADD:
382                 case OPSUB:
383                         n->n1 = elidecasts(n->n1, m);
384                         n->n2 = elidecasts(n->n2, m);
385                         n->databits = min(64, max(n->n1->databits, n->n2->databits) + 1);
386                         n->upper = n->n1->upper | n->n2->upper;
387                         break;
388                 case OPMUL:
389                         n->n1 = elidecasts(n->n1, m);
390                         n->n2 = elidecasts(n->n2, m);
391                         n->databits = min(64, n->n1->databits + n->n2->databits);
392                         n->upper = n->n1->upper | n->n2->upper;
393                         break;
394                 case OPAND:
395                 case OPOR:
396                 case OPXOR:
397                 case OPXNOR:
398                         n->n1 = elidecasts(n->n1, m);
399                         n->n2 = elidecasts(n->n2, m);
400                         if(n->op == OPAND && (n->n1->upper == UPZX || n->n2->upper == UPZX)){
401                                 n->upper = UPZX;
402                                 if(n->n1->upper == UPZX && n->n2->upper == UPZX)
403                                         n->databits = min(n->n1->databits, n->n2->databits);
404                                 else if(n->n1->upper == UPZX)
405                                         n->databits = n->n1->databits;
406                                 else
407                                         n->databits = n->n2->databits;
408                         }else{
409                                 n->databits = max(n->n1->databits, n->n2->databits);
410                                 n->upper = n->n1->upper | n->n2->upper;
411                         }
412                         break;
413                 case OPLSH:
414                         n->n1 = elidecasts(n->n1, m);
415                         n->n2 = elidecasts(n->n2, 64);
416                         if(n->n2->type == ONUM && n->n2->num >= 0 && n->n1->databits + (uvlong)n->n2->num <= 64)
417                                 n->databits = n->n1->databits + n->n2->num;
418                         else
419                                 n->databits = 64;
420                         n->upper = n->n1->upper;
421                         break;
422                 case OPRSH:
423                         n->n1 = elidecasts(n->n1, 64);
424                         n->n2 = elidecasts(n->n2, 64);
425                         if(n->n1->upper == n->typ->sign){
426                                 n->databits = n->n1->databits;
427                                 n->upper = n->n1->upper;
428                         }else{
429                                 n->databits = 64;
430                                 n->upper = UPZX;
431                         }
432                         break;
433                 case OPEQ:
434                 case OPNE:
435                 case OPLT:
436                 case OPLE:
437                 case OPLAND:
438                 case OPLOR:
439                         n->n1 = elidecasts(n->n1, 64);
440                         n->n2 = elidecasts(n->n2, 64);
441                         n->databits = 1;
442                         n->upper = UPZX;
443                         break;
444                 case OPDIV:
445                 case OPMOD:
446                 default:
447                         n->n1 = elidecasts(n->n1, 64);
448                         n->n2 = elidecasts(n->n2, 64);
449                         n->databits = 64;
450                         n->upper = UPZX;
451                         break;
452                 }
453                 break;
454         case OLNOT:
455                 n->n1 = elidecasts(n->n1, 64);
456                 n->databits = 1;
457                 n->upper = UPZX;
458                 break;
459         case OCAST:
460                 switch(n->typ->type){
461                 case TYPINT:
462                         n->n1 = elidecasts(n->n1, min(n->typ->size * 8, m));
463                         if(n->n1->databits < n->typ->size * 8 && n->n1->upper == n->typ->sign){
464                                 n->databits = n->n1->databits;
465                                 n->upper = n->n1->upper;
466                         }else{
467                                 n->databits = n->typ->size * 8;
468                                 n->upper = n->typ->sign ? UPSX : UPZX;
469                         }
470                         if(n->typ->size * 8 >= m) return n->n1;
471                         if(n->typ->size * 8 >= n->n1->databits && n->typ->sign == n->n1->upper) return n->n1;
472                         if(n->typ->size * 8 > n->n1->databits && n->typ->sign && !n->n1->upper) return n->n1;
473                         break;
474                 case TYPSTRING:
475                         n->n1 = elidecasts(n->n1, 64);
476                         break;
477                 default:
478                         sysfatal("elidecasts: don't know how to cast %τ to %τ", n->n1->typ, n->typ);
479                 }
480                 break;
481         case ORECORD:
482                 n->n1 = elidecasts(n->n1, min(n->typ->size * 8, m));
483                 if(n->n1->databits < n->typ->size * 8 && n->n1->upper == n->typ->sign){
484                         n->databits = n->n1->databits;
485                         n->upper = n->n1->upper;
486                 }else{
487                         n->databits = n->typ->size * 8;
488                         n->upper = n->typ->sign ? UPSX : UPZX;
489                 }
490                 break;
491         case OTERN:
492                 n->n1 = elidecasts(n->n1, 64);
493                 n->n2 = elidecasts(n->n2, m);
494                 n->n3 = elidecasts(n->n3, m);
495                 if(n->n2->upper == n->n3->upper){
496                         n->databits = max(n->n2->databits, n->n3->databits);
497                         n->upper = n->n2->upper;
498                 }else{
499                         if(n->n3->upper == UPSX)
500                                 n->databits = max(min(64, n->n2->databits + 1), n->n3->databits);
501                         else
502                                 n->databits = max(min(64, n->n3->databits + 1), n->n2->databits);
503                         n->upper = UPSX;
504                 }
505                 break;
506         default: sysfatal("elidecasts: unknown type %α", n->type);
507         }
508 //      print("need %d got %d%c %ε\n", n->needbits, n->databits, "ZS"[n->upper], n);
509         return n;
510 }
511
512
513 Node *
514 exprcheck(Node *n, int pred)
515 {
516         if(dflag) print("start       %ε\n", n);
517         n = typecheck(n);
518         if(errors) return n;
519         if(dflag) print("typecheck   %ε\n", n);
520         n = cfold(n);
521         if(dflag) print("cfold       %ε\n", n);
522         if(!pred){
523                 n = insrecord(calcrecsize(n));
524                 if(dflag) print("insrecord   %ε\n", n);
525         }
526         n = elidecasts(n, 64);
527         if(dflag) print("elidecasts  %ε\n", n);
528         return n;
529 }