pytorch的官方教程里提供了相关说明:
只保存模型用于以后的推断的话使用.pth
或.pt
,这样可以直接加载模型
A common PyTorch convention is to save models using either a .pt or .pth file extension.
torch.save(model, "model.pth") # or .pt
model = torch.load("model.pth")
断点保存的话则使用.tar
,加载的时候模型需要使用load_state_dict()
方法
To save multiple components, organize them in a dictionary and use torch.save() to serialize the dictionary. A common PyTorch convention is to save these checkpoints using the .tar file extension.
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, "checkpoint.tar")
...
checkpoint = torch.load("checkpoint.tar")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
其中部份人群喜欢使用.pth.tar
来表明这不是一个简单的压缩tar类型的文件
其实这个问题一直有人讨论,因为pth
同时也是Python的一种格式,所以有人甚至提出要更改一种后缀来区分…不过暂时不太需要考虑这个问题…
但实际上阅读save的源码就会发现,torch只是调用了Python的pickle来完成,而且没有做任何的后缀名判断,因此无论保存成什么后缀都是可以的…