Golang后端学习笔记 — 7. Golang如何处理数据库事务锁

上一节,学习了如何实现一个简单的转账事务,但是,我们还没做更新账户余额的操作,因为,它稍复杂一些,需要小心处理并发事务以避免死锁。

本节,将实现这个功能,顺便学习一下数据库锁,以及如何调试死锁的情况。(有点硬核,需要耐心学习,最好自己手动操作一遍,以便深入理解)

测试驱动开发(TDD)

这次,将使用一种不同的实现方式,即测试驱动开发(TDD)。
思路是:首先编写测试,然后逐渐改进功能代码直到测试通过。

接着上一节的store_test.go,需要TODO的地方,检查更新后的账户余额

		// 3. 为`ToAccount`创建一个账目记录,金额为`+10`
		result.ToEntry, err = q.CreateEntry(ctx, CreateEntryParams{
			AccountID: arg.ToAccountID,
			Amount:    arg.Amount,
		})
		if err != nil {
			return err
		}

		// TODO: 更新账户余额操作后面再做

要完成这个单元测试,需要先检查钱是从哪转出来的,然后,检查钱转到哪个账户里面去了。

		// 首先,检查钱是从哪转出来的,
		fromAccount := result.FromAccount
		require.NotEmpty(t, fromAccount)
		require.Equal(t, account1.ID, fromAccount.ID)

		// 然后,检查钱转到哪个账户里面去了
		toAccount := result.ToAccount
		require.NotEmpty(t, toAccount)
		require.Equal(t, account2.ID, toAccount.ID)

接着,计算一下,转出的金额 = account1.Balance - fromAccount.Balance,转入的金额 = toAccount.Balance - account2.Balance,需要检查以下几项:

  • 转出的金额应该和转入的金额相等
  • 并且转出的金额必须大于0
  • 转出的金额应该可以被每笔交易的金额整除。因为,如果产生了5次转账,这里设置的是每次转出10,总共转出了50,所以,50除以10是可以整除的。
  • 转出总金额除以每次转出的金额等于转出次数kk必须是大于等于1,并且小于等于n
  • 此外,每笔交易的 k 必须是唯一的,意思是第1笔交易时,k应该等于1,第2笔交易时,k应该等于2…
		// 转账方:转出的金额
		diff1 := account1.Balance - fromAccount.Balance
		// 收钱方:转入的金额
		diff2 := toAccount.Balance - account2.Balance
		// 这两个值应该相同
		require.Equal(t, diff1, diff2)
		// 转出来的钱应该大于0
		require.True(t, diff1 > 0)
		// 转出的金额应该可以被每笔交易的金额整除
		require.True(t, diff1%amount == 0)

最后,在最后在 for 循环外面,检查这两个账户的最终余额:

	// 最后在 for 循环外面,检查这两个账户的最终余额
	updateAccount1, err := testQueries.GetAccount(context.Background(), account1.ID)
	require.NoError(t, err)

	updateAccount2, err := testQueries.GetAccount(context.Background(), account2.ID)
	require.NoError(t, err)

	fmt.Println(">> after:", updateAccount1.Balance, updateAccount2.Balance)
	// account1的余额减去转账次数乘以每次转账的金额,必须等于最终的余额
	require.Equal(t, account1.Balance-int64(n)*amount, updateAccount1.Balance)
	require.Equal(t, account2.Balance+int64(n)*amount, updateAccount2.Balance)

完整的代码如下:

package db

import (
	"context"
	"fmt"
	"testing"

	"github.com/stretchr/testify/require"
)

