Protobuf 和代码生成辅助方法

文章讲述了如何克服Protobuf在Go代码生成时的局限性,包括无法通过Tag指定列名的问题,提出了侵入式修改protobufGo插件的方案。同时,文章讨论了代码生成辅助方法,特别是针对ORM框架,如何利用AST和模板编程自动生成Predicate方法。还提供了一个示例,展示了如何解析Go源文件并生成ORM相关代码。
摘要由CSDN通过智能技术生成

Protobuf 和代码生成辅助方法

Protobuf

在元数据里面,说过 Protobuf 这种代码生成 的,无法利用 Tag 来指定列名
image.png
image.png
我们希望能够达成图二这种效果,而不是图一那种。

Protobuf 的局限性

Protobuf 虽然暴露了插件机制,但是插件并不能 修改生成的 Go 代码,插件只能自己额外生成一些 代码。
image.png
所以实际上不能利用 protobuf 的插件机制

修改 protobuf Go 插件

  • clone 原本的 protobuf 仓库
  • 修改 protobuf 仓库里的核心代码
  • 安装修改后的 Go 插件
  • 执行 protoc 命令

protobuf 必然会生成 json 的标签, 所以只需要生成找到 json 标签的位置,然后插入我们的代码。
image.png
image.png
image.png
这种是侵入式的修改方案,不过我们别无选择。 如果在公司可以维护一个自己定制过的 protobuf Go 插件 仓库

例子

image.png
image.png
实际上,最开始考虑过 Google 的 field option 扩展,但是依旧用不了,只能用这种方案。

代码生成辅助方法

Predicate 设计有一些缺陷,这些缺陷 是可以改进的:

  • 生成字段名的常量
  • 生成 Predicate
type User struct {
    Name     string
    Age      *int
    NickName *sql.NullString
    Picture  []byte
}

type UserDetail struct {
    Address string
}

**基本上就是 AST + 模板编程 **:

  • AST 读取 Go 源文件内容,解析每个类型及其
  • 生成 import 内容,并且将 orm 依赖添加进去
  • 生成 const 内容
  • 生成辅助方法

具体实现

在 gen/orm_gen/ 目录下

package {{ .Package }}

import (
"gitee.com/geektime-geekbang/geektime-go/orm"
{{range $idx, $import := .Imports }}
    {{$import}}
{{end -}}
)
{{- $ops := .Ops -}}
{{range $i, $type := .Types }}

    const (
    {{- range $j, $field := .Fields}}
        {{$type.Name }}{{$field.Name}} = "{{$field.Name}}"
    {{- end}}
    )

    {{range $j, $field := .Fields}}
        {{- range $k, $op := $ops}}
            func {{$type.Name }}{{$field.Name}}{{$op}}(val {{$field.Type}}) orm.Predicate {
            return orm.C("{{$field.Name}}").{{$op}}(val)
            }
        {{end}}
    {{- end}}
{{- end}}
type SingleFileEntryVisitor struct {
	file *fileVisitor
}

func (s *SingleFileEntryVisitor) Get() File {
	if s.file != nil {
		return s.file.Get()
	}
	return File{}
}

func (s *SingleFileEntryVisitor) Visit(node ast.Node) ast.Visitor {
	n, ok := node.(*ast.File)
	if ok {
		s.file = &fileVisitor{
			pkg: n.Name.String(),
		}
		return s.file
	}
	return s
}

type fileVisitor struct {
	pkg     string
	imports []string
	types   []*typeVisitor
}

func (f *fileVisitor) Visit(node ast.Node) ast.Visitor {
	switch n := node.(type) {
	case *ast.TypeSpec:
		res := &typeVisitor{
			name:   n.Name.String(),
			fields: make([]Field, 0),
		}
		if f.types == nil {
			f.types = make([]*typeVisitor, 0)
		}
		f.types = append(f.types, res)
		return res
	case *ast.ImportSpec:
		path := n.Path.Value
		if n.Name != nil && n.Name.String() != "" {
			path = n.Name.String() + " " + path
		}
		if f.imports == nil {
			f.imports = make([]string, 0)
		}
		f.imports = append(f.imports, path)
	}
	return f
}

