PyTorch 学习笔记(五):多层全连接网络实现MNIST手写数字分类

先来介绍几个比较重要的函数

1. torch.nn.CrossEntropyLoss()函数

交叉熵损失函数,在pytorch中若模型使用CrossEntropyLoss这个loss函数,则不应该在最后一层再使用softmax进行激活,因为CrossEntropyLoss函数包括了softmax和计算交叉熵两个过程。

分析实例:https://www.jianshu.com/p/e184663e2f8a

2. torchvision.transforms函数

这个类提供了很多图片预处理的方法
transforms.ToTensor()是将图片转换成PyTorch中处理的对象Tensor,并且转化的过程中PyTorch自动将图片标准化了,也就是Tensor的范围是0-1;transforms.Normalize()需要传入两个参数:第一个参数是均值,第二个参数是方差,做的处理就是减均值,再除以方差。

transforms.Compose()将各种预处理操作组合在一起,如在下面的代码中先利用transforms.ToTensor将像素点的值由0-255转换到0-1,transforms.Normalize([0.5], [0.5])表示减去0.5再除以0.5,这样可以将图片转化到了-1到1之间,注意这是因为图片是灰度图,所以只有一个通道,如果是彩色图,有三个通道,那么用transforms.Normalize([a, b, c], [d, e, f])来表示每个通道对应的均值和方差。

3. torch.utils.data.DataLoader函数

torch.utils.data.DataLoader建立一个数据迭代器,先看看 dataloader.py脚本是怎么写的:

init(构造函数)中的几个重要的属性:

  • dataset:(数据类型 dataset)
    输入的数据类型。看名字感觉就像是数据库,C#里面也有dataset类,理论上应该还有下一级的datatable。这应当是原始数据的输入。PyTorch内也有这种数据结构。这里先不管,估计和C#的类似,这里只需要知道是输入数据类型是dataset就可以了。
  • batch_size:(数据类型 int)
    每次输入数据的行数,默认为1。PyTorch训练模型时调用数据不是一行一行进行的(这样太没效率),而是一捆一捆来的。这里就是定义每次喂给神经网络多少行数据,如果设置成1,那就是一行一行进行(个人偏好,PyTorch默认设置是1)。
  • shuffle:(数据类型 bool)
    洗牌。默认设置为False。在每次迭代训练时是否将数据洗牌,默认设置是False。将输入数据的顺序打乱,是为了使数据更有独立性,但如果数据是有序列特征的,就不要设置成True了。
  • collate_fn:(数据类型 callable,没见过的类型)
    将一小段数据合并成数据列表,默认设置是False。如果设置成True,系统会在返回前会将张量数据(Tensors)复制到CUDA内存中。(不太明白作用是什么,就暂时默认False)
  • batch_sampler:(数据类型 Sampler)
    批量采样,默认设置为None。但每次返回的是一批数据的索引(注意:不是数据)。其和batch_size、shuffle 、sampler and drop_last参数是不兼容的。我想,应该是每次输入网络的数据是随机采样模式,这样能使数据更具有独立性质。所以,它和一捆一捆按顺序输入,数据洗牌,数据采样,等模式是不兼容的。
  • sampler:(数据类型 Sampler)
    采样,默认设置为None。根据定义的策略从数据集中采样输入。如果定义采样规则,则洗牌(shuffle)设置必须为False。
  • num_workers:(数据类型 Int)
    工作者数量,默认是0。使用多少个子进程来导入数据。设置为0,就是使用主进程来导入数据。注意:这个数字必须是大于等于0的,负数估计会出错。
  • pin_memory:(数据类型 bool)
    内存寄存,默认为False。在数据返回前,是否将数据复制到CUDA内存中。
  • drop_last:(数据类型 bool)
    丢弃最后数据,默认为False。设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些。这时你是否需要丢弃这批数据。
  • timeout:(数据类型 numeric)
    超时,默认为0。是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。 所以,数值必须大于等于0。
  • worker_init_fn(数据类型 callable,没见过的类型)
    子进程导入模式,默认为Noun。在数据导入前和步长结束后,根据工作子进程的ID逐个按顺序导入数据。

脚本分析:https://blog.csdn.net/u014380165/article/details/79058479

多层全连接网络实现MNIST手写数字分类的代码如下:

