]> git.lizzy.rs Git - uwu-lang.git/blob - src/load.c
uwuint: use long instead of int to prevent YEAR2038 problem in nolambda time library
[uwu-lang.git] / src / load.c
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
4 #include <stdarg.h>
5 #include <libgen.h>
6 #include <dlfcn.h>
7 #include "common/err.h"
8 #include "common/str.h"
9 #include "common/file.h"
10 #include "common/dl.h"
11 #include "load.h"
12 #include "parse.h"
13
14 #define DEBUG 0
15
16 // helper functions
17
18 static char *dirname_wrapper(const char *name)
19 {
20         char *copy = strdup(name);
21         char *result = dirname(copy);
22         char *result_copy = strdup(result);
23
24         free(copy);
25         return result_copy;
26 }
27
28 // type definitions
29
30 typedef struct
31 {
32         char *name;
33         UwUVMFunction *ref;
34 } FunctionLink;
35
36 typedef struct
37 {
38         char *path;                                     // path without file extension
39         char *filename;                         // path with file extension
40         char *environment;                      // directory path
41
42         UwUVMModuleType type;                   // native (.so) or plain (.uwu)
43
44         FunctionLink *functions;        // required functions
45         size_t    num_functions;        // number of required functions
46         size_t loaded_functions;        // number of loaded functions (<= num_functions)
47
48         union
49         {
50                 AbstractSyntaxTree ast; // abstract syntax tree generated by parser (for plain modules)
51                 void *lib;                              // dlopen() shared object handle (for native modules)
52         } handle;
53 } Module;
54
55 typedef struct
56 {
57         Module   **modules;                     // loaded modules
58         size_t num_modules;                     // count for modules
59
60         char *std_path;                         // path to standard library
61
62         Program program;                        // the result program
63 } LoadState;
64
65 // functions
66
67 // returns mallocated string
68 static inline char *get_filename(const char *module_path)
69 {
70         const char *try_names[3] = {
71                 "%s",
72                 "%s.uwu",
73                 "%s.so",
74         };
75
76         char *filename;
77
78         for (int i = 0; i < 3; i++) {
79                 filename = asprintf_wrapper(try_names[i], module_path);
80
81                 if (file_exists(filename))
82                         return filename;
83                 else
84                         free(filename);
85         }
86
87         return NULL;
88 }
89
90 // module_path is a mallocated string
91 static Module *require_module(LoadState *state, char *module_path)
92 {
93         for (size_t i = 0; i < state->num_modules; i++) {
94                 Module *module = state->modules[i];
95
96                 if (strcmp(module_path, module->path) == 0) {
97                         free(module_path);
98                         return module;
99                 }
100         }
101
102         char *filename = get_filename(module_path);
103
104         if (! filename)
105                 error("module error: module %s not found\n", module_path);
106
107         size_t filename_len = strlen(filename);
108         UwUVMModuleType type = (filename_len >= 3 && strcmp(filename + filename_len - 3, ".so") == 0) ? MODULE_NATIVE : MODULE_PLAIN;
109
110         state->modules = realloc(state->modules, sizeof *state->modules * ++state->num_modules);
111         Module *module = state->modules[state->num_modules - 1] = malloc(sizeof *module);
112
113         *module = (Module) {
114                 .path = module_path,
115                 .filename = filename,
116                 .environment = dirname_wrapper(module_path),
117
118                 .type = type,
119
120                 .functions = NULL,
121                 .num_functions = 0,
122                 .loaded_functions = 0,
123         };
124
125         if (type == MODULE_PLAIN) {
126                 module->handle.ast = parse_file(filename);
127         } else {
128                 state->program.libraries = realloc(state->program.libraries, sizeof(void *) * ++state->program.num_libraries);
129                 state->program.libraries[state->program.num_libraries - 1] = module->handle.lib = dlopen(filename, RTLD_LAZY);
130
131                 check_dlerror();
132         }
133
134         return module;
135 }
136
137 static UwUVMFunction *require_function(LoadState *state, Module *module, const char *name)
138 {
139         for (size_t i = 0; i < module->num_functions; i++) {
140                 FunctionLink *link = &module->functions[i];
141
142                 if (strcmp(link->name, name) == 0)
143                         return link->ref;
144         }
145
146         UwUVMFunction *ref = malloc(sizeof *ref);
147         ref->type = module->type;
148
149         state->program.functions = realloc(state->program.functions, sizeof *state->program.functions * ++state->program.num_functions);
150         state->program.functions[state->program.num_functions - 1] = ref;
151
152         module->functions = realloc(module->functions, sizeof *module->functions * ++module->num_functions);
153         module->functions[module->num_functions - 1] = (FunctionLink) {
154                 .name = strdup(name),
155                 .ref = ref,
156         };
157
158         return ref;
159 }
160
161 static UwUVMFunction *resolve_function(LoadState *state, Module *caller_module, const char *full_name)
162 {
163         size_t len = strlen(full_name);
164
165         const char *fnname;
166         for (fnname = &full_name[len - 1]; *fnname != ':' && fnname > full_name; fnname--)
167                 ;
168
169         if (*fnname == ':')
170                 fnname++;
171
172         if (*fnname == '\0')
173                 error("module error: empty function name referenced/called by module %s\n", caller_module->filename);
174
175         Module *callee_module;
176
177         if (fnname == full_name) {
178                 callee_module = caller_module;
179         } else {
180                 const char *caller_path = caller_module->environment;
181                 const char *callee_name = full_name;
182
183                 if (*callee_name == ':') {
184                         caller_path = state->std_path;
185                         callee_name++;
186                 }
187
188                 size_t path_len = fnname - callee_name;
189                 char callee_path[path_len];
190
191                 for (size_t i = 0; i < path_len; i++)
192                         callee_path[i] = (i == path_len - 1) ? '\0'
193                                 : (callee_name[i] == ':') ? '/'
194                                 : callee_name[i];
195
196                 callee_module = require_module(state, asprintf_wrapper("%s/%s", caller_path, callee_path));
197         }
198
199         return require_function(state, callee_module, fnname);
200 }
201
202 static void translate_expression(LoadState *state, Module *module, UwUVMExpression *vm_expr, ParseExpression *parse_expr)
203 {
204         UwUVMFunction *vm_function;
205
206         if (parse_expr->type == EX_FNNAME || parse_expr->type == EX_FNCALL) {
207                 vm_function = resolve_function(state, module, parse_expr->value.str_value);
208                 free(parse_expr->value.str_value);
209         }
210
211         switch (vm_expr->type = parse_expr->type) {
212                 case EX_INTLIT:
213                 case EX_ARGNUM:
214                         vm_expr->value.int_value = parse_expr->value.int_value;
215                         break;
216
217                 case EX_STRLIT:
218                         vm_expr->value.str_value = parse_expr->value.str_value;
219                         break;
220
221                 case EX_FNNAME:
222                         vm_expr->value.ref_value = vm_function;
223                         break;
224
225                 case EX_FNCALL:
226                         vm_expr->value.cll_value.function = vm_function;
227                         vm_expr->value.cll_value.args = malloc(sizeof(UwUVMExpression) * parse_expr->num_children);
228                         vm_expr->value.cll_value.num_args = parse_expr->num_children;
229
230                         for (size_t i = 0; i < parse_expr->num_children; i++)
231                                 translate_expression(state, module, &vm_expr->value.cll_value.args[i], parse_expr->children[i]);
232
233                         if (parse_expr->children)
234                                 free(parse_expr->children);
235                         break;
236
237                 default:
238                         break;
239         }
240
241         free(parse_expr);
242 }
243
244 static void load_functions(LoadState *state, Module *module)
245 {
246         for (; module->loaded_functions < module->num_functions; module->loaded_functions++) {
247                 FunctionLink *link = &module->functions[module->loaded_functions];
248
249                 if (module->type == MODULE_PLAIN) {
250                         ParseFunction **function = NULL;
251
252                         for (size_t i = 0; i < module->handle.ast.num_functions; i++) {
253                                 ParseFunction **fn = &module->handle.ast.functions[i];
254
255                                 if (*fn && strcmp((*fn)->name, link->name) == 0) {
256                                         function = fn;
257                                         break;
258                                 }
259                         }
260
261                         if (function) {
262                                 translate_expression(state, module, link->ref->value.plain = malloc(sizeof(UwUVMExpression)), (*function)->expression);
263                                 free((*function)->name);
264                                 free(*function);
265
266                                 *function = NULL;
267                         } else {
268                                 error("module error: no function %s in module %s\n", link->name, module->filename);
269                         }
270                 } else {
271                         char *symbol = asprintf_wrapper("uwu_%s", link->name);
272                         link->ref->value.native = dlsym(module->handle.lib, symbol);
273
274                         check_dlerror();
275                         free(symbol);
276                 }
277         }
278 }
279
280 static void free_expression(ParseExpression *expr)
281 {
282         if (expr->type == EX_FNCALL) {
283                 for (size_t i = 0; i < expr->num_children; i++)
284                         free_expression(expr->children[i]);
285
286                 if (expr->children)
287                         free(expr->children);
288         }
289
290         if (expr->type != EX_INTLIT && expr->type != EX_ARGNUM)
291                 free(expr->value.str_value);
292
293         free(expr);
294 }
295
296 Program load_program(const char *progname, const char *modname)
297 {
298         char *prog_dirname = dirname_wrapper(progname);
299         char *api_path = asprintf_wrapper("%s/api/api.so", prog_dirname);
300
301         LoadState state = {
302                 .modules = NULL,
303                 .num_modules = 0,
304                 .std_path = asprintf_wrapper("%s/std", prog_dirname),
305                 .program = {
306                         .api_library = dlopen(api_path, RTLD_NOW | RTLD_GLOBAL),
307                         .main_function = NULL,
308                         .functions = NULL,
309                         .num_functions = 0,
310                         .libraries = NULL,
311                         .num_libraries = 0,
312                 },
313         };
314
315         free(prog_dirname);
316         free(api_path);
317
318         state.program.main_function = require_function(&state, require_module(&state, strdup(modname)), "main");
319
320         while (true) {
321                 bool fully_loaded = true;
322
323                 for (size_t i = 0; i < state.num_modules; i++) {
324                         Module *module = state.modules[i];
325
326 #if DEBUG
327                         printf("%s %lu/%lu\n", module->filename, module->loaded_functions, module->num_functions);
328 #endif
329
330                         if (module->loaded_functions < module->num_functions) {
331                                 fully_loaded = false;
332                                 load_functions(&state, module);
333                         }
334                 }
335
336                 if (fully_loaded)
337                         break;
338         }
339
340         free(state.std_path);
341
342         for (size_t i = 0; i < state.num_modules; i++) {
343                 Module *module = state.modules[i];
344
345                 free(module->path);
346                 free(module->filename);
347                 free(module->environment);
348
349                 for (size_t f = 0; f < module->num_functions; f++)
350                         free(module->functions[f].name);
351
352                 free(module->functions);
353
354                 if (module->type == MODULE_PLAIN) {
355                         for (size_t f = 0; f < module->handle.ast.num_functions; f++) {
356                                 ParseFunction *function = module->handle.ast.functions[f];
357
358                                 if (function) {
359                                         free_expression(function->expression);
360                                         free(function->name);
361                                         free(function);
362                                 }
363                         }
364
365                         if (module->handle.ast.functions)
366                                 free(module->handle.ast.functions);
367                 }
368
369                 free(module);
370         }
371
372         free(state.modules);
373
374         return state.program;
375 }