1.保存和加载模型
1.保存参数
torch.save(model_object.state_dict(), ‘params.pkl’)
model_object.load_state_dict(torch.load(‘params.pkl’))
model_object:为模型的实例化
加载时,要定义该类和实例
2.保存模型
torch.save(model_object, ‘model.pkl’)
model = torch.load(‘model.pkl’)
model_object:为模型的实例化
加载时,model即为可用模型
2.torch.max(test_output, 1)
torch.max(test_output, 1)
输出格式[tensor([最大值]),tensor([最大值的位置])]
参数1/0,输出行/列最大
3.直接调用数据集送入模型
mages=Variable(test_dataset.test_data[:100].reshape(-1, 28*28).float())
.test_data[:100]:test_dataset中标签为test_data的数据
reshape:改变数据类型
.float():改变数据byte为float型
4.torch.max(test_output, 1)[1].data.numpy().squeeze()