[pytorch conv2d BatchNorm融合]

from TinyViT

class ConvBN2d(Sequential):
    """An implementation of Conv2d + BatchNorm2d with support of fusion.

    Modified from
    https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py

    Args:
        in_channels (int): The number of input channels.
        out_channels (int): The number of output channels.
        kernel_size (int): The size of the convolution kernel.
            Default: 1.
        stride (int): The stride of the convolution.
            Default: 1.
        padding (int): The padding of the convolution.
            Default: 0.
        dilation (int): The dilation of the convolution.
            Default: 1.
        groups (int): The number of groups in the convolution.
            Default: 1.
        bn_weight_init (float): The initial value of the weight of
            the nn.BatchNorm2d layer. Default: 1.0.
        init_cfg (dict): The initialization config of the module.
            Default: None.
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=1,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bn_weight_init=1.0,
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        self.add_module(
            'conv2d',
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation,
                groups=groups,
                bias=False))
        bn2d = nn.BatchNorm2d(num_features=out_channels)
        # bn initialization
        torch.nn.init.constant_(bn2d.weight, bn_weight_init)
        torch.nn.init.constant_(bn2d.bias, 0)
        self.add_module('bn2d', bn2d)

    @torch.no_grad()
    def fuse(self):
        conv2d, bn2d = self._modules.values()
        w = bn2d.weight / (bn2d.running_var + bn2d.eps)**0.5
        w = conv2d.weight * w[:, None, None, None]
        b = bn2d.bias - bn2d.running_mean * bn2d.weight / \
            (bn2d.running_var + bn2d.eps)**0.5

        m = nn.Conv2d(
            in_channels=w.size(1) * self.c.groups,
            out_channels=w.size(0),
            kernel_size=w.shape[2:],
            stride=self.conv2d.stride,
            padding=self.conv2d.padding,
            dilation=self.conv2d.dilation,
            groups=self.conv2d.groups)
        m.weight.data.copy_(w)
        m.bias.data.copy_(b)
        return m

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

放飞自我的Coder

你的鼓励很棒棒哦~

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值