]> git.lizzy.rs Git - uwu-lang.git/blob - src/load.c
f6386d6341a5feeb9737ddf271d2fc39a11de93f
[uwu-lang.git] / src / load.c
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
4 #include <stdarg.h>
5 #include <libgen.h>
6 #include <dlfcn.h>
7 #include "common/err.h"
8 #include "common/str.h"
9 #include "common/file.h"
10 #include "common/dl.h"
11 #include "load.h"
12 #include "parse.h"
13
14 #define DEBUG 0
15
16 // helper functions
17
18 static char *wrap_name_func(const char *name, char *(*fn)(char *))
19 {
20         char *copy = strdup(name);
21         char *result = fn(copy);
22         char *result_copy = strdup(result);
23
24         free(copy);
25         return result_copy;
26 }
27
28 static char *basename_wrapper(const char *name)
29 {
30         return wrap_name_func(name, &basename);
31 }
32
33 static char *dirname_wrapper(const char *name)
34 {
35         return wrap_name_func(name, &dirname);
36 }
37
38 // type definitions
39
40 typedef struct
41 {
42         char *name;
43         UwUVMFunction *ref;
44 } FunctionLink;
45
46 typedef struct
47 {
48         char *path;                                     // path without file extension
49         char *filename;                         // path with file extension
50         char *environment;                      // directory path
51
52         UwUVMModuleType type;                   // native (.so) or plain (.uwu)
53
54         FunctionLink *functions;        // required functions
55         size_t    num_functions;        // number of required functions
56         size_t loaded_functions;        // number of loaded functions (<= num_functions)
57
58         union
59         {
60                 AbstractSyntaxTree ast; // abstract syntax tree generated by parser (for plain modules)
61                 void *lib;                              // dlopen() shared object handle (for native modules)
62         } handle;
63 } Module;
64
65 typedef struct
66 {
67         Module   **modules;                     // loaded modules
68         size_t num_modules;                     // count for modules
69
70         char *std_path;                         // path to standard library
71
72         Program program;                        // the result program
73 } LoadState;
74
75 // functions
76
77 // returns mallocated string
78 static inline char *get_filename(const char *module_path)
79 {
80         const char *try_names[3] = {
81                 "%s",
82                 "%s.uwu",
83                 "%s.so",
84         };
85
86         char *filename;
87
88         for (int i = 0; i < 3; i++) {
89                 filename = asprintf_wrapper(try_names[i], module_path);
90
91                 if (file_exists(filename))
92                         return filename;
93                 else
94                         free(filename);
95         }
96
97         return NULL;
98 }
99
100 // module_path is a mallocated string
101 static Module *require_module(LoadState *state, char *module_path)
102 {
103         for (size_t i = 0; i < state->num_modules; i++) {
104                 Module *module = state->modules[i];
105
106                 if (strcmp(module_path, module->path) == 0) {
107                         free(module_path);
108                         return module;
109                 }
110         }
111
112         char *filename = get_filename(module_path);
113
114         if (! filename)
115                 error("module error: module %s not found\n", module_path);
116
117         size_t filename_len = strlen(filename);
118         UwUVMModuleType type = (filename_len >= 3 && strcmp(filename + filename_len - 3, ".so") == 0) ? MODULE_NATIVE : MODULE_PLAIN;
119
120         state->modules = realloc(state->modules, sizeof *state->modules * ++state->num_modules);
121         Module *module = state->modules[state->num_modules - 1] = malloc(sizeof *module);
122
123         *module = (Module) {
124                 .path = module_path,
125                 .filename = filename,
126                 .environment = dirname_wrapper(module_path),
127
128                 .type = type,
129
130                 .functions = NULL,
131                 .num_functions = 0,
132                 .loaded_functions = 0,
133         };
134
135         if (type == MODULE_PLAIN) {
136                 module->handle.ast = parse_file(filename);
137         } else {
138                 state->program.libraries = realloc(state->program.libraries, sizeof(void *) * ++state->program.num_libraries);
139                 state->program.libraries[state->program.num_libraries - 1] = module->handle.lib = dlopen(filename, RTLD_LAZY);
140
141                 check_dlerror();
142         }
143
144         return module;
145 }
146
147 static UwUVMFunction *require_function(LoadState *state, Module *module, const char *name)
148 {
149         for (size_t i = 0; i < module->num_functions; i++) {
150                 FunctionLink *link = &module->functions[i];
151
152                 if (strcmp(link->name, name) == 0)
153                         return link->ref;
154         }
155
156         UwUVMFunction *ref = malloc(sizeof *ref);
157         ref->type = module->type;
158
159         state->program.functions = realloc(state->program.functions, sizeof *state->program.functions * ++state->program.num_functions);
160         state->program.functions[state->program.num_functions - 1] = ref;
161
162         module->functions = realloc(module->functions, sizeof *module->functions * ++module->num_functions);
163         module->functions[module->num_functions - 1] = (FunctionLink) {
164                 .name = strdup(name),
165                 .ref = ref,
166         };
167
168         return ref;
169 }
170
171 static UwUVMFunction *resolve_function(LoadState *state, Module *caller_module, const char *full_name)
172 {
173         size_t len = strlen(full_name);
174
175         const char *fnname;
176         for (fnname = &full_name[len - 1]; *fnname != ':' && fnname > full_name; fnname--)
177                 ;
178
179         if (*fnname == ':')
180                 fnname++;
181
182         if (*fnname == '\0')
183                 error("module error: empty function name referenced/called by module %s\n", caller_module->filename);
184
185         Module *callee_module;
186
187         if (fnname == full_name) {
188                 callee_module = caller_module;
189         } else {
190                 const char *caller_path = caller_module->environment;
191                 const char *callee_name = full_name;
192
193                 if (*callee_name == ':') {
194                         caller_path = state->std_path;
195                         callee_name++;
196                 }
197
198                 size_t path_len = fnname - callee_name;
199                 char callee_path[path_len];
200
201                 for (size_t i = 0; i < path_len; i++)
202                         callee_path[i] = (i == path_len - 1) ? '\0'
203                                 : (callee_name[i] == ':') ? '/'
204                                 : callee_name[i];
205
206                 callee_module = require_module(state, asprintf_wrapper("%s/%s", caller_path, callee_path));
207         }
208
209         return require_function(state, callee_module, fnname);
210 }
211
212 static void translate_expression(LoadState *state, Module *module, UwUVMExpression *vm_expr, ParseExpression *parse_expr)
213 {
214         UwUVMFunction *vm_function;
215
216         if (parse_expr->type == EX_FNNAME || parse_expr->type == EX_FNCALL) {
217                 vm_function = resolve_function(state, module, parse_expr->value.str_value);
218                 free(parse_expr->value.str_value);
219         }
220
221         switch (vm_expr->type = parse_expr->type) {
222                 case EX_INTLIT:
223                 case EX_ARGNUM:
224                         vm_expr->value.int_value = parse_expr->value.int_value;
225                         break;
226
227                 case EX_STRLIT:
228                         vm_expr->value.str_value = parse_expr->value.str_value;
229                         break;
230
231                 case EX_FNNAME:
232                         vm_expr->value.ref_value = vm_function;
233                         break;
234
235                 case EX_FNCALL:
236                         vm_expr->value.cll_value.function = vm_function;
237                         vm_expr->value.cll_value.args = malloc(sizeof(UwUVMExpression) * parse_expr->num_children);
238                         vm_expr->value.cll_value.num_args = parse_expr->num_children;
239
240                         for (size_t i = 0; i < parse_expr->num_children; i++)
241                                 translate_expression(state, module, &vm_expr->value.cll_value.args[i], parse_expr->children[i]);
242
243                         if (parse_expr->children)
244                                 free(parse_expr->children);
245                         break;
246
247                 default:
248                         break;
249         }
250
251         free(parse_expr);
252 }
253
254 static void load_functions(LoadState *state, Module *module)
255 {
256         for (; module->loaded_functions < module->num_functions; module->loaded_functions++) {
257                 FunctionLink *link = &module->functions[module->loaded_functions];
258
259                 if (module->type == MODULE_PLAIN) {
260                         ParseFunction **function = NULL;
261
262                         for (size_t i = 0; i < module->handle.ast.num_functions; i++) {
263                                 ParseFunction **fn = &module->handle.ast.functions[i];
264
265                                 if (*fn && strcmp((*fn)->name, link->name) == 0) {
266                                         function = fn;
267                                         break;
268                                 }
269                         }
270
271                         if (function) {
272                                 translate_expression(state, module, link->ref->value.plain = malloc(sizeof(UwUVMExpression)), (*function)->expression);
273                                 free((*function)->name);
274                                 free(*function);
275
276                                 *function = NULL;
277                         } else {
278                                 error("module error: no function %s in module %s\n", link->name, module->filename);
279                         }
280                 } else {
281                         char *symbol = asprintf_wrapper("uwu_%s", link->name);
282                         link->ref->value.native = dlsym(module->handle.lib, symbol);
283
284                         check_dlerror();
285                         free(symbol);
286                 }
287         }
288 }
289
290 static void free_expression(ParseExpression *expr)
291 {
292         if (expr->type == EX_FNCALL) {
293                 for (size_t i = 0; i < expr->num_children; i++)
294                         free_expression(expr->children[i]);
295
296                 if (expr->children)
297                         free(expr->children);
298         }
299
300         if (expr->type != EX_INTLIT && expr->type != EX_ARGNUM)
301                 free(expr->value.str_value);
302
303         free(expr);
304 }
305
306 Program load_program(const char *progname, const char *modname)
307 {
308         char *prog_dirname = dirname_wrapper(progname);
309         char *api_path = asprintf_wrapper("%s/api/api.so", prog_dirname);
310
311         LoadState state = {
312                 .modules = NULL,
313                 .num_modules = 0,
314                 .std_path = asprintf_wrapper("%s/std", prog_dirname),
315                 .program = {
316                         .api_library = dlopen(api_path, RTLD_NOW | RTLD_GLOBAL),
317                         .main_function = NULL,
318                         .functions = NULL,
319                         .num_functions = 0,
320                         .libraries = NULL,
321                         .num_libraries = 0,
322                 },
323         };
324
325         free(prog_dirname);
326         free(api_path);
327
328         state.program.main_function = require_function(&state, require_module(&state, strdup(modname)), "main");
329
330         while (true) {
331                 bool fully_loaded = true;
332
333                 for (size_t i = 0; i < state.num_modules; i++) {
334                         Module *module = state.modules[i];
335
336 #if DEBUG
337                         printf("%s %lu/%lu\n", module->filename, module->loaded_functions, module->num_functions);
338 #endif
339
340                         if (module->loaded_functions < module->num_functions) {
341                                 fully_loaded = false;
342                                 load_functions(&state, module);
343                         }
344                 }
345
346                 if (fully_loaded)
347                         break;
348         }
349
350         free(state.std_path);
351
352         for (size_t i = 0; i < state.num_modules; i++) {
353                 Module *module = state.modules[i];
354
355                 free(module->path);
356                 free(module->filename);
357                 free(module->environment);
358
359                 for (size_t f = 0; f < module->num_functions; f++)
360                         free(module->functions[f].name);
361
362                 free(module->functions);
363
364                 if (module->type == MODULE_PLAIN) {
365                         for (size_t f = 0; f < module->handle.ast.num_functions; f++) {
366                                 ParseFunction *function = module->handle.ast.functions[f];
367
368                                 if (function) {
369                                         free_expression(function->expression);
370                                         free(function->name);
371                                         free(function);
372                                 }
373                         }
374
375                         if (module->handle.ast.functions)
376                                 free(module->handle.ast.functions);
377                 }
378
379                 free(module);
380         }
381
382         free(state.modules);
383
384         return state.program;
385 }