func TestTransferTx(t *testing.T) {
	store := NewStore(testDB)

	account1 := createRandomAccount(t)
	account2 := createRandomAccount(t)
	fmt.Println(">> before:", account1.Balance, account2.Balance)

	n := 5
	amount := int64(10)

	errs := make(chan error)
	results := make(chan TransferTxResult)

	for i := 0; i < n; i++ {
		go func() {
			result, err := store.TransferTx(context.Background(), TransferTxParams{
				FromAccountID: account1.ID,
				ToAccountID:   account2.ID,
				Amount:        amount,
			})

			errs <- err
			results <- result
		}()
	}

	// 检查结果
	existed := make(map[int]bool)

	for i := 0; i < n; i++ {
		err := <-errs
		require.NoError(t, err)

		result := <-results
		require.NotEmpty(t, result)

		// check transfer
		transfer := result.Transfer
		require.NotEmpty(t, transfer)
		require.Equal(t, account1.ID, transfer.FromAccountID)
		require.Equal(t, account2.ID, transfer.ToAccountID)
		require.Equal(t, amount, transfer.Amount)
		require.NotZero(t, transfer.ID)
		require.NotZero(t, transfer.CreatedAt)

		_, err = store.GetTransfer(context.Background(), transfer.ID)
		require.NoError(t, err)

		// check entries
		formEntry := result.FromEntry
		require.NotEmpty(t, formEntry)
		require.Equal(t, account1.ID, formEntry.AccountID)
		require.Equal(t, -amount, formEntry.Amount)
		require.NotZero(t, formEntry.ID)
		require.NotZero(t, formEntry.CreatedAt)

		_, err = store.GetEntry(context.Background(), formEntry.ID)
		require.NoError(t, err)

		toEntry := result.ToEntry
		require.NotEmpty(t, toEntry)
		require.Equal(t, account2.ID, toEntry.AccountID)
		require.Equal(t, amount, toEntry.Amount)
		require.NotZero(t, toEntry.ID)
		require.NotZero(t, toEntry.CreatedAt)

		_, err = store.GetEntry(context.Background(), toEntry.ID)
		require.NoError(t, err)

		// 首先,检查钱是从哪转出来的,
		fromAccount := result.FromAccount
		require.NotEmpty(t, fromAccount)
		require.Equal(t, account1.ID, fromAccount.ID)

		// 然后,检查钱转到哪个账户里面去了
		toAccount := result.ToAccount
		require.NotEmpty(t, toAccount)
		require.Equal(t, account2.ID, toAccount.ID)

		// 检查更新后的账户余额
		fmt.Println(">> tx:", fromAccount.Balance, toAccount.Balance)
		// 转账方:转出的金额
		diff1 := account1.Balance - fromAccount.Balance
		// 收钱方:转入的金额
		diff2 := toAccount.Balance - account2.Balance
		// 这两个值应该相同
		require.Equal(t, diff1, diff2)
		// 转出来的钱应该大于0
		require.True(t, diff1 > 0)
		// 转出的金额应该可以被每笔交易的金额整除
		require.True(t, diff1%amount == 0)

		// 计算 k = diff1 除以 每笔交易的金额,k 必须是大于等于1,并且小于等于n的
		// 此外,每笔交易的 k 必须是唯一的,意思是第1笔交易时,k应该等于1,第2笔交易时,k应该等于2...
		// 上面,需要定义一个新变量,existed
		k := int(diff1 / amount)
		require.True(t, k >= 1 && k <= n)
		// 检查这个map,不应该包含 k
		require.NotContains(t, existed, k)
		// 之后,给这个 map 赋值
		existed[k] = true
	}
	// 最后在 for 循环外面,检查这两个账户的最终余额
	updateAccount1, err := testQueries.GetAccount(context.Background(), account1.ID)
	require.NoError(t, err)

	updateAccount2, err := testQueries.GetAccount(context.Background(), account2.ID)
	require.NoError(t, err)

	fmt.Println(">> after:", updateAccount1.Balance, updateAccount2.Balance)
	// account1的余额减去转账次数乘以每次转账的金额,必须等于最终的余额
	require.Equal(t, account1.Balance-int64(n)*amount, updateAccount1.Balance)
	require.Equal(t, account2.Balance+int64(n)*amount, updateAccount2.Balance)
}

运行,run test,可以看到如下错误,
run test 错误

这是因为具体的功能实现还没写,好,接着打开store.go文件,来编写未实现的功能。

更新账户的余额(错误的方法)

