问题描述
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{Logger: newLogger})
if err != nil {
panic(err)
}
db = db.Set("gorm:table_options", "ENGINE=InnoDB DEFAULT CHARSET=utf8")
if err := db.AutoMigrate(&SQLRecord{}, &DistributeLock{}); err != nil {
fmt.Println(err)
}
r := &SQLRecord{} // Table name: sql_record
if err := db.Take(&r).Error; err != nil {
fmt.Println(err)
}
fmt.Printf("r: %+v\n", r)
var l DistributeLock // Table name: distribute_lock
if err := db.Model(&DistributeLock{}).Take(&l); err != nil {
fmt.Println(err)
}
fmt.Printf("l: %+v", l)
这段代码先建立了一条数据库连接,然后使用 AutoMigrate 建表,接下来分别查了 sql_record 表和 distribute_lock 表并打印查询结果。
那么我们遇到的问题是,查询第一张表的时候 gorm 生成的查询语句为
SELECT * FROM `t_sql_record` LIMIT 1
可以看出来是没问题的,但是查询第二张表的时候,gorm 生成的查询语句为
SELECT * FROM `t_sql_record` LIMIT 1
可以看出来生成的查询语句中要查询的表还是上一次的表,导致结果出错。
问题原因
背景知识
Gorm 中包括三种方法,链式方法,Finisher 方法和新建会话方法。
链式方法是将 Clauses 修改或添加到当前 Statement 的方法。常见的有 Where,Select,Joins 等。
Finishers 是会立即执行注册回调的方法,然后生成并执行 SQL。常见的有 Create,First,Find,Save 等。
在初始化了 *gorm.DB 或 新建会话方法 后, 调用新建会话方法会创建一个新的 Statement 实例而不是使用当前的。常见的有 Session,WithContext 等。
下面是 DB 的结构体,里面有一个 clone 参数,这个参数来判断当前 DB 是否可被 clone。
// DB GORM DB definition
type DB struct {
*Config
Error error
RowsAffected int64
Statement *Statement
clone int
}
Open() 之后会默认返回一个 clone=1 的 DB,我们称之为初始 DB,这个初始 DB 是可以被 clone 的。
链式方法会根据当前 DB 是否可以进行 clone 来决定是返回原 DB 还是返回 clone 的新 DB。之后的 Where,Select 等操作都是操作这个 DB 的 statement。而且 clone 出来的 DB 不允许再次被 clone,但是初始 DB 可以被 clone 多次。
原因
其实 Set() 方法也算是链式方法的一种,这个函数返回的是 clone 初始 DB 出来的 DB。在下面操作 t_sql_record 表的时候因为 DB 已经被 clone 过一次,所以之后的 Take() 方法修改 statement 都会在这个 DB 上修改。当想操作 distribute_lock 表的时候用的是上一张表的 statement,导致出错。
正常的逻辑是什么样的呢
我们正常调用 gorm 一般都会这么写,以此为例
db.Model().Where().Find()
Model 是一个链式方法,会返回一个 clone 的 DB,之后的 Where(),Find() 方法都会在这个新 DB上修改,当经过 Finisher 方法后(查到值后)这个新 DB 会被销毁,而不会对原有 DB 有任何影响,因为操作的都是 clone 后的 DB。
再结合代码分析下错误原因
当我们调用 Set() 方法后就得到了一个 clone 后的 DB,该 DB 不可被再次 clone。我们将这个 DB 赋值给了原 DB。
db = db.Set("gorm:table_options", "ENGINE=InnoDB DEFAULT CHARSET=utf8")
所以当查询第一张表的时候,因为 DB 已经被 clone 过了一次,不可以再次被 clone,所以 Take() 方法没有返回一个新的 DB,而是继续使用了原有的 DB,在此基础上修改了 DB 的 statement,将 Table 字段改为了当前需要操作的表名,即 sql_record。
if err := db.Take(&r).Error; err != nil {
fmt.Println(err)
}
当我们操作下一张表的时候,Model() 方法也不会返回新的 DB,而是继续使用原有 DB,但是原有 DB 的 statement 的 Table 字段已经被修改为 sql_record 了,所以就会生成错误的查询语句,导致错误的结果。
if err := db.Model(&DistributeLock{}).Take(&l); err != nil {
fmt.Println(err)
}
源码分析
我们可以看到,Set() 方法调用了 db.getInstance()。
// Set store value with key into current db instance's context
func (db *DB) Set(key string, value interface{}) *DB {
tx := db.getInstance()
tx.Statement.Settings.Store(key, value)
return tx
}
getInstance() 方法可以看出来是根据 DB 里面的 clone 参数来判断当前 DB 是否可以被 clone,如果可以 clone 则返回一个新的 tx 回去,该 tx 的 statement 是一个新的 statement,并且 clone 参数默认为 0,即不可再次 clone。如果不可以被 clone,则返回当前 DB。
func (db *DB) getInstance() *DB {
if db.clone > 0 {
tx := &DB{Config: db.Config, Error: db.Error}
if db.clone == 1 {
// clone with new statement
tx.Statement = &Statement{
DB: tx,
ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context,
Clauses: map[string]clause.Clause{},
Vars: make([]interface{}, 0, 8),
}
} else {
// with clone statement
tx.Statement = db.Statement.clone()
tx.Statement.DB = tx
}
return tx
}
return db
}
像 Take() 方法也会调用到 getInstance() 方法
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1)
// do something
}
func (db *DB) Limit(limit int) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Limit: &limit})
return
}
Model() 方法也一样
func (db *DB) Model(value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Model = value
return
}
解决方案
知道了原因,那么就有两种方案解决。
方案一
既然 Set() 方法是一个链式方法,返回的是 clone 后的 DB,那么就按照链式调用的方式创建表就好了。
if err := db.Set("gorm:table_options", "ENGINE=InnoDB DEFAULT CHARSET=utf8").AutoMigrate(&SQLRecord{}, &DistributeLock{}); err != nil {
fmt.Println(err)
}
方案二
既然 gorm 提供了新建会话方法,那我们每次使用的时候新建一个会话就好了。
if err := db.WithContext(context.Background()).Take(&r).Error; err != nil {
fmt.Println(err)
}
if err := db.WithContext(context.Background()).Model(&DistributeLock{}).Take(&l); err != nil {
fmt.Println(err)
}