深度学习3:PyTorch保存和调用深度学习模型及模型参数

1.查看各层模型参数

param_num = 0
for param in model.parameters():#将model替换为你自己的模型
	param_num+=torch.numel(param)#累计模型参数总量
	print(param.shape)#输出模型各层的参数形状
print(param_num)#输出网络参数的总量

2.模型参数冻结

参数有一个很重要的属性:requires_grad,默认为True,意为:可学习的;如果是False,则不可被更新学习

#对所有参数进行冻结
for param in model.parameters():
	param.requires_grad = False

3.仅保存模型参数

保存:

torch.save(model.state_dict(),'model.pth')

调用

new_model= Model()#创建一个结构完全一样的模型
model_paramters = torch.load('model.pth')#加载保存的参数
new_model.load_state_dict(model_paramters)#更新new模型

4.保存整个模型

保存

#保存
torch.save(model,'model.pth')

调用

#调用
new_model = torch.load('model.pth')
  • 9
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值