Transformer方向
gMLP是一种使用MLP来获得transformer性能的方法,谷歌用更少的参数在大量实验上达到了transformer同样的精度。
gMLP的论文地址【链接】
讲解的代码地址【链接】
gMLP结构
class gMLP(nn.Module):
def __init__(
self,
*,
...
):
super().__init__()
dim_ff = dim * ff_mult
self.seq_len = seq_len
self.prob_survival = prob_survival
self.to_embed = nn.Embedding(num_tokens, dim) if exists(num_tokens) else nn.Identity()
self.layers = nn.ModuleList([Residual(PreNorm(dim, gMLPBlock(dim = dim, dim_ff = dim_ff, seq_len = seq_len, attn_dim = attn_dim, causal = causal, act = act))) for i in range(depth)])
# gmlp(norm(x))+x
self.to_logits = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_tokens)
) if exists(num_tokens) else nn.Identity()
def forward(self, x):
x = self.to_embed(x)
layers