]> git.lizzy.rs Git - uwu-lang.git/blob - src/load.c
edeba9759814b3850d0155b0bb5316b8b2eb71ec
[uwu-lang.git] / src / load.c
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
4 #include <stdarg.h>
5 #include <libgen.h>
6 #include <dlfcn.h>
7 #include "common/err.h"
8 #include "common/str.h"
9 #include "load.h"
10 #include "parse.h"
11
12 #define DEBUG 0
13
14 // helper functions
15
16 static char *wrap_name_func(const char *name, char *(*fn)(char *))
17 {
18         char *copy = strdup(name);
19         char *result = fn(copy);
20         char *result_copy = strdup(result);
21
22         free(copy);
23         return result_copy;
24 }
25
26 static char *basename_wrapper(const char *name)
27 {
28         return wrap_name_func(name, &basename);
29 }
30
31 static char *dirname_wrapper(const char *name)
32 {
33         return wrap_name_func(name, &dirname);
34 }
35
36 static bool file_exists(const char *filename)
37 {
38         FILE *f = fopen(filename, "r");
39
40         if (f) {
41                 fclose(f);
42                 return true;
43         }
44
45         return false;
46 }
47
48 // type definitions
49
50 typedef struct
51 {
52         char *name;
53         UwUVMFunction *ref;
54 } FunctionLink;
55
56 typedef struct
57 {
58         char *path;                                     // path without file extension
59         char *filename;                         // path with file extension
60         char *environment;                      // directory path
61
62         UwUVMModuleType type;                   // native (.so) or plain (.uwu)
63
64         FunctionLink *functions;        // required functions
65         size_t    num_functions;        // number of required functions
66         size_t loaded_functions;        // number of loaded functions (<= num_functions)
67
68         union
69         {
70                 AbstractSyntaxTree ast; // abstract syntax tree generated by parser (for plain modules)
71                 void *lib;                              // dlopen() shared object handle (for native modules)
72         } handle;
73 } Module;
74
75 typedef struct
76 {
77         Module   **modules;                     // loaded modules
78         size_t num_modules;                     // count for modules
79
80         char *std_path;                         // path to standard library
81
82         Program program;                        // the result program
83 } LoadState;
84
85 // functions
86
87 // returns mallocated string
88 static inline char *get_filename(const char *module_path)
89 {
90         const char *try_names[3] = {
91                 "%s",
92                 "%s.uwu",
93                 "%s.so",
94         };
95
96         char *filename;
97
98         for (int i = 0; i < 3; i++) {
99                 filename = asprintf_wrapper(try_names[i], module_path);
100
101                 if (file_exists(filename))
102                         return filename;
103                 else
104                         free(filename);
105         }
106
107         return NULL;
108 }
109
110 // module_path is a mallocated string
111 static Module *require_module(LoadState *state, char *module_path)
112 {
113         for (size_t i = 0; i < state->num_modules; i++) {
114                 Module *module = state->modules[i];
115
116                 if (strcmp(module_path, module->path) == 0) {
117                         free(module_path);
118                         return module;
119                 }
120         }
121
122         char *filename = get_filename(module_path);
123
124         if (! filename)
125                 error("error: module %s not found\n", module_path);
126
127         size_t filename_len = strlen(filename);
128         UwUVMModuleType type = (filename_len >= 3 && strcmp(filename + filename_len - 3, ".so") == 0) ? MODULE_NATIVE : MODULE_PLAIN;
129
130         state->modules = realloc(state->modules, sizeof *state->modules * ++state->num_modules);
131         Module *module = state->modules[state->num_modules - 1] = malloc(sizeof *module);
132
133         *module = (Module) {
134                 .path = module_path,
135                 .filename = filename,
136                 .environment = dirname_wrapper(module_path),
137
138                 .type = type,
139
140                 .functions = NULL,
141                 .num_functions = 0,
142                 .loaded_functions = 0,
143         };
144
145         if (type == MODULE_PLAIN) {
146                 module->handle.ast = parse_file(filename);
147         } else {
148                 state->program.libraries = realloc(state->program.libraries, sizeof(void *) * ++state->program.num_libraries);
149                 state->program.libraries[state->program.num_libraries - 1] = module->handle.lib = dlopen(filename, RTLD_LAZY);
150
151                 char *err = dlerror();
152                 if (err)
153                         error("library error: %s\n", err);
154         }
155
156         return module;
157 }
158
159 static UwUVMFunction *require_function(LoadState *state, Module *module, const char *name)
160 {
161         for (size_t i = 0; i < module->num_functions; i++) {
162                 FunctionLink *link = &module->functions[i];
163
164                 if (strcmp(link->name, name) == 0)
165                         return link->ref;
166         }
167
168         UwUVMFunction *ref = malloc(sizeof *ref);
169         ref->type = module->type;
170
171         state->program.functions = realloc(state->program.functions, sizeof *state->program.functions * ++state->program.num_functions);
172         state->program.functions[state->program.num_functions - 1] = ref;
173
174         module->functions = realloc(module->functions, sizeof *module->functions * ++module->num_functions);
175         module->functions[module->num_functions - 1] = (FunctionLink) {
176                 .name = strdup(name),
177                 .ref = ref,
178         };
179
180         return ref;
181 }
182
183 static UwUVMFunction *resolve_function(LoadState *state, Module *caller_module, const char *full_name)
184 {
185         size_t len = strlen(full_name);
186
187         const char *fnname;
188         for (fnname = &full_name[len - 1]; *fnname != ':' && fnname > full_name; fnname--)
189                 ;
190
191         if (*fnname == ':')
192                 fnname++;
193
194         if (*fnname == '\0')
195                 error("error: empty function name\n");
196
197         Module *callee_module;
198
199         if (fnname == full_name) {
200                 callee_module = caller_module;
201         } else {
202                 const char *caller_path = caller_module->environment;
203                 const char *callee_name = full_name;
204
205                 if (*callee_name == ':') {
206                         caller_path = state->std_path;
207                         callee_name++;
208                 }
209
210                 size_t path_len = fnname - callee_name;
211                 char callee_path[path_len];
212
213                 for (size_t i = 0; i < path_len; i++)
214                         callee_path[i] = (i == path_len - 1) ? '\0'
215                                 : (callee_name[i] == ':') ? '/'
216                                 : callee_name[i];
217
218                 callee_module = require_module(state, asprintf_wrapper("%s/%s", caller_path, callee_path));
219         }
220
221         return require_function(state, callee_module, fnname);
222 }
223
224 static void translate_expression(LoadState *state, Module *module, UwUVMExpression *vm_expr, ParseExpression *parse_expr)
225 {
226         UwUVMFunction *vm_function;
227
228         if (parse_expr->type == EX_FNNAME || parse_expr->type == EX_FNCALL) {
229                 vm_function = resolve_function(state, module, parse_expr->value.str_value);
230                 free(parse_expr->value.str_value);
231         }
232
233         switch (vm_expr->type = parse_expr->type) {
234                 case EX_INTLIT:
235                 case EX_ARGNUM:
236                         vm_expr->value.int_value = parse_expr->value.int_value;
237                         break;
238
239                 case EX_STRLIT:
240                         vm_expr->value.str_value = parse_expr->value.str_value;
241                         break;
242
243                 case EX_FNNAME:
244                         vm_expr->value.ref_value = vm_function;
245                         break;
246
247                 case EX_FNCALL:
248                         vm_expr->value.cll_value.function = vm_function;
249                         vm_expr->value.cll_value.args = malloc(sizeof(UwUVMExpression) * parse_expr->num_children);
250                         vm_expr->value.cll_value.num_args = parse_expr->num_children;
251
252                         for (size_t i = 0; i < parse_expr->num_children; i++)
253                                 translate_expression(state, module, &vm_expr->value.cll_value.args[i], parse_expr->children[i]);
254
255                         if (parse_expr->children)
256                                 free(parse_expr->children);
257                         break;
258
259                 default:
260                         break;
261         }
262
263         free(parse_expr);
264 }
265
266 static void load_functions(LoadState *state, Module *module)
267 {
268         for (; module->loaded_functions < module->num_functions; module->loaded_functions++) {
269                 FunctionLink *link = &module->functions[module->loaded_functions];
270
271                 if (module->type == MODULE_PLAIN) {
272                         ParseFunction **function = NULL;
273
274                         for (size_t i = 0; i < module->handle.ast.num_functions; i++) {
275                                 ParseFunction **fn = &module->handle.ast.functions[i];
276
277                                 if (*fn && strcmp((*fn)->name, link->name) == 0) {
278                                         function = fn;
279                                         break;
280                                 }
281                         }
282
283                         if (function) {
284                                 translate_expression(state, module, link->ref->value.plain = malloc(sizeof(UwUVMExpression)), (*function)->expression);
285                                 free((*function)->name);
286                                 free(*function);
287
288                                 *function = NULL;
289                         } else {
290                                 error("error: no function %s in module %s\n", link->name, module->filename);
291                         }
292                 } else {
293                         char *symbol = asprintf_wrapper("uwu_%s", link->name);
294                         link->ref->value.native = dlsym(module->handle.lib, symbol);
295
296                         char *err = dlerror();
297                         if (err)
298                                 error("library error: %s\n", err);
299
300                         free(symbol);
301                 }
302         }
303 }
304
305 static void free_expression(ParseExpression *expr)
306 {
307         if (expr->type == EX_FNCALL) {
308                 for (size_t i = 0; i < expr->num_children; i++)
309                         free_expression(expr->children[i]);
310
311                 if (expr->children)
312                         free(expr->children);
313         }
314
315         if (expr->type != EX_INTLIT && expr->type != EX_ARGNUM)
316                 free(expr->value.str_value);
317
318         free(expr);
319 }
320
321 Program load_program(const char *progname, const char *modname)
322 {
323         char *prog_dirname = dirname_wrapper(progname);
324         char *api_path = asprintf_wrapper("%s/api/api.so", prog_dirname);
325
326         LoadState state = {
327                 .modules = NULL,
328                 .num_modules = 0,
329                 .std_path = asprintf_wrapper("%s/std", prog_dirname),
330                 .program = {
331                         .api_library = dlopen(api_path, RTLD_NOW | RTLD_GLOBAL),
332                         .main_function = NULL,
333                         .functions = NULL,
334                         .num_functions = 0,
335                         .libraries = NULL,
336                         .num_libraries = 0,
337                 },
338         };
339
340         free(prog_dirname);
341         free(api_path);
342
343         state.program.main_function = require_function(&state, require_module(&state, strdup(modname)), "main");
344
345         while (true) {
346                 bool fully_loaded = true;
347
348                 for (size_t i = 0; i < state.num_modules; i++) {
349                         Module *module = state.modules[i];
350
351 #if DEBUG
352                         printf("%s %lu/%lu\n", module->filename, module->loaded_functions, module->num_functions);
353 #endif
354
355                         if (module->loaded_functions < module->num_functions) {
356                                 fully_loaded = false;
357                                 load_functions(&state, module);
358                         }
359                 }
360
361                 if (fully_loaded)
362                         break;
363         }
364
365         free(state.std_path);
366
367         for (size_t i = 0; i < state.num_modules; i++) {
368                 Module *module = state.modules[i];
369
370                 free(module->path);
371                 free(module->filename);
372                 free(module->environment);
373
374                 for (size_t f = 0; f < module->num_functions; f++)
375                         free(module->functions[f].name);
376
377                 free(module->functions);
378
379                 if (module->type == MODULE_PLAIN) {
380                         for (size_t f = 0; f < module->handle.ast.num_functions; f++) {
381                                 ParseFunction *function = module->handle.ast.functions[f];
382
383                                 if (function) {
384                                         free_expression(function->expression);
385                                         free(function->name);
386                                         free(function);
387                                 }
388                         }
389
390                         if (module->handle.ast.functions)
391                                 free(module->handle.ast.functions);
392                 }
393
394                 free(module);
395         }
396
397         free(state.modules);
398
399         return state.program;
400 }