gorm_generate根据表生成对应结构体

一. 先了解jen代码生成

  1. 有一个专门用于代码生成的三方库"github.com/dave/jennifer/jen",下方一个简单的示例
    它可以让你以编程的方式创建和修改.go文件
package main

import "github.com/dave/jennifer/jen"

func main() {
	//导包 "github.com/dave/jennifer/jen"
	
	//1.通过jen可以创建文件,入参指定的是pacakge
	f := jen.NewFile("main")

	//2.在文件中添加结构体的定义
	f.Type().Id("Person").Struct(
		//2.1添加属性
		jen.Id("Age").Int(),
		//设置属性的tag
		jen.Id("ID").Int64().Tag(map[string]string{
			"gorm": "primary_key",
		}),
		jen.Id("Name").String().Tag(map[string]string{
			"json": "name",
		}),
	)

	//3.在文件中添加函数
	f.Func().Id("main").Params().Block(
		// 创建一个结构体变量
		jen.Id("p").Op(":=").Id("Person").Values(
			jen.Id("Name").Op(":").Lit("Alice"),
			jen.Id("Age").Op(":").Lit(20),
		),
		// 调用函数
		jen.Id("p").Dot("PrintInfo").Call(),
	)

	//4.在文件中给结构体绑定方法
	f.Func().Params(jen.Id("p").Op("*").Id("Person")).Id("PrintInfo").Params().Block(
		jen.Qual("fmt", "Printf").Call(
			jen.Lit("Name: %s, Age: %d\n"),
			jen.Id("p").Dot("Name"),
			jen.Id("p").Dot("Age"),
		),
	)

	//5.输出保存到文件
	f.Save("example.go")
}
  1. 执行完毕后会生成如下example.go文件
package main

import "fmt"

type Person struct {
	Age  int
	ID   int64  `gorm:"primary_key"`
	Name string `json:"name"`
}

func main() {
	p := Person{Name: "Alice", Age: 20}
	p.PrintInfo()
}
func (p *Person) PrintInfo() {
	fmt.Printf("Name: %s, Age: %d\n", p.Name, p.Age)
}

二. gorm.io/gorm 新版本 gorm + jen 生成表对应结构体

介绍 gorm Migrator

  1. 在使用新版本的gorm时底层有一个Migrator迁移器,内部提供了一些用来创建,修改,删除,以及获取数据库的元数据信息的方法
CurrentDatabase:用于获取当前数据库的名称
TableType:用于获取一个表的类型,是普通表还是视图,返回一个TableType类型的值
ColumnTypes: 用于获取表字段信息的
GetIndexes: 用户获取索引信息的
AddColumn:用于给一个表添加一个列,根据你定义的结构体的字段来生成对应的类型和约束
AutoMigrate:用于自动迁移你的schema,保持你的schema是最新的。它会创建表,缺少的外键,约束,列和索引,并且会更改现有列的类型(如果其大小、精度、是否为空可更改)
FullDataTypeOf:用于获取给定字段的完整数据类型,包括长度、精度、是否为空等
GetTypeAliases:用于获取给定数据库类型名称的别名,例如int和integer
CreateTable:用于创建一个或多个表,根据你定义的结构体来生成对应的字段和类型,以及主键、外键、索引等约束
DropTable:用于删除一个或多个表,如果表不存在,不会报错
HasTable:用于检查一个表是否存在,可以传入表名或结构体
RenameTable:用于重命名一个表,可以传入旧表名和新表名,或者旧结构体和新结构体
GetTables:用于获取数据库中所有表的名称,返回一个字符串切片
package main

import (
	"fmt"
	"github.com/dave/jennifer/jen"
	_ "github.com/go-sql-driver/mysql"
	"gorm.io/driver/mysql"
	"gorm.io/gorm" //TODO 注意与"github.com/jinzhu/gorm"的不同,"github.com/jinzhu/gorm"是gorm的老版本
	"gorm.io/gorm/logger"
	"strconv"
	"strings"
)

