nn.Dropout是常用的防止过拟合的技术,同时有提生模型鲁棒性等诸多优点。
在pytorch中,Dropout共有两个参数:
- p – probability of an element to be zeroed. Default: 0.5
p为元素被置0的概率,即被‘丢’掉的概率- inplace – If set to True, will do this operation in-place. Default: False
是否进行就地操作,默认为否
下面记录几个例子方便理解:
import torch
import torch.utils.data as Data
import numpy as np
from torch import nn
test1=torch.tensor(np.ones((10)))
print(test1)
#tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=torch.float64)
shape=test1.size()[0]
l1=nn.Linear(shape,shape)
test1=l1(test1)
'''tensor([ 1.2093, 0.4491, -1.3389, -0.1974, -0.3614, -0.4381, 0.0361, 0.5686,
-0.3484, 0.0110], grad_fn=<AddBackward0>)'''
d=nn.Dropout(0.3)
test1=d(test1)
'''
tensor([ 1.7276, 0.6416, -0.0000, -0.2820, -0.0000, -0.6258, 0.0000, 0.0000,
-0.4977, 0.0157], grad_fn=<MulBackward0>)
note:
实测,0.3只是一个大概的概率,对单样本来说并不完全精确,但在大样本下无限接近于0.3
'''
l2=nn.Linear(shape,shape)
test1=l2(test1)
'''
tensor([ 0.2900, 0.2310, 0.0016, -0.0806, 0.5717, -0.3444, -0.4782, 0.0626,
0.5609, 0.5971], grad_fn=<AddBackward0>)
'''
print(l2.weight)
print(l2.weight.size())
'''
tensor([[-0.1172, 0.1366, -0.1985, -0.2270, 0.2187, 0.0618, 0.0044, 0.0445,
-0.2501, -0.1953],
[ 0.2652, -0.1562, -0.0418, -0.1766, 0.1065, -0.0898, 0.0416, 0.2394,
-0.0220, -0.1494],
[ 0.0389, 0.1469, -0.0870, 0.1732, 0.0536, -0.1480, -0.2159, -0.2245,
0.1741, 0.1966],
[-0.0640, 0.1407, -0.2929, -0.2074, -0.0870, 0.0739, 0.1317, 0.1397,
-0.1414, -0.1190],
[ 0.2213, -0.0548, -0.2381, -0.3103, -0.0315, 0.0421, 0.2332, -0.0250,
-0.0922, -0.2802],
[-0.2725, 0.1465, -0.0210, 0.2928, 0.2386, 0.1554, -0.0050, -0.2504,
-0.3003, 0.1500],
[-0.1510, -0.2754, 0.1225, -0.0894, -0.0776, 0.1202, 0.2340, -0.0242,
0.2875, 0.1247],
[ 0.2576, 0.2990, -0.2164, 0.1260, -0.2340, 0.2892, -0.1079, 0.0174,
0.2302, -0.0158],
[ 0.2372, 0.1228, 0.0886, 0.2946, 0.2048, -0.3042, -0.2103, -0.1585,
0.2601, -0.3085],
[ 0.1054, 0.2851, 0.2042, -0.2554, 0.0978, -0.2108, 0.2923, 0.1692,
-0.1682, -0.0952]], requires_grad=True)
torch.Size([10, 10])
'''
记录一个小知识点,
不构建nn.model时,
单行使用完
l2=nn.Linear(shape,shape)
test1=l2(test1)
l2.weight会自动更新