一些常用的函数技巧

1. 计算模型训练时间

import time

start = time.time()
# 输出训练所耗总时间
time_all = time.time() - start
print('Training complete in {:.2f}m | {:.2f}s'.format(time_all // 60, time_all % 60))

2. 模型的保存与加载

# 保存模型的网络结构
torch.save(model.state_dict(), MODEL_PATH)

# 加载模型
model = MyModel()
model.load_state_dict(torch.load(MODEL_PATH))

3. 预测正确个数统计

output: 模型预测输出

target:  真实标签

方法1:

_, pred = torch.max(output.data, dim=1)
correct = torch.sum(preds == labels).item()
# _, pred = torch.max(output, dim=1)
# correct = (pred == target).sum().item()

print(pred)
# tensor([8, 5, 5, 8, 8, 8, 8, 8])
print(target)
# tensor([8, 4, 0, 1, 8, 4, 1, 5])
print((pred == target).sum().item())
# 2

方法2: 推荐!!!

pred = output.argmax(dim=1, keepdim=True)
# pred = output.max(dim=1, keepdim=True)[1]  # 值,索引
correct = pred.eq(target.view_as(pred)).sum().item()
# 返回tensor([True, False, True, ..., False]),其中view_as把target 转换为与pred一样的形状

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值