DBB Code Learning

原文

READM.md

Diverse Branch Block: Building a Convolution as an Inception-like Unit
DBB是一种可以替代常规卷积的ConvNet(Convolutional Neural Network)的模块,使用它可以提高精度且不以增加推理时间损耗为成本。
you can also get the equivalnet kernel and bias in a differentiable way at any time(get_equivalent_kernel_bias in diversebranchblock.py)

Abstract

1、提出一个CNNs的模块DBB替换网络中的常规卷积,该模块可以不以增加推理时间损耗为成本的提高精度;
2、DBB包含六个分支及对应转换措施;
3、鲁棒性高,在分类、目标检测及分割中都获得了较好的成绩

Use our pretrained models

1、百度云可以下载文中提到的实验模型(https://pan.baidu.com/s/1wPaQnLKyNjF_bEMNRo4z6Q, “dbbk”)
2、为便于在其他任务上进行迁移学习,提供了训练和推理模型(以RESnET18为例,IMGNET_PATH 指的是"train"和"val"的根目录 ),示例代码如下:

python test.py IMGNET_PATH train ResNet-18_DBB_7101.pth -a ResNet-18 -t DBB

Convert the training-time models into inference-time

通过convert.py来将训练模型转化为推力模型

python convert.py [weights file of the training-time model to load] [path to save] -a [architecture name]
# exmaple
python convert.py ResNet-18_DBB_7101.pth ResNet-18_DBB_7101_deploy.pth -a ResNet-18

随后通过test.py来使用推理模型实现预测,参数“deploy”构建推理模型。

python test.py IMGNET_PATH deploy ResNet-18_DBB_7101_deploy.pth -a ResNet-18 -t DBB

ImageNet training

项目相较于pytorch官方展示,改进部分包括模型构建部分和余弦学习率策略等,训练和测试案例如下

python train.py -a ResNet-18 -t DBB --dist-url tcp://127.0.0.1:23333 --dist-backend nccl --multiprocessing-distributed --world-size 1 --rank 0 --workers 64 IMGNET_PATH
python test.py IMGNET_PATH train model_best.pth.tar -a ResNet-18

Use like this in your own code

假定你的模型为以下样式

class SomeModel(nn.Module):
    def __init__(self, ...):
        ...
        self.some_conv = nn.Conv2d(...)
        self.some_bn = nn.BatchNorm2d(...)
        ...
        
    def forward(self, inputs):
        out = ...
        out = self.some_bn(self.some_conv(out))
        ...

对于训练,可以用DBB来替代conv-BN,这样你的模型就变成如下所示

class SomeModel(nn.Module):
    def __init__(self, ...):
        ...
        self.some_dbb = DiverseBranchBlock(..., deploy=False)
        ...
        
    def forward(self, inputs):
        out = ...
        out = self.some_dbb(out)
        ...

训练模型就像训练其他常规模型一样,随后,调用每个DBB的switch_to_deploy,测试并保存

model = SomeModel(...)
train(model)
for m in train_model.modules():
    if hasattr(m, 'switch_to_deploy'):
        m.switch_to_deploy()
test(model)
save(model)

FAQs

Q: Is the inference-time model’s output the same as the training-time model?
A:是的,你可以通过以下验证

python dbb_verify.py

Q: What is the relationship between DBB and RepVGG?
RepVGG是一种普通的体系结构,RepVGG块与单个3x3 conv相比没有任何优势(正如RepVGG论文中报道的那样,它仅提高Resnet50 0.03%)。DBB是一个通用的构建块,可以在许多体系结构上使用。

Structural Re-parameterization Universe

RepLKNet (CVPR 2022)→→→→https://arxiv.org/abs/2203.06717)\ code
RepOptimizer (ICLR 2023) →→→→https://arxiv.org/pdf/2205.15242.pdf)\ code
RepVGG (CVPR 2021) →→→→https://arxiv.org/abs/2101.03697)\ code
RepMLP (CVPR 2022)→→→→https://arxiv.org/abs/2112.11081)\ code
ResRep (ICCV 2021) →→→→https://openaccess.thecvf.com/content/ICCV2021/papers/Ding_ResRep_Lossless_CNN_Pruning_via_Decoupling_Remembering_and_Forgetting_ICCV_2021_paper.pdf)\ code
ACB (ICCV 2019)→→→→http://openaccess.thecvf.com/content_ICCV_2019/papers/Ding_ACNet_Strengthening_the_Kernel_Skeletons_for_Powerful_CNN_via_Asymmetric_ICCV_2019_paper.pdf).\ code
DBB (CVPR 2021)→→→→https://arxiv.org/abs/2103.13425)\ code

abc.py

class ACBlock(nn.Module)

参数

in_channels, 输入通道
out_channels, 输出通道
kernel_size, 卷积核大小
stride=1, 步长
padding=0, 补0
dilation=1, 空洞率
groups=1, 组卷积数
padding_mode=‘zeros’, padding模式
deploy=False,
use_affine=True,
reduce_gamma=False,
gamma_init=None

deploy

   # if deploy = true 会产生一个由输入参数构成的普通conv
   # else deploy = False,首先会根据参数产生一个conv-BN
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False,
                 use_affine=True, reduce_gamma=False, gamma_init=None ):
        super(ACBlock, self).__init__()
        self.deploy = deploy
        if deploy:
            self.fused_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size,kernel_size), stride=stride,
                                      padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
        else:
            self.square_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                         kernel_size=(kernel_size, kernel_size), stride=stride,
                                         padding=padding, dilation=dilation, groups=groups, bias=False,
                                         padding_mode=padding_mode)
            self.square_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)


            if padding - kernel_size // 2 >= 0:
                #   Common use case. E.g., k=3, p=1 or k=5, p=2
                self.crop = 0
                #   Compared to the KxK layer, the padding of the 1xK layer and Kx1 layer should be adjust to align the sliding windows (Fig 2 in the paper)
                hor_padding = [padding - kernel_size // 2, padding]
                ver_padding = [padding, padding - kernel_size // 2]
            else:
                #   A negative "padding" (padding - kernel_size//2 < 0, which is not a common use case) is cropping.
                #   Since nn.Conv2d does not support negative padding, we implement it manually
                self.crop = kernel_size // 2 - padding
                hor_padding = [0, padding]
                ver_padding = [padding, 0]

            self.ver_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, 1),
                                      stride=stride,
                                      padding=ver_padding, dilation=dilation, groups=groups, bias=False,
                                      padding_mode=padding_mode)

            self.hor_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, kernel_size),
                                      stride=stride,
                                      padding=hor_padding, dilation=dilation, groups=groups, bias=False,
                                      padding_mode=padding_mode)
            self.ver_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)
            self.hor_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)

            if reduce_gamma:
                self.init_gamma(1.0 / 3)

            if gamma_init is not None:
                assert not reduce_gamma
                self.init_gamma(gamma_init)

