本文主要叙述MindSpore下model的各参数设置、训练、验证过程。
1、需要输入预处理好的数据集,可参考文章:MindSpore数据预处理,并生成batch装载数据中的函数,能够生成batch数据集。
2、本文中代码使用了文章:MindSpore自定义在train的过程中实时验证的回调函数中的evalcallback类。
3、本文中代码使用了文章:MindSpore自定义网络模型中的网络模型lenet5b。
def process(my_net, net_name, batch__size=32):
# 获取当前文件夹路径,再根据不同过程添加对应的数据集存放路径
current_path = os.getcwd()
train_data_path = os.path.join(current_path, 'data\\10-batches-bin')
test_data_path = os.path.join(current_path, 'data\\10-verify-bin')
# 设置运行过程标志,与process_dataset()函数有关联
status = "train"
print("=============== 来到训练过程 ==============")
# 生成训练数据集
train_ds = ds.Cifar10Dataset(train_data_path)
ds_train = my_dataset(train_ds, batch__size=batch__size, status=status)
# 构建网络
network = my_net(10, 3)
# 返回当前设备
device_target = mindspore.context.get_context('device_target')
print("--------------- 当前设备为: {} ---------------".format(device_target))
# 确定图模型是否下沉到芯片上
dataset_sink_mode = True if device_target in ['Ascend', 'GPU'] else False
# 设置模型的设备与图的模式
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
# 使用交叉熵函数作为损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
# 优化器为Adam
net_opt = nn.Adam(params=network.trainable_params(), learning_rate=0.001)
# 建立可训练模型
model = Model(network=network, loss_fn=net_loss, optimizer=net_opt, metrics={"Accuracy": Accuracy()})
# 设置CheckpointConfig,callback函数。save_checkpoint_steps=训练总数/batch_size
ckp_config = CheckpointConfig(save_checkpoint_steps=1562, keep_checkpoint_max=10)
ckp_cb = ModelCheckpoint(prefix="checkpoint_{}_verified".format(net_name), directory='results', config=ckp_config)
# 监控每个epoch训练的时间
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
# 监控每次打印时的loss值
loss_cb = LossMonitor(per_print_times=1)
# 使用自定义的EvalCallBack, 记录每次测试结果并打印
per_eval = {"epoch": [], "acc": []}
eval_cb = EvalCallBack(model, ds_train, 1, per_eval, dataset_sink_mode)
print("============== 开始训练 ==============")
model.train(25, ds_train, callbacks=[ckp_cb, loss_cb, eval_cb, time_cb], dataset_sink_mode=dataset_sink_mode)
print("============== 训练结束 ==============")
# -----------------------------------------------------------------------------------------------------------------------
# 设置运行过程标志,与process_dataset()函数有关联
status = "test"
print("============== 来到测试过程 ==============")
# 生成测试数据集
test_ds = ds.Cifar10Dataset(test_data_path)
ds_eval = my_dataset(test_ds, batch__size=batch__size, status=status)
print("============== 开始测试 ==============")
res = model.eval(ds_eval, dataset_sink_mode=dataset_sink_mode)
print("============== 测试结束 ==============")
# 评估测试集
print('测试结果:', res)
if __name__ == "__main__":
# print("------------------ 使用网络LeNet5a -----------------")
# process(my_net=LeNet5a, net_name="LeNet5a", batch__size=32)
print("------------------ 使用网络LeNet5b -----------------")
process(my_net=LeNet5b, net_name="LeNet5b", batch__size=32)