func (f *fileVisitor) Get() File {
	types := make([]Type, 0, len(f.types))
	for _, t := range f.types {
		types = append(types, t.Get())
	}
	return File{
		Package: f.pkg,
		Imports: f.imports,
		Types:   types,
	}
}

type typeVisitor struct {
	name   string
	fields []Field
}

func (t *typeVisitor) Visit(node ast.Node) ast.Visitor {
	fd, ok := node.(*ast.Field)
	if ok {
		// 所以实际上我们在这里并没有处理 map, channel 之类的类型
		var typName string
		switch fdType := fd.Type.(type) {
		case *ast.Ident:
			typName = fdType.String()
		case *ast.StarExpr:
			switch expr := fdType.X.(type) {
			case *ast.Ident:
				typName = fmt.Sprintf("*%s", expr.String())
			case *ast.SelectorExpr:
				x := expr.X.(*ast.Ident).String()
				name := expr.Sel.String()
				typName = fmt.Sprintf("*%s.%s", x, name)
			}
		case *ast.SelectorExpr:
			x := fdType.X.(*ast.Ident).String()
			name := fdType.Sel.String()
			typName = fmt.Sprintf("%s.%s", x, name)
		case *ast.ArrayType:
			// 其它类型我们都不能处理处理,本来在 ORM 框架里面我们也没有支持
			switch ele := fdType.Elt.(type) {
			case *ast.Ident:
				typName = fmt.Sprintf("[]%s", ele)
			}
		}
		t.fields = append(t.fields, Field{
			Type: typName,
			Name: fd.Names[0].String(),
		})
		return nil
	}
	return t
}

func (t *typeVisitor) Get() Type {
	return Type{
		Name:   t.name,
		Fields: t.fields,
	}
}

type File struct {
	Package string
	Imports []string
	Types   []Type
}

type Type struct {
	Name   string
	Fields []Field
}

type Field struct {
	Name string
	Type string
}

func main() {
	// 用户必须输入一个 src,限制为文件
	// 然后我们会在同目录下生成代码
	src := os.Args[1]
	// Dir返回路径的除最后一个元素之外的所有元素,通常是路径的目录。
	dstDir := filepath.Dir(src)
	// Base返回路径的最后一个元素
	fileName := filepath.Base(src)
	// LastIndexByte返回s中c的最后一个实例的索引,如果s中不存在c,则返回-1。
	idx := strings.LastIndexByte(fileName, '.')
	dst := filepath.Join(dstDir, fileName[:idx]+"_gen.go")
	f, err := os.Create(dst)
	if err != nil {
		fmt.Println(err)
		return
	}
	err = gen(f, src)
	if err != nil {
		fmt.Println(err)
		return
	}
	fmt.Println("生成成功")

}

// Go 会读取 tpl.gohtml 里面的内容填充到变量 tpl 里面
//go:embed tpl.gohtml
var genOrm string

type OrmFile struct {
	File
	Ops []string
}

func gen(writer io.Writer, srcFile string) error {
	fset := token.NewFileSet()
	f, err := parser.ParseFile(fset, srcFile, nil, parser.ParseComments)
	if err != nil {
		return err
	}
	tv := &SingleFileEntryVisitor{}
	ast.Walk(tv, f)
	file := tv.Get()

	tpl := template.New("gen_orm")
	tpl, err = tpl.Parse(genOrm)
	if err != nil {
		return err
	}
	return tpl.Execute(writer, OrmFile{
		File: file,
		Ops:  []string{"LT", "GT", "EQ"},
	})
}

单元测试