others

    def _fuse_bn_tensor(self, conv, bn):
        std = (bn.running_var + bn.eps).sqrt()
        t = (bn.weight / std).reshape(-1, 1, 1, 1)
        return conv.weight * t, bn.bias - bn.running_mean * bn.weight / std

    def _add_to_square_kernel(self, square_kernel, asym_kernel):
        asym_h = asym_kernel.size(2)
        asym_w = asym_kernel.size(3)
        square_h = square_kernel.size(2)
        square_w = square_kernel.size(3)
        square_kernel[:, :, square_h // 2 - asym_h // 2: square_h // 2 - asym_h // 2 + asym_h,
        square_w // 2 - asym_w // 2: square_w // 2 - asym_w // 2 + asym_w] += asym_kernel

    def get_equivalent_kernel_bias(self):
        hor_k, hor_b = self._fuse_bn_tensor(self.hor_conv, self.hor_bn)
        ver_k, ver_b = self._fuse_bn_tensor(self.ver_conv, self.ver_bn)
        square_k, square_b = self._fuse_bn_tensor(self.square_conv, self.square_bn)
        self._add_to_square_kernel(square_k, hor_k)
        self._add_to_square_kernel(square_k, ver_k)
        return square_k, hor_b + ver_b + square_b

    def switch_to_deploy(self):
        deploy_k, deploy_b = self.get_equivalent_kernel_bias()
        self.deploy = True
        self.fused_conv = nn.Conv2d(in_channels=self.square_conv.in_channels,
                                    out_channels=self.square_conv.out_channels,
                                    kernel_size=self.square_conv.kernel_size, stride=self.square_conv.stride,
                                    padding=self.square_conv.padding, dilation=self.square_conv.dilation,
                                    groups=self.square_conv.groups, bias=True,
                                    padding_mode=self.square_conv.padding_mode)
        self.__delattr__('square_conv')
        self.__delattr__('square_bn')
        self.__delattr__('hor_conv')
        self.__delattr__('hor_bn')
        self.__delattr__('ver_conv')
        self.__delattr__('ver_bn')
        self.fused_conv.weight.data = deploy_k
        self.fused_conv.bias.data = deploy_b

    def init_gamma(self, gamma_value):
        init.constant_(self.square_bn.weight, gamma_value)
        init.constant_(self.ver_bn.weight, gamma_value)
        init.constant_(self.hor_bn.weight, gamma_value)
        print('init gamma of square, ver and hor as ', gamma_value)

    def single_init(self):
        init.constant_(self.square_bn.weight, 1.0)
        init.constant_(self.ver_bn.weight, 0.0)
        init.constant_(self.hor_bn.weight, 0.0)
        print('init gamma of square as 1, ver and hor as 0')

    def forward(self, input):
        if self.deploy:
            return self.fused_conv(input)
        else:
            square_outputs = self.square_conv(input)
            square_outputs = self.square_bn(square_outputs)
            if self.crop > 0:
                ver_input = input[:, :, :, self.crop:-self.crop]
                hor_input = input[:, :, self.crop:-self.crop, :]
            else:
                ver_input = input
                hor_input = input
            vertical_outputs = self.ver_conv(ver_input)
            vertical_outputs = self.ver_bn(vertical_outputs)
            horizontal_outputs = self.hor_conv(hor_input)
            horizontal_outputs = self.hor_bn(horizontal_outputs)
            result = square_outputs + vertical_outputs + horizontal_outputs
            return result

alexnet.py

定义alexnet

import torch.nn as nn
import torch.nn.functional as F
from convnet_utils import conv_bn, conv_bn_relu

def create_stem(channels):
    stem = nn.Sequential()
    stem.add_module('conv1', conv_bn_relu(in_channels=3, out_channels=channels[0], kernel_size=11, stride=4, padding=2))
    stem.add_module('maxpool1', nn.Maxpool2d(kernel_size=3, stride=2))
    stem.add_module('conv2', conv_bn_relu(in_channels=channels[0], out_channels=channels[1], kernel_size=5, padding=2))
    stem.add_module('maxpool2', nn.Maxpool2d(kernel_size=3, stride=2))
    stem.add_module('conv3', conv_bn_relu(in_channels=channels[1], out_channels=channels[2], kernel_size=3, padding=1))
    stem.add_module('conv4', conv_bn_relu(in_channels=channels[2], out_channels=channels[3], kernel_size=3, padding=1))
    stem.add_module('conv5', conv_bn_relu(in_channels=channels[3], out_channels=channels[4], kernel_size=3, padding=1))
    stem.add_module('maxpool3', nn.Maxpool2d(kernel_size=3, stride=2))
    return stem

class AlexNet(nn.Module):

    def __init__(self):
        super(AlexNet, self).__init__()
        channels = [64, 192, 384, 384, 256]
        self.stem = create_stem(channels)
        self.linear1 = nn.Linear(in_features=channels[4] * 6 * 6, out_features=4096)
        self.relu1 = nn.ReLU()
        self.drop1 = nn.Dropout(0.5)
        self.linear2 = nn.Linear(in_features=4096, out_features=4096)
        self.relu2 = nn.ReLU()
        self.drop2 = nn.Dropout(0.5)
        self.linear3 = nn.Linear(in_features=4096, out_features=1000)

    def forward(self, x):
        out = self.stem(x)
        out = out.view(out.size(0), -1)
        out = self.linear1(out)
        out = self.relu1(out)
        out = self.drop1(out)
        out = self.linear2(out)
        out = self.relu2(out)
        out = self.drop2(out)
        out = self.linear3(out)
        return out

def create_AlexNet():
    return AlexNet()

dbb_transforms.py

定义dbb六种转换方式的代码

import torch
import numpy as np
import torch.nn.functional as F


def transI_fusebn(kernel, bn):
    gamma = bn.weight
    std = (bn.running_var + bn.eps).sqrt()
    return kernel * ((gamma / std).reshape(-1, 1, 1, 1)), bn.bias - bn.running_mean * gamma / std


def transII_addbranch(kernels, biases):
    return sum(kernels), sum(biases)


def transIII_1x1_kxk(k1, b1, k2, b2, groups):
    if groups == 1:
        k = F.conv2d(k2, k1.permute(1, 0, 2, 3))  #
        b_hat = (k2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3))
    else:
        k_slices = []
        b_slices = []
        k1_T = k1.permute(1, 0, 2, 3)
        k1_group_width = k1.size(0) // groups
        k2_group_width = k2.size(0) // groups
        for g in range(groups):
            k1_T_slice = k1_T[:, g * k1_group_width:(g + 1) * k1_group_width, :, :]
            k2_slice = k2[g * k2_group_width:(g + 1) * k2_group_width, :, :, :]
            k_slices.append(F.conv2d(k2_slice, k1_T_slice))
            b_slices.append(
                (k2_slice * b1[g * k1_group_width:(g + 1) * k1_group_width].reshape(1, -1, 1, 1)).sum((1, 2, 3)))
        k, b_hat = transIV_depthconcat(k_slices, b_slices)
    return k, b_hat + b2


