go源码系列 database/sql.go SQL 包解析

样例 demo

package main

import (
	"database/sql"
	"log"

	_ "github.com/go-sql-driver/mysql"
)

func main() {
	db, err := sql.Open("mysql", "root:1234567@tcp(127.0.0.1:3306)/hello")
	if err != nil {
		log.Fatal(err)
	}
	defer db.Close()

	var (
		id   int
		name string
	)
	rows, err := db.Query("select id, name from users where id = ?", 1)
	if err != nil {
		log.Fatal(err)
	}
	defer rows.Close()

	for rows.Next() {
		err := rows.Scan(&id, &name)
		if err != nil {
			log.Fatal(err)
		}
		log.Println(id, name)
	}

	err = rows.Err()
	if err != nil {
		log.Fatal(err)
	}
}

代码走读

在github.com/go-sql-driver/mysql包内


// 注册 mysql drive
func init() {
  // => drivers   = make(map[string]driver.Driver)
	sql.Register("mysql", &MySQLDriver{})
}
// 只要实现了 open 函数就是一个 Driver
type Driver interface {
	Open(name string) (Conn, error)
}
// 只要实现了OpenConnector 方法 就是一个DriverContext
type DriverContext interface {
	OpenConnector(name string) (Connector, error)
}
// mysql driver 包内 实现了open 和OpenConnector 函数
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
	cfg, err := ParseDSN(dsn)
	if err != nil {
		return nil, err
	}
	c := &connector{
		cfg: cfg,
	}
	return c.Connect(context.Background())
}
// OpenConnector implements driver.DriverContext.
func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) {
	cfg, err := ParseDSN(dsn)
	if err != nil {
		return nil, err
	}
	return &connector{
		cfg: cfg,
	}, nil
}

// 在 sql/sql.go 的主函数中 
func Open(driverName, dataSourceName string) (*DB, error) {
	driversMu.RLock()
	driveri, ok := drivers[driverName]
	driversMu.RUnlock()
	if !ok {
		return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
	}
	// 强转拿到的driver为driver.DriverContext 接口类型
	if driverCtx, ok := driveri.(driver.DriverContext); ok {
    // 直接执行其 OpenConnector函数 -> 走到(d MySQLDriver) OpenConnector
    // 返回一个 connector{cfg *Config}
		connector, err := driverCtx.OpenConnector(dataSourceName)
		if err != nil {
			return nil, err
		}
    // 用  connector{cfg *Config} 调用OpenDB(connector)
		return OpenDB(connector), nil
	}

	return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil
}
// 在/go-sql-driver/mysql 里的 Driver 实现
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
  // 解析输入的 DSN Data Source Name,数据库的源名称
	cfg, err := ParseDSN(dsn)
	if err != nil {
		return nil, err
	}
  // 处理好 config 设置 connector
	c := &connector{
		cfg: cfg,
	}
	return c.Connect(context.Background())
}
// ParseDSN parses the DSN string to a Config
func ParseDSN(dsn string) (cfg *Config, err error) {
  cfg = NewConfig()
  // 依次解析出 usename,password,addr,dbname
  	// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
	// Find the last '/' (since the password or the net addr might contain a '/')
  ...
  cfg.Passwd = dsn[k+1 : j]
  ...
  cfg.User = dsn[:k]
  cfg.Addr = dsn[k+1 : i-1]
  ...
  cfg.Net = dsn[j+1 : k]
  ...
  // 解析问号后的一段mysql参数
  parseDSNParams(cfg, dsn[j+1:])
  cfg.DBName = dsn[i+1 : j]
}
// parseDSNParams parses the DSN "query string"
// Values must be url.QueryEscape'ed
func parseDSNParams(cfg *Config, params string) (err error) {
  // & 切开query字符串
	for _, v := range strings.Split(params, "&") {
    param := strings.SplitN(v, "=", 2)
    // cfg params
    // 再用= 切开 
    // switch param 匹配设置 config
		switch value := param[1]; param[0] {
		// Disable INFILE allowlist / enable all files
		case "allowAllFiles":
			var isBool bool
			cfg.AllowAllFiles, isBool = readBool(value)
      ....
      // 设置 allowCleartextPasswords allowFallbackToPlaintext checkConnLiveness
      // parseTime readTimeout rejectReadOnly strict tls
    }
  }
}

打开链接,建立 tcp socket 链接

// 设置 connector
type connector struct {
	cfg *Config // immutable private copy.
}

