PyTorch---(四)MNIST 手写数字识别 (实战)(上)

b站视频(9,10):https://www.bilibili.com/video/av49008640/?p=9

1 训练过程概览

计算一次 ->
通过误差 loss ->
求梯度 ->
再根据梯度去更新参数w,b

2 回顾

最后一层会根据实际任务需要选择激活函数: siamoid 或者softmax,

3 流程

1 加载数据

2 建立模型

3 训练

4 测试

4 开始写代码

1 代码结构

2 文件1: utils.py辅助文件

# utils.py
import torch
from matplotlib import pyplot as plt


def plot_curve(data):  #下降曲线的绘制
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')  #
    plt.legend(['value'], loc='upper right')  #
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()



def plot_image(img, label, name):  # 画图片,帮助看识别结果
    fig = plt.figure()
    for i in range(6):  # 6个图像,两行三列
        # print(i) 012345
        plt.subplot(2, 3, i+1)
        plt.tight_layout()  # 紧密排版
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')  
        # 均值是0.1307,标准差是0.3081,
        
        plt.title("{}:{}".format(name, label[i].item()))  
        # name:image_sample   label[i].item():数字
        
        plt.xticks([])
        plt.yticks([])
    plt.show()


def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1,  index=idx, value=1)
    return out









 

3 文件2:mnist_train.py主文件


import torch
from torch import nn
from torch.nn import functional as F
from torch import optim

import torchvision
from matplotlib import pyplot as plt

from utils import plot_curve, plot_image, one_hot


batch_size = 512
# step1.load_dataset
# 'mnist_data':加载mnist数据集,路径
# train=True:选择训练集还是测试
# download=True:如果当前文件没有mnist文件就会自动从网上去下载
# torchvision.transforms.ToTensor():下载好的数据一般是numpy格式,转换成Tensor
# torchvision.transforms.Normalisze((0.1307,), (0.3081,)):正则化过程,为了让数据更好的在0的附近均匀的分布
# 上面一行可注释掉:但是性能会差到百分之70,加上是百分之80,更加方便神经网络去优化

train_loader = torch.utils.data.DataLoader(  # 加载训练集
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))  # 均值是0.1307,
                                       # 标准差是0.3081,这些系数都是数据集提供方计算好的数据
                                   ])),
    batch_size=batch_size, shuffle=True)
# batch_size=batch_size:表示一次加载多少张图片
# shuffle = True 加载的时候做一个随机的打散

test_loader = torch.utils.data.DataLoader(  # 加载测试集
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                                   ])),
    batch_size=batch_size, shuffle=False)  # 测试集不用打散

x, y = next(iter(train_loader))
print(x.shape, y.shape)
# print(x.min(), x.max())
# plot_image(x, y, 'image_sample')

4 右键->运行

 

PS:报错如下 

Cannot find reference 'utils' in '__init__.pyi' less... (Ctrl+F1) 
Inspection info: This inspection detects names that should resolve but don't.
ue to dynamic dispatch and duck typing, this is possible in a limited but useful number of cases.
Top-level and class-level items are supported better than instance items.

检查信息:此检查检测应该解析但没有解析的名称。
由于动态分派duck类型,这在有限但有用的情况下是可能的。
顶级和类级项目实例项目更受支持。

PS:问题原因: 

我的目录结构如下,应该用pycharm设置source在标记位置
这样mnist_data.py才能找到utils.py的位置
from utils import plot_curve, plot_image, one_shot

PS:解决办法: https://www.jianshu.com/p/f36f32f34ce1

file->setting->

Project:ProjectPython37-> Project Structure->选定文件夹位置右键->Source(快捷键ALT+S)

5 先下载 mnist文件

PS:中间遇到很多下载问题:下载慢然后停止

如下图

trian-images-ids...成功下载

train-labels-...成功下载


上面两个文件如下

 

t10k-images-idx3...下载停止

PS:解决办法:手动下载

1 直接下载压缩包

urls = [
    'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
    'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
    'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
    'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
]

2 下载之后如下

放在相应文件夹中

3 运行程序解压

 

 

 

 6  输出的结果

x.shap:
512张图,1个通道 ,像素28x28

y.shap:
512个label

 

7 输出改写:显示tensor范围,看看加入Normalize前后效果

原来

x, y = next(iter(train_loader))
print(x.shape, y.shape)

现在

x, y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())

输出范围如下,

 

这是加入Normalize的结果,0到1 的数据做了一个等效的均匀的变换,使得其在1的附近

均值是0.1307,标准差是0.3081,这些系数都是数据集提供方计算好的数据
https://blog.csdn.net/zjc910997316/article/details/93465763

 

 

如果注释掉在运行,果然是 0到1

 

8 utils.py的plot_image功能,展示六张样例

x, y = next(iter(train_loader))
print(x.shape, y.shape)
plot_image(x, y, 'image_sample')

 

plot_image(x, y, 'image_sample')

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

  • 0
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

计算机视觉-Archer

图像分割没有团队的同学可加群

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值