Transformer介绍
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.w1 = nn.Linear(d_model, d_ff)
self.w2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x):
x = self.w1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.w2(x)
return x
d_model = 512
d_ff = 2048
dropout = 0.2
x = out_mha
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
out_ff = ff(x)
print(out_ff)
print(out_ff.shape)
tensor([[[ 1.5385, -0.8189, -1.6781, ..., 1.5737, 0.4905, 0.8483],
[ 2.8966, -2.4892, -1.9388, ..., 1.7022, -0.2211, -0.7838],
[ 1.4625, -1.2973, -0.4546, ..., 2.4504, -1.5376, -0.8824],
[ 2.4148, -1.8958, -1.6720, ..., 1.6979, 0.3737, -0.1442]],
[[ 0.9309, 1.1935, 1.1984, ..., 2.3999, 0.3744, 0.2678],
[ 1.2424, 0.0684, 1.7166, ..., 2.2012, -0.7395, 0.5636],
[ 1.3801, -0.1511, 1.3062, ..., 1.5764, -0.5672, 0.4452],
[ 1.5959, 0.1437, 1.5425, ..., 2.1625, -1.0858, 0.1428]]],
grad_fn=<AddBackward0>)
torch.Size([2, 4, 512])