Pytorch---使用Pytorch实现多分类问题

一、代码中的数据集可以通过运行以下代码进行获取

train_ds = torchvision.datasets.MNIST(root=r'dataset', train=True, transform=ToTensor(), download=True)
test_ds = torchvision.datasets.MNIST(root=r'dataset', train=False, transform=ToTensor(), download=True)

二、代码运行环境

Pytorch-gpu==1.7.1
Python==3.7

三、数据集处理代码如下所示

import torchvision
from torchvision.transforms import ToTensor
import torch.utils.data
import matplotlib.pyplot as plt
import numpy as np


def make_dataset():
    train_ds = torchvision.datasets.MNIST(root=r'dataset', train=True, transform=ToTensor(), download=True)
    test_ds = torchvision.datasets.MNIST(root=r'dataset', train=False, transform=ToTensor(), download=True)
    train_dl = torch.utils.data.DataLoader(dataset=train_ds, batch_size=64, shuffle=True)
    test_dl = torch.utils.data.DataLoader(dataset=test_ds, batch_size=64)
    return train_dl, test_dl


if __name__ == '__main__':
    train, test = make_dataset()
    images, label = next(iter(train))
    plt.figure(figsize=(10, 3))
    for i, img in enumerate(images[:10]):
        np_img = img.numpy()
        np_img = np.squeeze(np_img)
        plt.subplot(1, 10, i + 1)
        plt.imshow(np_img)
        plt.axis('off')
        plt.title(str(label[i].numpy()))
    plt.show()

四、模型的构建代码如下所示

from torch import nn
import torch


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.liner_1 = nn.Linear(in_features=28 * 28, out_features=120)
        self.liner_2 = nn.Linear(in_features=120, out_features=84)
        self.liner_3 = nn.Linear(in_features=84, out_features=10)

    def forward(self, input):
        x = input.view(-1, 28 * 28)
        x = torch.relu(self.liner_1(x))
        x = torch.relu(self.liner_2(x))
        logits = self.liner_3(x)
        return logits

五、模型的训练代码如下所示

import torch
from data_loader import make_dataset
from model_loader import Model
from torch import nn
import tqdm
import os

if __name__ == '__main__':
    # 进行数据的加载
    train_dl, test_dl = make_dataset()

    # 进行模型的加载
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = Model().to(device)

    # 定义相关的训练参数
    loss_fn = nn.CrossEntropyLoss()
    opt = torch.optim.SGD(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=opt, milestones=[25, 50, 75], gamma=0.1)
    epochs = 100

    for epoch in range(epochs):
        # 开始进行训练
        train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
        train_tqdm.set_description_str('Train_Epoch {:3d}'.format(epoch))
        model.train()
        for image, label in train_tqdm:
            image, label = image.to(device), label.to(device)
            pred = model(image)
            loss = loss_fn(pred, label)
            opt.zero_grad()
            loss.backward()
            opt.step()
            with torch.no_grad():
                train_tqdm.set_postfix_str('Train_Loss is {:.14f}'.format(loss_fn(pred, label).item()))
        train_tqdm.close()
        # 开始进行测试
        with torch.no_grad():
            test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
            test_tqdm.set_description_str('Test_Epoch {:3d}'.format(epoch))
            model.eval()
            for image, label in test_tqdm:
                image, label = image.to(device), label.to(device)
                pred = model(image)
                loss = loss_fn(pred, label)
                test_tqdm.set_postfix_str('Test_Loss is {:.14f}'.format(loss.item()))
            test_tqdm.close()
        # 进行动态学习率的调整
        scheduler.step()

    # 进行模型的保存
    if not os.path.exists('model_data'):
        os.mkdir('model_data')
    torch.save(model.state_dict(), r'model_data\model.pth')

六、模型的预测代码如下所示

from model_loader import Model
from data_loader import make_dataset
import torch
import matplotlib.pyplot as plt
import matplotlib

# 进行数据的加载
train_dl, test_dl = make_dataset()

# 进行模型的加载
model = Model()
model_state_dict = torch.load(r'model_data\model.pth')
model.load_state_dict(model_state_dict)
model.eval()

# 进行模型的预测
index = 5
image, label = next(iter(test_dl))
with torch.no_grad():
    pred = model(image)
    pred = torch.argmax(input=pred, dim=-1)
    show_image = torch.squeeze(image)
    matplotlib.rc("font", family='Microsoft YaHei')
    plt.imshow(show_image[index])
    plt.title('预测结果为:' + str(pred[index].numpy()) + ',标签结果为:' + str(label[index].numpy()))
    plt.axis('off')
    plt.savefig('result.png')
    plt.show()

七、代码的运行结果如下所示

在这里插入图片描述

  • 0
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
2023-pytorch是一个开源深度学习库,它在计算机视觉领域有广泛的应用。本文将手把手教你如何使用CSDN来学习和获取有关2023-pytorch分类的相关资料和教程。 首先,在你的浏览器中打开CSDN的官方网站,网址为www.csdn.net。在主页上,你可以看到各种热门的技术文章、博客和论坛。在搜索框中输入"2023-pytorch分类",然后点击搜索按钮。CSDN将会为你展示与该关键词相关的所有内容。 接下来,你可以通过筛选工具来找到特定类型的文章或教程。例如,你可以选择只查看博客、文章或教程。你还可以选择按照发布时间或热度排序来获取最新或最受欢迎的内容。 当你找到一篇感兴趣的教程时,点击进入阅读。通常,教程会提供详细的步骤和示例代码,帮助你了解如何使用2023-pytorch进行分类任务。你可以按照教程中的指示一步一步地操作,并理解每个步骤的原理和作用。 除了阅读教程外,CSDN还提供了一个活跃的技术问答社区。你可以在这里向其他用户提问、讨论问题,或分享你的学习体验和心得。社区中的任何人都可以回答你的问题,所以不要犹豫,积极参与其中。 此外,CSDN还为用户提供了博客功能,你可以创建自己的博客来记录学习过程和分享实践经验。通过写博客,你还可以得到其他人的反馈和建议,不断提升自己的技术水平。 总结起来,要使用CSDN学习和获取有关2023-pytorch分类的相关资料和教程,你可以通过搜索功能找到相关内容,阅读教程并按照指导一步一步地进行实践,参与技术问答社区以及利用博客功能分享你的学习心得和经验。通过这些途径,你将能够快速掌握2023-pytorch分类的基本原理和应用。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

水哥很水

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值