目录
当卷积层后跟batch normalization层时为什么不要偏置b
详细解释卷积神经网络CNN中卷积层以及BN层的参数
详细解释卷积神经网络CNN中卷积层以及BN层的参数_bn层eval后-CSDN博客
BN层参数
在PyTorch中,批量归一化(Batch Normalization, BN)层是通过torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d等类实现的,针对不同维度的输入数据。这些BN层在训练过程中可以调整多个参数和超参数,以控制层的行为和性能。以下是一些主要的参数:
num_features: 输入特征的数量。对于BatchNorm1d,它是特征的维度;对于BatchNorm2d,它是特征图(feature maps)的数量,即输入张量的通道数。
eps: 用于避免除以零的小值,加在标准差上。默认值通常很小,比如1e-5。
momentum: 用于计算运行(moving)平均和方差的值。这决定了历史信息的权重,与新信息的平衡。默认值通常为0.1。
affine: 布尔值,指定是否对归一化的输出应用可学习的仿射变换(即乘以“gamma”(权重)并加上“beta”(偏置))。默认为True。
track_running_stats: 布尔值,指定是否跟踪整个训练集上的运行平均和方差。在训练模式下,这些统计信息会更新;在评估模式下,会使用这些统计信息进行归一化。默认为True。
除了这些参数外,BatchNorm层在训练过程中自动学习两个重要的参数:
权重(gamma): 归一化值的缩放参数,仅当affine=True时学习。
偏置(beta): 归一化值的偏移参数,仅当affine=True时学习。
调整这些参数和超参数可以影响模型的学习能力和最终性能。例如,较小的momentum值会使运行平均和方差对新批次的数据更敏感,而较大的值则使模型更稳定但可能对新数据的适应性较差。调整eps可以帮助避免数值稳定性问题,尤其是在使用较深的网络或较小的批次大小时。通过affine和track_running_stats选项,你可以控制批量归一化层的行为,以适应特定的训练或评估需求。
当卷积层后跟batch normalization层时为什么不要偏置b
当卷积层后跟batch normalization层时为什么不要偏置b_为什么batchnorm前不使用偏置-CSDN博客
卷积层和BN层融合
解释也不错:
深度学习推理时融合BN,轻松获得约5%的提速 - osc_s7aj86hu的个人空间 - OSCHINA - 中文开源技术交流社区
跟博士请教,分组卷积可以合并,如果是独立卷积,bn是通道的bn,可能不能合并?
1. 为什么要合并BN层
在训练深度网络模型时,BN(Batch Normalization)层能够加速网络收敛,并且能够控制过拟合,一般放在卷积层之后。BN 层将数据归一化后,能够有效解决梯度消失与梯度爆炸问题。虽然 BN 层在训练时起到了积极作用,然而,在网络前向推断时多了一些层的运算,影响了模型的性能,且占用了更多的内存或者显存空间。目前,很多先进的网络模型(ResNet,MobileNet,Xception,ShuffleNet 等)都使用了BN技术,因此,我们有必要将 BN 层的参数合并到卷积层,来提升模型前向推断的速度。
2. BN层与卷积层合并的数学原理
则有:
合并后:
3. 实验结果
机器:显卡 GTX 1080Ti,i7 CPU
本实验对比了Resnet50 模型合并BN层前后的性能,分类精度保持不变,速度显著提升。
模型 CPU前向时间 GPU前向时间
Resnet50(合并前) 176.17ms 11.03ms
Resnet50(合并后) 161.69ms 7.3ms
提升 8.96% 33.27%
————————————————
版权声明:本文为CSDN博主「小麦草」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/kangdi7547/article/details/81348254
https://github.com/vietnamican/conv-bn-merge/blob/main/convbnmerge/convbnmerge.py
合并代码1:
PyTorch中BN层与CONV层的融合(merge_bn)_pytorch融合bn训练-CSDN博客
import torch
import os
from collections import OrderedDict
import cv2
import numpy as np
import torchvision.transforms as transforms
""" Parameters and variables """
IMAGENET = '/home/zym/ImageNet/ILSVRC2012_img_val_256xN_list.txt'
LABEL = '/home/zym/ImageNet/synset.txt'
TEST_ITER = 10
SAVE = False
TEST_AFTER_MERGE = True
""" Functions """
def merge(params, name, layer):
# global variables
global weights, bias
global bn_param
if layer == 'Convolution':
# save weights and bias when meet conv layer
if 'weight' in name:
weights = params.data
bias = torch.zeros(weights.size()[0])
elif 'bias' in name:
bias = params.data
bn_param = {}
elif layer == 'BatchNorm':
# save bn params
bn_param[name.split('.')[-1]] = params.data
# running_var is the last bn param in pytorch
if 'running_var' in name:
# let us merge bn ~
tmp = bn_param['weight'] / torch.sqrt(bn_param['running_var'] + 1e-5)
weights = tmp.view(tmp.size()[0], 1, 1, 1) * weights
bias = tmp*(bias - bn_param['running_mean']) + bn_param['bias']
return weights, bias
return None, None
""" Main functions """
# import pytorch model
import models.shufflenetv2.shufflenetv2_merge as shufflenetv2
pytorch_net = shufflenetv2.ShuffleNetV2().eval()
model_path = shufflenetv2.weight_file
# load weights
print('Finding trained model weights...')
try:
for file in os.listdir(model_path):
if 'pth' in file:
print('Loading weights from %s ...' % file)
trained_weights = torch.load(os.path.join(model_path, file))
# pytorch_net.load_state_dict(trained_weights)
print('Weights load success')
break
except:
raise ValueError('No trained model found or loading error occurs')
# go through pytorch net
print('Going through pytorch net weights...')
new_weights = OrderedDict()
inner_product_flag = False
for name, params in trained_weights.items():
if len(params.size()) == 4:
_, _ = merge(params, name, 'Convolution')
prev_layer = name
elif len(params.size()) == 1 and not inner_product_flag:
w, b = merge(params, name, 'BatchNorm')
if w is not None:
new_weights[prev_layer] = w
new_weights[prev_layer.replace('weight', 'bias')] = b
else:
# inner product layer
# if meet inner product layer,
# the next bias weight can be misclassified as 'BatchNorm' layer as len(params.size()) == 1
new_weights[name] = params
inner_product_flag = True
# align names in new_weights with pytorch model
# after move BatchNorm layer in pytorch model,
# the layer names between old model and new model will mis-align
print('Aligning weight names...')
pytorch_net_key_list = list(pytorch_net.state_dict().keys())
new_weights_key_list = list(new_weights.keys())
assert len(pytorch_net_key_list) == len(new_weights_key_list)
for index in range(len(pytorch_net_key_list)):
new_weights[pytorch_net_key_list[index]] = new_weights.pop(new_weights_key_list[index])
# save new weights
if SAVE:
torch.save(new_weights, model_path + '/' + file.replace('.pth', '_merged.pth'))
# test merged pytorch model
if TEST_AFTER_MERGE:
try:
pytorch_net.load_state_dict(new_weights)
print('Pytorch net load weights success~')
except:
raise ValueError('Load new weights error')
print('-' * 50)
with open(LABEL) as f:
labels = f.read().splitlines()
with open(IMAGENET) as f:
images = f.read().splitlines()
for _ in range(TEST_ITER):
# cv2 default chann el is BGR
image_path, label = images[np.random.randint(0, len(images))].split(' ')
# image_path, label = images[0].split(' ')
input_image = cv2.imread(image_path)
input_image = cv2.resize(input_image, (224, 224))
input_image = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])(input_image)
input_image = input_image.view(1, 3, 224, 224)
output_logits = pytorch_net(input_image)
_, index = output_logits.max(dim=1)
print('true label: \t%s' % labels[int(label)])
print('predict label:\t%s' % labels[int(index)])
print('-' * 50)
合并代码2:
https://github.com/owphoo/pytorch_merge_bn/blob/master/pytorch_merge_bn.py
import torch
import os
from collections import OrderedDict
import numpy as np
global merged
merged = True
def merge(params, name, layer, deconv_layer_names=['deconv']):
# global variables
global weights, bias
global bn_param
global merged
is_deconv = False
for deconv_name in deconv_layer_names:
if deconv_name in name:
is_deconv = True
break
if layer == 'Convolution':
# save weights and bias when meet conv layer
if 'weight' in name:
weights = params.data
bias = torch.zeros(weights.size()[0], device=weights.device)
if is_deconv:
bias = torch.zeros(weights.size()[1], device=weights.device)
else:
bias = torch.zeros(weights.size()[0], device=weights.device)
merged = False
elif 'bias' in name:
bias = params.data
bn_param = {}
elif layer == 'BatchNorm':
# save bn params
bn_param[name.split('.')[-1]] = params.data
# running_var is the last bn param in pytorch
if 'running_var' in name:
# merge bn
tmp = bn_param['weight'] / torch.sqrt(bn_param['running_var'] + 1e-5)
if is_deconv:
weights = (tmp.view(tmp.size()[0], 1, 1, 1) * weights.permute(1,0,2,3)).permute(1,0,2,3)
else:
weights = tmp.view(tmp.size()[0], 1, 1, 1) * weights
bias = tmp * (bias - bn_param['running_mean']) + bn_param['bias']
return weights, bias
return None, None
import sys
if __name__ == '__main__':
if len(sys.argv) != 2:
print('Usage: python pytorch_merge_bn.py YOU_MODEL')
sys.exit(-1)
model_path = sys.argv[1]
print('input model: ', model_path)
checkpoint = torch.load(model_path)
trained_weights = checkpoint['net_state_dict']
'''
## conv_bn_relu module
# NAME | SIZE
# conv4.0.weight torch.Size([128, 256, 3, 3])
# conv4.1.weight torch.Size([256])
# conv4.1.bias torch.Size([256])
# conv4.1.running_mean torch.Size([256])
# conv4.1.running_var torch.Size([256])
## deconv_bn_relu module
# NAME | SIZE
# deconv4.0.weight torch.Size([256, 128, 4, 4])
# deconv4.1.weight torch.Size([128])
# deconv4.1.bias torch.Size([128])
# deconv4.1.running_mean torch.Size([128])
# deconv4.1.running_var torch.Size([128])
'''
# check it in your net modules
deconv_layer_names = ['deconv4', 'deconv3', 'deconv2', 'deconv1']
temp = []
for deconv_name in deconv_layer_names:
temp.append(deconv_name + '.0')
temp.append(deconv_name + '.1')
deconv_layer_names = temp
# go through pytorch net
new_weights = OrderedDict()
inner_product_flag = False
for name, params in trained_weights.items():
print ('name: ', name, params.size())
if len(params.size()) == 4:
_, _ = merge(params, name, 'Convolution', deconv_layer_names=deconv_layer_names)
prev_layer = name
# print('prev1: ', prev_layer)
elif len(params.size()) == 1 and not inner_product_flag:
w, b = merge(params, name, 'BatchNorm', deconv_layer_names=deconv_layer_names)
# print('prev2: ', prev_layer)
if w is not None:
new_weights[prev_layer] = w
new_weights[prev_layer.replace('weight', 'bias')] = b
# mergebn
merged = True
else:
# inner product layer (TODO, inner product layer may have bn module)
if name.find('num_batches_tracked') == -1:
new_weights[name] = params
inner_product_flag = True
else:
pass
# for the last conv/deconv if it has no bn module
if merged is False:
new_weights[prev_layer] = weights
new_weights[prev_layer.replace('weight', 'bias')] = bias
checkpoint['net_state_dict'] = new_weights
# save new weights
model_name = model_path[model_path.rfind('/')+1:]
model_path = model_path[:model_path.rfind('/')]
if model_path.find('/') == -1:
model_path = './'
torch.save(checkpoint, model_path + '/merge_bn_' + model_name)