Transformer的PositionEncoding源代码解读

1. trax库了解一下

详细了解戳 trax

trax库是google开源的一个深度学习代码库,基于tensorflow, jax实现了主流的深度学习模型。市场上有这么多开源的深度学习模型实现库,为什么还要搞个trax呢?它的特点就是聚焦端到端的深度学习模型,主打的就是实现简洁,好理解。当然也可以用它来实现自己的模型,CPU、TPU、GPU都能支持,选择trax的源代码, 当然也是因为google对模型的实现正确性更有保证(这一点很重要,对于初学者,看到错误的实现不一定能识别出来)。

tranformer源代码地址:transformer

2. PositionEncodingTransformer模型中的位置

这里上经典的tranformer架构图看一下,如下图,这个PositionEncoding在Encoder与Decoder的输入端都有使用,从图上看是PositionEncoding的输出与输入相加后做为Encoder-block与Decoder-block的输入, 实际的实现又是怎么做的呢?继续往下看。。。

3. PositionEncoding层的原理

位置编码是为了把序列的位置信息考虑进模型,在tranformer中,作者说因为模型不包含递归(如RNN)和卷积,所以加入PositionEncoding以利用序列的位置信息。

看一下原论文:Attention is all you need, 给出了以下两个公式:

 其中pos表示序列的位置,i代表维度(指嵌入维度,tranformer的嵌入维度是512,即dmodel是512,i的取值范围是0<= 2i < 2i + 1 <= 512, 这也就符合论文中说的函数波长从2π 到 10000 · 2π)。

从这个公式可以看出第0和1个嵌入维度位置函数周期是相同的,一个使用sin, 一个使用cos; 类推,第2,3个嵌入维度周期也是相同,波长按指数级扩大一些。。。

位置信息处理都有哪些方法呢?看一下new bing的回答,可以参考一下:

 

4. trax中PositionEncoding层的实现

在开始看PositionEncoding层之前,先看一下Tranformer模型层的组织方式,上源代码,函数也不长,直接都贴上了:

def Transformer(input_vocab_size,
                output_vocab_size=None,
                d_model=D_MODEL,
                d_ff=D_FF,
                n_encoder_layers=N_LAYERS,
                n_decoder_layers=N_LAYERS,
                n_heads=N_HEADS,
                max_len=MAX_SEQUENCE_LENGTH,
                dropout=DROPOUT_RATE,
                dropout_shared_axes=DROPOUT_SHARED_AXES,
                mode=MODE,
                ff_activation=FF_ACTIVATION_TYPE): 
  # Avoid 'predict' mode in encoder, since encoder doesn't run stepwise.
  encoder_mode = 'eval' if mode == 'predict' else mode

  # Share embedding weights if no separate output vocab size.
  in_embedder = tl.Embedding(input_vocab_size, d_model)
  if output_vocab_size is None:
    out_embedder = in_embedder
    output_vocab_size = input_vocab_size
  else:
    out_embedder = tl.Embedding(output_vocab_size, d_model)

  def _Dropout():
    return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)

  def _EncBlock():
    return _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                         mode, ff_activation)
  def _Encoder():
    encoder = tl.Serial(
        in_embedder,
        _Dropout(),
        #  这是编码器的位置编码层
        tl.PositionalEncoding(max_len=max_len, mode=encoder_mode),
        [_EncBlock() for _ in range(n_encoder_layers)],
        tl.LayerNorm(),
    )
    return tl.Cache(encoder) if mode == 'predict' else encoder

  def _EncDecBlock():
    return _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout,
                                dropout_shared_axes, mode, ff_activation)

  # Input to model is encoder-side tokens and decoder-side tokens: tok_d, tok_e
  # Model output is decoder-side vectors and decoder-side tokens: vec_d  tok_d
  return tl.Serial(
      tl.Select([0, 1, 1]),  # Copies decoder tokens for use in loss.

      # Encode.
      tl.Branch([], tl.PaddingMask()),  # tok_e masks tok_d tok_d
      _Encoder(),

      # Decode.
      tl.Select([2, 1, 0]),  # Re-orders inputs: tok_d masks vec_e .....
      tl.ShiftRight(mode=mode),  # 预测时直接返回x
      out_embedder,
      _Dropout(),
      #  这是解码器的位置编码层
      tl.PositionalEncoding(max_len=max_len, mode=mode),
      tl.Branch([], tl.EncoderDecoderMask()),  # vec_d masks_e masks vec_e tok_d ..... .....
      [_EncDecBlock() for _ in range(n_decoder_layers)],
      tl.LayerNorm(),
      tl.Select([0], n_in=3),  # Drops masks and encoding vectors.

      # Map vectors to match output vocab size.
      tl.Dense(output_vocab_size),
  )

模型所有的层都放在一个tl.Serial组合器内,根据return语句上方注释可以看到,模型的输入是两个token序列,一个是编码器的输入tok_e, 一个是解码器的输入tok_d, tok_e和tok_d可以放在一个元组里一起输入给模型,它们形状都是(batch, seq_length)(tl.Serial接受一个张量或一个元组/列表作为输入,参数在Serial的各层传递是按stack方式处理的,这一点有点奇葩,没有接触过的很容易掉坑里)。

