深入了解gorm Scan的使用

前言

在使用gorm查询数据保存时,可以通过Scan快速方便地将数据存储到指定数据类型中,减少数据的手动转存及赋值过程。

使用示例:

type Result struct {
    Name string
    Age  int
}

var result Result
db.Table("users").Select("name, age").Where("name = ?", 3).Scan(&result)

// Raw SQL
db.Raw("SELECT name, age FROM users WHERE name = ?", 3).Scan(&result)

那么,你知道:

  1. Scan支持哪些数据类型吗?
  2. Scan如何确定接收类型的数据与查询数据之间的匹配关系的呢?

我们带着这两个问题去看下相关的源码。

Scan

Scan源码

// Scan scan value to a struct
func (s *DB) Scan(dest interface{}) *DB {
    return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db
}

注释中说是将value scan到struct,实际不只是,后面源码中会给出答案。

Set

Set是将dest存储在DB的values(sync.Map)中,key为gorm:query_destination,方便后续的取出。

// Set set value by name
func (scope *Scope) Set(name string, value interface{}) *Scope {
    scope.db.InstantSet(name, value)
    return scope
}

// InstantSet instant set setting, will affect current db
func (s *DB) InstantSet(name string, value interface{}) *DB {
    s.values.Store(name, value)
    return s
}

// DB contains information for current db connection
type DB struct {
    sync.RWMutex
    Value        interface{}
    Error        error
    RowsAffected int64

    // single db
    db                SQLCommon
    blockGlobalUpdate bool
    logMode           logModeValue
    logger            logger
    search            *search
    values            sync.Map

    // global db
    parent        *DB
    callbacks     *Callback
    dialect       Dialect
    singularTable bool

    // function to be used to override the creating of a new timestamp
    nowFuncOverride func() time.Time
}
queryCallback

查询的具体处理是在gorm/callback_query.go文件中的queryCallback中处理的。

queryCallback包含了所有查询的处理,此处仅关注Scan的处理,其他的处理忽略。

// queryCallback used to query data from database
func queryCallback(scope *Scope) {
    ...

    var (
        isSlice, isPtr bool
        resultType     reflect.Type
        results        = scope.IndirectValue()
    )

    ...
    // 取出存储的dest
    if value, ok := scope.Get("gorm:query_destination"); ok {
        results = indirect(reflect.ValueOf(value))//如果是指针取其指向的值
    }
    // 判断results的类型,如果kind不为slice或struct,则报错
    if kind := results.Kind(); kind == reflect.Slice {//slice的处理
        isSlice = true
        resultType = results.Type().Elem()//获取slice内子元素的类型
        results.Set(reflect.MakeSlice(results.Type(), 0, 0))//根据子元素类型,初始化slice

        if resultType.Kind() == reflect.Ptr {//slice的元素为指针类型的处理
            isPtr = true//标记指针类型
            resultType = resultType.Elem()//取指针指向的具体类型
        }
    } else if kind != reflect.Struct {//非slice及struct的报错处理
        scope.Err(errors.New("unsupported destination, should be slice or struct"))
        return
    }
    // 准备查询
    scope.prepareQuerySQL()
    // 没有错误,开始查询
    if !scope.HasError() {
        scope.db.RowsAffected = 0

        ...
        //  正式开始查询
        if rows, err := scope.SQLDB().Query(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {//查询未出错
            defer rows.Close()

            columns, _ := rows.Columns()//获取列名
            for rows.Next() {//循环处理查询到的所有rows
                scope.db.RowsAffected++

                elem := results
                if isSlice {//slice的处理
                    elem = reflect.New(resultType).Elem()//根据类型构造slice的elem
                }
                // 具体scan的处理
                scope.scan(rows, columns, scope.New(elem.Addr().Interface()).Fields())

                if isSlice {//slice数据的组装
                    if isPtr {//根据是否指针,存储对应的指针或值
                        results.Set(reflect.Append(results, elem.Addr()))
                    } else {
                        results.Set(reflect.Append(results, elem))
                    }
                }
            }

            if err := rows.Err(); err != nil {//查询出错
                scope.Err(err)
            } else if scope.db.RowsAffected == 0 && !isSlice {//未查询到数据,需要注意的是:仅struct时会报错,slice并不会报错
                scope.Err(ErrRecordNotFound)
            }
        }
    }
}

需要注意的是: queryCallback中只检查类型是slice或struct及它们的指针类型,所以Scan至少要求接受数据的类型是slice或struct及它们的指针类型。

queryCallback的关于Scan的处理过程大致如下:

  1. 根据key取出存储在values中的dest,获取其(指针的)值results
  2. 判断results的类型
    • slice处理,获取slice内子元素的类型,初始化slice
    • 非struct及slice报错
  3. 查询数据出错报错处理
  4. 查找数据未出错
    • 获取列名
    • 循环将数据scan到elem中
    • 若是slice,将elem存入slice中
    • 记录获取到的数据条数
  5. 未查找到数据,且不是slice的报未查找到错误
获取接收数据的fields
// Fields get value's fields
func (scope *Scope) Fields() []*Field {
    if scope.fields == nil {
        var (
            fields             []*Field
            indirectScopeValue = scope.IndirectValue()
            isStruct           = indirectScopeValue.Kind() == reflect.Struct
        )

        for _, structField := range scope.GetModelStruct().StructFields {
            if isStruct {
                fieldValue := indirectScopeValue
                for _, name := range structField.Names {
                    if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() {
                        fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
                    }
                    fieldValue = reflect.Indirect(fieldValue).FieldByName(name)
                }
                fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)})
            } else {
                fields = append(fields, &Field{StructField: structField, IsBlank: true})
            }
        }
        scope.fields = &fields
    }

    return *scope.fields
}

