pytorch加载模型、冻结、优化
加载模型
读取torch.save保存的模型
weights=torch.load(path)
读取pickle.dump保存的模型
with open('a.pkl', 'wb') as f:
pickle.dump(score_dict, f)
weights=pickle.load(f)
model加载模型
直接加载
model.load_state_dict(weights)
字典生成式加载
self.load_state_dict({k:v for k,v in weights.items()})
#self.load_state_dict({k.replace("module.",""):v for k,v in weights.items()})
collections.OrderedDict()加载
b = collections.OrderedDict()
for k,v in weights.items():
b[k]=v
#b[k.replace("module.","")]=v
self.load_state_dict(b)
更新参数
new_weights=model.state_dict()
new_weights.update(weights) # 将weights中参数更新至new_weights中
self.load_state_dict(new_weights)
冻结模型
for key, value in model.named_parameters():# named_parameters()包含网络模块名称 key为模型模块名称 value为模型模块值,可以通过判断模块名称进行对应模块冻结
value.requires_grad = True
for value in model.parameters()():#不包含网络模块名称 value为模型模块值
value.requires_grad = True
局部优化模型
使用filter过滤需要的模块参数
optimizer = optim.SGD(
filter(lambda p: p.requires_grad, model.parameters()), #只更新 requires_grad=True的参数,即进行反向传播的参数
lr=,
momentum=,
weight_decay=,
nesterov=
)