0. 导入相关库
import paddle
import paddle.vision.transforms as T
import numpy as np
import matplotlib.pyplot as plt
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import Sized
# 查看paddle的版本及是否成功导入
print(paddle.__version__)
2.0.0
1. 数据准备
# 数据的加载和预处理(同时进行归一化)
transform = T.Normalize(mean=[127.5], std=[127.5])
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
eval_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
# 数据集的查看
print(train_dataset[0][0].shape) # 图片的尺寸
print(train_dataset[0][1][0]) # 标签值
# 可视化展示
plt.figure()
plt.imshow(train_dataset[0][0].reshape([28, 28]), cmap=plt.cm.binary)
plt.show()
(1, 28, 28)
5
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
return list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
'a.item() instead', DeprecationWarning, stacklevel=1)
2. 模型的选择和开发
# 模型网络结构搭建
network = paddle.nn.Sequential(
paddle.nn.Flatten(), # input: [-1, 28, 28], output: [-1, 784]
paddle.nn.Linear(784, 512), # input: [-1, 784], output: [-1, 512]
paddle.nn.ReLU(), # input: [-1, 512], output: [-1, 784]
paddle.nn.Linear(512, 10) # input: [-1, 512], output: [-1, 10]
)
# 封装模型
model = paddle.Model(network)
# 参数预加载
# model.load('model/mnist/mnist')
# 模型的配置(优化器、损失函数和评估指标)
model.prepare(
paddle.optimizer.Adam(learning_rate=0.001, parameters=network.parameters()), # 优化器
paddle.nn.CrossEntropyLoss(), # 损失函数
paddle.metric.Accuracy() # 评估指标
)
# 模型可视化
model.summary((1, 28, 28))
---------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
===========================================================================
Flatten-1 [[1, 28, 28]] [1, 784] 0
Linear-1 [[1, 784]] [1, 512] 401,920
ReLU-1 [[1, 512]] [1, 512] 0
Linear-2 [[1, 512]] [1, 10] 5,130
===========================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 1.55
Estimated Total Size (MB): 1.57
---------------------------------------------------------------------------
{'total_params': 407050, 'trainable_params': 407050}
# 训练参数设置并启动训练
model.fit(
train_dataset, # 训练集
eval_dataset, # 测试集
epochs=5, # 训练轮次
batch_size=64, # 批大小
verbose=1 # 日志展示形式
)
The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/5
step 30/938 [..............................] - loss: 0.5047 - acc: 0.6646 - ETA: 9s - 11ms/ste
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
return (isinstance(seq, collections.Sequence) and
step 50/938 [>.............................] - loss: 0.6721 - acc: 0.7375 - ETA: 8s - 10ms/stepstep 938/938 [==============================] - loss: 0.1966 - acc: 0.9141 - 11ms/step
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0678 - acc: 0.9481 - 7ms/step
Eval samples: 10000
Epoch 2/5
step 938/938 [==============================] - loss: 0.0522 - acc: 0.9590 - 18ms/step
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0126 - acc: 0.9621 - 6ms/step
Eval samples: 10000
Epoch 3/5
step 938/938 [==============================] - loss: 0.0171 - acc: 0.9686 - 17ms/step
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0100 - acc: 0.9694 - 6ms/step
Eval samples: 10000
Epoch 4/5
step 938/938 [==============================] - loss: 0.0055 - acc: 0.9741 - 17ms/step
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0049 - acc: 0.9677 - 6ms/step
Eval samples: 10000
Epoch 5/5
step 938/938 [==============================] - loss: 0.1630 - acc: 0.9770 - 17ms/step
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 0.0020 - acc: 0.9743 - 6ms/step
Eval samples: 10000
3. 模型评估
res = model.evaluate(eval_dataset, verbose=1)
print(res)
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 10000/10000 [==============================] - loss: 1.1921e-07 - acc: 0.9743 - 2ms/step
Eval samples: 10000
{'loss': [1.192093e-07], 'acc': 0.9743}
4. 模型预测
def draw(img, label):
plt.figure()
plt.title("predict: {}".format(label))
plt.imshow(img.reshape([28, 28]), cmap=plt.cm.binary)
plt.show()
# 1,批量预测
res = model.predict(eval_dataset)
indexs = [2, 15, 38, 211]
for idx in indexs:
draw(eval_dataset[idx][0], np.argmax(res[0][idx]))
Predict begin...
step 10000/10000 [==============================] - 2ms/step
Predict samples: 10000
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
'a.item() instead', DeprecationWarning, stacklevel=1)
# 2, 单张图片预测
img = eval_dataset[501][0]
res = model.predict_batch([img])
print(res)
draw(img, np.argmax(res))
[array([[ -3.2495534, -16.848299 , -2.9633887, -4.262614 , -2.6580017,
-5.9259615, -13.474122 , -1.5577601, -3.4419522, 13.671286 ]],
dtype=float32)]
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
'a.item() instead', DeprecationWarning, stacklevel=1)
5. 保存模型
# 保存动态图模型
model.save('model/mnist/mnist')
# 保存静态图模型
# model.save('model/mnist/mnist', training=False)