// 实现 driver.Connector 接口
// Connect implements driver.Connector interface.
// Connect returns a connection to the database.
// 返回一个数据库连接
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
  // New mysqlConn
	mc := &mysqlConn{
		maxAllowedPacket: maxPacketSize,
		maxWriteSize:     maxPacketSize - 1,
		closech:          make(chan struct{}),
		cfg:              c.cfg,
	}
	mc.parseTime = mc.cfg.ParseTime
  // dial 通信连接
  dial, ok := dials[mc.cfg.Net]
  	if ok {
		dctx := ctx
      // 如果设置了 io 超时时间 ,ctx 要设置为 WithTimeout的
		if mc.cfg.Timeout > 0 {
			var cancel context.CancelFunc
			dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout)
			defer cancel()
		}
    // tcp 请求开启连接  
		mc.netConn, err = dial(dctx, mc.cfg.Addr)
	} else {
    // 说明在dials map: dials     map[string]DialContextFunc 中没有该 net
    // 设置一个 Dialer
		nd := net.Dialer{Timeout: mc.cfg.Timeout}
      // 
		mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
    ...
   // Call startWatcher for context support (From Go 1.8)
      // 开启mysqlconnection 的 watcher 监听channel [watcher,finished]
	  mc.startWatcher()
    mc.buf = newBuffer(mc.netConn) 
    // Reading Handshake Initialization Packet 握手协议包
  	authData, plugin, err := mc.readHandshakePacket()  
    // Send Client Authentication Packet auth 认证
	  authResp, err := mc.auth(authData, plugin) 
    ...
     mc.writeHandshakeResponsePacket(authResp, plugin)
     mc.handleAuthResult(authData, plugin)
     // Handle DSN Params 根据params 设置参数  
		 err = mc.handleParams() 
      // 返回一个 mysqlConn
     return mc,nil 
	}
}

type mysqlConn struct {
	buf              buffer
	netConn          net.Conn
	rawConn          net.Conn // underlying connection when netConn is TLS connection.
	affectedRows     uint64
	insertId         uint64
	cfg              *Config
	maxAllowedPacket int
	maxWriteSize     int
	writeTimeout     time.Duration
	flags            clientFlag
	status           statusFlag
	sequence         uint8
	parseTime        bool
	reset            bool // set when the Go SQL package calls ResetSession

	// for context support (Go 1.8+)
	watching bool
	watcher  chan<- context.Context
	closech  chan struct{}
	finished chan<- struct{}
	canceled atomicError // set non-nil if conn is canceled
	closed   atomicBool  // set when conn is closed, before closech is closed
}

// 根据 dialer 建立打开 网络 链接 
func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
  ...
  addrs, err := d.resolver().resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr)
  sd := &sysDialer{
		Dialer:  *d,
		network: network,
		address: address,
	}
  var c Conn
	if len(fallbacks) > 0 {
		c, err = sd.dialParallel(ctx, primaries, fallbacks)
	} else {
		c, err = sd.dialSerial(ctx, primaries)
	}
  ...
}

db.Query发起查询

rows, err := db.Query("select id, name from users where id = ?", 1)
// 走到这个方法
func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
	var rows *Rows
	var err error
	var isBadConn bool
  // 没超出重试次数
	for i := 0; i < maxBadConnRetries; i++ {
    // 执行查询
		rows, err = db.query(ctx, query, args, cachedOrNewConn)
		isBadConn = errors.Is(err, driver.ErrBadConn)
		if !isBadConn {
			break
		}
	}
	if isBadConn {
    // 创建新 conn 查询
    return db.query(ctx, query, args, alwaysNewConn:1)
	}
	return rows, err
}

// connReuseStrategy 连接使用策略
func (db *DB) query(ctx context.Context, query string, args []any, strategy connReuseStrategy) (*Rows, error) {
  // if 有空闲连接 从 db.freeConn[:last]获取最后一个返回driverConn
  // 如果没有空闲连接或者不允许使用空闲连接,代码会检查是否允许打开更多的连接。
  // 如果不允许打开更多的连接,则会创建一个连接请求,并等待请求被处理。
  //  : db.connRequests[reqKey] = req 创建一个请求到 db 的等待请求队列
  // 并等待 可用的 conn 放入等待 channel: case ret, ok := <-req:
  // 如果允许创建新的连接 则创建新的连接并返回
	dc, err := db.conn(ctx, strategy)
	if err != nil {
		return nil, err
	}
	// 获取的driverConn
	return db.queryDC(ctx, nil, dc, dc.releaseConn, query, args)
}

