]> git.lizzy.rs Git - mt.git/blob - internal/mkserialize/mkserialize.go
be1ba223258f47ecf3a311d7f07e87709d80b67e
[mt.git] / internal / mkserialize / mkserialize.go
1 package main
2
3 import (
4         "bufio"
5         "flag"
6         "fmt"
7         "go/ast"
8         "go/printer"
9         "go/token"
10         "go/types"
11         "io"
12         "log"
13         "os"
14         "strconv"
15         "strings"
16
17         "golang.org/x/tools/go/ast/astutil"
18         "golang.org/x/tools/go/packages"
19 )
20
21 var (
22         pkg *packages.Package
23
24         serializeFmt   = make(map[string]string)
25         deserializeFmt = make(map[string]string)
26
27         uint8T = types.Universe.Lookup("uint8").Type()
28         byteT  = types.Universe.Lookup("byte").Type()
29
30         serialize   []*types.Named
31         inSerialize = make(map[string]bool)
32
33         consts = make(map[*ast.StructType][]*ast.Comment)
34 )
35
36 func structPragma(c *ast.Comment, sp *[]func(), expr string, de bool) {
37         fields := strings.SplitN(strings.TrimPrefix(c.Text, "//mt:"), " ", 2)
38         arg := ""
39         if len(fields) == 2 {
40                 arg = fields[1]
41         }
42         switch fields[0] {
43         case "const":
44                 tv, err := types.Eval(pkg.Fset, pkg.Types, c.Slash, arg)
45                 if err != nil {
46                         error(c.Pos(), err)
47                 }
48
49                 if de {
50                         fmt.Println("{")
51                         x := newVar()
52                         fmt.Println("var", x, typeStr(tv.Type))
53                         y := newVar()
54                         fmt.Println(y, ":=", arg)
55                         genSerialize(tv.Type, x, token.NoPos, nil, de)
56                         fmt.Println("if", x, "!=", y,
57                                 `{ chk(fmt.Errorf("const %v: %v",`, strconv.Quote(arg), ",", x, ")) }")
58                         fmt.Println("}")
59                 } else {
60                         v := newVar()
61                         fmt.Println("{", v, ":=", arg)
62                         genSerialize(tv.Type, v, c.Slash+token.Pos(len("//mt:const ")), nil, de)
63                         fmt.Println("}")
64                 }
65         case "assert":
66                 fmt.Printf("if !("+arg+") {", expr)
67                 fmt.Printf("chk(errors.New(%q))\n", "assertion failed: "+arg)
68                 fmt.Println("}")
69         case "zlib":
70                 if de {
71                         fmt.Println("{ r, err := zlib.NewReader(byteReader{r}); chk(err)")
72                         *sp = append(*sp, func() {
73                                 fmt.Println("chk(r.Close()) }")
74                         })
75                 } else {
76                         fmt.Println("{ w := zlib.NewWriter(w)")
77                         *sp = append(*sp, func() {
78                                 fmt.Println("chk(w.Close()) }")
79                         })
80                 }
81         case "lenhdr":
82                 if arg != "8" && arg != "16" && arg != "32" {
83                         error(c.Pos(), "usage: //mt:lenhdr (8|16|32)")
84                 }
85
86                 fmt.Println("{")
87
88                 if !de {
89                         fmt.Println("ow := w")
90                         fmt.Println("w := new(bytes.Buffer)")
91                 }
92
93                 var cg ast.CommentGroup
94                 if de {
95                         t := types.Universe.Lookup("uint" + arg).Type()
96                         fmt.Println("var n", t)
97                         genSerialize(t, "n", token.NoPos, nil, de)
98                         if arg == "64" {
99                                 fmt.Println(`if n > math.MaxInt64 { panic("too big len") }`)
100                         }
101                         fmt.Println("r := &io.LimitedReader{R: r, N: int64(n)}")
102                 } else {
103                         switch arg {
104                         case "8", "32":
105                                 cg.List = []*ast.Comment{{Text: "//mt:len" + arg}}
106                         case "16":
107                         }
108                 }
109
110                 *sp = append(*sp, func() {
111                         if de {
112                                 fmt.Println("if r.N > 0",
113                                         `{ chk(fmt.Errorf("%d bytes of trailing data", r.N)) }`)
114                         } else {
115                                 fmt.Println("{")
116                                 fmt.Println("buf := w")
117                                 fmt.Println("w := ow")
118                                 byteSlice := types.NewSlice(types.Typ[types.Byte])
119                                 genSerialize(byteSlice, "buf.Bytes()", token.NoPos, &cg, de)
120                                 fmt.Println("}")
121                         }
122
123                         fmt.Println("}")
124                 })
125         case "end":
126                 (*sp)[len(*sp)-1]()
127                 *sp = (*sp)[:len(*sp)-1]
128         case "if":
129                 fmt.Printf(strings.TrimPrefix(c.Text, "//mt:")+" {\n", expr)
130                 *sp = append(*sp, func() {
131                         fmt.Println("}")
132                 })
133         case "ifde":
134                 if !de {
135                         fmt.Println("/*")
136                 }
137         }
138 }
139
140 func genSerialize(t types.Type, expr string, pos token.Pos, doc *ast.CommentGroup, de bool) {
141         var lenhdr types.Type = types.Typ[types.Uint16]
142
143         useMethod := true
144         if doc != nil {
145                 for _, c := range doc.List {
146                         pragma := true
147                         switch c.Text {
148                         case "//mt:32to16":
149                                 t = types.Typ[types.Int16]
150                                 if de {
151                                         v := newVar()
152                                         fmt.Println("var", v, "int16")
153                                         defer fmt.Println(expr + " = int32(" + v + ")")
154                                         expr = v
155                                 } else {
156                                         expr = "int16(" + expr + ")"
157                                 }
158                                 pos = token.NoPos
159                         case "//mt:32tou16":
160                                 t = types.Typ[types.Uint16]
161                                 if de {
162                                         v := newVar()
163                                         fmt.Println("var", v, "uint16")
164                                         defer fmt.Println(expr + " = int32(" + v + ")")
165                                         expr = v
166                                 } else {
167                                         expr = "uint16(" + expr + ")"
168                                 }
169                                 pos = token.NoPos
170                         case "//mt:utf16":
171                                 t = types.NewSlice(types.Typ[types.Uint16])
172                                 if de {
173                                         v := newVar()
174                                         fmt.Println("var", v, typeStr(t))
175                                         defer fmt.Println(expr + " = string(utf16.Decode(" + v + "))")
176                                         expr = v
177                                 } else {
178                                         v := newVar()
179                                         fmt.Println(v, ":= utf16.Encode([]rune("+expr+"))")
180                                         expr = v
181                                 }
182                                 pos = token.NoPos
183                         case "//mt:raw":
184                                 lenhdr = nil
185                         case "//mt:len8":
186                                 lenhdr = types.Typ[types.Uint8]
187                         case "//mt:len32":
188                                 lenhdr = types.Typ[types.Uint32]
189                         case "//mt:opt":
190                                 fmt.Println("if err := pcall(func() {")
191                                 defer fmt.Println("}); err != nil && err != io.EOF",
192                                         "{ chk(err) }")
193                         default:
194                                 pragma = false
195                         }
196                         if pragma {
197                                 useMethod = false
198                         }
199                 }
200         }
201
202         str := types.TypeString(t, types.RelativeTo(pkg.Types))
203         if de {
204                 if or, ok := deserializeFmt[str]; ok {
205                         fmt.Println("{")
206                         fmt.Println("p := &" + expr)
207                         fmt.Print(or)
208                         fmt.Println("}")
209                         return
210                 }
211         } else {
212                 if or, ok := serializeFmt[str]; ok {
213                         fmt.Println("{")
214                         fmt.Println("x := " + expr)
215                         fmt.Print(or)
216                         fmt.Println("}")
217                         return
218                 }
219         }
220
221         expr = "(" + expr + ")"
222
223         switch t := t.(type) {
224         case *types.Named:
225                 if !useMethod {
226                         t := t.Underlying()
227                         genSerialize(t, "*(*"+typeStr(t)+")("+"&"+expr+")", pos, doc, de)
228                         return
229                 }
230
231                 method := "Serialize"
232                 if de {
233                         method = "Deserialize"
234                 }
235                 for i := 0; i < t.NumMethods(); i++ {
236                         m := t.Method(i)
237                         if m.Name() == method {
238                                 rw := "w"
239                                 if de {
240                                         rw = "r"
241                                 }
242                                 fmt.Println("chk(" + expr + "." + method + "(" + rw + "))")
243                                 return
244                         }
245                 }
246
247                 mkSerialize(t)
248
249                 fmt.Println("if err := pcall(func() {")
250                 if de {
251                         fmt.Println(expr + ".deserialize(r)")
252                 } else {
253                         fmt.Println(expr + ".serialize(w)")
254                 }
255                 fmt.Println("}); err != nil",
256                         `{`,
257                         `if err == io.EOF { chk(io.EOF) };`,
258                         `chk(fmt.Errorf("%s: %w", `+strconv.Quote(t.String())+`, err))`,
259                         `}`)
260         case *types.Struct:
261                 st := pos2node(pos)[0].(*ast.StructType)
262
263                 a := consts[st]
264                 b := st.Fields.List
265
266                 // Merge sorted slices.
267                 c := make([]ast.Node, 0, len(a)+len(b))
268                 for i, j := 0, 0; i < len(a) || j < len(b); {
269                         if i < len(a) && (j >= len(b) || a[i].Pos() < b[j].Pos()) {
270                                 c = append(c, a[i])
271                                 i++
272                         } else {
273                                 c = append(c, b[j])
274                                 j++
275                         }
276                 }
277
278                 var (
279                         stk []func()
280                         i   int
281                 )
282                 for _, field := range c {
283                         switch field := field.(type) {
284                         case *ast.Comment:
285                                 structPragma(field, &stk, expr, de)
286                         case *ast.Field:
287                                 n := len(field.Names)
288                                 if n == 0 {
289                                         n = 1
290                                 }
291                                 for ; n > 0; n-- {
292                                         f := t.Field(i)
293                                         genSerialize(f.Type(), expr+"."+f.Name(), field.Type.Pos(), field.Doc, de)
294                                         i++
295                                 }
296                         }
297                 }
298
299                 if len(stk) > 0 {
300                         error(pos, "missing //mt:end")
301                 }
302         case *types.Basic:
303                 switch t.Kind() {
304                 case types.String:
305                         byteSlice := types.NewSlice(types.Typ[types.Byte])
306                         if de {
307                                 v := newVar()
308                                 fmt.Println("var", v, byteSlice)
309                                 genSerialize(byteSlice, v, token.NoPos, doc, de)
310                                 fmt.Println(expr, "=", "string(", v, ")")
311                         } else {
312                                 genSerialize(byteSlice, "[]byte"+expr, token.NoPos, doc, de)
313                         }
314                 default:
315                         error(pos, "can't serialize ", t)
316                 }
317         case *types.Slice:
318                 if de {
319                         if lenhdr != nil {
320                                 v := newVar()
321                                 fmt.Println("var", v, lenhdr)
322                                 genSerialize(lenhdr, v, pos, nil, de)
323                                 fmt.Printf("%s = make(%v, %s)\n",
324                                         expr, typeStr(t), v)
325                                 genSerialize(types.NewArray(t.Elem(), 0), expr, pos, nil, de)
326                         } else {
327                                 if b, ok := t.Elem().(*types.Basic); ok && b.Kind() == types.Byte {
328                                         fmt.Println("{")
329                                         fmt.Println("var err error")
330                                         fmt.Println(expr, ", err = io.ReadAll(r)")
331                                         fmt.Println("chk(err)")
332                                         fmt.Println("}")
333                                         return
334                                 }
335
336                                 fmt.Println("for {")
337                                 v := newVar()
338                                 fmt.Println("var", v, typeStr(t.Elem()))
339                                 fmt.Println("err := pcall(func() {")
340                                 if pos.IsValid() {
341                                         pos = pos2node(pos)[0].(*ast.ArrayType).Elt.Pos()
342                                 }
343                                 genSerialize(t.Elem(), v, pos, nil, de)
344                                 fmt.Println("})")
345                                 fmt.Println("if err == io.EOF { break }")
346                                 fmt.Println(expr + " = append(" + expr + ", " + v + ")")
347                                 fmt.Println("chk(err)")
348                                 fmt.Println("}")
349                         }
350                 } else {
351                         if lenhdr != nil {
352                                 fmt.Println("if len("+expr+") >",
353                                         "math.Max"+strings.Title(lenhdr.String()),
354                                         "{ chk(ErrTooLong) }")
355                                 genSerialize(lenhdr, lenhdr.String()+"(len("+expr+"))", pos, nil, de)
356                         }
357                         genSerialize(types.NewArray(t.Elem(), 0), expr, pos, nil, de)
358                 }
359         case *types.Array:
360                 et := t.Elem()
361                 if et == byteT || et == uint8T {
362                         if de {
363                                 fmt.Println("{",
364                                         "_, err := io.ReadFull(r, "+expr+"[:]);",
365                                         "chk(err)",
366                                         "}")
367                         } else {
368                                 fmt.Println("{",
369                                         "_, err := w.Write("+expr+"[:]);",
370                                         "chk(err)",
371                                         "}")
372                         }
373                         break
374                 }
375                 i := newVar()
376                 fmt.Println("for", i, ":= range", expr, "{")
377                 if pos.IsValid() {
378                         pos = pos2node(pos)[0].(*ast.ArrayType).Elt.Pos()
379                 }
380                 genSerialize(et, expr+"["+i+"]", pos, nil, de)
381                 fmt.Println("}")
382         default:
383                 error(pos, "can't serialize ", t)
384         }
385 }
386
387 func readOverrides(path string, override map[string]string) {
388         f, err := os.Open(path)
389         if err != nil {
390                 log.Fatal(err)
391         }
392         defer f.Close()
393
394         b := bufio.NewReader(f)
395         line := 0
396         col1 := ""
397         for {
398                 ln, err := b.ReadString('\n')
399                 if err != nil {
400                         if err == io.EOF {
401                                 if len(ln) > 0 {
402                                         log.Fatal("no newline at end of ", f.Name())
403                                 }
404                                 return
405                         }
406                         log.Fatal(err)
407                 }
408                 line++
409
410                 if ln == "\n" {
411                         continue
412                 }
413
414                 fields := strings.SplitN(ln, "\t", 2)
415                 if len(fields) == 1 {
416                         log.Fatal(f.Name(), ":", line, ": missing tab")
417                 }
418                 if fields[0] != "" {
419                         col1 = fields[0]
420                 }
421
422                 if col1 == "" {
423                         fmt.Print(fields[1])
424                         continue
425                 }
426
427                 override[col1] += fields[1]
428         }
429 }
430
431 func mkSerialize(t *types.Named) {
432         if !inSerialize[t.String()] {
433                 serialize = append(serialize, t)
434                 inSerialize[t.String()] = true
435         }
436 }
437
438 var varNo int
439
440 func newVar() string {
441         varNo++
442         return fmt.Sprint("local", varNo)
443 }
444
445 func pos2node(pos token.Pos) []ast.Node {
446         return interval2node(pos, pos)
447 }
448
449 func interval2node(start, end token.Pos) []ast.Node {
450         for _, f := range pkg.Syntax {
451                 if f.Pos() <= start && end <= f.End() {
452                         if path, _ := astutil.PathEnclosingInterval(f, start, end); path != nil {
453                                 return path
454                         }
455                 }
456         }
457         return nil
458 }
459
460 func error(pos token.Pos, a ...interface{}) {
461         if !pos.IsValid() {
462                 log.Fatal(a...)
463         }
464         log.Fatal(append([]interface{}{pkg.Fset.Position(pos), ": "}, a...)...)
465 }
466
467 func typeStr(t types.Type) string {
468         return types.TypeString(t, func(p *types.Package) string {
469                 if p == pkg.Types {
470                         return ""
471                 }
472
473                 return p.Name()
474         })
475 }
476
477 var typeNames = []string{
478         "ToSrvNil",
479         "ToSrvInit",
480         "ToSrvInit2",
481         "ToSrvJoinModChan",
482         "ToSrvLeaveModChan",
483         "ToSrvMsgModChan",
484         "ToSrvPlayerPos",
485         "ToSrvGotBlks",
486         "ToSrvDeletedBlks",
487         "ToSrvInvAction",
488         "ToSrvChatMsg",
489         "ToSrvFallDmg",
490         "ToSrvSelectItem",
491         "ToSrvRespawn",
492         "ToSrvInteract",
493         "ToSrvRemovedSounds",
494         "ToSrvNodeMetaFields",
495         "ToSrvInvFields",
496         "ToSrvReqMedia",
497         "ToSrvCltReady",
498         "ToSrvFirstSRP",
499         "ToSrvSRPBytesA",
500         "ToSrvSRPBytesM",
501
502         "ToCltHello",
503         "ToCltAcceptAuth",
504         "ToCltAcceptSudoMode",
505         "ToCltDenySudoMode",
506         "ToCltKick",
507         "ToCltBlkData",
508         "ToCltAddNode",
509         "ToCltRemoveNode",
510         "ToCltInv",
511         "ToCltTimeOfDay",
512         "ToCltCSMRestrictionFlags",
513         "ToCltAddPlayerVel",
514         "ToCltMediaPush",
515         "ToCltChatMsg",
516         "ToCltAORmAdd",
517         "ToCltAOMsgs",
518         "ToCltHP",
519         "ToCltMovePlayer",
520         "ToCltLegacyKick",
521         "ToCltFOV",
522         "ToCltDeathScreen",
523         "ToCltMedia",
524         "ToCltNodeDefs",
525         "ToCltAnnounceMedia",
526         "ToCltItemDefs",
527         "ToCltPlaySound",
528         "ToCltStopSound",
529         "ToCltPrivs",
530         "ToCltInvFormspec",
531         "ToCltDetachedInv",
532         "ToCltShowFormspec",
533         "ToCltMovement",
534         "ToCltSpawnParticle",
535         "ToCltAddParticleSpawner",
536         "ToCltAddHUD",
537         "ToCltRmHUD",
538         "ToCltChangeHUD",
539         "ToCltHUDFlags",
540         "ToCltSetHotbarParam",
541         "ToCltBreath",
542         "ToCltSkyParams",
543         "ToCltOverrideDayNightRatio",
544         "ToCltLocalPlayerAnim",
545         "ToCltEyeOffset",
546         "ToCltDelParticleSpawner",
547         "ToCltCloudParams",
548         "ToCltFadeSound",
549         "ToCltUpdatePlayerList",
550         "ToCltModChanMsg",
551         "ToCltModChanSig",
552         "ToCltNodeMetasChanged",
553         "ToCltSunParams",
554         "ToCltMoonParams",
555         "ToCltStarParams",
556         "ToCltSRPBytesSaltB",
557         "ToCltFormspecPrepend",
558
559         "AOCmdProps",
560         "AOCmdPos",
561         "AOCmdTextureMod",
562         "AOCmdSprite",
563         "AOCmdHP",
564         "AOCmdArmorGroups",
565         "AOCmdAnim",
566         "AOCmdBonePos",
567         "AOCmdAttach",
568         "AOCmdPhysOverride",
569         "AOCmdSpawnInfant",
570         "AOCmdAnimSpeed",
571
572         "NodeMeta",
573         "MinimapMode",
574         "NodeDef",
575         "PointedNode",
576         "PointedAO",
577 }
578
579 func main() {
580         log.SetFlags(0)
581         log.SetPrefix("mkserialize: ")
582
583         flag.Parse()
584
585         cfg := &packages.Config{Mode: packages.NeedSyntax |
586                 packages.NeedName |
587                 packages.NeedDeps |
588                 packages.NeedImports |
589                 packages.NeedTypes |
590                 packages.NeedTypesInfo}
591         pkgs, err := packages.Load(cfg, flag.Args()...)
592         if err != nil {
593                 log.Fatal(err)
594         }
595         if packages.PrintErrors(pkgs) > 0 {
596                 os.Exit(1)
597         }
598
599         if len(pkgs) != 1 {
600                 log.Fatal("must be exactly 1 package")
601         }
602         pkg = pkgs[0]
603
604         fmt.Println("package", pkg.Name)
605
606         readOverrides("serialize.fmt", serializeFmt)
607         readOverrides("deserialize.fmt", deserializeFmt)
608
609         for _, f := range pkg.Syntax {
610                 for _, cg := range f.Comments {
611                         for _, c := range cg.List {
612                                 if !strings.HasPrefix(c.Text, "//mt:") {
613                                         continue
614                                 }
615                                 st := interval2node(c.Pos(), c.End())[1].(*ast.StructType)
616                                 consts[st] = append(consts[st], c)
617                         }
618                 }
619         }
620
621         for _, name := range typeNames {
622                 obj := pkg.Types.Scope().Lookup(name)
623                 if obj == nil {
624                         log.Println("undeclared identifier: ", name)
625                         continue
626                 }
627                 mkSerialize(obj.Type().(*types.Named))
628         }
629
630         for i := 0; i < len(serialize); i++ {
631                 for _, de := range []bool{false, true} {
632                         t := serialize[i]
633                         sig := "serialize(w io.Writer)"
634                         if de {
635                                 sig = "deserialize(r io.Reader)"
636                         }
637                         fmt.Println("\nfunc (obj *" + t.Obj().Name() + ") " + sig + " {")
638                         pos := t.Obj().Pos()
639                         tExpr := pos2node(pos)[1].(*ast.TypeSpec).Type
640                         var b strings.Builder
641                         printer.Fprint(&b, pkg.Fset, tExpr)
642                         genSerialize(pkg.TypesInfo.Types[tExpr].Type, "*(*("+b.String()+"))(obj)", tExpr.Pos(), nil, de)
643                         fmt.Println("}")
644                 }
645         }
646 }