PaddlePaddle图像分类神经网络构建正则化笔记

本文主要根据第二次作业进行分析

作业要求:

补全网络代码,并运行手写数字识别项目。以出现最后的图片和预测结果为准。

 

首先导入必要的包

numpy---------->python第三方库,用于进行科学计算

PIL------------> Python Image Library,python第三方图像处理库

matplotlib----->python的绘图库 pyplot:matplotlib的绘图框架

os------------->提供了丰富的方法来处理文件和目录


#导入需要的包
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import os
import paddle
print("本教程基于Paddle的版本号为:"+paddle.__version__)
本教程基于Paddle的版本号为:2.0.1

 

Step1:准备数据。

(1)数据集介绍

MNIST数据集包含60000个训练集和10000测试数据集。分为图片和标签,图片是28*28的像素矩阵,标签为0~9共10个数字。

(2)transform函数是定义了一个归一化标准化的标准

(3)train_dataset和test_dataset

paddle.vision.datasets.MNIST()中的mode='train'和mode='test'分别用于获取mnist训练集和测试集

transform=transform参数则为归一化标准

#导入数据集Compose的作用是将用于数据集预处理的接口以列表的方式进行组合。
#导入数据集Normalize的作用是图像归一化处理,支持两种方式: 1. 用统一的均值和标准差值对图像的每个通道进行归一化处理; 2. 对每个通道指定不同的均值和标准差值进行归一化处理。
from paddle.vision.transforms import Compose, Normalize
transform = Compose([Normalize(mean=[127.5],std=[127.5],data_format='CHW')])
# 使用transform对数据集做归一化
print('下载并加载训练数据')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('加载完成')
下载并加载训练数据
加载完成
#让我们一起看看数据集中的图片是什么样子的
train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]
train_data0 = train_data0.reshape([28,28])
plt.figure(figsize=(2,2))
print(plt.imshow(train_data0, cmap=plt.cm.binary))
print('train_data0 的标签为: ' + str(train_label_0))
AxesImage(18,18;111.6x108.72)
train_data0 的标签为: [5]
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  'a.item() instead', DeprecationWarning, stacklevel=1)
#让我们再来看看数据样子是什么样的吧
print(train_data0)
[[-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -0.9764706  -0.85882354 -0.85882354 -0.85882354 -0.01176471  0.06666667
   0.37254903 -0.79607844  0.3019608   1.          0.9372549  -0.00392157
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -0.7647059  -0.7176471  -0.2627451   0.20784314
   0.33333334  0.9843137   0.9843137   0.9843137   0.9843137   0.9843137
   0.7647059   0.34901962  0.9843137   0.8980392   0.5294118  -0.49803922
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -0.6156863   0.8666667   0.9843137   0.9843137   0.9843137
   0.9843137   0.9843137   0.9843137   0.9843137   0.9843137   0.96862745
  -0.27058825 -0.35686275 -0.35686275 -0.56078434 -0.69411767 -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -0.85882354  0.7176471   0.9843137   0.9843137   0.9843137
   0.9843137   0.9843137   0.5529412   0.42745098  0.9372549   0.8901961
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -0.37254903  0.22352941 -0.16078432  0.9843137
   0.9843137   0.60784316 -0.9137255  -1.         -0.6627451   0.20784314
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -0.8901961  -0.99215686  0.20784314
   0.9843137  -0.29411766 -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.          0.09019608
   0.9843137   0.49019608 -0.9843137  -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -0.9137255
   0.49019608  0.9843137  -0.4509804  -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -0.7254902   0.8901961   0.7647059   0.25490198 -0.15294118 -0.99215686
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -0.3647059   0.88235295  0.9843137   0.9843137  -0.06666667
  -0.8039216  -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -0.64705884  0.45882353  0.9843137   0.9843137
   0.1764706  -0.7882353  -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -0.8745098  -0.27058825  0.9764706
   0.9843137   0.46666667 -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.          0.9529412
   0.9843137   0.9529412  -0.49803922 -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -0.6392157   0.01960784  0.43529412  0.9843137
   0.9843137   0.62352943 -0.9843137  -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -0.69411767  0.16078432  0.79607844  0.9843137   0.9843137   0.9843137
   0.9607843   0.42745098 -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -0.8117647  -0.10588235
   0.73333335  0.9843137   0.9843137   0.9843137   0.9843137   0.5764706
  -0.3882353  -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -0.81960785 -0.48235294  0.67058825  0.9843137
   0.9843137   0.9843137   0.9843137   0.5529412  -0.3647059  -0.9843137
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -0.85882354  0.34117648  0.7176471   0.9843137   0.9843137   0.9843137
   0.9843137   0.5294118  -0.37254903 -0.92941177 -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -0.5686275   0.34901962
   0.77254903  0.9843137   0.9843137   0.9843137   0.9843137   0.9137255
   0.04313726 -0.9137255  -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.          0.06666667  0.9843137
   0.9843137   0.9843137   0.6627451   0.05882353  0.03529412 -0.8745098
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]
 [-1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.         -1.         -1.
  -1.         -1.         -1.         -1.        ]]

Step2.网络配置

