修改后的模型加载预训练的权重
当对模型进行部分修改后,模型未修改部分如何加载修改前的预训练权重:
model = ResNetBackbone() # 修改后的模型,对应的层,模块名称未变
# 官方的resnet50预训练权重
url_50 = 'https://download.pytorch.org/models/resnet50-19c8e357.pth'
save_model = model_zoo.load_url(url_50)
model_dict = model.state_dict() # 模型的参数字典
state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()}
# 更新并加载
model_dict.update(state_dict)
model.load_state_dict(model_dict)
中间发现个小问题,官方给的状态字典的key比model.state_dict
返回的状态字典键数少
import torch.utils.model_zoo as model_zoo
import torchvision.models as models
url_18 = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
save_model = model_zoo.load_url(url_18)
print(type(save_model)) # <class 'collections.OrderedDict'>
print(len(save_model)) # 102
model18 = models.resnet18()
model18_state_dict = model18.state_dict()
print(type(model18_state_dict)) # <class 'collections.OrderedDict'>
print(len(model18_state_dict)) # 122
model18.load_state_dict(save_model) # model18.load_state_dict(save_model)
为什么两个字典中红的key数量不同
# 输出两者不同的key
unique_keys_in_B = set(model18_state_dict.keys()) - set(save_model.keys())
for key in unique_keys_in_B:
print(key)
# layer2.1.bn1.num_batches_tracked
# layer3.1.bn2.num_batches_tracked
# layer4.1.bn1.num_batches_tracked
# layer4.0.bn1.num_batches_tracked
# layer1.0.bn2.num_batches_tracked
# layer3.1.bn1.num_batches_tracked
# layer2.0.downsample.1.num_batches_tracked
# layer2.0.bn1.num_batches_tracked
# layer3.0.bn2.num_batches_tracked
# layer2.0.bn2.num_batches_tracked
# layer1.0.bn1.num_batches_tracked
# bn1.num_batches_tracked
# layer4.0.downsample.1.num_batches_tracked
# layer1.1.bn2.num_batches_tracked
# layer3.0.downsample.1.num_batches_tracked
# layer4.1.bn2.num_batches_tracked
# layer1.1.bn1.num_batches_tracked
# layer3.0.bn1.num_batches_tracked
# layer4.0.bn2.num_batches_tracked
# layer2.1.bn2.num_batches_tracked
查了一下:
bn1.num_batches_tracked
:追踪批次数(Number of Batches Tracked)
- 这个参数用于追踪 Batch Normalization 层已经处理过的批次数量。它在推理过程中不起作用,只在训练过程中更新。
本想这自己保存后键数也是102。但当把这些当前模型的状态字典保存,再加载,还是122。
torch.save(model18.state_dict(), 'resnet18.pth')
new_save_dict = torch.load('resnet18.pth')
print(type(new_save_dict)) # <class 'collections.OrderedDict'>
print(len(new_save_dict)) # 122
在深度学习中,“BN” 是 Batch Normalization(批标准化)的缩写,是一种用于加速神经网络训练和提高模型性能的技术。Batch Normalization 通过对每个批次的输入进行标准化,将输入归一化为均值为0、方差为1的分布,从而减少了内部协变量偏移(Internal Covariate Shift),有助于网络的训练。
在 PyTorch 中,bn1.weight
、bn1.bias
、bn1.running_mean
、bn1.running_var
和 bn1.num_batches_tracked
是 Batch Normalization 层(通常用 nn.BatchNorm2d
表示)中的参数。
bn1.weight
:缩放参数(Scale Parameter)- 这是一个可学习的参数,用于对标准化后的特征进行缩放。它控制着 Batch Normalization 层的输出范围,从而增强了网络的表达能力。在训练过程中,这个参数会根据输入数据自动更新。
bn1.bias
:平移参数(Shift Parameter)- 这也是一个可学习的参数,用于对标准化后的特征进行平移。它控制着 Batch Normalization 层的输出的平均值,从而增强了网络的表达能力。在训练过程中,这个参数会根据输入数据自动更新。
bn1.running_mean
:运行时均值(Running Mean)- 这是一个用于存储训练过程中计算得到的特征的均值的值。它用于在推理过程中对输入数据进行标准化。在训练过程中,每个批次的均值都会被累积到这个参数中。
bn1.running_var
:运行时方差(Running Variance)- 这是一个用于存储训练过程中计算得到的特征的方差的值。它用于在推理过程中对输入数据进行标准化。在训练过程中,每个批次的方差都会被累积到这个参数中。
bn1.num_batches_tracked
:追踪批次数(Number of Batches Tracked)- 这个参数用于追踪 Batch Normalization 层已经处理过的批次数量。它在推理过程中不起作用,只在训练过程中更新。
这是BN的计算公式:
B
N
(
x
)
=
γ
⊙
x
−
μ
^
B
θ
^
B
+
β
BN(\rm{x}) = \gamma \ \odot \ \frac{\rm{x} - \hat{\mu}_\mathcal{B}}{\hat \theta _ \mathcal{B}} + \beta
BN(x)=γ ⊙ θ^Bx−μ^B+β