GetModelStruct是一个超长长长的func(近500行代码),看着头皮发麻,主要是ModelStruct(声明数据结构的struct)的解析处理。好消息是,如果你比较数据gorm的model规则,这部分不需要具体到每一行去看,着重点关注下面几行代码即可。

func (scope *Scope) GetModelStruct() *ModelStruct {
    var modelStruct ModelStruct
    // Scope value can't be nil
    if scope.Value == nil {
        return &modelStruct
    }

    reflectType := reflect.ValueOf(scope.Value).Type()
    for reflectType.Kind() == reflect.Slice || reflectType.Kind() == reflect.Ptr {
        reflectType = reflectType.Elem()
    }

    // Scope value need to be a struct
    if reflectType.Kind() != reflect.Struct {
        return &modelStruct
    }
    ...
            // Even it is ignored, also possible to decode db value into the field
            if value, ok := field.TagSettingsGet("COLUMN"); ok {
                field.DBName = value
            } else {
                field.DBName = ToColumnName(fieldStruct.Name)
            }

            modelStruct.StructFields = append(modelStruct.StructFields, field)
        }
    ...
    return &modelStruct
}

前面是对接收数据类型的检查,要求子元素必须是struct或其指针类型,否则返回空的ModelStruct。因此,Scan支持的数据类型仅为struct及struct slice以及它们的指针类型。如此,回答了问题1

最后几行的代码意思是:如果指定对应column的列名,则使用指定的列名,否则使用默认规则主动将key转换成对应的列名。

再回过头来看Fields,主要是获取struct(或其指针类型)的fields并完成fieldValue的封装。

scan

scan是具体将数据存入对应fields的过程。

