F.dropout 是一个函数,参数包含输入的tensor,概率和training 为真还是假。当training 是真的时候,才会将一部分元素置为0,其他元素会乘以 scale 1/(1-p). training 为false的时候,dropout不起作用。默认情况下training是True。
nn.Dropout 是一个类,参数只有两个,概率和inplace,返回的是一个对象,对象的参数是input的tensor,返回的是处理之后的tensor。在构造对象的时候,inplace如果设置成true,那么原来的input 也会发生改变,和输出一样。默认inplace是False。
import torch
import torch.nn as nn
import torch.nn.functional as F
if __name__ == '__main__':
input = torch.randn(5, 5)
output1 = F.dropout(input,p=0.5, training = True)
print(input,'\n' ,output1)
m = nn.Dropout(p=0.5,inplace=True)
output2 = m(input)
# output = m(input)
print(input,'\n' ,output2,'\n' ,(input==output2))
output3 = F.dropout(input, p=0.5, training=False)
print(input, '\n', output3,'\n',(input==output3).all())
输出:
可以看到在第一个tensor的输出部分元素被置为0,未被置为0的元素值 扩大了2倍(1/(1-0.5)).
inplace设置为True,输入会被输出覆盖。training设置为False的时候,输出不发生变化,相当于是identity 矩阵。
tensor([[ 1.2217, 0.0197, 1.6543, -0.2873, -1.8832],
[ 0.9392, -0.4237, 0.2885, -1.2082, 0.3996],
[ 1.2309, -0.1735, -0.4891, 0.2473, -1.3866],
[ 1.1813, 1.4792, 0.1614, -0.3758, 0.4791],
[ 1.4066, 0.6809, 1.1930, -0.6618, 1.1439]])
tensor([[ 0.0000, 0.0000, 0.0000, -0.0000, -0.0000],
[ 1.8785, -0.0000, 0.0000, -2.4164, 0.0000],
[ 2.4619, -0.0000, -0.9783, 0.4946, -2.7733],
[ 2.3626, 0.0000, 0.0000, -0.7517, 0.0000],
[ 0.0000, 1.3618, 2.3861, -1.3235, 0.0000]])
tensor([[ 2.4434, 0.0394, 3.3086, -0.5745, -0.0000],
[ 0.0000, -0.8475, 0.0000, -0.0000, 0.0000],
[ 0.0000, -0.3470, -0.9783, 0.4946, -2.7733],
[ 0.0000, 2.9584, 0.0000, -0.0000, 0.9582],
[ 0.0000, 0.0000, 2.3861, -0.0000, 2.2878]])
tensor([[ 2.4434, 0.0394, 3.3086, -0.5745, -0.0000],
[ 0.0000, -0.8475, 0.0000, -0.0000, 0.0000],
[ 0.0000, -0.3470, -0.9783, 0.4946, -2.7733],
[ 0.0000, 2.9584, 0.0000, -0.0000, 0.9582],
[ 0.0000, 0.0000, 2.3861, -0.0000, 2.2878]])
tensor([[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True]])
tensor([[ 2.4434, 0.0394, 3.3086, -0.5745, -0.0000],
[ 0.0000, -0.8475, 0.0000, -0.0000, 0.0000],
[ 0.0000, -0.3470, -0.9783, 0.4946, -2.7733],
[ 0.0000, 2.9584, 0.0000, -0.0000, 0.9582],
[ 0.0000, 0.0000, 2.3861, -0.0000, 2.2878]])
tensor([[ 2.4434, 0.0394, 3.3086, -0.5745, -0.0000],
[ 0.0000, -0.8475, 0.0000, -0.0000, 0.0000],
[ 0.0000, -0.3470, -0.9783, 0.4946, -2.7733],
[ 0.0000, 2.9584, 0.0000, -0.0000, 0.9582],
[ 0.0000, 0.0000, 2.3861, -0.0000, 2.2878]])
tensor(True)