神经网络分类任务---手写数字识别

Mnist分类任务

网络基本构建与训练方法,常用函数解析

torch.nn.functional模块

nn.Module模块

读取Mnist数据集

会自动进行下载(如果下载出问题可以在这个github下载,在文件根目录下创建data/mnist/,放入mnist文件夹即可)

%matplotlib inline
from pathlib import Path
import requests
​
DATA_PATH =Path("data")
PATH = DATA_PATH / "mnist"
​
PATH.mkdir(parents=True, exist_ok=True)
​
URL = "http://deeplearning.net/mnist.pkl.gz"
FILENAME = "mnist.pkl.gz"
​
if not (PATH / FILENAME).exists():
    content = requests.get(URL + FILENAME).content
    (PATH / FILENAME).open("wb").write(content)
import pickle
import gzip
​
with gzip.open((PATH / FILENAME).as_posix(),"rb") as f:
    ((x_train,y_train), (x_valid,y_valid),_) = pickle.load(f, encoding="latin-1")

from matplotlib import pyplot

import numpy as np



pyplot.imshow(x_train[0].reshape((28,28)),cmap="gray")
print(x_train.shape)
(50000, 784)

50000是样本数字,784是像素点个数=28*28*1= h*w*颜色通道, 在黑白图中只有一个颜色通道

import torch
x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape
x_train, x_train.shape, y_train.min(), y_train.max()
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]) tensor([5, 0, 4,  ..., 8, 4, 8])
torch.Size([50000, 784])
tensor(0) tensor(9)

torch.nn.functional中有很多功能,后续会常用到。那么什么时候用nn.Module,什么时候用nn.function呢,一般情况下,如果模型有可学习的参数,最好用 nn.Module,其他情况nn.function会更简单一些

import torch.nn.functional as F
loss_func = F.cross_entropy
def model(xb):
    return xb.mm(weights) + bias
bs = 64
xb = x_train[0:bs]
yb = y_train[0:bs]
weights = torch.randn([784, 10], dtype = torch.float, requires_grad=True)
bs = 64
bias = torch.zeros(10,requires_grad=True)
print(loss_func(model(xb),yb))
tensor(13.7654, grad_fn=<NllLossBackward0>)

创建一个model来简化代码

必须继承nn.Module且在其构造函数中需要调用nn.Module的构造函数

无需写反向传播函数,nn.Module能够利用autograd自动实现反向传播

Module中的可学习参数可以通过named_parameters()或者paramters()返回迭代器

from torch import nn
class Mnist_NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(784, 128)
        self.hidden2 = nn.Linear(128, 256)
        self.out = nn.Linear(256, 10)

    def forward(self,x):
        x = F.relu(self.hidden1(x))
        x = F.relu(self.hidden2(x))
        x = self.out(x)
        return x
net = Mnist_NN()
print(net)
Mnist_NN(
  (hidden1): Linear(in_features=784, out_features=128, bias=True)
  (hidden2): Linear(in_features=128, out_features=256, bias=True)
  (out): Linear(in_features=256, out_features=10, bias=True)
)

可以打印我们定义好名字里的权重和偏置项

for name, parameter in net.named_parameters():
    print(name,parameter,parameter.size())
