MxNet系列——how_to——bucketing

博客新址: http://blog.xuezhisd.top
邮箱:xuezhisd@126.com


在MXNet中使用Bucketing

Bucketing是一种训练多个不同但又相似的结构的网络,这些网络共享相同的参数集。一个典型的应用是循环神经网络(RNNs)。在使用符号网络定义的工具箱中,实现RNNs通常会沿时间轴将网络显式地展开。显式地展开RNNs之前需要知道序列的长度。为了处理序列中的所有元素,我们需要将网络展开成最大可能的序列长度。然而这很浪费资源,因为对于较短的序列,大部分计算都是在填充后的数据上执行的。

Bucketing,是从 Tensorflow’s sequence training example 借鉴而来的一个简单的方法。它不再将网络展开成最大可能长度,而是展开成多个不同长度的实例(比如,长度为5, 10, 20, 30)。在训练过程中,对于不同长度的最小批数据,我们使用最恰当的展开模型。对于RNNs,尽管这些模型具有不同的架构,但参数在时间轴上是共享的。尽管选出的不同bucket的模型,并以不同的最小批来训练,但本质上都是在优化相同的参数集。MXNet 在所有的执行器中重复使用中间的存储缓存。

对于简单的RNNs,可以使用一个for循环来遍历输入序列,通过保持状态和沿时间的梯度之间的连接的方式沿时间反向传播。而然,这可能会使降低处理速度。这个方法能够处理不同长度的序列。但对于更加复杂的模型(比如,使用序列到序列网络的翻译模型)来说,并不容易展开。在这个例程中,我们将介绍MXNet的允许我们事先bucketing的APIs。

不同长度的序列训练PTB

在这个例程中,我们使用 PennTreeBank language model example 。如果你对这个例程不熟悉,请首先查看 原教程 (in Julia)

例程中使用的架构是两个LSTM层,加一个简单的单词嵌入层。原例程将模型沿时间展开成固定长度(32)。本例程将介绍如何使用bucketing来实现变长序列训练。

为了使用bucketing,MXNet需要知道如何为不同长度的序列构建一个新的展开的符号架构(图)。为了实现这个目的,我们不是构建一个使用固定 Symbol 的模型,而是使用一个回调函数,该函数对新的bucket key 生成一个新的 Symbol

model = mx.model.FeedForward(
        ctx     = contexts,
        symbol  = sym_gen)

sym_gen 必须是一个函数,它只有一个输入,即 bucket_key;并为这个bucket返回一个 Symbol。我们使用序列长度作为 bucket key。任何对象都可以用作bucket key。比如,在神经网络翻译应用中,不同长度的输入和输出序列的组合对应于不同的展开方式,一对长度值(输入/输出长度)可以用作bucket key。

def sym_gen(seq_len):
    return lstm_unroll(num_lstm_layer, seq_len, len(vocab),
                       num_hidden=num_hidden, num_embed=num_embed,
                       num_label=len(vocab))

数据迭代器需要报告 default_bucket_key,它允许MXNet在读取数据之前初始化参数。现在,模型能够以不同的buckets进行训练,这是通过共享参数和不同buckets之间的计算缓存。

为了训练,我们还需要为 DataIter 添加一些额外的bits。除了报告之前提到的 default_bucket_key之外,还需要为每最小批报告当前的 bucket_key。更具体的说,在每个最下批中,通过 DataIter 返回的 DataBatch 对象需要包含下面的附加属性:

  • bucket_key: 对应于一批数据的 bucket key。 在本例程中,它是指一批数据的序列长度。如果该bucket key对应的执行器还没有创建,将根据由函数 gen_sym 以bucket key为参数生成的符号模型,构建该bucket key对应的执行器。该执行器将会放在缓存中,以便未来使用。注意:生成的 Symbols 可能是任意的,但他们应具有相同的可训练参数和辅助状态。
  • provide_data: 和 DataIter 对象报告的信息相同。 因为现在每个bucket都对应一个不同的架构,它们可以有不同的输入。同时,确保 DataIter 对象返回的 provide_data 信息和 default_bucket_key的架构是兼容的。.
  • provide_label: 和 provide_data相同。

现在,DataIter 负责将数据分到不同的 buckets。 假如已经激活随机化,在么个最小批中,DataIter 随机选择一个 bucket (根据一个由bucket尺寸均衡的分布),然后从bucket中随机选择一个序列来组成一个最小批数据。如果有必要,它将对最小批中的不同长度的序列进行填充。

获取一个读取文本序列的 DataIter (它通过实现上述的API)的完整实现,请查看 example/rnn/lstm_ptb_bucketing.py。在本例中,你可以使用静态配置的 bucketing (比如,buckets = [10, 20, 30, 40, 50, 60]), 或者让 MXnet 根据dataset (buckets = [])自动生成 bucketing。后一种方法是通过添加一个和长度和输入数量相同的bucket(bucket足够长)来实现的。获取更多信息,请查看 default_gen_buckets().

Beyond Sequence Training

在本例程中,简单的描述了bucketing API是如何工作的。然而,bucketing API不限于上文使用的序列长度的bucketing。bucket的键(key)可以是任意的对象,只要 gen_sym 返回的架构兼容即可。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值