]> git.lizzy.rs Git - uwu-lang.git/blob - src/collect.c
fadb9d10129026496782d805d8fd212b2d96f3e9
[uwu-lang.git] / src / collect.c
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
4 #include <stdarg.h>
5 #include <libgen.h>
6 #include <dlfcn.h>
7 #include "err.h"
8 #include "util.h"
9 #include "collect.h"
10 #include "parse.h"
11
12 #define DEBUG 0
13
14 // helper functions
15
16 static char *wrap_name_func(const char *name, char *(*fn)(char *))
17 {
18         char *copy = strdup(name);
19         char *result = fn(copy);
20         char *result_copy = strdup(result);
21
22         free(copy);
23         return result_copy;
24 }
25
26 static char *basename_wrapper(const char *name)
27 {
28         return wrap_name_func(name, &basename);
29 }
30
31 static char *dirname_wrapper(const char *name)
32 {
33         return wrap_name_func(name, &dirname);
34 }
35
36 static bool file_exists(const char *filename)
37 {
38         FILE *f = fopen(filename, "r");
39
40         if (f) {
41                 fclose(f);
42                 return true;
43         }
44
45         return false;
46 }
47
48 // type definitions
49
50 typedef struct
51 {
52         char *name;
53         UwUVMFunction *ref;
54 } FunctionLink;
55
56 typedef struct
57 {
58         char *path;                                     // path without file extension
59         char *filename;                         // path with file extension
60         char *environment;                      // directory path
61
62         UwUVMModuleType type;                   // native (.so) or plain (.uwu)
63
64         FunctionLink *functions;        // required functions
65         size_t    num_functions;        // number of required functions
66         size_t loaded_functions;        // number of loaded functions (<= num_functions)
67
68         union
69         {
70                 AbstractSyntaxTree ast; // abstract syntax tree generated by parser (for plain modules)
71                 void *lib;                              // dlopen() shared object handle (for native modules)
72         } handle;
73 } Module;
74
75 typedef struct 
76 {
77         Module   **modules;                     // loaded modules
78         size_t num_modules;                     // count for modules
79         
80         char *std_path;                         // path to standard library
81
82         UwUVMProgram program;                   // the result program
83 } CollectorState;
84
85 // functions
86
87 // returns mallocated string
88 static inline char *get_filename(const char *module_path)
89 {
90         const char *try_names[3] = {
91                 "%s",
92                 "%s.uwu",
93                 "%s.so",
94         };
95
96         char *filename;
97
98         for (int i = 0; i < 3; i++) {
99                 filename = asprintf_wrapper(try_names[i], module_path);
100
101                 if (file_exists(filename))
102                         return filename;
103                 else
104                         free(filename);
105         }
106
107         return NULL;
108 }
109
110 // module_path is a mallocated string
111 static Module *require_module(CollectorState *state, char *module_path)
112 {
113         for (size_t i = 0; i < state->num_modules; i++) {
114                 Module *module = state->modules[i];
115
116                 if (strcmp(module_path, module->path) == 0) {
117                         free(module_path);
118                         return module;
119                 }
120         }
121
122         char *filename = get_filename(module_path);
123
124         if (! filename)
125                 error("error: module %s not found\n", module_path);
126
127         size_t filename_len = strlen(filename);
128         UwUVMModuleType type = (filename_len >= 3 && strcmp(filename + filename_len - 3, ".so") == 0) ? MODULE_NATIVE : MODULE_PLAIN;
129
130         state->modules = realloc(state->modules, sizeof *state->modules * ++state->num_modules);
131         Module *module = state->modules[state->num_modules - 1] = malloc(sizeof *module);
132
133         *module = (Module) {
134                 .path = module_path,
135                 .filename = filename,
136                 .environment = dirname_wrapper(module_path),
137
138                 .type = type,
139
140                 .functions = NULL,
141                 .num_functions = 0,
142                 .loaded_functions = 0,
143         };
144
145         if (type == MODULE_PLAIN) {
146                 module->handle.ast = parse_file(filename);
147         } else {
148                 state->program.libraries = realloc(state->program.libraries, sizeof(void *) * ++state->program.num_libraries);
149                 state->program.libraries[state->program.num_libraries - 1] = module->handle.lib = dlopen(filename, RTLD_LAZY);
150
151                 char *err = dlerror();
152                 if (err)
153                         error("%s\n", err);
154         }
155
156         return module;
157 }
158
159 static UwUVMFunction *require_function(CollectorState *state, Module *module, const char *name)
160 {
161         for (size_t i = 0; i < module->num_functions; i++) {
162                 FunctionLink *link = &module->functions[i];
163
164                 if (strcmp(link->name, name) == 0)
165                         return link->ref;
166         }
167
168         UwUVMFunction *ref = malloc(sizeof *ref);
169         ref->type = module->type;
170
171         state->program.functions = realloc(state->program.functions, sizeof *state->program.functions * ++state->program.num_functions);
172         state->program.functions[state->program.num_functions - 1] = ref;       
173
174         module->functions = realloc(module->functions, sizeof *module->functions * ++module->num_functions);
175         module->functions[module->num_functions - 1] = (FunctionLink) {
176                 .name = strdup(name),
177                 .ref = ref,
178         };
179
180         return ref;
181 }
182
183 static UwUVMFunction *resolve_function(CollectorState *state, Module *caller_module, const char *full_name)
184 {
185         size_t len = strlen(full_name);
186
187         const char *fnname;
188         for (fnname = &full_name[len - 1]; *fnname != ':' && fnname > full_name; fnname--)
189                 ;
190
191         if (*fnname == ':')
192                 fnname++;
193
194         if (*fnname == '\0')
195                 error("error: empty function name\n");
196
197         Module *callee_module;
198
199         if (fnname == full_name) {
200                 callee_module = caller_module;
201         } else {
202                 const char *caller_path = caller_module->environment;
203                 const char *callee_name = full_name;
204
205                 if (*callee_name == ':') {
206                         caller_path = state->std_path;
207                         callee_name++;
208                 }
209
210                 size_t path_len = fnname - callee_name; 
211                 char callee_path[path_len];
212
213                 for (size_t i = 0; i < path_len; i++)
214                         callee_path[i] = (i == path_len - 1) ? '\0'
215                                 : (callee_name[i] == ':') ? '/'
216                                 : callee_name[i];
217
218                 callee_module = require_module(state, asprintf_wrapper("%s/%s", caller_path, callee_path));
219         }
220
221         return require_function(state, callee_module, fnname);
222 }
223
224 static void translate_expression(CollectorState *state, Module *module, UwUVMExpression *vm_expr, ParseExpression *parse_expr)
225 {
226         UwUVMFunction *vm_function;
227
228         if (parse_expr->type == EX_FNNAME || parse_expr->type == EX_FNCALL) {
229                 vm_function = resolve_function(state, module, parse_expr->value.str_value);
230                 free(parse_expr->value.str_value);
231         }
232         
233         switch (vm_expr->type = parse_expr->type) {
234                 case EX_INTLIT:
235                 case EX_ARGNUM:
236                         vm_expr->value.int_value = parse_expr->value.int_value;
237                         break;
238
239                 case EX_STRLIT:
240                         vm_expr->value.str_value = parse_expr->value.str_value;
241                         break;
242
243                 case EX_FNNAME:
244                         vm_expr->value.ref_value = vm_function;
245                         break;
246
247                 case EX_FNCALL:
248                         vm_expr->value.cll_value.function = vm_function;
249                         vm_expr->value.cll_value.args = malloc(sizeof(UwUVMExpression) * parse_expr->num_children);
250                         vm_expr->value.cll_value.num_args = parse_expr->num_children;
251
252                         for (size_t i = 0; i < parse_expr->num_children; i++)
253                                 translate_expression(state, module, &vm_expr->value.cll_value.args[i], parse_expr->children[i]);
254
255                         if (parse_expr->children)
256                                 free(parse_expr->children);
257                         break;
258
259                 default:
260                         break;
261         }               
262
263         free(parse_expr);
264 }
265
266 static void load_functions(CollectorState *state, Module *module)
267 {
268         for (; module->loaded_functions < module->num_functions; module->loaded_functions++) {
269                 FunctionLink *linkptr = &module->functions[module->loaded_functions];
270                 FunctionLink link = *linkptr;
271
272                 bool found = false;
273
274                 if (module->type == MODULE_PLAIN) {
275                         ParseFunction **function = NULL;
276
277                         for (size_t i = 0; i < module->handle.ast.num_functions; i++) {
278                                 ParseFunction **fn = &module->handle.ast.functions[i];
279
280                                 if (*fn && strcmp((*fn)->name, link.name) == 0) {
281                                         function = fn;
282                                         break;
283                                 }
284                         }
285
286                         if (function) {
287                                 found = true;
288                                 linkptr = NULL;
289
290                                 translate_expression(state, module, link.ref->value.plain = malloc(sizeof(UwUVMExpression)), (*function)->expression);
291                                 free((*function)->name);
292                                 free(*function);
293
294                                 *function = NULL;
295                         }
296                 } else {
297                         char *symbol = asprintf_wrapper("uwu_%s", link.name);
298                         linkptr->ref->value.native = dlsym(module->handle.lib, symbol);
299
300                         if (! dlerror())
301                                 found = true;
302         
303                         free(symbol);
304                 }
305
306                 if (! found)
307                         error("error: no function %s in module %s\n", link.name, module->filename);
308         }
309 }
310
311 static void free_expression(ParseExpression *expr)
312 {
313         if (expr->type == EX_FNCALL) {
314                 for (size_t i = 0; i < expr->num_children; i++)
315                         free_expression(expr->children[i]);
316
317                 if (expr->children)
318                         free(expr->children);
319         }
320
321         if (expr->type != EX_INTLIT && expr->type != EX_ARGNUM)
322                 free(expr->value.str_value);
323
324         free(expr);     
325 }
326
327 UwUVMProgram create_program(const char *progname, const char *modname)
328 {
329         char *prog_dirname = dirname_wrapper(progname);
330         char *api_path = asprintf_wrapper("%s/api/api.so", prog_dirname);
331
332         CollectorState state = {
333                 .modules = NULL,
334                 .num_modules = 0,
335                 .std_path = asprintf_wrapper("%s/std", prog_dirname),
336                 .program = {
337                         .api_library = dlopen(api_path, RTLD_NOW | RTLD_GLOBAL),
338                         .main_function = NULL,
339                         .functions = NULL,
340                         .num_functions = 0,
341                         .libraries = NULL,
342                         .num_libraries = 0,
343                 },
344         };
345
346         free(prog_dirname);
347         free(api_path);
348         
349         state.program.main_function = require_function(&state, require_module(&state, strdup(modname)), "main");
350
351         while (true) {
352                 bool fully_loaded = true;
353
354                 for (size_t i = 0; i < state.num_modules; i++) {
355                         Module *module = state.modules[i];
356
357 #if DEBUG
358                         printf("%s %lu/%lu\n", module->filename, module->loaded_functions, module->num_functions);
359 #endif
360
361                         if (module->loaded_functions < module->num_functions) {
362                                 fully_loaded = false;
363                                 load_functions(&state, module);
364                         }
365                 }
366
367                 if (fully_loaded)
368                         break;
369         }
370
371         free(state.std_path);
372
373         for (size_t i = 0; i < state.num_modules; i++) {
374                 Module *module = state.modules[i];
375
376                 free(module->path);
377                 free(module->filename);
378                 free(module->environment);
379
380                 for (size_t f = 0; f < module->num_functions; f++)
381                         free(module->functions[f].name);
382
383                 free(module->functions);
384
385                 if (module->type == MODULE_PLAIN) {
386                         for (size_t f = 0; f < module->handle.ast.num_functions; f++) {
387                                 ParseFunction *function = module->handle.ast.functions[f];
388
389                                 if (function) {
390                                         free_expression(function->expression);
391                                         free(function->name);
392                                         free(function);
393                                 }
394                         }
395
396                         if (module->handle.ast.functions)
397                                 free(module->handle.ast.functions);
398                 }
399                 
400                 free(module);
401         }
402
403         free(state.modules);
404
405         return state.program;
406 }