// queryDC executes a query on the given connection.
// The connection gets released by the releaseConn function.
// The ctx context is from a query method and the txctx context is from an
// optional transaction context.
func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []any) (*Rows, error) {
  // 转为 queryerCtx 和 queryer
	queryerCtx, ok := dc.ci.(driver.QueryerContext)
	var queryer driver.Queryer
	if !ok {
		queryer, ok = dc.ci.(driver.Queryer)
	}
	if ok {
		var nvdargs []driver.NamedValue
		var rowsi driver.Rows
		var err error
		withLock(dc, func() {
			nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
			if err != nil {
				return
			}
      // 调用相应的方法执行查询,并返回一个driver.Rows对象
			rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs)
		})
		if err != driver.ErrSkip {
			if err != nil {
				releaseConn(err)
				return nil, err
			}
			// Note: ownership of dc passes to the *Rows, to be freed
			// with releaseConn.
			rows := &Rows{
				dc:          dc,
				releaseConn: releaseConn,
				rowsi:       rowsi,
			}
			rows.initContextClose(ctx, txctx)
      // 返回结果
			return rows, nil
		}
	}

	var si driver.Stmt
	var err error
	withLock(dc, func() {
		si, err = ctxDriverPrepare(ctx, dc.ci, query)
	})
	if err != nil {
		releaseConn(err)
		return nil, err
	}
  //准备一个driver.Stmt对象
	ds := &driverStmt{Locker: dc, si: si}
  // rowsiFromStatement方法执行查询,并返回一个driver.Rows对象。
	rowsi, err := rowsiFromStatement(ctx, dc.ci, ds, args...)
	if err != nil {
		ds.Close()
    // 释放连接
		releaseConn(err)
		return nil, err
	}

	// Note: ownership of ci passes to the *Rows, to be freed
	// with releaseConn.
  // 组织 mysql 查询结果到 Rows
	rows := &Rows{
		dc:          dc,
		releaseConn: releaseConn,
		rowsi:       rowsi,
		closeStmt:   ds,
	}
	rows.initContextClose(ctx, txctx)
	return rows, nil
}

func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) {
	if siCtx, is := si.(driver.StmtQueryContext); is {
		return siCtx.QueryContext(ctx, nvdargs)
	}
	dargs, err := namedValueToValue(nvdargs)
	if err != nil {
		return nil, err
	}

	select {
	default:
	case <-ctx.Done():
		return nil, ctx.Err()
	}
	return si.Query(dargs)
}

执行 mysql 的QueryContext

// sql driver 
func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
	dargs, err := namedValueToValue(args)
	if err != nil {
		return nil, err
	}

	if err := mc.watchCancel(ctx); err != nil {
		return nil, err
	}

	rows, err := mc.query(query, dargs)
	if err != nil {
		mc.finish()
		return nil, err
	}
	rows.finish = mc.finish
	return rows, err
}

// 发送 mysql 的真正查询,并返回结果
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
 ... 
  // 错误处理  和 query 预处理
		query = prepared
	 ...
	// Send command 发送命令 包
  err := mc.writeCommandPacketStr(comQuery:3, query)
	if err == nil {
		// Read Result
		var resLen int
    // 读返回结果头-> 结果长度
		resLen, err = mc.readResultSetHeaderPacket()
		if err == nil {
			rows := new(textRows)
			rows.mc = mc

			if resLen == 0 {
				rows.rs.done = true

				switch err := rows.NextResultSet(); err {
				case nil, io.EOF:
					return rows, nil
				default:
					return nil, err
				}
			}

			// Columns
      // 读结果 column 类型
			rows.rs.columns, err = mc.readColumns(resLen)
      // 返回 row 持有 mysqlconnection 和结果的字节开始地址
			return rows, err
		}
	}
	return nil, mc.markBadConn(err)
}

// 返回结果集对象
type mysqlRows struct {
	mc     *mysqlConn
	rs     resultSet
	finish func()
}

type binaryRows struct {
	mysqlRows
}

type textRows struct {
	mysqlRows
}

mysql 的 scan

	rows, err := db.Query("select id, name from users where id = ?", 1)
	if err != nil {
		log.Fatal(err)
	}
	defer rows.Close()
  // 读取下一行数据到rows.lastcols
	for rows.Next() {
		err := rows.Scan(&id, &name)
		if err != nil {
			log.Fatal(err)
		}
		log.Println(id, name)
	}

// rows.Next() -> nextLocked
 func (rs *Rows) nextLocked() (doClose, ok bool) {
	if rs.closed {
		return false, false
	}

	// Lock the driver connection before calling the driver interface
	// rowsi to prevent a Tx from rolling back the connection at the same time.
	rs.dc.Lock()
	defer rs.dc.Unlock()
	// 开辟rs.lastcols 内存空间
	if rs.lastcols == nil {
		rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
	}
	// 读rs.rowsi的内存块结果的下一行数据到rs.lastcols内存地址
	rs.lasterr = rs.rowsi.Next(rs.lastcols)
	...
   // 错误处理
}

// 调 mysql 包的 Next 方法
func (rows *binaryRows) Next(dest []driver.Value) error {
	if mc := rows.mc; mc != nil {
		if err := mc.error(); err != nil {
			return err
		}

		// Fetch next row from stream
    // 读一行 stream 到dest
		return rows.readRow(dest)
	}
	return io.EOF
}
func (rows *textRows) Next(dest []driver.Value) error {
	if mc := rows.mc; mc != nil {
		if err := mc.error(); err != nil {
			return err
		}

		// Fetch next row from stream
		return rows.readRow(dest)
	}
	return io.EOF
}


func (rs *Rows) Scan(dest ...any) error {
...
	for i, sv := range rs.lastcols {
    // 把rs.lastcols中当前处理行的值 转换为目标的类型,保存到 dest 的地址
		err := convertAssignRows(dest[i], sv, rs)
		if err != nil {
			return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err)
		}
	}
	return nil
}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值