]> git.lizzy.rs Git - uwu-lang.git/blobdiff - src/load.c
Allow passing arguments to program, refactor directory structure
[uwu-lang.git] / src / load.c
diff --git a/src/load.c b/src/load.c
new file mode 100644 (file)
index 0000000..edeba97
--- /dev/null
@@ -0,0 +1,400 @@
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <stdarg.h>
+#include <libgen.h>
+#include <dlfcn.h>
+#include "common/err.h"
+#include "common/str.h"
+#include "load.h"
+#include "parse.h"
+
+#define DEBUG 0
+
+// helper functions
+
+static char *wrap_name_func(const char *name, char *(*fn)(char *))
+{
+       char *copy = strdup(name);
+       char *result = fn(copy);
+       char *result_copy = strdup(result);
+
+       free(copy);
+       return result_copy;
+}
+
+static char *basename_wrapper(const char *name)
+{
+       return wrap_name_func(name, &basename);
+}
+
+static char *dirname_wrapper(const char *name)
+{
+       return wrap_name_func(name, &dirname);
+}
+
+static bool file_exists(const char *filename)
+{
+       FILE *f = fopen(filename, "r");
+
+       if (f) {
+               fclose(f);
+               return true;
+       }
+
+       return false;
+}
+
+// type definitions
+
+typedef struct
+{
+       char *name;
+       UwUVMFunction *ref;
+} FunctionLink;
+
+typedef struct
+{
+       char *path;                                     // path without file extension
+       char *filename;                         // path with file extension
+       char *environment;                      // directory path
+
+       UwUVMModuleType type;                   // native (.so) or plain (.uwu)
+
+       FunctionLink *functions;        // required functions
+       size_t    num_functions;        // number of required functions
+       size_t loaded_functions;        // number of loaded functions (<= num_functions)
+
+       union
+       {
+               AbstractSyntaxTree ast; // abstract syntax tree generated by parser (for plain modules)
+               void *lib;                              // dlopen() shared object handle (for native modules)
+       } handle;
+} Module;
+
+typedef struct
+{
+       Module   **modules;                     // loaded modules
+       size_t num_modules;                     // count for modules
+
+       char *std_path;                         // path to standard library
+
+       Program program;                        // the result program
+} LoadState;
+
+// functions
+
+// returns mallocated string
+static inline char *get_filename(const char *module_path)
+{
+       const char *try_names[3] = {
+               "%s",
+               "%s.uwu",
+               "%s.so",
+       };
+
+       char *filename;
+
+       for (int i = 0; i < 3; i++) {
+               filename = asprintf_wrapper(try_names[i], module_path);
+
+               if (file_exists(filename))
+                       return filename;
+               else
+                       free(filename);
+       }
+
+       return NULL;
+}
+
+// module_path is a mallocated string
+static Module *require_module(LoadState *state, char *module_path)
+{
+       for (size_t i = 0; i < state->num_modules; i++) {
+               Module *module = state->modules[i];
+
+               if (strcmp(module_path, module->path) == 0) {
+                       free(module_path);
+                       return module;
+               }
+       }
+
+       char *filename = get_filename(module_path);
+
+       if (! filename)
+               error("error: module %s not found\n", module_path);
+
+       size_t filename_len = strlen(filename);
+       UwUVMModuleType type = (filename_len >= 3 && strcmp(filename + filename_len - 3, ".so") == 0) ? MODULE_NATIVE : MODULE_PLAIN;
+
+       state->modules = realloc(state->modules, sizeof *state->modules * ++state->num_modules);
+       Module *module = state->modules[state->num_modules - 1] = malloc(sizeof *module);
+
+       *module = (Module) {
+               .path = module_path,
+               .filename = filename,
+               .environment = dirname_wrapper(module_path),
+
+               .type = type,
+
+               .functions = NULL,
+               .num_functions = 0,
+               .loaded_functions = 0,
+       };
+
+       if (type == MODULE_PLAIN) {
+               module->handle.ast = parse_file(filename);
+       } else {
+               state->program.libraries = realloc(state->program.libraries, sizeof(void *) * ++state->program.num_libraries);
+               state->program.libraries[state->program.num_libraries - 1] = module->handle.lib = dlopen(filename, RTLD_LAZY);
+
+               char *err = dlerror();
+               if (err)
+                       error("library error: %s\n", err);
+       }
+
+       return module;
+}
+
+static UwUVMFunction *require_function(LoadState *state, Module *module, const char *name)
+{
+       for (size_t i = 0; i < module->num_functions; i++) {
+               FunctionLink *link = &module->functions[i];
+
+               if (strcmp(link->name, name) == 0)
+                       return link->ref;
+       }
+
+       UwUVMFunction *ref = malloc(sizeof *ref);
+       ref->type = module->type;
+
+       state->program.functions = realloc(state->program.functions, sizeof *state->program.functions * ++state->program.num_functions);
+       state->program.functions[state->program.num_functions - 1] = ref;
+
+       module->functions = realloc(module->functions, sizeof *module->functions * ++module->num_functions);
+       module->functions[module->num_functions - 1] = (FunctionLink) {
+               .name = strdup(name),
+               .ref = ref,
+       };
+
+       return ref;
+}
+
+static UwUVMFunction *resolve_function(LoadState *state, Module *caller_module, const char *full_name)
+{
+       size_t len = strlen(full_name);
+
+       const char *fnname;
+       for (fnname = &full_name[len - 1]; *fnname != ':' && fnname > full_name; fnname--)
+               ;
+
+       if (*fnname == ':')
+               fnname++;
+
+       if (*fnname == '\0')
+               error("error: empty function name\n");
+
+       Module *callee_module;
+
+       if (fnname == full_name) {
+               callee_module = caller_module;
+       } else {
+               const char *caller_path = caller_module->environment;
+               const char *callee_name = full_name;
+
+               if (*callee_name == ':') {
+                       caller_path = state->std_path;
+                       callee_name++;
+               }
+
+               size_t path_len = fnname - callee_name;
+               char callee_path[path_len];
+
+               for (size_t i = 0; i < path_len; i++)
+                       callee_path[i] = (i == path_len - 1) ? '\0'
+                               : (callee_name[i] == ':') ? '/'
+                               : callee_name[i];
+
+               callee_module = require_module(state, asprintf_wrapper("%s/%s", caller_path, callee_path));
+       }
+
+       return require_function(state, callee_module, fnname);
+}
+
+static void translate_expression(LoadState *state, Module *module, UwUVMExpression *vm_expr, ParseExpression *parse_expr)
+{
+       UwUVMFunction *vm_function;
+
+       if (parse_expr->type == EX_FNNAME || parse_expr->type == EX_FNCALL) {
+               vm_function = resolve_function(state, module, parse_expr->value.str_value);
+               free(parse_expr->value.str_value);
+       }
+
+       switch (vm_expr->type = parse_expr->type) {
+               case EX_INTLIT:
+               case EX_ARGNUM:
+                       vm_expr->value.int_value = parse_expr->value.int_value;
+                       break;
+
+               case EX_STRLIT:
+                       vm_expr->value.str_value = parse_expr->value.str_value;
+                       break;
+
+               case EX_FNNAME:
+                       vm_expr->value.ref_value = vm_function;
+                       break;
+
+               case EX_FNCALL:
+                       vm_expr->value.cll_value.function = vm_function;
+                       vm_expr->value.cll_value.args = malloc(sizeof(UwUVMExpression) * parse_expr->num_children);
+                       vm_expr->value.cll_value.num_args = parse_expr->num_children;
+
+                       for (size_t i = 0; i < parse_expr->num_children; i++)
+                               translate_expression(state, module, &vm_expr->value.cll_value.args[i], parse_expr->children[i]);
+
+                       if (parse_expr->children)
+                               free(parse_expr->children);
+                       break;
+
+               default:
+                       break;
+       }
+
+       free(parse_expr);
+}
+
+static void load_functions(LoadState *state, Module *module)
+{
+       for (; module->loaded_functions < module->num_functions; module->loaded_functions++) {
+               FunctionLink *link = &module->functions[module->loaded_functions];
+
+               if (module->type == MODULE_PLAIN) {
+                       ParseFunction **function = NULL;
+
+                       for (size_t i = 0; i < module->handle.ast.num_functions; i++) {
+                               ParseFunction **fn = &module->handle.ast.functions[i];
+
+                               if (*fn && strcmp((*fn)->name, link->name) == 0) {
+                                       function = fn;
+                                       break;
+                               }
+                       }
+
+                       if (function) {
+                               translate_expression(state, module, link->ref->value.plain = malloc(sizeof(UwUVMExpression)), (*function)->expression);
+                               free((*function)->name);
+                               free(*function);
+
+                               *function = NULL;
+                       } else {
+                               error("error: no function %s in module %s\n", link->name, module->filename);
+                       }
+               } else {
+                       char *symbol = asprintf_wrapper("uwu_%s", link->name);
+                       link->ref->value.native = dlsym(module->handle.lib, symbol);
+
+                       char *err = dlerror();
+                       if (err)
+                               error("library error: %s\n", err);
+
+                       free(symbol);
+               }
+       }
+}
+
+static void free_expression(ParseExpression *expr)
+{
+       if (expr->type == EX_FNCALL) {
+               for (size_t i = 0; i < expr->num_children; i++)
+                       free_expression(expr->children[i]);
+
+               if (expr->children)
+                       free(expr->children);
+       }
+
+       if (expr->type != EX_INTLIT && expr->type != EX_ARGNUM)
+               free(expr->value.str_value);
+
+       free(expr);
+}
+
+Program load_program(const char *progname, const char *modname)
+{
+       char *prog_dirname = dirname_wrapper(progname);
+       char *api_path = asprintf_wrapper("%s/api/api.so", prog_dirname);
+
+       LoadState state = {
+               .modules = NULL,
+               .num_modules = 0,
+               .std_path = asprintf_wrapper("%s/std", prog_dirname),
+               .program = {
+                       .api_library = dlopen(api_path, RTLD_NOW | RTLD_GLOBAL),
+                       .main_function = NULL,
+                       .functions = NULL,
+                       .num_functions = 0,
+                       .libraries = NULL,
+                       .num_libraries = 0,
+               },
+       };
+
+       free(prog_dirname);
+       free(api_path);
+
+       state.program.main_function = require_function(&state, require_module(&state, strdup(modname)), "main");
+
+       while (true) {
+               bool fully_loaded = true;
+
+               for (size_t i = 0; i < state.num_modules; i++) {
+                       Module *module = state.modules[i];
+
+#if DEBUG
+                       printf("%s %lu/%lu\n", module->filename, module->loaded_functions, module->num_functions);
+#endif
+
+                       if (module->loaded_functions < module->num_functions) {
+                               fully_loaded = false;
+                               load_functions(&state, module);
+                       }
+               }
+
+               if (fully_loaded)
+                       break;
+       }
+
+       free(state.std_path);
+
+       for (size_t i = 0; i < state.num_modules; i++) {
+               Module *module = state.modules[i];
+
+               free(module->path);
+               free(module->filename);
+               free(module->environment);
+
+               for (size_t f = 0; f < module->num_functions; f++)
+                       free(module->functions[f].name);
+
+               free(module->functions);
+
+               if (module->type == MODULE_PLAIN) {
+                       for (size_t f = 0; f < module->handle.ast.num_functions; f++) {
+                               ParseFunction *function = module->handle.ast.functions[f];
+
+                               if (function) {
+                                       free_expression(function->expression);
+                                       free(function->name);
+                                       free(function);
+                               }
+                       }
+
+                       if (module->handle.ast.functions)
+                               free(module->handle.ast.functions);
+               }
+
+               free(module);
+       }
+
+       free(state.modules);
+
+       return state.program;
+}