func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) {
    var (
        ignored            interface{}//默认value
        values             = make([]interface{}, len(columns))//存储接收数据的指针类型
        selectFields       []*Field//存储未匹配的接收fileds
        selectedColumnsMap = map[string]int{}//已匹配的到的列
        resetFields        = map[int]*Field{}//需要将数据转换为非指针类型的fields
    )
    // 根据查询数据的列名循环处理
    for index, column := range columns {
        values[index] = &ignored

        // rows.Scan要求所有接收数据的类型为指针类型,因此需要将selectFields转换为指针类型,再接收数据

        selectFields = fields//接收数据fields
        offset := 0
        if idx, ok := selectedColumnsMap[column]; ok {//已完成接收的fields移除
            offset = idx + 1
            selectFields = selectFields[offset:]
        }

        for fieldIndex, field := range selectFields {//循环处理剩余的fields
            if field.DBName == column {//比对查询数据的列名与接收数据的列名,一致则处理数据
                if field.Field.Kind() == reflect.Ptr {//指针类型的处理,直接取指针存入
                    values[index] = field.Field.Addr().Interface()
                } else {// 非指针类型,需要先存指针用以接收数据,后续需要重置为非指针类型
                    reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type))
                    reflectValue.Elem().Set(field.Field.Addr())
                    values[index] = reflectValue.Interface()
                    resetFields[index] = field//需要接收数据后处理
                }

                selectedColumnsMap[column] = offset + fieldIndex //记录已匹配的列

                if field.IsNormal {
                    break
                }
            }
        }
    }

    scope.Err(rows.Scan(values...))//接收数据,rows.Scan要求所有接收数据的类型为指针类型

    for index, field := range resetFields {//非指针类型需要将接收到数据类型转换
        if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() {
            field.Field.Set(v)
        }
    }
}

// Scan copies the columns in the current row into the values pointed
// at by dest. The number of values in dest must be the same as the
// number of columns in Rows.
//
// Scan converts columns read from the database into the following
// common Go types and special types provided by the sql package:
//
//    *string
//    *[]byte
//    *int, *int8, *int16, *int32, *int64
//    *uint, *uint8, *uint16, *uint32, *uint64
//    *bool
//    *float32, *float64
//    *interface{}
//    *RawBytes
//    *Rows (cursor value)
//    any type implementing Scanner (see Scanner docs)
//
// In the most simple case, if the type of the value from the source
// column is an integer, bool or string type T and dest is of type *T,
// Scan simply assigns the value through the pointer.
//
// Scan also converts between string and numeric types, as long as no
// information would be lost. While Scan stringifies all numbers
// scanned from numeric database columns into *string, scans into
// numeric types are checked for overflow. For example, a float64 with
// value 300 or a string with value "300" can scan into a uint16, but
// not into a uint8, though float64(255) or "255" can scan into a
// uint8. One exception is that scans of some float64 numbers to
// strings may lose information when stringifying. In general, scan
// floating point columns into *float64.
//
// If a dest argument has type *[]byte, Scan saves in that argument a
// copy of the corresponding data. The copy is owned by the caller and
// can be modified and held indefinitely. The copy can be avoided by
// using an argument of type *RawBytes instead; see the documentation
// for RawBytes for restrictions on its use.
//
// If an argument has type *interface{}, Scan copies the value
// provided by the underlying driver without conversion. When scanning
// from a source value of type []byte to *interface{}, a copy of the
// slice is made and the caller owns the result.
//
// Source values of type time.Time may be scanned into values of type
// *time.Time, *interface{}, *string, or *[]byte. When converting to
// the latter two, time.RFC3339Nano is used.
//
// Source values of type bool may be scanned into types *bool,
// *interface{}, *string, *[]byte, or *RawBytes.
//
// For scanning into *bool, the source may be true, false, 1, 0, or
// string inputs parseable by strconv.ParseBool.
//
// Scan can also convert a cursor returned from a query, such as
// "select cursor(select * from my_table) from dual", into a
// *Rows value that can itself be scanned from. The parent
// select query will close any cursor *Rows if the parent *Rows is closed.
func (rs *Rows) Scan(dest ...interface{}) error {
    rs.closemu.RLock()

    if rs.lasterr != nil && rs.lasterr != io.EOF {
        rs.closemu.RUnlock()
        return rs.lasterr
    }
    if rs.closed {
        err := rs.lasterrOrErrLocked(errRowsClosed)
        rs.closemu.RUnlock()
        return err
    }
    rs.closemu.RUnlock()

    if rs.lastcols == nil {
        return errors.New("sql: Scan called without calling Next")
    }
    if len(dest) != len(rs.lastcols) {
        return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
    }
    for i, sv := range rs.lastcols {
        err := convertAssignRows(dest[i], sv, rs)
        if err != nil {
            return fmt.Errorf(`sql: Scan error on column index %d, name %q: %v`, i, rs.rowsi.Columns()[i], err)
        }
    }
    return nil
}

