(三)训练模型_测试模型_分类别打印模型准确率

(1)设置损失Loss和优化器optim

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

CrossEntropyLoss: This criterion computes the cross entropy loss between input and target.

SGD: Implements stochastic gradient descent (optionally with momentum).

 (2)训练模型

for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        # 权重参数梯度清零
        optimizer.zero_grad()
        # 正向和反向传播
        outputs = net.forward(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        # Performs a single optimization step (parameter update).
        optimizer.step()

        # 显示损失
        running_loss += loss.item() # item() 把loss取出来
        # 每迭代2000个小的批次,打印一次loss
        if i % 2000 == 1999:
            print(['[%d, %5d] loss: %.3f' % (epoch+1, i+1, running_loss/2000)])
            running_loss = 0

(3)测试模型

# 测试模型
correct = 0
total = 0

with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print("Accuracy of the total network on the 10000 test images: %d %%" % (correct*100/total))

 (4)分类别打印模型准确率

# 分类别打印模型准确率
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

for i in range(10):
    print("Accuracy of %5s : %2d %%" % (classes[i], (100*class_correct[i]/class_total[i])))

Loss Functionsicon-default.png?t=M7J4https://pytorch.org/docs/stable/nn.html#loss-functions

torch.max()icon-default.png?t=M7J4https://pytorch.org/docs/stable/torch.html#reduction-ops

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值