原文:Go Concurrency Patterns: Pipelines and cancellation。
引言
Go 的并发基础数据使得码农能很容易地构建能有效利用 I/O 和多 CPU 的流式数据管道。这篇文章提供了一些使用这些管道的例子、强调了当操作失败时的处理技巧、并介绍了整洁地处理失败的技巧。
什么是管道?
对于 Go 中的管道,并没有一个正式的定义;它只是许多并行程序中的一种。非正式地说,管道是一系列由 channel 连接的阶段,每个阶段都是一组运行同一个函数的的 goroutine。在每个阶段,这些 goroutine 会执行以下操作:
- 通过入站 channel 从上游获取数据
- 使用一些函数处理获取到的数据,一般会产生新的数据
- 通过出站 channel 向下游发送数据
每个阶段都有一些入站和出站 channel,除了第一个和最后一个阶段。第一个阶段只有出站 channel,最后一个阶段只有入站 channel。第一个阶段有时候被称作源,或生产者。最后一个阶段有时候被称为池,或消费者。
我们将会以一个简单的管道例子作为开始,来解释方法和技巧。稍后我们会提供一个更符合现实的例子。
平方数
考虑一个有三个阶段的管道。
第一个阶段,gen,是一个将整型列表转化为一个发送列表中整型值的 channel 的函数。gen 函数首先启动一个 goroutine,去将整型值发送到 channel 中。当所有值都发送出去后,关闭这个 channel。
func gen(nums ...int) <-chan int {
out := make(chan int)
go func() {
for _, n := range nums {
out <- n
}
close(out)
}()
return out
}
第二个阶段,sq,从入站 channel 中读取整型值,并返回另一个出站 channel。这个 channel 会发送每个接收到的整型值的平方。在入站 channel 被关闭,以及这个阶段把所有数据都发送给下游之后,出站 channel 会被关闭:
func sq(in <-chan int) <-chan int {
out := make(chan int)
go func() {
for n := range in {
out <- n * n
}
close(out)
}()
return out
}
main 函数建立管道并运行最终的阶段:它从第二个阶段获取数据,并打印每个数据,直到 channel 被关闭:
func main() {
// Set up the pipeline.
c := gen(2,3)
out := sq(c)
// Consume the output.
fmt.Println(<-out) // 4
fmt.Println(<-out) // 9
}
由于 sq 的入站和出站 channel 的类型是一样的,我们可以任意次组合它。我们也可以把 main 函数重写成一个循环,跟其他阶段一样:
func main() {
// 设置管道并消费输出
for n := range sq(sq(gen(2, 3))) {
fmt.Println(n) // 16 然后是 81
}
}
扇入和扇出
许多函数可以从同一个 channel 读取数据,直到这个 channel 被关闭;这个过程叫做扇出 (fan-out)。它提供了一种将工作分配给一组工作节点,从而并行利用 CPU 和 I/O 资源的方式。
一个函数可以通过把多个输入 channel 复用到单个 channel (当所有输入 channel 都关闭后,这个单个 channel 会被关闭),来从多个输入读取数据,并处理这些数据,直到所有的输入都关闭。这个过程叫扇入 (fan-in)。
我们可以修改我们的管道,来运行两个 sq 实例,这两个实例从同一个输入 channel 读取数据。下面我们介绍一个新的函数,merge,来扇入结果:
func main() {
in := gen(2, 3)
// 将 sq 工作分发给两个 goroutine,这两个 goroutine 都从 in 中读取数据。
c1 := sq(in)
c2 := sq(in)
// 消费从 c1 和 c2 合并的输出 channel
for n := range merge(c1, c2) {
fmt.Println(n) // 打印 4 和 9,或者打印 9 和 4
}
}
merge 函数将一系列 channel 转换为一个单独的 channel。它内部对每个入站 channel 启动一个 goroutine,并从各入站 channel 中读取并复制数据,然后发送给一个单独的出站 channel。一旦所有的输出 goroutine 启动之后,merge 函数启动另一个 goroutine,当所有对出站 channel 的发送结束后关闭出站 channel。
往一个已经被关闭的 channel 发送数据会引起错误,所以在调用关闭之前,确保所有的发送都已经完成是很重要的。sync.WaitGroup 类型提供了一个简单的方式来准备同步:
func merge(cs ...<-chan int) <-chan int {
var wg sync.WaitGroup
out := make(chan int)
// 对于 cs 中的每个输入 channel 启动一个输出 goroutine。
// 从 c 中输出复制的数据到 output,直到 c 被关闭,然后调用 wg.Done。
output := func(c <-chan int) {
for n:= range c {
out <- n
}
wg.Done()
}
wg.Add(len(cs))
for _, c := range cs {
go output(c)
}
// 启动一个 goroutine,当所有的输出 goroutine 都结束后,关闭输出 channel,
// 这个步骤必须在 wg.Add 调用后开启。
go func() {
wg.Wait()
close(out)
}()
return out
}
突然停止
对于我们的管道函数有这样的模式:
- 在一个阶段中,当所有发送操作都结束后,出站 channel 应该被关闭。
- 在一个阶段中,持续对入站 channel 进行读操作,直到入站 channel 被关闭。
这个模式允许每个接收阶段被写成一个 range 循环,并保证所有 goroutine 在把所有数据都成功发送给下游后会立即退出。
但是在真实的管道中,各阶段并不总是获取到所有的入站数据。有时下面的情况会被有意设计:接收方可能只需要获取上游发送数据的一个子集来处理。更常发生的情况是,由于上游阶段传入的数据代表一个错误从而造成下游阶段过早退出。在这些情况下,接收方不应继续等待剩余的数据,我们希望上游的阶段停止产生下游下游阶段不需要的数据。
在我们的例子中,如果一个阶段不再消费入站数据,那么尝试发送这些数据的上游 goroutine 将会被永远阻塞:
// 消费从输出获取的第一个数据
out := merge(c1, c2)
fmt.Println(<-out) // 4 或 9
return
// 由于我们没有从 out 中接收第二个值,
// 其中一个 goroutine 会被挂起,并不停尝试去发送第二个值。
}
这是一种资源泄露:goroutine 消耗内存和运行资源,goroutine 栈中的堆引用会保持数据不被垃圾回收。Goroutine 不会被垃圾回收,它们必须自己退出。
我们需要准备这个管道中的上游的阶段,让它们能正常退出,即使下游的阶段没能获取所有的入站数据。其中一种方式是把出站 channel 改为有缓冲 channel。这种 channel 能保存固定数量的数据;当 channel 中还有空间时,发送操作会立即接受:
c := make(chan int, 2) // buffer 大小为 2
c <- 1 // 立刻成功
c <- 2 // 立刻成功
c <- 3 // 保持阻塞,直到另一个 goroutine 执行 <-c 并获取 1
如果在创建 channel 的时候已知要发送数据的个数,那么有缓冲的 channel 能简化代码。比如,我们可以重写 gen 函数,将输入的整型值写入一个有缓冲的 channel 中,避免创建一个新的 goroutine (因为往未满的有缓冲 channel 中写数据是不阻塞的):
func gen(nums ...int) <-chan int {
out := make(chan int, len(nums))
for _, n := range nums {
out <- n
}
close(out)
return out
}
返回到我们管道中的阻塞 goroutine,我们可能考虑对 merge 函数返回的出站 channel 加一个缓冲:
func merge(cs ...<-chan int) <-chan int {
var wg sync.WaitGroup
out := make(chan int, 1) // 给未读的输入数据设置足够的空间
// ... 后面的代码保持不变 ...
虽然这样做修复了阻塞 goroutine 的问题,但这是很糟糕的代码。在这里选择缓冲空间为 1 是基于知道 merge 会接收的数据个数及下游阶段会消费的数据量。这种代码是很脆弱的:如果我们给 gen 传递一个额外的数据,或者如果下游阶段读取了更少的数据,那么我们仍然会有阻塞 goroutine 的情况。
所以我们需要提供一种方式,让下游阶段去告诉发送者它们将停止接收输入数据。
显式取消
当 main 函数决定不接收所有来自 out 的数据并退出时,它必须告诉上游阶段的 goroutine 去放弃发送数据。它通过一个叫做 done 的 channel 来发送这样的消息。它发送两个值,因为有两个潜在的阻塞发送者:
func main() {
in := gen(2, 3)
// 把 sq 的工作分发给两个 goroutine,这两个 goroutine 都从 in 中读取数据
c1 := sq(in)
c2 := sq(in)
// 从输出中消费第一个值
done := make(chan struct{}, 2)
out := merge(done, c1, c2)
fmt.Println(<-out) // 打印 4 或 9
// 告诉剩下的发送者,我们要退出了
done <- struct{}{}
done <- struct{}{}
}
发送端的 goroutine 使用 select 语句来代替原有发送操作。这个 select 语句处理向 out 发送数据、或从 done 读取数据。done 的类型是一个空的结构体,它的类型无关紧要:它表明应该停止向 out 发送数据。output goroutine 继续循环从入站 channel 读取数据,以防止它的上游阶段被阻塞。(稍后我们会讨论如何允许这个循环早点结束)
func merge(done <-chan struct{}, cs ...<-chan int) <-chan int {
var wg sync.WaitGroup
out := make(chan int)
// 对 cs 中的每个输入 channel 都启动一个输出 goroutine。
// 输出 goroutine 从 c 中复制数据并发送到 out,直到 c 被关闭,或它从 done
// 中接收到数据,然后输出 goroutine 调用 wg.Done。
output := func(c <- chan int) {
for n := range c {
select {
case out <- n:
case <- done:
}
}
wg.Done()
}
// ... 下面的部分没有变化 ...
}
这种方式有个问题:每个下游接受者需要知道有可能被阻塞的上游发送者的数量,并在过早返回时给每个发送者准备发送一个信号。跟踪这样计数即冗长又容易出错。
我们需要一种方式来告诉未知且没有数量限制的 goroutine 去停止把它们的数据发送给下游。在 Go 里,我们可以通过关闭 channel 来实现,因为对于一个关闭 channel 的获取操作总能立即返回,并获取一个零值。
这意味着 main 函数可以通过简单地关闭 done channel 来解除所有发送者的阻塞。这个关闭操作是一个有效的向所有发送者广播的信号。我们扩展一下我们的管道函数,把 done 作为一个新的参数,并通过 defer 语句来处理 channel close 的情况,从而使所有从 main 退出的路径都能出发向管道阶段发送退出信号。
func main() {
// 建立一个 done channel,用于在所有管道中共享,
// 当管道退出时,关闭 done,作为通知所有 goroutine 退出的信号。
done := make(chan struct{})
defer close(done)
in := gen(done, 2, 3)
// 把 sq 的工作分配到两个 goroutine 中,
// 这两个 goroutine 都从 in 中读取数据
c1 := sq(done, in)
c2 := sq(done, in)
// 从输出消费第一个值
out := merge(done, c1, c2)
fmt.Println(<-out) // 输出 4 或 9
}
现在每个管道阶段都可以在 done 被关闭时自由地退出了。merge 函数中的 output goroutine 可以在不消耗它的入站 channel 的前提下退出,因为它知道当 done 被关闭时,上游的发送者,即 sq,会停止发送数据。output 通过 defer 语句保证所有退出路径都会调用 wg.Done:
func merge(done <-chan struct{}, cs ...<-chan int) <-chan int {
var wg sync.WaitGroup
out := make(chan int)
output := func(c <-chan int) {
defer wg.Done()
for n := range c {
select {
case out <- n:
case <- done:
return
}
}
}
// ... 剩下的部分没有变化 ...
}
相似地,sq 可以在 done 被关闭时立刻返回。sq 通过 defer 语句保证在所有返回路径上都会关闭 out channel:
func sq(done <-chan struct{}, in <-chan int) <-chan int {
out := make(chan int)
go func() {
defer close(out)
for n := range in {
select {
case out <- n * n:
case <-done:
return
}
}
}()
return out
}
下面是构建管道的指导原则:
- 当每个阶段中的发送操作完成时,关闭出站 channel。
- 每个阶段持续从入站 channel 获取数据,直到这些 channel 被关闭,或发送方被阻塞。
管道可以通过两种方式来保证发送方不会被阻塞:或者确保 channel 有足够大的缓冲,或者当接受者可能放弃 channel 的时候发送显式信号。
获取一棵树的摘要
下面我们考虑一个更现实的管道设计。
MD5 是一种消息摘要算法,它可以用来计算文件的校验和。命令 md5sum 用于分别打印一组文件的摘要信息。
% md5sum *.go
d47c2bbc28298ca9befdfbc5d3aa4e65 bounded.go
ee869afd31f83cbb2d10ee81b2b831dc parallel.go
b88175e65fdcbc01ac08aaf1fd9b5e96 serial.go
我们的程序例子就像 md5sum 例子。不提供的是,我们的程序获取单个路径作为参数,并按照文件名称的顺序打印这个路径下每个文件的摘要信息。
% go run serial.go .
d47c2bbc28298ca9befdfbc5d3aa4e65 bounded.go
ee869afd31f83cbb2d10ee81b2b831dc parallel.go
b88175e65fdcbc01ac08aaf1fd9b5e96 serial.go
程序的主函数调用一个帮助程序,MD5ALL,来获取一个键值对分别为文件名称和摘要信息的 map。然后主函数对结果进行排序,并打印出来。
func main() {
// 计算给定路径下所有文件的 MD5 校验和,
// 然后将按路径名称排序的结果打印出来。
m, err := MD5All(os.Args[1])
if err != nil {
fmt.Println(err)
return
}
var paths []string
for path := range m {
paths = append(paths, path)
}
sort.Strings(paths)
for _, path := range paths {
fmt.Println("%x %s\n", m[path], path)
}
}
MD5All 函数是我们讨论的重点。在 serial.go 文件中的实现没有用到并发,只是简单地遍历树,并且读取每个文件,然后计算每个文件的校验和。
// MD5All 读取文件树中的所有文件,并返回一个键值对分别是文件路径和文件内容 MD5 校验和的 map。
// 如果路径遍历失败或读取操作失败的话,MD5All 返回操作。
func MD5All(root string) (map[string][md5.Size]byte, err) {
m := make(map[string][md5.Size])
err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.Mode().IsRegular() {
return nil
}
data, err := ioutil.ReadFile(path)
if err != nil {
return err
}
m[path] = md5.Sum(data)
return nil
})
if err != nil {
return nil, err
}
return m, nil
}
并行摘要
在 parallel.go,我们把 MD5All 分成一个两阶段的管道。第一个阶段,sumFiles,遍历文件树,在一个新的 goroutine 中计算文件的摘要,然后把结果发送给一个 channel 中,结果的类型是 result:
type result struct {
path string
sum [md5.Size]byte
err error
}
sumFiles 返回两个 channel,一个用于传递 result,一个用于传递 filepath.Walk 错误结果。walk 函数对每个文件启动一个新的 goroutine 来进行处理,然后检查 done channel。如果 done 被关闭了,那么 walk 程序立刻停止:
func sumFiles(done <-chan struct{}, root string) (<-chan result, <-chan error) {
// 对每个普通文件,启动一个 goroutine 来计算文件校验和,并把结果发送到 c。
// 把 walk 的结果发送到 errc。
c := make(chan result)
errc := make(chan error, 1)
go func() {
var wg sync.WaitGroup
err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.Mode().IsRegular() {
return nil
}
wg.Add(1)
go func() {
data, err := ioutil.ReadFile(path)
select {
case c <- result{path, md5.Sum(data), err}:
case <-done:
}
wg.Done()
}()
// 如果 done 被关闭,则退出程序
select {
case <-done:
return errors.New("walk canceled")
default:
return nil
}
})
// 遍历返回后,所有调用的 wg.Add 都完成了
// 当所有发送都完成时,立即启动一个 goroutine 来关闭 c
go func() {
wg.Wait()
close(c)
}()
// 这里不用使用 select,因为 errc 是有缓冲的
errc <- err
}()
return c, errc
}
MD5All 从 c.MD5All 获取摘要信息。如果过早退出的话,则返回错误,并通过 defer 来关闭 done:
func MD5All(root string) (map[string][md5.Size]byte, error) {
// 当 MD5All 返回时关闭 done channel;
// 这个关闭操作可能会在从 c 和 errc 获取所有值之前执行。
done := make(chan struct{})
defer close(done)
c, errc := sumFiles(done, root)
m := make(map[string][md5.Size]byte)
for r := range c {
if r.err != nil {
return nil, r.err
}
m[r.path] = r.sum
}
if err := <-errc; err != nil {
return nil, err
}
return m, nil
}
受限并行
在 parallel.go 中的 MD5All 实现对每个文件都启动了一个 goroutine。当一个路径下有非常多文件的时候,这种操作会分配过多内存,超过机器的承受能力。
我们可以通过限制读文件的并发量来限制内存分配量。在 bounded.go 里,我们通过创建固定数量的 goroutine 来实现。现在我们的管道有三个阶段:遍历树,读取文件并计算摘要,以及收集摘要信息。
在第一个阶段,walkFiles 发送树中的常规文件:
func walkFiles(done <-chan struct{}, root string) (<-chan string, <-chan error) {
paths := make(chan string)
errc := make(chan error, 1)
go func() {
// 在遍历返回后关闭 paths channel。
defer close(paths)
// 不需要使用 select 来发送,因为 errc 是有缓冲的。
errc <- filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.Mode().IsRegular() {
return nil
}
select {
case paths <- path:
case <-done:
return errors.New("walk canceled")
}
return nil
})
}()
return paths, errc
}
第二个阶段启动一个固定数量的 digester goroutine 去从 paths 获取文件名称,并把结果 result 通过 channel c 发送出去:
func digester(done <-chan struct{}, paths <-chan string, c chan<- result) {
for path := range paths {
data, err := ioutil.ReadFile(path)
select {
case c <- result{path, md5.Sum(data), err}:
case <-done:
return
}
}
}
跟之前的例子不同,digester 不会关闭它的输出 channel,因为多个 goroutine 会通过一个共享的 channel 去发送消息。当所有的 digester 都完成工作之后,MD5All 会完成对 channel 的关闭操作:
// 启动固定数量的 goroutine 去读取文件并计算摘要。
c := make(chan result)
var wg sync.WaitGroup
const numDigesters = 20
wg.Add(numDigesters)
for i := 0; i < numDigesters; i++ {
go func() {
digester(done, paths, c)
wg.Done()
}()
}
go func() {
wg.wait()
close(c)
}()
我们也可以让每个 digester 自己去创建并返回输出 channel,但是这样我们就需要额外的 goroutine 去扇入结果。
最后一个阶段从 c 获取所有的结果,然后检查 errc 是否有错误信息。这个错误检查不能在更早的时间检查,因为在这个时间点之前,walkFiles 可能被向下游发送数据的操作所阻塞:
m := make(map[string][md5.Size]byte)
for r := range c {
if r.err != nil {
return nil, r.err
}
m[r.path] = r.sum
}
// 检查 Walk 是否失败
if err := <-errc; err != nil {
return nil, err
}
return m, nil
}
结论
本文介绍了使用 Go 构建流式数据管道的技巧。处理这种管道的失败情况需要技巧,因为管道中每个阶段都可能在向下游发送数据时被阻塞,而这时下游可能已经不再读取上游数据了。我们展示了如何使用关闭一个 channel 来作为一个完成信号,去广播给管道启动的所有 goroutine。由此我们定义了正确构建管道的原则。
延伸阅读:
- Go Concurrent Patterns (video) 介绍了 Go 并发基本数据类型基础已经几种使用它们的方法。
- Advanced Go Concurrent Patterns (video) 介绍了更多对 Go 基础类型的复杂使用,尤其是 select。
- Douglas Mcllroy 的文章:Squinting at Power Series,介绍了Go 类型的并发如何为复杂的计算提供优雅支持。
-- 作者:Sameer Ajmani