在用python加载数据集训练模型时,可以先下载好数据集放在指定文件夹这样python不会重复下载数据集
比如下载地址为root=’./DataSet’,就在你的py文件所在的文件夹下新建一个DataSet文件夹,把下载好的数据集放到里面
trainset = torchvision.datasets.CIFAR10(root='./DataSet', train=True,
download=True, transform=transform)
为什么训练模型的时候建议把定义网络单独提出来做成一个py文件
因为训练好了模型再调用,如以下代码,如果你的模型定义和训练在一起,那么你from model2 import LeNet的时候,会把model2.py文件里所有的代码都执行一遍,也就是会重新训练,把定义模型的部分单独提出来就不会再次训练了
from model2 import LeNet
import torch
# 实例化模型
net = LeNet()
PATH = 'cifar_net_10.pth'
# 将训练好的参数导入
net.load_state_dict(torch.load(PATH))
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
我们经常会看到后缀名为.pt, .pth, .pkl的pytorch模型文件,这几种模型文件在格式上有什么区别吗?
其实它们并不是在格式上有区别,只是后缀不同而已(仅此而已),在用torch.save()函数保存模型文件时,各人有不同的喜好,有些人喜欢用.pt后缀,有些人喜欢用.pth或.pkl.用相同的torch.save()语句保存出来的模型文件没有什么不同。
在pytorch官方的文档/代码里,有用.pt的,也有用.pth的。一般惯例是使用.pth,但是官方文档里貌似.pt更多,而且官方也不是很在意固定用一种。
PATH = './cifar_net_10.pth'
torch.save(net.state_dict(), PATH)