bn层学习笔记 卷积层和BN层融合

目录

详细解释卷积神经网络CNN中卷积层以及BN层的参数

bn层参数

当卷积层后跟batch normalization层时为什么不要偏置b

卷积层和BN层融合

合并代码1:

合并代码2:


详细解释卷积神经网络CNN中卷积层以及BN层的参数

详细解释卷积神经网络CNN中卷积层以及BN层的参数_bn层eval后-CSDN博客

BN层参数

在PyTorch中,批量归一化(Batch Normalization, BN)层是通过torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d等类实现的,针对不同维度的输入数据。这些BN层在训练过程中可以调整多个参数和超参数,以控制层的行为和性能。以下是一些主要的参数:

num_features: 输入特征的数量。对于BatchNorm1d,它是特征的维度;对于BatchNorm2d,它是特征图(feature maps)的数量,即输入张量的通道数。

eps: 用于避免除以零的小值,加在标准差上。默认值通常很小,比如1e-5。

momentum: 用于计算运行(moving)平均和方差的值。这决定了历史信息的权重,与新信息的平衡。默认值通常为0.1。

affine: 布尔值,指定是否对归一化的输出应用可学习的仿射变换(即乘以“gamma”(权重)并加上“beta”(偏置))。默认为True。

track_running_stats: 布尔值,指定是否跟踪整个训练集上的运行平均和方差。在训练模式下,这些统计信息会更新;在评估模式下,会使用这些统计信息进行归一化。默认为True。

除了这些参数外,BatchNorm层在训练过程中自动学习两个重要的参数:

权重(gamma): 归一化值的缩放参数,仅当affine=True时学习。
偏置(beta): 归一化值的偏移参数,仅当affine=True时学习。
调整这些参数和超参数可以影响模型的学习能力和最终性能。例如,较小的momentum值会使运行平均和方差对新批次的数据更敏感,而较大的值则使模型更稳定但可能对新数据的适应性较差。调整eps可以帮助避免数值稳定性问题,尤其是在使用较深的网络或较小的批次大小时。通过affine和track_running_stats选项,你可以控制批量归一化层的行为,以适应特定的训练或评估需求。

当卷积层后跟batch normalization层时为什么不要偏置b

当卷积层后跟batch normalization层时为什么不要偏置b_为什么batchnorm前不使用偏置-CSDN博客

卷积层和BN层融合

解释也不错:

深度学习推理时融合BN,轻松获得约5%的提速 - osc_s7aj86hu的个人空间 - OSCHINA - 中文开源技术交流社区

跟博士请教,分组卷积可以合并,如果是独立卷积,bn是通道的bn,可能不能合并?

1.  为什么要合并BN层

在训练深度网络模型时,BN(Batch Normalization)层能够加速网络收敛,并且能够控制过拟合,一般放在卷积层之后。BN 层将数据归一化后,能够有效解决梯度消失与梯度爆炸问题。虽然 BN 层在训练时起到了积极作用,然而,在网络前向推断时多了一些层的运算,影响了模型的性能,且占用了更多的内存或者显存空间。目前,很多先进的网络模型(ResNet,MobileNet,Xception,ShuffleNet 等)都使用了BN技术,因此,我们有必要将 BN 层的参数合并到卷积层,来提升模型前向推断的速度。

2.  BN层与卷积层合并的数学原理

则有:

合并后:

3.  实验结果

机器:显卡 GTX 1080Ti,i7 CPU

本实验对比了Resnet50 模型合并BN层前后的性能,分类精度保持不变,速度显著提升。

模型    CPU前向时间    GPU前向时间
Resnet50(合并前)    176.17ms    11.03ms
Resnet50(合并后)    161.69ms    7.3ms
提升    8.96%    33.27%
————————————————
版权声明:本文为CSDN博主「小麦草」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/kangdi7547/article/details/81348254

https://github.com/vietnamican/conv-bn-merge/blob/main/convbnmerge/convbnmerge.py

合并代码1:

PyTorch中BN层与CONV层的融合(merge_bn)_pytorch融合bn训练-CSDN博客

import torch
import os
from collections import OrderedDict
import cv2
import numpy as np
import torchvision.transforms as transforms


