请求合并
package req_merge
import (
"errors"
"fmt"
"runtime"
"sync"
)
type Group struct {
sync.Mutex
m map[string]*call
}
type call struct {
wg sync.WaitGroup
val interface{}
err error
released bool
dups int
chans []chan <- Result
}
type Result struct {
Val interface{}
Err error
Dups int
}
type panicError struct {
err string
}
func (p panicError) Error() string {
return p.err
}
var panicE = panicError{"panic:程序运行错误"}
func newPanicError(err interface{}) panicError {
return panicError{err: fmt.Sprintf("%v: %v", panicE.Error(), err)}
}
type runtimeError struct {
err string
}
func (p runtimeError) Error() string {
return p.err
}
var runtimeE = runtimeError{"runtime: 运行时异常"}
func newRuntimeError(err interface{}) runtimeError {
return runtimeError{err: fmt.Sprintf("%v: %v", runtimeE.Error(), err)}
}
func (g *Group)Do(key string, fn func()(interface{},error)) (v interface{},err error,shared bool) {
g.Lock()
if g.m == nil{
g.m = make(map[string]*call)
}
if c, ok := g.m[key];ok{
c.dups++
g.Unlock()
c.wg.Wait()
switch c.err.(type) {
case panicError:
panic(c.err)
case runtimeError:
runtime.Goexit()
}
return c.val,c.err,true
}
c := new(call)
c.wg.Add(1)
g.m[key] = c
g.Unlock()
g.doCall(c,key,fn)
return c.val,c.err,c.dups>0
}
func (g *Group)doCall(c *call,key string,fn func()(interface{},error)) {
normalReturn := false
recovered := false
defer func() {
if !normalReturn && !recovered{
c.err = newRuntimeError(errors.New("退出"))
}
c.wg.Done()
g.Lock()
defer g.Unlock()
if !c.released{
delete(g.m,key)
}
switch c.err.(type) {
case panicError:
if len(c.chans)>0{
go func() {panic(c.err)}()
select {}
}else{
panic(c.err)
}
case runtimeError:
default:
for _, ch := range c.chans {
ch <- Result{
Val: c.val,
Err: c.err,
Dups: c.dups,
}
}
}
}()
func(){
defer func() {
if !normalReturn{
if r := recover();r != nil{
c.err = newPanicError(r)
}
}
}()
c.val,c.err = fn()
normalReturn = true
}()
if !normalReturn{
recovered = true
}
}
func (g *Group) DoChan(key string,fn func()(interface{},error)) <- chan Result {
ch := make(chan Result,1)
g.Lock()
if g.m == nil{
g.m = make(map[string]*call)
}
if c, ok := g.m[key];ok{
c.dups++
c.chans = append(c.chans,ch)
g.Unlock()
return ch
}
c := &call{
wg: sync.WaitGroup{},
chans: []chan <- Result{ch},
}
c.wg.Add(1)
g.m[key] = c
g.Unlock()
go g.doCall(c,key,fn)
return ch
}
func (g *Group)Release(key string) {
g.Lock()
if c, ok := g.m[key];ok{
c.released = true
}
delete(g.m,key)
g.Unlock()
}
测试
package req_merge
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
)
var count int32
func TestGroup_Do(t *testing.T) {
time.AfterFunc(1*time.Second, func() {
atomic.AddInt32(&count, -count)
})
var (
wg sync.WaitGroup
now = time.Now()
n = 1000
)
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
res, _ := getArticle(1)
if res != "article: 1" {
panic("err")
}
wg.Done()
}()
}
wg.Wait()
fmt.Printf("同时发起 %d 次请求,耗时: %s", n, time.Since(now))
}
func TestGroup_Do2(t *testing.T) {
time.AfterFunc(1*time.Second, func() {
atomic.AddInt32(&count, -count)
})
var (
wg sync.WaitGroup
now = time.Now()
n = 1000
sg = &Group{}
)
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
res, _ := reqMergeGetArticle(sg, 1)
if res != "article: 1" {
panic("err")
}
wg.Done()
}()
}
wg.Wait()
fmt.Printf("同时发起 %d 次请求,耗时: %s", n, time.Since(now))
}
func getArticle(id int) (article string, err error) {
atomic.AddInt32(&count, 1)
time.Sleep(time.Duration(count) * time.Millisecond)
return fmt.Sprintf("article: %d", id), nil
}
func reqMergeGetArticle(sg *Group, id int) (string, error) {
v, err, _ := sg.Do(fmt.Sprintf("%d", id), func() (interface{}, error) {
return getArticle(id)
})
return v.(string), err
}
func reqMergeGetArticleHang(sg *Group, id int) (string, error) {
v, err, _ := sg.Do(fmt.Sprintf("%d", id), func() (interface{}, error) {
select {}
return getArticle(id)
})
return v.(string), err
}
func TestGroup_DoHang(t *testing.T) {
time.AfterFunc(1*time.Second, func() {
atomic.AddInt32(&count, -count)
})
var (
wg sync.WaitGroup
now = time.Now()
n = 1000
sg = &Group{}
)
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
res, _ := reqMergeGetArticleHang(sg, 1)
if res != "article: 1" {
panic("err")
}
wg.Done()
}()
}
wg.Wait()
fmt.Printf("同时发起 %d 次请求,耗时: %s", n, time.Since(now))
}
func reqMergeGetArticleNoHang(ctx context.Context, sg *Group, id int) (string, error) {
result := sg.DoChan(fmt.Sprintf("%d", id), func() (interface{}, error) {
select {}
return getArticle(id)
})
select {
case r := <-result:
return r.Val.(string), r.Err
case <-ctx.Done():
return "超时", ctx.Err()
}
}
func TestGroup_DoNoHang2(t *testing.T) {
time.AfterFunc(1*time.Second, func() {
atomic.AddInt32(&count, -count)
})
var (
wg sync.WaitGroup
now = time.Now()
n = 1000
sg = &Group{}
)
ctx,cancel := context.WithCancel(context.Background())
go func() {
select {
case <-time.After(time.Duration(2) * time.Second):
cancel()
}
}()
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
res, _ := reqMergeGetArticleNoHang(ctx,sg, 1)
if res != "article: 1" {
panic(res)
}
wg.Done()
}()
}
wg.Wait()
fmt.Printf("同时发起 %d 次请求,耗时: %s", n, time.Since(now))
}
func TestName(t *testing.T) {
go func() {
time.Sleep(100 * time.Millisecond)
}()
}