]> git.lizzy.rs Git - micro.git/blob - tools/testgen.go
2465ae0a5d5ed27ef37835957a9ab242d45c97ff
[micro.git] / tools / testgen.go
1 package main
2
3 import (
4         "fmt"
5         "io/ioutil"
6         "log"
7         "os"
8         "regexp"
9         "strings"
10
11         "github.com/robertkrimen/otto/ast"
12         "github.com/robertkrimen/otto/parser"
13 )
14
15 type walker struct {
16         nodes []ast.Node
17 }
18
19 func (w *walker) Enter(node ast.Node) ast.Visitor {
20         w.nodes = append(w.nodes, node)
21         return w
22 }
23
24 func (w *walker) Exit(node ast.Node) {
25 }
26
27 func getAllNodes(node ast.Node) []ast.Node {
28         w := &walker{}
29         ast.Walk(w, node)
30         return w.nodes
31 }
32
33 func getCalls(node ast.Node, name string) []*ast.CallExpression {
34         nodes := []*ast.CallExpression{}
35         for _, n := range getAllNodes(node) {
36                 if ce, ok := n.(*ast.CallExpression); ok {
37                         var calleeName string
38                         switch callee := ce.Callee.(type) {
39                         case *ast.Identifier:
40                                 calleeName = callee.Name
41                         case *ast.DotExpression:
42                                 calleeName = callee.Identifier.Name
43                         default:
44                                 continue
45                         }
46                         if calleeName == name {
47                                 nodes = append(nodes, ce)
48                         }
49                 }
50         }
51         return nodes
52 }
53
54 func getPropertyValue(node ast.Node, key string) ast.Expression {
55         for _, p := range node.(*ast.ObjectLiteral).Value {
56                 if p.Key == key {
57                         return p.Value
58                 }
59         }
60         return nil
61 }
62
63 type operation struct {
64         startLine   int
65         startColumn int
66         endLine     int
67         endColumn   int
68         text        []string
69 }
70
71 type check struct {
72         before     []string
73         operations []operation
74         after      []string
75 }
76
77 type test struct {
78         description string
79         checks      []check
80 }
81
82 func stringSliceToGoSource(slice []string) string {
83         var b strings.Builder
84         b.WriteString("[]string{\n")
85         for _, s := range slice {
86                 b.WriteString(fmt.Sprintf("%#v,\n", s))
87         }
88         b.WriteString("}")
89         return b.String()
90 }
91
92 func testToGoTest(test test, name string) string {
93         var b strings.Builder
94
95         b.WriteString("func Test")
96         b.WriteString(name)
97         b.WriteString("(t *testing.T) {\n")
98
99         for _, c := range test.checks {
100                 b.WriteString("check(\n")
101                 b.WriteString("t,\n")
102                 b.WriteString(fmt.Sprintf("%v,\n", stringSliceToGoSource(c.before)))
103                 b.WriteString("[]operation{\n")
104                 for _, op := range c.operations {
105                         b.WriteString("operation{\n")
106                         b.WriteString(fmt.Sprintf("start: Loc{%v, %v},\n", op.startColumn, op.startLine))
107                         b.WriteString(fmt.Sprintf("end: Loc{%v, %v},\n", op.endColumn, op.endLine))
108                         b.WriteString(fmt.Sprintf("text: %v,\n", stringSliceToGoSource(op.text)))
109                         b.WriteString("},\n")
110                 }
111                 b.WriteString("},\n")
112                 b.WriteString(fmt.Sprintf("%v,\n", stringSliceToGoSource(c.after)))
113                 b.WriteString(")\n")
114         }
115
116         b.WriteString("}\n")
117
118         return b.String()
119 }
120
121 func nodeToStringSlice(node ast.Node) []string {
122         var result []string
123         for _, s := range node.(*ast.ArrayLiteral).Value {
124                 result = append(result, s.(*ast.StringLiteral).Value)
125         }
126         return result
127 }
128
129 func nodeToStringSlice2(node ast.Node) []string {
130         var result []string
131         for _, o := range node.(*ast.ArrayLiteral).Value {
132                 result = append(result, getPropertyValue(o, "text").(*ast.StringLiteral).Value)
133         }
134         return result
135 }
136
137 func nodeToInt(node ast.Node) int {
138         return int(node.(*ast.NumberLiteral).Value.(int64))
139 }
140
141 func getChecks(node ast.Node) []check {
142         checks := []check{}
143
144         for _, ce := range getCalls(node, "testApplyEdits") {
145                 if len(ce.ArgumentList) != 3 {
146                         // Wrong function
147                         continue
148                 }
149
150                 before := nodeToStringSlice2(ce.ArgumentList[0])
151                 after := nodeToStringSlice2(ce.ArgumentList[2])
152
153                 var operations []operation
154                 for _, op := range ce.ArgumentList[1].(*ast.ArrayLiteral).Value {
155                         args := getPropertyValue(op, "range").(*ast.NewExpression).ArgumentList
156                         operations = append(operations, operation{
157                                 startLine:   nodeToInt(args[0]) - 1,
158                                 startColumn: nodeToInt(args[1]) - 1,
159                                 endLine:     nodeToInt(args[2]) - 1,
160                                 endColumn:   nodeToInt(args[3]) - 1,
161                                 text:        []string{getPropertyValue(op, "text").(*ast.StringLiteral).Value},
162                         })
163                 }
164
165                 checks = append(checks, check{before, operations, after})
166         }
167
168         for _, ce := range getCalls(node, "testApplyEditsWithSyncedModels") {
169                 if len(ce.ArgumentList) > 3 && ce.ArgumentList[3].(*ast.BooleanLiteral).Value {
170                         // inputEditsAreInvalid == true
171                         continue
172                 }
173
174                 before := nodeToStringSlice(ce.ArgumentList[0])
175                 after := nodeToStringSlice(ce.ArgumentList[2])
176
177                 var operations []operation
178                 for _, op := range getCalls(ce.ArgumentList[1], "editOp") {
179                         operations = append(operations, operation{
180                                 startLine:   nodeToInt(op.ArgumentList[0]) - 1,
181                                 startColumn: nodeToInt(op.ArgumentList[1]) - 1,
182                                 endLine:     nodeToInt(op.ArgumentList[2]) - 1,
183                                 endColumn:   nodeToInt(op.ArgumentList[3]) - 1,
184                                 text:        nodeToStringSlice(op.ArgumentList[4]),
185                         })
186                 }
187
188                 checks = append(checks, check{before, operations, after})
189         }
190
191         return checks
192 }
193
194 func getTests(node ast.Node) []test {
195         tests := []test{}
196         for _, ce := range getCalls(node, "test") {
197                 description := ce.ArgumentList[0].(*ast.StringLiteral).Value
198                 body := ce.ArgumentList[1].(*ast.FunctionLiteral).Body
199                 checks := getChecks(body)
200                 if len(checks) > 0 {
201                         tests = append(tests, test{description, checks})
202                 }
203         }
204         return tests
205 }
206
207 func main() {
208         var tests []test
209
210         for _, filename := range os.Args[1:] {
211                 source, err := ioutil.ReadFile(filename)
212                 if err != nil {
213                         log.Fatalln(err)
214                 }
215
216                 program, err := parser.ParseFile(nil, "", source, parser.IgnoreRegExpErrors)
217                 if err != nil {
218                         log.Fatalln(err)
219                 }
220
221                 tests = append(tests, getTests(program)...)
222         }
223
224         if len(tests) == 0 {
225                 log.Fatalln("no tests found!")
226         }
227
228         fmt.Println("// This file is generated from VSCode model tests by the testgen tool.")
229         fmt.Println("// DO NOT EDIT THIS FILE BY HAND; your changes will be overwritten!\n")
230         fmt.Println("package buffer")
231         fmt.Println(`import "testing"`)
232
233         re := regexp.MustCompile(`[^\w]`)
234         usedNames := map[string]bool{}
235
236         for _, test := range tests {
237                 name := strings.Title(strings.ToLower(test.description))
238                 name = re.ReplaceAllLiteralString(name, "")
239                 if name == "" {
240                         name = "Unnamed"
241                 }
242                 if usedNames[name] {
243                         for i := 2; ; i++ {
244                                 newName := fmt.Sprintf("%v_%v", name, i)
245                                 if !usedNames[newName] {
246                                         name = newName
247                                         break
248                                 }
249                         }
250                 }
251                 usedNames[name] = true
252
253                 fmt.Println(testToGoTest(test, name))
254         }
255 }