入门学习MNIST手写数字识别

一、MNIST数据集

1.MNIST数据集简介

MNIST数据集是一个公开的数据集,相当于深度学习的hello world,用来检验一个模型/库/框架是否有效的一个评价指标。

MNIST数据集是由0〜9手写数字图片和数字标签所组成的,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片。MNIST 数据集来自美国国家标准与技术研究所,整个训练集由250个不同人的手写数字组成,其中50%来自美国高中学生,50%来自人口普查的工作人员。

2.MNIST数据集包含四部分

MNIST数据集是集成的API无需手动下载,可以通过torch里面的API直接获取。

官方文档:https://pytorch.org/vision/stable/datasets.html

参数:

  • root:指的是下载的目录
  •  train:如果设置成True的话表示取训练集,如果要取测试集就设置成False
  • download:如果设置成True,会先判断是否下载过,如果未下载过,就会下载文件;如果已经下载过了,设置成True和False都一样,不会重新下载。
  • transform:是对图片进行预处理的一些操作,可以将一个PIL 图片翻译成张量或其他内容           
from torchvision.datasets import MNIST#获取MNIST的数据集
mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=my_transforms)
#root表示下载路径,训练模式取训练集,是否下载:是

注:当我们后期要训练自己的数据集的时候需要将MNIST 数据集换成我们自己的数据集

print(len(mnist_train))#len表示求它的长度,训练集有10000张
print(mnist_train[0])#getitem表示通过索引把图像取出来


运行结果:
60000
(<PIL.Image.Image image mode=L size=28x28 at 0x2175A412320>, 5)

解释说明:
运行结果中60000是训练集的长度,即训练集有60000张图片
第二行返回结果有两部分,一部分是PIL图像,另一部分是图片的标签
标签为5,就表示这个数字是5.

怎样能看一下这张图片是什么呢?

  • 首先需要安装并导入matlab库

import matplotlib.pyplot as plt#安装matlab库并导入matplotlib.pyplot
from torchvision.datasets import MNIST#获取MNIST的数据集
mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=None)
#root表示下载路径,训练模式取训练集,是否下载:是
print(len(mnist_train))#len表示求它的长度,训练集有10000张
print(mnist_train[0][0])#getitem表示通过索引把图像取出来
image = mnist_train[0][0]#取出具体的一张图片
plt.imshow(image)
plt.show()#把图片展示出来
print(mnist_train[0][1])#把图片的标签打印出来


运行结果:
60000
<PIL.Image.Image image mode=L size=28x28 at 0x1DAF3D3BB00>
5

图片展示  

遇到OMP报错的话,在代码中添加下面两行代码即可

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"#前两行代码解决一个OMP报错

 可以参考下面链接:

OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.OMP: Hint_fencecat的博客-CSDN博客OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.OMP: Hinthttps://blog.csdn.net/fencecat/article/details/122887204?spm=1001.2014.3001.5502

有时即使模型再好,识别率也达不到100%;因为有些数字写的实在太飘逸了,标签也是随心所欲😂

二、数据加载

MNIST数据集继承了torch.utils.data.Dataset

需要自己实现__len__和__getitem__两个方法:

  • __len__实现获取数据集长度的操作
  • __getitem__实现获取第几个对象的操作,通过索引的方式把图片取出来。

torch已封装好的加载器

前边已经得到MNIST数据集的实例化对象,接下来就可以进行数据的加载,加载器功能较多,如果自己实现的话会比较复杂,我们可以借助torch已经封装好的加载器来处理

官方文档https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader#导入数据加载器
mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=None)
dataloader = DataLoader(mnist_train, batch_size=2, shuffle=True)
#实例化一个类,传入把训练集,batch_size设为1,shuffle设为true打乱
print(dataloader)#打印一下

运行结果:
<torch.utils.data.dataloader.DataLoader object at 0x00000297C76711D0>

迭代DataLoader类:

# 加载器
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader#导入数据加载器
mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=None)
dataloader = DataLoader(mnist_train, batch_size=2, shuffle=True)
#实例化一个类,传入把训练集,batch_size设为1,shuffle设为true打乱
# print(dataloader)#打印一下
for i in dataloader:
    print(i)


报错:
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>
#batch必须包含张量,numpy数组,数字,字典或列表,不支持PIL图像

