17 "golang.org/x/tools/go/ast/astutil"
18 "golang.org/x/tools/go/packages"
24 serializeFmt = make(map[string]string)
25 deserializeFmt = make(map[string]string)
27 uint8T = types.Universe.Lookup("uint8").Type()
28 byteT = types.Universe.Lookup("byte").Type()
30 serialize []*types.Named
31 inSerialize = make(map[string]bool)
33 consts = make(map[*ast.StructType][]*ast.Comment)
36 func structPragma(c *ast.Comment, sp *[]func(), expr string, de bool) {
37 fields := strings.SplitN(strings.TrimPrefix(c.Text, "//mt:"), " ", 2)
44 tv, err := types.Eval(pkg.Fset, pkg.Types, c.Slash, arg)
52 fmt.Println("var", x, typeStr(tv.Type))
54 fmt.Println(y, ":=", arg)
55 genSerialize(tv.Type, x, token.NoPos, nil, de)
56 fmt.Println("if", x, "!=", y,
57 `{ chk(fmt.Errorf("const %v: %v",`, strconv.Quote(arg), ",", x, ")) }")
61 fmt.Println("{", v, ":=", arg)
62 genSerialize(tv.Type, v, c.Slash+token.Pos(len("//mt:const ")), nil, de)
66 fmt.Printf("if !("+arg+") {", expr)
67 fmt.Printf("chk(errors.New(%q))\n", "assertion failed: "+arg)
71 fmt.Println("{ r, err := zlib.NewReader(byteReader{r}); chk(err)")
72 *sp = append(*sp, func() {
73 fmt.Println("chk(r.Close()) }")
76 fmt.Println("{ w := zlib.NewWriter(w)")
77 *sp = append(*sp, func() {
78 fmt.Println("chk(w.Close()) }")
82 if arg != "8" && arg != "16" && arg != "32" {
83 error(c.Pos(), "usage: //mt:lenhdr (8|16|32)")
89 fmt.Println("ow := w")
90 fmt.Println("w := new(bytes.Buffer)")
93 var cg ast.CommentGroup
95 t := types.Universe.Lookup("uint" + arg).Type()
96 fmt.Println("var n", t)
97 genSerialize(t, "n", token.NoPos, nil, de)
99 fmt.Println(`if n > math.MaxInt64 { panic("too big len") }`)
101 fmt.Println("r := &io.LimitedReader{R: r, N: int64(n)}")
105 cg.List = []*ast.Comment{{Text: "//mt:len" + arg}}
110 *sp = append(*sp, func() {
112 fmt.Println("if r.N > 0",
113 `{ chk(fmt.Errorf("%d bytes of trailing data", r.N)) }`)
116 fmt.Println("buf := w")
117 fmt.Println("w := ow")
118 byteSlice := types.NewSlice(types.Typ[types.Byte])
119 genSerialize(byteSlice, "buf.Bytes()", token.NoPos, &cg, de)
127 *sp = (*sp)[:len(*sp)-1]
129 fmt.Printf(strings.TrimPrefix(c.Text, "//mt:")+" {\n", expr)
130 *sp = append(*sp, func() {
140 func genSerialize(t types.Type, expr string, pos token.Pos, doc *ast.CommentGroup, de bool) {
141 var lenhdr types.Type = types.Typ[types.Uint16]
145 for _, c := range doc.List {
149 t = types.Typ[types.Int16]
152 fmt.Println("var", v, "int16")
153 defer fmt.Println(expr + " = int32(" + v + ")")
156 expr = "int16(" + expr + ")"
160 t = types.Typ[types.Uint16]
163 fmt.Println("var", v, "uint16")
164 defer fmt.Println(expr + " = int32(" + v + ")")
167 expr = "uint16(" + expr + ")"
171 t = types.NewSlice(types.Typ[types.Uint16])
174 fmt.Println("var", v, typeStr(t))
175 defer fmt.Println(expr + " = string(utf16.Decode(" + v + "))")
179 fmt.Println(v, ":= utf16.Encode([]rune("+expr+"))")
186 lenhdr = types.Typ[types.Uint8]
188 lenhdr = types.Typ[types.Uint32]
190 fmt.Println("if err := pcall(func() {")
191 defer fmt.Println("}); err != nil && err != io.EOF",
202 str := types.TypeString(t, types.RelativeTo(pkg.Types))
204 if or, ok := deserializeFmt[str]; ok {
206 fmt.Println("p := &" + expr)
212 if or, ok := serializeFmt[str]; ok {
214 fmt.Println("x := " + expr)
221 expr = "(" + expr + ")"
223 switch t := t.(type) {
227 genSerialize(t, "*(*"+typeStr(t)+")("+"&"+expr+")", pos, doc, de)
231 method := "Serialize"
233 method = "Deserialize"
235 for i := 0; i < t.NumMethods(); i++ {
237 if m.Name() == method {
242 fmt.Println("chk(" + expr + "." + method + "(" + rw + "))")
249 fmt.Println("if err := pcall(func() {")
251 fmt.Println(expr + ".deserialize(r)")
253 fmt.Println(expr + ".serialize(w)")
255 fmt.Println("}); err != nil",
257 `if err == io.EOF { chk(io.EOF) };`,
258 `chk(fmt.Errorf("%s: %w", `+strconv.Quote(t.String())+`, err))`,
261 st := pos2node(pos)[0].(*ast.StructType)
266 // Merge sorted slices.
267 c := make([]ast.Node, 0, len(a)+len(b))
268 for i, j := 0, 0; i < len(a) || j < len(b); {
269 if i < len(a) && (j >= len(b) || a[i].Pos() < b[j].Pos()) {
282 for _, field := range c {
283 switch field := field.(type) {
285 structPragma(field, &stk, expr, de)
287 n := len(field.Names)
293 genSerialize(f.Type(), expr+"."+f.Name(), field.Type.Pos(), field.Doc, de)
300 error(pos, "missing //mt:end")
305 byteSlice := types.NewSlice(types.Typ[types.Byte])
308 fmt.Println("var", v, byteSlice)
309 genSerialize(byteSlice, v, token.NoPos, doc, de)
310 fmt.Println(expr, "=", "string(", v, ")")
312 genSerialize(byteSlice, "[]byte"+expr, token.NoPos, doc, de)
315 error(pos, "can't serialize ", t)
321 fmt.Println("var", v, lenhdr)
322 genSerialize(lenhdr, v, pos, nil, de)
323 fmt.Printf("%s = make(%v, %s)\n",
325 genSerialize(types.NewArray(t.Elem(), 0), expr, pos, nil, de)
327 if b, ok := t.Elem().(*types.Basic); ok && b.Kind() == types.Byte {
329 fmt.Println("var err error")
330 fmt.Println(expr, ", err = io.ReadAll(r)")
331 fmt.Println("chk(err)")
338 fmt.Println("var", v, typeStr(t.Elem()))
339 fmt.Println("err := pcall(func() {")
341 pos = pos2node(pos)[0].(*ast.ArrayType).Elt.Pos()
343 genSerialize(t.Elem(), v, pos, nil, de)
345 fmt.Println("if err == io.EOF { break }")
346 fmt.Println(expr + " = append(" + expr + ", " + v + ")")
347 fmt.Println("chk(err)")
352 fmt.Println("if len("+expr+") >",
353 "math.Max"+strings.Title(lenhdr.String()),
354 "{ chk(ErrTooLong) }")
355 genSerialize(lenhdr, lenhdr.String()+"(len("+expr+"))", pos, nil, de)
357 genSerialize(types.NewArray(t.Elem(), 0), expr, pos, nil, de)
361 if et == byteT || et == uint8T {
364 "_, err := io.ReadFull(r, "+expr+"[:]);",
369 "_, err := w.Write("+expr+"[:]);",
376 fmt.Println("for", i, ":= range", expr, "{")
378 pos = pos2node(pos)[0].(*ast.ArrayType).Elt.Pos()
380 genSerialize(et, expr+"["+i+"]", pos, nil, de)
383 error(pos, "can't serialize ", t)
387 func readOverrides(path string, override map[string]string) {
388 f, err := os.Open(path)
394 b := bufio.NewReader(f)
398 ln, err := b.ReadString('\n')
402 log.Fatal("no newline at end of ", f.Name())
414 fields := strings.SplitN(ln, "\t", 2)
415 if len(fields) == 1 {
416 log.Fatal(f.Name(), ":", line, ": missing tab")
427 override[col1] += fields[1]
431 func mkSerialize(t *types.Named) {
432 if !inSerialize[t.String()] {
433 serialize = append(serialize, t)
434 inSerialize[t.String()] = true
440 func newVar() string {
442 return fmt.Sprint("local", varNo)
445 func pos2node(pos token.Pos) []ast.Node {
446 return interval2node(pos, pos)
449 func interval2node(start, end token.Pos) []ast.Node {
450 for _, f := range pkg.Syntax {
451 if f.Pos() <= start && end <= f.End() {
452 if path, _ := astutil.PathEnclosingInterval(f, start, end); path != nil {
460 func error(pos token.Pos, a ...interface{}) {
464 log.Fatal(append([]interface{}{pkg.Fset.Position(pos), ": "}, a...)...)
467 func typeStr(t types.Type) string {
468 return types.TypeString(t, func(p *types.Package) string {
477 var typeNames = []string{
493 "ToSrvRemovedSounds",
494 "ToSrvNodeMetaFields",
504 "ToCltAcceptSudoMode",
512 "ToCltCSMRestrictionFlags",
525 "ToCltAnnounceMedia",
534 "ToCltSpawnParticle",
535 "ToCltAddParticleSpawner",
540 "ToCltSetHotbarParam",
543 "ToCltOverrideDayNightRatio",
544 "ToCltLocalPlayerAnim",
546 "ToCltDelParticleSpawner",
549 "ToCltUpdatePlayerList",
552 "ToCltNodeMetasChanged",
556 "ToCltSRPBytesSaltB",
557 "ToCltFormspecPrepend",
581 log.SetPrefix("mkserialize: ")
585 cfg := &packages.Config{Mode: packages.NeedSyntax |
588 packages.NeedImports |
590 packages.NeedTypesInfo}
591 pkgs, err := packages.Load(cfg, flag.Args()...)
595 if packages.PrintErrors(pkgs) > 0 {
600 log.Fatal("must be exactly 1 package")
604 fmt.Println("package", pkg.Name)
606 readOverrides("serialize.fmt", serializeFmt)
607 readOverrides("deserialize.fmt", deserializeFmt)
609 for _, f := range pkg.Syntax {
610 for _, cg := range f.Comments {
611 for _, c := range cg.List {
612 if !strings.HasPrefix(c.Text, "//mt:") {
615 st := interval2node(c.Pos(), c.End())[1].(*ast.StructType)
616 consts[st] = append(consts[st], c)
621 for _, name := range typeNames {
622 obj := pkg.Types.Scope().Lookup(name)
624 log.Println("undeclared identifier: ", name)
627 mkSerialize(obj.Type().(*types.Named))
630 for i := 0; i < len(serialize); i++ {
631 for _, de := range []bool{false, true} {
633 sig := "serialize(w io.Writer)"
635 sig = "deserialize(r io.Reader)"
637 fmt.Println("\nfunc (obj *" + t.Obj().Name() + ") " + sig + " {")
639 tExpr := pos2node(pos)[1].(*ast.TypeSpec).Type
640 var b strings.Builder
641 printer.Fprint(&b, pkg.Fset, tExpr)
642 genSerialize(pkg.TypesInfo.Types[tExpr].Type, "*(*("+b.String()+"))(obj)", tExpr.Pos(), nil, de)