]> git.lizzy.rs Git - mt.git/blob - internal/mkserialize/mkserialize.go
s/ToSrvModChan(.*)/ToSrv\1ModChan/g
[mt.git] / internal / mkserialize / mkserialize.go
1 package main
2
3 import (
4         "bufio"
5         "flag"
6         "fmt"
7         "go/ast"
8         "go/printer"
9         "go/token"
10         "go/types"
11         "io"
12         "log"
13         "os"
14         "strconv"
15         "strings"
16
17         "golang.org/x/tools/go/ast/astutil"
18         "golang.org/x/tools/go/packages"
19 )
20
21 var (
22         pkg *packages.Package
23
24         serializeFmt   = make(map[string]string)
25         deserializeFmt = make(map[string]string)
26
27         uint8T = types.Universe.Lookup("uint8").Type()
28         byteT  = types.Universe.Lookup("byte").Type()
29
30         serialize   []*types.Named
31         inSerialize = make(map[string]bool)
32
33         consts = make(map[*ast.StructType][]*ast.Comment)
34 )
35
36 func structPragma(c *ast.Comment, sp *[]func(), expr string, de bool) {
37         fields := strings.SplitN(strings.TrimPrefix(c.Text, "//mt:"), " ", 2)
38         arg := ""
39         if len(fields) == 2 {
40                 arg = fields[1]
41         }
42         switch fields[0] {
43         case "const":
44                 tv, err := types.Eval(pkg.Fset, pkg.Types, c.Slash, arg)
45                 if err != nil {
46                         error(c.Pos(), err)
47                 }
48
49                 if de {
50                         fmt.Println("{")
51                         x := newVar()
52                         fmt.Println("var", x, typeStr(tv.Type))
53                         genSerialize(tv.Type, x, token.NoPos, nil, de)
54                         fmt.Println("if", x, "!=", "(", tv.Value, ")",
55                                 `{ chk(fmt.Errorf("const %v: %v",`, tv.Value, ",", x, ")) }")
56                         fmt.Println("}")
57                 } else {
58                         v := newVar()
59                         fmt.Println("{", v, ":=", arg)
60                         genSerialize(tv.Type, v, c.Slash+token.Pos(len("//mt:const ")), nil, de)
61                         fmt.Println("}")
62                 }
63         case "assert":
64                 fmt.Printf("if !("+arg+") {", expr)
65                 fmt.Printf("chk(errors.New(%q))\n", "assertion failed: "+arg)
66                 fmt.Println("}")
67         case "zlib":
68                 if de {
69                         fmt.Println("{ r, err := zlib.NewReader(byteReader{r}); chk(err)")
70                         *sp = append(*sp, func() {
71                                 fmt.Println("chk(r.Close()) }")
72                         })
73                 } else {
74                         fmt.Println("{ w := zlib.NewWriter(w)")
75                         *sp = append(*sp, func() {
76                                 fmt.Println("chk(w.Close()) }")
77                         })
78                 }
79         case "lenhdr":
80                 if arg != "8" && arg != "16" && arg != "32" {
81                         error(c.Pos(), "usage: //mt:lenhdr (8|16|32)")
82                 }
83
84                 fmt.Println("{")
85
86                 if !de {
87                         fmt.Println("ow := w")
88                         fmt.Println("w := new(bytes.Buffer)")
89                 }
90
91                 var cg ast.CommentGroup
92                 if de {
93                         t := types.Universe.Lookup("uint" + arg).Type()
94                         fmt.Println("var n", t)
95                         genSerialize(t, "n", token.NoPos, nil, de)
96                         if arg == "64" {
97                                 fmt.Println(`if n > math.MaxInt64 { panic("too big len") }`)
98                         }
99                         fmt.Println("r := &io.LimitedReader{r, int64(n)}")
100                 } else {
101                         switch arg {
102                         case "8", "32":
103                                 cg.List = []*ast.Comment{{Text: "//mt:len" + arg}}
104                         case "16":
105                         }
106                 }
107
108                 *sp = append(*sp, func() {
109                         if de {
110                                 fmt.Println("if r.N > 0",
111                                         `{ chk(fmt.Errorf("%d bytes of trailing data", r.N)) }`)
112                         } else {
113                                 fmt.Println("{")
114                                 fmt.Println("buf := w")
115                                 fmt.Println("w := ow")
116                                 byteSlice := types.NewSlice(types.Typ[types.Byte])
117                                 genSerialize(byteSlice, "buf.Bytes()", token.NoPos, &cg, de)
118                                 fmt.Println("}")
119                         }
120
121                         fmt.Println("}")
122                 })
123         case "end":
124                 (*sp)[len(*sp)-1]()
125                 *sp = (*sp)[:len(*sp)-1]
126         case "if":
127                 fmt.Printf(strings.TrimPrefix(c.Text, "//mt:")+" {\n", expr)
128                 *sp = append(*sp, func() {
129                         fmt.Println("}")
130                 })
131         case "ifde":
132                 if !de {
133                         fmt.Println("/*")
134                 }
135         }
136 }
137
138 func genSerialize(t types.Type, expr string, pos token.Pos, doc *ast.CommentGroup, de bool) {
139         var lenhdr types.Type = types.Typ[types.Uint16]
140
141         useMethod := true
142         if doc != nil {
143                 for _, c := range doc.List {
144                         pragma := true
145                         switch c.Text {
146                         case "//mt:utf16":
147                                 t = types.NewSlice(types.Typ[types.Uint16])
148                                 if de {
149                                         v := newVar()
150                                         fmt.Println("var", v, typeStr(t))
151                                         defer fmt.Println(expr + " = string(utf16.Decode(" + v + "))")
152                                         expr = v
153                                 } else {
154                                         expr = "utf16.Encode([]rune(" + expr + "))"
155                                 }
156                                 pos = token.NoPos
157                         case "//mt:raw":
158                                 lenhdr = nil
159                         case "//mt:len8":
160                                 lenhdr = types.Typ[types.Uint8]
161                         case "//mt:len32":
162                                 lenhdr = types.Typ[types.Uint32]
163                         case "//mt:opt":
164                                 fmt.Println("if err := pcall(func() {")
165                                 defer fmt.Println("}); err != nil && err != io.EOF",
166                                         "{ chk(err) }")
167                         default:
168                                 pragma = false
169                         }
170                         if pragma {
171                                 useMethod = false
172                         }
173                 }
174         }
175
176         str := types.TypeString(t, types.RelativeTo(pkg.Types))
177         if de {
178                 if or, ok := deserializeFmt[str]; ok {
179                         fmt.Println("{")
180                         fmt.Println("p := &" + expr)
181                         fmt.Print(or)
182                         fmt.Println("}")
183                         return
184                 }
185         } else {
186                 if or, ok := serializeFmt[str]; ok {
187                         fmt.Println("{")
188                         fmt.Println("x := " + expr)
189                         fmt.Print(or)
190                         fmt.Println("}")
191                         return
192                 }
193         }
194
195         expr = "(" + expr + ")"
196
197         switch t := t.(type) {
198         case *types.Named:
199                 if !useMethod {
200                         t := t.Underlying()
201                         genSerialize(t, "*(*"+typeStr(t)+")("+"&"+expr+")", pos, doc, de)
202                         return
203                 }
204
205                 method := "Serialize"
206                 if de {
207                         method = "Deserialize"
208                 }
209                 for i := 0; i < t.NumMethods(); i++ {
210                         m := t.Method(i)
211                         if m.Name() == method {
212                                 rw := "w"
213                                 if de {
214                                         rw = "r"
215                                 }
216                                 fmt.Println("chk(" + expr + "." + method + "(" + rw + "))")
217                                 return
218                         }
219                 }
220
221                 mkSerialize(t)
222
223                 fmt.Println("if err := pcall(func() {")
224                 if de {
225                         fmt.Println(expr + ".deserialize(r)")
226                 } else {
227                         fmt.Println(expr + ".serialize(w)")
228                 }
229                 fmt.Println("}); err != nil",
230                         `{`,
231                         `if err == io.EOF { chk(io.EOF) };`,
232                         `chk(fmt.Errorf("%s: %w", `+strconv.Quote(t.String())+`, err))`,
233                         `}`)
234         case *types.Struct:
235                 st := pos2node(pos)[0].(*ast.StructType)
236
237                 a := consts[st]
238                 b := st.Fields.List
239
240                 // Merge sorted slices.
241                 c := make([]ast.Node, 0, len(a)+len(b))
242                 for i, j := 0, 0; i < len(a) || j < len(b); {
243                         if i < len(a) && (j >= len(b) || a[i].Pos() < b[j].Pos()) {
244                                 c = append(c, a[i])
245                                 i++
246                         } else {
247                                 c = append(c, b[j])
248                                 j++
249                         }
250                 }
251
252                 var (
253                         stk []func()
254                         i   int
255                 )
256                 for _, field := range c {
257                         switch field := field.(type) {
258                         case *ast.Comment:
259                                 structPragma(field, &stk, expr, de)
260                         case *ast.Field:
261                                 n := len(field.Names)
262                                 if n == 0 {
263                                         n = 1
264                                 }
265                                 for ; n > 0; n-- {
266                                         f := t.Field(i)
267                                         genSerialize(f.Type(), expr+"."+f.Name(), field.Type.Pos(), field.Doc, de)
268                                         i++
269                                 }
270                         }
271                 }
272
273                 if len(stk) > 0 {
274                         error(pos, "missing //mt:end")
275                 }
276         case *types.Basic:
277                 switch t.Kind() {
278                 case types.String:
279                         byteSlice := types.NewSlice(types.Typ[types.Byte])
280                         if de {
281                                 v := newVar()
282                                 fmt.Println("var", v, byteSlice)
283                                 genSerialize(byteSlice, v, token.NoPos, doc, de)
284                                 fmt.Println(expr, "=", "string(", v, ")")
285                         } else {
286                                 genSerialize(byteSlice, "[]byte"+expr, token.NoPos, doc, de)
287                         }
288                 default:
289                         error(pos, "can't serialize ", t)
290                 }
291         case *types.Slice:
292                 if de {
293                         if lenhdr != nil {
294                                 v := newVar()
295                                 fmt.Println("var", v, lenhdr)
296                                 genSerialize(lenhdr, v, pos, nil, de)
297                                 fmt.Printf("%s = make(%v, %s)\n",
298                                         expr, typeStr(t), v)
299                                 genSerialize(types.NewArray(t.Elem(), 0), expr, pos, nil, de)
300                         } else {
301                                 fmt.Println("for {")
302                                 v := newVar()
303                                 fmt.Println("var", v, typeStr(t.Elem()))
304                                 fmt.Println("err := pcall(func() {")
305                                 if pos.IsValid() {
306                                         pos = pos2node(pos)[0].(*ast.ArrayType).Elt.Pos()
307                                 }
308                                 genSerialize(t.Elem(), v, pos, nil, de)
309                                 fmt.Println("})")
310                                 fmt.Println("if err == io.EOF { break }")
311                                 fmt.Println(expr + " = append(" + expr + ", " + v + ")")
312                                 fmt.Println("chk(err)")
313                                 fmt.Println("}")
314                         }
315                 } else {
316                         if lenhdr != nil {
317                                 fmt.Println("if len("+expr+") >",
318                                         "math.Max"+strings.Title(lenhdr.String()),
319                                         "{ chk(ErrTooLong) }")
320                                 genSerialize(lenhdr, lenhdr.String()+"(len("+expr+"))", pos, nil, de)
321                         }
322                         genSerialize(types.NewArray(t.Elem(), 0), expr, pos, nil, de)
323                 }
324         case *types.Array:
325                 et := t.Elem()
326                 if et == byteT || et == uint8T {
327                         if de {
328                                 fmt.Println("{",
329                                         "_, err := io.ReadFull(r, "+expr+"[:]);",
330                                         "chk(err)",
331                                         "}")
332                         } else {
333                                 fmt.Println("{",
334                                         "_, err := w.Write("+expr+"[:]);",
335                                         "chk(err)",
336                                         "}")
337                         }
338                         break
339                 }
340                 i := newVar()
341                 fmt.Println("for", i, ":= range", expr, "{")
342                 if pos.IsValid() {
343                         pos = pos2node(pos)[0].(*ast.ArrayType).Elt.Pos()
344                 }
345                 genSerialize(et, expr+"["+i+"]", pos, nil, de)
346                 fmt.Println("}")
347         default:
348                 error(pos, "can't serialize ", t)
349         }
350 }
351
352 func readOverrides(path string, override map[string]string) {
353         f, err := os.Open(path)
354         if err != nil {
355                 log.Fatal(err)
356         }
357         defer f.Close()
358
359         b := bufio.NewReader(f)
360         line := 0
361         col1 := ""
362         for {
363                 ln, err := b.ReadString('\n')
364                 if err != nil {
365                         if err == io.EOF {
366                                 if len(ln) > 0 {
367                                         log.Fatal("no newline at end of ", f.Name())
368                                 }
369                                 return
370                         }
371                         log.Fatal(err)
372                 }
373                 line++
374
375                 if ln == "\n" {
376                         continue
377                 }
378
379                 fields := strings.SplitN(ln, "\t", 2)
380                 if len(fields) == 1 {
381                         log.Fatal(f.Name(), ":", line, ": missing tab")
382                 }
383                 if fields[0] != "" {
384                         col1 = fields[0]
385                 }
386
387                 if col1 == "" {
388                         fmt.Print(fields[1])
389                         continue
390                 }
391
392                 override[col1] += fields[1]
393         }
394 }
395
396 func mkSerialize(t *types.Named) {
397         if !inSerialize[t.String()] {
398                 serialize = append(serialize, t)
399                 inSerialize[t.String()] = true
400         }
401 }
402
403 var varNo int
404
405 func newVar() string {
406         varNo++
407         return fmt.Sprint("local", varNo)
408 }
409
410 func pos2node(pos token.Pos) []ast.Node {
411         return interval2node(pos, pos)
412 }
413
414 func interval2node(start, end token.Pos) []ast.Node {
415         for _, f := range pkg.Syntax {
416                 if f.Pos() <= start && end <= f.End() {
417                         if path, _ := astutil.PathEnclosingInterval(f, start, end); path != nil {
418                                 return path
419                         }
420                 }
421         }
422         return nil
423 }
424
425 func error(pos token.Pos, a ...interface{}) {
426         if !pos.IsValid() {
427                 log.Fatal(a...)
428         }
429         log.Fatal(append([]interface{}{pkg.Fset.Position(pos), ": "}, a...)...)
430 }
431
432 func typeStr(t types.Type) string {
433         return types.TypeString(t, func(p *types.Package) string {
434                 if p == pkg.Types {
435                         return ""
436                 }
437
438                 return p.Name()
439         })
440 }
441
442 var typeNames = []string{
443         "ToSrvNil",
444         "ToSrvInit",
445         "ToSrvInit2",
446         "ToSrvJoinModChan",
447         "ToSrvLeaveModChan",
448         "ToSrvMsgModChan",
449         "ToSrvPlayerPos",
450         "ToSrvGotBlks",
451         "ToSrvDeletedBlks",
452         "ToSrvInvAction",
453         "ToSrvChatMsg",
454         "ToSrvFallDmg",
455         "ToSrvSelectItem",
456         "ToSrvRespawn",
457         "ToSrvInteract",
458         "ToSrvRemovedSounds",
459         "ToSrvNodeMetaFields",
460         "ToSrvInvFields",
461         "ToSrvReqMedia",
462         "ToSrvCltReady",
463         "ToSrvFirstSRP",
464         "ToSrvSRPBytesA",
465         "ToSrvSRPBytesM",
466
467         "ToCltHello",
468         "ToCltAcceptAuth",
469         "ToCltAcceptSudoMode",
470         "ToCltDenySudoMode",
471         "ToCltDisco",
472         "ToCltBlkData",
473         "ToCltAddNode",
474         "ToCltRemoveNode",
475         "ToCltInv",
476         "ToCltTimeOfDay",
477         "ToCltCSMRestrictionFlags",
478         "ToCltAddPlayerVel",
479         "ToCltMediaPush",
480         "ToCltChatMsg",
481         "ToCltAORmAdd",
482         "ToCltAOMsgs",
483         "ToCltHP",
484         "ToCltMovePlayer",
485         "ToCltDiscoLegacy",
486         "ToCltFOV",
487         "ToCltDeathScreen",
488         "ToCltMedia",
489         "ToCltNodeDefs",
490         "ToCltAnnounceMedia",
491         "ToCltItemDefs",
492         "ToCltPlaySound",
493         "ToCltStopSound",
494         "ToCltPrivs",
495         "ToCltInvFormspec",
496         "ToCltDetachedInv",
497         "ToCltShowFormspec",
498         "ToCltMovement",
499         "ToCltSpawnParticle",
500         "ToCltAddParticleSpawner",
501         "ToCltAddHUD",
502         "ToCltRmHUD",
503         "ToCltChangeHUD",
504         "ToCltHUDFlags",
505         "ToCltSetHotbarParam",
506         "ToCltBreath",
507         "ToCltSkyParams",
508         "ToCltOverrideDayNightRatio",
509         "ToCltLocalPlayerAnim",
510         "ToCltEyeOffset",
511         "ToCltDelParticleSpawner",
512         "ToCltCloudParams",
513         "ToCltFadeSound",
514         "ToCltUpdatePlayerList",
515         "ToCltModChanMsg",
516         "ToCltModChanSig",
517         "ToCltNodeMetasChanged",
518         "ToCltSunParams",
519         "ToCltMoonParams",
520         "ToCltStarParams",
521         "ToCltSRPBytesSaltB",
522         "ToCltFormspecPrepend",
523
524         "AOCmdProps",
525         "AOCmdPos",
526         "AOCmdTextureMod",
527         "AOCmdSprite",
528         "AOCmdHP",
529         "AOCmdArmorGroups",
530         "AOCmdAnim",
531         "AOCmdBonePos",
532         "AOCmdAttach",
533         "AOCmdPhysOverride",
534         "AOCmdSpawnInfant",
535         "AOCmdAnimSpeed",
536
537         "NodeMeta",
538         "MinimapMode",
539         "NodeDef",
540         "PointedNode",
541         "PointedAO",
542 }
543
544 func main() {
545         log.SetFlags(0)
546         log.SetPrefix("mkserialize: ")
547
548         flag.Parse()
549
550         cfg := &packages.Config{Mode: packages.NeedSyntax |
551                 packages.NeedName |
552                 packages.NeedDeps |
553                 packages.NeedImports |
554                 packages.NeedTypes |
555                 packages.NeedTypesInfo}
556         pkgs, err := packages.Load(cfg, flag.Args()...)
557         if err != nil {
558                 log.Fatal(err)
559         }
560         if packages.PrintErrors(pkgs) > 0 {
561                 os.Exit(1)
562         }
563
564         if len(pkgs) != 1 {
565                 log.Fatal("must be exactly 1 package")
566         }
567         pkg = pkgs[0]
568
569         fmt.Println("package", pkg.Name)
570
571         readOverrides("serialize.fmt", serializeFmt)
572         readOverrides("deserialize.fmt", deserializeFmt)
573
574         for _, f := range pkg.Syntax {
575                 for _, cg := range f.Comments {
576                         for _, c := range cg.List {
577                                 if !strings.HasPrefix(c.Text, "//mt:") {
578                                         continue
579                                 }
580                                 st := interval2node(c.Pos(), c.End())[1].(*ast.StructType)
581                                 consts[st] = append(consts[st], c)
582                         }
583                 }
584         }
585
586         for _, name := range typeNames {
587                 obj := pkg.Types.Scope().Lookup(name)
588                 if obj == nil {
589                         log.Println("undeclared identifier: ", name)
590                         continue
591                 }
592                 mkSerialize(obj.Type().(*types.Named))
593         }
594
595         for i := 0; i < len(serialize); i++ {
596                 for _, de := range []bool{false, true} {
597                         t := serialize[i]
598                         sig := "serialize(w io.Writer)"
599                         if de {
600                                 sig = "deserialize(r io.Reader)"
601                         }
602                         fmt.Println("\nfunc (obj *" + t.Obj().Name() + ") " + sig + " {")
603                         pos := t.Obj().Pos()
604                         tExpr := pos2node(pos)[1].(*ast.TypeSpec).Type
605                         var b strings.Builder
606                         printer.Fprint(&b, pkg.Fset, tExpr)
607                         genSerialize(pkg.TypesInfo.Types[tExpr].Type, "*(*("+b.String()+"))(obj)", tExpr.Pos(), nil, de)
608                         fmt.Println("}")
609                 }
610         }
611 }