import torch
from torch import optim, nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np

# 加载数据集
def get_data():
    # 定义数据预处理操作, transforms.Compose将各种预处理操作组合在一起
    data_tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
    train_dataset = datasets.MNIST(root='./data', train=True, transform=data_tf, download=True)
    test_dataset = datasets.MNIST(root='./data', train=False, transform=data_tf, download=True)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    #test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
    return train_loader, test_dataset

# 构建模型,三层神经网络
class batch_net(nn.Module):
    def __init__(self, in_dim, hidden1_dim, hidden2_dim, out_dim):
        super(batch_net, self).__init__()
        self.layer1 = nn.Sequential(nn.Linear(in_dim, hidden1_dim), nn.BatchNorm1d(hidden1_dim), nn.ReLU(True))
        self.layer2 = nn.Sequential(nn.Linear(hidden1_dim, hidden2_dim), nn.BatchNorm1d(hidden2_dim), nn.ReLU(True))
        self.layer3 = nn.Sequential(nn.Linear(hidden2_dim, out_dim))
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x



if __name__ == "__main__":
    # 超参数配置
    batch_size = 64
    learning_rate = 1e-2
    num_epoches = 5
    # 加载数据集
    train_dataset, test_dataset = get_data()
    # 导入网络,并定义损失函数和优化器
    model = batch_net(28*28, 300, 100, 10)
    if torch.cuda.is_available():
        model = model.cuda()
    criterion = nn.CrossEntropyLoss()
    opitimizer = optim.SGD(model.parameters(), lr=learning_rate)
    # 开始训练
    for i in range(num_epoches):
        for img, label in train_dataset:
            img = img.view(64, -1)
            img = Variable(img)
            #print(img.size())
            label = Variable(label)
            # forward
            out = model(img)
            loss = criterion(out, label)
            # backward
            opitimizer.zero_grad()
            loss.backward()
            opitimizer.step()
            # 打印
            print("epoches= {},loss is {}".format(i, loss))
    # 测试
    model.eval()
    count = 0
    for data in test_dataset:
        img, label = data
        img = img.view(img.size(0), -1)
        img = Variable(img, volatile=True)
        #label = Variable(label, volatile=True)
        out = model(img)
        _, predict = torch.max(out, 1)
        if predict == label:
            count += 1
    print("acc = {}".format(count/len(test_dataset)))

运行结果:
在这里插入图片描述

  • 0
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
要使用PyTorchMNIST手写数字图像进行分类,你可以按照以下步骤进行操作: 1. 首先,使用PyTorch定义ResNet50网络模型。ResNet50是一种深度卷积神经网络,适用于图像分类任务。 2. 接下来,使用PyTorch加载MNIST数据集。PyTorch提供了方便的数据加载工具,你可以使用torchvision.datasets.MNIST函数加载MNIST数据集。如果是第一次运行代码,PyTorch会自动下载数据集。 3. 在加载数据集之后,你可以对数据进行一些预处理操作。例如,可以使用torchvision.transforms.Compose函数将多个转换操作组合在一起,比如将图像转换为Tensor,并进行标准化。 4. 接着,你可以创建训练数据加载器和测试数据加载器。可以使用torch.utils.data.DataLoader函数来创建数据加载器。训练数据加载器用于训练模型,测试数据加载器用于评估模型的性能。你可以指定批量大小、是否打乱数据等参数。 5. 然后,你可以使用定义好的网络模型、数据加载器和损失函数,进行训练过程。训练过程中,可以使用优化器(如SGD或Adam)来更新模型的参数,并计算损失值。训练过程中可以显示损失值的变化情况。 总结起来,对于MNIST手写数字图像分类PyTorch代码,你需要定义ResNet50网络模型,加载MNIST数据集,进行数据预处理,创建训练和测试数据加载器,并进行训练过程。 请注意,上述步骤只是一个大致的框架,具体的代码实现可能会有所不同。你可以根据自己的需求和实际情况对代码进行调整和修改。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [Resnet50卷积神经网络训练MNIST手写数字图像分类 Pytorch训练代码](https://download.csdn.net/download/baidu_36499789/87418795)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] - *2* *3* [用PyTorch实现MNIST手写数字识别(非常详细)](https://blog.csdn.net/sxf1061700625/article/details/105870851)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值