以下的代码判断就是定义一个简单的多层感知器,一共有三层,两个大小为100的隐层和一个大小为10的输出层,因为MNIST数据集是手写0到9的灰度图像,类别有10个,所以最后的输出大小是10。最后输出层的激活函数是Softmax,所以最后的输出层相当于一个分类器。加上一个输入层的话,多层感知器的结构是:输入层-->>隐层-->>隐层-->>输出层。

 

重头戏开始了!!!!

构建的每一个步骤都去查了查官方的API文档说明,大致了解了构建神经网络的过程。

问题最大的就在于“正则化”这个概念:

正则化简单来说就是找一条曲线去贴合坐标系的点,当点逐渐多的时候,曲线的契合度如果过高会导致预测结果没有普遍性(由于点的离散程度不同)

正则化

如上图,在拟合的过程中我们要保证贴合效果,需要“舍弃”一些点,那么到底舍弃多少呢?我选择了20%的点

paddle.nn.Dropout()默认是0.5,也就是舍弃50%的点去进行拟合,从而达到预测效果

这里附上大佬写的正则化详解:【直观详解】什么是正则化 https://blog.csdn.net/gqixf/article/details/85319510

# 定义多层感知器 
#动态图定义多层感知器
class multilayer_perceptron(paddle.nn.Layer):
    def __init__(self):
        super(multilayer_perceptron,self).__init__()
        #请在这里补全网络代码
        self.flatten = paddle.nn.Flatten()#将一个连续维度的Tensor展平成一维Tensor
        self.linear_1 = paddle.nn.Linear(784, 128)##第一层感知机
        self.linear_2 = paddle.nn.Linear(128, 10)##第二层感知机
        self.relu = paddle.nn.ReLU()##激活层
        self.dropout = paddle.nn.Dropout(0.2)##正则化,归零率为0.5

    def forward(self, x):
        #请在这里补全传播过程的代码
        y = self.flatten(x)
        y = self.linear_1(y)
        y = self.relu(y)
        y = self.dropout(y)
        y = self.linear_2(y)
        return y
##自定义感知网络
#请在这里定义卷积网络的代码
#注意:定义完成卷积的代码后,后面的代码是需要修改的!
LeNet = multilayer_perceptron()
# 配置模型
model.prepare(optim,paddle.nn.CrossEntropyLoss(),Accuracy())


# 训练保存并验证模型
model.fit(train_dataset,test_dataset,epochs=2,batch_size=64,save_dir='multilayer_perceptron',verbose=1)

运行时长: 15秒660毫秒
结束时间: 2021-03-12 09:56:12
The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/2
step  30/938 [..............................] - loss: 0.9485 - acc: 0.5526 - ETA: 7s - 8ms/step 
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return (isinstance(seq, collections.Sequence) and
step 938/938 [==============================] - loss: 0.2986 - acc: 0.8793 - 7ms/step         
save checkpoint at /home/aistudio/multilayer_perceptron/0
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0635 - acc: 0.9431 - 6ms/step         
Eval samples: 10000
Epoch 2/2
step 938/938 [==============================] - loss: 0.1512 - acc: 0.9350 - 8ms/step         
save checkpoint at /home/aistudio/multilayer_perceptron/1
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0231 - acc: 0.9546 - 6ms/step         
Eval samples: 10000
save checkpoint at /home/aistudio/multilayer_perceptron/final
# 训练保存并验证模型
model.fit(train_dataset,test_dataset,epochs=2,batch_size=64,save_dir='multilayer_perceptron',verbose=1)
运行时长: 17秒487毫秒
结束时间: 2021-03-12 09:56:30
The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/2
step 938/938 [==============================] - loss: 0.1907 - acc: 0.9479 - 10ms/step         
save checkpoint at /home/aistudio/multilayer_perceptron/0
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0292 - acc: 0.9619 - 6ms/step         
Eval samples: 10000
Epoch 2/2
step 938/938 [==============================] - loss: 0.0512 - acc: 0.9520 - 7ms/step         
save checkpoint at /home/aistudio/multilayer_perceptron/1
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0102 - acc: 0.9659 - 6ms/step         
Eval samples: 10000
save checkpoint at /home/aistudio/multilayer_perceptron/final
#获取测试集的第一个图片
test_data0, test_label_0 = test_dataset[0][0],test_dataset[0][1]
test_data0 = test_data0.reshape([28,28])
plt.figure(figsize=(2,2))
#展示测试集中的第一个图片
print(plt.imshow(test_data0, cmap=plt.cm.binary))
print('test_data0 的标签为: ' + str(test_label_0))
#模型预测
result = model.predict(test_dataset, batch_size=1)
#打印模型预测的结果

运行时长: 12秒828毫秒
结束时间: 2021-03-12 09:56:42
AxesImage(18,18;111.6x108.72)
test_data0 的标签为: [7]
Predict begin...
step 10000/10000 [==============================] - 1ms/step        
Predict samples: 10000
test_data0 预测的数值为:7

写在最后:

一直想学习一些深度学习的东西,感觉这个图像分类课程很值得学习

附上课程连接:

https://aistudio.baidu.com/aistudio/course/introduce/11939?directly=1&shared=1

  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

迟暮 .

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值