func TestFileVisitor_Get(t *testing.T) {
	testCases := []struct {
		src  string
		want File
	}{
		{
			src: `
package orm_gen
import (
	"fmt"
    "database/sql"
) 

import (
	dri "database/sql/driver"
)
type (
	StructType struct {
		// Public is a field
		// @type string
		Public string
        Ptr *sql.NullString
		Struct sql.NullInt64
		Age *int8
		Slice []byte
	}
)
`,
			want: File{
				Package: "orm_gen",
				Imports: []string{`"fmt"`, `"database/sql"`, `dri "database/sql/driver"`},
				Types: []Type{
					{
						Name: "StructType",
						Fields: []Field{
							{
								Name: "Public",
								Type: "string",
							},
							{
								Name: "Ptr",
								Type: "*sql.NullString",
							},
							{
								Name: "Struct",
								Type: "sql.NullInt64",
							},
							{
								Name: "Age",
								Type: "*int8",
							},
							{
								Name: "Slice",
								Type: "[]byte",
							},
						},
					},
				},
			},
		},
	}
	for _, tc := range testCases {
		fset := token.NewFileSet()
		f, err := parser.ParseFile(fset, "src.go", tc.src, parser.ParseComments)
		if err != nil {
			t.Fatal(err)
		}
		tv := &SingleFileEntryVisitor{}
		ast.Walk(tv, f)
		file := tv.Get()
		assert.Equal(t, tc.want, file)
	}
}
func TestGen(t *testing.T) {
	bs := &bytes.Buffer{}
	err := gen(bs, "testdata/user.go")
	require.NoError(t, err)
	assert.Equal(t, `package testdata

import (
    "gitee.com/geektime-geekbang/geektime-go/orm"

    "database/sql"
)

const (
    UserName = "Name"
    UserAge = "Age"
    UserNickName = "NickName"
    UserPicture = "Picture"
)


func UserNameLT(val string) orm.Predicate {
    return orm.C("Name").LT(val)
}

func UserNameGT(val string) orm.Predicate {
    return orm.C("Name").GT(val)
}

func UserNameEQ(val string) orm.Predicate {
    return orm.C("Name").EQ(val)
}

func UserAgeLT(val *int) orm.Predicate {
    return orm.C("Age").LT(val)
}

func UserAgeGT(val *int) orm.Predicate {
    return orm.C("Age").GT(val)
}

func UserAgeEQ(val *int) orm.Predicate {
    return orm.C("Age").EQ(val)
}

func UserNickNameLT(val *sql.NullString) orm.Predicate {
    return orm.C("NickName").LT(val)
}

func UserNickNameGT(val *sql.NullString) orm.Predicate {
    return orm.C("NickName").GT(val)
}

func UserNickNameEQ(val *sql.NullString) orm.Predicate {
    return orm.C("NickName").EQ(val)
}

func UserPictureLT(val []byte) orm.Predicate {
    return orm.C("Picture").LT(val)
}

func UserPictureGT(val []byte) orm.Predicate {
    return orm.C("Picture").GT(val)
}

func UserPictureEQ(val []byte) orm.Predicate {
    return orm.C("Picture").EQ(val)
}


const (
    UserDetailAddress = "Address"
)


func UserDetailAddressLT(val string) orm.Predicate {
    return orm.C("Address").LT(val)
}

func UserDetailAddressGT(val string) orm.Predicate {
    return orm.C("Address").GT(val)
}

func UserDetailAddressEQ(val string) orm.Predicate {
    return orm.C("Address").EQ(val)
}
`, bs.String())
}

总结

  • 怎么修改 protobuf 生成的代码?生成好的代码从实践上来说是不应该修改的,而很不幸的 是,protobuf 对应不同插件生成的目标代码,只能改插件的源码 ;
  • 怎么为 protobuf 的字段添加额外的属性?可以通过 field option 来增加额外的属性,但是 这种新的属性需要你自己写代码解析 ;
  • 代码生成的常用场景?一般来说,样板代码都可以考虑使用代码生成来替换掉,比如典型 的利用代码生成来生成数据库查询(如ENT),生成增删改查的代码,生成前端代码 ;
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值