]> git.lizzy.rs Git - uwu-lang.git/blob - src/load.c
Redesign function names
[uwu-lang.git] / src / load.c
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
4 #include <stdarg.h>
5 #include <libgen.h>
6 #include <dlfcn.h>
7 #include "common/err.h"
8 #include "common/str.h"
9 #include "common/file.h"
10 #include "common/dl.h"
11 #include "load.h"
12 #include "parse.h"
13
14 #define DEBUG 0
15
16 // helper functions
17
18 static char *dirname_wrapper(const char *name)
19 {
20         char *copy = strdup(name);
21         char *result = dirname(copy);
22         char *result_copy = strdup(result);
23
24         free(copy);
25         return result_copy;
26 }
27
28 // type definitions
29
30 typedef struct
31 {
32         char *name;
33         UwUVMFunction *ref;
34 } FunctionLink;
35
36 typedef struct
37 {
38         char *path;                 // path without file extension
39         char *filename;             // path with file extension
40         char *environment;          // directory path
41
42         UwUVMModuleType type;       // native (.so) or plain (.uwu)
43
44         FunctionLink *functions;    // required functions
45         size_t    num_functions;    // number of required functions
46         size_t loaded_functions;    // number of loaded functions (<= num_functions)
47
48         union
49         {
50                 AbstractSyntaxTree ast; // abstract syntax tree generated by parser (for plain modules)
51                 void *lib;              // dlopen() shared object handle (for native modules)
52         } handle;
53 } Module;
54
55 typedef struct
56 {
57         Module   **modules;         // loaded modules
58         size_t num_modules;         // count for modules
59
60         char     **module_paths;    // module search paths
61         size_t num_module_paths;    // count for module_paths
62         char  *module_paths_str;    // module search paths, stringified
63
64         char *std_path;             // path to standard library
65
66         Program program;            // the result program
67 } LoadState;
68
69 // functions
70
71 // returns mallocated string
72 static inline char *get_filename(const char *module_path)
73 {
74         const char *try_names[3] = {
75                 "%s",
76                 "%s.uwu",
77                 "%s.so",
78         };
79
80         char *filename;
81
82         for (int i = 0; i < 3; i++) {
83                 filename = asprintf_wrapper(try_names[i], module_path);
84
85                 if (file_exists(filename))
86                         return filename;
87                 else
88                         free(filename);
89         }
90
91         return NULL;
92 }
93
94 // module_path is a mallocated string
95 static Module *load_module(LoadState *state, char *module_path)
96 {
97         for (size_t i = 0; i < state->num_modules; i++) {
98                 Module *module = state->modules[i];
99
100                 if (strcmp(module_path, module->path) == 0) {
101                         free(module_path);
102                         return module;
103                 }
104         }
105
106         char *filename = get_filename(module_path);
107
108         if (! filename) {
109                 free(module_path);
110                 return NULL;
111         }
112
113         size_t filename_len = strlen(filename);
114         UwUVMModuleType type = (filename_len >= 3 && strcmp(filename + filename_len - 3, ".so") == 0) ? MODULE_NATIVE : MODULE_PLAIN;
115
116         state->modules = realloc(state->modules, sizeof *state->modules * ++state->num_modules);
117         Module *module = state->modules[state->num_modules - 1] = malloc(sizeof *module);
118
119         *module = (Module) {
120                 .path = module_path,
121                 .filename = filename,
122                 .environment = dirname_wrapper(module_path),
123
124                 .type = type,
125
126                 .functions = NULL,
127                 .num_functions = 0,
128                 .loaded_functions = 0,
129         };
130
131         if (type == MODULE_PLAIN) {
132                 module->handle.ast = parse_file(filename);
133         } else {
134                 state->program.libraries = realloc(state->program.libraries, sizeof(void *) * ++state->program.num_libraries);
135                 state->program.libraries[state->program.num_libraries - 1] = module->handle.lib = dlopen(filename, RTLD_LAZY);
136
137                 check_dlerror();
138         }
139
140         return module;
141 }
142
143 static UwUVMFunction *require_function(LoadState *state, Module *module, const char *name)
144 {
145         for (size_t i = 0; i < module->num_functions; i++) {
146                 FunctionLink *link = &module->functions[i];
147
148                 if (strcmp(link->name, name) == 0)
149                         return link->ref;
150         }
151
152         UwUVMFunction *ref = malloc(sizeof *ref);
153         ref->type = module->type;
154
155         state->program.functions = realloc(state->program.functions, sizeof *state->program.functions * ++state->program.num_functions);
156         state->program.functions[state->program.num_functions - 1] = ref;
157
158         module->functions = realloc(module->functions, sizeof *module->functions * ++module->num_functions);
159         module->functions[module->num_functions - 1] = (FunctionLink) {
160                 .name = strdup(name),
161                 .ref = ref,
162         };
163
164         return ref;
165 }
166
167 static UwUVMFunction *resolve_function(LoadState *state, Module *caller_module, const char *full_name)
168 {
169         size_t len = strlen(full_name);
170
171         const char *fnname;
172         for (fnname = &full_name[len - 1]; *fnname != '.' && fnname > full_name; fnname--)
173                 ;
174
175         if (*fnname == '.')
176                 fnname++;
177
178         if (*fnname == '\0')
179                 error("module error: empty function name referenced/called by module %s\n", caller_module->filename);
180
181         Module *callee_module;
182
183         if (fnname == full_name) {
184                 callee_module = caller_module;
185         } else {
186                 char     **environments     = state->module_paths;
187                 size_t num_environments     = state->num_module_paths;
188                 char      *environments_str = state->module_paths_str;
189
190                 const char *callee_name = full_name;
191
192                 if (*callee_name == '.') {
193                         callee_name++;
194
195                         environments = &caller_module->environment;
196                         num_environments = 1;
197                         environments_str = caller_module->environment;
198                 }
199
200                 size_t path_len = fnname - callee_name;
201                 char callee_path[path_len];
202
203                 for (size_t i = 0; i < path_len; i++)
204                         callee_path[i] = (i == path_len - 1) ? '\0'
205                                 : (callee_name[i] == '.') ? '/'
206                                 : callee_name[i];
207
208                 for (size_t i = 0; i < num_environments; i++)
209                         if ((callee_module = load_module(state, asprintf_wrapper("%s/%s", environments[i], callee_path))))
210                                 break;
211
212                 if (! callee_module)
213                         error("module error: no module %s in path %s\n", callee_path, environments_str);
214         }
215
216         return require_function(state, callee_module, fnname);
217 }
218
219 static void translate_expression(LoadState *state, Module *module, UwUVMExpression *vm_expr, ParseExpression *parse_expr)
220 {
221         UwUVMFunction *vm_function;
222
223         if (parse_expr->type == EX_FNNAME || parse_expr->type == EX_FNCALL) {
224                 vm_function = resolve_function(state, module, parse_expr->value.str_value);
225                 free(parse_expr->value.str_value);
226         }
227
228         switch (vm_expr->type = parse_expr->type) {
229                 case EX_INTLIT:
230                 case EX_ARGNUM:
231                         vm_expr->value.int_value = parse_expr->value.int_value;
232                         break;
233
234                 case EX_STRLIT:
235                         vm_expr->value.str_value = parse_expr->value.str_value;
236                         break;
237
238                 case EX_FNNAME:
239                         vm_expr->value.ref_value = vm_function;
240                         break;
241
242                 case EX_FNCALL:
243                         vm_expr->value.cll_value.function = vm_function;
244                         vm_expr->value.cll_value.args = malloc(sizeof(UwUVMExpression) * parse_expr->num_children);
245                         vm_expr->value.cll_value.num_args = parse_expr->num_children;
246
247                         for (size_t i = 0; i < parse_expr->num_children; i++)
248                                 translate_expression(state, module, &vm_expr->value.cll_value.args[i], parse_expr->children[i]);
249
250                         if (parse_expr->children)
251                                 free(parse_expr->children);
252                         break;
253
254                 default:
255                         break;
256         }
257
258         free(parse_expr);
259 }
260
261 static void load_functions(LoadState *state, Module *module)
262 {
263         for (; module->loaded_functions < module->num_functions; module->loaded_functions++) {
264                 FunctionLink *link = &module->functions[module->loaded_functions];
265
266                 if (module->type == MODULE_PLAIN) {
267                         ParseFunction **function = NULL;
268
269                         for (size_t i = 0; i < module->handle.ast.num_functions; i++) {
270                                 ParseFunction **fn = &module->handle.ast.functions[i];
271
272                                 if (*fn && strcmp((*fn)->name, link->name) == 0) {
273                                         function = fn;
274                                         break;
275                                 }
276                         }
277
278                         if (function) {
279                                 translate_expression(state, module, link->ref->value.plain = malloc(sizeof(UwUVMExpression)), (*function)->expression);
280                                 free((*function)->name);
281                                 free(*function);
282
283                                 *function = NULL;
284                         } else {
285                                 error("module error: no function %s in module %s\n", link->name, module->filename);
286                         }
287                 } else {
288                         char *symbol = asprintf_wrapper("uwu_%s", link->name);
289                         link->ref->value.native = dlsym(module->handle.lib, symbol);
290
291                         check_dlerror();
292                         free(symbol);
293                 }
294         }
295 }
296
297 static void free_expression(ParseExpression *expr)
298 {
299         if (expr->type == EX_FNCALL) {
300                 for (size_t i = 0; i < expr->num_children; i++)
301                         free_expression(expr->children[i]);
302
303                 if (expr->children)
304                         free(expr->children);
305         }
306
307         if (expr->type != EX_INTLIT && expr->type != EX_ARGNUM)
308                 free(expr->value.str_value);
309
310         free(expr);
311 }
312
313 Program load_program(const char *progname, const char *modname)
314 {
315         char *prog_dirname = dirname_wrapper(progname);
316         char *api_path = asprintf_wrapper("%s/api/api.so", prog_dirname);
317
318         LoadState state = {
319                 .modules = NULL,
320                 .num_modules = 0,
321                 .program = {
322                         .api_library = dlopen(api_path, RTLD_NOW | RTLD_GLOBAL),
323                         .main_function = NULL,
324                         .functions = NULL,
325                         .num_functions = 0,
326                         .libraries = NULL,
327                         .num_libraries = 0,
328                 },
329         };
330
331         char *uwu_module_path = getenv("UWU_MODULE_PATH");
332
333         if (uwu_module_path) {
334                 char  *uwu_module_path_ptr = state.module_paths_str = uwu_module_path;
335                 char  *uwu_module_path_base_ptr = uwu_module_path_ptr;
336                 size_t uwu_module_path_len = 1;
337
338                 state.num_module_paths = 0;
339                 state.module_paths = NULL;
340
341                 for (;; uwu_module_path_ptr++, uwu_module_path_len++) {
342                         if (*uwu_module_path_ptr == '\0' || *uwu_module_path_ptr == ':') {
343                                 state.module_paths = realloc(state.module_paths, sizeof(char **) * ++state.num_module_paths);
344                                 strncpy(state.module_paths[state.num_module_paths - 1] = malloc(uwu_module_path_len), uwu_module_path_base_ptr, uwu_module_path_len)[uwu_module_path_len - 1] = '\0';
345
346                                 uwu_module_path_len = 0;
347                                 uwu_module_path_base_ptr = uwu_module_path_ptr + 1;
348                         }
349
350                         if (*uwu_module_path_ptr == '\0')
351                                 break;
352                 }
353         } else {
354                 state.module_paths_str = asprintf_wrapper("%s/std", prog_dirname);
355                 state.num_module_paths = 1;
356                 state.module_paths = malloc(sizeof(char **));
357                 state.module_paths[0] = state.module_paths_str;
358         }
359
360         free(prog_dirname);
361         free(api_path);
362
363         Module *main_module = load_module(&state, strdup(modname));
364
365         if (! main_module)
366                 error("module error: requested module %s not found\n", modname);
367
368         state.program.main_function = require_function(&state, main_module, "main");
369
370         while (true) {
371                 bool fully_loaded = true;
372
373                 for (size_t i = 0; i < state.num_modules; i++) {
374                         Module *module = state.modules[i];
375
376 #if DEBUG
377                         printf("%s %lu/%lu\n", module->filename, module->loaded_functions, module->num_functions);
378 #endif
379
380                         if (module->loaded_functions < module->num_functions) {
381                                 fully_loaded = false;
382                                 load_functions(&state, module);
383                         }
384                 }
385
386                 if (fully_loaded)
387                         break;
388         }
389
390         for (size_t i = 0; i < state.num_module_paths; i++)
391                 free(state.module_paths[i]);
392
393         free(state.module_paths);
394
395         for (size_t i = 0; i < state.num_modules; i++) {
396                 Module *module = state.modules[i];
397
398                 free(module->path);
399                 free(module->filename);
400                 free(module->environment);
401
402                 for (size_t f = 0; f < module->num_functions; f++)
403                         free(module->functions[f].name);
404
405                 free(module->functions);
406
407                 if (module->type == MODULE_PLAIN) {
408                         for (size_t f = 0; f < module->handle.ast.num_functions; f++) {
409                                 ParseFunction *function = module->handle.ast.functions[f];
410
411                                 if (function) {
412                                         free_expression(function->expression);
413                                         free(function->name);
414                                         free(function);
415                                 }
416                         }
417
418                         if (module->handle.ast.functions)
419                                 free(module->handle.ast.functions);
420                 }
421
422                 free(module);
423         }
424
425         free(state.modules);
426
427         return state.program;
428 }