首先,从数据库中获取该account,然后从余额中增加或减去一些金额,然后更新到数据库。但是,如果没有适当的锁机制,这通常会出现错误。像这样:

		// 从数据库中获取 account -> 更新账户余额
		account1, err := q.GetAccount(ctx, arg.FromAccountID)
		if err != nil {
			return err
		}

		result.FromAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
			ID:      arg.FromAccountID,
			Balance: account1.Balance - arg.Amount,
		})
		if err != nil {
			return err
		}

		account2, err := q.GetAccount(ctx, arg.ToAccountID)
		if err != nil {
			return err
		}

		result.ToAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
			ID:      arg.ToAccountID,
			Balance: account2.Balance + arg.Amount,
		})
		if err != nil {
			return err
		}

运行一下单元测试run test,来看一下,发生了错误:
数据库事务异常
在日志中,可以看到,前两笔转账交易是正确的,到第3笔转账,出现了问题,account1账户的余额并没有减10,仍然是166

要了解具体的原因,让我们看一下GetAccountSQL语句:

-- name: GetAccount :one
SELECT * FROM accounts
WHERE id = $1 LIMIT 1;

它就是普通的SQL查询语句,它不会阻止来自其他转账交易同时读取相同账户的数据。所以,2个并发交易可以得到account1相同的值,导致并发的情况出错。

为了演示一下,这个是怎么发生的,让我们打开两个终端,都进入到postgres控制台:

docker exec -it postgres14 psql -U root -d simple_bank

进入postgres控制台

不加锁查询的情况

让我们在两个不同的终端运行2个并行事务。

  • 在第一个控制台输入BEGIN;开启第一个事务
  • 在第二个控制台输入BEGIN;开启第二个事务
  • 在第一个控制台输入select * from accounts where id=1;
  • 在第二个控制台也输入select * from accounts where id=1;

可以看到,相同账户的数据在两个事务里,会立即返回出来,并不会被阻止。因此,我们先在这两个控制台里输入ROLLBACK;回滚这两个事务,并学习一下如何修复它。

加锁查询的情况

  • 在第一个控制台输入BEGIN;开启第一个事务
  • 在第二个控制台输入BEGIN;开启第二个事务
  • 在第一个控制台输入select * from accounts where id=1 for update;
  • 在第二个控制台输入select * from accounts where id=1 for update;,可以看到,它不返回结果了,被阻塞住了,必须等待第一个事务提交或回滚。
  • 在第一个控制台输入update accounts set balance=500 where id=1;,切换到第二个控制台,看到它仍然阻塞
  • 在第一个控制台输入COMMIT;,切换到第二个控制台,可以看到返回出来结果了,balance500

加锁更新账户余额

知道问题在哪里了,我们回到account.sql文件,再增加一条sql语句GetAccountForUpdate

-- name: GetAccountForUpdate :one
SELECT * FROM accounts
WHERE id = $1 LIMIT 1
FOR UPDATE;

然后,在项目终端下运行make sqlc来重新生成代码,然后,可以看到account.sql.go文件里新增加了个GetAccountForUpdate函数。

store.go里面就可以用到它了,把更新账户余额的操作修改一下:

		// 从数据库中获取 account -> 更新账户余额
		account1, err := q.GetAccountForUpdate(ctx, arg.FromAccountID)
		if err != nil {
			return err
		}

		result.FromAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
			ID:      arg.FromAccountID,
			Balance: account1.Balance - arg.Amount,
		})
		if err != nil {
			return err
		}

		account2, err := q.GetAccountForUpdate(ctx, arg.ToAccountID)
		if err != nil {
			return err
		}

		result.ToAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
			ID:      arg.ToAccountID,
			Balance: account2.Balance + arg.Amount,
		})
		if err != nil {
			return err
		}

再次,运行单元测试run test,不幸的是,它又出错了。
deadlock detected
这次报的错是:检测到死锁了 deadlock detected,所以,怎么解决呢?接下来,学习如何调试这种死锁的情况。

调试死锁

为了弄清楚为什么会发生死锁(deadlock),我们需要打印一些日志,查看一下哪个事务正在调用哪个查询sql,以及调用顺序是怎么样的。

