模型修改后加载预训练权重

修改后的模型加载预训练的权重

当对模型进行部分修改后,模型未修改部分如何加载修改前的预训练权重:

model = ResNetBackbone()                         # 修改后的模型,对应的层,模块名称未变

# 官方的resnet50预训练权重
url_50 = 'https://download.pytorch.org/models/resnet50-19c8e357.pth'
save_model = model_zoo.load_url(url_50)

model_dict = model.state_dict()                  # 模型的参数字典

state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()}

# 更新并加载
model_dict.update(state_dict)
model.load_state_dict(model_dict)

中间发现个小问题,官方给的状态字典的key比model.state_dict返回的状态字典键数少

import torch.utils.model_zoo as model_zoo
import torchvision.models as models

url_18 = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
save_model = model_zoo.load_url(url_18)
print(type(save_model))                     # <class 'collections.OrderedDict'>
print(len(save_model))                      # 102

model18 =  models.resnet18()
model18_state_dict = model18.state_dict()   
print(type(model18_state_dict))            # <class 'collections.OrderedDict'>
print(len(model18_state_dict))             # 122

model18.load_state_dict(save_model)        # model18.load_state_dict(save_model)

为什么两个字典中红的key数量不同

# 输出两者不同的key
unique_keys_in_B = set(model18_state_dict.keys()) - set(save_model.keys())
for key in unique_keys_in_B:
    print(key)
    

# layer2.1.bn1.num_batches_tracked
# layer3.1.bn2.num_batches_tracked
# layer4.1.bn1.num_batches_tracked
# layer4.0.bn1.num_batches_tracked
# layer1.0.bn2.num_batches_tracked
# layer3.1.bn1.num_batches_tracked
# layer2.0.downsample.1.num_batches_tracked
# layer2.0.bn1.num_batches_tracked
# layer3.0.bn2.num_batches_tracked
# layer2.0.bn2.num_batches_tracked
# layer1.0.bn1.num_batches_tracked
# bn1.num_batches_tracked
# layer4.0.downsample.1.num_batches_tracked
# layer1.1.bn2.num_batches_tracked
# layer3.0.downsample.1.num_batches_tracked
# layer4.1.bn2.num_batches_tracked
# layer1.1.bn1.num_batches_tracked
# layer3.0.bn1.num_batches_tracked
# layer4.0.bn2.num_batches_tracked
# layer2.1.bn2.num_batches_tracked

查了一下:

bn1.num_batches_tracked:追踪批次数(Number of Batches Tracked)

  • 这个参数用于追踪 Batch Normalization 层已经处理过的批次数量。它在推理过程中不起作用,只在训练过程中更新

本想这自己保存后键数也是102。但当把这些当前模型的状态字典保存,再加载,还是122。

torch.save(model18.state_dict(), 'resnet18.pth')
new_save_dict = torch.load('resnet18.pth')
print(type(new_save_dict))                         # <class 'collections.OrderedDict'>
print(len(new_save_dict))                          # 122

在深度学习中,“BN” 是 Batch Normalization(批标准化)的缩写,是一种用于加速神经网络训练和提高模型性能的技术。Batch Normalization 通过对每个批次的输入进行标准化,将输入归一化为均值为0、方差为1的分布,从而减少了内部协变量偏移(Internal Covariate Shift),有助于网络的训练。

在 PyTorch 中,bn1.weightbn1.biasbn1.running_meanbn1.running_varbn1.num_batches_tracked 是 Batch Normalization 层(通常用 nn.BatchNorm2d 表示)中的参数。

  1. bn1.weight:缩放参数(Scale Parameter)
    • 这是一个可学习的参数,用于对标准化后的特征进行缩放。它控制着 Batch Normalization 层的输出范围,从而增强了网络的表达能力。在训练过程中,这个参数会根据输入数据自动更新。
  2. bn1.bias:平移参数(Shift Parameter)
    • 这也是一个可学习的参数,用于对标准化后的特征进行平移。它控制着 Batch Normalization 层的输出的平均值,从而增强了网络的表达能力。在训练过程中,这个参数会根据输入数据自动更新。
  3. bn1.running_mean:运行时均值(Running Mean)
    • 这是一个用于存储训练过程中计算得到的特征的均值的值。它用于在推理过程中对输入数据进行标准化。在训练过程中,每个批次的均值都会被累积到这个参数中。
  4. bn1.running_var:运行时方差(Running Variance)
    • 这是一个用于存储训练过程中计算得到的特征的方差的值。它用于在推理过程中对输入数据进行标准化。在训练过程中,每个批次的方差都会被累积到这个参数中。
  5. bn1.num_batches_tracked:追踪批次数(Number of Batches Tracked)
    • 这个参数用于追踪 Batch Normalization 层已经处理过的批次数量。它在推理过程中不起作用,只在训练过程中更新

这是BN的计算公式:
B N ( x ) = γ   ⊙   x − μ ^ B θ ^ B + β BN(\rm{x}) = \gamma \ \odot \ \frac{\rm{x} - \hat{\mu}_\mathcal{B}}{\hat \theta _ \mathcal{B}} + \beta BN(x)=γ  θ^Bxμ^B+β

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值