def transIV_depthconcat(kernels, biases):
    return torch.cat(kernels, dim=0), torch.cat(biases)


def transV_avg(channels, kernel_size, groups):
    input_dim = channels // groups
    k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
    k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
    return k


#   This has not been tested with non-square kernels (kernel.size(2) != kernel.size(3)) nor even-size kernels
def transVI_multiscale(kernel, target_kernel_size):
    H_pixels_to_pad = (target_kernel_size - kernel.size(2)) // 2
    W_pixels_to_pad = (target_kernel_size - kernel.size(3)) // 2
    return F.pad(kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad])

convnet_utils.py

定义部分基础函数,应用在convert.py中。

补充(组卷积)

普通卷积

首先,我们先明确一下普通卷积的过程,当我们定义如下卷积时,输入三通道,输出四通道,在卷积时会产生四个卷积核,每个卷积核的维度与输入相一致;在卷积过程中,输入和每个卷积核对应通道进行卷积操作随后代数相加生成一张特征图,并以相同操作生成剩余特征图,这便是普通卷积。

nn.Conv2d(in_channels=3,
          out_channels=4,
          kernel_size=3)

组卷积(group convolution)

下图中,图左为普通卷积,图右为组卷积,两者直观地差异就很明显了,组卷积顾名思义要分组,而普通卷积只需要直接卷积即可,当然具体计算过程和原理并非这么简单,红色小堆叠快就是一个卷积核。
图中想展示的是一张特征图通过普通卷积和组卷积由12维度变为6维度的过程。首先,我们可以直观看出的是无论是组卷积还是普通卷积,卷积核的个数是一致的,但是,以图中三组为例,对于每个卷积核来说,组卷积卷积核的维度是普通卷积的三分之一。
在这里插入图片描述

深度可分离卷积(Depthwise seperable convolution)

我们进一步推广,实质上MoblieNet系列中所使用的深度可分离卷积中的depthwise卷积就是(组数=输入通道=输出通道)的分组卷积,在操作上就是把每个3通道的卷积核砍成1通道的卷积核,但是卷积核的数量还是3个。
深度可分离卷积分为Depthwise Convolution和Pointwise Convolution。

Depthwise Convolution

Depthwise Convolution就是组卷积,一个卷积核负责一个通道,一个通道只被一个卷积核卷积,最会保证所产生的featuremap通道数和输入通道数一致。
我们以知乎上的一张图为例,将一张三通道的彩图作为输入,经Depthwise Convolution不会改变图像的通道数,因此最终输出为通道数为3的特征图。
在这里插入图片描述

Pointwise Convolution

Depthwise Convolution的弊端在运算时已经被充分暴露,即这种运算对输入层的每个通道独立计算,没有有效的利用不同通道在相同像素位置的特征信息,其实,这也可以说是组卷积的通病。在moblienet中则是以Pointwise Convolution来弥补这个问题。
Pointwise Convolution的运算和常规卷积运算基本一致,它使用的是1x1conv,所以这里的卷及操作会将上一步的feature map在深度方向上进行加权组合,生成新的特征。
所以抛开华丽的外表,Pointwise Convolution实际上就是我们现在常用于改变通道维度的1x1conv。
在这里插入图片描述

补充(一些方法)

nn.Identity()

nn.Identity()实质上就是一个线性映射 f(x)=x,在官方文档中给出了这样的解释:nn.Identity() will just return its input, but does not show that it is a view.nn.Identity will just return the input without any clone usage or manipulation of the input. The input and output would thus be the same.You might not want to use this layer as it’s not “doing anything” besides just returning the input.However, there are use cases where users needed exactly this (e.g. to replace another layer) and were manually creating custom modules to do so and asked for the nn.Identity layer in the PyTorch nn backend. Since more and more users were depending on it, it was created.However, as already said, this layer might not be interesting for you.

hasattr( )

hasattr( )用于判断对象是否包含对应的属性

hasattr(object, name)
object -- 对象。
name -- 字符串,属性名。
return
如果对象有该属性返回 True,否则返回 False
class variable:
    x = 1
    y = 'a'
    z = True

dd = variable() 
print(hasattr(dd, 'x'))
print(hasattr(dd, 'y'))
print(hasattr(dd, 'z'))
print(hasattr(dd, 'no'))

True
True
True
False

代码

import torch
import torch.nn as nn
from diversebranchblock import DiverseBranchBlock
from acb import ACBlock
from dbb_transforms import transI_fusebn

CONV_BN_IMPL = 'base'

DEPLOY_FLAG = False


