]> git.lizzy.rs Git - metalua.git/blob - src/lib/metalua/extension/match.mlua
fbc032ccdd9a82d614a9682169dba36a23c5ee95
[metalua.git] / src / lib / metalua / extension / match.mlua
1 ----------------------------------------------------------------------
2 -- Metalua samples:  $Id$
3 --
4 -- Summary: Structural pattern matching for metalua ADT.
5 --
6 ----------------------------------------------------------------------
7 --
8 -- Copyright (c) 2006-2008, Fabien Fleutot <metalua@gmail.com>.
9 --
10 -- This software is released under the MIT Licence, see licence.txt
11 -- for details.
12 --
13 --------------------------------------------------------------------------------
14 --
15 -- Glossary:
16 --
17 -- * term_seq: the tested stuff, a sequence of terms
18 -- * pattern_element: might match one term of a term seq. Represented
19 --   as expression ASTs.
20 -- * pattern_seq: might match a term_seq
21 -- * pattern_group: several pattern seqs, one of them might match
22 --                  the term seq.
23 -- * case: pattern_group * guard option * block
24 -- * match_statement: tested term_seq * case list 
25 --
26 -- Hence a complete match statement is a:
27 --
28 -- { list(expr),  list{ list(list(expr)), expr or false, block } } 
29 --
30 -- Implementation hints
31 -- ====================
32 --
33 -- The implementation is made as modular as possible, so that parts
34 -- can be reused in other extensions. The priviledged way to share
35 -- contextual information across functions is through the 'cfg' table
36 -- argument. Its fields include:
37 --
38 -- * code: code generated from pattern. A pattern_(element|seq|group)
39 --   is compiled as a sequence of instructions which will jump to
40 --   label [cfg.on_failure] if the tested term doesn't match.
41 --
42 -- * on_failure: name of the label where the code will jump if the
43 --   pattern doesn't match
44 --
45 -- * locals: names of local variables used by the pattern. This
46 --   includes bound variables, and temporary variables used to
47 --   destructure tables. Names are stored as keys of the table,
48 --   values are meaningless.
49 --
50 -- * after_success: label where the code must jump after a pattern
51 --   succeeded to capture a term, and the guard suceeded if there is
52 --   any, and the conditional block has run.
53 --
54 -- * ntmp: number of temporary variables used to destructurate table
55 --   in the current case.
56 --
57 -- Code generation is performed by acc_xxx() functions, which accumulate
58 -- code in cfg.code:
59 --
60 -- * acc_test(test, cfg) will generate a jump to cfg.on_failure 
61 --   *when the test returns TRUE*
62 --
63 -- * acc_stat accumulates a statement
64 -- 
65 -- * acc_assign accumulate an assignment statement, and makes sure that 
66 --   the LHS variable the registered as local in cfg.locals.
67 --   
68 ----------------------------------------------------------------------
69
70 -- TODO: hygiene wrt type()
71 -- TODO: cfg.ntmp isn't reset as often as it could. I'm not even sure
72 --       the corresponding locals are declared.
73
74 module ('spmatch', package.seeall)
75
76 tmpvar_base = mlp.gensym 'submatch.' [1]
77 function next_tmpvar(cfg)
78    assert (cfg.ntmp, "No cfg.ntmp imbrication level in the match compiler")
79    cfg.ntmp = cfg.ntmp+1
80    return `Id{ tmpvar_base .. cfg.ntmp }
81 end
82
83 -- Code accumulators
84 acc_stat = |x,cfg| table.insert (cfg.code, x)
85 acc_test = |x,cfg| acc_stat(+{stat: if -{x} then -{`Goto{cfg.on_failure}} end}, cfg)
86 -- lhs :: `Id{ string }
87 -- rhs :: expr
88 function acc_assign (lhs, rhs, cfg)
89    assert(lhs.tag=='Id')
90    cfg.locals[lhs[1]] = true
91    acc_stat (`Set{ {lhs}, {rhs} }, cfg)
92 end
93
94 literal_tags = table.transpose{ 'String', 'Number', 'True', 'False', 'Nil' }
95
96 -- pattern :: `Id{ string }
97 -- term    :: expr
98 function id_pattern_element_builder (pattern, term, cfg)
99    assert (pattern.tag == "Id")
100    if pattern[1] == "_" then 
101       -- "_" is used as a dummy var ==> no assignment, no == checking
102       cfg.locals._ = true
103    elseif cfg.locals[pattern[1]] then 
104       -- This var is already bound ==> test for equality
105       acc_test (+{ -{term} ~= -{pattern} }, cfg)
106    else
107       -- Free var ==> bind it, and remember it for latter linearity checking
108       acc_assign (pattern, term, cfg) 
109       cfg.locals[pattern[1]] = true
110    end
111 end
112
113 -- Concatenate code in [cfg.code], that will jump to label
114 -- [cfg.on_failure] if [pattern] doesn't match [term]. [pattern]
115 -- should be an identifier, or at least cheap to compute and
116 -- side-effects free.
117 --
118 -- pattern :: pattern_element
119 -- term    :: expr
120 function pattern_element_builder (pattern, term, cfg)
121    if literal_tags[pattern.tag] then
122       acc_test (+{ -{term} ~= -{pattern} }, cfg)
123    elseif "Id" == pattern.tag then 
124       id_pattern_element_builder (pattern, term, cfg)
125    elseif "Op" == pattern.tag and "div" == pattern[1] then
126       regexp_pattern_element_builder (pattern, term, cfg)
127    elseif "Op" == pattern.tag and "eq" == pattern[1] then
128       eq_pattern_element_builder (pattern, term, cfg)
129    elseif "Table" == pattern.tag then
130       table_pattern_element_builder (pattern, term, cfg)
131    else 
132       error ("Invalid pattern: "..table.tostring(pattern, "nohash"))
133    end
134 end
135
136 function eq_pattern_element_builder (pattern, term, cfg)
137    local _, pat1, pat2 = unpack (pattern)
138    local ntmp_save = cfg.ntmp
139    pattern_element_builder (pat1, term, cfg)
140    cfg.ntmp = ntmp_save
141    pattern_element_builder (pat2, term, cfg)
142 end
143
144 -- pattern :: `Op{ 'div', string, list{`Id string} or `Id{ string }}
145 -- term    :: expr
146 function regexp_pattern_element_builder (pattern, term, cfg)
147    local op, regexp, sub_pattern = unpack(pattern)
148
149    -- Sanity checks --
150    assert (op=='div', "Don't know what to do with that op in a pattern")
151    assert (regexp.tag=="String", 
152            "Left hand side operand for '/' in a pattern must be "..
153            "a literal string representing a regular expression")
154    assert (sub_pattern.tag=="Table",
155            "Right hand side operand for '/' in a pattern must be "..
156            "an identifier or a list of identifiers")
157    for x in ivalues(sub_pattern) do
158       assert (x.tag=="Id" or x.tag=='Dots',
159               "Right hand side operand for '/' in a pattern must be "..
160               "a list of identifiers")
161    end
162
163    -- Regexp patterns can only match strings
164    acc_test (+{ type(-{term}) ~= 'string' }, cfg)
165    -- put all captures in a list
166    local capt_list  = +{ { string.strmatch(-{term}, -{regexp}) } }
167    -- save them in a var_n for recursive decomposition
168    local v2 = next_tmpvar(cfg)
169    acc_stat (+{stat: local -{v2} = -{capt_list} }, cfg)
170    -- was capture successful?
171    acc_test (+{ not next (-{v2}) }, cfg)
172    pattern_element_builder (sub_pattern, v2, cfg)
173 end
174
175 -- pattern :: pattern and `Table{ }
176 -- term    :: expr
177 function table_pattern_element_builder (pattern, term, cfg)
178    local seen_dots, len = false, 0
179    acc_test (+{ type( -{term} ) ~= "table" }, cfg)
180    for i = 1, #pattern do
181       local key, sub_pattern
182       if pattern[i].tag=="Pair" then -- Explicit key/value pair
183          key, sub_pattern = unpack (pattern[i])
184          assert (literal_tags[key.tag], "Invalid key")
185       else -- Implicit key
186          len, key, sub_pattern = len+1, `Number{ len+1 }, pattern[i]
187       end
188       
189       -- '...' can only appear in final position
190       -- Could be fixed actually...
191       assert (not seen_dots, "Wrongly placed `...' ")
192
193       if sub_pattern.tag == "Id" then 
194          -- Optimization: save a useless [ v(n+1)=v(n).key ]
195          id_pattern_element_builder (sub_pattern, `Index{ term, key }, cfg)
196          if sub_pattern[1] ~= "_" then 
197             acc_test (+{ -{sub_pattern} == nil }, cfg)
198          end
199       elseif sub_pattern.tag == "Dots" then
200          -- Remember to suppress arity checking
201          seen_dots = true
202       else
203          -- Business as usual:
204          local v2 = next_tmpvar(cfg)
205          acc_assign (v2, `Index{ term, key }, cfg)
206          pattern_element_builder (sub_pattern, v2, cfg)
207          -- TODO: restore ntmp?
208       end
209    end
210    if not seen_dots then -- Check arity
211       acc_test (+{ #-{term} ~= -{`Number{len}} }, cfg)
212    end
213 end
214
215 -- Jumps to [cfg.on_faliure] if pattern_seq doesn't match
216 -- term_seq.
217 function pattern_seq_builder (pattern_seq, term_seq, cfg)
218    if #pattern_seq ~= #term_seq then error ("Bad seq arity") end
219    cfg.locals = { } -- reset bound variables between alternatives
220    for i=1, #pattern_seq do
221       cfg.ntmp = 1 -- reset the tmp var generator
222       pattern_element_builder(pattern_seq[i], term_seq[i], cfg)
223    end
224 end
225
226 --------------------------------------------------
227 -- for each case i:
228 --   pattern_seq_builder_i:
229 --    * on failure, go to on_failure_i
230 --    * on success, go to on_success
231 --   label on_success:
232 --   block
233 --   goto after_success
234 --   label on_failure_i
235 --------------------------------------------------
236 function case_builder (case, term_seq, cfg)
237    local patterns_group, guard, block = unpack(case)
238    local on_success = mlp.gensym 'on_success' [1]
239    for i = 1, #patterns_group do
240       local pattern_seq = patterns_group[i]
241       cfg.on_failure = mlp.gensym 'match_fail' [1]
242       pattern_seq_builder (pattern_seq, term_seq, cfg)
243       if i<#patterns_group then
244          acc_stat (`Goto{on_success}, cfg)
245          acc_stat (`Label{cfg.on_failure}, cfg)
246       end
247    end
248    acc_stat (`Label{on_success}, cfg)
249    if guard then acc_test (+{not -{guard}}, cfg) end
250    block.tag = 'Do'
251    acc_stat (block, cfg)
252    acc_stat (`Goto{cfg.after_success}, cfg)
253    acc_stat (`Label{cfg.on_failure}, cfg)
254 end
255
256 function match_builder (x)
257    local term_seq, cases = unpack(x)
258    local cfg = { 
259       code          = `Do{ },
260       after_success = mlp.gensym "_after_success" }
261
262    local match_locals
263
264    -- Make sure that all tested terms are variables or literals
265    for i=1, #term_seq do
266       local t = term_seq[i]
267       -- Capture problem: the following would compile wrongly:
268       --    `match x with x -> end'
269       -- Temporary workaround: suppress the condition, so that
270       -- all external variables are copied into unique names.
271       --if t.tag ~= 'Id' and not literal_tags[t.tag] then
272          local v = mlp.gensym 'v'
273          if not match_locals then match_locals = `Local{ {v}, {t} } else
274             table.insert(match_locals[1], v)
275             table.insert(match_locals[2], t)
276          end
277          term_seq[i] = v
278       --end
279    end
280    
281    if match_locals then acc_stat(match_locals, cfg) end
282
283    for i=1, #cases do
284       local case_cfg = { 
285          after_success = cfg.after_success,
286          code         = `Do{ },
287          locals       = { } }
288       case_builder (cases[i], term_seq, case_cfg)
289       if next (case_cfg.locals) then
290          local case_locals = { }
291          table.insert (case_cfg.code, 1, `Local{ case_locals, { } })
292          for v in keys (case_cfg.locals) do
293             table.insert (case_locals, `Id{ v })
294          end
295       end
296       acc_stat(case_cfg.code, cfg)
297    end
298    acc_stat(+{error 'mismatch'}, cfg)
299    acc_stat(`Label{cfg.after_success}, cfg)
300    return cfg.code
301 end
302
303 ----------------------------------------------------------------------
304 -- Syntactical front-end
305 ----------------------------------------------------------------------
306
307 mlp.lexer:add{ "match", "with", "->" }
308 mlp.block.terminators:add "|"
309
310 match_cases_list_parser = gg.list{ name = "match cases list",
311    gg.sequence{ name = "match case",
312       gg.list{ name  = "match case patterns list",
313          primary     = mlp.expr_list,
314          separators  = "|",
315          terminators = { "->", "if" } },
316       gg.onkeyword{ "if", mlp.expr, consume = true },
317       "->",
318       mlp.block },
319    separators  = "|",
320    terminators = "end" }
321
322 mlp.stat:add{ name = "match statement",
323    "match", 
324    mlp.expr_list, 
325    "with", gg.optkeyword "|",
326    match_cases_list_parser,
327    "end",
328    builder = |x| match_builder{ x[1], x[3] } }
329