pytorch merge bn (model from mmdetec)

import os
import torch

src_weight = 'epoch_12.pth'
model = torch.load(src_weight)

weight = model['state_dict']
filter_names = []
layer_names = list(weight.keys())
for layer_name in layer_names:
    if 'conv' in layer_name and 'conv2_offset.weight' in layer_name:
        #print(layer_name)
        filter_name = layer_name.replace('conv2_offset', 'bn2')
        filter_names.append(filter_name)

for layer_name in layer_names:
    if 'bn' in layer_name and 'weight' in layer_name:
        if layer_name in filter_names:
            continue
        conv_name = layer_name.replace('bn', 'conv')
        bn_bias_name = layer_name.replace('weight', 'bias')
        bn_running_mean_name = layer_name.replace('weight', 'running_mean')
        bn_running_var_name = layer_name.replace('weight', 'running_var')

        bn_weight = weight[layer_name]
        conv_weight = weight[conv_name]
        bn_bias_weight = weight[bn_bias_name]
        bn_running_mean_weight = weight[bn_running_mean_name]
        bn_running_var_weight = weight[bn_running_var_name]
        var_sqrt = torch.sqrt(bn_running_var_weight + 1e-5)
        mean = bn_running_mean_weight

        beta = bn_weight
        gamma = bn_bias_weight

        conv_bias = mean.new_zeros(mean.shape)
        conv_weight = conv_weight * (beta / var_sqrt).reshape([conv_weight.shape[0], 1, 1, 1])
        conv_bias = (conv_bias - mean) / var_sqrt * beta + gamma
        weight[conv_name] = conv_weight
        conv_bias_name = conv_name.replace('weight', 'bias')
        weight[conv_bias_name] = conv_bias

        num_batches_tracked_name = layer_name.replace('weight', 'num_batches_tracked')
        weight.pop(layer_name)
        weight.pop(bn_bias_name)
        weight.pop(bn_running_mean_name)
        weight.pop(bn_running_var_name)
        weight.pop(num_batches_tracked_name)
    if 'downsample.1' in layer_name and 'weight' in layer_name:
        conv_name = layer_name.replace('downsample.1', 'downsample.0')
        bn_bias_name = layer_name.replace('weight', 'bias')
        bn_running_mean_name = layer_name.replace('weight', 'running_mean')
        bn_running_var_name = layer_name.replace('weight', 'running_var')

        bn_weight = weight[layer_name]
        conv_weight = weight[conv_name]
        bn_bias_weight = weight[bn_bias_name]
        bn_running_mean_weight = weight[bn_running_mean_name]
        bn_running_var_weight = weight[bn_running_var_name]
        var_sqrt = torch.sqrt(bn_running_var_weight + 1e-5)
        mean = bn_running_mean_weight

        beta = bn_weight
        gamma = bn_bias_weight

        conv_bias = mean.new_zeros(mean.shape)
        conv_weight = conv_weight * (beta / var_sqrt).reshape([conv_weight.shape[0], 1, 1, 1])
        conv_bias = (conv_bias - mean) / var_sqrt * beta + gamma
        weight[conv_name] = conv_weight
        conv_bias_name = conv_name.replace('weight', 'bias')
        weight[conv_bias_name] = conv_bias

        num_batches_tracked_name = layer_name.replace('weight', 'num_batches_tracked')
        weight.pop(layer_name)
        weight.pop(bn_bias_name)
        weight.pop(bn_running_mean_name)
        weight.pop(bn_running_var_name)
        weight.pop(num_batches_tracked_name)

for name in weight.keys():
    print(name)

new_model = dict()
new_model['meta'] = model['meta']
new_model['state_dict'] = weight
torch.save(new_model, 'merge_bn.pth')

mmdet/models/backbones/resnet.py  mmdet/models/backbones/resnxet.py 也需要修改相应地方  

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值