class ConvBN(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride, padding, dilation, groups, deploy=False, nonlinear=None):
        super().__init__()
        # 判断线性与否
        if nonlinear is None:
            self.nonlinear = nn.Identity()
        else:
            self.nonlinear = nonlinear
        # 根据deploy确定创建conv还是conv-bn
        if deploy:
            self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                  stride=stride, padding=padding, dilation=dilation, groups=groups, bias=True)
        else:
            self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                  stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False)
            self.bn = nn.BatchNorm2d(num_features=out_channels)

    def forward(self, x):
        # 判断是否包含bn,有的话就在conv后bn一下,没有的话就直接线性
        if hasattr(self, 'bn'):
            return self.nonlinear(self.bn(self.conv(x)))
        else:
            return self.nonlinear(self.conv(x))

    # 切换到deploy
    def switch_to_deploy(self):
        # 获取第一类转换的卷积核及偏置
        kernel, bias = transI_fusebn(self.conv.weight, self.bn)
        # 定义一个新的卷积核用于接收训练过程的权重
        conv = nn.Conv2d(in_channels=self.conv.in_channels, out_channels=self.conv.out_channels,
                         kernel_size=self.conv.kernel_size,
                         stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation,
                         groups=self.conv.groups, bias=True)
        # 使用第一类转换的参数更新conv的卷积核和偏置
        conv.weight.data = kernel
        conv.bias.data = bias

        for para in self.parameters():
            para.detach_()
        self.__delattr__('conv')
        self.__delattr__('bn')
        self.conv = conv


# conv-bn
def conv_bn(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1):
    if CONV_BN_IMPL == 'base' or kernel_size == 1 or kernel_size >= 7:
        blk_type = ConvBN
    elif CONV_BN_IMPL == 'ACB':
        blk_type = ACBlock
    else:
        blk_type = DiverseBranchBlock
    return blk_type(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                    padding=padding, dilation=dilation, groups=groups, deploy=DEPLOY_FLAG)


# conv-bn-relu
def conv_bn_relu(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1):
    if CONV_BN_IMPL == 'base' or kernel_size == 1 or kernel_size >= 7:
        blk_type = ConvBN
    elif CONV_BN_IMPL == 'ACB':
        blk_type = ACBlock
    else:
        blk_type = DiverseBranchBlock
    return blk_type(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                    padding=padding, dilation=dilation, groups=groups, deploy=DEPLOY_FLAG, nonlinear=nn.ReLU())


def switch_conv_bn_impl(block_type):
    assert block_type in ['base', 'DBB', 'ACB']
    global CONV_BN_IMPL
    CONV_BN_IMPL = block_type


def switch_deploy_flag(deploy):
    global DEPLOY_FLAG
    DEPLOY_FLAG = deploy
    print('deploy flag: ', DEPLOY_FLAG)


def build_model(arch):
    if arch == 'ResNet-18':
        from resnet import create_Res18
        model = create_Res18()
    elif arch == 'ResNet-50':
        from resnet import create_Res50
        model = create_Res50()
    elif arch == 'MobileNet':
        from mobilenet import create_MobileNet
        model = create_MobileNet()
    else:
        raise ValueError('TODO')
    return model

convert.py

补充(argparse)

argparse是python自带的命令行参数解析包,可以用来方便读取命令行参数,使用简单。
argparse基本使用方式可以分为三步骤:首先,我们使用argparse.ArgumentParser生成一个参数解释器,然后,通过实例对象的add_argument方法像对象中添加成员,最后,通过实例对象的parse_args方法获取解析的参数。

import argparse


def main():
    # 使用ArgumentParser类生成一个parser对象,通常成为参数解释器
    parser = argparse.ArgumentParser(description="Demo of argparse")
    # 使用add_argument方法来增加参数,-n或者--name代指相同,default表示默认请款下的参数
    parser.add_argument('-n', '--name', default=' Li ')
    parser.add_argument('-y', '--year', default='20')
    # parse_args方法获得解析的参数
    args = parser.parse_args()
    print(args)
    name = args.name
    year = args.year
    print('Hello {}  {}'.format(name, year))


if __name__ == '__main__':
    main()

补充(conv.weight.data)

conv.weight.data可以获取卷积核,在重参数化,常用于修改卷积核的权重和偏置
conv = nn.Conv2d()生成的对象,其属性conv.weight并不是一个tensor类,而是一个torch.nn.parameter.Parameter, conv.weight.data才是一个torch.Tensor类。

# 随机定义一个张量作为卷积核
kernel_data = torch.rand(1,1,3,3)
print(kernel_data )
# 定义一个卷积操作,用来接收卷积核
conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=(3,3),stride=1, padding=1, padding_mode='zeros', bias=False)
# 显示当前conv中的卷积核
print(conv.weight.data)
# 使用kernel_data 替换conv的卷积核
conv.weight = nn.Parameter(kernel_data)
# 显示替换后的卷积核
print(conv.weight.data)

补充(delattr)

如果去对象中读取不到某属性,就会进入getattr;如果去对象中赋值或者修改某属性,就会进入setattr;如果删除对象中的属性,就会进入delattr。

class Foo:
    x = 1
 
    def __init__(self, y):
        self.y = y
 
    def __getattr__(self, item):
        print('----> from getattr:你找的属性不存在')
 
    def __setattr__(self, key, value):
        print('----> from setattr')
        # self.key = value  # 这就无限递归了,你好好想想
        # self.__dict__[key] = value  # 应该使用它
 
    def __delattr__(self, item):
        print('----> from delattr')
        # del self.item  # 无限递归了
        self.__dict__.pop(item)
 
 
f1 = Foo(10)

代码

import argparse
import os
import torch
from convnet_utils import switch_conv_bn_impl, switch_deploy_flag, build_model

parser = argparse.ArgumentParser(description='DBB Conversion')
parser.add_argument('load', metavar='LOAD', help='path to the weights file')
parser.add_argument('save', metavar='SAVE', help='path to the weights file')
parser.add_argument('-a', '--arch', metavar='ARCH', default='ResNet-18')

def convert():
    args = parser.parse_args()

    switch_conv_bn_impl('DBB')
    switch_deploy_flag(False)
    train_model = build_model(args.arch)

    if 'hdf5' in args.load:
        from utils import model_load_hdf5
        model_load_hdf5(train_model, args.load)
    elif os.path.isfile(args.load):
        print("=> loading checkpoint '{}'".format(args.load))
        checkpoint = torch.load(args.load)
        if 'state_dict' in checkpoint:
            checkpoint = checkpoint['state_dict']
        ckpt = {k.replace('module.', ''): v for k, v in checkpoint.items()}  # strip the names
        train_model.load_state_dict(ckpt)
    else:
        print("=> no checkpoint found at '{}'".format(args.load))

    for m in train_model.modules():
        if hasattr(m, 'switch_to_deploy'):
            m.switch_to_deploy()

    torch.save(train_model.state_dict(), args.save)


