Go Concurrency Patterns: Pipelines and cancellation

原文地址: https://blog.golang.org/pipelines

简介

Go 语言提供的并发原语使得可以很方便的构建数据流 pipeline,使用这样的 pipeline 可以高效的利用 I/O 和多 cpu 的优势. 这篇文章我们将展示如何构建并使用 pipeline.

什么是 pipeline ?

在 go 语言中没有正式的定义什么是 pipeline. 它只是众多并发程序类型中的一种. 非正式的说,pipeline 是一系列通过 channel 联系起来的 stage. 每个 stage 包含多个执行相同功能的 goroutine. 在每个 stage 中, goroutine 执行以下操作:

  • 从输入 channel 中读取数据
  • 处理数据,产生新的数据
  • 将数据发送到输出 channel

除了第一个和最后一个 stage,每个 stage 可以拥有任意数量的 输入channel 和 输出channel。 第一个和最后一个 stage 只能有一个输入channel一个输出channel. 第一个 stage 也被称为 SourceProducer, 最后一个 stage 被称为 SinkConsumer

接下来,我们通过一个简单的示例来说明.

平方数

假设我们的 pipeline 有三个 stage.

第一个 stage 是 gen, 用来将与一组数字转化为一个 channel.

func gen(nums ...int) <-chan int {
    out := make(chan int)
    go func() {
        for _, n := range nums {
            out <- n
        }
        close(out)
    }()
    return out
}

第二个 stage 是 sq, 从 输入channel 中接收数字,计算数字的平方数,并将数字写入输出channel中.

func sq(in <-chan int) <-chan int {
    out := make(chan int)
    go func() {
        for n := range in {
            out <- n * n
        }
        close(out)
    }()
    return out
}

main 函数中建立该 pipeline,并运行最后最后一个 stage. 最后一个 stage 从第二个 stage 中接收平方数,并将接收到的数据打印出来.

func main() {
    // Set up the pipeline.
    c := gen(2, 3)
    out := sq(c)

    // Consume the output.
    fmt.Println(<-out) // 4
    fmt.Println(<-out) // 9
}

因为 gen 的输入channel 和输出 channel具有相同的输入和输出类型,因此我们可以重复的使用他们任意次.

我们可以将 main 方法重写为如下形式:

func main() {
    // Set up the pipeline and consume the output.
    for n := range sq(sq(gen(2, 3))) {
        fmt.Println(n) // 16 then 81
    }
}
扇入,扇出

多个函数可以从一个channel中读取数据,直到这个channel关闭,这叫做 扇出(fan-out). 通过这种方式,我们可以将一些列任务分派给多个 woker,这些 worker 可以在多个 CPU 上执行或者进行 I/O 操作.

一个函数可以从多个输入 channel 中读取并处理数据,直到所有的 channel 被关闭. 并将输出写入到同一个输出channel 上,处理完数据后关闭输出 channel. 这叫做 扇入(fan-in).

举个例子,我们可以运行两个 sq 方法,这两个方法均从同一个输入 channel 上读取数据. 这里我们再引入另外一个方法 merge, 该方法用于将两个 sq 的输出整合到通过一个输出channel中.

func main() {
    in := gen(2, 3)

    // Distribute the sq work across two goroutines that both read from in.
    c1 := sq(in)
    c2 := sq(in)

    // Consume the merged output from c1 and c2.
    for n := range merge(c1, c2) {
        fmt.Println(n) // 4 then 9, or 9 then 4
    }
}
func merge(cs ...<-chan int) <-chan int {
    var wg sync.WaitGroup
    out := make(chan int)

    // Start an output goroutine for each input channel in cs.  output
    // copies values from c to out until c is closed, then calls wg.Done.
    output := func(c <-chan int) {
        for n := range c {
            out <- n
        }
        wg.Done()
    }
    wg.Add(len(cs))
    for _, c := range cs {
        go output(c)
    }

    // Start a goroutine to close out once all the output goroutines are
    // done.  This must start after the wg.Add call.
    go func() {
        wg.Wait()
        close(out)
    }()
    return out
}
尽快停止

