动机
突然想试着用golang写一个自己的redis连接客户端,尽可能的让操作变简单,也通过写代码的过程中,学习一些知识点,不想一直局限于写web项目
简介
目前实现了简单的get和set操作,加入了直接set结构体和get结构体数据的方法,添加了拦截器链,在执行redis命令的过程中进行前置拦截和后置拦截。也加入了事务,但是事务的实现可能不够完善,后续会继续修改。使用了连接池管理对象。写这个需要很多的时间,所以后续我会慢慢的补齐其他的功能。如果有志同道合的朋友我们可以一起写,也是当锻炼能力了。
代码在github上可以看到
部分代码如下
package redis
import (
"fmt"
"time"
)
/**
redis 中间层主要代码,负责在redisClient 和 connect 中起到承上启下的作用
*/
type Redis struct {
// 配置类
Config *Config
// 连接池
connPool *ConnPool
// 前置拦截器,查询前拦截器
preInterceptors []func(redisContext *InterceptorContext)
// 后置拦截器,查询后的拦截器
postInterceptors []func(redisContext *InterceptorContext)
}
func (r *Redis) GetByteArr(cmd *RedisCommand, t *transaction) (res []byte) {
// 获取连接对象
var conn *RedisConn
if t == nil {
conn = r.connPool.Get()
// 归还连接对象
defer r.connPool.Put(conn)
} else {
if t.status > 0 {
return nil
}
conn = t.conn
}
// 构建一个命令对象
context := newInterceptorContext(conn, cmd)
// 执行前置拦截器链
r.doPreInterceptor(context)
result, success := conn.CommandGetResult(cmd.Cmd())
if success {
res = result
cmd.SetRes(res)
} else if t != nil {
t.status = 2
}
// 执行后置拦截链
r.doPostInterceptor(context)
// 置空,方便垃圾回收
context = nil
return
}
// Set 默认set方法,未携带过期时间
func (r *Redis) Set(cmd *RedisCommand, t *transaction) (success bool) {
// 获取连接对象
var conn *RedisConn
if t == nil {
conn = r.connPool.Get()
// 归还连接对象
defer r.connPool.Put(conn)
} else {
if t.status > 0 {
return false
}
conn = t.conn
}
// 构建一个命令对象
context := newInterceptorContext(conn, cmd)
// 执行前置拦截器链
r.doPreInterceptor(context)
if conn.CommandNoResult(cmd.Cmd()) {
success = true
} else if t != nil {
t.status = 2
}
// 执行后置拦截链
r.doPostInterceptor(context)
// 归还连接对象
r.connPool.Put(conn)
// 置空,方便垃圾回收
context = nil
return
}
// SetWithExp set顺便设置过期时间,-1表示用不过期,大于0的数表示过期时间,单位为毫秒,
func (r *Redis) SetWithExp(key string, value interface{}, expireTime time.Duration) (success bool) {
return
}
// AddPreInterceptor 添加前置拦截器
func (r *Redis) AddPreInterceptor(interceptor func(redisContext *InterceptorContext)) {
r.preInterceptors = append(r.preInterceptors, interceptor)
}
// AddPostInterceptor 添加后置拦截器
func (r *Redis) AddPostInterceptor(interceptor func(redisContext *InterceptorContext)) {
r.postInterceptors = append(r.postInterceptors, interceptor)
}
func (r *Redis) doPreInterceptor(redisContext *InterceptorContext) {
for _, interceptor := range r.preInterceptors {
interceptor(redisContext)
}
}
func (r *Redis) doPostInterceptor(redisContext *InterceptorContext) {
for _, interceptor := range r.postInterceptors {
interceptor(redisContext)
}
}
func (r *Redis) getRedisConn() *RedisConn {
return r.connPool.Get()
}
func (r *Redis) putRedisConn(c *RedisConn) {
r.connPool.Put(c)
}
/*
事务
*/
type transaction struct {
// 当前事务的连接对象
conn *RedisConn
// 当前事务的状态 0表示开启,1表示结束,2表示回滚
status int
}
func (t *transaction) Open() {
if t.conn.BeginTransaction() {
fmt.Println("事务开启失败!")
}
}
func (t *transaction) Close() {
if t.conn.EndTransaction() {
fmt.Println("关闭事务失败")
}
}
func (t *transaction) Rollback() {
if t.conn.Rollback() {
fmt.Println("回滚失败")
}
}
package redis
import (
"fmt"
"github.com/vmihailenco/msgpack/v5"
)
type redisCli struct {
redis Redis
}
// CreateRedisCli 创建连接对象
func CreateRedisCli(config *Config) *redisCli {
// 初始化中间连接层
redis := Redis{
Config: config,
}
// 创建连接池
pool := NewConnPool(config, nil, 5, 10, 8)
redis.connPool = pool
// 判断是否开启debug 模式
if config.Debug {
redis.AddPostInterceptor(PrintDebug)
}
// TODO 还可以继续干很多事情,后面再写
r := redisCli{
redis,
}
return &r
}
func (r *redisCli) xxx(key string) {
}
// Get 默认Get方法,直接转化为string类型返回
func (r *redisCli) Get(key string, arg ...*transaction) (res string) {
arr := r.GetByte(key, arg...)
res = string(arr)
return
}
// GetByte 获取原始二进制结果
func (r *redisCli) GetByte(key string, arg ...*transaction) (res []byte) {
command := newGetCommand(key, -1)
var t *transaction = nil
if len(arg) > 0 {
t = arg[0]
}
res = r.redis.GetByteArr(command, t)
return
}
// Set 设置任意值,结构体除外
func (r *redisCli) Set(key string, value interface{}, arg ...*transaction) (success bool) {
cmd := key + " " + convertInterfaceToString(value)
command := newSetCommand(cmd, -1)
command.AddArgs(value)
var t *transaction = nil
if len(arg) > 0 {
t = arg[0]
}
success = r.redis.Set(command, t)
return
}
// SetWithStruct 直接设置一个结构体对象
func (r *redisCli) SetWithStruct(key string, value interface{}, arg ...*transaction) (success bool) {
v, err := msgpack.Marshal(value)
if err != nil {
fmt.Printf("不能将[]byte转为%T\n", v)
fmt.Println(err)
return
}
cmd := key + " " + convertInterfaceToString(v)
command := newSetCommand(cmd, -1)
command.AddArgs(value)
var t *transaction = nil
if len(arg) > 0 {
t = arg[0]
}
success = r.redis.Set(command, t)
return
}
// GetWithStruct 获取结构体对象,获取的对象就在传进来的对象中
func (r *redisCli) GetWithStruct(key string, vt interface{}, arg ...*transaction) {
arr := r.GetByte(key, arg...)
err := msgpack.Unmarshal(arr, vt)
if err != nil {
fmt.Printf("不能将[]byte转为%T\n", vt)
}
return
}
// AddPreInterceptor 添加前置拦截器
func (r *redisCli) AddPreInterceptor(interceptor func(redisContext *InterceptorContext)) {
r.redis.preInterceptors = append(r.redis.preInterceptors, interceptor)
}
// AddPostInterceptor 添加后置拦截器
func (r *redisCli) AddPostInterceptor(interceptor func(redisContext *InterceptorContext)) {
r.redis.postInterceptors = append(r.redis.postInterceptors, interceptor)
}
// CreateTransaction 创建一个事务对象
func (r *redisCli) CreateTransaction() *transaction {
t := transaction{conn: nil, status: 0}
return &t
}
// DoTransaction 事务执行
func (r *redisCli) DoTransaction(t *transaction, f func() (err error)) {
// 开启事务
t.conn = r.redis.getRedisConn()
defer r.redis.putRedisConn(t.conn)
t.Open()
err := f()
if err != nil {
// 回滚
t.status = 2
fmt.Println(err, "事务回滚")
t.Rollback()
}
// 提交事务
t.status = 1
t.Close()
}
package redis
import (
"time"
)
type Config struct {
// redis连接地址
IpAddr string
// 端口
Port int
// 网络连接类型,默认Tcp
NetType string
//redis-server 协议类型,默认协议为3
Protocol int
//用户名
Username string
//密码
Password string
//数据库的索引
DataBase int
// 最大失败尝试次数
MaxRetries int
//连接超时时间
ConnectionTimeOut time.Duration
//最大读取时间
ReadTimeOut time.Duration
// 最大写入时间
WriteTimeout time.Duration
//连接池大小
PoolSize int
// 最大空闲连接数量
MaxIdleConn int
// 最大活动连接数量
MaxActiveConn int
// 是否开启Debug模式,开启日志打印,默认为false
Debug bool
}
func FastConfig(addr string, port int, password string) *Config {
return &Config{
IpAddr: addr,
Port: port,
Password: password,
NetType: "tcp",
DataBase: 16,
Debug: true,
}
}
package redis
import (
"time"
)
const (
get = iota
set
setex
sel
ping
del
flushdb
flushall
keys
expire
)
const (
Get = "GET "
END = "\r\n"
Set = "SET "
Multl = "MULTI"
Exec = "EXEC"
Discard = "DISCARD"
AUTH = "auth"
)
// RedisCommand 表示一个Redis命令
type RedisCommand struct {
// 命令名
cmd string
// 参数,依次为key,value……等
args []interface{}
// 超时时间
timeout time.Duration
// 返回结果
res []byte
// 命令种类
cmdType int
// 附加参数
}
func (r *RedisCommand) Res() []byte {
return r.res
}
func (r *RedisCommand) SetRes(res []byte) {
r.res = res
}
func (r *RedisCommand) Args() []interface{} {
return r.args
}
func (r *RedisCommand) SetArgs(args []interface{}) {
r.args = args
}
func (r *RedisCommand) CmdType() int {
return r.cmdType
}
func (r *RedisCommand) AddArgs(item interface{}) {
r.args = append(r.args, item)
}
func (r *RedisCommand) SetCmdType(cmdType int) {
r.cmdType = cmdType
}
func (r *RedisCommand) Cmd() string {
return r.cmd
}
func (r *RedisCommand) SetCmd(cmd string) {
r.cmd = cmd
}
func (r *RedisCommand) Timeout() time.Duration {
return r.timeout
}
func (r *RedisCommand) SetTimeout(timeout time.Duration) {
r.timeout = timeout
}
func newRedisCommand(cmd string, timeout time.Duration, res []byte, cmdType int) *RedisCommand {
return &RedisCommand{cmd: cmd, timeout: timeout, res: res, cmdType: cmdType}
}
func newGetCommand(cmd string, timeout time.Duration) *RedisCommand {
cmd = Get + cmd + END
return newRedisCommand(cmd, timeout, nil, get)
}
func newSetCommand(cmd string, timeout time.Duration) *RedisCommand {
cmd = Set + cmd + END
return newRedisCommand(cmd, timeout, nil, set)
}
package redis
import (
"bufio"
"fmt"
"net"
"strings"
)
type RedisConn struct {
// 连接对象
conn net.Conn
//是否开启事务
}
// NewRedisConn 创建一个新的连接,使用Golang的net包
func NewRedisConn(netType string, addr string) (*RedisConn, error) {
conn, err := net.Dial(netType, addr)
if err != nil {
fmt.Println(err)
return nil, err
}
return &RedisConn{conn: conn}, err
}
func (r *RedisConn) Auth(password string) bool {
authCommand := AUTH + " " + password + END
_, err := r.conn.Write([]byte(authCommand))
if err != nil {
fmt.Println("Redis权限认证错误:", err)
return false
}
reader := bufio.NewReader(r.conn)
response, err := reader.ReadString('\n')
if err != nil {
fmt.Println("认证时读取Redis返回值出错:", err)
return false
}
// 如果认证成功,可以继续发送其他命令进行操作
if response != "+OK\r\n" {
fmt.Println("Redis 认证失败.")
return false
}
return true
}
func (r *RedisConn) CommandNoResult(cmd string) (success bool) {
_, err := r.conn.Write([]byte(cmd))
if err != nil {
fmt.Println("执行命令失败", err)
return false
}
return true
}
func (r *RedisConn) CommandGetResult(cmd string) (result []byte, success bool) {
_, err := r.conn.Write([]byte(cmd))
success = true
if err != nil {
fmt.Println("执行命令失败", err)
success = false
return
}
if success {
reader := bufio.NewReader(r.conn)
respLenBytes, err := reader.ReadBytes('\n')
if err != nil {
fmt.Println("无法读取响应长度:", err)
return
}
// 解析响应长度
respLenStr := strings.TrimPrefix(string(respLenBytes), "$")
respLen := 0
_, err = fmt.Sscanf(respLenStr, "%d", &respLen)
if err != nil {
return nil, false
}
// 读取响应内容
for i := 0; i < respLen; i++ {
responseByte, err := reader.ReadByte()
if err != nil {
fmt.Println("无法读取响应内容:", err)
return
}
result = append(result, responseByte)
}
// 读取响应结束符
_, err = reader.ReadBytes('\n')
if err != nil {
fmt.Println("无法读取响应结束符:", err)
return
}
}
return
}
// BeginTransaction 开启事务
func (r *RedisConn) BeginTransaction() bool {
_, f := r.CommandGetResult("MULTI" + END)
return f
}
// EndTransaction 结束事务
func (r *RedisConn) EndTransaction() bool {
_, f := r.CommandGetResult("EXEC" + END)
return f
}
// Rollback 添加一个 Rollback 方法来执行回滚操作
func (r *RedisConn) Rollback() bool {
_, f := r.CommandGetResult("DISCARD" + END) // 发送 DISCARD 命令来取消事务
return f
}
package redis
import (
"fmt"
"strconv"
"sync"
)
type ConnPool struct {
pool chan *RedisConn
maxOpenConns int
maxIdleConns int
mu sync.Mutex
factory func() *RedisConn
}
func (p *ConnPool) SetFactory(factory func() *RedisConn) {
p.factory = factory
}
func (p *ConnPool) MaxOpenConns() int {
return p.maxOpenConns
}
func (p *ConnPool) SetMaxOpenConns(maxOpenConns int) {
p.maxOpenConns = maxOpenConns
}
func (p *ConnPool) MaxIdleConns() int {
return p.maxIdleConns
}
func (p *ConnPool) SetMaxIdleConns(maxIdleConns int) {
p.maxIdleConns = maxIdleConns
}
// DefaultFactory 默认创建连接工厂
func DefaultFactory(config *Config) *RedisConn {
dst := config.IpAddr + ":" + strconv.Itoa(config.Port)
conn, err := NewRedisConn(config.NetType, dst)
if err != nil {
return nil
}
return conn
}
func NewConnPool(config *Config, factory func() *RedisConn, initialOpenConns, maxOpenConns, maxIdleConns int) *ConnPool {
// 创建一个连接池
pool := &ConnPool{
pool: make(chan *RedisConn, maxOpenConns),
maxOpenConns: maxOpenConns,
maxIdleConns: maxIdleConns,
factory: factory,
}
df := factory == nil
// 初始化若干个连接对象
for i := 0; i < initialOpenConns; i++ {
// 使用工厂方法对连接进行创建,如果没有传入,使用默认实现的工厂对象
var conn *RedisConn
if df {
conn = DefaultFactory(config)
// 建立连接对象失败
if conn == nil {
fmt.Println("创建连接对象失败!")
break
}
} else {
//TODO 执行自定义的连接创建工厂
}
// 直到创建完成initialOpenConns个连接对象才停止循环
if conn.Auth(config.Password) {
pool.Put(conn)
}
}
return pool
}
func (p *ConnPool) Get() *RedisConn {
select {
case item := <-p.pool:
return item
default:
p.mu.Lock()
defer p.mu.Unlock()
if len(p.pool) >= p.maxOpenConns {
return nil
}
conn := p.factory()
return conn
}
}
func (p *ConnPool) Put(item *RedisConn) {
select {
case p.pool <- item:
if len(p.pool) > p.maxIdleConns {
item.conn.Close()
}
default:
item.conn.Close()
}
}