if __name__ == '__main__':
    convert()

dbb_verify.py

import torch
import torch.nn as nn
from diversebranchblock import DiverseBranchBlock


if __name__ == '__main__':
    x = torch.randn(1, 32, 56, 56)
    for k in (3, 5):
        for s in (1, 2):
            dbb = DiverseBranchBlock(in_channels=32, out_channels=64, kernel_size=k, stride=s, padding=k//2,
                                           groups=2, deploy=False)
            for module in dbb.modules():
                if isinstance(module, torch.nn.BatchNorm2d):
                    nn.init.uniform_(module.running_mean, 0, 0.1)
                    nn.init.uniform_(module.running_var, 0, 0.1)
                    nn.init.uniform_(module.weight, 0, 0.1)
                    nn.init.uniform_(module.bias, 0, 0.1)
            dbb.eval()
            print(dbb)
            train_y = dbb(x)
            dbb.switch_to_deploy()
            deploy_y = dbb(x)
            print(dbb)
            print('========================== The diff is')
            print(((train_y - deploy_y) ** 2).sum())

diversebranchblock.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from dbb_transforms import *

def conv_bn(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,
                   padding_mode='zeros'):
    conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                           stride=stride, padding=padding, dilation=dilation, groups=groups,
                           bias=False, padding_mode=padding_mode)
    bn_layer = nn.BatchNorm2d(num_features=out_channels, affine=True)
    se = nn.Sequential()
    se.add_module('conv', conv_layer)
    se.add_module('bn', bn_layer)
    return se


class IdentityBasedConv1x1(nn.Conv2d):

    def __init__(self, channels, groups=1):
        super(IdentityBasedConv1x1, self).__init__(in_channels=channels, out_channels=channels, kernel_size=1, stride=1, padding=0, groups=groups, bias=False)

        assert channels % groups == 0
        input_dim = channels // groups
        id_value = np.zeros((channels, input_dim, 1, 1))
        for i in range(channels):
            id_value[i, i % input_dim, 0, 0] = 1
        self.id_tensor = torch.from_numpy(id_value).type_as(self.weight)
        nn.init.zeros_(self.weight)

    def forward(self, input):
        kernel = self.weight + self.id_tensor.to(self.weight.device)
        result = F.conv2d(input, kernel, None, stride=1, padding=0, dilation=self.dilation, groups=self.groups)
        return result

    def get_actual_kernel(self):
        return self.weight + self.id_tensor.to(self.weight.device)


class BNAndPadLayer(nn.Module):
    def __init__(self,
                 pad_pixels,
                 num_features,
                 eps=1e-5,
                 momentum=0.1,
                 affine=True,
                 track_running_stats=True):
        super(BNAndPadLayer, self).__init__()
        self.bn = nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats)
        self.pad_pixels = pad_pixels

    def forward(self, input):
        output = self.bn(input)
        if self.pad_pixels > 0:
            if self.bn.affine:
                pad_values = self.bn.bias.detach() - self.bn.running_mean * self.bn.weight.detach() / torch.sqrt(self.bn.running_var + self.bn.eps)
            else:
                pad_values = - self.bn.running_mean / torch.sqrt(self.bn.running_var + self.bn.eps)
            output = F.pad(output, [self.pad_pixels] * 4)
            pad_values = pad_values.view(1, -1, 1, 1)
            output[:, :, 0:self.pad_pixels, :] = pad_values
            output[:, :, -self.pad_pixels:, :] = pad_values
            output[:, :, :, 0:self.pad_pixels] = pad_values
            output[:, :, :, -self.pad_pixels:] = pad_values
        return output

    @property
    def weight(self):
        return self.bn.weight

    @property
    def bias(self):
        return self.bn.bias

    @property
    def running_mean(self):
        return self.bn.running_mean

    @property
    def running_var(self):
        return self.bn.running_var

    @property
    def eps(self):
        return self.bn.eps


