PyTorch中BN层与CONV层的融合(merge_bn)

之前发了很久之前写好的一篇关于Caffe中merge_bn的博客,详情可见
Caffe中BN层与CONV层的融合(merge_bn)

今天由于工作需要要对PyTorch模型进行merge_bn,发现网上貌似还没有类似的现成代码,决定自己写个脚本,思路和方法见上面的博客即可,具体的步骤如下:

  1. 要求安装的包有

    numpy
    torch
    torchvision
    cv2

  2. 准备好自己的PyTorch模型,包括模型和权重,放在你能import到的地方即可
  3. 修改代码中的相应路径和import
  4. run and merge!

具体的代码如下

import torch
import os
from collections import OrderedDict
import cv2
import numpy as np
import torchvision.transforms as transforms


"""  Parameters and variables  """
IMAGENET = '/home/zym/ImageNet/ILSVRC2012_img_val_256xN_list.txt'
LABEL = '/home/zym/ImageNet/synset.txt'
TEST_ITER = 10
SAVE = False
TEST_AFTER_MERGE = True


"""  Functions  """
def merge(params, name, layer):
    # global variables
    global weights, bias
    global bn_param

    if layer == 'Convolution':
        # save weights and bias when meet conv layer
        if 'weight' in name:
            weights = params.data
            bias = torch.zeros(weights.size()[0])
        elif 'bias' in name:
            bias = params.data
        bn_param = {
   }

    elif layer == 'BatchNorm':
        # save bn params
        bn_param[name.split('.')[-1]] = params.data

        # running_var is the last bn param in pytorch
        if 'running_var' in name:
            # let us merge bn ~
            tmp = bn_param[
  • 4
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 25
    评论
PytorchBN层是Batch Normalization的缩写,用于在深度学习模型对输入数据进行归一化处理。BN层的作用是通过对每个小批量的输入数据进行归一化,使得模型在训练过程更加稳定和快速收敛。\[1\] 在Pytorch,使用BN层的方法如下所示: ```python from torch import nn # 创建一个BN层对象,需要传入特征的通道数num_features作为参数 bn = nn.BatchNorm2d(num_features) # 输入数据 input = torch.randn(batch_size, num_features, height, width) # 将输入数据传入BN层进行处理 output = bn(input) ``` 其,`num_features`表示输入数据的通道数,`batch_size`表示输入数据的批量大小,`height`和`width`表示输入数据的高度和宽度。\[1\] 在BN层的类,还有一些其他的参数可以进行设置,例如`eps`表示用于数值稳定性的小值,默认为1e-5;`momentum`表示用于计算移动平均的动量,默认为0.1;`affine`表示是否学习BN层的参数γ和β,默认为True;`track_running_stats`表示是否跟踪训练过程的统计数据,默认为True。\[2\] 需要注意的是,BN层的参数γ和β是否可学习是由`affine`参数控制的,默认情况下是可学习的,即可通过反向传播进行更新。而BN层的统计数据更新是在每一次训练阶段的`model.train()`后的`forward()`方法自动实现的,而不是在梯度计算与反向传播更新`optim.step()`完成。\[3\] #### 引用[.reference_title] - *1* [一起来学PyTorch——神经网络(BN层)](https://blog.csdn.net/TomorrowZoo/article/details/129531658)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* *3* [pytorchBN层简介](https://blog.csdn.net/lpj822/article/details/109772094)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 25
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值