这里以除了最后的全连接层,冻结网络其他参数为例:
weigth_path = './net.pth'
weights_dict = torch.load(weight_path, map_location=device)
# 只保留和模型参数个数一个的预训练参数块
load_weights_dict = {k: v for k, v in weights_dict.items()
if model.state_dict()[k].numel() == v.numel()}
# 加载权重
model.load_state_dict(load_weights_dict, strict=False)
for name, para in model.named_parameters():
# 除最后的全连接层外,其他权重全部冻结,注意这里的fc是在定义模型时命名的一个块
if "fc" not in name:
para.requires_grad_(False)
# 这个变量保存所有训练的参数,是供之后优化器使用的
pg = [p for p in model.parameters() if p.requires_grad]