了解ResNet18的网络结构;掌握模型的保存和加载方法;掌握批量测试图片的方法。
结合图像分类任务,使用典型的图像分类网络ResNet18,实现手写数字识别。
ResNet作为经典的图像分类网络有其明显的优点:
-
首先,它足够深,常见的有34层,50层,101层。通常层次越深,表征能力越强,分类准确率越高。
-
其次,可学习,采用了残差结构,通过shortcut连接把低层直接跟高层相连,解决了反向传播过程中因为网络太深造成的梯度消失问题。
-
此外,ResNet网络的性能很好,既表现为识别的准确率,也包括它本身模型的大小和参数量。
1. 加载并处理数据集
import os
import sys
import moxing as mox
datasets_dir = '../datasets'
if not os.path.exists(datasets_dir):
os.makedirs(datasets_dir)
if not os.path.exists(os.path.join(datasets_dir, 'MNIST_Data.zip')):
mox.file.copy('obs://modelarts-labs-bj4-v2/course/hwc_edu/python_module_framework/datasets/mindspore_data/MNIST_Data.zip',
os.path.join(datasets_dir, 'MNIST_Data.zip'))
os.system('cd %s; unzip MNIST_Data.zip' % (datasets_dir))
sys.path.insert(0, os.path.join(os.getcwd(), '../datasets/MNIST_Data'))
from load_data_all import load_data_all
from process_dataset import process_dataset
mnist_ds_train, mnist_ds_test, train_len, test_len = load_data_all(datasets_dir) # 加载数据集
mnist_ds_train = process_dataset(mnist_ds_train, batch_size= 64, resize= 28) # 处理训练集,分批加载
mnist_ds_test = process_dataset(mnist_ds_test, batch_size= 32, resize= 28) # 处理测试集, 分批加载
训练集规模:60000,测试集规模:10000