var db *gorm.DB
var err error

func init() {
	// 连接数据库,开启日志模式
	dsn := "账号:密码@tcp(ip:3306)/lm_pms?charset=utf8mb4&parseTime=True&loc=Local"
	db, err = gorm.Open(
		mysql.Open(dsn),
		&gorm.Config{Logger: logger.Default.LogMode(logger.Info)},
	)
	if err != nil {
		panic("failed to connect database")
	}
}

func main() {

	//首先通过Raw()可以执行原生SQL
	//当前执行"SHOW CREATE TABLE hotel_list"获取指定表的建表语句
	rows, err := db.Raw("SHOW CREATE TABLE hotel_list").Rows()
	if err != nil {
		fmt.Println(err)
	}
	for rows.Next() {
		var table string
		var sql string
		rows.Scan(&table, &sql)
		fmt.Println(table, sql)
	}

	//1.通过Migrator()方法可以获取到一个Migrator迁移器
	//这个迁移器内部提供了一些用来创建,修改,删除,以及获取数据库的元数据信息的方法

	//我们可以点进去看一下
	migrator := db.Migrator()

	//3.调用ColumnTypes()获取表的列类型信息
	columns, err := migrator.ColumnTypes("hotel_list")
	if err != nil {
		fmt.Println(err)
	}
	for _, column := range columns {
		// 获取列的名称,类型,长度,是否为空等信息
		name := column.Name()
		valtype := column.DatabaseTypeName()
		length, _ := column.Length()
		nullable, _ := column.Nullable()
		valDefault, _ := column.DefaultValue()
		comment, _ := column.Comment()
		fmt.Printf("name: %v, valtype: %v,length %v, nullable %v, valDefault %v, comment %v", name, valtype, length, nullable, valDefault, comment)

	}

	//4.调用
	Index, err := migrator.GetIndexes("users")

	tab, err := migrator.TableType("users")
	fmt.Print(Index, tab)

}

gorm Migrator 读取数据库表结构 + jen 生成对应结构体

package main

import (
	"github.com/dave/jennifer/jen"
	_ "github.com/go-sql-driver/mysql"
	"gorm.io/driver/mysql"
	"gorm.io/gorm" //TODO 注意与"github.com/jinzhu/gorm"的不同,"github.com/jinzhu/gorm"是gorm的老版本
	"gorm.io/gorm/logger"
	"strings"
)

var db *gorm.DB

func init() {
	// 连接数据库,开启日志模式
	dsn := "账号:密码@tcp(ip:3306)/lm_pms?charset=utf8mb4&parseTime=True&loc=Local"
	db, err := gorm.Open(
		mysql.Open(dsn),
		&gorm.Config{Logger: logger.Default.LogMode(logger.Info)},
	)
	if err != nil {
		panic("failed to connect database")
	}
	db = db
}

