DropPath代码
DropPath代码
最近在学习ViT模型,记录一下其中的droppath操作,实际上就是对一个batch中随机选择一定数量的sample,将其特征值变为0:
ViT github源码地址链接
def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
# shape (b, 1, 1, 1...)
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
# 向下取整用于确定保存哪些样本
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
# 除以keep_prob是为了保持训练和测试时期望一致
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)