pytorch:融合conv和bn

from copy import deepcopy
import torch
import torch.nn as nn

def fuse_conv_and_bn(conv, bn):
    # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
    fusedconv = (
        nn.Conv2d(
            conv.in_channels,
            conv.out_channels,
            kernel_size=conv.kernel_size,
            stride=conv.stride,
            padding=conv.padding,
            groups=conv.groups,
            bias=True,
        )
        .requires_grad_(False)
        .to(conv.weight.device)
    )

    # prepare filters
    w_conv = conv.weight.clone().view(conv.out_channels, -1)
    w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
    fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))

    # prepare spatial bias
    b_conv = (
        torch.zeros(conv.weight.size(0), device=conv.weight.device)
        if conv.bias is None
        else conv.bias
    )
    b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
        torch.sqrt(bn.running_var + bn.eps)
    )
    fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)

    return fusedconv


def fuse_model(model):
    from yolox.models.network_blocks import BaseConv

    for m in model.modules():
        if type(m) is BaseConv and hasattr(m, "bn"):
            m.conv = fuse_conv_and_bn(m.conv, m.bn)  # update conv
            delattr(m, "bn")  # remove batchnorm
            m.forward = m.fuseforward  # update forward
    return model


def replace_module(module, replaced_module_type, new_module_type, replace_func=None):
    """
    Replace given type in module to a new type. mostly used in deploy.

    Args:
        module (nn.Module): model to apply replace operation.
        replaced_module_type (Type): module type to be replaced.
        new_module_type (Type)
        replace_func (function): python function to describe replace logic. Defalut value None.

    Returns:
        model (nn.Module): module that already been replaced.
    """

    def default_replace_func(replaced_module_type, new_module_type):
        return new_module_type()

    if replace_func is None:
        replace_func = default_replace_func

    model = module
    if isinstance(module, replaced_module_type):
        model = replace_func(replaced_module_type, new_module_type)
    else:  # recurrsively replace
        for name, child in module.named_children():
            new_child = replace_module(child, replaced_module_type, new_module_type)
            if new_child is not child:  # child is already replaced
                model.add_module(name, new_child)

    return model



if __name__ == '__main__':
    from torchvision.models import resnet18
    model = resnet18(True)
    fusedconv = fuse_conv_and_bn(model.conv1, model.bn1)
    delattr(model, "bn1")
    model.add_module("conv1", fusedconv)
    print(model)
    


  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是一个简单的DCA特征融合pytorch代码实现: ```python import torch import torch.nn as nn class DCA(nn.Module): def __init__(self, in_channels): super(DCA, self).__init__() self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(in_channels) self.relu1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(in_channels) self.relu2 = nn.ReLU(inplace=True) self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) self.bn3 = nn.BatchNorm2d(in_channels) self.relu3 = nn.ReLU(inplace=True) def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu1(out) out = self.conv2(out) out = self.bn2(out) out = self.relu2(out) out = self.conv3(out) out = self.bn3(out) out = out + identity out = self.relu3(out) return out class DCA_Fusion(nn.Module): def __init__(self, in_channels, num_branches): super(DCA_Fusion, self).__init__() self.num_branches = num_branches self.dca_layers = nn.ModuleList() for i in range(num_branches): self.dca_layers.append(DCA(in_channels)) def forward(self, x): out = x[0] for i in range(1, self.num_branches): out = out + x[i] for i in range(self.num_branches): x[i] = self.dca_layers[i](x[i]) out = out / self.num_branches out = torch.cat(x, dim=1) return out ``` 在这个实现中,我们首先定义了一个DCA模块,它包含了三个卷积BN层,每个卷积后面都有ReLU激活函数。然后我们定义了一个DCA_Fusion模块,它包含了多个DCA模块,并且实现了特征融合的功能。在forward函数中,我们首先对输入的多个特征图进行求和,然后将每个特征图都输入到对应的DCA模块中进行特征增强,最后将增强后的特征图进行拼接。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值