一、dropout网络原理:
dropout是在深度网络训练时防止过拟合,其原理是在一批(Batch)图片中,有drop_prob概率的样本图片不会被执行参与模型的运算,对不参与计算的图片的矩阵的所有值赋为0。具体步骤如下(根据下面的代码解释):
1.获得图片及批量4X2X3X3(4张大小为3X3X2的图片),再输入不计算的概率drop_prob=0.5;
2.计算要保留的概率keep_prob = 1 - drop_prob;
3.获得特征值的大小(shape),在这里为4X1X1X1;
4.根据伯努利分布(bernoulli_)随机获取要保留和不计算的图片;
5.对保留计算的图片的值整除保留的概率,对不计算的图片的值乘以0。
二、代码解释如下:
注意:在该代码运行中,假设输入一个4X2X3X3的矩阵,如果是图片,4表示要测的图片的批量Batch大小,2X3X3表示图片的尺寸3X3X2,其他的看代码注释即可。
# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import numpy as np import torch import torch.nn as nn # drop_prob:要删除的比率(0,1);在drop_path分支中,每个batch有drop_prob的概率样本在 self.conv(x) 不会 “执行”,会以0直接传递 def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): #如果drop_prob == 0. or not training,则返回x本身,即对x不做处理n if drop_prob == 0. or not training: return x # 要保留的比率(0,1) keep_prob = 1 - drop_prob #获取特征的shape:2X1X1X1 (4, 1, 1, 1) shape = (x.shape[0],) + (1,) * (x.ndim - 1) print(shape) #bernoulli_是伯努利函数,即通过参数为keep_prob的伯努利分布获取到保留的矩阵,random_tensor输出为tensor([[[[1.]]],[[[1.]]],[[[0.]]],[[[1.]]]]),里面的0和1随机分配 #上面的输出表示有4个矩阵,1,2,4保留进行计算,3的所有元素取值为0,即3不参与后续的计算 random_tensor = x.new_empty(shape).bernoulli_(keep_prob) #print(random_tensor) if keep_prob > 0.0 and scale_by_keep: #random_tensor=random_tensor//keep_prob,表示整除keep_prob random_tensor.div_(keep_prob) #print(random_tensor) #返回值相当于保留计算的矩阵整除保留的比率,不保留计算的矩阵的值全修改为0 return x * random_tensor class DropPath(nn.Module): """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=0.5, scale_by_keep=True): super(DropPath, self).__init__() self.drop_prob = drop_prob self.scale_by_keep = scale_by_keep def forward(self, x): return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) if __name__ == "__main__": #输入一个4X2X3X3的array数组 np_data=np.arange(72).reshape((4,2,3,3)) #array---->tensor datas=torch.tensor(np_data,dtype=torch.float32) print(datas.shape) #调用DropPath网络 Drop=DropPath() output=Drop(datas) print("datas:", datas) print("output:",output)