怎样才能迭代将PIL图像打印呢?需要引入图像处理

三、transforms图像处理

1.导入transforms方法,并将MNIST数据集的transfrom改为transforms.ToTensor()

#图片处理
#导入transforms方法,并将MNIST数据集中transform改为transforms.ToTensor()
from torchvision import transforms#导入transforms方法
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader#导入数据加载器
mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=transforms.ToTensor())
dataloader = DataLoader(mnist_train, batch_size=2, shuffle=True)
#实例化一个类,传入把训练集,batch_size设为1,shuffle设为true打乱
# print(dataloader)#打印一下
for i in dataloader:
    print(i)

运行结果:将PIL图像转换成了张量形式
[tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],
        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]]), tensor([8, 0])]

2.集合transforms.Compose(transforms)可以将transforms组合起来使用

#图片处理
#导入transforms方法,并将MNIST数据集中transform改为transforms.ToTensor()
from torchvision import transforms#导入transforms方法
from torchvision.datasets import MNIST

my_transforms = transforms.Compose(
    [transforms.PILToTensor()])
from torch.utils.data import DataLoader#导入数据加载器
mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=transforms.ToTensor())
dataloader = DataLoader(mnist_train, batch_size=1, shuffle=True)#实例化一个类,传入把训练集,batch_size设为1,shuffle设为true打乱
for i in dataloader:
    print(i)
    exit()#打印一次后退出

运行结果:
[tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.1608, 0.5961, 0.9137, 0.5961, 0.5961,
           0.2000, 0.0392, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.7961, 0.9922, 0.9882, 0.9922, 0.9882,
           0.9922, 0.6745, 0.1608, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.4000, 0.9961, 0.9922, 0.4000, 0.2392,
           0.6392, 0.9529, 0.9176, 0.2000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.4000, 0.9922, 0.9882, 0.0000, 0.0000,
           0.0000, 0.3176, 0.9922, 0.9098, 0.1608, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.4000, 0.9961, 0.9922, 0.0000, 0.0000,
           0.0000, 0.0000, 0.5176, 0.9922, 0.6392, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0784, 0.8353, 0.9882, 0.1608, 0.0000,
           0.0000, 0.1608, 0.5176, 0.9882, 0.8745, 0.0784, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.3608, 0.9922, 0.8392, 0.2000,
           0.4431, 0.9137, 0.7961, 0.7961, 0.3216, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.2000, 0.9882, 0.9922, 0.9882,
           0.9922, 0.8314, 0.0784, 0.0784, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7961, 0.9961, 0.9922,
           0.5569, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6353, 0.9922, 0.9882,
           0.0784, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.5176, 0.9922, 0.9961, 0.9922,
           0.2431, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.1608, 0.9922, 0.9882, 0.9922, 0.9882,
           0.4000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.6392, 0.9961, 0.6745, 0.5961, 0.9922,
           0.4000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0824, 0.8745, 0.8353, 0.0392, 0.2784, 0.9882,
           0.7176, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.2039, 0.9922, 0.7961, 0.0000, 0.1608, 0.9529,
           0.9961, 0.1961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.5176, 0.9882, 0.7961, 0.0000, 0.0000, 0.7961,
           0.9922, 0.1961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.6000, 0.9922, 0.8000, 0.0000, 0.1216, 0.9137,
           1.0000, 0.1961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.5961, 0.9882, 0.7961, 0.0000, 0.6784, 0.9882,
           0.6745, 0.0392, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.3608, 0.9922, 1.0000, 0.9922, 1.0000, 0.9922,
           0.1608, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0392, 0.5137, 0.8353, 0.9882, 0.9137, 0.2745,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000]]]]), tensor([8])]
Process finished with exit code 0
  • 打印时将images与labels分开
# for i in dataloader:
#     print(i)
#     exit()#打印一次后退出


# for (images, labels) in dataloader:
#     print(images, labels)


#for i in dataloader:
#    print(i[0], i[1])

 3.transfroms方法

官方文档:torchvision.transforms — Torchvision 0.11.0 documentationhttps://pytorch.org/vision/stable/transforms.html