所以,我们必须给每个事务分配个名称,并通过上下文参数传递给 TransferTx() 函数,在store_test.go里面的for i := 0; i < n; i++ {内,增加

txName := fmt.Sprintf("tx %d", i+1)

为了把这个txName通过context传递,需要在store.go里面定义一个空的结构体

var txKey = struct{}{}

这里的第二个{}表示空的对象,回到store_test.go改造一下,把txKeytxName作为参数传给context.WithValue()

	for i := 0; i < n; i++ {
		txName := fmt.Sprintf("tx %d", i+1)
		go func() {
			ctx := context.WithValue(context.Background(), txKey, txName)
			result, err := store.TransferTx(ctx, TransferTxParams{
				FromAccountID: account1.ID,
				ToAccountID:   account2.ID,
				Amount:        amount,
			})

			errs <- err
			results <- result
		}()
	}

这样,上下文就可以保存事务名称了。

来到store.go里面的TransferTx函数,增加日志输出:

var txKey = struct{}{}

func (store *Store) TransferTx(ctx context.Context, arg TransferTxParams) (TransferTxResult, error) {
	var result TransferTxResult

	err := store.execTx(ctx, func(q *Queries) error {
		var err error

		txName := ctx.Value(txKey)

		// 1. 创建一个金额等于`10`的转账记录
		fmt.Println(txName, "create transfer")
		result.Transfer, err = q.CreateTransfer(ctx, CreateTransferParams{
			FromAccountID: arg.FromAccountID,
			ToAccountID:   arg.ToAccountID,
			Amount:        arg.Amount,
		})
		if err != nil {
			return err
		}

		// 2. 为`FromAccount`创建一个账目记录,金额为`-10`
		fmt.Println(txName, "create entry 1")
		result.FromEntry, err = q.CreateEntry(ctx, CreateEntryParams{
			AccountID: arg.FromAccountID,
			Amount:    -arg.Amount,
		})
		if err != nil {
			return err
		}

		// 3. 为`ToAccount`创建一个账目记录,金额为`+10`
		fmt.Println(txName, "create entry 2")
		result.ToEntry, err = q.CreateEntry(ctx, CreateEntryParams{
			AccountID: arg.ToAccountID,
			Amount:    arg.Amount,
		})
		if err != nil {
			return err
		}

		// 从数据库中获取 account -> 更新账户余额
		fmt.Println(txName, "get account 1")
		account1, err := q.GetAccountForUpdate(ctx, arg.FromAccountID)
		if err != nil {
			return err
		}

		fmt.Println(txName, "update account 1")
		result.FromAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
			ID:      arg.FromAccountID,
			Balance: account1.Balance - arg.Amount,
		})
		if err != nil {
			return err
		}

		fmt.Println(txName, "get account 2")
		account2, err := q.GetAccountForUpdate(ctx, arg.ToAccountID)
		if err != nil {
			return err
		}

		fmt.Println(txName, "update account 2")
		result.ToAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
			ID:      arg.ToAccountID,
			Balance: account2.Balance + arg.Amount,
		})
		if err != nil {
			return err
		}

		return err
	})

	return result, err
}

为了便于调试,把store_test.go里面的并发数量n改成3,便于跟踪调试,运行它run test,可以看到保存了,并输出了我们打印的日志:

=== RUN   TestTransferTx
>> before: 103 980
tx 3 create transfer
tx 3 create entry 1
tx 3 create entry 2
tx 2 create transfer
tx 1 create transfer
tx 3 get account 1
tx 3 update account 1
tx 3 get account 2
tx 3 update account 2
tx 2 create entry 1
tx 1 create entry 1
tx 1 create entry 2
tx 2 create entry 2
tx 1 get account 1
tx 2 get account 1
>> tx: 93 990
tx 2 update account 1

到这里出错了,现在就要跟踪一下,是怎么出这个错误的,我们来用go执行的过程模拟一下,先把go执行转账事务所需要的sql整理出来,就是下面这个流程:

BEGIN;
-- create transfer
INSERT INTO transfers (from_account_id, to_account_id, amount) VALUES  (1, 2, 10) RETURNING *;
-- create entry 1
INSERT INTO entries (account_id, amount) VALUES (1, -10) RETURNING *;
-- create entry 2
INSERT INTO entries (account_id, amount) VALUES (2, 10) RETURNING *;
-- get account 1
SELECT * FROM accounts WHERE id = 1 FOR UPDATE;
-- update account 1
UPDATE accounts SET balance = 90 WHERE id = 1 RETURNING *;
-- get account 2
SELECT * FROM accounts WHERE id = 2 FOR UPDATE;
-- update account 2
UPDATE accounts SET balance = 110 WHERE id = 2 RETURNING *;