class DiverseBranchBlock(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, groups=1,
                 internal_channels_1x1_3x3=None,
                 deploy=False, nonlinear=None, single_init=False):
        super(DiverseBranchBlock, self).__init__()
        self.deploy = deploy

        if nonlinear is None:
            self.nonlinear = nn.Identity()
        else:
            self.nonlinear = nonlinear

        self.kernel_size = kernel_size
        self.out_channels = out_channels
        self.groups = groups
        assert padding == kernel_size // 2

        if deploy:
            self.dbb_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                                      padding=padding, dilation=dilation, groups=groups, bias=True)

        else:

            self.dbb_origin = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups)

            self.dbb_avg = nn.Sequential()
            if groups < out_channels:
                self.dbb_avg.add_module('conv',
                                        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1,
                                                  stride=1, padding=0, groups=groups, bias=False))
                self.dbb_avg.add_module('bn', BNAndPadLayer(pad_pixels=padding, num_features=out_channels))
                self.dbb_avg.add_module('avg', nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0))
                self.dbb_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride,
                                       padding=0, groups=groups)
            else:
                self.dbb_avg.add_module('avg', nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=padding))

            self.dbb_avg.add_module('avgbn', nn.BatchNorm2d(out_channels))


            if internal_channels_1x1_3x3 is None:
                internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels   # For mobilenet, it is better to have 2X internal channels

            self.dbb_1x1_kxk = nn.Sequential()
            if internal_channels_1x1_3x3 == in_channels:
                self.dbb_1x1_kxk.add_module('idconv1', IdentityBasedConv1x1(channels=in_channels, groups=groups))
            else:
                self.dbb_1x1_kxk.add_module('conv1', nn.Conv2d(in_channels=in_channels, out_channels=internal_channels_1x1_3x3,
                                                            kernel_size=1, stride=1, padding=0, groups=groups, bias=False))
            self.dbb_1x1_kxk.add_module('bn1', BNAndPadLayer(pad_pixels=padding, num_features=internal_channels_1x1_3x3, affine=True))
            self.dbb_1x1_kxk.add_module('conv2', nn.Conv2d(in_channels=internal_channels_1x1_3x3, out_channels=out_channels,
                                                            kernel_size=kernel_size, stride=stride, padding=0, groups=groups, bias=False))
            self.dbb_1x1_kxk.add_module('bn2', nn.BatchNorm2d(out_channels))

        #   The experiments reported in the paper used the default initialization of bn.weight (all as 1). But changing the initialization may be useful in some cases.
        if single_init:
            #   Initialize the bn.weight of dbb_origin as 1 and others as 0. This is not the default setting.
            self.single_init()

    def get_equivalent_kernel_bias(self):
        k_origin, b_origin = transI_fusebn(self.dbb_origin.conv.weight, self.dbb_origin.bn)

        if hasattr(self, 'dbb_1x1'):
            k_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight, self.dbb_1x1.bn)
            k_1x1 = transVI_multiscale(k_1x1, self.kernel_size)
        else:
            k_1x1, b_1x1 = 0, 0

        if hasattr(self.dbb_1x1_kxk, 'idconv1'):
            k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel()
        else:
            k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight
        k_1x1_kxk_first, b_1x1_kxk_first = transI_fusebn(k_1x1_kxk_first, self.dbb_1x1_kxk.bn1)
        k_1x1_kxk_second, b_1x1_kxk_second = transI_fusebn(self.dbb_1x1_kxk.conv2.weight, self.dbb_1x1_kxk.bn2)
        k_1x1_kxk_merged, b_1x1_kxk_merged = transIII_1x1_kxk(k_1x1_kxk_first, b_1x1_kxk_first, k_1x1_kxk_second, b_1x1_kxk_second, groups=self.groups)

        k_avg = transV_avg(self.out_channels, self.kernel_size, self.groups)
        k_1x1_avg_second, b_1x1_avg_second = transI_fusebn(k_avg.to(self.dbb_avg.avgbn.weight.device), self.dbb_avg.avgbn)
        if hasattr(self.dbb_avg, 'conv'):
            k_1x1_avg_first, b_1x1_avg_first = transI_fusebn(self.dbb_avg.conv.weight, self.dbb_avg.bn)
            k_1x1_avg_merged, b_1x1_avg_merged = transIII_1x1_kxk(k_1x1_avg_first, b_1x1_avg_first, k_1x1_avg_second, b_1x1_avg_second, groups=self.groups)
        else:
            k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second

        return transII_addbranch((k_origin, k_1x1, k_1x1_kxk_merged, k_1x1_avg_merged), (b_origin, b_1x1, b_1x1_kxk_merged, b_1x1_avg_merged))

    def switch_to_deploy(self):
        if hasattr(self, 'dbb_reparam'):
            return
        kernel, bias = self.get_equivalent_kernel_bias()
        self.dbb_reparam = nn.Conv2d(in_channels=self.dbb_origin.conv.in_channels, out_channels=self.dbb_origin.conv.out_channels,
                                     kernel_size=self.dbb_origin.conv.kernel_size, stride=self.dbb_origin.conv.stride,
                                     padding=self.dbb_origin.conv.padding, dilation=self.dbb_origin.conv.dilation, groups=self.dbb_origin.conv.groups, bias=True)
        self.dbb_reparam.weight.data = kernel
        self.dbb_reparam.bias.data = bias
        for para in self.parameters():
            para.detach_()
        self.__delattr__('dbb_origin')
        self.__delattr__('dbb_avg')
        if hasattr(self, 'dbb_1x1'):
            self.__delattr__('dbb_1x1')
        self.__delattr__('dbb_1x1_kxk')

    def forward(self, inputs):

        if hasattr(self, 'dbb_reparam'):
            return self.nonlinear(self.dbb_reparam(inputs))

        out = self.dbb_origin(inputs)
        if hasattr(self, 'dbb_1x1'):
            out += self.dbb_1x1(inputs)
        out += self.dbb_avg(inputs)
        out += self.dbb_1x1_kxk(inputs)
        return self.nonlinear(out)

    def init_gamma(self, gamma_value):
        if hasattr(self, "dbb_origin"):
            torch.nn.init.constant_(self.dbb_origin.bn.weight, gamma_value)
        if hasattr(self, "dbb_1x1"):
            torch.nn.init.constant_(self.dbb_1x1.bn.weight, gamma_value)
        if hasattr(self, "dbb_avg"):
            torch.nn.init.constant_(self.dbb_avg.avgbn.weight, gamma_value)
        if hasattr(self, "dbb_1x1_kxk"):
            torch.nn.init.constant_(self.dbb_1x1_kxk.bn2.weight, gamma_value)

    def single_init(self):
        self.init_gamma(0.0)
        if hasattr(self, "dbb_origin"):
            torch.nn.init.constant_(self.dbb_origin.bn.weight, 1.0)

moblienet.py

import torch.nn as nn
import torch.nn.functional as F
from convnet_utils import conv_bn_relu
MOBILE_CHANNELS = [32,
                   32, 64,
                   64, 128,
                   128, 128,
                   128, 256,
                   256, 256,
                   256, 512,
                   512, 512, 512, 512, 512, 512, 512, 512, 512, 512,
                   512, 1024,
                   1024, 1024]

class MobileV1Block(nn.Module):
    '''Depthwise conv + Pointwise conv'''
    def __init__(self, in_planes, out_planes, stride=1):
        super(MobileV1Block, self).__init__()
        self.depthwise = conv_bn_relu(in_channels=in_planes, out_channels=in_planes, kernel_size=3,
                                          stride=stride, padding=1, groups=in_planes)
        self.pointwise = conv_bn_relu(in_channels=in_planes, out_channels=out_planes, kernel_size=1,
                                          stride=1, padding=0)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out


