paddlepaddle高层API训练营第一天——手写数字识别

0. 导入相关库

import paddle
import paddle.vision.transforms as T
import numpy as np
import matplotlib.pyplot as plt
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
# 查看paddle的版本及是否成功导入
print(paddle.__version__)
2.0.0

1. 数据准备

# 数据的加载和预处理(同时进行归一化)
transform = T.Normalize(mean=[127.5], std=[127.5])
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
eval_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)

# 数据集的查看
print(train_dataset[0][0].shape)  # 图片的尺寸
print(train_dataset[0][1][0])  # 标签值

# 可视化展示
plt.figure()
plt.imshow(train_dataset[0][0].reshape([28, 28]), cmap=plt.cm.binary)
plt.show()
(1, 28, 28)
5
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  'a.item() instead', DeprecationWarning, stacklevel=1)

在这里插入图片描述

2. 模型的选择和开发

# 模型网络结构搭建
network = paddle.nn.Sequential(
    paddle.nn.Flatten(),                # input: [-1, 28, 28], output: [-1, 784]
    paddle.nn.Linear(784, 512),         # input: [-1, 784], output: [-1, 512]
    paddle.nn.ReLU(),                   # input: [-1, 512], output: [-1, 784]
    paddle.nn.Linear(512, 10)           # input: [-1, 512], output: [-1, 10]
)

# 封装模型
model = paddle.Model(network)

# 参数预加载
# model.load('model/mnist/mnist')

# 模型的配置(优化器、损失函数和评估指标)
model.prepare(
    paddle.optimizer.Adam(learning_rate=0.001, parameters=network.parameters()),  # 优化器
    paddle.nn.CrossEntropyLoss(),  # 损失函数
    paddle.metric.Accuracy()  # 评估指标
)

# 模型可视化
model.summary((1, 28, 28))

---------------------------------------------------------------------------
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
   Flatten-1       [[1, 28, 28]]           [1, 784]              0       
   Linear-1          [[1, 784]]            [1, 512]           401,920    
    ReLU-1           [[1, 512]]            [1, 512]              0       
   Linear-2          [[1, 512]]            [1, 10]             5,130     
===========================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 1.55
Estimated Total Size (MB): 1.57
---------------------------------------------------------------------------
{'total_params': 407050, 'trainable_params': 407050}
# 训练参数设置并启动训练
model.fit(
    train_dataset,   # 训练集
    eval_dataset,    # 测试集
    epochs=5,        # 训练轮次
    batch_size=64,   # 批大小
    verbose=1        # 日志展示形式
)
The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/5
step  30/938 [..............................] - loss: 0.5047 - acc: 0.6646 - ETA: 9s - 11ms/ste

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return (isinstance(seq, collections.Sequence) and


step  50/938 [>.............................] - loss: 0.6721 - acc: 0.7375 - ETA: 8s - 10ms/stepstep 938/938 [==============================] - loss: 0.1966 - acc: 0.9141 - 11ms/step        
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0678 - acc: 0.9481 - 7ms/step        
Eval samples: 10000
Epoch 2/5
step 938/938 [==============================] - loss: 0.0522 - acc: 0.9590 - 18ms/step         
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0126 - acc: 0.9621 - 6ms/step         
Eval samples: 10000
Epoch 3/5
step 938/938 [==============================] - loss: 0.0171 - acc: 0.9686 - 17ms/step        
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0100 - acc: 0.9694 - 6ms/step         
Eval samples: 10000
Epoch 4/5
step 938/938 [==============================] - loss: 0.0055 - acc: 0.9741 - 17ms/step        
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0049 - acc: 0.9677 - 6ms/step         
Eval samples: 10000
Epoch 5/5
step 938/938 [==============================] - loss: 0.1630 - acc: 0.9770 - 17ms/step        
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0020 - acc: 0.9743 - 6ms/step         
Eval samples: 10000

3. 模型评估

res = model.evaluate(eval_dataset, verbose=1)
print(res)
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10000/10000 [==============================] - loss: 1.1921e-07 - acc: 0.9743 - 2ms/step         
Eval samples: 10000
{'loss': [1.192093e-07], 'acc': 0.9743}

4. 模型预测

def draw(img, label):
    plt.figure()
    plt.title("predict: {}".format(label))
    plt.imshow(img.reshape([28, 28]), cmap=plt.cm.binary)
    plt.show()
# 1,批量预测
res = model.predict(eval_dataset)
indexs = [2, 15, 38, 211]
for idx in indexs:
    draw(eval_dataset[idx][0], np.argmax(res[0][idx]))
Predict begin...
step 10000/10000 [==============================] - 2ms/step        
Predict samples: 10000

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  'a.item() instead', DeprecationWarning, stacklevel=1)

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

# 2, 单张图片预测
img = eval_dataset[501][0]
res = model.predict_batch([img])
print(res)
draw(img, np.argmax(res))
[array([[ -3.2495534, -16.848299 ,  -2.9633887,  -4.262614 ,  -2.6580017,
         -5.9259615, -13.474122 ,  -1.5577601,  -3.4419522,  13.671286 ]],
      dtype=float32)]

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  'a.item() instead', DeprecationWarning, stacklevel=1)

在这里插入图片描述

5. 保存模型

# 保存动态图模型
model.save('model/mnist/mnist')

# 保存静态图模型
# model.save('model/mnist/mnist', training=False)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值