因为我们设置了3个并发,因此,想上面操作的那样,我们打开3个postgres控制台:
3个postgres控制台
接下来,按照日志的执行顺序,分别依次在不同的postgres控制台里输入这些sql

  • 首先这3个控制台都输入BEGIN;开启事务
  • 日志tx 3 create transfer,就在第3个控制台输入,INSERT INTO transfers (from_account_id, to_account_id, amount) VALUES (1, 2, 10) RETURNING *;
  • 日志tx 3 create entry 1,就在第3个控制台输入, INSERT INTO entries (account_id, amount) VALUES (1, -10) RETURNING *;
  • 依次类推…
  • 直到这一条,tx 3 get account 1,在第3个控制台执行 SELECT * FROM accounts WHERE id = 1 FOR UPDATE;,发现卡住了,为什么INSERT transfers的时候,会引发SELECT accounts操作死锁,是不同的表啊;带着问题,我们Google搜postgres lock,在postgresWIKI里(https://wiki.postgresql.org/wiki/Lock_Monitoring),看到了这段:
SELECT blocked_locks.pid     AS blocked_pid,
         blocked_activity.usename  AS blocked_user,
         blocking_locks.pid     AS blocking_pid,
         blocking_activity.usename AS blocking_user,
         blocked_activity.query    AS blocked_statement,
         blocking_activity.query   AS current_statement_in_blocking_process
   FROM  pg_catalog.pg_locks         blocked_locks
    JOIN pg_catalog.pg_stat_activity blocked_activity  ON blocked_activity.pid = blocked_locks.pid
    JOIN pg_catalog.pg_locks         blocking_locks 
        ON blocking_locks.locktype = blocked_locks.locktype
        AND blocking_locks.database IS NOT DISTINCT FROM blocked_locks.database
        AND blocking_locks.relation IS NOT DISTINCT FROM blocked_locks.relation
        AND blocking_locks.page IS NOT DISTINCT FROM blocked_locks.page
        AND blocking_locks.tuple IS NOT DISTINCT FROM blocked_locks.tuple
        AND blocking_locks.virtualxid IS NOT DISTINCT FROM blocked_locks.virtualxid
        AND blocking_locks.transactionid IS NOT DISTINCT FROM blocked_locks.transactionid
        AND blocking_locks.classid IS NOT DISTINCT FROM blocked_locks.classid
        AND blocking_locks.objid IS NOT DISTINCT FROM blocked_locks.objid
        AND blocking_locks.objsubid IS NOT DISTINCT FROM blocked_locks.objsubid
        AND blocking_locks.pid != blocked_locks.pid

    JOIN pg_catalog.pg_stat_activity blocking_activity ON blocking_activity.pid = blocking_locks.pid
   WHERE NOT blocked_locks.granted;
  • 把这段sql放到navicat里执行一下,可以看到:
    lock的sql
    发现,blocked的sql就是SELECT * FROM accounts WHERE id = 1 FOR UPDATE;,而引发blocking的sql是INSERT INTO transfers (from_account_id, to_account_id, amount) VALUES (1, 2, 10) RETURNING *;,因此,对这两个不同的表操作,确实会引发相互阻塞。

  • 让我们更深入的了解一下为什么SELECT查询必须等待INSERT操作,回到postgres WIKI页面,往下继续看,找到1条,找出数据库中所有的锁

SELECT a.datname,
         l.relation::regclass,
         l.transactionid,
         l.mode,
         l.GRANTED,
         a.usename,
         a.query,
         a.query_start,
         age(now(), a.query_start) AS "age",
         a.pid
FROM pg_stat_activity a
JOIN pg_locks l ON l.pid = a.pid
ORDER BY a.query_start;

把它复制到navicat中,并稍作修改,因为要查看一些更多的信息,在a.datname,后面增加a.application_name,,用来查看锁来自哪个应用。

这个sql中,

  • l.relation::regclass其实就是表名,
  • l.transactionid是事务ID,
  • l.mode后面增加l.locktype用来查看锁的类型,
  • l.GRANTED是锁是否被授权,
  • a.usename是运行sql的用户名,
  • a.query是持有或试图获取锁的sql语句,
  • a.query_startage(now(), a.query_start) AS "age"在这个场景中部重要,删除它,
  • a.pid是进程ID
  • ORDER BY a.query_start改为ORDER BY a.pid更好一下,因为我们在psql控制台有3个不同的处理进程,更容易看出来哪个锁是属于哪个事务,
  • 增加WHERE a.application_name = 'psql'只关注和psql控制台相关的
  • 如下。
SELECT a.application_name,
         l.relation::regclass,
         l.transactionid,
         l.mode,
		 l.locktype,
         l.GRANTED,
         a.usename,
         a.query,
         a.pid
FROM pg_stat_activity a
JOIN pg_locks l ON l.pid = a.pid
WHERE a.application_name = 'psql'
ORDER BY a.pid;

在,navicat里运行一下,
未授权的sql
看到只有一条,GRANTEDf的记录,可以看到它的transactionid1075,可以看到transactionid1075的还有一条,它正在进行INSERT操作,pid6991,但是,为什么SELECT FROM accounts表需要从INSERT INTO transfers的事务中获取锁呢?

来看一下,我们的sql建表语句,有这样的外键约束:

ALTER TABLE "transfers" ADD FOREIGN KEY ("from_account_id") REFERENCES "accounts" ("id");

所以,对transfers表的from_account_id做任何更新操作时,都会影响到accounts表做SELECT FOR UPDATE的操作,因为,这里需要获取一个锁来防止冲突,保证数据的一致性。这样,就解释了deadlock是如何发生的。如何解决这个问题呢?
让我们先把这3个控制台执行ROLLBACK;\q退出。

修复死锁方案1

我们知道,死锁是由外键约束引起的,那我们尝试把这几个外键约束的sql都删除掉,在000001_init_schema.up.sql文件里,先注释掉下面这几行:

-- ALTER TABLE "entries" ADD FOREIGN KEY ("account_id") REFERENCES "accounts" ("id");

-- ALTER TABLE "transfers" ADD FOREIGN KEY ("from_account_id") REFERENCES "accounts" ("id");

-- ALTER TABLE "transfers" ADD FOREIGN KEY ("to_account_id") REFERENCES "accounts" ("id");

让我们把数据库降级一下make migratedown,再make migrateup,之后,再来运行这个单元测试run test,可以看到成功了。

但是,这不是最好的解决方案,因为,这样失去了保持数据一致性的良好约束。所以,恢复刚才注释掉的那几行sql代码。并运行make migratedownmake migrateup

修复死锁方案2

打开account.sql文件,看这里

-- name: UpdateAccount :one
UPDATE accounts SET balance = $2 WHERE id = $1
RETURNING *;

更新账户余额的时候,id是永远不会更改的,因为它是账户表的主键。所以,如果我们能告诉postgres,SELECT * FROM accounts FOR UPDATE时不更新主键,那么postgres将不需要获取事务锁,因此,就不会出现deadlock

恩,有个知识点,FOR UPDATE的时候,我们只需要明确一下,NO KEY UPDATE,说明我们不更新主键ID,像这样。

-- name: GetAccountForUpdate :one
SELECT * FROM accounts
WHERE id = $1 LIMIT 1
FOR NO KEY UPDATE;

在项目终端里,在运行make sqlc重新生成代码,再运行这个单元测试,可以看到成功了。

之后,清理一下代码,把我们之前加的打印日志都去掉吧,再运行一下单元测试,没问题,通过。
测试通过

可以,看到这个日志,每次交易后的2个账户余额是怎么变化的。

修复死锁方案3

有一种更好的方法来实现更新余额操作。目前,我们必须执行两个sql来获取并更新账户余额:

		account1, err := q.GetAccountForUpdate(ctx, arg.FromAccountID)
		if err != nil {
			return err
		}

		result.FromAccount, err = q.UpdateAccount(ctx, UpdateAccountParams{
			ID:      arg.FromAccountID,
			Balance: account1.Balance - arg.Amount,
		})
		if err != nil {
			return err
		}

其实,可以通过1个sql语句来搞定。所以,我们在account.sql里再增加一个AddAccountBalance语句,它和UpdateAccount类似,不同的是设置 balance 等于 balance + 第二个参数,如下:

-- name: AddAccountBalance :one
UPDATE accounts SET balance = balance + $2 WHERE id = $1
RETURNING *;

之后,运行make sqlc来生成代码,看到account.sql.go里面增加了如下代码:

type AddAccountBalanceParams struct {
	ID      int64 `json:"id"`
	Balance int64 `json:"balance"`
}

但是,这个参数Balance有点混淆,我们是让余额增加或减少金额,而不是更新Balance本身,在sqlc可以这样做:

-- name: AddAccountBalance :one
UPDATE accounts SET balance = balance + sqlc.arg(amount) 
WHERE id = sqlc.arg(id)
RETURNING *;

$2改成sqlc.arg(amount),$1改成sqlc.arg(id),之后,在运行make sqlc,可以看到生成的代码变成了这样:

type AddAccountBalanceParams struct {
	Amount int64 `json:"amount"`
	ID     int64 `json:"id"`
}

Balance变成了Amount,这样语义就清晰了。

回到store.go文件,删除GetAccountForUpdate,把UpdateAccount改为AddAccountBalance,完整代码如下:

package db

import (
	"context"
	"database/sql"
	"fmt"
)

type Store struct {
	*Queries
	db *sql.DB
}

func NewStore(db *sql.DB) *Store {
	return &Store{
		db:      db,
		Queries: New(db),
	}
}

func (store *Store) execTx(ctx context.Context, fn func(*Queries) error) error {
	tx, err := store.db.BeginTx(ctx, nil)
	if err != nil {
		return err
	}

	q := New(tx)
	err = fn(q)
	if err != nil {
		if rbErr := tx.Rollback(); rbErr != nil {
			return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr)
		}
		return err
	}

	return tx.Commit()
}