截至目前,我们将所有的 pipeline 函数设计为如下模式:

  • 当前 stage 应该关闭 输出channel,当我们处理完了所有的输入数据,并且所有的输出数据已经发送到了 输出channel 之后.
  • 当前 stage 应该持续接收数据直到 输入channel 被关闭.

这样设计使得我们可以再接收stage 中使用 range 循环来处理所有的数据,当所有数据被处理并发送到输出channel之后,我们的循环为自动退出.

但是在真实情况下,我们往往不会接收从输入channel中接收所有的数据. 有时,我们仅仅需要读取输入数据的一个子集便可以继续往下进行了. 更通常的情况下,stage 提前退出,因为上流 stage 发生了错误. 在这种情况下,我们不应该等待所有的数据到来,并且我们希望上流 stage 直接退出而不是继续产生哪些我们已经不在需要的数据.

在我们的例子中,如果当前 stage 无法正确的处理所有的 输入数据,那么上流尝试继续发送数据到 stage 会被永久的阻塞住.

	// Consume the first value from the output.
    out := merge(c1, c2)
    fmt.Println(<-out) // 4 or 9
    return
    // Since we didn't receive the second value from out,
    // one of the output goroutines is hung attempting to send it.

这会导致资源泄露. goroutine 会消耗内存和运行时资源, goroutine 堆栈中的对该 channel 的引用会阻止垃圾回收器回收该 channel 所占的资源,直到它自己退出.

我们需要我们 pipeline 中的上流 stage 总是能自动退出即使下流 stage 无法接收该stage 所产生的所有数据. 一种方案是给输出channel设置 buffer. buffer 中可以保存指定数量的数据,只要buffer没有满,往这样的channel 中发送数据的操作总是能立马返回.

c := make(chan int, 2) // buffer size 2
c <- 1  // succeeds immediately
c <- 2  // succeeds immediately
c <- 3  // blocks until another goroutine does <-c and receives 1

如果我们在创建一个输出channel的时候,便直到需要发送多少数据,那么使用 buffer 会简化我们的代码.

func gen(nums ...int) <-chan int {
    // 这里,对于每个输入数字,我们均会产生一个输出,
    // 因此我们便可以将输出 channel 的buffer 大小设置为输入 nums 的大小
    // 这样我们往 out channel 中发送数据的操作永远不会阻塞当前方法
    out := make(chan int, len(nums))
    for _, n := range nums {
        out <- n
    }
    close(out)
    return out
}

另外一种方案是,下流 stage 通知上流stage,它已经停止接收数据了.

取消接收

当我们在 main 方法中决定不再从 out channel 中接收数据,直接退出的时候,我们必须通知上流 stage,我们已经不再从该 channel 中接受数据了. 我们可以通过一个 done channel 来实现.

func main() {
    in := gen(2, 3)

    // Distribute the sq work across two goroutines that both read from in.
    c1 := sq(in)
    c2 := sq(in)

    // 因为当前 stage 有两个上流 channel,因此我们将 done 的 buffer 大小初始化为 2
    done := make(chan struct{}, 2)
    out := merge(done, c1, c2)
    fmt.Println(<-out) // 4 or 9

    // Tell the remaining senders we're leaving.
    done <- struct{}{}
    done <- struct{}{}
}

上流 stage 需要做如下修改:

func merge(done <-chan struct{}, cs ...<-chan int) <-chan int {
    var wg sync.WaitGroup
    out := make(chan int)

    // Start an output goroutine for each input channel in cs.  output
    // copies values from c to out until c is closed or it receives a value
    // from done, then output calls wg.Done.
    output := func(c <-chan int) {
        for n := range c {
            // 这里使用 select 语句代替原先的单纯发送数据的操作
            // 以便当下流 stage 停止接收,往 done channel 上发送停止接收的信号
            select {
            case out <- n:
            // 当我们在 main 方法中往 done channel 发送数据后,我们便会在这里接收到该数据
            // 我们便可以结束当前 stage 了
            case <-done: 
            }
        }
        wg.Done()
    }
    // ... the rest is unchanged ...
}

这种方法存在一个问题,那就是对于每个下流 stage,都得知道上流 stage 的数量,这样我们才能确定 done channel 的大小. 这看起来并不是一个优雅的解决方案.