hidden1.weight Parameter containing: tensor([[-0.0325, -0.0265, 0.0178, ..., 0.0236, -0.0253, -0.0139], [ 0.0308, 0.0199, -0.0114, ..., 0.0319, -0.0297, 0.0206], [-0.0147, -0.0082, 0.0126, ..., -0.0241, -0.0222, 0.0104], ..., [-0.0152, -0.0187, 0.0223, ..., 0.0195, 0.0154, 0.0051], [ 0.0291, 0.0240, -0.0014, ..., -0.0316, 0.0345, 0.0214], [ 0.0138, 0.0118, 0.0059, ..., -0.0114, 0.0332, -0.0270]], requires_grad=True) torch.Size([128, 784]) hidden1.bias Parameter containing: tensor([ 0.0219, 0.0226, 0.0348, -0.0243, 0.0202, -0.0118, 0.0346, -0.0002, -0.0283, -0.0212, -0.0095, -0.0199, 0.0057, 0.0276, 0.0143, 0.0054, -0.0161, 0.0051, -0.0192, -0.0025, 0.0189, -0.0038, 0.0085, -0.0146, 0.0177, 0.0061, -0.0138, 0.0245, 0.0140, 0.0343, 0.0356, 0.0205, -0.0245, -0.0167, 0.0282, -0.0108, 0.0239, -0.0163, -0.0229, 0.0036, 0.0009, 0.0230, -0.0138, -0.0155, -0.0064, 0.0093, -0.0183, 0.0110, -0.0118, -0.0205, 0.0016, -0.0222, -0.0291, -0.0331, -0.0143, 0.0260, -0.0066, -0.0024, 0.0276, -0.0136, -0.0265, 0.0211, 0.0253, -0.0289, -0.0162, 0.0335, 0.0345, -0.0209, -0.0156, -0.0180, 0.0076, -0.0350, 0.0277, 0.0195, -0.0021, 0.0170, 0.0041, -0.0119, -0.0008, -0.0015, 0.0237, 0.0309, -0.0213, 0.0351, 0.0204, 0.0217, 0.0099, -0.0098, -0.0136, -0.0297, 0.0317, -0.0105, 0.0293, 0.0130, 0.0253, 0.0099, -0.0207, -0.0041, 0.0094, 0.0176, -0.0293, 0.0131, 0.0274, -0.0302, 0.0082, -0.0308, -0.0162, 0.0076, 0.0265, -0.0062, 0.0130, -0.0330, -0.0103, 0.0308, 0.0115, 0.0160, -0.0197, -0.0085, 0.0040, 0.0033, -0.0086, 0.0281, 0.0263, -0.0082, 0.0184, -0.0117, -0.0284, -0.0238], requires_grad=True) torch.Size([128]) hidden2.weight Parameter containing: tensor([[-0.0340, -0.0734, 0.0346, ..., 0.0559, -0.0386, -0.0846], [ 0.0821, -0.0031, 0.0393, ..., 0.0061, 0.0190, -0.0220], [-0.0009, 0.0106, -0.0100, ..., -0.0401, 0.0322, -0.0491], ..., [-0.0390, 0.0538, -0.0025, ..., 0.0705, -0.0584, -0.0758], [-0.0805, 0.0092, -0.0577, ..., -0.0677, 0.0392, 0.0216], [-0.0809, -0.0496, 0.0092, ..., 0.0328, -0.0212, -0.0592]], requires_grad=True) torch.Size([256, 128]) hidden2.bias Parameter containing: tensor([ 3.4251e-02, 6.6245e-02, -5.3668e-02, -2.6804e-02, 6.2958e-02, -9.6022e-03, -6.8485e-02, 4.8541e-02, -5.9051e-02, 4.6563e-02, -5.6918e-02, 1.7858e-02, 8.2025e-03, 2.9629e-03, -1.5890e-02, -6.8271e-02, -6.1588e-02, -4.5848e-02, 9.3048e-03, -2.9849e-02, 8.0482e-02, 7.4823e-02, -8.7341e-02, -2.9350e-02, -3.2482e-02, -5.5491e-02, -3.0718e-02, 3.8584e-03, 7.6252e-02, 6.1728e-02, -5.9369e-02, 1.9814e-02, -4.4111e-02, -8.1931e-02, 5.8540e-02, -7.8754e-02, 6.1414e-02, 5.6715e-02, -3.9198e-02, 5.3766e-02, -8.8324e-02, -1.5136e-02, 2.3646e-02, 5.9494e-02, 2.2706e-02, -6.8846e-02, 5.5657e-02, 1.4645e-03, 6.5911e-02, 2.0665e-02, 6.7541e-02, -7.0019e-02, -3.5764e-02, 4.7367e-02, -8.4237e-02, 1.7093e-02, -8.1570e-02, 1.0706e-02, -5.1387e-02, 4.2129e-02, 1.8284e-03, -3.9581e-02, -3.6075e-02, -6.6867e-02, -2.4424e-02, -1.0475e-02, 5.6918e-02, -6.2000e-02, 5.0072e-02, -8.2728e-02, -7.2127e-02, 3.4423e-02, -6.4720e-02, 5.8055e-02, 1.6779e-02, -7.0746e-02, -5.2716e-02, -5.1131e-03, 9.4748e-03, 2.4467e-02, 8.2609e-02, -3.9185e-02, -7.0271e-02, 3.6122e-02, -5.2537e-02, 1.5991e-02, -4.4633e-02, 6.4542e-02, -6.3199e-02, 4.3374e-02, 4.8441e-02, -4.5691e-02, -3.7080e-02, -5.6189e-02, 3.5212e-03, 3.6487e-02, -7.0284e-02, -6.9327e-02, -5.8588e-02, 7.6064e-02, -3.9589e-02, -5.5521e-02, -5.5006e-02, 2.7749e-03, -5.6363e-02, 1.9204e-03, -7.5818e-02, -5.8002e-02, 5.6914e-02, -5.1079e-02, 4.6740e-02, 1.7789e-02, -5.2705e-03, -1.1432e-02, 6.5533e-02, 2.4519e-02, -8.1965e-02, 1.8052e-02, 8.7573e-02, 8.9896e-03, 8.2436e-02, 5.1366e-02, 8.4385e-03, -7.5690e-02, 4.4200e-02, 9.7071e-03, -7.8598e-02, -4.1634e-02, -8.3067e-02, -7.1623e-03, 2.3093e-02, -7.4160e-02, 5.5457e-02, -1.6331e-02, 8.2332e-02, 6.6278e-03, -4.3818e-02, -6.5338e-02, -2.7475e-02, -5.3869e-02, -3.6781e-02, -1.5129e-02, -1.8047e-02, 5.0949e-03, 8.1808e-03, -2.4383e-02, -6.0043e-02, -3.9112e-02, 6.2894e-02, 7.3802e-02, -2.9938e-02, -6.8608e-03, -1.6504e-03, 5.6096e-02, -6.5193e-02, 6.0587e-02, -4.6924e-02, -2.5903e-02, -6.5839e-02, -6.3925e-02, -4.2316e-02, -7.7518e-02, 8.3411e-02, -1.8981e-02, 1.4119e-02, -3.8588e-02, 2.0915e-02, -5.4507e-02, 7.9109e-02, -6.8472e-03, 4.4487e-02, 4.2394e-06, 6.2874e-02, -2.4991e-02, 3.1016e-02, 2.4549e-02, -3.6681e-03, -7.3148e-02, -4.5987e-02, 7.7052e-02, -2.7164e-02, 2.3189e-02, -2.8427e-02, 3.0965e-02, 7.4590e-02, -5.9826e-02, -1.9704e-02, -5.7558e-02, -7.8640e-02, 5.7251e-02, -2.4419e-02, -5.4338e-02, -3.5999e-02, 2.8274e-02, -5.1797e-02, 1.3047e-02, -4.3136e-02, 2.6374e-02, -4.1364e-02, -2.8603e-02, -2.2389e-02, 1.5212e-02, -4.9188e-02, 5.9365e-02, -5.5241e-02, -6.8918e-02, -2.2441e-02, 8.8711e-03, -6.9141e-02, -4.5297e-02, -2.1198e-02, 8.0370e-02, -8.7273e-02, 7.5763e-02, 8.6972e-02, -1.4734e-03, -5.7440e-02, 1.9224e-02, -1.1871e-02, 3.2531e-02, 4.8892e-02, 7.4199e-02, -5.4005e-02, -9.7446e-03, -1.3905e-02, -6.7276e-02, 7.1744e-02, -6.7407e-02, 6.1195e-02, 6.6671e-02, -3.3406e-02, 4.8141e-02, -7.3937e-02, -1.0370e-02, -3.7342e-02, -7.0859e-02, -1.6061e-02, 3.7811e-02, -7.5927e-02, 4.5883e-02, 6.8066e-02, 2.1853e-02, -2.7309e-03, 7.8225e-02, 3.9355e-02, -8.8315e-02, 3.8017e-02, -4.9178e-02, -6.8704e-02, 8.1418e-02, -1.0785e-02, -5.9308e-02, -2.2016e-02, 4.5492e-02, 5.6116e-02, -8.6721e-02], requires_grad=True) torch.Size([256]) out.weight Parameter containing: tensor([[ 0.0394, -0.0572, -0.0190, ..., 0.0153, 0.0324, 0.0284], [-0.0188, 0.0441, -0.0214, ..., 0.0510, 0.0085, 0.0408], [ 0.0054, -0.0528, 0.0188, ..., -0.0515, -0.0436, -0.0300], ..., [ 0.0105, -0.0091, -0.0491, ..., 0.0312, -0.0291, -0.0333], [ 0.0023, 0.0401, 0.0541, ..., -0.0160, 0.0430, -0.0306], [ 0.0065, -0.0045, 0.0404, ..., 0.0034, -0.0315, -0.0219]], requires_grad=True) torch.Size([10, 256]) out.bias Parameter containing: tensor([-0.0267, -0.0337, -0.0451, 0.0126, 0.0417, -0.0079, 0.0078, 0.0493, -0.0088, -0.0089], requires_grad=True) torch.Size([10])

