上一节,学习了如何实现一个简单的转账事务,但是,我们还没做更新账户余额的操作,因为,它稍复杂一些,需要小心处理并发事务以避免死锁。
本节,将实现这个功能,顺便学习一下数据库锁,以及如何调试死锁的情况。(有点硬核,需要耐心学习,最好自己手动操作一遍,以便深入理解)
测试驱动开发(TDD)
这次,将使用一种不同的实现方式,即测试驱动开发(TDD
)。
思路是:首先编写测试,然后逐渐改进功能代码直到测试通过。
接着上一节的store_test.go
,需要TODO
的地方,检查更新后的账户余额
// 3. 为`ToAccount`创建一个账目记录,金额为`+10`
result.ToEntry, err = q.CreateEntry(ctx, CreateEntryParams{
AccountID: arg.ToAccountID,
Amount: arg.Amount,
})
if err != nil {
return err
}
// TODO: 更新账户余额操作后面再做
要完成这个单元测试,需要先检查钱是从哪转出来的,然后,检查钱转到哪个账户里面去了。
// 首先,检查钱是从哪转出来的,
fromAccount := result.FromAccount
require.NotEmpty(t, fromAccount)
require.Equal(t, account1.ID, fromAccount.ID)
// 然后,检查钱转到哪个账户里面去了
toAccount := result.ToAccount
require.NotEmpty(t, toAccount)
require.Equal(t, account2.ID, toAccount.ID)
接着,计算一下,转出的金额 = account1.Balance - fromAccount.Balance
,转入的金额 = toAccount.Balance - account2.Balance
,需要检查以下几项:
- 转出的金额应该和转入的金额相等
- 并且转出的金额必须大于0
- 转出的金额应该可以被每笔交易的金额整除。因为,如果产生了5次转账,这里设置的是每次转出10,总共转出了50,所以,50除以10是可以整除的。
- 转出总金额除以每次转出的金额等于转出次数
k
,k
必须是大于等于1
,并且小于等于n
的 - 此外,每笔交易的 k 必须是唯一的,意思是第1笔交易时,k应该等于1,第2笔交易时,k应该等于2…
// 转账方:转出的金额
diff1 := account1.Balance - fromAccount.Balance
// 收钱方:转入的金额
diff2 := toAccount.Balance - account2.Balance
// 这两个值应该相同
require.Equal(t, diff1, diff2)
// 转出来的钱应该大于0
require.True(t, diff1 > 0)
// 转出的金额应该可以被每笔交易的金额整除
require.True(t, diff1%amount == 0)
最后,在最后在 for 循环外面,检查这两个账户的最终余额:
// 最后在 for 循环外面,检查这两个账户的最终余额
updateAccount1, err := testQueries.GetAccount(context.Background(), account1.ID)
require.NoError(t, err)
updateAccount2, err := testQueries.GetAccount(context.Background(), account2.ID)
require.NoError(t, err)
fmt.Println(">> after:", updateAccount1.Balance, updateAccount2.Balance)
// account1的余额减去转账次数乘以每次转账的金额,必须等于最终的余额
require.Equal(t, account1.Balance-int64(n)*amount, updateAccount1.Balance)
require.Equal(t, account2.Balance+int64(n)*amount, updateAccount2.Balance)
完整的代码如下:
package db
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func TestTransferTx(t *testing.T) {
store := NewStore(testDB)
account1 := createRandomAccount(t)
account2 := createRandomAccount(t)
fmt.Println(">> before:", account1.Balance, account2.Balance)
n := 5
amount := int64(10)
errs := make(chan error)
results := make(chan TransferTxResult)
for i := 0; i < n; i++ {
go func() {
result, err := store.TransferTx(context.Background(), TransferTxParams{
FromAccountID: account1.ID,
ToAccountID: account2.ID,
Amount: amount,
})
errs <- err
results <- result
}()
}
// 检查结果
existed := make(map[int]bool)
for i := 0; i < n; i++ {
err := <-errs
require.NoError(t, err)
result := <-results
require.NotEmpty(t, result)
// check transfer
transfer := result.Transfer
require.NotEmpty(t, transfer)
require.Equal(t, account1.ID, transfer.FromAccountID)
require.Equal(t, account2.ID, transfer.ToAccountID)
require.Equal(t, amount, transfer.Amount)
require.NotZero(t, transfer.ID)
require.NotZero(t, transfer.CreatedAt)
_, err = store.GetTransfer(context.Background(), transfer.ID)
require.NoError(t, err)
// check entries
formEntry := result.FromEntry
require.NotEmpty(t, formEntry)
require.Equal(t, account1.ID, formEntry.AccountID)
require.Equal(t, -amount, formEntry.Amount)
require.NotZero(t, formEntry.ID)
require.NotZero(t, formEntry.CreatedAt)
_, err = store.GetEntry(context.Background(), formEntry.ID)
require.NoError(t, err)
toEntry := result.ToEntry
require.NotEmpty(t, toEntry)
require.Equal(t, account2.ID, toEntry.AccountID)
require.Equal(t, amount, toEntry.Amount)
require.NotZero(t, toEntry.ID)
require.NotZero(t, toEntry.CreatedAt)
_, err = store.GetEntry(context.Background(), toEntry.ID)
require.NoError(t, err)
// 首先,检查钱是从哪转出来的,
fromAccount := result.FromAccount
require.NotEmpty(t, fromAccount)
require.Equal(t, account1.ID, fromAccount.ID)
// 然后,检查钱转到哪个账户里面去了
toAccount := result.ToAccount
require.NotEmpty(t, toAccount)
require.Equal(t, account2.ID, toAccount.ID)
// 检查更新后的账户余额
fmt.Println(">> tx:", fromAccount.Balance, toAccount.Balance)
// 转账方:转出的金额
diff1 := account1.Balance - fromAccount.Balance
// 收钱方:转入的金额
diff2 := toAccount.Balance - account2.Balance
// 这两个值应该相同
require.Equal(t, diff1, diff2)
// 转出来的钱应该大于0
require.True(t, diff1 > 0)
// 转出的金额应该可以被每笔交易的金额整除
require.True(t, diff1%amount == 0)
// 计算 k = diff1 除以 每笔交易的金额,k 必须是大于等于1,并且小于等于n的
// 此外,每笔交易的 k 必须是唯一的,意思是第1笔交易时,k应该等于1,第2笔交易时,k应该等于2...
// 上面,需要定义一个新变量,existed
k := int(diff1 / amount)
require.True(t, k >= 1 && k <= n)
// 检查这个map,不应该包含 k
require.NotContains(t, existed, k)
// 之后,给这个 map 赋值
existed[k] = true
}
// 最后在 for 循环外面,检查这两个账户的最终余额
updateAccount1, err := testQueries.GetAccount(context.Background(), account1.ID)
require.NoError(t, err)
updateAccount2, err := testQueries.GetAccount(context.Background(), account2.ID)
require.NoError(t, err)
fmt.Println(">> after:", updateAccount1.Balance, updateAccount2.Balance)
// account1的余额减去转账次数乘以每次转账的金额,必须等于最终的余额
require.Equal(t, account1.Balance-int64(n)*amount, updateAccount1.Balance)
require.Equal(t, account2.Balance+int64(n)*amount, updateAccount2.Balance)
}
运行,run test
,可以看到如下错误,
这是因为具体的功能实现还没写,好,接着打开store.go
文件,来编写未实现的功能。
更新账户的余额(错误的方法)
首先,从数据库中获取该account
,然后从余额中增加或减去一些金额,然后更新到数据库。但是,如果没有适当的锁机制,这通常会出现错误。像这样:
// 从数据库中获取 account -> 更新账户余额
account1, err := q.GetAccount(ctx, arg.FromAccountID)
if err != nil {
return err
}
result.FromAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
ID: arg.FromAccountID,
Balance: account1.Balance - arg.Amount,
})
if err != nil {
return err
}
account2, err := q.GetAccount(ctx, arg.ToAccountID)
if err != nil {
return err
}
result.ToAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
ID: arg.ToAccountID,
Balance: account2.Balance + arg.Amount,
})
if err != nil {
return err
}
运行一下单元测试run test
,来看一下,发生了错误:
在日志中,可以看到,前两笔转账交易是正确的,到第3笔转账,出现了问题,account1
账户的余额并没有减10
,仍然是166
。
要了解具体的原因,让我们看一下GetAccount
的SQL
语句:
-- name: GetAccount :one
SELECT * FROM accounts
WHERE id = $1 LIMIT 1;
它就是普通的SQL查询语句,它不会阻止来自其他转账交易同时读取相同账户的数据。所以,2个并发交易可以得到account1
相同的值,导致并发的情况出错。
为了演示一下,这个是怎么发生的,让我们打开两个终端,都进入到postgres
控制台:
docker exec -it postgres14 psql -U root -d simple_bank
不加锁查询的情况
让我们在两个不同的终端运行2个并行事务。
- 在第一个控制台输入
BEGIN;
开启第一个事务 - 在第二个控制台输入
BEGIN;
开启第二个事务 - 在第一个控制台输入
select * from accounts where id=1;
- 在第二个控制台也输入
select * from accounts where id=1;
可以看到,相同账户的数据在两个事务里,会立即返回出来,并不会被阻止。因此,我们先在这两个控制台里输入ROLLBACK;
回滚这两个事务,并学习一下如何修复它。
加锁查询的情况
- 在第一个控制台输入
BEGIN;
开启第一个事务 - 在第二个控制台输入
BEGIN;
开启第二个事务 - 在第一个控制台输入
select * from accounts where id=1 for update;
- 在第二个控制台输入
select * from accounts where id=1 for update;
,可以看到,它不返回结果了,被阻塞住了,必须等待第一个事务提交或回滚。 - 在第一个控制台输入
update accounts set balance=500 where id=1;
,切换到第二个控制台,看到它仍然阻塞 - 在第一个控制台输入
COMMIT;
,切换到第二个控制台,可以看到返回出来结果了,balance
是500
加锁更新账户余额
知道问题在哪里了,我们回到account.sql
文件,再增加一条sql
语句GetAccountForUpdate
:
-- name: GetAccountForUpdate :one
SELECT * FROM accounts
WHERE id = $1 LIMIT 1
FOR UPDATE;
然后,在项目终端下运行make sqlc
来重新生成代码,然后,可以看到account.sql.go
文件里新增加了个GetAccountForUpdate
函数。
在store.go
里面就可以用到它了,把更新账户余额的操作修改一下:
// 从数据库中获取 account -> 更新账户余额
account1, err := q.GetAccountForUpdate(ctx, arg.FromAccountID)
if err != nil {
return err
}
result.FromAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
ID: arg.FromAccountID,
Balance: account1.Balance - arg.Amount,
})
if err != nil {
return err
}
account2, err := q.GetAccountForUpdate(ctx, arg.ToAccountID)
if err != nil {
return err
}
result.ToAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
ID: arg.ToAccountID,
Balance: account2.Balance + arg.Amount,
})
if err != nil {
return err
}
再次,运行单元测试run test
,不幸的是,它又出错了。
这次报的错是:检测到死锁了 deadlock detected
,所以,怎么解决呢?接下来,学习如何调试这种死锁的情况。
调试死锁
为了弄清楚为什么会发生死锁(deadlock
),我们需要打印一些日志,查看一下哪个事务正在调用哪个查询sql
,以及调用顺序是怎么样的。
所以,我们必须给每个事务分配个名称,并通过上下文参数传递给 TransferTx()
函数,在store_test.go
里面的for i := 0; i < n; i++ {
内,增加
txName := fmt.Sprintf("tx %d", i+1)
为了把这个txName
通过context
传递,需要在store.go
里面定义一个空的结构体
var txKey = struct{}{}
这里的第二个{}
表示空的对象,回到store_test.go
改造一下,把txKey
和txName
作为参数传给context.WithValue()
:
for i := 0; i < n; i++ {
txName := fmt.Sprintf("tx %d", i+1)
go func() {
ctx := context.WithValue(context.Background(), txKey, txName)
result, err := store.TransferTx(ctx, TransferTxParams{
FromAccountID: account1.ID,
ToAccountID: account2.ID,
Amount: amount,
})
errs <- err
results <- result
}()
}
这样,上下文就可以保存事务名称了。
来到store.go
里面的TransferTx
函数,增加日志输出:
var txKey = struct{}{}
func (store *Store) TransferTx(ctx context.Context, arg TransferTxParams) (TransferTxResult, error) {
var result TransferTxResult
err := store.execTx(ctx, func(q *Queries) error {
var err error
txName := ctx.Value(txKey)
// 1. 创建一个金额等于`10`的转账记录
fmt.Println(txName, "create transfer")
result.Transfer, err = q.CreateTransfer(ctx, CreateTransferParams{
FromAccountID: arg.FromAccountID,
ToAccountID: arg.ToAccountID,
Amount: arg.Amount,
})
if err != nil {
return err
}
// 2. 为`FromAccount`创建一个账目记录,金额为`-10`
fmt.Println(txName, "create entry 1")
result.FromEntry, err = q.CreateEntry(ctx, CreateEntryParams{
AccountID: arg.FromAccountID,
Amount: -arg.Amount,
})
if err != nil {
return err
}
// 3. 为`ToAccount`创建一个账目记录,金额为`+10`
fmt.Println(txName, "create entry 2")
result.ToEntry, err = q.CreateEntry(ctx, CreateEntryParams{
AccountID: arg.ToAccountID,
Amount: arg.Amount,
})
if err != nil {
return err
}
// 从数据库中获取 account -> 更新账户余额
fmt.Println(txName, "get account 1")
account1, err := q.GetAccountForUpdate(ctx, arg.FromAccountID)
if err != nil {
return err
}
fmt.Println(txName, "update account 1")
result.FromAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
ID: arg.FromAccountID,
Balance: account1.Balance - arg.Amount,
})
if err != nil {
return err
}
fmt.Println(txName, "get account 2")
account2, err := q.GetAccountForUpdate(ctx, arg.ToAccountID)
if err != nil {
return err
}
fmt.Println(txName, "update account 2")
result.ToAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
ID: arg.ToAccountID,
Balance: account2.Balance + arg.Amount,
})
if err != nil {
return err
}
return err
})
return result, err
}
为了便于调试,把store_test.go
里面的并发数量n
改成3,便于跟踪调试,运行它run test
,可以看到保存了,并输出了我们打印的日志:
=== RUN TestTransferTx
>> before: 103 980
tx 3 create transfer
tx 3 create entry 1
tx 3 create entry 2
tx 2 create transfer
tx 1 create transfer
tx 3 get account 1
tx 3 update account 1
tx 3 get account 2
tx 3 update account 2
tx 2 create entry 1
tx 1 create entry 1
tx 1 create entry 2
tx 2 create entry 2
tx 1 get account 1
tx 2 get account 1
>> tx: 93 990
tx 2 update account 1
到这里出错了,现在就要跟踪一下,是怎么出这个错误的,我们来用go
执行的过程模拟一下,先把go
执行转账事务所需要的sql
整理出来,就是下面这个流程:
BEGIN;
-- create transfer
INSERT INTO transfers (from_account_id, to_account_id, amount) VALUES (1, 2, 10) RETURNING *;
-- create entry 1
INSERT INTO entries (account_id, amount) VALUES (1, -10) RETURNING *;
-- create entry 2
INSERT INTO entries (account_id, amount) VALUES (2, 10) RETURNING *;
-- get account 1
SELECT * FROM accounts WHERE id = 1 FOR UPDATE;
-- update account 1
UPDATE accounts SET balance = 90 WHERE id = 1 RETURNING *;
-- get account 2
SELECT * FROM accounts WHERE id = 2 FOR UPDATE;
-- update account 2
UPDATE accounts SET balance = 110 WHERE id = 2 RETURNING *;
因为我们设置了3个并发,因此,想上面操作的那样,我们打开3个postgres
控制台:
接下来,按照日志的执行顺序,分别依次在不同的postgres
控制台里输入这些sql
- 首先这3个控制台都输入
BEGIN;
开启事务 - 日志
tx 3 create transfer
,就在第3
个控制台输入,INSERT INTO transfers (from_account_id, to_account_id, amount) VALUES (1, 2, 10) RETURNING *;
- 日志
tx 3 create entry 1
,就在第3
个控制台输入,INSERT INTO entries (account_id, amount) VALUES (1, -10) RETURNING *;
- 依次类推…
- 直到这一条,
tx 3 get account 1
,在第3个控制台执行SELECT * FROM accounts WHERE id = 1 FOR UPDATE;
,发现卡住了,为什么INSERT transfers
的时候,会引发SELECT accounts
操作死锁,是不同的表啊;带着问题,我们Google搜postgres lock
,在postgres
的WIKI
里(https://wiki.postgresql.org/wiki/Lock_Monitoring),看到了这段:
SELECT blocked_locks.pid AS blocked_pid,
blocked_activity.usename AS blocked_user,
blocking_locks.pid AS blocking_pid,
blocking_activity.usename AS blocking_user,
blocked_activity.query AS blocked_statement,
blocking_activity.query AS current_statement_in_blocking_process
FROM pg_catalog.pg_locks blocked_locks
JOIN pg_catalog.pg_stat_activity blocked_activity ON blocked_activity.pid = blocked_locks.pid
JOIN pg_catalog.pg_locks blocking_locks
ON blocking_locks.locktype = blocked_locks.locktype
AND blocking_locks.database IS NOT DISTINCT FROM blocked_locks.database
AND blocking_locks.relation IS NOT DISTINCT FROM blocked_locks.relation
AND blocking_locks.page IS NOT DISTINCT FROM blocked_locks.page
AND blocking_locks.tuple IS NOT DISTINCT FROM blocked_locks.tuple
AND blocking_locks.virtualxid IS NOT DISTINCT FROM blocked_locks.virtualxid
AND blocking_locks.transactionid IS NOT DISTINCT FROM blocked_locks.transactionid
AND blocking_locks.classid IS NOT DISTINCT FROM blocked_locks.classid
AND blocking_locks.objid IS NOT DISTINCT FROM blocked_locks.objid
AND blocking_locks.objsubid IS NOT DISTINCT FROM blocked_locks.objsubid
AND blocking_locks.pid != blocked_locks.pid
JOIN pg_catalog.pg_stat_activity blocking_activity ON blocking_activity.pid = blocking_locks.pid
WHERE NOT blocked_locks.granted;
-
把这段
sql
放到navicat
里执行一下,可以看到:
发现,blocked
的sql就是SELECT * FROM accounts WHERE id = 1 FOR UPDATE;
,而引发blocking
的sql是INSERT INTO transfers (from_account_id, to_account_id, amount) VALUES (1, 2, 10) RETURNING *;
,因此,对这两个不同的表操作,确实会引发相互阻塞。 -
让我们更深入的了解一下为什么
SELECT
查询必须等待INSERT
操作,回到postgres WIKI
页面,往下继续看,找到1条,找出数据库中所有的锁
SELECT a.datname,
l.relation::regclass,
l.transactionid,
l.mode,
l.GRANTED,
a.usename,
a.query,
a.query_start,
age(now(), a.query_start) AS "age",
a.pid
FROM pg_stat_activity a
JOIN pg_locks l ON l.pid = a.pid
ORDER BY a.query_start;
把它复制到navicat
中,并稍作修改,因为要查看一些更多的信息,在a.datname,
后面增加a.application_name,
,用来查看锁来自哪个应用。
这个sql中,
l.relation::regclass
其实就是表名,l.transactionid
是事务ID,- 在
l.mode
后面增加l.locktype
用来查看锁的类型, l.GRANTED
是锁是否被授权,a.usename
是运行sql
的用户名,a.query
是持有或试图获取锁的sql
语句,a.query_start
和age(now(), a.query_start) AS "age"
在这个场景中部重要,删除它,a.pid
是进程IDORDER BY a.query_start
改为ORDER BY a.pid
更好一下,因为我们在psql
控制台有3个不同的处理进程,更容易看出来哪个锁是属于哪个事务,- 增加
WHERE a.application_name = 'psql'
只关注和psql
控制台相关的 - 如下。
SELECT a.application_name,
l.relation::regclass,
l.transactionid,
l.mode,
l.locktype,
l.GRANTED,
a.usename,
a.query,
a.pid
FROM pg_stat_activity a
JOIN pg_locks l ON l.pid = a.pid
WHERE a.application_name = 'psql'
ORDER BY a.pid;
在,navicat
里运行一下,
看到只有一条,GRANTED
为f
的记录,可以看到它的transactionid
是1075
,可以看到transactionid
为1075
的还有一条,它正在进行INSERT
操作,pid
为6991
,但是,为什么SELECT FROM accounts
表需要从INSERT INTO transfers
的事务中获取锁呢?
来看一下,我们的sql
建表语句,有这样的外键约束:
ALTER TABLE "transfers" ADD FOREIGN KEY ("from_account_id") REFERENCES "accounts" ("id");
所以,对transfers
表的from_account_id
做任何更新操作时,都会影响到accounts
表做SELECT FOR UPDATE
的操作,因为,这里需要获取一个锁来防止冲突,保证数据的一致性。这样,就解释了deadlock
是如何发生的。如何解决这个问题呢?
让我们先把这3个控制台执行ROLLBACK;
并\q
退出。
修复死锁方案1
我们知道,死锁是由外键约束引起的,那我们尝试把这几个外键约束的sql
都删除掉,在000001_init_schema.up.sql
文件里,先注释掉下面这几行:
-- ALTER TABLE "entries" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id");
-- ALTER TABLE "transfers" ADD FOREIGN KEY ("from_account_id") REFERENCES "accounts" ("id");
-- ALTER TABLE "transfers" ADD FOREIGN KEY ("to_account_id") REFERENCES "accounts" ("id");
让我们把数据库降级一下make migratedown
,再make migrateup
,之后,再来运行这个单元测试run test
,可以看到成功了。
但是,这不是最好的解决方案,因为,这样失去了保持数据一致性的良好约束。所以,恢复刚才注释掉的那几行sql
代码。并运行make migratedown
和make migrateup
。
修复死锁方案2
打开account.sql
文件,看这里
-- name: UpdateAccount :one
UPDATE accounts SET balance = $2 WHERE id = $1
RETURNING *;
更新账户余额的时候,id是永远不会更改的,因为它是账户表的主键。所以,如果我们能告诉postgres
,SELECT * FROM accounts FOR UPDATE
时不更新主键,那么postgres
将不需要获取事务锁,因此,就不会出现deadlock
。
恩,有个知识点,FOR UPDATE
的时候,我们只需要明确一下,NO KEY UPDATE
,说明我们不更新主键ID,像这样。
-- name: GetAccountForUpdate :one
SELECT * FROM accounts
WHERE id = $1 LIMIT 1
FOR NO KEY UPDATE;
在项目终端里,在运行make sqlc
重新生成代码,再运行这个单元测试,可以看到成功了。
之后,清理一下代码,把我们之前加的打印日志都去掉吧,再运行一下单元测试,没问题,通过。
可以,看到这个日志,每次交易后的2个账户余额是怎么变化的。
修复死锁方案3
有一种更好的方法来实现更新余额操作。目前,我们必须执行两个sql
来获取并更新账户余额:
account1, err := q.GetAccountForUpdate(ctx, arg.FromAccountID)
if err != nil {
return err
}
result.FromAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
ID: arg.FromAccountID,
Balance: account1.Balance - arg.Amount,
})
if err != nil {
return err
}
其实,可以通过1个sql
语句来搞定。所以,我们在account.sql
里再增加一个AddAccountBalance
语句,它和UpdateAccount
类似,不同的是设置 balance
等于 balance
+ 第二个参数,如下:
-- name: AddAccountBalance :one
UPDATE accounts SET balance = balance + $2 WHERE id = $1
RETURNING *;
之后,运行make sqlc
来生成代码,看到account.sql.go
里面增加了如下代码:
type AddAccountBalanceParams struct {
ID int64 `json:"id"`
Balance int64 `json:"balance"`
}
但是,这个参数Balance
有点混淆,我们是让余额增加或减少金额,而不是更新Balance
本身,在sqlc
可以这样做:
-- name: AddAccountBalance :one
UPDATE accounts SET balance = balance + sqlc.arg(amount)
WHERE id = sqlc.arg(id)
RETURNING *;
把$2
改成sqlc.arg(amount)
,$1
改成sqlc.arg(id)
,之后,在运行make sqlc
,可以看到生成的代码变成了这样:
type AddAccountBalanceParams struct {
Amount int64 `json:"amount"`
ID int64 `json:"id"`
}
Balance
变成了Amount
,这样语义就清晰了。
回到store.go
文件,删除GetAccountForUpdate
,把UpdateAccount
改为AddAccountBalance
,完整代码如下:
package db
import (
"context"
"database/sql"
"fmt"
)
type Store struct {
*Queries
db *sql.DB
}
func NewStore(db *sql.DB) *Store {
return &Store{
db: db,
Queries: New(db),
}
}
func (store *Store) execTx(ctx context.Context, fn func(*Queries) error) error {
tx, err := store.db.BeginTx(ctx, nil)
if err != nil {
return err
}
q := New(tx)
err = fn(q)
if err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr)
}
return err
}
return tx.Commit()
}
type TransferTxParams struct {
FromAccountID int64 `json:"from_account_id"`
ToAccountID int64 `json:"to_account_id"`
Amount int64 `json:"amount"`
}
type TransferTxResult struct {
Transfer Transfer `json:"transfer"`
FromAccount Account `json:"from_account"`
ToAccount Account `json:"to_account"`
FromEntry Entry `json:"from_entry"`
ToEntry Entry `json:"to_entry"`
}
func (store *Store) TransferTx(ctx context.Context, arg TransferTxParams) (TransferTxResult, error) {
var result TransferTxResult
err := store.execTx(ctx, func(q *Queries) error {
var err error
// 1. 创建一个金额等于`10`的转账记录
result.Transfer, err = q.CreateTransfer(ctx, CreateTransferParams{
FromAccountID: arg.FromAccountID,
ToAccountID: arg.ToAccountID,
Amount: arg.Amount,
})
if err != nil {
return err
}
// 2. 为`FromAccount`创建一个账目记录,金额为`-10`
result.FromEntry, err = q.CreateEntry(ctx, CreateEntryParams{
AccountID: arg.FromAccountID,
Amount: -arg.Amount,
})
if err != nil {
return err
}
// 3. 为`ToAccount`创建一个账目记录,金额为`+10`
result.ToEntry, err = q.CreateEntry(ctx, CreateEntryParams{
AccountID: arg.ToAccountID,
Amount: arg.Amount,
})
if err != nil {
return err
}
// 更新账户余额
result.FromAccount, err = q.AddAccountBalance(ctx, AddAccountBalanceParams{
ID: arg.FromAccountID,
Amount: -arg.Amount,
})
if err != nil {
return err
}
result.ToAccount, err = q.AddAccountBalance(ctx, AddAccountBalanceParams{
ID: arg.ToAccountID,
Amount: arg.Amount,
})
if err != nil {
return err
}
return err
})
return result, err
}
再次运行单元测试,没问题,测试通过。
好了,本次课程学习结束。下节,将继续学习如何避免数据库事务查询中的死锁顺序问题