(1) transfroms简介

  • transfroms是一种常用的图像转换方法,他们可以通过Compose方法组合到一起,这样可以实现许多个transfroms对图像进行处理。transfroms方法提供图像的精细化处理,例如在分割任务的情况下  ,你必须建立一个更复杂的转换管道,这时transfroms方法是很有用的。
  • 很多转换器既接受PIL图像,也接受tensor图像。一张tensor图像是形状为(C, H, W)的张量,这里C表示通道数,H和W 是图像的高和宽。1batch 的tensor图像是一个形状为(B, C, H, W) 的张量,这里B表示在batch上有多少张图片。
  • transfroms方法处理过后,会把通道移到最前边。比如MNIST h*w*c为:28*28*1,tensor处理完,通道数会提前,并且做了轴交换,变为了c*h*w为:1*28*28,为什么要这样设计呢?据说是做矩阵加减乘除以及卷积等运算是需要用cuda和cudnn的函数的,而这些接口都设成chw格式了。

a. 轴交换

transfroms方法处理过后,如果我们需要把图片转回PIL,需要进行一次轴交换;因为无法处理一个28通道数的图片。

#轴交换之前打印一下图片的形状
from torchvision import transforms#导入transforms方法
from torchvision.datasets import MNIST

my_transforms = transforms.Compose(
    [transforms.PILToTensor()]
)

from torch.utils.data import DataLoader
mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=  transforms.ToTensor())
dataloader = DataLoader(mnist_train, batch_size=1, shuffle=True)

for (images, labels) in dataloader:
    print(images.shape)
    exit()#打印一次后退出

#运行结果
torch.Size([1, 1, 28, 28])

#说明:
这里第一个“1”表1 batch_size,即一次加载一张图片
第二个“1”表示通道数,后边两个“28”分别表示图片的高和宽
#使用make_grid方法将两张图片融合
from torchvision.utils import make_grid##即使一张图片我们也要将它融合一下,使用make_grid方法
from torchvision import transforms#导入transforms方法
from torchvision.datasets import MNIST

my_transforms = transforms.Compose(
    [transforms.PILToTensor()]
)#将多个transforms组合在一起,还可以加入标准化等图像处理
from torch.utils.data import DataLoader#导入数据加载器



mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=transforms.ToTensor())
dataloader = DataLoader(mnist_train, batch_size=1, shuffle=True)
#实例化一个类,传入把训练集,batch_size设为1,shuffle设为true打乱
print(dataloader)#打印一下
for (images, labels) in dataloader:
    print(make_grid(images).shape)
    exit()#打印一次后退出

#运行结果:
<torch.utils.data.dataloader.DataLoader object at 0x00000289D3634438>
torch.Size([3, 28, 28])
Process finished with exit code 0

#结果图像的形状变成了3*28*28

#如果将上述代码中dataloader = DataLoader(mnist_train, batch_size=1, shuffle=True)换成
dataloader = DataLoader(mnist_train, batch_size=2, shuffle=True)
#运行结果变为torch.Size([3, 32, 62]),相当于把两张图片融合了

b. 使用轴交换边回去

#轴交换
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"#前两行代码解决一个OMP报错
from torchvision.utils import make_grid##即使一张图片我们也要将它融合一下,使用make_grid方法
from torchvision import transforms#导入transforms方法
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt#安装matlab库并导入matplotlib.pyplot

my_transforms = transforms.Compose(
    [transforms.PILToTensor()]
)#将多个transforms组合在一起,还可以加入标准化等图像处理
from torch.utils.data import DataLoader#导入数据加载器


mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=transforms.ToTensor())
dataloader = DataLoader(mnist_train, batch_size=2, shuffle=True)
#实例化一个类,传入把训练集,batch_size设为1,shuffle设为true打乱
print(dataloader)#打印一下
for (images, labels) in dataloader:
    image = make_grid(images).permute(1, 2, 0).numpy()
    #permute(1, 2, 0)实际上就是把通道数移到后边的过程,忘记的话回看第二节课的视频
    #轴交换了之后转换成numpy的数组,之后就可以做加载了
    plt.imshow(image)
    plt.show()
    exit()



#运行结果
<torch.utils.data.dataloader.DataLoader object at 0x0000027DCB7A4710>
Process finished with exit code 0

图片展示:

添加代码:print(labels)可以将标签打印出来

