一. 先了解jen代码生成
有一个专门用于代码生成的三方库"github.com/dave/jennifer/jen",下方一个简单的示例 它可以让你以编程的方式创建和修改.go文件
package main
import "github.com/dave/jennifer/jen"
func main ( ) {
f := jen. NewFile ( "main" )
f. Type ( ) . Id ( "Person" ) . Struct (
jen. Id ( "Age" ) . Int ( ) ,
jen. Id ( "ID" ) . Int64 ( ) . Tag ( map [ string ] string {
"gorm" : "primary_key" ,
} ) ,
jen. Id ( "Name" ) . String ( ) . Tag ( map [ string ] string {
"json" : "name" ,
} ) ,
)
f. Func ( ) . Id ( "main" ) . Params ( ) . Block (
jen. Id ( "p" ) . Op ( ":=" ) . Id ( "Person" ) . Values (
jen. Id ( "Name" ) . Op ( ":" ) . Lit ( "Alice" ) ,
jen. Id ( "Age" ) . Op ( ":" ) . Lit ( 20 ) ,
) ,
jen. Id ( "p" ) . Dot ( "PrintInfo" ) . Call ( ) ,
)
f. Func ( ) . Params ( jen. Id ( "p" ) . Op ( "*" ) . Id ( "Person" ) ) . Id ( "PrintInfo" ) . Params ( ) . Block (
jen. Qual ( "fmt" , "Printf" ) . Call (
jen. Lit ( "Name: %s, Age: %d\n" ) ,
jen. Id ( "p" ) . Dot ( "Name" ) ,
jen. Id ( "p" ) . Dot ( "Age" ) ,
) ,
)
f. Save ( "example.go" )
}
执行完毕后会生成如下example.go文件
package main
import "fmt"
type Person struct {
Age int
ID int64 `gorm:"primary_key"`
Name string `json:"name"`
}
func main ( ) {
p := Person{ Name: "Alice" , Age: 20 }
p. PrintInfo ( )
}
func ( p * Person) PrintInfo ( ) {
fmt. Printf ( "Name: %s, Age: %d\n" , p. Name, p. Age)
}
二. gorm.io/gorm 新版本 gorm + jen 生成表对应结构体
介绍 gorm Migrator
在使用新版本的gorm时底层有一个Migrator迁移器,内部提供了一些用来创建,修改,删除,以及获取数据库的元数据信息的方法
CurrentDatabase:用于获取当前数据库的名称
TableType:用于获取一个表的类型,是普通表还是视图,返回一个TableType类型的值
ColumnTypes: 用于获取表字段信息的
GetIndexes: 用户获取索引信息的
AddColumn:用于给一个表添加一个列,根据你定义的结构体的字段来生成对应的类型和约束
AutoMigrate:用于自动迁移你的schema,保持你的schema是最新的。它会创建表,缺少的外键,约束,列和索引,并且会更改现有列的类型(如果其大小、精度、是否为空可更改)
FullDataTypeOf:用于获取给定字段的完整数据类型,包括长度、精度、是否为空等
GetTypeAliases:用于获取给定数据库类型名称的别名,例如int 和integer
CreateTable:用于创建一个或多个表,根据你定义的结构体来生成对应的字段和类型,以及主键、外键、索引等约束
DropTable:用于删除一个或多个表,如果表不存在,不会报错
HasTable:用于检查一个表是否存在,可以传入表名或结构体
RenameTable:用于重命名一个表,可以传入旧表名和新表名,或者旧结构体和新结构体
GetTables:用于获取数据库中所有表的名称,返回一个字符串切片
package main
import (
"fmt"
"github.com/dave/jennifer/jen"
_ "github.com/go-sql-driver/mysql"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"strconv"
"strings"
)
var db * gorm. DB
var err error
func init ( ) {
dsn := "账号:密码@tcp(ip:3306)/lm_pms?charset=utf8mb4&parseTime=True&loc=Local"
db, err = gorm. Open (
mysql. Open ( dsn) ,
& gorm. Config{ Logger: logger. Default. LogMode ( logger. Info) } ,
)
if err != nil {
panic ( "failed to connect database" )
}
}
func main ( ) {
rows, err := db. Raw ( "SHOW CREATE TABLE hotel_list" ) . Rows ( )
if err != nil {
fmt. Println ( err)
}
for rows. Next ( ) {
var table string
var sql string
rows. Scan ( & table, & sql)
fmt. Println ( table, sql)
}
migrator := db. Migrator ( )
columns, err := migrator. ColumnTypes ( "hotel_list" )
if err != nil {
fmt. Println ( err)
}
for _ , column := range columns {
name := column. Name ( )
valtype := column. DatabaseTypeName ( )
length, _ := column. Length ( )
nullable, _ := column. Nullable ( )
valDefault, _ := column. DefaultValue ( )
comment, _ := column. Comment ( )
fmt. Printf ( "name: %v, valtype: %v,length %v, nullable %v, valDefault %v, comment %v" , name, valtype, length, nullable, valDefault, comment)
}
Index, err := migrator. GetIndexes ( "users" )
tab, err := migrator. TableType ( "users" )
fmt. Print ( Index, tab)
}
gorm Migrator 读取数据库表结构 + jen 生成对应结构体
package main
import (
"github.com/dave/jennifer/jen"
_ "github.com/go-sql-driver/mysql"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"strings"
)
var db * gorm. DB
func init ( ) {
dsn := "账号:密码@tcp(ip:3306)/lm_pms?charset=utf8mb4&parseTime=True&loc=Local"
db, err := gorm. Open (
mysql. Open ( dsn) ,
& gorm. Config{ Logger: logger. Default. LogMode ( logger. Info) } ,
)
if err != nil {
panic ( "failed to connect database" )
}
db = db
}
func main ( ) {
f := jen. NewFile ( "main" )
dbTableName := "pms_billpay_log"
migrator := db. Migrator ( )
tab, err := migrator. TableType ( dbTableName)
if nil != err {
panic ( err)
}
columns, err := migrator. ColumnTypes ( dbTableName)
if err != nil {
panic ( err)
}
name := underscoreToCamelCase ( tab. Name ( ) )
f. Type ( ) . Id ( name) . StructFunc ( func ( g * jen. Group) {
for _ , c := range columns {
statement := g. Id ( underscoreToCamelCase ( c. Name ( ) ) )
cType := c. DatabaseTypeName ( )
switch cType {
case "int" , "tinyint" , "smallint" :
statement. Int ( )
case "mediumint" , "bigint" :
statement. Qual ( "github.com/shopspring/decimal" , "Decimal" )
case "float" :
statement. Float32 ( )
case "double" :
statement. Float64 ( )
case "decimal" :
statement. Qual ( "github.com/shopspring/decimal" , "Decimal" )
case "char" , "varchar" , "text" , "tinytext" , "mediumtext" , "longtext" , "blob" :
statement. String ( )
case "date" , "time" , "datetime" , "timestamp" :
statement. Qual ( "time" , "Time" )
case "bit" , "bool" :
statement. Bool ( )
case "json" :
statement. Interface ( )
default :
statement. Interface ( )
}
tagM := getTagMap ( c)
statement. Tag ( tagM)
}
} )
f. Save ( "example.go" )
}
func getTagMap ( column gorm. ColumnType) map [ string ] string {
tagStr := ""
if f, _ := column. PrimaryKey ( ) ; f {
tagStr += "primary_key;"
}
if f, _ := column. Unique ( ) ; f {
tagStr += "Unique;"
}
if f, _ := column. Nullable ( ) ; ! f {
tagStr += "not null;"
}
if t, ok := column. ColumnType ( ) ; ok {
tagStr += "type:" + t + ";"
}
tagStr += "column:" + column. Name ( ) + ";"
if comment, ok := column. Comment ( ) ; ok && len ( comment) > 0 {
tagStr += "comment:'" + comment + "';"
}
if defaultVal, ok := column. DefaultValue ( ) ; ok && len ( defaultVal) > 0 {
tagStr += "default:" + defaultVal + ";"
}
m := make ( map [ string ] string )
m[ "gorm" ] = tagStr
name := column. Name ( )
m[ "json" ] = lowerFirst ( lowerFirst ( underscoreToCamelCase ( name) ) )
return m
}
func underscoreToCamelCase ( name string ) string {
result := ""
if ! strings. Contains ( name, "_" ) {
first := strings. Title ( name[ : 1 ] )
return first + name[ 1 : ]
}
parts := strings. Split ( name, "_" )
for _ , part := range parts {
part = strings. Title ( part)
part = strings. Replace ( part, part[ 1 : ] , strings. ToLower ( part[ 1 : ] ) , 1 )
result += part
}
return result
}
func lowerFirst ( str string ) string {
if str == "" {
return ""
}
first := strings. ToLower ( str[ : 1 ] )
return first + str[ 1 : ]
}
三. github.com/jinzhu/gorm 老版本 gorm + jen 生成表结构
老版本的gorm中没有Migrator , 所以编写专门读取指定表信息的sql,例如读取指定表的字段名,字段类型,字段索引,最大长度,默认值,是否运行为null,注释等信息,然后使用jen生成对应该表的结构体
package main
import (
"github.com/dave/jennifer/jen"
_ "github.com/go-sql-driver/mysql"
"github.com/jinzhu/gorm"
"log"
"strconv"
"strings"
)
type Column struct {
ColumnName string `gorm:"column:COLUMN_NAME"`
ColumnKey string `gorm:"column:COLUMN_KEY"`
DataType string `gorm:"column:DATA_TYPE"`
CharacterMaximumLength int `gorm:"column:CHARACTER_MAXIMUM_LENGTH"`
IsNullable string `gorm:"column:IS_NULLABLE"`
ColumnDefault string `gorm:"column:COLUMN_DEFAULT"`
ColumnComment string `gorm:"column:COLUMN_COMMENT"`
}
const dsn = "账号:密码@tcp(ip:3306)/lm_pms?charset=utf8mb4&parseTime=True&loc=Local"
const databaseName = "lm_pms"
const tableName = "pms_billpay_log"
var TabNameCamelCaseUp = underscoreToCamelCase ( tableName)
var TabNameCamelCaselower = lowerFirst ( TabNameCamelCaseUp)
func main ( ) {
db, err := gorm. Open ( "mysql" , dsn)
if err != nil {
log. Fatal ( err)
}
defer db. Close ( )
var columns [ ] Column
sql := "SELECT COLUMN_NAME, DATA_TYPE, COLUMN_KEY, CHARACTER_MAXIMUM_LENGTH, IS_NULLABLE, COLUMN_DEFAULT, COLUMN_COMMENT FROM information_schema.columns WHERE table_schema = '" + databaseName + "' AND table_name = '" + tableName + "';"
err = db. Debug ( ) . Raw ( sql) . Scan ( & columns) . Error
if err != nil {
panic ( err)
}
f := jen. NewFile ( "main" )
f. Type ( ) . Id ( TabNameCamelCaseUp) . StructFunc ( func ( g * jen. Group) {
for _ , c := range columns {
statement := g. Id ( underscoreToCamelCase ( c. ColumnName) )
switch c. DataType {
case "int" , "tinyint" , "smallint" :
statement. Int ( )
case "mediumint" , "bigint" :
statement. Qual ( "github.com/shopspring/decimal" , "Decimal" )
case "float" :
statement. Float32 ( )
case "double" :
statement. Float64 ( )
case "decimal" :
statement. Qual ( "github.com/shopspring/decimal" , "Decimal" )
case "char" , "varchar" , "text" , "tinytext" , "mediumtext" , "longtext" , "blob" :
statement. String ( )
case "date" , "time" , "datetime" , "timestamp" :
statement. Qual ( "time" , "Time" )
case "bit" , "bool" :
statement. Bool ( )
case "json" :
statement. Interface ( )
default :
statement. Interface ( )
}
tagM := getTagMap ( c)
statement. Tag ( tagM)
}
} )
f. Comment ( "TODO 代表数据库连接的全局变量,项目启动时先初始化数据库连接,在执行操作数据库方法时就可以直接使用全局变量的连接,不要来回传递了," +
"此处是为了防止生成的代码报错添加的,项目中如果已经存在数据库连接变量,将生成代码中的这个变量删除,直接使用已经存在的就可以" )
f. Var ( ) . Id ( "db" ) . Op ( "*" ) . Qual ( "github.com/jinzhu/gorm" , "DB" )
f. Func ( ) . Params ( jen. Id ( "c" ) . Op ( "*" ) . Id ( TabNameCamelCaseUp) ) . Id ( "TableName" ) . Params ( ) . String ( ) . Block (
jen. Return ( jen. Lit ( tableName) ) ,
)
f. Func ( ) . Params ( jen. Id ( "c" ) . Op ( "*" ) . Id ( TabNameCamelCaseUp) ) . Id ( "Add" ) . Params ( ) . Error ( ) . Block (
jen. Return ( jen. Id ( "db" ) . Dot ( "Create" ) . Call ( jen. Id ( "c" ) ) . Dot ( "Error" ) ) ,
)
f. Func ( ) . Params ( jen. Id ( "c" ) . Op ( "*" ) . Id ( TabNameCamelCaseUp) ) . Id ( "Update" ) . Params ( ) . Error ( ) . Block (
jen. Return ( jen. Id ( "db" ) . Dot ( "Model" ) . Call ( jen. Id ( "c" ) ) . Dot ( "Update" ) . Call ( jen. Id ( "c" ) ) . Dot ( "Error" ) ) ,
)
f. Func ( ) . Params ( jen. Id ( "c" ) . Op ( "*" ) . Id ( TabNameCamelCaseUp) ) . Id ( "UpdateSave" ) . Params ( ) . Error ( ) . Block (
jen. Return ( jen. Id ( "db" ) . Dot ( "Model" ) . Call ( jen. Id ( "c" ) ) . Dot ( "Save" ) . Call ( jen. Id ( "c" ) ) . Dot ( "Error" ) ) ,
)
f. Func ( ) . Params ( jen. Id ( "c" ) . Op ( "*" ) . Id ( TabNameCamelCaseUp) ) . Id ( "SearchFirst" ) . Params ( ) . Error ( ) . Block (
jen. Return ( jen. Id ( "db" ) . Dot ( "Model" ) . Call ( jen. Id ( "c" ) ) . Dot ( "Where" ) . Call ( jen. Id ( "c" ) ) . Dot ( "First" ) . Call ( jen. Id ( "c" ) ) . Dot ( "Error" ) ) ,
)
f. Func ( ) . Params ( jen. Id ( "c" ) . Op ( "*" ) . Id ( TabNameCamelCaseUp) ) . Id ( "Search" ) . Params ( ) . Params ( jen. Index ( ) . Id ( TabNameCamelCaseUp) , jen. Error ( ) ) . Block (
jen. Var ( ) . Id ( TabNameCamelCaselower+ "List" ) . Index ( ) . Id ( TabNameCamelCaseUp) ,
jen. Err ( ) . Op ( ":=" ) . Id ( "db" ) . Dot ( "Model" ) . Call ( jen. Id ( "c" ) ) . Dot ( "Where" ) . Call ( jen. Id ( "c" ) ) . Dot ( "Find" ) . Call ( jen. Op ( "&" ) . Id ( TabNameCamelCaselower+ "List" ) ) . Dot ( "Error" ) ,
jen. Return ( jen. Id ( TabNameCamelCaselower+ "List" ) , jen. Err ( ) ) ,
)
f. Func ( ) . Id ( "SelectById" ) . Params ( jen. Id ( "id" ) . Int ( ) ) . Params ( jen. Op ( "*" ) . Id ( TabNameCamelCaseUp) , jen. Error ( ) ) . Block (
jen. Var ( ) . Id ( TabNameCamelCaselower) . Id ( TabNameCamelCaseUp) ,
jen. Err ( ) . Op ( ":=" ) . Id ( "db" ) . Dot ( "First" ) . Call ( jen. Op ( "&" ) . Id ( TabNameCamelCaselower) , jen. Id ( "id" ) ) . Dot ( "Error" ) ,
jen. If ( jen. Err ( ) . Op ( "!=" ) . Nil ( ) ) . Block (
jen. Return ( jen. Nil ( ) , jen. Err ( ) ) ,
) ,
jen. Return ( jen. Op ( "&" ) . Id ( TabNameCamelCaselower) , jen. Nil ( ) ) ,
)
f. Func ( ) . Params ( jen. Id ( "c" ) . Op ( "*" ) . Id ( TabNameCamelCaseUp) ) . Id ( "SelectByPage" ) . Params ( jen. Id ( "pageNo" ) . Int ( ) , jen. Id ( "pageSize" ) . Int ( ) ) . Params ( jen. Index ( ) . Id ( TabNameCamelCaseUp) , jen. Int ( ) , jen. Error ( ) ) . Block (
jen. Var ( ) . Id ( TabNameCamelCaselower+ "List" ) . Index ( ) . Id ( TabNameCamelCaseUp) ,
jen. Var ( ) . Id ( "count" ) . Int ( ) ,
jen. Id ( "db" ) . Op ( "=" ) . Id ( "db" ) . Dot ( "Model" ) . Call ( jen. Op ( "&" ) . Id ( TabNameCamelCaseUp+ "{}" ) ) . Dot ( "Where" ) . Call ( jen. Id ( "c" ) ) ,
jen. Id ( "limit" ) . Op ( ":=" ) . Id ( "pageSize" ) ,
jen. Id ( "offset" ) . Op ( ":=" ) . Id ( "pageSize" ) . Op ( "*" ) . Call ( jen. Id ( "pageNo" ) . Op ( "-" ) . Lit ( 1 ) ) ,
jen. Comment ( "注意Limit 方法和 Offset 方法必须在 Find 方法之前调用,否则会出现错误。" ) ,
jen. Err ( ) . Op ( ":=" ) . Id ( "db" ) . Dot ( "Count" ) . Call ( jen. Op ( "&" ) . Id ( "count" ) ) . Dot ( "Limit" ) . Call ( jen. Id ( "limit" ) ) . Dot ( "Offset" ) . Call ( jen. Id ( "offset" ) ) . Dot ( "Find" ) . Call ( jen. Op ( "&" ) . Id ( TabNameCamelCaselower+ "List" ) ) . Dot ( "Error" ) ,
jen. If ( jen. Err ( ) . Op ( "!=" ) . Nil ( ) ) . Block (
jen. Return ( jen. Nil ( ) , jen. Lit ( 0 ) , jen. Err ( ) ) ,
) ,
jen. Return ( jen. Id ( TabNameCamelCaselower+ "List" ) , jen. Id ( "count" ) , jen. Nil ( ) ) ,
)
err = f. Save ( tableName + ".go" )
if err != nil {
log. Fatal ( err)
}
}
func getTagMap ( column Column) map [ string ] string {
tagStr := ""
if column. ColumnKey == "PRI" {
tagStr += "primary_key;"
}
if column. IsNullable == "NO" {
tagStr += "not null;"
}
tagStr += column. DataType
if column. CharacterMaximumLength > 0 {
tagStr += "(" + strconv. Itoa ( column. CharacterMaximumLength) + ")"
}
tagStr += ";"
tagStr += "column:" + column. ColumnName + ";"
if len ( column. ColumnComment) > 0 {
tagStr += "comment:'" + column. ColumnComment + "';"
}
if len ( column. ColumnDefault) > 0 {
tagStr += "default:" + column. ColumnDefault + ";"
}
m := make ( map [ string ] string )
m[ "gorm" ] = tagStr
m[ "json" ] = lowerFirst ( lowerFirst ( underscoreToCamelCase ( column. ColumnName) ) )
return m
}
func underscoreToCamelCase ( name string ) string {
result := ""
if ! strings. Contains ( name, "_" ) {
first := strings. Title ( name[ : 1 ] )
return first + name[ 1 : ]
}
parts := strings. Split ( name, "_" )
for _ , part := range parts {
part = strings. Title ( part)
part = strings. Replace ( part, part[ 1 : ] , strings. ToLower ( part[ 1 : ] ) , 1 )
result += part
}
return result
}
func lowerFirst ( str string ) string {
if str == "" {
return ""
}
first := strings. ToLower ( str[ : 1 ] )
return first + str[ 1 : ]
}