PyTorch中BN层与CONV层的融合(merge_bn)

本文介绍了如何在PyTorch中实现BN层与CONV层的融合,类似于Caffe中的merge_bn操作。通过安装numpy, torch, torchvision和cv2等库,将PyTorch模型与权重准备就绪,然后运行特定代码实现融合。" 120090710,11027864,C++编程:深入理解数组与指针,"['C++', '数组操作', '指针操作', '编程基础']
摘要由CSDN通过智能技术生成

之前发了很久之前写好的一篇关于Caffe中merge_bn的博客,详情可见
Caffe中BN层与CONV层的融合(merge_bn)

今天由于工作需要要对PyTorch模型进行merge_bn,发现网上貌似还没有类似的现成代码,决定自己写个脚本,思路和方法见上面的博客即可,具体的步骤如下:

  1. 要求安装的包有

    numpy
    torch
    torchvision
    cv2

  2. 准备好自己的PyTorch模型,包括模型和权重,放在你能import到的地方即可
  3. 修改代码中的相应路径和import
  4. run and merge!

具体的代码如下

import torch
import os
from collections import OrderedDict
import cv2
import numpy as np
import torchvision.transforms as transforms


"""  Parameters and variables  """
IMAGENET = '/home/zym/ImageNet/ILSVRC2012_img_val_256xN_list.txt'
LABEL = '/home/zym/ImageNet/synset.txt'
TEST_ITER = 10
SAVE = False
TEST_AFTER_MERGE = True


"""  Functions  """
def merge(params, name, layer):
    # global variables
    global weights, bias
    global bn_param

    if layer == 'Convolution':
        # save weights and bias when meet conv layer
        if 'weight' in name:
            weights = params.data
            bias = torch.zeros(weights.size()[0])
        elif 'bias' in name:
            bias = params.data
        bn_param = {
   }

    elif layer == 'BatchNorm':
        # save bn params
        bn_param[name.split('.')[-1]] = params.data

        # running_var is the last bn param in pytorch
        if 'running_var' in name:
            # let us merge bn ~
            tmp = bn_param[
评论 25
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值