go-errgroup使用

errgroup用于goroutine的同步,g.Go接收func() error函数作为参数,如果有一个goroutine返回error,则调用cancel函数取消context。因此一般接收的函数内部使用闭包,使用外部提供的context。

type Group struct {
	cancel func()    	// 用于取消context
	wg sync.WaitGroup   // 用于线程同步
	sem chan token		// 用于控制并发goroutine数量
	errOnce sync.Once   // 只cancel一次
	err     error		//  记录某个错误
}

// 其主要方法
func WithContext(ctx context.Context) (*Group, context.Context) { ... } // 创建Group
func (g *Group) Go(f func() error) { ... }  // 启动goroutine
func (g *Group) Wait() error       { ... }  // 等待所有线程结束
func (g *Group) SetLimit(n int)    { ... }  // 设置并发数量

Parallel(并发任务同步)

type Result string
type Search func(ctx context.Context, query string) (Result, error)

func myFunc(kind, query string) (Result, error) {
	// 模拟函数调用,web耗时2秒成功,image耗时4秒失败,video耗时6秒成功
	if kind == "web" {
		fmt.Println("web start")
		time.Sleep(2  * time.Second)
		fmt.Println("web end")
		return Result(fmt.Sprintf("%s result for %q", kind, query)), nil
	} else if kind == "image" {
		fmt.Println("image start")
		time.Sleep(4  * time.Second)
		fmt.Println("image end")
		return "", errors.New("image failed")
	} else {
		fmt.Println("video start")
		time.Sleep(6 * time.Second)
		fmt.Println("video end")
		return Result(fmt.Sprintf("%s result for %q", kind, query)), nil
	}
}
func fakeSearch(kind string) Search {
	return func(ctx context.Context, query string) (Result, error){
		done := make(chan Result, 1) // buf chan, 防止ctx直接返回时goroutine阻塞, 用来传递返回的Result
		errch := make(chan error, 1) // buf chan, 防止ctx直接返回时goroutine阻塞, 用来传递返回的err
		go func() {
			// 添加一个验证是否goroutine泄露的打印,如果有,说明每个goroutine都执行完了,不存在泄露
			defer func() {
				fmt.Printf("%s back goroutine end\n", kind)
			}()
			resp, err := myFunc(kind, query)
			if err != nil {
				errch <- err
				close(done)
			} else {
				done <- resp
			}
		}()
		select {
		case <- ctx.Done(): // 如果这里返回,会造成goroutine泄露,该怎么办?chan buf设置为1
			fmt.Printf("%s ctx.Done\n", kind)  // 由于image返回报错,导致kind为video时不等后台myFunc执行完就直接走到这里,
			return "", ctx.Err()
		case rsp, ok := <- done: // 使用done是否关闭来区分返回成功和失败
			if !ok {
				err := <- errch
				return "", err
			} else {
				return rsp, nil
			}
		}

	}
}

var (
	Web = fakeSearch("web")
	Image = fakeSearch("image")
	Video = fakeSearch("video")
)

func main() {
	Google := func(ctx context.Context, query string) ([]Result, error) {
		g, ctx := errgroup.WithContext(ctx)

		searches := []Search{Web, Image, Video}
		results := make([]Result, len(searches))

		for i, search := range searches {
			i, search := i, search
			// g.Go只接收无参函数,此无参函数内部,大部分变量都是函数外的变量,因此形成闭包
			g.Go(func() error {
				// search:     是外部变量,且由于延迟绑定,前一步必须重新赋值给search;
				// ctx, query: 是外部变量,但是不会变化,所以使用安全?
				result, err := search(ctx, query)
				if err == nil {
					results[i] = result // 给外面的results赋值?这里不会造成data race吗(存在data race)?
				}
				return err
			})
		}
		if err := g.Wait(); err != nil {
			return nil, err
		}
		return results, nil
	}


	results, err := Google(context.Background(), "golang")
	if err != nil {
		fmt.Printf("Google err: %v\n", err)

		time.Sleep(5 * time.Second)
		return
	}

	time.Sleep(5 * time.Second)
	// 都成功,打印所有结果
	for _, result := range results {
		fmt.Println(result)
	}

}

上面程序用了两个chan,下面用一个chan实现:

type Result string
type Search func(ctx context.Context, query string) (Result, error)

type response struct {
	resp Result
	err error
}
// myFunc 模拟真实函数调用,web耗时2秒成功,image耗时4秒失败,video耗时6秒成功
func myFunc(kind, query string) (Result, error) {
	// 模拟函数调用,web耗时2秒成功,image耗时4秒失败,video耗时6秒成功
	if kind == "web" {
		fmt.Println("web start")
		time.Sleep(2  * time.Second)
		fmt.Println("web end")
		return Result(fmt.Sprintf("%s result for %q", kind, query)), nil
	} else if kind == "image" {
		fmt.Println("image start")
		time.Sleep(4  * time.Second)
		fmt.Println("image end")
		return "", errors.New("image failed")
	} else {
		fmt.Println("video start")
		time.Sleep(6 * time.Second)
		fmt.Println("video end")
		return Result(fmt.Sprintf("%s result for %q", kind, query)), nil
	}
}

func fakeSearch(kind string) Search {
	return func(ctx context.Context, query string) (Result, error){
		done := make(chan response, 1) // 为避免goroutine泄露

		// 后台线程负责调用函数,往done中写数据
		go func() {
			resp, err := myFunc(kind, query)  // 真正的耗时操作,这里有个问题:就是ctx取消时,这个耗时操作仍在继续,无法取消
			done <- response{resp, err}
		}()

		// 主线程监听done和ctx
		select {
		case <- ctx.Done(): // 如果这里返回,是否会造成goroutine泄露?
			return "", ctx.Err()
		case r := <- done:
			return r.resp, r.err
		}
	}
}

var (
	Web = fakeSearch("web")
	Image = fakeSearch("image")
	Video = fakeSearch("video")
)
func main() {
	Google := func(ctx context.Context, query string) ([]Result, error) {
		g, ctx := errgroup.WithContext(ctx)

		searches := []Search{Web, Image, Video}    // 构造三个函数
		results := make([]Result, len(searches))   // 三个函数的结果

		for i, search := range searches {
			i, search := i, search

			g.Go(func() error {
				// search:     是外部变量,且由于延迟绑定,前一步必须重新赋值给search;
				// ctx, query: 是外部变量,但是不会变化,所以使用安全?
				result, err := search(ctx, query)
				if err == nil {
					results[i] = result // 给外面的results赋值?这里不会造成data race吗(确实存在data race)?
				}
				return err
			})
		}

		if err := g.Wait(); err != nil {
			return nil, err
		}
		return results, nil
	}


	results, err := Google(context.Background(), "golang")
	if err != nil {
		fmt.Printf("Google err: %v\n", err)

		time.Sleep(5 * time.Second)
		return
	}

	time.Sleep(5 * time.Second)
	// 都成功,打印所有结果
	for _, result := range results {
		fmt.Println(result)
	}

}

pipeline(管道模式处理)

package main

import (
	"context"
	"crypto/md5"
	"fmt"
	"io/ioutil"
	"log"
	"os"
	"path/filepath"

	"golang.org/x/sync/errgroup"
)

type result struct {
	path string
	sum  [md5.Size]byte
}

// MD5All pipeline模式,一些goroutine往里写,另外一些goroutine往外读