func main() {
	//1.创建一个jen文件对象,用于生成代码
	f := jen.NewFile("main")

	//2.需要生成结构体的表名
	dbTableName := "pms_billpay_log"

	//3.获取迁移器,通过迁移器获取表信息,字段信息
	migrator := db.Migrator()
	tab, err := migrator.TableType(dbTableName)
	if nil != err {
		panic(err)
	}
	columns, err := migrator.ColumnTypes(dbTableName)
	if err != nil {
		panic(err)
	}

	//4.获取表名,将表名转换为驼峰命名并且首字母大写
	name := underscoreToCamelCase(tab.Name())

	//5.生成结构体
	f.Type().Id(name).StructFunc(func(g *jen.Group) {
		for _, c := range columns {

			//5.1获取属性名
			statement := g.Id(underscoreToCamelCase(c.Name()))
			//5.2获取属性类型,将mysql数据类型转换为golang中使用的
			//TODO 注意此处只是简单的判断了几个类型,要根据实际需求进行转换
			cType := c.DatabaseTypeName()
			switch cType {
			case "int", "tinyint", "smallint":
				statement.Int()
			case "mediumint", "bigint":
				statement.Qual("github.com/shopspring/decimal", "Decimal")
			case "float":
				statement.Float32()
			case "double":
				statement.Float64()
			case "decimal":
				statement.Qual("github.com/shopspring/decimal", "Decimal")
			case "char", "varchar", "text", "tinytext", "mediumtext", "longtext", "blob":
				statement.String()
			case "date", "time", "datetime", "timestamp":
				statement.Qual("time", "Time")
			case "bit", "bool":
				statement.Bool()
			case "json":
				statement.Interface()
			default:
				statement.Interface()
			}
			//5.3获取属性后的tag,设置tag
			tagM := getTagMap(c)
			statement.Tag(tagM)
		}
	})

	//6.生成保存结构体的文件
	f.Save("example.go")
}

// 组装属性tag
// TODO 在组装tag时,并不是越多越好,例如默认值, 是否允许为空,数据类型,数据长度等,如果代码中使用不到,就不要组装了,因为在后续迭代过程中,由于遗漏,可能会出现数据库表设置与当前代码tag中的不一致
func getTagMap(column gorm.ColumnType) map[string]string {
	tagStr := ""
	if f, _ := column.PrimaryKey(); f {
		tagStr += "primary_key;"
	}
	if f, _ := column.Unique(); f {
		tagStr += "Unique;"
	}
	/*if f, _ := column.AutoIncrement(); f {
		tagStr += "autoIncrement;"
	}*/

	if f, _ := column.Nullable(); !f {
		tagStr += "not null;"
	}
	if t, ok := column.ColumnType(); ok {
		tagStr += "type:" + t + ";"
	}
	/*if le, ok := column.Length(); ok {
		tagStr += "(" + strconv.FormatInt(le, 10) + ");"
	} else {
		tagStr += ";"
	}*/

	tagStr += "column:" + column.Name() + ";"
	if comment, ok := column.Comment(); ok && len(comment) > 0 {
		tagStr += "comment:'" + comment + "';"
	}

	if defaultVal, ok := column.DefaultValue(); ok && len(defaultVal) > 0 {
		tagStr += "default:" + defaultVal + ";"
	}

	m := make(map[string]string)
	m["gorm"] = tagStr
	name := column.Name()
	m["json"] = lowerFirst(lowerFirst(underscoreToCamelCase(name)))
	return m
}

// 将带下划线的数据库命名转换为驼峰命名,首字母大写
func underscoreToCamelCase(name string) string {
	result := ""
	if !strings.Contains(name, "_") {
		first := strings.Title(name[:1])
		return first + name[1:]
	}
	parts := strings.Split(name, "_")

	// 遍历字符串切片,对每个部分进行转换
	for _, part := range parts {
		// 使用Title函数,将每个部分的首字母转换为大写
		part = strings.Title(part)
		// 使用Replace函数,将每个部分的其他字母转换为小写
		part = strings.Replace(part, part[1:], strings.ToLower(part[1:]), 1)
		// 将转换后的部分拼接到结果字符串中
		result += part
	}
	// 返回结果字符串
	return result
}

// 字符串首字母小写
func lowerFirst(str string) string {
	// 判断字符串是否为空
	if str == "" {
		// 如果为空,直接返回空字符串
		return ""
	}
	// 使用strings包的ToLower方法,将字符串的第一个字符转换为小写
	first := strings.ToLower(str[:1])
	// 将转换后的第一个字符和剩余的字符串拼接起来,返回结果
	return first + str[1:]
}

三. github.com/jinzhu/gorm 老版本 gorm + jen 生成表结构

  1. 老版本的gorm中没有Migrator , 所以编写专门读取指定表信息的sql,例如读取指定表的字段名,字段类型,字段索引,最大长度,默认值,是否运行为null,注释等信息,然后使用jen生成对应该表的结构体