我们需要一种解决方案,这个解决方案不需要知道上流和下流的 stage 数量.

在 go 中,我们可以通过关闭 channel 来实现. 因为试图从一个已经关闭的 channel 上接收数据总是会直接返回,返回值是一个对应数据类型的 zero 值.

这意味着,我们只需要在 main 函数中关闭 done channel,然后所有尝试从 done 中接收信号的上流stage 都会收到一个零值,这样他们便可以直接退出了.

修改 main 函数,使用这种方案. 我们需要给每个上流 stage 增加一个done channel 参数,这样,当 在main 中,我们关闭 done 之后,所有上流 stage 都能收到信号,并退出. 上流stage 的实现类似与 merge 的实现,略.

func main() {
    // Set up a done channel that's shared by the whole pipeline,
    // and close that channel when this pipeline exits, as a signal
    // for all the goroutines we started to exit.
    done := make(chan struct{}) // 注意,这里 done 不要 buffer
    defer close(done) // 使用 defer,在 main 函数退出时,该 channel 会被关闭

    in := gen(done, 2, 3)

    // Distribute the sq work across two goroutines that both read from in.
    c1 := sq(done, in)
    c2 := sq(done, in)

    // Consume the first value from output.
    out := merge(done, c1, c2)
    fmt.Println(<-out) // 4 or 9

    // done will be closed by the deferred call.
}
计算文件 MD5 checksum

接下来,我们看一个更加真实的例子.

MD5 经常被用来计算文件的 checksum. md5sum 命令可以输出一组文件的 checksum.

% md5sum *.go
d47c2bbc28298ca9befdfbc5d3aa4e65  bounded.go
ee869afd31f83cbb2d10ee81b2b831dc  parallel.go
b88175e65fdcbc01ac08aaf1fd9b5e96  serial.go

在这个例子中,我们来实现 md5sum 命令. 不同的是我们的md5sum 命令接收一个目录,输出这个目录下所有文件的 checksum,按照路径排序.

func main() {
    // Calculate the MD5 sum of all files under the specified directory,
    // then print the results sorted by path name.
    m, err := MD5All(os.Args[1])
    if err != nil {
        fmt.Println(err)
        return
    }
    var paths []string
    for path := range m {
        paths = append(paths, path)
    }
    sort.Strings(paths)
    for _, path := range paths {
        fmt.Printf("%x  %s\n", m[path], path)
    }
}

MD5All 的实现如下

// MD5All reads all the files in the file tree rooted at root and returns a map
// from file path to the MD5 sum of the file's contents.  If the directory walk
// fails or any read operation fails, MD5All returns an error.
func MD5All(root string) (map[string][md5.Size]byte, error) {
    m := make(map[string][md5.Size]byte)
    err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
        if err != nil {
            return err
        }
        if !info.Mode().IsRegular() {
            return nil
        }
        data, err := ioutil.ReadFile(path)
        if err != nil {
            return err
        }
        m[path] = md5.Sum(data)
        return nil
    })
    if err != nil {
        return nil, err
    }
    return m, nil
}
并行化计算 MD5 checksum

在这节中,我们将 MD5All 拆分为两个有两个 stage 的 pipeline. 第一个stage sumFiles 遍历文件目录,计算文件 checksum,并将结果发送到输出 channel 中, 计算结果的类型为 result.

