通过一道考题,来实现一个简单的并发安全的Map,以加深对并发编程和channel的理解。
要求实现一个map:
(1)面向高并发
(2)只存在插入和删除操作,时间复杂度为O(1)
(3)查询时,若Key存在,直接返回Val;若Key不存在,阻塞直到KV对被放入后,获取Val返回;
如果等待指定时长仍未放入,返回超时错误;
1. 先实现上面的1、2点
比较容易想到的就是使用一个Map和一把锁来实现并发安全,如果只考虑第1、2点的话,读写锁是一个比较好的选择。
先写一个MyConCurrentMap的类,封装读写锁和map:
type MyConCurrentMap struct {
sync.RWMutex // golang支持结构体嵌入,允许将一个结构体嵌入到另一个结构体,从而继承被嵌入类型的字段和方法
mp map[int]int
}
实现Get方法:
func (m *MyConCurrentMap) Get(k int) (int, error) {
m.RLock()
defer m.RUnlock()
val, ok := m.mp[k]
if ok {
return val, nil
}
return 0, errors.New("can't get val")
}
实现Put方法:
func (m *MyConCurrentMap) Put(k, v int) {
m.Lock()
defer m.Unlock()
m.mp[k] = v
}
2. 如何实现第3点要求
上面我们在没有第三点限制的前提下,实现了第1、2点要求,如果加了第3点,我们应该怎么实现呢?
最先想到可能是使用sync.Cond,在没有读到数据时使用cond.Wait()等待,等到put数据时使用cond.Signal()唤醒。但是这样使用会出现惊群效应,因为可能有很多goroutine在等待获取不同的key,使用cond.Signal会把他们全都唤醒。
所以我们考虑使用channel,将channel存在map里,每一个key都对应一个channel,这样就可以实现put时只向key对应的chanel中发信号,解除对应goroutine的阻塞。因为在Put中会创建chan,相当于插入操作,而我们之前用的是读写锁的读锁,不能保证写的并发安全,所以我们还是改用mutex。更改后的代码如下:
type MyConCurrentMap struct {
sync.Mutex // golang支持结构体嵌入,允许将一个结构体嵌入到另一个结构体,从而继承被嵌入类型的字段和方法
mp map[int]int
keytoCh map[int]chan struct{} // struct{}是一个空结构体,不占用空间,可以当信号使用
}
func (m *MyConCurrentMap) Put(k, v int) {
m.Lock()
defer m.Unlock()
m.mp[k] = v
ch, ok := m.keytoCh[k]
if !ok { // 说明没有在等待的goroutine
return
}
ch <- struct{}{} // 将信号传入唤醒等待的goroutine
}
func (m *MyConCurrentMap) Get(k int) (int, error) {
m.Lock()
val, ok := m.mp[k]
if ok {
m.Unlock()
return val, nil
}
m.keytoCh[k] = make(chan int)
m.Unlock()
<-ch //从k的chan中接收信号
m.Lock()
val = m.mp[k]
m.Unlock()
return val, nil
}
上面的代码中仍然有很多问题,例如,m.keytoCh[k]=make(chan int) 会一直覆盖之前的goroutine创建的ch,导致之前的goroutine无法被唤醒,所以我们需要在创建chan时判断chan是否被创建。同时因为同一时间只有一个goroutine可以接收到chan中的数据,我们只是往里面简单的发送一个v,只能唤醒一个goroutine。这个时候就可以用到chan的特性:当一个chan被关闭时,所有的读goroutine都会被唤醒,所有的写goroutine也会被唤醒,但是会panic。正好符合我们的需求,更新后的代码如下:
func (m *MyConCurrentMap) Put(k, v int) {
m.Lock()
defer m.Unlock()
m.mp[k] = v
ch, ok := m.keytoCh[k]
if !ok { // 说明没有在等待的goroutine
return
}
close(ch) // 将v传入唤醒等待的goroutine
}
func (m *MyConCurrentMap) Get(k int) (int, error) {
m.Lock()
val, ok := m.mp[k]
if ok {
m.Unlock()
return val, nil
}
ch, ok := m.keytoCh[k]
if !ok {
ch = make(chan struct{})
m.keytoCh[k] = ch
}
m.Unlock()
<-ch //从k的chan中接收信号
m.Lock()
val = m.mp[k]
m.Unlock()
return val, nil
}
我们还要考虑在put中重复关闭同一个chan的问题,可以利用chan的特性:从已关闭的chan中读数据不会阻塞。同时还要实现题目中的超时报错功能,代码如下:
func (m *MyConCurrentMap) Put(k, v int) {
m.Lock()
defer m.Unlock()
m.mp[k] = v
ch, ok := m.keytoCh[k]
if !ok { // 说明没有在等待的goroutine
return
}
select {
case <-ch: //如果channel被关闭,则会走到这个分支return
return
default:
close(ch) // 当一个chan被关闭时,所有的读goroutine都会被唤醒,所有的写goroutine也会被唤醒,但是会panic
}
}
func (m *MyConCurrentMap) Get(k int, timeout time.Duration) (int, error) {
m.Lock()
val, ok := m.mp[k]
if ok {
m.Unlock()
return val, nil
}
ch, ok := m.keytoCh[k]
if !ok {
ch = make(chan struct{})
m.keytoCh[k] = ch
}
m.Unlock()
select {
case <-time.After(timeout):
return -1, errors.New("timeout")
case <-ch: //从k的chan中接收信号
}
m.Lock()
val = m.mp[k]
m.Unlock()
return val, nil
}
最后我们再加上初始化的函数就大功告成了:
type MyConCurrentMap struct {
sync.Mutex // golang支持结构体嵌入,允许将一个结构体嵌入到另一个结构体,从而继承被嵌入类型的字段和方法
mp map[int]int
keytoCh map[int]chan struct{}
}
func NewMyConCurrentMap() *MyConCurrentMap {
return &MyConCurrentMap{
mp: make(map[int]int),
keytoCh: make(map[int]chan struct{}),
}
}
func (m *MyConCurrentMap) Put(k, v int) {
m.Lock()
defer m.Unlock()
m.mp[k] = v
ch, ok := m.keytoCh[k]
if !ok { // 说明没有在等待的goroutine
return
}
select {
case <-ch: //如果channel被关闭,则会走到这个分支return
return
default:
close(ch) // 当一个chan被关闭时,所有的读goroutine都会被唤醒,所有的写goroutine也会被唤醒,但是会panic
}
}
func (m *MyConCurrentMap) Get(k int, timeout time.Duration) (int, error) {
m.Lock()
val, ok := m.mp[k]
if ok {
m.Unlock()
return val, nil
}
ch, ok := m.keytoCh[k]
if !ok {
ch = make(chan struct{})
m.keytoCh[k] = ch
}
m.Unlock()
select {
case <-time.After(timeout):
return -1, errors.New("timeout")
case <-ch: //从k的chan中接收信号
}
m.Lock()
val = m.mp[k]
m.Unlock()
return val, nil
}
测试用例:
func TestMap_PutGet(t *testing.T) {
m := NewMyConCurrentMap()
// 测试写入后立即读取
m.Put(4, 40)
v, err := m.Get(4, 100*time.Millisecond)
if err != nil {
t.Errorf("Get failed: %v", err)
}
if v != 40 {
t.Errorf("Expected value 40, got %d", v)
}
m.Put(4, 20)
v, err = m.Get(4, 100*time.Millisecond)
if err != nil {
t.Errorf("Get failed: %v", err)
}
if v != 20 {
t.Errorf("Expected value 20, got %d", v)
}
}
func TestMap_timeoutGet(t *testing.T) {
m := NewMyConCurrentMap()
_, err := m.Get(2, 100*time.Millisecond)
if err != nil {
t.Error(err)
}
}
func TestMyMap_ConcurrentPutGet(t *testing.T) {
m := NewMyConCurrentMap()
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
m.Put(i, i*10)
v, err := m.Get(i, 10*time.Second)
if err != nil {
t.Errorf("Get failed: %v", err)
}
if v != i*10 {
t.Errorf("Expected value %d, got %d", i*10, v)
}
}(i)
}
wg.Wait()
}
func TestMyMap_ConcurrentPutGet2(t *testing.T) {
m := NewMyConCurrentMap()
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
m.Put(i, i*10)
}
for i := 0; i < 100; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
v, err := m.Get(i, 10*time.Second)
if err != nil {
t.Errorf("Get failed: %v", err)
}
if v != i*10 {
t.Errorf("Expected value %d, got %d", i*10, v)
}
}(i)
}
wg.Wait()
}