]> git.lizzy.rs Git - mt.git/blob - internal/mkserialize/mkserialize.go
fix errors in usage of zstd
[mt.git] / internal / mkserialize / mkserialize.go
1 package main
2
3 import (
4         "bufio"
5         "flag"
6         "fmt"
7         "go/ast"
8         "go/printer"
9         "go/token"
10         "go/types"
11         "io"
12         "log"
13         "os"
14         "strconv"
15         "strings"
16
17         "golang.org/x/tools/go/ast/astutil"
18         "golang.org/x/tools/go/packages"
19 )
20
21 var (
22         pkg *packages.Package
23
24         serializeFmt   = make(map[string]string)
25         deserializeFmt = make(map[string]string)
26
27         uint8T = types.Universe.Lookup("uint8").Type()
28         byteT  = types.Universe.Lookup("byte").Type()
29
30         serialize   []*types.Named
31         inSerialize = make(map[string]bool)
32
33         consts = make(map[*ast.StructType][]*ast.Comment)
34 )
35
36 func structPragma(c *ast.Comment, sp *[]func(), expr string, de bool) {
37         fields := strings.SplitN(strings.TrimPrefix(c.Text, "//mt:"), " ", 2)
38         arg := ""
39         if len(fields) == 2 {
40                 arg = fields[1]
41         }
42         switch fields[0] {
43         case "const":
44                 tv, err := types.Eval(pkg.Fset, pkg.Types, c.Slash, arg)
45                 if err != nil {
46                         error(c.Pos(), err)
47                 }
48
49                 if de {
50                         fmt.Println("{")
51                         x := newVar()
52                         fmt.Println("var", x, typeStr(tv.Type))
53                         y := newVar()
54                         fmt.Println(y, ":=", arg)
55                         genSerialize(tv.Type, x, token.NoPos, nil, de)
56                         fmt.Println("if", x, "!=", y,
57                                 `{ chk(fmt.Errorf("const %v: %v",`, strconv.Quote(arg), ",", x, ")) }")
58                         fmt.Println("}")
59                 } else {
60                         v := newVar()
61                         fmt.Println("{", v, ":=", arg)
62                         genSerialize(tv.Type, v, c.Slash+token.Pos(len("//mt:const ")), nil, de)
63                         fmt.Println("}")
64                 }
65         case "assert":
66                 fmt.Printf("if !("+arg+") {", expr)
67                 fmt.Printf("chk(errors.New(%q))\n", "assertion failed: "+arg)
68                 fmt.Println("}")
69         case "zlib":
70                 if de {
71                         fmt.Println("{ r, err := zlib.NewReader(byteReader{r}); chk(err)")
72                         *sp = append(*sp, func() {
73                                 fmt.Println("chk(r.Close()) }")
74                         })
75                 } else {
76                         fmt.Println("{ w := zlib.NewWriter(w)")
77                         *sp = append(*sp, func() {
78                                 fmt.Println("chk(w.Close()) }")
79                         })
80                 }
81         case "zstd":
82                 if de {
83                         fmt.Println("{ r, err := zstd.NewReader(byteReader{r}); chk(err)")
84                         *sp = append(*sp, func() {
85                                 fmt.Println("r.Close() }")
86                         })
87                 } else {
88                         fmt.Println("{ w, err := zstd.NewWriter(w); chk(err)")
89                         *sp = append(*sp, func() {
90                                 fmt.Println("chk(w.Close()) }")
91                         })
92                 }
93         case "lenhdr":
94                 if arg != "8" && arg != "16" && arg != "32" {
95                         error(c.Pos(), "usage: //mt:lenhdr (8|16|32)")
96                 }
97
98                 fmt.Println("{")
99
100                 if !de {
101                         fmt.Println("ow := w")
102                         fmt.Println("w := new(bytes.Buffer)")
103                 }
104
105                 var cg ast.CommentGroup
106                 if de {
107                         t := types.Universe.Lookup("uint" + arg).Type()
108                         fmt.Println("var n", t)
109                         genSerialize(t, "n", token.NoPos, nil, de)
110                         if arg == "64" {
111                                 fmt.Println(`if n > math.MaxInt64 { panic("too big len") }`)
112                         }
113                         fmt.Println("r := &io.LimitedReader{R: r, N: int64(n)}")
114                 } else {
115                         switch arg {
116                         case "8", "32":
117                                 cg.List = []*ast.Comment{{Text: "//mt:len" + arg}}
118                         case "16":
119                         }
120                 }
121
122                 *sp = append(*sp, func() {
123                         if de {
124                                 fmt.Println("if r.N > 0",
125                                         `{ chk(fmt.Errorf("%d bytes of trailing data", r.N)) }`)
126                         } else {
127                                 fmt.Println("{")
128                                 fmt.Println("buf := w")
129                                 fmt.Println("w := ow")
130                                 byteSlice := types.NewSlice(types.Typ[types.Byte])
131                                 genSerialize(byteSlice, "buf.Bytes()", token.NoPos, &cg, de)
132                                 fmt.Println("}")
133                         }
134
135                         fmt.Println("}")
136                 })
137         case "end":
138                 (*sp)[len(*sp)-1]()
139                 *sp = (*sp)[:len(*sp)-1]
140         case "if":
141                 fmt.Printf(strings.TrimPrefix(c.Text, "//mt:")+" {\n", expr)
142                 *sp = append(*sp, func() {
143                         fmt.Println("}")
144                 })
145         case "ifde":
146                 if !de {
147                         fmt.Println("/*")
148                 }
149         }
150 }
151
152 func genSerialize(t types.Type, expr string, pos token.Pos, doc *ast.CommentGroup, de bool) {
153         var lenhdr types.Type = types.Typ[types.Uint16]
154
155         useMethod := true
156         if doc != nil {
157                 for _, c := range doc.List {
158                         pragma := true
159                         switch c.Text {
160                         case "//mt:32to16":
161                                 t = types.Typ[types.Int16]
162                                 if de {
163                                         v := newVar()
164                                         fmt.Println("var", v, "int16")
165                                         defer fmt.Println(expr + " = int32(" + v + ")")
166                                         expr = v
167                                 } else {
168                                         expr = "int16(" + expr + ")"
169                                 }
170                                 pos = token.NoPos
171                         case "//mt:32tou16":
172                                 t = types.Typ[types.Uint16]
173                                 if de {
174                                         v := newVar()
175                                         fmt.Println("var", v, "uint16")
176                                         defer fmt.Println(expr + " = int32(" + v + ")")
177                                         expr = v
178                                 } else {
179                                         expr = "uint16(" + expr + ")"
180                                 }
181                                 pos = token.NoPos
182                         case "//mt:utf16":
183                                 t = types.NewSlice(types.Typ[types.Uint16])
184                                 if de {
185                                         v := newVar()
186                                         fmt.Println("var", v, typeStr(t))
187                                         defer fmt.Println(expr + " = string(utf16.Decode(" + v + "))")
188                                         expr = v
189                                 } else {
190                                         v := newVar()
191                                         fmt.Println(v, ":= utf16.Encode([]rune("+expr+"))")
192                                         expr = v
193                                 }
194                                 pos = token.NoPos
195                         case "//mt:raw":
196                                 lenhdr = nil
197                         case "//mt:len8":
198                                 lenhdr = types.Typ[types.Uint8]
199                         case "//mt:len32":
200                                 lenhdr = types.Typ[types.Uint32]
201                         case "//mt:opt":
202                                 fmt.Println("if err := pcall(func() {")
203                                 defer fmt.Println("}); err != nil && err != io.EOF",
204                                         "{ chk(err) }")
205                         default:
206                                 pragma = false
207                         }
208                         if pragma {
209                                 useMethod = false
210                         }
211                 }
212         }
213
214         str := types.TypeString(t, types.RelativeTo(pkg.Types))
215         if de {
216                 if or, ok := deserializeFmt[str]; ok {
217                         fmt.Println("{")
218                         fmt.Println("p := &" + expr)
219                         fmt.Print(or)
220                         fmt.Println("}")
221                         return
222                 }
223         } else {
224                 if or, ok := serializeFmt[str]; ok {
225                         fmt.Println("{")
226                         fmt.Println("x := " + expr)
227                         fmt.Print(or)
228                         fmt.Println("}")
229                         return
230                 }
231         }
232
233         expr = "(" + expr + ")"
234
235         switch t := t.(type) {
236         case *types.Named:
237                 if !useMethod {
238                         t := t.Underlying()
239                         genSerialize(t, "*(*"+typeStr(t)+")("+"&"+expr+")", pos, doc, de)
240                         return
241                 }
242
243                 method := "Serialize"
244                 if de {
245                         method = "Deserialize"
246                 }
247                 for i := 0; i < t.NumMethods(); i++ {
248                         m := t.Method(i)
249                         if m.Name() == method {
250                                 rw := "w"
251                                 if de {
252                                         rw = "r"
253                                 }
254                                 fmt.Println("chk(" + expr + "." + method + "(" + rw + "))")
255                                 return
256                         }
257                 }
258
259                 mkSerialize(t)
260
261                 fmt.Println("if err := pcall(func() {")
262                 if de {
263                         fmt.Println(expr + ".deserialize(r)")
264                 } else {
265                         fmt.Println(expr + ".serialize(w)")
266                 }
267                 fmt.Println("}); err != nil",
268                         `{`,
269                         `if err == io.EOF { chk(io.EOF) };`,
270                         `chk(fmt.Errorf("%s: %w", `+strconv.Quote(t.String())+`, err))`,
271                         `}`)
272         case *types.Struct:
273                 st := pos2node(pos)[0].(*ast.StructType)
274
275                 a := consts[st]
276                 b := st.Fields.List
277
278                 // Merge sorted slices.
279                 c := make([]ast.Node, 0, len(a)+len(b))
280                 for i, j := 0, 0; i < len(a) || j < len(b); {
281                         if i < len(a) && (j >= len(b) || a[i].Pos() < b[j].Pos()) {
282                                 c = append(c, a[i])
283                                 i++
284                         } else {
285                                 c = append(c, b[j])
286                                 j++
287                         }
288                 }
289
290                 var (
291                         stk []func()
292                         i   int
293                 )
294                 for _, field := range c {
295                         switch field := field.(type) {
296                         case *ast.Comment:
297                                 structPragma(field, &stk, expr, de)
298                         case *ast.Field:
299                                 n := len(field.Names)
300                                 if n == 0 {
301                                         n = 1
302                                 }
303                                 for ; n > 0; n-- {
304                                         f := t.Field(i)
305                                         genSerialize(f.Type(), expr+"."+f.Name(), field.Type.Pos(), field.Doc, de)
306                                         i++
307                                 }
308                         }
309                 }
310
311                 if len(stk) > 0 {
312                         error(pos, "missing //mt:end")
313                 }
314         case *types.Basic:
315                 switch t.Kind() {
316                 case types.String:
317                         byteSlice := types.NewSlice(types.Typ[types.Byte])
318                         if de {
319                                 v := newVar()
320                                 fmt.Println("var", v, byteSlice)
321                                 genSerialize(byteSlice, v, token.NoPos, doc, de)
322                                 fmt.Println(expr, "=", "string(", v, ")")
323                         } else {
324                                 genSerialize(byteSlice, "[]byte"+expr, token.NoPos, doc, de)
325                         }
326                 default:
327                         error(pos, "can't serialize ", t)
328                 }
329         case *types.Slice:
330                 if de {
331                         if lenhdr != nil {
332                                 v := newVar()
333                                 fmt.Println("var", v, lenhdr)
334                                 genSerialize(lenhdr, v, pos, nil, de)
335                                 fmt.Printf("%s = make(%v, %s)\n",
336                                         expr, typeStr(t), v)
337                                 genSerialize(types.NewArray(t.Elem(), 0), expr, pos, nil, de)
338                         } else {
339                                 if b, ok := t.Elem().(*types.Basic); ok && b.Kind() == types.Byte {
340                                         fmt.Println("{")
341                                         fmt.Println("var err error")
342                                         fmt.Println(expr, ", err = io.ReadAll(r)")
343                                         fmt.Println("chk(err)")
344                                         fmt.Println("}")
345                                         return
346                                 }
347
348                                 fmt.Println("for {")
349                                 v := newVar()
350                                 fmt.Println("var", v, typeStr(t.Elem()))
351                                 fmt.Println("err := pcall(func() {")
352                                 if pos.IsValid() {
353                                         pos = pos2node(pos)[0].(*ast.ArrayType).Elt.Pos()
354                                 }
355                                 genSerialize(t.Elem(), v, pos, nil, de)
356                                 fmt.Println("})")
357                                 fmt.Println("if err == io.EOF { break }")
358                                 fmt.Println(expr + " = append(" + expr + ", " + v + ")")
359                                 fmt.Println("chk(err)")
360                                 fmt.Println("}")
361                         }
362                 } else {
363                         if lenhdr != nil {
364                                 fmt.Println("if len("+expr+") >",
365                                         "math.Max"+strings.Title(lenhdr.String()),
366                                         "{ chk(ErrTooLong) }")
367                                 genSerialize(lenhdr, lenhdr.String()+"(len("+expr+"))", pos, nil, de)
368                         }
369                         genSerialize(types.NewArray(t.Elem(), 0), expr, pos, nil, de)
370                 }
371         case *types.Array:
372                 et := t.Elem()
373                 if et == byteT || et == uint8T {
374                         if de {
375                                 fmt.Println("{",
376                                         "_, err := io.ReadFull(r, "+expr+"[:]);",
377                                         "chk(err)",
378                                         "}")
379                         } else {
380                                 fmt.Println("{",
381                                         "_, err := w.Write("+expr+"[:]);",
382                                         "chk(err)",
383                                         "}")
384                         }
385                         break
386                 }
387                 i := newVar()
388                 fmt.Println("for", i, ":= range", expr, "{")
389                 if pos.IsValid() {
390                         pos = pos2node(pos)[0].(*ast.ArrayType).Elt.Pos()
391                 }
392                 genSerialize(et, expr+"["+i+"]", pos, nil, de)
393                 fmt.Println("}")
394         default:
395                 error(pos, "can't serialize ", t)
396         }
397 }
398
399 func readOverrides(path string, override map[string]string) {
400         f, err := os.Open(path)
401         if err != nil {
402                 log.Fatal(err)
403         }
404         defer f.Close()
405
406         b := bufio.NewReader(f)
407         line := 0
408         col1 := ""
409         for {
410                 ln, err := b.ReadString('\n')
411                 if err != nil {
412                         if err == io.EOF {
413                                 if len(ln) > 0 {
414                                         log.Fatal("no newline at end of ", f.Name())
415                                 }
416                                 return
417                         }
418                         log.Fatal(err)
419                 }
420                 line++
421
422                 if ln == "\n" {
423                         continue
424                 }
425
426                 fields := strings.SplitN(ln, "\t", 2)
427                 if len(fields) == 1 {
428                         log.Fatal(f.Name(), ":", line, ": missing tab")
429                 }
430                 if fields[0] != "" {
431                         col1 = fields[0]
432                 }
433
434                 if col1 == "" {
435                         fmt.Print(fields[1])
436                         continue
437                 }
438
439                 override[col1] += fields[1]
440         }
441 }
442
443 func mkSerialize(t *types.Named) {
444         if !inSerialize[t.String()] {
445                 serialize = append(serialize, t)
446                 inSerialize[t.String()] = true
447         }
448 }
449
450 var varNo int
451
452 func newVar() string {
453         varNo++
454         return fmt.Sprint("local", varNo)
455 }
456
457 func pos2node(pos token.Pos) []ast.Node {
458         return interval2node(pos, pos)
459 }
460
461 func interval2node(start, end token.Pos) []ast.Node {
462         for _, f := range pkg.Syntax {
463                 if f.Pos() <= start && end <= f.End() {
464                         if path, _ := astutil.PathEnclosingInterval(f, start, end); path != nil {
465                                 return path
466                         }
467                 }
468         }
469         return nil
470 }
471
472 func error(pos token.Pos, a ...interface{}) {
473         if !pos.IsValid() {
474                 log.Fatal(a...)
475         }
476         log.Fatal(append([]interface{}{pkg.Fset.Position(pos), ": "}, a...)...)
477 }
478
479 func typeStr(t types.Type) string {
480         return types.TypeString(t, func(p *types.Package) string {
481                 if p == pkg.Types {
482                         return ""
483                 }
484
485                 return p.Name()
486         })
487 }
488
489 var typeNames = []string{
490         "ToSrvNil",
491         "ToSrvInit",
492         "ToSrvInit2",
493         "ToSrvJoinModChan",
494         "ToSrvLeaveModChan",
495         "ToSrvMsgModChan",
496         "ToSrvPlayerPos",
497         "ToSrvGotBlks",
498         "ToSrvDeletedBlks",
499         "ToSrvInvAction",
500         "ToSrvChatMsg",
501         "ToSrvFallDmg",
502         "ToSrvSelectItem",
503         "ToSrvRespawn",
504         "ToSrvInteract",
505         "ToSrvRemovedSounds",
506         "ToSrvNodeMetaFields",
507         "ToSrvInvFields",
508         "ToSrvReqMedia",
509         "ToSrvCltReady",
510         "ToSrvFirstSRP",
511         "ToSrvSRPBytesA",
512         "ToSrvSRPBytesM",
513
514         "ToCltHello",
515         "ToCltAcceptAuth",
516         "ToCltAcceptSudoMode",
517         "ToCltDenySudoMode",
518         "ToCltKick",
519         "ToCltBlkData",
520         "ToCltAddNode",
521         "ToCltRemoveNode",
522         "ToCltInv",
523         "ToCltTimeOfDay",
524         "ToCltCSMRestrictionFlags",
525         "ToCltAddPlayerVel",
526         "ToCltMediaPush",
527         "ToCltChatMsg",
528         "ToCltAORmAdd",
529         "ToCltAOMsgs",
530         "ToCltHP",
531         "ToCltMovePlayer",
532         "ToCltLegacyKick",
533         "ToCltFOV",
534         "ToCltDeathScreen",
535         "ToCltMedia",
536         "ToCltNodeDefs",
537         "ToCltAnnounceMedia",
538         "ToCltItemDefs",
539         "ToCltPlaySound",
540         "ToCltStopSound",
541         "ToCltPrivs",
542         "ToCltInvFormspec",
543         "ToCltDetachedInv",
544         "ToCltShowFormspec",
545         "ToCltMovement",
546         "ToCltSpawnParticle",
547         "ToCltAddParticleSpawner",
548         "ToCltAddHUD",
549         "ToCltRmHUD",
550         "ToCltChangeHUD",
551         "ToCltHUDFlags",
552         "ToCltSetHotbarParam",
553         "ToCltBreath",
554         "ToCltSkyParams",
555         "ToCltOverrideDayNightRatio",
556         "ToCltLocalPlayerAnim",
557         "ToCltEyeOffset",
558         "ToCltDelParticleSpawner",
559         "ToCltCloudParams",
560         "ToCltFadeSound",
561         "ToCltUpdatePlayerList",
562         "ToCltModChanMsg",
563         "ToCltModChanSig",
564         "ToCltNodeMetasChanged",
565         "ToCltSunParams",
566         "ToCltMoonParams",
567         "ToCltStarParams",
568         "ToCltSRPBytesSaltB",
569         "ToCltFormspecPrepend",
570
571         "AOCmdProps",
572         "AOCmdPos",
573         "AOCmdTextureMod",
574         "AOCmdSprite",
575         "AOCmdHP",
576         "AOCmdArmorGroups",
577         "AOCmdAnim",
578         "AOCmdBonePos",
579         "AOCmdAttach",
580         "AOCmdPhysOverride",
581         "AOCmdSpawnInfant",
582         "AOCmdAnimSpeed",
583
584         "NodeMeta",
585         "MinimapMode",
586         "NodeDef",
587         "PointedNode",
588         "PointedAO",
589 }
590
591 func main() {
592         log.SetFlags(0)
593         log.SetPrefix("mkserialize: ")
594
595         flag.Parse()
596
597         cfg := &packages.Config{Mode: packages.NeedSyntax |
598                 packages.NeedName |
599                 packages.NeedDeps |
600                 packages.NeedImports |
601                 packages.NeedTypes |
602                 packages.NeedTypesInfo}
603         pkgs, err := packages.Load(cfg, flag.Args()...)
604         if err != nil {
605                 log.Fatal(err)
606         }
607         if packages.PrintErrors(pkgs) > 0 {
608                 os.Exit(1)
609         }
610
611         if len(pkgs) != 1 {
612                 log.Fatal("must be exactly 1 package")
613         }
614         pkg = pkgs[0]
615
616         fmt.Println("package", pkg.Name)
617
618         readOverrides("serialize.fmt", serializeFmt)
619         readOverrides("deserialize.fmt", deserializeFmt)
620
621         for _, f := range pkg.Syntax {
622                 for _, cg := range f.Comments {
623                         for _, c := range cg.List {
624                                 if !strings.HasPrefix(c.Text, "//mt:") {
625                                         continue
626                                 }
627                                 st := interval2node(c.Pos(), c.End())[1].(*ast.StructType)
628                                 consts[st] = append(consts[st], c)
629                         }
630                 }
631         }
632
633         for _, name := range typeNames {
634                 obj := pkg.Types.Scope().Lookup(name)
635                 if obj == nil {
636                         log.Println("undeclared identifier: ", name)
637                         continue
638                 }
639                 mkSerialize(obj.Type().(*types.Named))
640         }
641
642         for i := 0; i < len(serialize); i++ {
643                 for _, de := range []bool{false, true} {
644                         t := serialize[i]
645                         sig := "serialize(w io.Writer)"
646                         if de {
647                                 sig = "deserialize(r io.Reader)"
648                         }
649                         fmt.Println("\nfunc (obj *" + t.Obj().Name() + ") " + sig + " {")
650                         pos := t.Obj().Pos()
651                         tExpr := pos2node(pos)[1].(*ast.TypeSpec).Type
652                         var b strings.Builder
653                         printer.Fprint(&b, pkg.Fset, tExpr)
654                         genSerialize(pkg.TypesInfo.Types[tExpr].Type, "*(*("+b.String()+"))(obj)", tExpr.Pos(), nil, de)
655                         fmt.Println("}")
656                 }
657         }
658 }