pytorch加载模型与冻结

pytorch加载模型、冻结、优化

加载模型

读取torch.save保存的模型

weights=torch.load(path)

读取pickle.dump保存的模型

with open('a.pkl', 'wb') as f:
	pickle.dump(score_dict, f)
weights=pickle.load(f)        

model加载模型

直接加载

model.load_state_dict(weights)

字典生成式加载

self.load_state_dict({k:v for k,v in weights.items()})
#self.load_state_dict({k.replace("module.",""):v for k,v in weights.items()})

collections.OrderedDict()加载

b = collections.OrderedDict()
for k,v in weights.items():
	b[k]=v
	#b[k.replace("module.","")]=v
self.load_state_dict(b)

更新参数

new_weights=model.state_dict()
new_weights.update(weights) # 将weights中参数更新至new_weights中
self.load_state_dict(new_weights)

冻结模型

for key, value in model.named_parameters():# named_parameters()包含网络模块名称 key为模型模块名称 value为模型模块值,可以通过判断模块名称进行对应模块冻结
	value.requires_grad = True
for value in model.parameters()():#不包含网络模块名称 value为模型模块值
	value.requires_grad = True

局部优化模型

使用filter过滤需要的模块参数

optimizer = optim.SGD(
 	filter(lambda p: p.requires_grad, model.parameters()),   #只更新			requires_grad=True的参数,即进行反向传播的参数
	lr=,
	momentum=,
	weight_decay=,
	nesterov=
)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值