pytorch保存模型pth,如何加载和使用PyTorch(.pth.tar)模型

I am not very familiar with Torch, and I primarily use Tensorflow. I, however, need to use a retrained inception model that was retrained in Torch. Due to the large amount of computing resources required to retrain an inception model for my particular application, I would like to use the model that was already retrained.

This model is saved as a .pth.tar file.

I would like to be able to first load this model. So far, I have been able to figure out that I must use the following:

model = torch.load('iNat_2018_InceptionV3.pth.tar', map_location='cpu')

This seems to work, because print(model) prints out a large set of numbers and other values, which I presume are the values for the weights an biases.

After this, I need to be able to classify an image with it. I haven't been able to figure this out. How must I format the image? Should the image be converted into an array? After this, how must I pass the input data to the network?

解决方案

you basically need to do the same as in tensorflow. That is, when you store a network, only the parameters (i.e. the trainable objects in your network) will be stored, but not the "glue", that is all the logic you need to use a trained model.

So if you have a .pth.tar file, you can load it, thereby overriding the parameter values of a model already defined.

That means that the general procedure of saving/loading a model is as follows:

write your network definition (i.e. your nn.Module object)

train or otherwise change the network's parameters in a way you want

save the parameters using torch.save

when you want to use that network, use the same definition of an nn.Module object to first instantiate a pytorch network

then override the values of the network's parameters using torch.load

Here's a discussion with some references on how to do this: pytorch forums

And here's a super short mwe:

# to store

torch.save({

'state_dict': model.state_dict(),

'optimizer' : optimizer.state_dict(),

}, 'filename.pth.tar')

# to load

checkpoint = torch.load('filename.pth.tar')

model.load_state_dict(checkpoint['state_dict'])

optimizer.load_state_dict(checkpoint['optimizer'])

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值