1. trax库了解一下
详细了解戳 trax
trax库是google开源的一个深度学习代码库,基于tensorflow, jax实现了主流的深度学习模型。市场上有这么多开源的深度学习模型实现库,为什么还要搞个trax呢?它的特点就是聚焦端到端的深度学习模型,主打的就是实现简洁,好理解。当然也可以用它来实现自己的模型,CPU、TPU、GPU都能支持,选择trax的源代码, 当然也是因为google对模型的实现正确性更有保证(这一点很重要,对于初学者,看到错误的实现不一定能识别出来)。
tranformer源代码地址:transformer
2. PositionEncoding在Transformer模型中的位置
这里上经典的tranformer架构图看一下,如下图,这个PositionEncoding在Encoder与Decoder的输入端都有使用,从图上看是PositionEncoding的输出与输入相加后做为Encoder-block与Decoder-block的输入, 实际的实现又是怎么做的呢?继续往下看。。。
3. PositionEncoding层的原理
位置编码是为了把序列的位置信息考虑进模型,在tranformer中,作者说因为模型不包含递归(如RNN)和卷积,所以加入PositionEncoding以利用序列的位置信息。
看一下原论文:Attention is all you need, 给出了以下两个公式:
其中pos表示序列的位置,i代表维度(指嵌入维度,tranformer的嵌入维度是512,即dmodel是512,i的取值范围是0<= 2i < 2i + 1 <= 512, 这也就符合论文中说的函数波长从2π 到 10000 · 2π)。
从这个公式可以看出第0和1个嵌入维度位置函数周期是相同的,一个使用sin, 一个使用cos; 类推,第2,3个嵌入维度周期也是相同,波长按指数级扩大一些。。。
位置信息处理都有哪些方法呢?看一下new bing的回答,可以参考一下:
4. trax中PositionEncoding层的实现
在开始看PositionEncoding层之前,先看一下Tranformer模型层的组织方式,上源代码,函数也不长,直接都贴上了:
def Transformer(input_vocab_size,
output_vocab_size=None,
d_model=D_MODEL,
d_ff=D_FF,
n_encoder_layers=N_LAYERS,
n_decoder_layers=N_LAYERS,
n_heads=N_HEADS,
max_len=MAX_SEQUENCE_LENGTH,
dropout=DROPOUT_RATE,
dropout_shared_axes=DROPOUT_SHARED_AXES,
mode=MODE,
ff_activation=FF_ACTIVATION_TYPE):
# Avoid 'predict' mode in encoder, since encoder doesn't run stepwise.
encoder_mode = 'eval' if mode == 'predict' else mode
# Share embedding weights if no separate output vocab size.
in_embedder = tl.Embedding(input_vocab_size, d_model)
if output_vocab_size is None:
out_embedder = in_embedder
output_vocab_size = input_vocab_size
else:
out_embedder = tl.Embedding(output_vocab_size, d_model)
def _Dropout():
return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)
def _EncBlock():
return _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
mode, ff_activation)
def _Encoder():
encoder = tl.Serial(
in_embedder,
_Dropout(),
# 这是编码器的位置编码层
tl.PositionalEncoding(max_len=max_len, mode=encoder_mode),
[_EncBlock() for _ in range(n_encoder_layers)],
tl.LayerNorm(),
)
return tl.Cache(encoder) if mode == 'predict' else encoder
def _EncDecBlock():
return _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout,
dropout_shared_axes, mode, ff_activation)
# Input to model is encoder-side tokens and decoder-side tokens: tok_d, tok_e
# Model output is decoder-side vectors and decoder-side tokens: vec_d tok_d
return tl.Serial(
tl.Select([0, 1, 1]), # Copies decoder tokens for use in loss.
# Encode.
tl.Branch([], tl.PaddingMask()), # tok_e masks tok_d tok_d
_Encoder(),
# Decode.
tl.Select([2, 1, 0]), # Re-orders inputs: tok_d masks vec_e .....
tl.ShiftRight(mode=mode), # 预测时直接返回x
out_embedder,
_Dropout(),
# 这是解码器的位置编码层
tl.PositionalEncoding(max_len=max_len, mode=mode),
tl.Branch([], tl.EncoderDecoderMask()), # vec_d masks_e masks vec_e tok_d ..... .....
[_EncDecBlock() for _ in range(n_decoder_layers)],
tl.LayerNorm(),
tl.Select([0], n_in=3), # Drops masks and encoding vectors.
# Map vectors to match output vocab size.
tl.Dense(output_vocab_size),
)
模型所有的层都放在一个tl.Serial组合器内,根据return语句上方注释可以看到,模型的输入是两个token序列,一个是编码器的输入tok_e, 一个是解码器的输入tok_d, tok_e和tok_d可以放在一个元组里一起输入给模型,它们形状都是(batch, seq_length)(tl.Serial接受一个张量或一个元组/列表作为输入,参数在Serial的各层传递是按stack方式处理的,这一点有点奇葩,没有接触过的很容易掉坑里)。
tok_d,tok_e在经过词嵌入层处理后,会进行一次Dropout处理,然后进入位置编码层PositionEncodeing处理。下面看PositionEncodeing的源代码, 它包括两部分:初始化和调用,初始化的代码如下:
def init_weights_and_state(self, input_signature):
"""Randomly initializes the positional encoding vectors.
Args:
input_signature: :py:class:`ShapeDtype` instance characterizing the input
this layer should compute on.
"""
d_feature = input_signature.shape[-1]
if self._d_feature is not None:
d_feature = self._d_feature
# 初始化一个保存位置编码的矩阵,形状(seq_length, d_feature)
pe = np.zeros((self._max_len, d_feature), dtype=np.float32)
# 序列位置数组position, (seq_length, 1)
position = np.arange(0, self._max_len)[:, np.newaxis]
# 嵌入维度位置数组div_term, 形状(d_feature/2, ), 是一个一维数组
div_term = np.exp(
np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature))
# 填充位置编码矩阵嵌入维度偶数下标位置的数据, position * div_term得到形状(seq_length, d_feature/2)的矩阵
pe[:, 0::2] = np.sin(position * div_term)
# 嵌入维度奇数下标位置的数据
pe[:, 1::2] = np.cos(position * div_term) # [self._max_len, d_feature]
if self._use_bfloat16:
pe = pe.astype(jnp.bfloat16)
w = jnp.array(pe) # Trainable parameters, initialized above.
if self._d_feature is not None:
ff = init.GlorotUniformInitializer()(
(d_feature, input_signature.shape[-1]), self.rng)
self.weights = w, ff
else:
self.weights = w
if self._mode == 'predict':
self.state = jnp.zeros((), dtype=jnp.int32)
在计算pos/(10000^(2i/dmodel))时, 分成了两部分计算:pos和1/(10000^(2i/dmodel))
再看下调用时的代码:
def forward(self, inputs):
"""Returns the input activations, with added positional information."""
weights = self.weights
if self._d_feature is not None:
weights, ff = weights
weights = jnp.dot(weights[:inputs.shape[1], :], ff)
if len(weights.shape) < 3: # old checkpoints have 1 in first dim already
# 初始化时weight是(len, d_feature)形态的, 这里给它加了一个batch维度, 即第0维
weights = weights[None, :, :] # [1, self._max_len, d_feature]
if self._mode != 'predict':
# 模型训练跑的分支
x = inputs
symbol_size = jnp.shape(x)[1]
if self._mode != 'train' or self._start_from_zero_prob >= 1.0:
# 指定从0位置开始或非训练模式,如eval模式, 从位置0取symbol_size个序列作为位置编码数据
px = weights[:, :symbol_size, :]
else:
# 随机从位置0开始取位置编码数据
rng1, rng2 = fastmath.random.split(self.rng, 2)
start = fastmath.random.randint(rng1, (), 0, self._max_offset_to_add)
start_from_zero = fastmath.random.uniform(rng2, (), jnp.float32, 0, 1)
start = jnp.where(start_from_zero < self._start_from_zero_prob,
jnp.zeros((), dtype=jnp.int32), start)
px = fastmath.dynamic_slice_in_dim(weights, start, symbol_size,
axis=1)
# dropout规则处理
if self._dropout == 0:
return x + px
else:
noise_shape = list(px.shape)
for dim in self._dropout_broadcast_dims:
noise_shape[dim] = 1
keep_prob = 1.0 - self._dropout
keep = fastmath.random.bernoulli(self.rng, keep_prob,
tuple(noise_shape))
multiplier = keep.astype(x.dtype) / keep_prob
return x + px * multiplier
else:
# 模型预测跑的分支
if self._dropout != 0:
raise ValueError(f'In predict mode, but dropout rate '
f'({self._dropout}) is not zero.')
# State in this class is only used for fast inference. In that case,
# the model is called with consecutive elements position-by-position.
# This positional encoding layer stores the index of the current
# position and increments it on each call.
# 根据输入序列形状,从weight取一个相同开关的位置编码矩阵相加
emb = fastmath.dynamic_slice_in_dim(
weights, self.state, inputs.shape[1], axis=1)
self.state += inputs.shape[1]
return inputs + emb
看代码可以看出这个PositionEncoding层是需要初始化的,不初始化,forward函数是会报错的,因为weight不初始化是一个空tuple, 没有shape属性, if len(weights.shape) < 3: 这个判断会出错, 但是在模型的定义中并没有看到初始化的代码?这个问题后面有时间再看看,有兴趣的也可以自己分析一下,如果有答案,也欢迎留言讨论。