type result struct {
    path string
    sum  [md5.Size]byte
    err  error
}
func sumFiles(done <-chan struct{}, root string) (<-chan result, <-chan error) {
    // For each regular file, start a goroutine that sums the file and sends
    // the result on c.  Send the result of the walk on errc.
    c := make(chan result)
    errc := make(chan error, 1)
    // 主线程开启一个 goroutine, 在goroutine 中遍历文件,并计算checksum,将结果输出到 c channel,如果发生错误,将错误信息发送到 errc channel
    go func() {
        var wg sync.WaitGroup
        err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
            if err != nil {
                return err
            }
            if !info.Mode().IsRegular() {
                return nil
            }
            wg.Add(1)
            // 为每个文件使用一个单独的 goroutine 来计算文件 checksum
            go func() {
                data, err := ioutil.ReadFile(path)
                // 尝试往 channel c 中发送计算结果,如果发送操作被阻塞且 done 已经被关闭
                // select 语句便会进入 done 对应的 case,程序得以继续往下进行
                select {
                case c <- result{path, md5.Sum(data), err}:
                case <-done:
                }
                wg.Done()
            }()
            // Abort the walk if done is closed.
            select {
            case <-done:
                return errors.New("walk canceled")
            default:
                return nil
            }
        })
        // Walk has returned, so all calls to wg.Add are done.  Start a
        // goroutine to close c once all the sends are done.
        // 等待所有计算文件 checksum 的 goroutine 退出
        go func() { 
            wg.Wait()
            close(c) // 结束时,关闭 channel c
        }()
        // No select needed here, since errc is buffered.
        errc <- err
    }()
    return c, errc
}

MD5All 用来接收 checksum 或者 sumfiles 中发生的错误.

func MD5All(root string) (map[string][md5.Size]byte, error) {
    // MD5All closes the done channel when it returns; it may do so before
    // receiving all the values from c and errc.
    done := make(chan struct{})
    defer close(done)

    c, errc := sumFiles(done, root)

    m := make(map[string][md5.Size]byte)
    // 从 c 上读取数据,无论 sumFiles 是否正常结束,
    // range c 都确保我们不会阻塞在这个 for 循环处
    for r := range c {
        if r.err != nil {
            return nil, r.err
        }
        m[r.path] = r.sum
    }
    // 检查是否发生错误
    if err := <-errc; err != nil {
        return nil, err
    }
    return m, nil
}
限制并行数量

在上一节中,我们给每个文件创建一个 goroutine 用来计算文件的 MD5 checksum. 这里有一个问题,如果某个目录下有很多文件,那么我们便需要创建大量个 goroutine,这可能会超出实际的物理内存大小.

我们可以通过限制并行处理的文件数量来解决这个问题. 这里,我们通过创建指定数量的 goroutine 来读取文件. 此时,我们的 pipeline 就需要有三个stage 了: 遍历文件目录,读取数据并计算 MD5 checksum, 收集计算结果.

第一个 stage walkFiles 读取文件并将结果写入输出 channel 中

func walkFiles(done <-chan struct{}, root string) (<-chan string, <-chan error) {
    paths := make(chan string)
    errc := make(chan error, 1)
    go func() {
        // Close the paths channel after Walk returns.
        defer close(paths)
        // No select needed for this send, since errc is buffered.
        errc <- filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
            if err != nil {
                return err
            }
            if !info.Mode().IsRegular() {
                return nil
            }
            select {
            case paths <- path:
            case <-done:
                return errors.New("walk canceled")
            }
            return nil
        })
    }()
    return paths, errc
}

第二个 stage 启用指定数量个 goroutine 执行 digester 方法. 这个 goroutine 从 paths channel 中读取文件路径并计算 MD5 checksum,将结果输出到 channel c 上

// 注意,这里我们不关闭 channel c,因为我们有多个 goroutine 往 c 中发送数据
func digester(done <-chan struct{}, paths <-chan string, c chan<- result) {
    for path := range paths {
        data, err := ioutil.ReadFile(path)
        select {
        case c <- result{path, md5.Sum(data), err}:
        case <-done:
            return
        }
    }
}
 // Start a fixed number of goroutines to read and digest files.
    c := make(chan result)
    var wg sync.WaitGroup
    const numDigesters = 20
    wg.Add(numDigesters)
    for i := 0; i < numDigesters; i++ {
        go func() {
            digester(done, paths, c)
            wg.Done()
        }()
    }
    go func() {
        wg.Wait()
        close(c)
    }()

最后一个 stage 从 channel c 上接收计算结果或者错误信息.

 	m := make(map[string][md5.Size]byte)
    for r := range c {
        if r.err != nil {
            return nil, r.err
        }
        m[r.path] = r.sum
    }
    // Check whether the Walk failed.
    if err := <-errc; err != nil {
        return nil, err
    }
    return m, nil

END!!!

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值