package main

import (
	"github.com/dave/jennifer/jen"
	_ "github.com/go-sql-driver/mysql"
	"github.com/jinzhu/gorm" //TODO 注意与"gorm.io/gorm"的不同,"github.com/jinzhu/gorm"是gorm的老版本
	"log"
	"strconv"
	"strings"
)

// 1.定义用于存储表字段信息的结构体
type Column struct {
	ColumnName             string `gorm:"column:COLUMN_NAME"`              //字段名称
	ColumnKey              string `gorm:"column:COLUMN_KEY"`               //索引类型
	DataType               string `gorm:"column:DATA_TYPE"`                //数据类型
	CharacterMaximumLength int    `gorm:"column:CHARACTER_MAXIMUM_LENGTH"` //设置的最大长度限制
	IsNullable             string `gorm:"column:IS_NULLABLE"`              //是否可以为空
	ColumnDefault          string `gorm:"column:COLUMN_DEFAULT"`           //默认值
	ColumnComment          string `gorm:"column:COLUMN_COMMENT"`           //注释
}

// 2.数据库地址
const dsn = "账号:密码@tcp(ip:3306)/lm_pms?charset=utf8mb4&parseTime=True&loc=Local"

// 3.需要生成的表所在库名
const databaseName = "lm_pms"

// 4.需要生成的表名
const tableName = "pms_billpay_log"

// 5.表名的驼峰命名并且首字母大写
var TabNameCamelCaseUp = underscoreToCamelCase(tableName)

// 6.表名的驼峰命名,并且首字母小写
var TabNameCamelCaselower = lowerFirst(TabNameCamelCaseUp)

