用Pytorch自建6层神经网络训练Fashion-MNIST数据集,测试准确率达到 92%

一般的深度学习入门例子是 MNIST 的训练和测试,几乎就算是深度学习领域的 HELLO WORLD 了,但是,有一个问题是,MNIST 太简单了,初学者闭着眼镜随便构造几层网络就可以将准确率提升到 90% 以上。但是,初学者这算入门了吗?

答案是没有。

现实开发当中的例子可没有这么简单,如果让初学者直接去上手 VOC 或者是 COCO 这样的数据集,很可能自己搭建的神经网络准确率不超过 30%。

是的,如果不用开源的 VGG、GooLeNet、ResNet 等等,也许你手下敲出的代码命中率还没有瞎猜高。

这篇文章介绍如何用 Pytorch 训练一个自建的神经网络去训练 Fashion-MNIST 数据集。

Fashion-MNIST

Fashion-MINST 的目的是为了替代 MNIST。

这是它的地址:https://github.com/zalandoresearch/fashion-mnist

它是一系列的服装图片集合,总共有 10 个类别。

60000 张训练图片,10000 张测试图片。

Label Description
0 T-shirt/top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boo

在这里插入图片描述

Fashion-MNIST 体积并不大,更方便的是像 Tensorflow 和 Pytorch 目前的版本都可以直接用代码下载。

训练神经网络的步骤

下面我张图是我自己制作的,每次要写相关博客时,我都会翻出来温习一下。
在这里插入图片描述
它会提醒我要做这些事情:

  1. 处理数据
  2. 搭建模型
  3. 确定 Loss 函数
  4. 训练

下面文章就按这样的步骤来讲解

处理数据

下载数据

Pytorch 现在通过现成的 API 就可以下载 Fashion-MINST 的数据

trainset = torchvision.datasets.FashionMNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.FashionMNIST(root='./data', train=False,
                                       download=True, transform=transform1)
testloader = torch.utils.data.DataLoader(testset, batch_size=50,
                                         shuffle=False, num_workers=2)

上面创建了两个 DataLoader ,分别用来加载训练集的图片和测试集的图片。

代码运行后,会在当前目录的 data 目录下存放对应的文件。

在这里插入图片描述

数据增强

为提高模型的泛化能力,一般会将数据进行增强操作。

transform = transforms.Compose(
    [
     transforms.RandomHorizontalFlip(),
     transforms.RandomGrayscale(),
     transforms.ToTensor()])

transform1 = transforms.Compose(
    [
     transforms.ToTensor()])

transform 主要进行了左右翻转,灰度随机变换,用来给训练的图像进行数据加强,测试的图片就不需要了。

transform 的引用传递到前面的数据集对应的 API 就可以了,非常方便。

搭建模型

虽然是自己搭建的神经网络,但是却参考了 VGG 的网络架构。

在这里插入图片描述
总共 6 层网络,4 层卷积层,2 层全连接。

用 Pytorch 实现起来也非常方便。

class Net(nn.Module):


    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(1,64,1,padding=1)
        self.conv2 = nn.Conv2d(64,64,3,padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu1 = nn.ReLU()

        self.conv3 = nn.Conv2d(64,128,3,padding=1)
        self.conv4 = nn.Conv2d(128, 128, 3,padding=1)
        self.pool2 = nn.MaxPool2d(2, 2, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.relu2 = nn.ReLU()

        self.fc5 = nn.Linear(128*8*8,512)
        self.drop1 = nn.Dropout2d()
        self.fc6 = nn.Linear(512,10)


    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.pool1(x)
        x = self.bn1(x)
        x = self.relu1(x)


        x = self.conv3(x)
        x = self.conv4(x)
        x 
  • 20
    点赞
  • 197
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 17
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

frank909

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值