Go MySQL 高级特性实现指南
1. 连接池管理
package db
import (
"time"
"github.com/jmoiron/sqlx"
)
type DBPool struct {
Master *sqlx.DB
Slaves []*sqlx.DB
}
func NewDBPool(masterDSN string, slaveDSNs []string) (*DBPool, error) {
master, err := sqlx.Connect("mysql", masterDSN)
if err != nil {
return nil, err
}
master.SetMaxOpenConns(100)
master.SetMaxIdleConns(10)
master.SetConnMaxLifetime(time.Hour)
master.SetConnMaxIdleTime(time.Minute * 30)
slaves := make([]*sqlx.DB, 0, len(slaveDSNs))
for _, dsn := range slaveDSNs {
slave, err := sqlx.Connect("mysql", dsn)
if err != nil {
return nil, err
}
slave.SetMaxOpenConns(50)
slave.SetMaxIdleConns(5)
slave.SetConnMaxLifetime(time.Hour)
slave.SetConnMaxIdleTime(time.Minute * 30)
slaves = append(slaves, slave)
}
return &DBPool{
Master: master,
Slaves: slaves,
}, nil
}
2. ORM 映射实现
package orm
import (
"reflect"
"strings"
)
type Model interface {
TableName() string
}
type BaseModel struct {
ID int64 `db:"id" json:"id"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
type Tag struct {
Name string
Options map[string]string
}
type ModelMapper struct {
model Model
tableName string
fields map[string]*Tag
}
func NewModelMapper(model Model) *ModelMapper {
m := &ModelMapper{
model: model,
tableName: model.TableName(),
fields: make(map[string]*Tag),
}
m.parseModel()
return m
}
func (m *ModelMapper) parseModel() {
t := reflect.TypeOf(m.model)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
tag := field.Tag.Get("db")
if tag == "" || tag == "-" {
continue
}
m.fields[field.Name] = &Tag{
Name: tag,
Options: parseTagOptions(field.Tag),
}
}
}
func (m *ModelMapper) InsertSQL() string {
var cols, vals []string
for _, tag := range m.fields {
cols = append(cols, tag.Name)
vals = append(vals, "?")
}
return fmt.Sprintf(
"INSERT INTO %s (%s) VALUES (%s)",
m.tableName,
strings.Join(cols, ","),
strings.Join(vals, ","),
)
}
3. 读写分离实现
package db
import (
"context"
"math/rand"
)
type DBRouter struct {
master *sqlx.DB
slaves []*sqlx.DB
counter uint64
}
func (r *DBRouter) Master() *sqlx.DB {
return r.master
}
func (r *DBRouter) Slave() *sqlx.DB {
if len(r.slaves) == 0 {
return r.master
}
return r.slaves[rand.Intn(len(r.slaves))]
}
func (r *DBRouter) RoundRobinSlave() *sqlx.DB {
if len(r.slaves) == 0 {
return r.master
}
atomic.AddUint64(&r.counter, 1)
return r.slaves[r.counter%uint64(len(r.slaves))]
}
type UserRepository struct {
router *DBRouter
}
func (r *UserRepository) Create(ctx context.Context, user *User) error {
return r.router.Master().QueryRowxContext(ctx,
"INSERT INTO users (name, age) VALUES (?, ?)",
user.Name, user.Age,
).Err()
}
func (r *UserRepository) Get(ctx context.Context, id int64) (*User, error) {
var user User
err := r.router.Slave().GetContext(ctx, &user,
"SELECT * FROM users WHERE id = ?", id)
return &user, err
}
4. 分库分表实现
package sharding
import (
"fmt"
"hash/crc32"
)
type ShardConfig struct {
DBCount int
TableCount int
}
type ShardRouter struct {
config ShardConfig
pools map[int]*sqlx.DB
}
func (r *ShardRouter) CalcShardLocation(shardKey string) (dbIndex, tableIndex int) {
hash := crc32.ChecksumIEEE([]byte(shardKey))
dbIndex = int(hash % uint32(r.config.DBCount))
tableIndex = int((hash / uint32(r.config.DBCount)) % uint32(r.config.TableCount))
return
}
func (r *ShardRouter) GetTableName(baseTable string, tableIndex int) string {
return fmt.Sprintf("%s_%d", baseTable, tableIndex)
}
func (r *ShardRouter) GetDB(dbIndex int) *sqlx.DB {
return r.pools[dbIndex]
}
type UserShardRepository struct {
router *ShardRouter
}
func (r *UserShardRepository) Create(user *User) error {
dbIndex, tableIndex := r.router.CalcShardLocation(user.UserID)
db := r.router.GetDB(dbIndex)
tableName := r.router.GetTableName("users", tableIndex)
_, err := db.Exec(fmt.Sprintf(
"INSERT INTO %s (user_id, name, age) VALUES (?, ?, ?)",
tableName,
), user.UserID, user.Name, user.Age)
return err
}
5. 主从复制监控
package replication
import (
"context"
"time"
)
type ReplicationStatus struct {
MasterFile string
MasterPosition int
SlaveIORunning bool
SlaveSQLRunning bool
SecondsBehindMaster int
LastError string
}
type ReplicationMonitor struct {
master *sqlx.DB
slave *sqlx.DB
}
func (m *ReplicationMonitor) CheckStatus(ctx context.Context) (*ReplicationStatus, error) {
var status ReplicationStatus
err := m.master.QueryRowContext(ctx, "SHOW MASTER STATUS").Scan(
&status.MasterFile,
&status.MasterPosition,
)
if err != nil {
return nil, fmt.Errorf("get master status failed: %v", err)
}
err = m.slave.QueryRowContext(ctx, "SHOW SLAVE STATUS").Scan(
&status.SlaveIORunning,
&status.SlaveSQLRunning,
&status.SecondsBehindMaster,
&status.LastError,
)
if err != nil {
return nil, fmt.Errorf("get slave status failed: %v", err)
}
return &status, nil
}
type MonitorService struct {
monitor *ReplicationMonitor
interval time.Duration
}
func (s *MonitorService) Start(ctx context.Context) {
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
status, err := s.monitor.CheckStatus(ctx)
if err != nil {
log.Printf("check replication status failed: %v", err)
continue
}
if status.SecondsBehindMaster > 30 {
log.Printf("replication lag too high: %d seconds",
status.SecondsBehindMaster)
}
if !status.SlaveIORunning || !status.SlaveSQLRunning {
log.Printf("replication not running, error: %s",
status.LastError)
}
}
}
}
使用示例
func main() {
pool, err := NewDBPool(
"root:123456@tcp(master:3306)/test",
[]string{
"root:123456@tcp(slave1:3306)/test",
"root:123456@tcp(slave2:3306)/test",
},
)
if err != nil {
log.Fatal(err)
}
router := &DBRouter{
master: pool.Master,
slaves: pool.Slaves,
}
shardRouter := &ShardRouter{
config: ShardConfig{
DBCount: 2,
TableCount: 4,
},
pools: map[int]*sqlx.DB{
0: pool.Master,
1: pool.Slaves[0],
},
}
monitor := &ReplicationMonitor{
master: pool.Master,
slaves: pool.Slaves[0],
}
ctx := context.Background()
monitorService := &MonitorService{
monitor: monitor,
interval: time.Minute,
}
go monitorService.Start(ctx)
}