1. 读取预训练权重
pre_weights = torch.load(model_weights_path, map_location=device)
2. 读取预训练权重中与现有模型参数设置相同层的权重,可适用于修改了分类或某些层通道数的情况
net = yourmodel()
pre_dict = {k: v for k, v in pre_weights.items()
if net.state_dict()[k].numel() == v.numel()}
# strict = False 表示仅读取可以匹配的权重
missing_keys, unexpected_keys = net.load_state_dict(pre_dict, strict = False)
3. 冻结特征提取层预训练权重
for params in net.features.parameters():
params.requires_grad = False
4. 由于BN层参数是由各通道值计算得出,在forward中自动实现,而不是通过梯度计算和反向传播更新,需额外冻结BN层权重
def freeze_bn(ly):
classname = ly.__class__.__name__
if classname.find('BatchNorm') != -1:
ly.eval()
net.apply(fre