class PositionwiseFeedForward(nn.Module):
"Implements FFN equation."
def __init__(self, d_model, d_ff, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)#剖析点1
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)#剖析点2
def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))
剖析源码
1 剖析点1:self.w_1 = nn.Linear(d_model, d_ff)
这里的d_model是embedding的长度一般取512
d_ff是inner_layer的维度:2048
2 剖析点2:nn.Dropout(dropout)
参考:https://blog.csdn.net/weixin_42979152/article/details/113769291
注意区别nn.Dropout(dropout)
和F.dropout(dropout)
dropout=torch.tensor(0.5)
print(dropout)
torch.nn.functional.dropout(dropout)
torch.nn.functional.dropout
的函数头是torch.nn.functional.dropout(input,p=0.5,training=False,inplace=False)
这里必须要手动设置training=True,否则是没有启用dropout的。torch.nn.Dropout
的函数头是torch.nn.Dropout(p=0.5,inplace=False)
功能:将输入的张量的部分元素设置为0,对于每次前向调用,被置为0的元素都是随机的
p:为将元素置为0的概率
inplace:若设置为True则表示原地操作
输入是张量
输出也是张量,且输出的张量与输入的张量维度相同
from torch import autograd
m=nn.Dropout(p=0.3)
input=autograd.Variable(torch.rand(2,3))
print(input)
output=m(input)
print(output)
#输出
tensor([[0.6117, 0.5744, 0.0756],
[0.9749, 0.2046, 0.4306]])
tensor([[0.0000, 0.8205, 0.1080],
[0.0000, 0.0000, 0.0000]])
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.dropout_1 = nn.Dropout(0.5)
self.dropout_2 = nn.Dropout(0.5)
def forward(self, input):
# print(input)
drop_1 = self.dropout_1(input)
print(drop_1)
drop_1 = self.dropout_1(input)
print(drop_1)
drop_2 = self.dropout_2(input)
print(drop_2)
if __name__ == '__main__':
i = torch.rand((3,3))
print(i.shape)
print(i)
m = MyModel()
m.forward(i)
#输出
torch.Size([3, 3])
tensor([[0.2487, 0.1715, 0.3385],
[0.0692, 0.9432, 0.7410],
[0.6616, 0.7565, 0.8751]])
tensor([[0.0000, 0.3430, 0.6769],
[0.0000, 1.8864, 1.4819],
[1.3233, 1.5130, 1.7502]])
tensor([[0.0000, 0.0000, 0.0000],
[0.1385, 0.0000, 0.0000],
[0.0000, 1.5130, 1.7502]])
tensor([[0.4974, 0.3430, 0.0000],
[0.0000, 1.8864, 0.0000],
[0.0000, 0.0000, 0.0000]])