class MobileV1(nn.Module):

    def __init__(self, num_classes):
        super(MobileV1, self).__init__()
        channels = MOBILE_CHANNELS
        assert len(channels) == 27
        self.conv1 = conv_bn_relu(in_channels=3, out_channels=channels[0], kernel_size=3, stride=2, padding=1)
        blocks = []
        for block_idx in range(13):
            depthwise_channels = int(channels[block_idx * 2 + 1])
            pointwise_channels = int(channels[block_idx * 2 + 2])
            stride = 2 if block_idx in [1, 3, 5, 11] else 1
            blocks.append(MobileV1Block(in_planes=depthwise_channels, out_planes=pointwise_channels, stride=stride))
        self.stem = nn.Sequential(*blocks)
        self.gap = nn.AdaptiveAvgPool2d(output_size=1)
        self.linear = nn.Linear(channels[-1], num_classes)

    def forward(self, x):
        out = self.conv1(x)
        out = self.stem(out)
        out = self.gap(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def create_MobileNet():
    return MobileV1(num_classes=1000)

resnet.py

import torch.nn as nn
import torch.nn.functional as F
from convnet_utils import conv_bn, conv_bn_relu

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = conv_bn(in_channels=in_planes, out_channels=self.expansion * planes, kernel_size=1, stride=stride)
        else:
            self.shortcut = nn.Identity()
        self.conv1 = conv_bn_relu(in_channels=in_planes, out_channels=planes, kernel_size=3, stride=stride, padding=1)
        self.conv2 = conv_bn(in_channels=planes, out_channels=self.expansion * planes, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = out + self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = conv_bn(in_planes, self.expansion*planes, kernel_size=1, stride=stride)
        else:
            self.shortcut = nn.Identity()

        self.conv1 = conv_bn_relu(in_planes, planes, kernel_size=1)
        self.conv2 = conv_bn_relu(planes, planes, kernel_size=3, stride=stride, padding=1)
        self.conv3 = conv_bn(planes, self.expansion*planes, kernel_size=1)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=1000, width_multiplier=1):
        super(ResNet, self).__init__()

        self.in_planes = int(64 * width_multiplier)
        self.stage0 = nn.Sequential()
        self.stage0.add_module('conv1', conv_bn_relu(in_channels=3, out_channels=self.in_planes, kernel_size=7, stride=2, padding=3))
        self.stage0.add_module('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        self.stage1 = self._make_stage(block, int(64 * width_multiplier), num_blocks[0], stride=1)
        self.stage2 = self._make_stage(block, int(128 * width_multiplier), num_blocks[1], stride=2)
        self.stage3 = self._make_stage(block, int(256 * width_multiplier), num_blocks[2], stride=2)
        self.stage4 = self._make_stage(block, int(512 * width_multiplier), num_blocks[3], stride=2)
        self.gap = nn.AdaptiveAvgPool2d(output_size=1)
        self.linear = nn.Linear(int(512*block.expansion*width_multiplier), num_classes)

    def _make_stage(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        blocks = []
        for stride in strides:
            if block is Bottleneck:
                blocks.append(block(in_planes=self.in_planes, planes=int(planes), stride=stride))
            else:
                blocks.append(block(in_planes=self.in_planes, planes=int(planes), stride=stride))
            self.in_planes = int(planes * block.expansion)
        return nn.Sequential(*blocks)

    def forward(self, x):
        out = self.stage0(x)
        out = self.stage1(out)
        out = self.stage2(out)
        out = self.stage3(out)
        out = self.stage4(out)
        out = self.gap(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def create_Res18():
    return ResNet(BasicBlock, [2,2,2,2], num_classes=1000, width_multiplier=1)


def create_Res50():
    return ResNet(Bottleneck, [3,4,6,3], num_classes=1000, width_multiplier=1)

test.py

import argparse
import os
import time
import torch
import torch.backends.cudnn as cudnn
import torchvision.datasets as datasets
from utils import accuracy, ProgressMeter, AverageMeter, val_preprocess
from convnet_utils import switch_deploy_flag, switch_conv_bn_impl, build_model

parser = argparse.ArgumentParser(description='PyTorch ImageNet Test')
parser.add_argument('data', metavar='DIR', help='path to dataset')
parser.add_argument('mode', metavar='MODE', default='train', choices=['train', 'deploy'], help='train or deploy')
parser.add_argument('weights', metavar='WEIGHTS', help='path to the weights file')
parser.add_argument('-a', '--arch', metavar='ARCH', default='ResNet-18')
parser.add_argument('-t', '--blocktype', metavar='BLK', default='DBB', choices=['DBB', 'ACB', 'base'])
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('-b', '--batch-size', default=100, type=int,
                    metavar='N',
                    help='mini-batch size (default: 100) for test')

def test():
    args = parser.parse_args()

    switch_deploy_flag(args.mode == 'deploy')
    switch_conv_bn_impl(args.blocktype)
    model = build_model(args.arch)

    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
        use_gpu = False
    else:
        model = model.cuda()
        use_gpu = True

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss().cuda()

    if 'hdf5' in args.weights:
        from utils import model_load_hdf5
        model_load_hdf5(model, args.weights)
    elif os.path.isfile(args.weights):
        print("=> loading checkpoint '{}'".format(args.weights))
        checkpoint = torch.load(args.weights)
        if 'state_dict' in checkpoint:
            checkpoint = checkpoint['state_dict']
        ckpt = {k.replace('module.', ''):v for k,v in checkpoint.items()}   # strip the names
        model.load_state_dict(ckpt)
    else:
        print("=> no checkpoint found at '{}'".format(args.weights))


    cudnn.benchmark = True

    # Data loading code
    valdir = os.path.join(args.data, 'val')

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, val_preprocess(224)),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    validate(val_loader, model, criterion, use_gpu)


def validate(val_loader, model, criterion, use_gpu):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            if use_gpu:
                images = images.cuda(non_blocking=True)
                target = target.cuda(non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % 10 == 0:
                progress.display(i)

        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return top1.avg




if __name__ == '__main__':
    test()

train.py

import argparse
import os
import random
import shutil
import time
import warnings

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.optim.lr_scheduler import CosineAnnealingLR
from utils import AverageMeter, accuracy, ProgressMeter, val_preprocess, strong_train_preprocess, standard_train_preprocess

IMAGENET_TRAINSET_SIZE = 1281167


parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR',
                    help='path to dataset')

parser.add_argument('-a', '--arch', metavar='ARCH', default='ResNet-18')
parser.add_argument('-t', '--blocktype', metavar='BLK', default='DBB', choices=['DBB', 'ACB', 'base'])

parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=120, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)',
                    dest='weight_decay')
parser.add_argument('-p', '--print-freq', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
parser.add_argument('--world-size', default=-1, type=int,
                    help='number of nodes for distributed training')
parser.add_argument('--rank', default=-1, type=int,
                    help='node rank for distributed training')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
                    help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
                    help='distributed backend')
parser.add_argument('--seed', default=None, type=int,
                    help='seed for initializing training. ')
parser.add_argument('--gpu', default=None, type=int,
                    help='GPU id to use.')
parser.add_argument('--multiprocessing-distributed', action='store_true',
                    help='Use multi-processing distributed training to launch '
                         'N processes per node, which has N GPUs. This is the '
                         'fastest way to use PyTorch for either single node or '
                         'multi node data parallel training')

best_acc1 = 0


def sgd_optimizer(model, lr, momentum, weight_decay):
    params = []
    for key, value in model.named_parameters():
        if not value.requires_grad:
            continue
        apply_lr = lr
        apply_wd = weight_decay
        if 'bias' in key:
            apply_lr = 2 * lr       #   Just a Caffe-style common practice. Made no difference.
        if 'depth' in key:
            apply_wd = 0
        print('set weight decay ', key, apply_wd)
        params += [{'params': [value], 'lr': apply_lr, 'weight_decay': apply_wd}]
    optimizer = torch.optim.SGD(params, lr, momentum=momentum)
    return optimizer

def main():
    args = parser.parse_args()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    ngpus_per_node = torch.cuda.device_count()
    if args.multiprocessing_distributed:
        # Since we have ngpus_per_node processes per node, the total world_size
        # needs to be adjusted accordingly
        args.world_size = ngpus_per_node * args.world_size
        # Use torch.multiprocessing.spawn to launch distributed processes: the
        # main_worker process function
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
    else:
        # Simply call main_worker function
        main_worker(args.gpu, ngpus_per_node, args)


def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    #   =========================== build model
    from convnet_utils import switch_deploy_flag, switch_conv_bn_impl, build_model
    switch_deploy_flag(False)
    switch_conv_bn_impl(args.blocktype)
    model = build_model(args.arch)

    if gpu == 0:
        for name, param in model.named_parameters():
            print(name, param.size())

    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
    elif args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], broadcast_buffers=False)
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model, broadcast_buffers=False)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(model).cuda()


    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer = sgd_optimizer(model, args.lr, args.momentum, args.weight_decay)

    lr_scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=args.epochs * IMAGENET_TRAINSET_SIZE // args.batch_size // ngpus_per_node)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['scheduler'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')

    trans = strong_train_preprocess(224) if 'ResNet' in args.arch else standard_train_preprocess(224)
    print('aug is ', trans)
    train_dataset = datasets.ImageFolder(traindir, trans)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)


    val_dataset = datasets.ImageFolder(valdir, val_preprocess(224))
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)


    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
            print('set sampler')
        # adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args, lr_scheduler)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
                'scheduler': lr_scheduler.state_dict(),
            }, is_best, filename='{}_{}.pth.tar'.format(args.arch, args.blocktype))