scan的大致处理过程:

  1. 根据查询数据列名columns循环
  2. 根据接收数据的fileds循环
  3. 比对fields中的列名field与columns中列名column,
  4. 若一致,确认field的类型,如果是指针类型,则直接取指针存入values中;否则,创建指针存入values,再记录到reset中,方便后续处理。
  5. 调用sql.Scan将数据赋值到对应的values中
  6. 对于非指针类型的values,更新其值为指针指向的值

scan中关于查询与接收数据的匹配是根据列名进行匹配,而列名是根据其struct的model规则指定的,因此为保证数据能准确的Scan到,则要求接收数据的列名必须与查询数据结构的列名对应。此处回答了问题2

结合Fields中的非struct类型,values为空,将不会接收到任何数据。

总结

gorm的Scan支持接收的数据类型是struct、struct slice以及它们的指针类型(A、[]A、[]*A、*A、*[]A、*[]*A),鉴于是接收数据作其他处理,实际使用的都是指针类型。

需要注意的是:使用其他类型的slice并不会报错,但是接收不到任何数据。

gorm的Scan是根据列名进行数据匹配的,而列名是通过struct指定或自动转换的,这就要求接收数据的与查询数据的最终列名必须一致才能正常匹配,尤其是需要自定义新名称时,就需要添加gorm:"column:col_name"的tag才行。

公众号

鄙人刚刚开通了公众号,专注于分享Go开发相关内容,望大家感兴趣的支持一下,在此特别感谢。

  • 10
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 17
    评论
好的,你想了解在 Gin 框架中如何使用 GORM 吗? 首先,你需要在你的项目中导入 Gin 和 GORM 的依赖。在导入之后,你可以在你的代码中创建一个 GORM 的数据库实例,然后使用它来执行数据库的操作。 以下是一个示例代码片段,展示了如何在 Gin 中使用 GORM: ```go import ( "github.com/gin-gonic/gin" "gorm.io/driver/mysql" "gorm.io/gorm" ) func main() { // 创建数据库连接 dsn := "user:password@tcp(127.0.0.1:3306)/database_name?charset=utf8mb4&parseTime=True&loc=Local" db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) if err != nil { panic("failed to connect database") } // 将数据库实例作为 Gin 的中间件 r := gin.Default() r.Use(func(c *gin.Context) { c.Set("db", db) c.Next() }) // 在路由处理函数中使用数据库 r.GET("/users", func(c *gin.Context) { var users []User result := db.Find(&users) if result.Error != nil { c.JSON(500, gin.H{"error": "failed to get users"}) return } c.JSON(200, users) }) // 启动 Gin 服务器 r.Run(":8080") } ``` 在上面的代码中,我们创建了一个 GORM 的数据库实例,并将其作为 Gin 的中间件,使得在路由处理函数中可以直接使用该数据库实例来进行数据库的操作。在示例代码中,我们定义了一个 `/users` 的路由,当接收到 GET 请求时,会使用 GORM 查询数据库中的所有用户,并将结果以 JSON 格式返回给客户端。 希望这个例子可以帮助你理解在 Gin 中如何使用 GORM

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值