博客新址: http://blog.xuezhisd.top
邮箱:xuezhisd@126.com
在MXNet中使用Bucketing
Bucketing是一种训练多个不同但又相似的结构的网络,这些网络共享相同的参数集。一个典型的应用是循环神经网络(RNNs)。在使用符号网络定义的工具箱中,实现RNNs通常会沿时间轴将网络显式地展开。显式地展开RNNs之前需要知道序列的长度。为了处理序列中的所有元素,我们需要将网络展开成最大可能的序列长度。然而这很浪费资源,因为对于较短的序列,大部分计算都是在填充后的数据上执行的。
Bucketing,是从 Tensorflow’s sequence training example 借鉴而来的一个简单的方法。它不再将网络展开成最大可能长度,而是展开成多个不同长度的实例(比如,长度为5, 10, 20, 30)。在训练过程中,对于不同长度的最小批数据,我们使用最恰当的展开模型。对于RNNs,尽管这些模型具有不同的架构,但参数在时间轴上是共享的。尽管选出的不同bucket的模型,并以不同的最小批来训练,但本质上都是在优化相同的参数集。MXNet 在所有的执行器中重复使用中间的存储缓存。
对于简单的RNNs,可以使用一个for循环来遍历输入序列,通过保持状态和沿时间的梯度之间的连接的方式沿时间反向传播。而然,这可能会使降低处理速度。这个方法能够处理不同长度的序列。但对于更加复杂的模型(比如,使用序列到序列网络的翻译模型)来说,并不容易展开。在这个例程中,我们将介绍MXNet的允许我们事先bucketing的APIs。
不同长度的序列训练PTB
在这个例程中,我们使用 PennTreeBank language model example 。如果你对这个例程不熟悉,请首先查看 原教程 (in Julia)。
例程中使用的架构是两个LSTM层,加一个简单的单词嵌入层。原例程将模型沿时间展开成固定长度(32)。本例程将介绍如何使用bucketing来实现变长序列训练。
为了使用bucketing,MXNet需要知道如何为不同长度的序列构建一个新的展开的符号架构(图)。为了实现这个目的,我们不是构建一个使用固定 Symbol
的模型,而是使用一个回调函数,该函数对新的bucket key 生成一个新的 Symbol
。
model = mx.model.FeedForward(
ctx = contexts,
symbol = sym_gen)
sym_gen
必须是一个函数,它只有一个输入,即 bucket_key
;并为这个bucket返回一个 Symbol
。我们使用序列长度作为 bucket key。任何对象都可以用作bucket key。比如,在神经网络翻译应用中,不同长度的输入和输出序列的组合对应于不同的展开方式,一对长度值(输入/输出长度)可以用作bucket key。
def sym_gen(seq_len):
return lstm_unroll(num_lstm_layer, seq_len, len(vocab),
num_hidden=num_hidden, num_embed=num_embed,
num_label=len(vocab))
数据迭代器需要报告 default_bucket_key
,它允许MXNet在读取数据之前初始化参数。现在,模型能够以不同的buckets进行训练,这是通过共享参数和不同buckets之间的计算缓存。
为了训练,我们还需要为 DataIter
添加一些额外的bits。除了报告之前提到的 default_bucket_key
之外,还需要为每最小批报告当前的 bucket_key
。更具体的说,在每个最下批中,通过 DataIter
返回的 DataBatch
对象需要包含下面的附加属性:
bucket_key
: 对应于一批数据的 bucket key。 在本例程中,它是指一批数据的序列长度。如果该bucket key对应的执行器还没有创建,将根据由函数gen_sym
以bucket key为参数生成的符号模型,构建该bucket key对应的执行器。该执行器将会放在缓存中,以便未来使用。注意:生成的Symbol
s 可能是任意的,但他们应具有相同的可训练参数和辅助状态。provide_data
: 和DataIter
对象报告的信息相同。 因为现在每个bucket都对应一个不同的架构,它们可以有不同的输入。同时,确保DataIter
对象返回的provide_data
信息和default_bucket_key
的架构是兼容的。.provide_label
: 和provide_data
相同。
现在,DataIter
负责将数据分到不同的 buckets。 假如已经激活随机化,在么个最小批中,DataIter
随机选择一个 bucket (根据一个由bucket尺寸均衡的分布),然后从bucket中随机选择一个序列来组成一个最小批数据。如果有必要,它将对最小批中的不同长度的序列进行填充。
获取一个读取文本序列的 DataIter
(它通过实现上述的API)的完整实现,请查看 example/rnn/lstm_ptb_bucketing.py。在本例中,你可以使用静态配置的 bucketing (比如,buckets = [10, 20, 30, 40, 50, 60]
), 或者让 MXnet 根据dataset (buckets = []
)自动生成 bucketing。后一种方法是通过添加一个和长度和输入数量相同的bucket(bucket足够长)来实现的。获取更多信息,请查看 default_gen_buckets().
Beyond Sequence Training
在本例程中,简单的描述了bucketing API是如何工作的。然而,bucketing API不限于上文使用的序列长度的bucketing。bucket的键(key)可以是任意的对象,只要 gen_sym
返回的架构兼容即可。