def train(train_loader, model, criterion, optimizer, epoch, args, lr_scheduler):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5, ],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)
        if torch.cuda.is_available():
            target = target.cuda(args.gpu, non_blocking=True)

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        lr_scheduler.step()

        if i % args.print_freq == 0 and args.gpu == 0:
            progress.display(i)
        if i % 1000 == 0 and args.gpu == 0:
            print('cur lr: ', lr_scheduler.get_lr()[0])




def validate(val_loader, model, criterion, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
            if torch.cuda.is_available():
                target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)


        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return top1.avg


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, filename.replace('.pth.tar', '_best.pth.tar'))




if __name__ == '__main__':
    main()

utils.py

import torch
import torchvision.transforms as transforms

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

class PCALighting(object):
    """Lighting noise(AlexNet - style PCA - based noise)"""
    def __init__(self, alphastd, eigval, eigvec):
        self.alphastd = alphastd
        self.eigval = eigval
        self.eigvec = eigvec

    def __call__(self, img):
        if self.alphastd == 0:
            return img
        alpha = img.new().resize_(3).normal_(0, self.alphastd)
        rgb = self.eigvec.type_as(img).clone()\
            .mul(alpha.view(1, 3).expand(3, 3))\
            .mul(self.eigval.view(1, 3).expand(3, 3))\
            .sum(1).squeeze()
        return img.add(rgb.view(3, 1, 1).expand_as(img))


imagenet_pca = {
    'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
    'eigvec': torch.Tensor([
        [-0.5675,  0.7192,  0.4009],
        [-0.5808, -0.0045, -0.8140],
        [-0.5836, -0.6948,  0.4203],
    ])
}

def strong_train_preprocess(img_size):
    trans = transforms.Compose([
        transforms.RandomResizedCrop(img_size),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4, saturation=0.4, hue=0.4),
        transforms.ToTensor(),
        PCALighting(0.1, imagenet_pca['eigval'], imagenet_pca['eigvec']),
        normalize,
    ])
    print('---------------------- strong dataaug!')
    return trans

def standard_train_preprocess(img_size):
    trans = transforms.Compose([
        transforms.RandomResizedCrop(img_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    print('---------------------- weak dataaug!')
    return trans

def val_preprocess(img_size):
    trans = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        normalize,
    ])
    return trans

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def read_hdf5(file_path):
    import h5py
    import numpy as np
    result = {}
    with h5py.File(file_path, 'r') as f:
        for k in f.keys():
            value = np.asarray(f[k])
            result[str(k).replace('+', '/')] = value
    print('read {} arrays from {}'.format(len(result), file_path))
    f.close()
    return result

def model_load_hdf5(model:torch.nn.Module, hdf5_path, ignore_keys='stage0.'):
    weights_dict = read_hdf5(hdf5_path)
    for name, param in model.named_parameters():
        print('load param: ', name, param.size())
        if name in weights_dict:
            np_value = weights_dict[name]
        else:
            np_value = weights_dict[name.replace(ignore_keys, '')]
        value = torch.from_numpy(np_value).float()
        assert tuple(value.size()) == tuple(param.size())
        param.data = value
    for name, param in model.named_buffers():
        print('load buffer: ', name, param.size())
        if name in weights_dict:
            np_value = weights_dict[name]
        else:
            np_value = weights_dict[name.replace(ignore_keys, '')]
        value = torch.from_numpy(np_value).float()
        assert tuple(value.size()) == tuple(param.size())
        param.data = value
  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ZRX_GIS

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值