使用TensorDataset和Dateloder来简化

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
​
valid_ds = TensorDataset(x_valid,y_valid)
valis_dl = DataLoader(valid_ds, batch_size=bs * 2)


def get_data(train_ds, valid_ds, bs):
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds,batch_size=bs * 2),
    )

一般模型训练时加上model.train(),这样会正常使用Batch Normalization和Dropout 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和Dropout

import numpy as np
def fit(steps, model, loss_func, opt, train_dl, valid_dl):
    for step in range(steps):
        model.train()
        for xb, yb, in train_dl:
            loss_batch(model, loss_func, xb, yb, opt)

        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb, opt) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
        print('当前step:'+str(step), '验证集损失:'+str(val_loss))
from torch import optim
def get_model():
    model = Mnist_NN()
    return model, optim.SGD(model.parameters(), lr=0.001)


def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)

    if opt is not None:
        loss.requires_grad_(True)
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)

三行搞定

train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(25, model, loss_func, opt, train_dl, valid_dl)
当前step:0 验证集损失:2.2827181243896484
当前step:1 验证集损失:2.2478832061767577
当前step:2 验证集损失:2.1945471561431886
当前step:3 验证集损失:2.1076549449920656
当前step:4 验证集损失:1.9676035526275635
当前step:5 验证集损失:1.764887052345276
当前step:6 验证集损失:1.522310170173645
当前step:7 验证集损失:1.287570206260681
当前step:8 验证集损失:1.0948256267547607
当前step:9 验证集损失:0.9480554515838623
当前step:10 验证集损失:0.8378450462341308
当前step:11 验证集损失:0.7521841385841369
当前step:12 验证集损失:0.6849521107673645
当前step:13 验证集损失:0.6307255907058715
当前step:14 验证集损失:0.5866909859657288
当前step:15 验证集损失:0.5503184311389923
当前step:16 验证集损失:0.5185965992450714
当前step:17 验证集损失:0.4931147786140442
当前step:18 验证集损失:0.4708563284873962
当前step:19 验证集损失:0.452236319065094
当前step:20 验证集损失:0.435598934841156
当前step:21 验证集损失:0.42189919352531435
当前step:22 验证集损失:0.409069230890274
当前step:23 验证集损失:0.39823391227722166
当前step:24 验证集损失:0.3882803885936737

这样一个手写数字识别就完成了。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

荔枝味啊~

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值