17 "golang.org/x/tools/go/ast/astutil"
18 "golang.org/x/tools/go/packages"
24 serializeFmt = make(map[string]string)
25 deserializeFmt = make(map[string]string)
27 uint8T = types.Universe.Lookup("uint8").Type()
28 byteT = types.Universe.Lookup("byte").Type()
30 serialize []*types.Named
31 inSerialize = make(map[string]bool)
33 consts = make(map[*ast.StructType][]*ast.Comment)
36 func structPragma(c *ast.Comment, sp *[]func(), expr string, de bool) {
37 fields := strings.SplitN(strings.TrimPrefix(c.Text, "//mt:"), " ", 2)
44 tv, err := types.Eval(pkg.Fset, pkg.Types, c.Slash, arg)
52 fmt.Println("var", x, typeStr(tv.Type))
53 genSerialize(tv.Type, x, token.NoPos, nil, de)
54 fmt.Println("if", x, "!=", "(", tv.Value, ")",
55 `{ chk(fmt.Errorf("const %v: %v",`, tv.Value, ",", x, ")) }")
59 fmt.Println("{", v, ":=", arg)
60 genSerialize(tv.Type, v, c.Slash+token.Pos(len("//mt:const ")), nil, de)
64 fmt.Printf("if !("+arg+") {", expr)
65 fmt.Printf("chk(errors.New(%q))\n", "assertion failed: "+arg)
69 fmt.Println("{ r, err := zlib.NewReader(byteReader{r}); chk(err)")
70 *sp = append(*sp, func() {
71 fmt.Println("chk(r.Close()) }")
74 fmt.Println("{ w := zlib.NewWriter(w)")
75 *sp = append(*sp, func() {
76 fmt.Println("chk(w.Close()) }")
80 if arg != "8" && arg != "16" && arg != "32" {
81 error(c.Pos(), "usage: //mt:lenhdr (8|16|32)")
87 fmt.Println("ow := w")
88 fmt.Println("w := new(bytes.Buffer)")
91 var cg ast.CommentGroup
93 t := types.Universe.Lookup("uint" + arg).Type()
94 fmt.Println("var n", t)
95 genSerialize(t, "n", token.NoPos, nil, de)
97 fmt.Println(`if n > math.MaxInt64 { panic("too big len") }`)
99 fmt.Println("r := &io.LimitedReader{r, int64(n)}")
103 cg.List = []*ast.Comment{{Text: "//mt:len" + arg}}
108 *sp = append(*sp, func() {
110 fmt.Println("if r.N > 0",
111 `{ chk(fmt.Errorf("%d bytes of trailing data", r.N)) }`)
114 fmt.Println("buf := w")
115 fmt.Println("w := ow")
116 byteSlice := types.NewSlice(types.Typ[types.Byte])
117 genSerialize(byteSlice, "buf.Bytes()", token.NoPos, &cg, de)
125 *sp = (*sp)[:len(*sp)-1]
127 fmt.Printf(strings.TrimPrefix(c.Text, "//mt:")+" {\n", expr)
128 *sp = append(*sp, func() {
138 func genSerialize(t types.Type, expr string, pos token.Pos, doc *ast.CommentGroup, de bool) {
139 var lenhdr types.Type = types.Typ[types.Uint16]
143 for _, c := range doc.List {
147 t = types.NewSlice(types.Typ[types.Uint16])
150 fmt.Println("var", v, typeStr(t))
151 defer fmt.Println(expr + " = string(utf16.Decode(" + v + "))")
154 expr = "utf16.Encode([]rune(" + expr + "))"
160 lenhdr = types.Typ[types.Uint8]
162 lenhdr = types.Typ[types.Uint32]
164 fmt.Println("if err := pcall(func() {")
165 defer fmt.Println("}); err != nil && err != io.EOF",
176 str := types.TypeString(t, types.RelativeTo(pkg.Types))
178 if or, ok := deserializeFmt[str]; ok {
180 fmt.Println("p := &" + expr)
186 if or, ok := serializeFmt[str]; ok {
188 fmt.Println("x := " + expr)
195 expr = "(" + expr + ")"
197 switch t := t.(type) {
201 genSerialize(t, "*(*"+typeStr(t)+")("+"&"+expr+")", pos, doc, de)
205 method := "Serialize"
207 method = "Deserialize"
209 for i := 0; i < t.NumMethods(); i++ {
211 if m.Name() == method {
216 fmt.Println("chk(" + expr + "." + method + "(" + rw + "))")
223 fmt.Println("if err := pcall(func() {")
225 fmt.Println(expr + ".deserialize(r)")
227 fmt.Println(expr + ".serialize(w)")
229 fmt.Println("}); err != nil",
231 `if err == io.EOF { chk(io.EOF) };`,
232 `chk(fmt.Errorf("%s: %w", `+strconv.Quote(t.String())+`, err))`,
235 st := pos2node(pos)[0].(*ast.StructType)
240 // Merge sorted slices.
241 c := make([]ast.Node, 0, len(a)+len(b))
242 for i, j := 0, 0; i < len(a) || j < len(b); {
243 if i < len(a) && (j >= len(b) || a[i].Pos() < b[j].Pos()) {
256 for _, field := range c {
257 switch field := field.(type) {
259 structPragma(field, &stk, expr, de)
261 n := len(field.Names)
267 genSerialize(f.Type(), expr+"."+f.Name(), field.Type.Pos(), field.Doc, de)
274 error(pos, "missing //mt:end")
279 byteSlice := types.NewSlice(types.Typ[types.Byte])
282 fmt.Println("var", v, byteSlice)
283 genSerialize(byteSlice, v, token.NoPos, doc, de)
284 fmt.Println(expr, "=", "string(", v, ")")
286 genSerialize(byteSlice, "[]byte"+expr, token.NoPos, doc, de)
289 error(pos, "can't serialize ", t)
295 fmt.Println("var", v, lenhdr)
296 genSerialize(lenhdr, v, pos, nil, de)
297 fmt.Printf("%s = make(%v, %s)\n",
299 genSerialize(types.NewArray(t.Elem(), 0), expr, pos, nil, de)
303 fmt.Println("var", v, typeStr(t.Elem()))
304 fmt.Println("err := pcall(func() {")
306 pos = pos2node(pos)[0].(*ast.ArrayType).Elt.Pos()
308 genSerialize(t.Elem(), v, pos, nil, de)
310 fmt.Println("if err == io.EOF { break }")
311 fmt.Println(expr + " = append(" + expr + ", " + v + ")")
312 fmt.Println("chk(err)")
317 fmt.Println("if len("+expr+") >",
318 "math.Max"+strings.Title(lenhdr.String()),
319 "{ chk(ErrTooLong) }")
320 genSerialize(lenhdr, lenhdr.String()+"(len("+expr+"))", pos, nil, de)
322 genSerialize(types.NewArray(t.Elem(), 0), expr, pos, nil, de)
326 if et == byteT || et == uint8T {
329 "_, err := io.ReadFull(r, "+expr+"[:]);",
334 "_, err := w.Write("+expr+"[:]);",
341 fmt.Println("for", i, ":= range", expr, "{")
343 pos = pos2node(pos)[0].(*ast.ArrayType).Elt.Pos()
345 genSerialize(et, expr+"["+i+"]", pos, nil, de)
348 error(pos, "can't serialize ", t)
352 func readOverrides(path string, override map[string]string) {
353 f, err := os.Open(path)
359 b := bufio.NewReader(f)
363 ln, err := b.ReadString('\n')
367 log.Fatal("no newline at end of ", f.Name())
379 fields := strings.SplitN(ln, "\t", 2)
380 if len(fields) == 1 {
381 log.Fatal(f.Name(), ":", line, ": missing tab")
392 override[col1] += fields[1]
396 func mkSerialize(t *types.Named) {
397 if !inSerialize[t.String()] {
398 serialize = append(serialize, t)
399 inSerialize[t.String()] = true
405 func newVar() string {
407 return fmt.Sprint("local", varNo)
410 func pos2node(pos token.Pos) []ast.Node {
411 return interval2node(pos, pos)
414 func interval2node(start, end token.Pos) []ast.Node {
415 for _, f := range pkg.Syntax {
416 if f.Pos() <= start && end <= f.End() {
417 if path, _ := astutil.PathEnclosingInterval(f, start, end); path != nil {
425 func error(pos token.Pos, a ...interface{}) {
429 log.Fatal(append([]interface{}{pkg.Fset.Position(pos), ": "}, a...)...)
432 func typeStr(t types.Type) string {
433 return types.TypeString(t, func(p *types.Package) string {
442 var typeNames = []string{
458 "ToSrvRemovedSounds",
459 "ToSrvNodeMetaFields",
469 "ToCltAcceptSudoMode",
477 "ToCltCSMRestrictionFlags",
490 "ToCltAnnounceMedia",
499 "ToCltSpawnParticle",
500 "ToCltAddParticleSpawner",
505 "ToCltSetHotbarParam",
508 "ToCltOverrideDayNightRatio",
509 "ToCltLocalPlayerAnim",
511 "ToCltDelParticleSpawner",
514 "ToCltUpdatePlayerList",
517 "ToCltNodeMetasChanged",
521 "ToCltSRPBytesSaltB",
522 "ToCltFormspecPrepend",
546 log.SetPrefix("mkserialize: ")
550 cfg := &packages.Config{Mode: packages.NeedSyntax |
553 packages.NeedImports |
555 packages.NeedTypesInfo}
556 pkgs, err := packages.Load(cfg, flag.Args()...)
560 if packages.PrintErrors(pkgs) > 0 {
565 log.Fatal("must be exactly 1 package")
569 fmt.Println("package", pkg.Name)
571 readOverrides("serialize.fmt", serializeFmt)
572 readOverrides("deserialize.fmt", deserializeFmt)
574 for _, f := range pkg.Syntax {
575 for _, cg := range f.Comments {
576 for _, c := range cg.List {
577 if !strings.HasPrefix(c.Text, "//mt:") {
580 st := interval2node(c.Pos(), c.End())[1].(*ast.StructType)
581 consts[st] = append(consts[st], c)
586 for _, name := range typeNames {
587 obj := pkg.Types.Scope().Lookup(name)
589 log.Println("undeclared identifier: ", name)
592 mkSerialize(obj.Type().(*types.Named))
595 for i := 0; i < len(serialize); i++ {
596 for _, de := range []bool{false, true} {
598 sig := "serialize(w io.Writer)"
600 sig = "deserialize(r io.Reader)"
602 fmt.Println("\nfunc (obj *" + t.Obj().Name() + ") " + sig + " {")
604 tExpr := pos2node(pos)[1].(*ast.TypeSpec).Type
605 var b strings.Builder
606 printer.Fprint(&b, pkg.Fset, tExpr)
607 genSerialize(pkg.TypesInfo.Types[tExpr].Type, "*(*("+b.String()+"))(obj)", tExpr.Pos(), nil, de)