(a)原始输入图像
(b)绿色部分表示激活的特征单元,b图表示了随机dropout激活单元,但是这样dropout后,网络还会从drouout掉的激活单元附近学习到同样的信息
(c)绿色部分表示激活的特征单元,c图表示本文的DropBlock,通过dropout掉一部分相邻的整片的区域(比如头和脚),网络就会去注重学习狗的别的部位的特征,来实现正确分类,从而表现出更好的泛化。
code:
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
class DropBlock2D(nn.Module):
r"""Randomly zeroes 2D spatial blocks of the input tensor.
As described in the paper
`DropBlock: A regularization method for convolutional networks`_ ,
dropping whole blocks of feature map allows to remove semantic
information as compared to regular dropout.
Args:
drop_prob (float): probability of an element to be dropped.
block_size (int): size of the block to drop
Shape:
- Input: `(N, C, H, W)`
- Output: `(N, C, H, W)`
.. _DropBlock: A regularization method for convolutional networks:
https://arxiv.org/abs/1810.12890
"""
def __init__(self, drop_prob, block_size):
super(DropBlock2D, self).__init__()
self.drop_prob = drop_prob
self.block_size = block_size
def forward(self, x):
# shape: (bsize, channels, height, width)
assert x.dim() == 4, \
"Expected input with 4 dimensions (bsize, channels, height, width)"
if not self.training or self.drop_prob == 0.:
return x
else:
# get gamma value
gamma = self._compute_gamma(x)
# sample mask
mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float()
# place mask on input device
mask = mask.to(x.device)
# compute block mask
block_mask = self._compute_block_mask(mask)
# apply block mask
out = x * block_mask[:, None, :, :]
# scale output
out = out * block_mask.numel() / block_mask.sum()
return out
def _compute_block_mask(self, mask):
block_mask = F.max_pool2d(input=mask[:, None, :, :],
kernel_size=(self.block_size, self.block_size),
stride=(1, 1),
padding=self.block_size // 2)
if self.block_size % 2 == 0:
block_mask = block_mask[:, :, :-1, :-1]
block_mask = 1 - block_mask.squeeze(1)
return block_mask
def _compute_gamma(self, x):
return self.drop_prob / (self.block_size ** 2)
if __name__ == "__main__":
# (bsize, n_feats, height, width)
x = torch.rand(10, 2, 8, 8)
drop_block = DropBlock2D(block_size=3, drop_prob=0.3)
regularized_x = drop_block(x)
print(regularized_x.shape)