Pytorch模型的搭建——以图像分类任务为例

Pytorch模型的搭建——以图像分类任务为例

1. 搭建训练模型

  1. 反向传播经典三步骤:
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
  1. 前后加上loss
        loss = criterion(y_pred, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += float(loss.item())
  1. print每一个epoch的结果
    print('epoch {}/{}\tTrain loss: {:.4f}\tTrain accuracy: {:.2f}%'.
          format(epoch + 1, num_epochs, running_loss / (index + 1), correct_pred.item() / (batch_size * (index + 1)) * 100))
    print('Time: {:.2f}s'.format(end - start))
  1. 绘制训练过程中准确率-epoch和损失-epoch图
    Loss_list.append(running_loss / (len(train_data)))
    Accuracy_list.append(100 * correct_pred / (len(train_data)))
    
x1 = range(0, num_epochs)
x2 = range(0, num_epochs)
y1 = Accuracy_list
y2 = Loss_list
plt.subplot(2, 1, 1)
plt.plot(x1, y1, 'o-')
plt.title('Train accuracy vs. epoches')
plt.ylabel('Train accuracy')
plt.subplot(2, 1, 2)
plt.plot(x2, y2, '.-')
plt.xlabel('Train loss vs. epoches')
plt.ylabel('Train loss')
# plt.show()
plt.savefig("accuracy_loss.jpg")
plt.close()
  1. 最后调用train
model.train()
  1. 训练总代码
"""
train
"""
# 先定义两个数组
Loss_list = []
Accuracy_list = []
for epoch in range(num_epochs):
    start = time.perf_counter()
    # model.train()
    running_loss = 0.0
    correct_pred = 0

    # for index, data in enumerate(cifar100_training_loader):
    for index, data in enumerate(train_loader):
        image, label = data
        image = image.to(DEVICE)
        label = label.to(DEVICE)
        y_pred = resnet50(image)

        _, pred = torch.max(y_pred, 1)
        correct_pred += (pred == label).sum()

        loss = criterion(y_pred, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += float(loss.item())

    end = time.perf_counter()
    # batch_size * (index + 1) = 800
    print('epoch {}/{}\tTrain loss: {:.4f}\tTrain accuracy: {:.2f}%'.
          format(epoch + 1, num_epochs, running_loss / (index + 1), correct_pred.item() / (batch_size * (index + 1)) * 100))
    print('Time: {:.2f}s'.format(end - start))

    Loss_list.append(running_loss / (len(train_data)))
    Accuracy_list.append(100 * correct_pred / (len(train_data)))

# 这里迭代num_epochs次,所以x的取值范围为(0,num_epochs)
x1 = range(0, num_epochs)
x2 = range(0, num_epochs)
y1 = Accuracy_list
y2 = Loss_list
plt.subplot(2, 1, 1)
plt.plot(x1, y1, 'o-')
plt.title('Test accuracy vs. epoches')
plt.ylabel('Test accuracy')
plt.subplot(2, 1, 2)
plt.plot(x2, y2, '.-')
plt.xlabel('Test loss vs. epoches')
plt.ylabel('Test loss')
# plt.show()
plt.savefig("accuracy_loss.jpg")
plt.close()

print('Finished training!')

model.train()

2. 搭建测试模型

"""
test
"""
test_loss = 0.0
correct_pred = 0
model.eval()
# for _, data in enumerate(cifar100_test_loader):
for _, data in enumerate(test_loader):
    image, label = data
    image = image.to(DEVICE)
    lable = label.to(DEVICE)
    y_pred = model(image)
    # print("y_pred:", y_pred)

    _, pred = torch.max(y_pred, 1)
    # pred_5 = torch.max(y_pred, 5)
    correct_pred += (pred == label.cuda()).sum()

    loss = criterion(y_pred.cuda(), label.cuda())
    test_loss += float(loss.item())

    # loss_5 = criterion(y_pred_5.cuda(), label.cuda())
    # test_loss5 += float(loss_5.item())

    C2 = confusion_matrix(lable.cpu(), pred.cpu(), labels=[0, 1, 2])
    # print(C2)
    sns.heatmap(C2, annot=True)
    plt.savefig("heatmap.jpg")
    # plt.show()
    plt.close()
    print('Test loss: {:.4f}\tTest accuracy_1: {:.2f}%'.format(test_loss / (len(test_data)), correct_pred.item() / (len(test_data))* 100))

3. 保存/加载模型参数

  1. 保存
torch.save(model.state_dict(), './train_epoch200_.pth')

  1. 加载模型
model = torch.load('model.pkl')
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值