17 "golang.org/x/tools/go/ast/astutil"
18 "golang.org/x/tools/go/packages"
24 serializeFmt = make(map[string]string)
25 deserializeFmt = make(map[string]string)
27 uint8T = types.Universe.Lookup("uint8").Type()
28 byteT = types.Universe.Lookup("byte").Type()
30 serialize []*types.Named
31 inSerialize = make(map[string]bool)
33 consts = make(map[*ast.StructType][]*ast.Comment)
36 func structPragma(c *ast.Comment, sp *[]func(), expr string, de bool) {
37 fields := strings.SplitN(strings.TrimPrefix(c.Text, "//mt:"), " ", 2)
44 tv, err := types.Eval(pkg.Fset, pkg.Types, c.Slash, arg)
52 fmt.Println("var", x, typeStr(tv.Type))
54 fmt.Println(y, ":=", arg)
55 genSerialize(tv.Type, x, token.NoPos, nil, de)
56 fmt.Println("if", x, "!=", y,
57 `{ chk(fmt.Errorf("const %v: %v",`, strconv.Quote(arg), ",", x, ")) }")
61 fmt.Println("{", v, ":=", arg)
62 genSerialize(tv.Type, v, c.Slash+token.Pos(len("//mt:const ")), nil, de)
66 fmt.Printf("if !("+arg+") {", expr)
67 fmt.Printf("chk(errors.New(%q))\n", "assertion failed: "+arg)
71 fmt.Println("{ r, err := zlib.NewReader(byteReader{r}); chk(err)")
72 *sp = append(*sp, func() {
73 fmt.Println("chk(r.Close()) }")
76 fmt.Println("{ w := zlib.NewWriter(w)")
77 *sp = append(*sp, func() {
78 fmt.Println("chk(w.Close()) }")
83 fmt.Println("{ r, err := zstd.NewReader(byteReader{r}); chk(err)")
84 *sp = append(*sp, func() {
85 fmt.Println("r.Close() }")
88 fmt.Println("{ w, err := zstd.NewWriter(w); chk(err)")
89 *sp = append(*sp, func() {
90 fmt.Println("chk(w.Close()) }")
94 if arg != "8" && arg != "16" && arg != "32" {
95 error(c.Pos(), "usage: //mt:lenhdr (8|16|32)")
101 fmt.Println("ow := w")
102 fmt.Println("w := new(bytes.Buffer)")
105 var cg ast.CommentGroup
107 t := types.Universe.Lookup("uint" + arg).Type()
108 fmt.Println("var n", t)
109 genSerialize(t, "n", token.NoPos, nil, de)
111 fmt.Println(`if n > math.MaxInt64 { panic("too big len") }`)
113 fmt.Println("r := &io.LimitedReader{R: r, N: int64(n)}")
117 cg.List = []*ast.Comment{{Text: "//mt:len" + arg}}
122 *sp = append(*sp, func() {
124 fmt.Println("if r.N > 0",
125 `{ chk(fmt.Errorf("%d bytes of trailing data", r.N)) }`)
128 fmt.Println("buf := w")
129 fmt.Println("w := ow")
130 byteSlice := types.NewSlice(types.Typ[types.Byte])
131 genSerialize(byteSlice, "buf.Bytes()", token.NoPos, &cg, de)
139 *sp = (*sp)[:len(*sp)-1]
141 fmt.Printf(strings.TrimPrefix(c.Text, "//mt:")+" {\n", expr)
142 *sp = append(*sp, func() {
152 func genSerialize(t types.Type, expr string, pos token.Pos, doc *ast.CommentGroup, de bool) {
153 var lenhdr types.Type = types.Typ[types.Uint16]
157 for _, c := range doc.List {
161 t = types.Typ[types.Int16]
164 fmt.Println("var", v, "int16")
165 defer fmt.Println(expr + " = int32(" + v + ")")
168 expr = "int16(" + expr + ")"
172 t = types.Typ[types.Uint16]
175 fmt.Println("var", v, "uint16")
176 defer fmt.Println(expr + " = int32(" + v + ")")
179 expr = "uint16(" + expr + ")"
183 t = types.NewSlice(types.Typ[types.Uint16])
186 fmt.Println("var", v, typeStr(t))
187 defer fmt.Println(expr + " = string(utf16.Decode(" + v + "))")
191 fmt.Println(v, ":= utf16.Encode([]rune("+expr+"))")
198 lenhdr = types.Typ[types.Uint8]
200 lenhdr = types.Typ[types.Uint32]
202 fmt.Println("if err := pcall(func() {")
203 defer fmt.Println("}); err != nil && err != io.EOF",
214 str := types.TypeString(t, types.RelativeTo(pkg.Types))
216 if or, ok := deserializeFmt[str]; ok {
218 fmt.Println("p := &" + expr)
224 if or, ok := serializeFmt[str]; ok {
226 fmt.Println("x := " + expr)
233 expr = "(" + expr + ")"
235 switch t := t.(type) {
239 genSerialize(t, "*(*"+typeStr(t)+")("+"&"+expr+")", pos, doc, de)
243 method := "Serialize"
245 method = "Deserialize"
247 for i := 0; i < t.NumMethods(); i++ {
249 if m.Name() == method {
254 fmt.Println("chk(" + expr + "." + method + "(" + rw + "))")
261 fmt.Println("if err := pcall(func() {")
263 fmt.Println(expr + ".deserialize(r)")
265 fmt.Println(expr + ".serialize(w)")
267 fmt.Println("}); err != nil",
269 `if err == io.EOF { chk(io.EOF) };`,
270 `chk(fmt.Errorf("%s: %w", `+strconv.Quote(t.String())+`, err))`,
273 st := pos2node(pos)[0].(*ast.StructType)
278 // Merge sorted slices.
279 c := make([]ast.Node, 0, len(a)+len(b))
280 for i, j := 0, 0; i < len(a) || j < len(b); {
281 if i < len(a) && (j >= len(b) || a[i].Pos() < b[j].Pos()) {
294 for _, field := range c {
295 switch field := field.(type) {
297 structPragma(field, &stk, expr, de)
299 n := len(field.Names)
305 genSerialize(f.Type(), expr+"."+f.Name(), field.Type.Pos(), field.Doc, de)
312 error(pos, "missing //mt:end")
317 byteSlice := types.NewSlice(types.Typ[types.Byte])
320 fmt.Println("var", v, byteSlice)
321 genSerialize(byteSlice, v, token.NoPos, doc, de)
322 fmt.Println(expr, "=", "string(", v, ")")
324 genSerialize(byteSlice, "[]byte"+expr, token.NoPos, doc, de)
327 error(pos, "can't serialize ", t)
333 fmt.Println("var", v, lenhdr)
334 genSerialize(lenhdr, v, pos, nil, de)
335 fmt.Printf("%s = make(%v, %s)\n",
337 genSerialize(types.NewArray(t.Elem(), 0), expr, pos, nil, de)
339 if b, ok := t.Elem().(*types.Basic); ok && b.Kind() == types.Byte {
341 fmt.Println("var err error")
342 fmt.Println(expr, ", err = io.ReadAll(r)")
343 fmt.Println("chk(err)")
350 fmt.Println("var", v, typeStr(t.Elem()))
351 fmt.Println("err := pcall(func() {")
353 pos = pos2node(pos)[0].(*ast.ArrayType).Elt.Pos()
355 genSerialize(t.Elem(), v, pos, nil, de)
357 fmt.Println("if err == io.EOF { break }")
358 fmt.Println(expr + " = append(" + expr + ", " + v + ")")
359 fmt.Println("chk(err)")
364 fmt.Println("if len("+expr+") >",
365 "math.Max"+strings.Title(lenhdr.String()),
366 "{ chk(ErrTooLong) }")
367 genSerialize(lenhdr, lenhdr.String()+"(len("+expr+"))", pos, nil, de)
369 genSerialize(types.NewArray(t.Elem(), 0), expr, pos, nil, de)
373 if et == byteT || et == uint8T {
376 "_, err := io.ReadFull(r, "+expr+"[:]);",
381 "_, err := w.Write("+expr+"[:]);",
388 fmt.Println("for", i, ":= range", expr, "{")
390 pos = pos2node(pos)[0].(*ast.ArrayType).Elt.Pos()
392 genSerialize(et, expr+"["+i+"]", pos, nil, de)
395 error(pos, "can't serialize ", t)
399 func readOverrides(path string, override map[string]string) {
400 f, err := os.Open(path)
406 b := bufio.NewReader(f)
410 ln, err := b.ReadString('\n')
414 log.Fatal("no newline at end of ", f.Name())
426 fields := strings.SplitN(ln, "\t", 2)
427 if len(fields) == 1 {
428 log.Fatal(f.Name(), ":", line, ": missing tab")
439 override[col1] += fields[1]
443 func mkSerialize(t *types.Named) {
444 if !inSerialize[t.String()] {
445 serialize = append(serialize, t)
446 inSerialize[t.String()] = true
452 func newVar() string {
454 return fmt.Sprint("local", varNo)
457 func pos2node(pos token.Pos) []ast.Node {
458 return interval2node(pos, pos)
461 func interval2node(start, end token.Pos) []ast.Node {
462 for _, f := range pkg.Syntax {
463 if f.Pos() <= start && end <= f.End() {
464 if path, _ := astutil.PathEnclosingInterval(f, start, end); path != nil {
472 func error(pos token.Pos, a ...interface{}) {
476 log.Fatal(append([]interface{}{pkg.Fset.Position(pos), ": "}, a...)...)
479 func typeStr(t types.Type) string {
480 return types.TypeString(t, func(p *types.Package) string {
489 var typeNames = []string{
505 "ToSrvRemovedSounds",
506 "ToSrvNodeMetaFields",
516 "ToCltAcceptSudoMode",
524 "ToCltCSMRestrictionFlags",
537 "ToCltAnnounceMedia",
546 "ToCltSpawnParticle",
547 "ToCltAddParticleSpawner",
552 "ToCltSetHotbarParam",
555 "ToCltOverrideDayNightRatio",
556 "ToCltLocalPlayerAnim",
558 "ToCltDelParticleSpawner",
561 "ToCltUpdatePlayerList",
564 "ToCltNodeMetasChanged",
568 "ToCltSRPBytesSaltB",
569 "ToCltFormspecPrepend",
593 log.SetPrefix("mkserialize: ")
597 cfg := &packages.Config{Mode: packages.NeedSyntax |
600 packages.NeedImports |
602 packages.NeedTypesInfo}
603 pkgs, err := packages.Load(cfg, flag.Args()...)
607 if packages.PrintErrors(pkgs) > 0 {
612 log.Fatal("must be exactly 1 package")
616 fmt.Println("package", pkg.Name)
618 readOverrides("serialize.fmt", serializeFmt)
619 readOverrides("deserialize.fmt", deserializeFmt)
621 for _, f := range pkg.Syntax {
622 for _, cg := range f.Comments {
623 for _, c := range cg.List {
624 if !strings.HasPrefix(c.Text, "//mt:") {
627 st := interval2node(c.Pos(), c.End())[1].(*ast.StructType)
628 consts[st] = append(consts[st], c)
633 for _, name := range typeNames {
634 obj := pkg.Types.Scope().Lookup(name)
636 log.Println("undeclared identifier: ", name)
639 mkSerialize(obj.Type().(*types.Named))
642 for i := 0; i < len(serialize); i++ {
643 for _, de := range []bool{false, true} {
645 sig := "serialize(w io.Writer)"
647 sig = "deserialize(r io.Reader)"
649 fmt.Println("\nfunc (obj *" + t.Obj().Name() + ") " + sig + " {")
651 tExpr := pos2node(pos)[1].(*ast.TypeSpec).Type
652 var b strings.Builder
653 printer.Fprint(&b, pkg.Fset, tExpr)
654 genSerialize(pkg.TypesInfo.Types[tExpr].Type, "*(*("+b.String()+"))(obj)", tExpr.Pos(), nil, de)