我们以上一篇文章的Letnet模型为例子来进行保存,这里我们主要是储存和加载模型参数,这样能有效节约空间
def load_param(model, path):
if os.path.exists(path):
model.load_state_dict(torch.load(path))
def save_param(model, path):
torch.save(model.state_dict(), path)
测试集
test_loder = DataLoader(test_set, batch_size=4, shuffle=False, num_workers=2)
测试函数
def test(test_loder, model):
correct = 0
total = 0
for data in test_loder:
image, labels = data
if CUDA:
image = image.cuda()
labels = labels.cuda()
outputs = model(image)
_,predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
print('Accuracy on the test set : %d%%' % (100 * correct / total))
这里是通过模型后选出10个里面最大值的编号存在predicted中,然后再与其标签相对比
if __name__ == '__main__':
load_param(let_net, 'model.pkl')
train(let_net, criterion, sge, epochs=2)
save_param(let_net, 'model.pkt')
test(test_loder, let_net)