errgroup
是 Go 语言官方扩展库 x/sync
中的一个包,它提供了一种方式来并行运行多个 goroutine,并在所有 goroutine 都完成时返回第一个发生的错误(如果有的话)。这对于需要并行处理多个任务并等待它们全部完成,同时需要处理其中任何一个可能发生的错误的场景非常有用。
errgroup
是 Go 语言中用于管理多个 goroutine 的同步和错误处理的库。使用 errgroup
可以简化并发代码的编写,使得错误处理更加简洁和一致。
注意:goroutine任务需要有取消功能才能立即终止其它任务返回。
errgroup可以等待所有任务完成再返回,也可以等到第一个错误出现时终止其它任务,取决于业务逻辑。
它的主要作用包括:
- 同步: errgroup.Group 提供了一个 Wait 方法,这个方法会阻塞调用者,直到组内的所有goroutine都完成执行。
- 错误传播: errgroup 能确保第一个发生的错误会被立即传播给所有其他goroutine,这样可以避免在多个并发任务中检查每个任务的状态,简化错误处理逻辑。
- 取消上下文: errgroup 结合 context.Context 使用,可以在外部请求取消时,通知所有goroutine停止执行。返回第一个错误原因。
- 限制并发: 通过 SetLimit 方法,errgroup 可以限制同时运行的goroutine数量。配合done()方法。cancel()触发done()。
案例1
fetch其中一个报错了,其它的goroutine还在运行。errgroup返回的是最后一个错误。
package main
import (
"errors"
"fmt"
"golang.org/x/sync/errgroup"
"math/rand"
"time"
)
func fetch(url string) (string, error) {
randomNumber := rand.Intn(10) + 5
fmt.Println(randomNumber)
time.Sleep(time.Duration(randomNumber) * time.Second)
fmt.Println("fetch")
return url, errors.New("error happens")
}
func main() {
urls := []string{
"http://example.com",
"http://example.org",
"http://example.net",
}
rand.Seed(time.Now().UnixNano())
var eg errgroup.Group
for _, url := range urls {
eg.Go(func() error {
body, err := fetch(url)
if err != nil {
return err
}
fmt.Printf("Fetched %s: %s\n", url, body)
return nil
})
}
if err := eg.Wait(); err != nil {
fmt.Printf("Failed to fetch one or more URLs: %v\n", err)
}
}
输出:
14
7
9
fetch
fetch
fetch
Failed to fetch one or more URLs: error happens
这个并不是我们想要的结果,我们期望其中一个goroutine报错后,其它的任务终止。
案例2
当碰到错误会立即停止所有goroutine。
package main
import (
"context"
"errors"
"fmt"
"golang.org/x/sync/errgroup"
"strings"
"time"
)
func main() {
queryUrls := map[string]string{
"url1": "http://localhost/url1",
"url2": "http://localhost/url2",
"url3": "http://localhost/url3",
}
var results []string
ctx, cancel := context.WithCancel(context.Background())
eg, errCtx := errgroup.WithContext(ctx)
for _, url := range queryUrls {
url := url
eg.Go(func() error {
result, err := query(errCtx, url)
if err != nil {
//其实这里不用手动取消,看完源码就知道为啥了
cancel()
return err
}
results = append(results, fmt.Sprintf("url:%s -- ret: %v", url, result))
return nil
})
}
err := eg.Wait()
if err != nil {
fmt.Println("eg.Wait error:", err)
return
}
for k, v := range results {
fmt.Printf("%v ---> %v\n", k, v)
}
}
func query(errCtx context.Context, url string) (ret string, err error) {
fmt.Printf("请求 %s 开始....\n", url)
// 假设这里是发送请求,获取数据
if strings.Contains(url, "url2") {
// 假设请求 url2 时出现错误
time.Sleep(time.Second * 2)
return "", errors.New("请求出错")
} else if strings.Contains(url, "url3") {
// 假设 请求 url3 需要1秒
select {
case <-errCtx.Done():
ret, err = "", errors.New("请求3被取消")
fmt.Println("请求3被取消")
return
case <-time.After(time.Second * 3):
fmt.Printf("请求 %s 结束....\n", url)
return "success3", nil
}
} else {
select {
case <-errCtx.Done():
ret, err = "", errors.New("请求1被取消")
fmt.Println("请求1被取消")
return
case <-time.After(time.Second):
fmt.Printf("请求 %s 结束....\n", url)
return "success1", nil
}
}
}
输出:
请求 http://localhost/url2 开始....
请求 http://localhost/url3 开始....
请求 http://localhost/url1 开始....
请求 http://localhost/url1 结束....
请求3被取消
eg.Wait error: 请求出错
eg.Wait()
会阻塞,直到所有的 goroutine 都完成执行或者其中一个 goroutine 返回了错误。如果有错误发生,eg.Wait()
会返回第一个遇到的错误。
通过使用 errgroup
,我们可以更容易地管理多个 goroutine,并在其中一个 goroutine 发生错误时取消其他 goroutine。