"""  Parameters and variables  """
IMAGENET = '/home/zym/ImageNet/ILSVRC2012_img_val_256xN_list.txt'
LABEL = '/home/zym/ImageNet/synset.txt'
TEST_ITER = 10
SAVE = False
TEST_AFTER_MERGE = True


"""  Functions  """
def merge(params, name, layer):
    # global variables
    global weights, bias
    global bn_param

    if layer == 'Convolution':
        # save weights and bias when meet conv layer
        if 'weight' in name:
            weights = params.data
            bias = torch.zeros(weights.size()[0])
        elif 'bias' in name:
            bias = params.data
        bn_param = {}

    elif layer == 'BatchNorm':
        # save bn params
        bn_param[name.split('.')[-1]] = params.data

        # running_var is the last bn param in pytorch
        if 'running_var' in name:
            # let us merge bn ~
            tmp = bn_param['weight'] / torch.sqrt(bn_param['running_var'] + 1e-5)
            weights = tmp.view(tmp.size()[0], 1, 1, 1) * weights
            bias = tmp*(bias - bn_param['running_mean']) + bn_param['bias']

            return weights, bias

    return None, None


"""  Main functions  """
# import pytorch model
import models.shufflenetv2.shufflenetv2_merge as shufflenetv2
pytorch_net = shufflenetv2.ShuffleNetV2().eval()
model_path = shufflenetv2.weight_file

# load weights
print('Finding trained model weights...')
try:
    for file in os.listdir(model_path):
        if 'pth' in file:
            print('Loading weights from %s ...' % file)
            trained_weights = torch.load(os.path.join(model_path, file))
            # pytorch_net.load_state_dict(trained_weights)
            print('Weights load success')
            break
except:
    raise ValueError('No trained model found or loading error occurs')

# go through pytorch net
print('Going through pytorch net weights...')
new_weights = OrderedDict()
inner_product_flag = False
for name, params in trained_weights.items():
    if len(params.size()) == 4:
        _, _ = merge(params, name, 'Convolution')
        prev_layer = name
    elif len(params.size()) == 1 and not inner_product_flag:
        w, b = merge(params, name, 'BatchNorm')
        if w is not None:
            new_weights[prev_layer] = w
            new_weights[prev_layer.replace('weight', 'bias')] = b
    else:
        # inner product layer
        # if meet inner product layer,
        # the next bias weight can be misclassified as 'BatchNorm' layer as len(params.size()) == 1
        new_weights[name] = params
        inner_product_flag = True

# align names in new_weights with pytorch model
# after move BatchNorm layer in pytorch model,
# the layer names between old model and new model will mis-align
print('Aligning weight names...')
pytorch_net_key_list = list(pytorch_net.state_dict().keys())
new_weights_key_list = list(new_weights.keys())
assert len(pytorch_net_key_list) == len(new_weights_key_list)
for index in range(len(pytorch_net_key_list)):
    new_weights[pytorch_net_key_list[index]] = new_weights.pop(new_weights_key_list[index])

# save new weights
if SAVE:
    torch.save(new_weights, model_path + '/' + file.replace('.pth', '_merged.pth'))

# test merged pytorch model
if TEST_AFTER_MERGE:
    try:
        pytorch_net.load_state_dict(new_weights)
        print('Pytorch net load weights success~')
    except:
        raise ValueError('Load new weights error')

    print('-' * 50)
    with open(LABEL) as f:
        labels = f.read().splitlines()
    with open(IMAGENET) as f:
        images = f.read().splitlines()
        for _ in range(TEST_ITER):
            # cv2 default chann el is BGR
            image_path, label = images[np.random.randint(0, len(images))].split(' ')
            # image_path, label = images[0].split(' ')
            input_image = cv2.imread(image_path)
            input_image = cv2.resize(input_image, (224, 224))
            input_image = transforms.Compose([transforms.ToTensor(),
                                              transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                   std=[0.229, 0.224, 0.225])
                                              ])(input_image)
            input_image = input_image.view(1, 3, 224, 224)
            output_logits = pytorch_net(input_image)
            _, index = output_logits.max(dim=1)
            print('true label: \t%s' % labels[int(label)])
            print('predict label:\t%s' % labels[int(index)])
            print('-' * 50)

合并代码2:

https://github.com/owphoo/pytorch_merge_bn/blob/master/pytorch_merge_bn.py

import torch
import os
from collections import OrderedDict
import numpy as np

