GPT-2代码解读[3]:Block

GPT-2代码解读[3]:Block

Overview

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fBtHYbm0-1582276934126)(C:\Users\co\AppData\Roaming\Typora\typora-user-images\1582248633145.png)]

模型由12个基本块构成,每一块由三部分构成,我们已经考虑过和Embedding与Attention相关的部分,现在考虑最后一部分:Add&MLP。

记Attention层的输出为a,块输入为x。

Add&MLP层的信息流动如下:

x = x + a x=x+a x=x+a

m = m l p ( x ) m=mlp(x) m=mlp(x)

x = x + m x=x+m x=x+m

首先经过一层残差将a附加于已有信息x,确保不损失已有信息。

然后将此时信息做一次mlp1,得到特征m。

最后将特征m附加于已有信息x,得到本层的输出。

可以训练的参数是:mlp中的两个线性层

def block(x, scope, *, past, hparams):

nx = x.shape[-1].value
a, present = attn(norm(x, 'ln_1'), 'attn', nx, past=past, hparams=hparams)
x = x + a
m = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams)
x = x + m

前两句我们已经熟知,分别是取embedding维度为nx和做Attention。

注意这里在做Attention层之前先对x做normalization,这也是cv里常见的做法,暂不分析。

x=x+a

残差操作,这里的’+'就是element-wise plus。

此时x=a=[batch,seq,embedding],加和之后依然是[batch,seq,embedding]。

m = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams)

m是经过mlp进一步提取的特征。

def mlp(x, scope, n_state, *, hparams):

n_state表示第一层线性变换的特征维度。

nx = x.shape[-1].value
h = gelu(conv1d(x, 'c_fc', n_state))
h2 = conv1d(h, 'c_proj', nx)

线性变换到n_state维,gelu激活,再变换回nx维。

x = x + m

将mlp得到的信息m残差加和到已有信息x。


  1. mlp是中间带有激活函数的两层线性层,Attention机制一般与mlp配合使用,否则缺失非线性变换。 ↩︎

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值