1.查看各层模型参数
param_num = 0
for param in model.parameters():#将model替换为你自己的模型
param_num+=torch.numel(param)#累计模型参数总量
print(param.shape)#输出模型各层的参数形状
print(param_num)#输出网络参数的总量
2.模型参数冻结
参数有一个很重要的属性:requires_grad,默认为True,意为:可学习的;如果是False,则不可被更新学习
#对所有参数进行冻结
for param in model.parameters():
param.requires_grad = False
3.仅保存模型参数
保存:
torch.save(model.state_dict(),'model.pth')
调用
new_model= Model()#创建一个结构完全一样的模型
model_paramters = torch.load('model.pth')#加载保存的参数
new_model.load_state_dict(model_paramters)#更新new模型
4.保存整个模型
保存
#保存
torch.save(model,'model.pth')
调用
#调用
new_model = torch.load('model.pth')