很多时候sql查询都不允许select * 出现要求指定列名,如果你是用gormv2 ,恭喜你可以使用QueryFields属性,如果是gormv1版本怎么样,难道要升级gormV2吗,这里提供种反射的实现,可能不是最优解,但只是一个方案。
首先mysq建一个表
CREATE TABLE `test` (
`id` BIGINT(20) NOT NULL,
`name` VARCHAR(5) DEFAULT NULL,
`age` INT(11) DEFAULT NULL,
PRIMARY KEY (`id`)
) ENGINE=INNODB DEFAULT CHARSET=utf8mb4
go的代码如下:
package main
import (
"fmt"
"gorm.io/driver/mysql"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"reflect"
"strings"
_ "github.com/go-sql-driver/mysql"
"gorm.io/gorm"
)
func main() {
dns := "root:root@tcp(192.168.100.30:3306)/demo?charset=utf8&parseTime=True&loc=Local"
config := &gorm.Config{
NamingStrategy: schema.NamingStrategy{
SingularTable: true,
},
//QueryFields: true,
Logger: logger.Default.LogMode(logger.Info),
}
db, err := gorm.Open(mysql.Open(dns), config)
if err != nil {
fmt.Println(fmt.Sprintf("Open err:%v", err))
}
var ret []*Test
//一般查询
fmt.Println("一般查询")
err = db.Table("test").Where("id>1").Find(&ret).Debug().Error
if err != nil {
fmt.Println(fmt.Sprintf("select err:%v", err))
}
//通过反射指定列
fmt.Println("通过反射指定列")
err = db.Table("test").Where("id>1").Select(GetAllFields(new(Test))).Find(&ret).Debug().Error
if err != nil {
fmt.Println(fmt.Sprintf("select err:%v", err))
}
//通过QueryFields 指定列
fmt.Println("通过QueryFields 指定列")
tx := db.Table("test")
tx.QueryFields = true
err = tx.Where("id>1").Find(&ret).Debug().Error
if err != nil {
fmt.Println(fmt.Sprintf("select err:%v", err))
}
}
type Test struct {
ID int64 `gorm:"type:bigint(20);column:id;primary_key"`
Name string `gorm:"type:varchar(5);column:name"`
Age int `gorm:"type:int(11);column:age"`
}
func GetAllFields(info interface{}) string {
tagName := strings.ToUpper("column")
var arr []string
el := reflect.TypeOf(info).Elem()
for i := 0; i < el.NumField(); i++ {
structTag := el.Field(i).Tag
tags := parseTagSetting(structTag)
if columnName, ok := tags[tagName]; ok && columnName != "-" {
arr = append(arr, columnName)
}
}
sqlFields := "`" + strings.Join(arr, "`,`") + "`"
return sqlFields
}
func parseTagSetting(tags reflect.StructTag) map[string]string {
setting := map[string]string{}
for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
if str == "" {
continue
}
tags := strings.Split(str, ";")
for _, value := range tags {
v := strings.Split(value, ":")
k := strings.TrimSpace(strings.ToUpper(v[0]))
if len(v) >= 2 {
setting[k] = strings.Join(v[1:], ":")
} else {
setting[k] = k
}
}
}
return setting
}
运行结果: