官方提供的 sync.Cond 支持 Wait() Signal() Broadcast() 三个函数,相对比较简单,有时候可能会用到等待超时或者等待中断等操作,这里基于sync.Locker和channel实现一个扩展版的condition,支持函数:
Await() //等待函数
AwaitWithTimeOut(timeout time.Duration) bool //等待带超时函数
AwaitNanos(nanos time.Duration) time.Duration //等待并返回剩余时间函数
Signal()
SignalAll()
Interrupt() //中断函数
实现源码:
package common
import (
"sync"
"time"
)
type Condition struct {
L sync.Locker
ch chan int
interrupt chan int
}
func NewCondition(locker sync.Locker) *Condition {
c := new(Condition)
c.ch = make(chan int)
c.interrupt = make(chan int)
c.L = locker
return c
}
func (c *Condition) Await() {
c.L.Unlock()
defer c.L.Lock()
select {
case c.ch <- 1:
return
case c.interrupt <- 1:
panic("Interrupt exception")
}
}
func (c *Condition) AwaitWithTimeOut(timeout time.Duration) bool {
c.L.Unlock()
defer c.L.Lock()
select {
case c.ch <- 1:
return true
case <-time.After(timeout):
return false
case c.interrupt <- 1:
panic("Interrupt exception")
}
}
func (c *Condition) AwaitNanos(nanos time.Duration) time.Duration {
c.L.Unlock()
defer c.L.Lock()
estimateTime := time.Now().UnixNano() + int64(nanos)
select {
case c.ch <- 1:
return time.Duration(estimateTime-time.Now().UnixNano())
case <-time.After(nanos):
return time.Duration(estimateTime-time.Now().UnixNano())
case c.interrupt <- 1:
panic("Interrupt exception")
}
}
func (c *Condition) Signal() {
select {
case _ = <-c.ch:
return
case <-time.After(1 * time.Millisecond):
return
}
}
func (c *Condition) SignalAll() {
for {
select {
case _ = <-c.ch:
break
case <-time.After(1 * time.Millisecond):
return
}
}
}
func (c *Condition) Interrupt() {
defer c.SignalAll()
for {
select {
case _ = <-c.interrupt:
break
case <-time.After(1 * time.Millisecond):
return
}
}
}
测试代码:
import (
"fmt"
"./common"
"sync"
"time"
)
func main() {
locker := new(sync.Mutex)
cond := common.NewCondition(locker)
wg := new(sync.WaitGroup)
wg.Add(1)
go func() {
defer wg.Done()
cond.L.Lock()
defer cond.L.Unlock()
cond.Await()
fmt.Println("1")
}()
wg.Add(1)
go func() {
defer wg.Done()
cond.L.Lock()
defer cond.L.Unlock()
fmt.Println(cond.AwaitWithTimeOut(1 * time.Second))
fmt.Println("2")
}()
wg.Add(1)
go func() {
defer wg.Done()
cond.L.Lock()
defer cond.L.Unlock()
fmt.Println(cond.AwaitWithTimeOut(5 * time.Second))
fmt.Println("3")
}()
time.Sleep(3 * time.Second)
//cond.Signal()
cond.SignalAll()
//cond.Interrupt()
wg.Wait()
}
输出结果:
false
2
true
3
1
可以看到线程2超时了返回false,线程3未超时返回true,线程1为做超时处理正常等待退出