type TransferTxParams struct {
	FromAccountID int64 `json:"from_account_id"`
	ToAccountID   int64 `json:"to_account_id"`
	Amount        int64 `json:"amount"`
}

type TransferTxResult struct {
	Transfer    Transfer `json:"transfer"`
	FromAccount Account  `json:"from_account"`
	ToAccount   Account  `json:"to_account"`
	FromEntry   Entry    `json:"from_entry"`
	ToEntry     Entry    `json:"to_entry"`
}

func (store *Store) TransferTx(ctx context.Context, arg TransferTxParams) (TransferTxResult, error) {
	var result TransferTxResult

	err := store.execTx(ctx, func(q *Queries) error {
		var err error

		// 1. 创建一个金额等于`10`的转账记录
		result.Transfer, err = q.CreateTransfer(ctx, CreateTransferParams{
			FromAccountID: arg.FromAccountID,
			ToAccountID:   arg.ToAccountID,
			Amount:        arg.Amount,
		})
		if err != nil {
			return err
		}

		// 2. 为`FromAccount`创建一个账目记录,金额为`-10`
		result.FromEntry, err = q.CreateEntry(ctx, CreateEntryParams{
			AccountID: arg.FromAccountID,
			Amount:    -arg.Amount,
		})
		if err != nil {
			return err
		}

		// 3. 为`ToAccount`创建一个账目记录,金额为`+10`
		result.ToEntry, err = q.CreateEntry(ctx, CreateEntryParams{
			AccountID: arg.ToAccountID,
			Amount:    arg.Amount,
		})
		if err != nil {
			return err
		}

		// 更新账户余额
		result.FromAccount, err = q.AddAccountBalance(ctx, AddAccountBalanceParams{
			ID:     arg.FromAccountID,
			Amount: -arg.Amount,
		})
		if err != nil {
			return err
		}

		result.ToAccount, err = q.AddAccountBalance(ctx, AddAccountBalanceParams{
			ID:     arg.ToAccountID,
			Amount: arg.Amount,
		})
		if err != nil {
			return err
		}

		return err
	})

	return result, err
}

再次运行单元测试,没问题,测试通过。

好了,本次课程学习结束。下节,将继续学习如何避免数据库事务查询中的死锁顺序问题

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值