之前发了很久之前写好的一篇关于Caffe中merge_bn的博客,详情可见
Caffe中BN层与CONV层的融合(merge_bn)
今天由于工作需要要对PyTorch模型进行merge_bn,发现网上貌似还没有类似的现成代码,决定自己写个脚本,思路和方法见上面的博客即可,具体的步骤如下:
- 要求安装的包有
numpy
torch
torchvision
cv2 - 准备好自己的PyTorch模型,包括模型和权重,放在你能import到的地方即可
- 修改代码中的相应路径和import
- run and merge!
具体的代码如下
import torch
import os
from collections import OrderedDict
import cv2
import numpy as np
import torchvision.transforms as transforms
""" Parameters and variables """
IMAGENET = '/home/zym/ImageNet/ILSVRC2012_img_val_256xN_list.txt'
LABEL = '/home/zym/ImageNet/synset.txt'
TEST_ITER = 10
SAVE = False
TEST_AFTER_MERGE = True
""" Functions """
def merge(params, name, layer):
# global variables
global weights, bias
global bn_param
if layer == 'Convolution':
# save weights and bias when meet conv layer
if 'weight' in name:
weights = params.data
bias = torch.zeros(weights.size()[0])
elif 'bias' in name:
bias = params.data
bn_param = {
}
elif layer == 'BatchNorm':
# save bn params
bn_param[name.split('.')[-1]] = params.data
# running_var is the last bn param in pytorch
if 'running_var' in name:
# let us merge bn ~
tmp = bn_param[