]> git.lizzy.rs Git - metalua.git/blobdiff - metalua/extension/match.mlua
Merge branch 'master' of ssh://git.eclipse.org/gitroot/koneki/org.eclipse.koneki...
[metalua.git] / metalua / extension / match.mlua
diff --git a/metalua/extension/match.mlua b/metalua/extension/match.mlua
new file mode 100644 (file)
index 0000000..8561e05
--- /dev/null
@@ -0,0 +1,400 @@
+-------------------------------------------------------------------------------
+-- Copyright (c) 2006-2013 Fabien Fleutot and others.
+--
+-- All rights reserved.
+--
+-- This program and the accompanying materials are made available
+-- under the terms of the Eclipse Public License v1.0 which
+-- accompanies this distribution, and is available at
+-- http://www.eclipse.org/legal/epl-v10.html
+--
+-- This program and the accompanying materials are also made available
+-- under the terms of the MIT public license which accompanies this
+-- distribution, and is available at http://www.lua.org/license.html
+--
+-- Contributors:
+--     Fabien Fleutot - API and implementation
+--
+-------------------------------------------------------------------------------
+
+-------------------------------------------------------------------------------
+--
+-- Glossary:
+--
+-- * term_seq: the tested stuff, a sequence of terms
+-- * pattern_element: might match one term of a term seq. Represented
+--   as expression ASTs.
+-- * pattern_seq: might match a term_seq
+-- * pattern_group: several pattern seqs, one of them might match
+--                  the term seq.
+-- * case: pattern_group * guard option * block
+-- * match_statement: tested term_seq * case list
+--
+-- Hence a complete match statement is a:
+--
+-- { list(expr),  list{ list(list(expr)), expr or false, block } }
+--
+-- Implementation hints
+-- ====================
+--
+-- The implementation is made as modular as possible, so that parts
+-- can be reused in other extensions. The priviledged way to share
+-- contextual information across functions is through the 'cfg' table
+-- argument. Its fields include:
+--
+-- * code: code generated from pattern. A pattern_(element|seq|group)
+--   is compiled as a sequence of instructions which will jump to
+--   label [cfg.on_failure] if the tested term doesn't match.
+--
+-- * on_failure: name of the label where the code will jump if the
+--   pattern doesn't match
+--
+-- * locals: names of local variables used by the pattern. This
+--   includes bound variables, and temporary variables used to
+--   destructure tables. Names are stored as keys of the table,
+--   values are meaningless.
+--
+-- * after_success: label where the code must jump after a pattern
+--   succeeded to capture a term, and the guard suceeded if there is
+--   any, and the conditional block has run.
+--
+-- * ntmp: number of temporary variables used to destructurate table
+--   in the current case.
+--
+-- Code generation is performed by acc_xxx() functions, which accumulate
+-- code in cfg.code:
+--
+-- * acc_test(test, cfg) will generate a jump to cfg.on_failure
+--   *when the test returns TRUE*
+--
+-- * acc_stat accumulates a statement
+--
+-- * acc_assign accumulate an assignment statement, and makes sure that
+--   the LHS variable the registered as local in cfg.locals.
+--
+-------------------------------------------------------------------------------
+
+-- TODO: hygiene wrt type()
+-- TODO: cfg.ntmp isn't reset as often as it could. I'm not even sure
+--       the corresponding locals are declared.
+
+
+local gg  = require 'metalua.grammar.generator'
+local pp  = require 'metalua.pprint'
+
+----------------------------------------------------------------------
+-- This would have been best done through library 'metalua.walk',
+-- but walk depends on match, so we have to break the dependency.
+-- It replaces all instances of `...' in `ast' with `term', unless
+-- it appears in a function.
+----------------------------------------------------------------------
+local function replace_dots (ast, term)
+    local function rec (node)
+        for i, child in ipairs(node) do
+            if type(child)~="table" then -- pass
+            elseif child.tag=='Dots' then
+                if term=='ambiguous' then
+                    error ("You can't use `...' on the right of a match case when it appears "..
+                           "more than once on the left")
+                else node[i] = term end
+            elseif child.tag=='Function' then return nil
+            else rec(child) end
+        end
+    end
+    return rec(ast)
+end
+
+local tmpvar_base = gg.gensym 'submatch.' [1]
+
+local function next_tmpvar(cfg)
+   assert (cfg.ntmp, "No cfg.ntmp imbrication level in the match compiler")
+   cfg.ntmp = cfg.ntmp+1
+   return `Id{ tmpvar_base .. cfg.ntmp }
+end
+
+-- Code accumulators
+local acc_stat = |x,cfg| table.insert (cfg.code, x)
+local acc_test = |x,cfg| acc_stat(+{stat: if -{x} then -{`Goto{cfg.on_failure}} end}, cfg)
+-- lhs :: `Id{ string }
+-- rhs :: expr
+local function acc_assign (lhs, rhs, cfg)
+   assert(lhs.tag=='Id')
+   cfg.locals[lhs[1]] = true
+   acc_stat (`Set{ {lhs}, {rhs} }, cfg)
+end
+
+local literal_tags = { String=1, Number=1, True=1, False=1, Nil=1 }
+
+-- pattern :: `Id{ string }
+-- term    :: expr
+local function id_pattern_element_builder (pattern, term, cfg)
+   assert (pattern.tag == "Id")
+   if pattern[1] == "_" then
+      -- "_" is used as a dummy var ==> no assignment, no == checking
+      cfg.locals._ = true
+   elseif cfg.locals[pattern[1]] then
+      -- This var is already bound ==> test for equality
+      acc_test (+{ -{term} ~= -{pattern} }, cfg)
+   else
+      -- Free var ==> bind it, and remember it for latter linearity checking
+      acc_assign (pattern, term, cfg)
+      cfg.locals[pattern[1]] = true
+   end
+end
+
+-- mutually recursive with table_pattern_element_builder
+local pattern_element_builder
+
+-- pattern :: pattern and `Table{ }
+-- term    :: expr
+local function table_pattern_element_builder (pattern, term, cfg)
+   local seen_dots, len = false, 0
+   acc_test (+{ type( -{term} ) ~= "table" }, cfg)
+   for i = 1, #pattern do
+      local key, sub_pattern
+      if pattern[i].tag=="Pair" then -- Explicit key/value pair
+         key, sub_pattern = unpack (pattern[i])
+         assert (literal_tags[key.tag], "Invalid key")
+      else -- Implicit key
+         len, key, sub_pattern = len+1, `Number{ len+1 }, pattern[i]
+      end
+
+      -- '...' can only appear in final position
+      -- Could be fixed actually...
+      assert (not seen_dots, "Wrongly placed `...' ")
+
+      if sub_pattern.tag == "Id" then
+         -- Optimization: save a useless [ v(n+1)=v(n).key ]
+         id_pattern_element_builder (sub_pattern, `Index{ term, key }, cfg)
+         if sub_pattern[1] ~= "_" then
+            acc_test (+{ -{sub_pattern} == nil }, cfg)
+         end
+      elseif sub_pattern.tag == "Dots" then
+         -- Remember where the capture is, and thatt arity checking shouldn't occur
+         seen_dots = true
+      else
+         -- Business as usual:
+         local v2 = next_tmpvar(cfg)
+         acc_assign (v2, `Index{ term, key }, cfg)
+         pattern_element_builder (sub_pattern, v2, cfg)
+         -- TODO: restore ntmp?
+      end
+   end
+   if seen_dots then -- remember how to retrieve `...'
+      -- FIXME: check, but there might be cases where the variable -{term}
+      -- will be overridden in contrieved tables.
+      -- ==> save it now, and clean the setting statement if unused
+      if cfg.dots_replacement then cfg.dots_replacement = 'ambiguous'
+      else cfg.dots_replacement = +{ select (-{`Number{len}}, unpack(-{term})) } end
+   else -- Check arity
+      acc_test (+{ #-{term} ~= -{`Number{len}} }, cfg)
+   end
+end
+
+-- mutually recursive with pattern_element_builder
+local eq_pattern_element_builder, regexp_pattern_element_builder
+
+-- Concatenate code in [cfg.code], that will jump to label
+-- [cfg.on_failure] if [pattern] doesn't match [term]. [pattern]
+-- should be an identifier, or at least cheap to compute and
+-- side-effects free.
+--
+-- pattern :: pattern_element
+-- term    :: expr
+function pattern_element_builder (pattern, term, cfg)
+   if literal_tags[pattern.tag] then
+      acc_test (+{ -{term} ~= -{pattern} }, cfg)
+   elseif "Id" == pattern.tag then
+      id_pattern_element_builder (pattern, term, cfg)
+   elseif "Op" == pattern.tag and "div" == pattern[1] then
+      regexp_pattern_element_builder (pattern, term, cfg)
+   elseif "Op" == pattern.tag and "eq" == pattern[1] then
+      eq_pattern_element_builder (pattern, term, cfg)
+   elseif "Table" == pattern.tag then
+      table_pattern_element_builder (pattern, term, cfg)
+   else
+      error ("Invalid pattern at "..
+             tostring(pattern.lineinfo)..
+             ": "..pp.tostring(pattern, {hide_hash=true}))
+   end
+end
+
+function eq_pattern_element_builder (pattern, term, cfg)
+   local _, pat1, pat2 = unpack (pattern)
+   local ntmp_save = cfg.ntmp
+   pattern_element_builder (pat1, term, cfg)
+   cfg.ntmp = ntmp_save
+   pattern_element_builder (pat2, term, cfg)
+end
+
+-- pattern :: `Op{ 'div', string, list{`Id string} or `Id{ string }}
+-- term    :: expr
+local function regexp_pattern_element_builder (pattern, term, cfg)
+   local op, regexp, sub_pattern = unpack(pattern)
+
+   -- Sanity checks --
+   assert (op=='div', "Don't know what to do with that op in a pattern")
+   assert (regexp.tag=="String",
+           "Left hand side operand for '/' in a pattern must be "..
+           "a literal string representing a regular expression")
+   if sub_pattern.tag=="Table" then
+      for _, x in ipairs(sub_pattern) do
+        assert (x.tag=="Id" or x.tag=='Dots',
+                "Right hand side operand for '/' in a pattern must be "..
+                "a list of identifiers")
+      end
+   else
+      assert (sub_pattern.tag=="Id",
+             "Right hand side operand for '/' in a pattern must be "..
+              "an identifier or a list of identifiers")
+   end
+
+   -- Regexp patterns can only match strings
+   acc_test (+{ type(-{term}) ~= 'string' }, cfg)
+   -- put all captures in a list
+   local capt_list  = +{ { string.strmatch(-{term}, -{regexp}) } }
+   -- save them in a var_n for recursive decomposition
+   local v2 = next_tmpvar(cfg)
+   acc_stat (+{stat: local -{v2} = -{capt_list} }, cfg)
+   -- was capture successful?
+   acc_test (+{ not next (-{v2}) }, cfg)
+   pattern_element_builder (sub_pattern, v2, cfg)
+end
+
+
+-- Jumps to [cfg.on_faliure] if pattern_seq doesn't match
+-- term_seq.
+local function pattern_seq_builder (pattern_seq, term_seq, cfg)
+   if #pattern_seq ~= #term_seq then error ("Bad seq arity") end
+   cfg.locals = { } -- reset bound variables between alternatives
+   for i=1, #pattern_seq do
+      cfg.ntmp = 1 -- reset the tmp var generator
+      pattern_element_builder(pattern_seq[i], term_seq[i], cfg)
+   end
+end
+
+--------------------------------------------------
+-- for each case i:
+--   pattern_seq_builder_i:
+--    * on failure, go to on_failure_i
+--    * on success, go to on_success
+--   label on_success:
+--   block
+--   goto after_success
+--   label on_failure_i
+--------------------------------------------------
+local function case_builder (case, term_seq, cfg)
+   local patterns_group, guard, block = unpack(case)
+   local on_success = gg.gensym 'on_success' [1]
+   for i = 1, #patterns_group do
+      local pattern_seq = patterns_group[i]
+      cfg.on_failure = gg.gensym 'match_fail' [1]
+      cfg.dots_replacement = false
+      pattern_seq_builder (pattern_seq, term_seq, cfg)
+      if i<#patterns_group then
+         acc_stat (`Goto{on_success}, cfg)
+         acc_stat (`Label{cfg.on_failure}, cfg)
+      end
+   end
+   acc_stat (`Label{on_success}, cfg)
+   if guard then acc_test (+{not -{guard}}, cfg) end
+   if cfg.dots_replacement then
+      replace_dots (block, cfg.dots_replacement)
+   end
+   block.tag = 'Do'
+   acc_stat (block, cfg)
+   acc_stat (`Goto{cfg.after_success}, cfg)
+   acc_stat (`Label{cfg.on_failure}, cfg)
+end
+
+local function match_builder (x)
+   local term_seq, cases = unpack(x)
+   local cfg = {
+      code          = `Do{ },
+      after_success = gg.gensym "_after_success" }
+
+
+   -- Some sharing issues occur when modifying term_seq,
+   -- so it's replaced by a copy new_term_seq.
+   -- TODO: clean that up, and re-suppress the useless copies
+   -- (cf. remarks about capture bug below).
+   local new_term_seq = { }
+
+   local match_locals
+
+   -- Make sure that all tested terms are variables or literals
+   for i=1, #term_seq do
+      local t = term_seq[i]
+      -- Capture problem: the following would compile wrongly:
+      --    `match x with x -> end'
+      -- Temporary workaround: suppress the condition, so that
+      -- all external variables are copied into unique names.
+      --if t.tag ~= 'Id' and not literal_tags[t.tag] then
+         local v = gg.gensym 'v'
+         if not match_locals then match_locals = `Local{ {v}, {t} } else
+            table.insert(match_locals[1], v)
+            table.insert(match_locals[2], t)
+         end
+         new_term_seq[i] = v
+      --end
+   end
+   term_seq = new_term_seq
+
+   if match_locals then acc_stat(match_locals, cfg) end
+
+   for i=1, #cases do
+      local case_cfg = {
+         after_success    = cfg.after_success,
+         code             = `Do{ }
+         -- locals    = { } -- unnecessary, done by pattern_seq_builder
+      }
+      case_builder (cases[i], term_seq, case_cfg)
+      if next (case_cfg.locals) then
+         local case_locals = { }
+         table.insert (case_cfg.code, 1, `Local{ case_locals, { } })
+         for v, _ in pairs (case_cfg.locals) do
+            table.insert (case_locals, `Id{ v })
+         end
+      end
+      acc_stat(case_cfg.code, cfg)
+  end
+  local li = `String{tostring(cases.lineinfo)}
+  acc_stat(+{error('mismatch at '..-{li})}, cfg)
+  acc_stat(`Label{cfg.after_success}, cfg)
+  return cfg.code
+end
+
+----------------------------------------------------------------------
+-- Syntactical front-end
+----------------------------------------------------------------------
+
+local function extend(M)
+
+    local _M = gg.future(M)
+
+    checks('metalua.compiler.parser')
+    M.lexer:add{ "match", "with", "->" }
+    M.block.terminators:add "|"
+
+    local match_cases_list_parser = gg.list{ name = "match cases list",
+        gg.sequence{ name = "match case",
+                     gg.list{ name  = "match case patterns list",
+                              primary     = _M.expr_list,
+                              separators  = "|",
+                              terminators = { "->", "if" } },
+                     gg.onkeyword{ "if", _M.expr, consume = true },
+                     "->",
+                     _M.block },
+        separators  = "|",
+        terminators = "end" }
+
+    M.stat:add{ name = "match statement",
+                  "match",
+                  _M.expr_list,
+                  "with", gg.optkeyword "|",
+                  match_cases_list_parser,
+                  "end",
+                  builder = |x| match_builder{ x[1], x[3] } }
+end
+
+return extend
\ No newline at end of file