1. 计算模型训练时间
import time
start = time.time()
# 输出训练所耗总时间
time_all = time.time() - start
print('Training complete in {:.2f}m | {:.2f}s'.format(time_all // 60, time_all % 60))
2. 模型的保存与加载
# 保存模型的网络结构
torch.save(model.state_dict(), MODEL_PATH)
# 加载模型
model = MyModel()
model.load_state_dict(torch.load(MODEL_PATH))
3. 预测正确个数统计
output: 模型预测输出
target: 真实标签
方法1:
_, pred = torch.max(output.data, dim=1)
correct = torch.sum(preds == labels).item()
# _, pred = torch.max(output, dim=1)
# correct = (pred == target).sum().item()
print(pred)
# tensor([8, 5, 5, 8, 8, 8, 8, 8])
print(target)
# tensor([8, 4, 0, 1, 8, 4, 1, 5])
print((pred == target).sum().item())
# 2
方法2: 推荐!!!
pred = output.argmax(dim=1, keepdim=True)
# pred = output.max(dim=1, keepdim=True)[1] # 值,索引
correct = pred.eq(target.view_as(pred)).sum().item()
# 返回tensor([True, False, True, ..., False]),其中view_as把target 转换为与pred一样的形状