这个操作一般只有调试时才会用,正常运算不需要把tensor图像转换成PIL图像再看一下

(2)进阶了解transfroms方法

 参考文档:PyTorch 学习笔记:transforms的二十二个方法(transforms用法非常详细)_liangbaqiang的博客-CSDN博客_transforms.scale

四、模型和优化器

1.简介

模型四深度学习的关键内容,是深度学习的核心。

深度神经网络的种类主要有:

  • 传统神经网络CNN
  • 卷积网络CNN
  • 循环神经网络(递归神经网络)RNN

目前比较流行的深度神经模型几乎都是卷积和循环两种模型的延伸。

2.全连接层:torch.nn.Linear

(1)简介

官方文档:https://pytorch.org/docs/stable/nn.html

对于MNIST数据集这种简单的,且样本数量足够多的项目,一个全连接层就能达到不错的效果。

后期会对这些模型的“层”进行组合实现。有卷积层、池化层、标准化层等等。

 全连接层指的是层中的每个节点都会连接它下一层的所有节点,它是模仿人脑神经结构来构建的。最左边是输入的是图像,实际上就是图像的像素点,全连接层每层之间都是线性关系。

比如:假设输入为x_{1}x_{2}、……x_{n},那么与输入层直接相连的中间层就是这样计算来的

y_{1}=x_{1}\omega _{1}+x_{2}\omega _{2}\cdots +x_{n}\omega _{n},同理可以计算出第二层的y_{2}y_{3}……y_{n},同样中间各层之间都有一个权重,下一层的输出都是由上一层的每个输入乘以相应的权重累加得出的。最终得到的是两个输出结果,这是一个二分类的问题。输出几个值几分类问题。

(2)全连接层的实现

#全连接层
#首先我们要新建一个类,这个类要继承nn.Module
class MnistModel(nn.Module):
    def __init__(self):#继承__init__方法
        super(MnistModel, self).__init__()
        self.fc2 = nn.Linear(1*28*28, 10)#最初传入的图片的像素点是1*28*28的,最后我们要收敛成10个结果
        #如果先收敛成100个,然后在写一个全连接层
        # self.fc2 = nn.Linear(1 * 28 * 28, 100)
        # self.fc2 = nn.Linear(100, 10)
        #激活函数,激励函数,通过数学手段将线性计算过程进行优化,使其加速。最常用的线性激活函数Relu
        self.relu = nn.ReLU()

    def forward(self, image):#继承前向传播的方法
        image_viwed = image.view(-1, 1*28*28)#此处需要拍平
        out = self.fc2(image_viwed)
        fcl_out = self.relu(out)#激活函数对应一下
        return out

3.优化器

#优化器官方文档:https://pytorch.org/docs/stable/optim.html

(1)简介

  • 优化器的作用就是寻求模型最优解,优化器有梯度下降,动量优化,自适应优化等,梯度下降是最原始的,也是最基础的。
  • 梯度下降算法,载入数据集,计算所有梯度,然后执行决策。依据是损失函数,通过损失进行每一步计算,梯度下降算法分为:标准梯度下降法、批量梯度下降法和随机梯度下降法。

(2)优化器实现 

from  torch import optim#导入优化器
#需要把实例化的模型传进去
model = MnistModel()
optim.Adam(model.parameters(), lr=1e-4)#这是一种自适应的优化器,不需要调参
#lr表示学习率,1e-4表示10的4次方

#优化器官方文档:https://pytorch.org/docs/stable/optim.html

4.损失函数

(1)简介

  • 损失函数,设计一个损失函数的计算方法,让他统一一个损失值,算出一个结论,进而判断下次模型要朝着那个方向去优化权重,最终损失函数的选择取决于最终的结果和标签之间的关系。
  • 每一种损失函数都对应着一种数学模型计算,目的就是把模型训练结果与标签之间建立起关系,在梯度下降优化器中,让损失不断减小的方向就是训练的方向 #损失函数的实现

(2)损失函数实现

LOST = nn.CTCLoss()#调用nn的损失函数,实例化
LOST(MODEL_RESULT, LABELS)#把模型的结果和标签传进去,得到一个数字就是损失值,就是优化器朝哪个方向去做的一个依据

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Zkaisen

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值