go 无锁(乐观锁)队列的实现

1 篇文章 0 订阅

采用轻量级的CAS(Compare and swap)乐观锁,实现安全的队列,代码如下:

/**
 * 线程安全的队列,使用轻量级的 CAS 锁
 */
package cas_queue

import (
	"errors"
	"fmt"
	"runtime"
	"sync/atomic"
)

type casCache struct {
	putNo uint32
	getNo uint32
	value interface{}
}

// lock free queue
type CASQueue struct {
	capacity uint32
	capMod    uint32
	putPos    uint32
	getPos    uint32
	cache     []casCache
}

func NewCASQueue(capacity uint32) *CASQueue {
	q := new(CASQueue)
	q.capacity = minQuantity(capacity)
	q.capMod = q.capacity - 1
	q.putPos = 0
	q.getPos = 0
	q.cache = make([]casCache, q.capacity)
	for i := range q.cache {
		cache := &q.cache[i]
		cache.getNo = uint32(i)
		cache.putNo = uint32(i)
	}
	cache := &q.cache[0]
	cache.getNo = q.capacity
	cache.putNo = q.capacity
	return q
}

func (q *CASQueue) String() string {
	getPos := atomic.LoadUint32(&q.getPos)
	putPos := atomic.LoadUint32(&q.putPos)
	return fmt.Sprintf("Queue{capacity: %v, capMod: %v, putPos: %v, getPos: %v}",
		q.capacity, q.capMod, putPos, getPos)
}

func (q *CASQueue) getCapacity() uint32 {
	return q.capacity
}

/**
 * 获取当前元素个数
 */
func (q *CASQueue) Quantity() uint32 {
	var putPos, getPos uint32
	var quantity uint32
	getPos = atomic.LoadUint32(&q.getPos)
	putPos = atomic.LoadUint32(&q.putPos)

	if putPos >= getPos {
		quantity = putPos - getPos
	} else {
		quantity = q.capMod + (putPos - getPos)
	}

	return quantity
}

/**
 * put queue functions
 * ok: 如果成功添加,则为true,反之同一时间有多个线程put导致写入失败或者队列长度不够,则返回false
 * quantity: 返回代表队列的元素个数,如果大于等于capMod - 1则表示空间满了
 */
func (q *CASQueue) putMayFail(val interface{}) (ok bool, quantity uint32) {
	var putPos, putPosNew, getPos, posCnt uint32
	var cache *casCache
	capMod := q.capMod

	getPos = atomic.LoadUint32(&q.getPos)
	putPos = atomic.LoadUint32(&q.putPos)

	if putPos >= getPos {
		posCnt = putPos - getPos
	} else {
		posCnt = capMod + (putPos - getPos)
	}

	// 空间不足
	if posCnt >= capMod-1 {
		runtime.Gosched()
		return false, posCnt
	}

	putPosNew = putPos + 1
	if !atomic.CompareAndSwapUint32(&q.putPos, putPos, putPosNew) {
		runtime.Gosched()
		return false, posCnt
	}

	cache = &q.cache[putPosNew & capMod]

	for {
		getNo := atomic.LoadUint32(&cache.getNo)
		putNo := atomic.LoadUint32(&cache.putNo)
		if putPosNew == putNo && getNo == putNo {
			cache.value = val
			atomic.AddUint32(&cache.putNo, q.capacity)
			return true, posCnt + 1
		} else {
			runtime.Gosched()
		}
	}
}

/**
 * 添加一个元素到队列,如果队列满了则报错
 */
func (q *CASQueue) Put(val interface{}) error {
	var ok bool
	var quantity uint32
	for !ok { // 写入失败,没拿到CAS锁,则继续写入
		ok, quantity = q.putMayFail(val)
		// 队列长度不够了,则直接返回错误
		if quantity >= q.capMod - 1 {
			errMsg := fmt.Sprintf("queue almost overflow, the capacity is [%d], now the quantity is [%d]", q.capacity, quantity)
			return errors.New(errMsg)
		}
	}
	return nil
}

/**
 * 获取队列中的数据
 * ok: 获取成功为 true,否则false
 * quantity: 当前剩下的数据量,为0且ok为false则说明没有数据可读了
 */
func (q *CASQueue) getMayFail() (val interface{}, ok bool, quantity uint32) {
	var putPos, getPos, getPosNew, posCnt uint32
	var cache *casCache
	capMod := q.capMod

	putPos = atomic.LoadUint32(&q.putPos)
	getPos = atomic.LoadUint32(&q.getPos)

	if putPos >= getPos {
		posCnt = putPos - getPos
	} else {
		posCnt = capMod + (putPos - getPos)
	}

	if posCnt < 1 {
		runtime.Gosched()
		return nil, false, posCnt
	}

	getPosNew = getPos + 1
	if !atomic.CompareAndSwapUint32(&q.getPos, getPos, getPosNew) {
		runtime.Gosched()
		return nil, false, posCnt
	}

	cache = &q.cache[getPosNew&capMod]

	for {
		getNo := atomic.LoadUint32(&cache.getNo)
		putNo := atomic.LoadUint32(&cache.putNo)
		if getPosNew == getNo && getNo == putNo-q.capacity {
			val = cache.value
			cache.value = nil
			atomic.AddUint32(&cache.getNo, q.capacity)
			return val, true, posCnt - 1
		} else {
			runtime.Gosched()
		}
	}
}

/**
 * 获取队列中的数据
 * emptyFlag: false 则此次没获取到数据,原因是队列为空,true则获取到了数据
 */
func (q *CASQueue) Get() (val interface{}, emptyFlag bool) {
	var ok bool
	var quantity uint32
	var v interface{}
	for !ok { // 写入失败,没拿到CAS锁,则继续写入
		v, ok, quantity = q.getMayFail()
		// 队列为空
		if quantity == 0 && !ok {
			return nil, false
		}
	}
	return v, true
}

// round 到最近的2的倍数
func minQuantity(v uint32) uint32 {
	v--
	v |= v >> 1
	v |= v >> 2
	v |= v >> 4
	v |= v >> 8
	v |= v >> 16
	v++
	return v
}

测试方法:

package cas_queue

import (
	"fmt"
	"runtime"
	"testing"
	"time"
)

func addFuncSucc(q *CASQueue, prefix int, l int) {
	for i := 0; i < l; i++ {
		v := fmt.Sprintf("%d---%d", prefix, i)
		err := q.Put(v)
		if err != nil {
			fmt.Println(err)
		}
		fmt.Printf("%v ", v)
	}
	fmt.Println()
}

func getFuncSucc(q *CASQueue, l int) {
	for i := 0; i < l; i++ {
		_, flag := q.Get()
		if !flag {
			fmt.Printf("get fail, the queue is empty\n")
		}
	}
}

// 测试并发时,查看是否做到了线程安全
func TestPutManayTimesSucc(t *testing.T) {
	runtime.GOMAXPROCS(runtime.NumCPU())

	q := NewCASQueue(10000)
	l := 1000
	for i := 0; i < 10; i++ {
		prefix := i
		go func() {
			fmt.Printf("truely prefix [%d]\n", prefix)
			addFuncSucc(q, prefix, l)
		}()
	}
	time.Sleep(1 * time.Second)
	fmt.Println("going")
	time.Sleep(1 * time.Second)
	fmt.Println(q.Quantity())

	for i := 0; i < 10; i++ {
		go func() {
			getFuncSucc(q, l)
		}()
	}

	time.Sleep(1 * time.Second)
	fmt.Println("going")
	time.Sleep(1 * time.Second)
	fmt.Println(q.Quantity())
}
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值