// MD5All 读取以root为根节点的文件树结构,返回文件路径和md5值对应的map。
// 如果目录遍历失败或者任何读操作失败,函数返回一个错误。
// MD5All reads all the files in the file tree rooted at root and returns a map
// from file path to the MD5 sum of the file's contents. If the directory walk
// fails or any read operation fails, MD5All returns an error.
func MD5All(ctx context.Context, root string) (map[string][md5.Size]byte, error) {
	// 当g.Wait()返回时ctx被取消
	// 当MD5All返回,甚至是错误返回时,所有的goroutine都已经结束,他们所用的内存将被垃圾回收
	// ctx is canceled when g.Wait() returns. When this version of MD5All returns
	// - even in case of error! - we know that all of the goroutines have finished
	// and the memory they were using can be garbage-collected.
	g, ctx := errgroup.WithContext(ctx)
	paths := make(chan string) // 这里采用unbuffer模式,是否会泄露?(下面超时时发送方直接close,因此接收方不会阻塞)

	// 这里只开一个线程用于walk,将获取到的path放入到paths 这个chan中
	g.Go(func() error {
		defer close(paths) // 引用外部的paths  由发送者进行close

		return filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
			if err != nil {
				return err
			}
			if !info.Mode().IsRegular() {
				return nil
			}

			select {
			case paths <- path:
			case <-ctx.Done(): // 引用外部的ctx
				return ctx.Err()
			}
			return nil
		})
	})

	// Start a fixed number of goroutines to read and digest files.
	c := make(chan result) // 这里也是unbuffer,是否会泄露
	const numDigesters = 20
	// 计算md5比较耗时,开20个goroutine同时计算
	for i := 0; i < numDigesters; i++ {
		g.Go(func() error {
			for path := range paths { // 引用外部paths, 外部的paths何时关闭(walk执行完或执行失败或超时)?
				data, err := ioutil.ReadFile(path)
				if err != nil {
					return err
				}
				select {
				case c <- result{path, md5.Sum(data)}:
				case <-ctx.Done(): // 引用外部的ctx
					return ctx.Err()
				}
			}
			return nil
		})
	}

	// 等待所有goroutine结束
	go func() {
		g.Wait()
		close(c) // 这里关闭后,main goroutine才不会阻塞
	}()

	m := make(map[string][md5.Size]byte)

	for r := range c { // c何时close?等所有的goroutine都结束
		m[r.path] = r.sum
	}

	// 检查是否有goroutine失败。
	// 因为g累积了错误,我们不需要在单独的结果集中发送或检查它们。
	// 这里意思是g的错误只记录一个,不需要再通过chan发送错误了
	// Check whether any of the goroutines failed. Since g is accumulating the
	// errors, we don't need to send them (or check for them) in the individual
	// results sent on the channel.

	if err := g.Wait(); err != nil {
		return nil, err
	}
	return m, nil
}

func MyMD5All(ctx context.Context, root string) (map[string][md5.Size]byte, error) {
	// 传入一个ctx和root节点,获取遍历文件的md5
	path_chan := make(chan string) // buffered or unbuffered, 待定
	g, ctx := errgroup.WithContext(ctx)

	// g.Go中函数一旦返回error,触发g.cancel,取消ctx
	g.Go(func() error {
		defer close(path_chan)

		return filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
			if err != nil {
				return err
			}
			if !info.Mode().IsRegular() {
				return nil
			}

			select {
			case path_chan <- path:
			case <-ctx.Done():
				return ctx.Err()
			}
			return nil
		})
	})

	result_chan := make(chan result) // buffered or unbuffered, 待定

	for i := 0; i < 10; i++ {
		g.Go(func() error {
			// close时退出for循环
			for path := range path_chan {
				data, err := ioutil.ReadFile(path)
				if err != nil {
					return err
				}

				select {
				case result_chan <- result{path, md5.Sum(data)}:
				case <-ctx.Done():
					return ctx.Err()
				}
			}
			return nil
		})
	}

	//
	go func() {
		g.Wait()
		close(result_chan)
	}()

	m := make(map[string][md5.Size]byte)
	for r := range result_chan {
		m[r.path] = r.sum
	}

	// Wait依托于sync.WaitGroup的Wait和context.Cancel,都是可以重复调用的
	if err := g.Wait(); err != nil {
		return nil, err
	}
	return m, nil
}

// Pipeline demonstrates the use of a Group to implement a multi-stage
// pipeline: a version of the MD5All function with bounded parallelism from
// https://blog.golang.org/pipelines.
func main() {
	m, err := MD5All(context.Background(), ".")
	if err != nil {
		log.Fatal(err)
	}

	for k, sum := range m {
		fmt.Printf("%s:\t%x\n", k, sum)
	}
}

注意点:

  • 在errgroup使用中,由于func (g *Group) Go(f func() error)接收的是无参数函数,因此内部通过闭包的方式引用外部参数,比如常用的ctx。
  • errgroup中一旦某个g.Go()返回错误,会触发g.err赋值和cancel ctx, 且只会触发一次,这样其他的goroutine会通过监听ctx的方式来结束任务,就达到了一个任务失败,全部任务级联取消的效果。
  • errgroup内部通过名为sem的buffered chan来控制并发的goroutine数。
  • errgroup有个缺点,
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值