tok_d,tok_e在经过词嵌入层处理后,会进行一次Dropout处理,然后进入位置编码层PositionEncodeing处理。下面看PositionEncodeing的源代码, 它包括两部分:初始化和调用,初始化的代码如下:

  def init_weights_and_state(self, input_signature):
    """Randomly initializes the positional encoding vectors.

    Args:
      input_signature: :py:class:`ShapeDtype` instance characterizing the input
          this layer should compute on.
    """
    d_feature = input_signature.shape[-1]
    if self._d_feature is not None:
      d_feature = self._d_feature
    # 初始化一个保存位置编码的矩阵,形状(seq_length, d_feature)
    pe = np.zeros((self._max_len, d_feature), dtype=np.float32)
    # 序列位置数组position, (seq_length, 1)
    position = np.arange(0, self._max_len)[:, np.newaxis]
    # 嵌入维度位置数组div_term, 形状(d_feature/2, ), 是一个一维数组
    div_term = np.exp(
        np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature))
    # 填充位置编码矩阵嵌入维度偶数下标位置的数据, position * div_term得到形状(seq_length, d_feature/2)的矩阵
    pe[:, 0::2] = np.sin(position * div_term)
    # 嵌入维度奇数下标位置的数据
    pe[:, 1::2] = np.cos(position * div_term)  # [self._max_len, d_feature]
    if self._use_bfloat16:
      pe = pe.astype(jnp.bfloat16)
    w = jnp.array(pe)  # Trainable parameters, initialized above.
    if self._d_feature is not None:
      ff = init.GlorotUniformInitializer()(
          (d_feature, input_signature.shape[-1]), self.rng)
      self.weights = w, ff
    else:
      self.weights = w
    if self._mode == 'predict':
      self.state = jnp.zeros((), dtype=jnp.int32)

在计算pos/(10000^(2i/dmodel))时, 分成了两部分计算:pos和1/(10000^(2i/dmodel))

再看下调用时的代码:

  def forward(self, inputs):
    """Returns the input activations, with added positional information."""
    weights = self.weights
    if self._d_feature is not None:
      weights, ff = weights
      weights = jnp.dot(weights[:inputs.shape[1], :], ff)
    if len(weights.shape) < 3:  # old checkpoints have 1 in first dim already
      #  初始化时weight是(len, d_feature)形态的, 这里给它加了一个batch维度, 即第0维
      weights = weights[None, :, :]  # [1, self._max_len, d_feature]
    if self._mode != 'predict':
      # 模型训练跑的分支
      x = inputs
      symbol_size = jnp.shape(x)[1]
      if self._mode != 'train' or self._start_from_zero_prob >= 1.0:
        # 指定从0位置开始或非训练模式,如eval模式, 从位置0取symbol_size个序列作为位置编码数据
        px = weights[:, :symbol_size, :]
      else:
        # 随机从位置0开始取位置编码数据
        rng1, rng2 = fastmath.random.split(self.rng, 2)
        start = fastmath.random.randint(rng1, (), 0, self._max_offset_to_add)
        start_from_zero = fastmath.random.uniform(rng2, (), jnp.float32, 0, 1)
        start = jnp.where(start_from_zero < self._start_from_zero_prob,
                          jnp.zeros((), dtype=jnp.int32), start)
        px = fastmath.dynamic_slice_in_dim(weights, start, symbol_size,
                                           axis=1)
      # dropout规则处理
      if self._dropout == 0:
        return x + px
      else:
        noise_shape = list(px.shape)
        for dim in self._dropout_broadcast_dims:
          noise_shape[dim] = 1
        keep_prob = 1.0 - self._dropout
        keep = fastmath.random.bernoulli(self.rng, keep_prob,
                                         tuple(noise_shape))
        multiplier = keep.astype(x.dtype) / keep_prob
        return x + px * multiplier
    else:
      # 模型预测跑的分支
      if self._dropout != 0:
        raise ValueError(f'In predict mode, but dropout rate '
                         f'({self._dropout}) is not zero.')

      # State in this class is only used for fast inference. In that case,
      # the model is called with consecutive elements position-by-position.
      # This positional encoding layer stores the index of the current
      # position and increments it on each call.
      # 根据输入序列形状,从weight取一个相同开关的位置编码矩阵相加
      emb = fastmath.dynamic_slice_in_dim(
          weights, self.state, inputs.shape[1], axis=1)
      self.state += inputs.shape[1]
      return inputs + emb

看代码可以看出这个PositionEncoding层是需要初始化的,不初始化,forward函数是会报错的,因为weight不初始化是一个空tuple, 没有shape属性, if len(weights.shape) < 3: 这个判断会出错, 但是在模型的定义中并没有看到初始化的代码?这个问题后面有时间再看看,有兴趣的也可以自己分析一下,如果有答案,也欢迎留言讨论。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值