func main() {

	//1.获取数据库连接
	db, err := gorm.Open("mysql", dsn)
	if err != nil {
		log.Fatal(err)
	}
	defer db.Close()

	//2.执行sql查询指定表的所有字段信息,包括(字段名,数据类型,索引,设置的最大长度,是否允许为null,默认值,注释)
	var columns []Column
	sql := "SELECT COLUMN_NAME, DATA_TYPE, COLUMN_KEY, CHARACTER_MAXIMUM_LENGTH, IS_NULLABLE, COLUMN_DEFAULT, COLUMN_COMMENT  FROM information_schema.columns WHERE table_schema = '" + databaseName + "' AND table_name = '" + tableName + "';"
	err = db.Debug().Raw(sql).Scan(&columns).Error
	if err != nil {
		panic(err)
	}

	//3.使用jen开始创建并生成对应表的代码,指定代码所在包
	f := jen.NewFile("main")

	//4.封装结构体
	f.Type().Id(TabNameCamelCaseUp).StructFunc(func(g *jen.Group) {
		for _, c := range columns {

			//4.1将数据库字段名转换为结构体需要的属性名
			statement := g.Id(underscoreToCamelCase(c.ColumnName))
			//4.2获取属性类型,将mysql数据类型转换为golang中使用的
			//TODO 注意此处只是简单的判断了几个类型,要根据实际需求进行转换
			switch c.DataType {
			case "int", "tinyint", "smallint":
				statement.Int()
			case "mediumint", "bigint":
				statement.Qual("github.com/shopspring/decimal", "Decimal")
			case "float":
				statement.Float32()
			case "double":
				statement.Float64()
			case "decimal":
				statement.Qual("github.com/shopspring/decimal", "Decimal")
			case "char", "varchar", "text", "tinytext", "mediumtext", "longtext", "blob":
				statement.String()
			case "date", "time", "datetime", "timestamp":
				statement.Qual("time", "Time")
			case "bit", "bool":
				statement.Bool()
			case "json":
				statement.Interface()
			default:
				statement.Interface()
			}
			//4.3获取属性后的tag,设置tag
			tagM := getTagMap(c)
			statement.Tag(tagM)
		}
	})

	f.Comment("TODO 代表数据库连接的全局变量,项目启动时先初始化数据库连接,在执行操作数据库方法时就可以直接使用全局变量的连接,不要来回传递了," +
		"此处是为了防止生成的代码报错添加的,项目中如果已经存在数据库连接变量,将生成代码中的这个变量删除,直接使用已经存在的就可以")
	f.Var().Id("db").Op("*").Qual("github.com/jinzhu/gorm", "DB")

	//6.生成TableName方法
	f.Func().Params(jen.Id("c").Op("*").Id(TabNameCamelCaseUp)).Id("TableName").Params().String().Block(
		jen.Return(jen.Lit(tableName)),
	)

	//7.生成Add方法
	f.Func().Params(jen.Id("c").Op("*").Id(TabNameCamelCaseUp)).Id("Add").Params().Error().Block(
		jen.Return(jen.Id("db").Dot("Create").Call(jen.Id("c")).Dot("Error")),
	)

	//8.生成Update方法
	f.Func().Params(jen.Id("c").Op("*").Id(TabNameCamelCaseUp)).Id("Update").Params().Error().Block(
		jen.Return(jen.Id("db").Dot("Model").Call(jen.Id("c")).Dot("Update").Call(jen.Id("c")).Dot("Error")),
	)

	//9.生成UpdateSave方法
	f.Func().Params(jen.Id("c").Op("*").Id(TabNameCamelCaseUp)).Id("UpdateSave").Params().Error().Block(
		jen.Return(jen.Id("db").Dot("Model").Call(jen.Id("c")).Dot("Save").Call(jen.Id("c")).Dot("Error")),
	)

	//10.生成SearchFirst方法
	f.Func().Params(jen.Id("c").Op("*").Id(TabNameCamelCaseUp)).Id("SearchFirst").Params().Error().Block(
		jen.Return(jen.Id("db").Dot("Model").Call(jen.Id("c")).Dot("Where").Call(jen.Id("c")).Dot("First").Call(jen.Id("c")).Dot("Error")),
	)

	//11.生成Search方法
	f.Func().Params(jen.Id("c").Op("*").Id(TabNameCamelCaseUp)).Id("Search").Params().Params(jen.Index().Id(TabNameCamelCaseUp), jen.Error()).Block(
		jen.Var().Id(TabNameCamelCaselower+"List").Index().Id(TabNameCamelCaseUp),
		jen.Err().Op(":=").Id("db").Dot("Model").Call(jen.Id("c")).Dot("Where").Call(jen.Id("c")).Dot("Find").Call(jen.Op("&").Id(TabNameCamelCaselower+"List")).Dot("Error"),
		jen.Return(jen.Id(TabNameCamelCaselower+"List"), jen.Err()),
	)

	//12.生成SelectById方法
	f.Func().Id("SelectById").Params(jen.Id("id").Int()).Params(jen.Op("*").Id(TabNameCamelCaseUp), jen.Error()).Block(
		jen.Var().Id(TabNameCamelCaselower).Id(TabNameCamelCaseUp),
		jen.Err().Op(":=").Id("db").Dot("First").Call(jen.Op("&").Id(TabNameCamelCaselower), jen.Id("id")).Dot("Error"),
		jen.If(jen.Err().Op("!=").Nil()).Block(
			jen.Return(jen.Nil(), jen.Err()),
		),
		jen.Return(jen.Op("&").Id(TabNameCamelCaselower), jen.Nil()),
	)

	//13.生成SelectByPage方法
	f.Func().Params(jen.Id("c").Op("*").Id(TabNameCamelCaseUp)).Id("SelectByPage").Params(jen.Id("pageNo").Int(), jen.Id("pageSize").Int()).Params(jen.Index().Id(TabNameCamelCaseUp), jen.Int(), jen.Error()).Block(
		jen.Var().Id(TabNameCamelCaselower+"List").Index().Id(TabNameCamelCaseUp),
		jen.Var().Id("count").Int(),
		jen.Id("db").Op("=").Id("db").Dot("Model").Call(jen.Op("&").Id(TabNameCamelCaseUp+"{}")).Dot("Where").Call(jen.Id("c")),
		jen.Id("limit").Op(":=").Id("pageSize"),
		jen.Id("offset").Op(":=").Id("pageSize").Op("*").Call(jen.Id("pageNo").Op("-").Lit(1)),
		jen.Comment("注意Limit 方法和 Offset 方法必须在 Find 方法之前调用,否则会出现错误。"),
		jen.Err().Op(":=").Id("db").Dot("Count").Call(jen.Op("&").Id("count")).Dot("Limit").Call(jen.Id("limit")).Dot("Offset").Call(jen.Id("offset")).Dot("Find").Call(jen.Op("&").Id(TabNameCamelCaselower+"List")).Dot("Error"),
		jen.If(jen.Err().Op("!=").Nil()).Block(
			jen.Return(jen.Nil(), jen.Lit(0), jen.Err()),
		),
		jen.Return(jen.Id(TabNameCamelCaselower+"List"), jen.Id("count"), jen.Nil()),
	)

	//14.保存文件
	err = f.Save(tableName + ".go")
	if err != nil {
		log.Fatal(err)
	}
}

