go源码系列 database/sql.go SQL 包解析
样例 demo
package main
import (
"database/sql"
"log"
_ "github.com/go-sql-driver/mysql"
)
func main() {
db, err := sql.Open("mysql", "root:1234567@tcp(127.0.0.1:3306)/hello")
if err != nil {
log.Fatal(err)
}
defer db.Close()
var (
id int
name string
)
rows, err := db.Query("select id, name from users where id = ?", 1)
if err != nil {
log.Fatal(err)
}
defer rows.Close()
for rows.Next() {
err := rows.Scan(&id, &name)
if err != nil {
log.Fatal(err)
}
log.Println(id, name)
}
err = rows.Err()
if err != nil {
log.Fatal(err)
}
}
代码走读
在github.com/go-sql-driver/mysql包内
// 注册 mysql drive
func init() {
// => drivers = make(map[string]driver.Driver)
sql.Register("mysql", &MySQLDriver{})
}
// 只要实现了 open 函数就是一个 Driver
type Driver interface {
Open(name string) (Conn, error)
}
// 只要实现了OpenConnector 方法 就是一个DriverContext
type DriverContext interface {
OpenConnector(name string) (Connector, error)
}
// mysql driver 包内 实现了open 和OpenConnector 函数
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
cfg, err := ParseDSN(dsn)
if err != nil {
return nil, err
}
c := &connector{
cfg: cfg,
}
return c.Connect(context.Background())
}
// OpenConnector implements driver.DriverContext.
func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) {
cfg, err := ParseDSN(dsn)
if err != nil {
return nil, err
}
return &connector{
cfg: cfg,
}, nil
}
// 在 sql/sql.go 的主函数中
func Open(driverName, dataSourceName string) (*DB, error) {
driversMu.RLock()
driveri, ok := drivers[driverName]
driversMu.RUnlock()
if !ok {
return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
}
// 强转拿到的driver为driver.DriverContext 接口类型
if driverCtx, ok := driveri.(driver.DriverContext); ok {
// 直接执行其 OpenConnector函数 -> 走到(d MySQLDriver) OpenConnector
// 返回一个 connector{cfg *Config}
connector, err := driverCtx.OpenConnector(dataSourceName)
if err != nil {
return nil, err
}
// 用 connector{cfg *Config} 调用OpenDB(connector)
return OpenDB(connector), nil
}
return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil
}
// 在/go-sql-driver/mysql 里的 Driver 实现
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
// 解析输入的 DSN Data Source Name,数据库的源名称
cfg, err := ParseDSN(dsn)
if err != nil {
return nil, err
}
// 处理好 config 设置 connector
c := &connector{
cfg: cfg,
}
return c.Connect(context.Background())
}
// ParseDSN parses the DSN string to a Config
func ParseDSN(dsn string) (cfg *Config, err error) {
cfg = NewConfig()
// 依次解析出 usename,password,addr,dbname
// [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN]
// Find the last '/' (since the password or the net addr might contain a '/')
...
cfg.Passwd = dsn[k+1 : j]
...
cfg.User = dsn[:k]
cfg.Addr = dsn[k+1 : i-1]
...
cfg.Net = dsn[j+1 : k]
...
// 解析问号后的一段mysql参数
parseDSNParams(cfg, dsn[j+1:])
cfg.DBName = dsn[i+1 : j]
}
// parseDSNParams parses the DSN "query string"
// Values must be url.QueryEscape'ed
func parseDSNParams(cfg *Config, params string) (err error) {
// & 切开query字符串
for _, v := range strings.Split(params, "&") {
param := strings.SplitN(v, "=", 2)
// cfg params
// 再用= 切开
// switch param 匹配设置 config
switch value := param[1]; param[0] {
// Disable INFILE allowlist / enable all files
case "allowAllFiles":
var isBool bool
cfg.AllowAllFiles, isBool = readBool(value)
....
// 设置 allowCleartextPasswords allowFallbackToPlaintext checkConnLiveness
// parseTime readTimeout rejectReadOnly strict tls
}
}
}
打开链接,建立 tcp socket 链接
// 设置 connector
type connector struct {
cfg *Config // immutable private copy.
}
// 实现 driver.Connector 接口
// Connect implements driver.Connector interface.
// Connect returns a connection to the database.
// 返回一个数据库连接
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
// New mysqlConn
mc := &mysqlConn{
maxAllowedPacket: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
closech: make(chan struct{}),
cfg: c.cfg,
}
mc.parseTime = mc.cfg.ParseTime
// dial 通信连接
dial, ok := dials[mc.cfg.Net]
if ok {
dctx := ctx
// 如果设置了 io 超时时间 ,ctx 要设置为 WithTimeout的
if mc.cfg.Timeout > 0 {
var cancel context.CancelFunc
dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout)
defer cancel()
}
// tcp 请求开启连接
mc.netConn, err = dial(dctx, mc.cfg.Addr)
} else {
// 说明在dials map: dials map[string]DialContextFunc 中没有该 net
// 设置一个 Dialer
nd := net.Dialer{Timeout: mc.cfg.Timeout}
//
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
...
// Call startWatcher for context support (From Go 1.8)
// 开启mysqlconnection 的 watcher 监听channel [watcher,finished]
mc.startWatcher()
mc.buf = newBuffer(mc.netConn)
// Reading Handshake Initialization Packet 握手协议包
authData, plugin, err := mc.readHandshakePacket()
// Send Client Authentication Packet auth 认证
authResp, err := mc.auth(authData, plugin)
...
mc.writeHandshakeResponsePacket(authResp, plugin)
mc.handleAuthResult(authData, plugin)
// Handle DSN Params 根据params 设置参数
err = mc.handleParams()
// 返回一个 mysqlConn
return mc,nil
}
}
type mysqlConn struct {
buf buffer
netConn net.Conn
rawConn net.Conn // underlying connection when netConn is TLS connection.
affectedRows uint64
insertId uint64
cfg *Config
maxAllowedPacket int
maxWriteSize int
writeTimeout time.Duration
flags clientFlag
status statusFlag
sequence uint8
parseTime bool
reset bool // set when the Go SQL package calls ResetSession
// for context support (Go 1.8+)
watching bool
watcher chan<- context.Context
closech chan struct{}
finished chan<- struct{}
canceled atomicError // set non-nil if conn is canceled
closed atomicBool // set when conn is closed, before closech is closed
}
// 根据 dialer 建立打开 网络 链接
func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
...
addrs, err := d.resolver().resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr)
sd := &sysDialer{
Dialer: *d,
network: network,
address: address,
}
var c Conn
if len(fallbacks) > 0 {
c, err = sd.dialParallel(ctx, primaries, fallbacks)
} else {
c, err = sd.dialSerial(ctx, primaries)
}
...
}
db.Query发起查询
rows, err := db.Query("select id, name from users where id = ?", 1)
// 走到这个方法
func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
var rows *Rows
var err error
var isBadConn bool
// 没超出重试次数
for i := 0; i < maxBadConnRetries; i++ {
// 执行查询
rows, err = db.query(ctx, query, args, cachedOrNewConn)
isBadConn = errors.Is(err, driver.ErrBadConn)
if !isBadConn {
break
}
}
if isBadConn {
// 创建新 conn 查询
return db.query(ctx, query, args, alwaysNewConn:1)
}
return rows, err
}
// connReuseStrategy 连接使用策略
func (db *DB) query(ctx context.Context, query string, args []any, strategy connReuseStrategy) (*Rows, error) {
// if 有空闲连接 从 db.freeConn[:last]获取最后一个返回driverConn
// 如果没有空闲连接或者不允许使用空闲连接,代码会检查是否允许打开更多的连接。
// 如果不允许打开更多的连接,则会创建一个连接请求,并等待请求被处理。
// : db.connRequests[reqKey] = req 创建一个请求到 db 的等待请求队列
// 并等待 可用的 conn 放入等待 channel: case ret, ok := <-req:
// 如果允许创建新的连接 则创建新的连接并返回
dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
// 获取的driverConn
return db.queryDC(ctx, nil, dc, dc.releaseConn, query, args)
}
// queryDC executes a query on the given connection.
// The connection gets released by the releaseConn function.
// The ctx context is from a query method and the txctx context is from an
// optional transaction context.
func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []any) (*Rows, error) {
// 转为 queryerCtx 和 queryer
queryerCtx, ok := dc.ci.(driver.QueryerContext)
var queryer driver.Queryer
if !ok {
queryer, ok = dc.ci.(driver.Queryer)
}
if ok {
var nvdargs []driver.NamedValue
var rowsi driver.Rows
var err error
withLock(dc, func() {
nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
if err != nil {
return
}
// 调用相应的方法执行查询,并返回一个driver.Rows对象
rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs)
})
if err != driver.ErrSkip {
if err != nil {
releaseConn(err)
return nil, err
}
// Note: ownership of dc passes to the *Rows, to be freed
// with releaseConn.
rows := &Rows{
dc: dc,
releaseConn: releaseConn,
rowsi: rowsi,
}
rows.initContextClose(ctx, txctx)
// 返回结果
return rows, nil
}
}
var si driver.Stmt
var err error
withLock(dc, func() {
si, err = ctxDriverPrepare(ctx, dc.ci, query)
})
if err != nil {
releaseConn(err)
return nil, err
}
//准备一个driver.Stmt对象
ds := &driverStmt{Locker: dc, si: si}
// rowsiFromStatement方法执行查询,并返回一个driver.Rows对象。
rowsi, err := rowsiFromStatement(ctx, dc.ci, ds, args...)
if err != nil {
ds.Close()
// 释放连接
releaseConn(err)
return nil, err
}
// Note: ownership of ci passes to the *Rows, to be freed
// with releaseConn.
// 组织 mysql 查询结果到 Rows
rows := &Rows{
dc: dc,
releaseConn: releaseConn,
rowsi: rowsi,
closeStmt: ds,
}
rows.initContextClose(ctx, txctx)
return rows, nil
}
func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) {
if siCtx, is := si.(driver.StmtQueryContext); is {
return siCtx.QueryContext(ctx, nvdargs)
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
}
select {
default:
case <-ctx.Done():
return nil, ctx.Err()
}
return si.Query(dargs)
}
执行 mysql 的QueryContext
// sql driver
func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
rows, err := mc.query(query, dargs)
if err != nil {
mc.finish()
return nil, err
}
rows.finish = mc.finish
return rows, err
}
// 发送 mysql 的真正查询,并返回结果
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
...
// 错误处理 和 query 预处理
query = prepared
...
// Send command 发送命令 包
err := mc.writeCommandPacketStr(comQuery:3, query)
if err == nil {
// Read Result
var resLen int
// 读返回结果头-> 结果长度
resLen, err = mc.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
if resLen == 0 {
rows.rs.done = true
switch err := rows.NextResultSet(); err {
case nil, io.EOF:
return rows, nil
default:
return nil, err
}
}
// Columns
// 读结果 column 类型
rows.rs.columns, err = mc.readColumns(resLen)
// 返回 row 持有 mysqlconnection 和结果的字节开始地址
return rows, err
}
}
return nil, mc.markBadConn(err)
}
// 返回结果集对象
type mysqlRows struct {
mc *mysqlConn
rs resultSet
finish func()
}
type binaryRows struct {
mysqlRows
}
type textRows struct {
mysqlRows
}
mysql 的 scan
rows, err := db.Query("select id, name from users where id = ?", 1)
if err != nil {
log.Fatal(err)
}
defer rows.Close()
// 读取下一行数据到rows.lastcols
for rows.Next() {
err := rows.Scan(&id, &name)
if err != nil {
log.Fatal(err)
}
log.Println(id, name)
}
// rows.Next() -> nextLocked
func (rs *Rows) nextLocked() (doClose, ok bool) {
if rs.closed {
return false, false
}
// Lock the driver connection before calling the driver interface
// rowsi to prevent a Tx from rolling back the connection at the same time.
rs.dc.Lock()
defer rs.dc.Unlock()
// 开辟rs.lastcols 内存空间
if rs.lastcols == nil {
rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
}
// 读rs.rowsi的内存块结果的下一行数据到rs.lastcols内存地址
rs.lasterr = rs.rowsi.Next(rs.lastcols)
...
// 错误处理
}
// 调 mysql 包的 Next 方法
func (rows *binaryRows) Next(dest []driver.Value) error {
if mc := rows.mc; mc != nil {
if err := mc.error(); err != nil {
return err
}
// Fetch next row from stream
// 读一行 stream 到dest
return rows.readRow(dest)
}
return io.EOF
}
func (rows *textRows) Next(dest []driver.Value) error {
if mc := rows.mc; mc != nil {
if err := mc.error(); err != nil {
return err
}
// Fetch next row from stream
return rows.readRow(dest)
}
return io.EOF
}
func (rs *Rows) Scan(dest ...any) error {
...
for i, sv := range rs.lastcols {
// 把rs.lastcols中当前处理行的值 转换为目标的类型,保存到 dest 的地址
err := convertAssignRows(dest[i], sv, rs)
if err != nil {
return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err)
}
}
return nil
}