global merged
merged = True
def merge(params, name, layer, deconv_layer_names=['deconv']):
    # global variables
    global weights, bias
    global bn_param
    global merged

    is_deconv = False
    for deconv_name in deconv_layer_names:
        if deconv_name in name:
            is_deconv = True
            break 
    if layer == 'Convolution':
        # save weights and bias when meet conv layer
        if 'weight' in name:
            weights = params.data
            bias = torch.zeros(weights.size()[0], device=weights.device)
            if is_deconv:
                bias = torch.zeros(weights.size()[1], device=weights.device)
            else:
                bias = torch.zeros(weights.size()[0], device=weights.device)
            merged = False
        elif 'bias' in name:
            bias = params.data
        bn_param = {}

    elif layer == 'BatchNorm':
        # save bn params
        bn_param[name.split('.')[-1]] = params.data

        # running_var is the last bn param in pytorch
        if 'running_var' in name:
            # merge bn
            tmp = bn_param['weight'] / torch.sqrt(bn_param['running_var'] + 1e-5)
            if is_deconv:
                weights = (tmp.view(tmp.size()[0], 1, 1, 1) * weights.permute(1,0,2,3)).permute(1,0,2,3)
            else:
                weights = tmp.view(tmp.size()[0], 1, 1, 1) * weights
            bias = tmp * (bias - bn_param['running_mean']) + bn_param['bias']

            return weights, bias

    return None, None


import sys
if __name__ == '__main__':
    if len(sys.argv) != 2:
        print('Usage: python pytorch_merge_bn.py YOU_MODEL')
        sys.exit(-1)
    model_path = sys.argv[1]
    print('input model: ', model_path)
    checkpoint = torch.load(model_path)

    trained_weights = checkpoint['net_state_dict']

    '''
    ## conv_bn_relu module
    #           NAME           |           SIZE
    #   conv4.0.weight              torch.Size([128, 256, 3, 3])
    #   conv4.1.weight              torch.Size([256])
    #   conv4.1.bias                torch.Size([256])
    #   conv4.1.running_mean        torch.Size([256])
    #   conv4.1.running_var         torch.Size([256])
    ## deconv_bn_relu module
    #           NAME           |           SIZE
    #   deconv4.0.weight             torch.Size([256, 128, 4, 4])
    #   deconv4.1.weight             torch.Size([128])
    #   deconv4.1.bias               torch.Size([128])
    #   deconv4.1.running_mean       torch.Size([128])
    #   deconv4.1.running_var        torch.Size([128])
    '''

    # check it in your net modules
    deconv_layer_names = ['deconv4', 'deconv3', 'deconv2', 'deconv1']
    temp = []
    for deconv_name in deconv_layer_names:
        temp.append(deconv_name + '.0')
        temp.append(deconv_name + '.1')
    deconv_layer_names = temp

    # go through pytorch net
    new_weights = OrderedDict()
    inner_product_flag = False
    for name, params in trained_weights.items():
        print ('name: ', name, params.size())
        if len(params.size()) == 4:
            _, _ = merge(params, name, 'Convolution', deconv_layer_names=deconv_layer_names)
            prev_layer = name
            # print('prev1: ', prev_layer)
        elif len(params.size()) == 1 and not inner_product_flag:
            w, b = merge(params, name, 'BatchNorm', deconv_layer_names=deconv_layer_names)
            # print('prev2: ', prev_layer)
            if w is not None:
                new_weights[prev_layer] = w
                new_weights[prev_layer.replace('weight', 'bias')] = b
            # mergebn
            merged = True
        else:
            # inner product layer (TODO, inner product layer may have bn module)
            if name.find('num_batches_tracked') == -1:
                new_weights[name] = params
                inner_product_flag = True
            else:
                pass

    # for the last conv/deconv if it has no bn module 
    if merged is False:
        new_weights[prev_layer] = weights
        new_weights[prev_layer.replace('weight', 'bias')] = bias

    checkpoint['net_state_dict'] = new_weights

    # save new weights
    model_name = model_path[model_path.rfind('/')+1:]
    model_path = model_path[:model_path.rfind('/')]
    if model_path.find('/') == -1:
        model_path = './'
    torch.save(checkpoint, model_path + '/merge_bn_' + model_name)

  • 5
    点赞
  • 38
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI算法网奇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值