from TinyViT
class ConvBN2d(Sequential):
"""An implementation of Conv2d + BatchNorm2d with support of fusion.
Modified from
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
kernel_size (int): The size of the convolution kernel.
Default: 1.
stride (int): The stride of the convolution.
Default: 1.
padding (int): The padding of the convolution.
Default: 0.
dilation (int): The dilation of the convolution.
Default: 1.
groups (int): The number of groups in the convolution.
Default: 1.
bn_weight_init (float): The initial value of the weight of
the nn.BatchNorm2d layer. Default: 1.0.
init_cfg (dict): The initialization config of the module.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
dilation=1,
groups=1,
bn_weight_init=1.0,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.add_module(
'conv2d',
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=False))
bn2d = nn.BatchNorm2d(num_features=out_channels)
# bn initialization
torch.nn.init.constant_(bn2d.weight, bn_weight_init)
torch.nn.init.constant_(bn2d.bias, 0)
self.add_module('bn2d', bn2d)
@torch.no_grad()
def fuse(self):
conv2d, bn2d = self._modules.values()
w = bn2d.weight / (bn2d.running_var + bn2d.eps)**0.5
w = conv2d.weight * w[:, None, None, None]
b = bn2d.bias - bn2d.running_mean * bn2d.weight / \
(bn2d.running_var + bn2d.eps)**0.5
m = nn.Conv2d(
in_channels=w.size(1) * self.c.groups,
out_channels=w.size(0),
kernel_size=w.shape[2:],
stride=self.conv2d.stride,
padding=self.conv2d.padding,
dilation=self.conv2d.dilation,
groups=self.conv2d.groups)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m