// 组装属性tag
// TODO 在组装tag时,并不是越多越好,例如默认值, 是否允许为空,数据类型,数据长度等,如果代码中使用不到,就不要组装了,因为在后续迭代过程中,由于遗漏,可能会出现数据库表设置与当前代码tag中的不一致
func getTagMap(column Column) map[string]string {
	tagStr := ""
	//拼接索引
	//TODO 当前只判断了主键索引
	if column.ColumnKey == "PRI" {
		tagStr += "primary_key;"
	}

	if column.IsNullable == "NO" {
		tagStr += "not null;"
	}

	tagStr += column.DataType
	if column.CharacterMaximumLength > 0 {
		tagStr += "(" + strconv.Itoa(column.CharacterMaximumLength) + ")"
	}
	tagStr += ";"

	tagStr += "column:" + column.ColumnName + ";"
	if len(column.ColumnComment) > 0 {
		tagStr += "comment:'" + column.ColumnComment + "';"
	}

	if len(column.ColumnDefault) > 0 {
		tagStr += "default:" + column.ColumnDefault + ";"
	}

	m := make(map[string]string)
	m["gorm"] = tagStr
	m["json"] = lowerFirst(lowerFirst(underscoreToCamelCase(column.ColumnName)))
	return m
}

// 将带下划线的数据库命名转换为驼峰命名,首字母大写
func underscoreToCamelCase(name string) string {
	result := ""
	if !strings.Contains(name, "_") {
		first := strings.Title(name[:1])
		return first + name[1:]
	}
	parts := strings.Split(name, "_")

	// 遍历字符串切片,对每个部分进行转换
	for _, part := range parts {
		// 使用Title函数,将每个部分的首字母转换为大写
		part = strings.Title(part)
		// 使用Replace函数,将每个部分的其他字母转换为小写
		part = strings.Replace(part, part[1:], strings.ToLower(part[1:]), 1)
		// 将转换后的部分拼接到结果字符串中
		result += part
	}
	// 返回结果字符串
	return result
}

// 字符串首字母小写
func lowerFirst(str string) string {
	// 判断字符串是否为空
	if str == "" {
		// 如果为空,直接返回空字符串
		return ""
	}
	// 使用strings包的ToLower方法,将字符串的第一个字符转换为小写
	first := strings.ToLower(str[:1])
	// 将转换后的第一个字符和剩余的字符串拼接起来,返回结果
	return first + str[1:]
}
  • 10
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值