我们先搞一个自定义训练
首先导入库
读取数据集后归一化然后创建新的数据集并设置batch
设置优化器与损失函数
设置四个指标
设置损失计算
设置训练与测试步骤
定义整体训练步骤并训练
目录
1 保存检查点
首先我们创建一个文件夹用于保存检查点文件
- 一定要保存到另一个文件夹中,要不后面的调用会有问题
之后我们添加检查点,使用tf.train.Checkpoint()添加检查点对象,参数为优化器与模型
之后使用save()保存检查点,参数为保存路径
我们一共训练5次,使用这个方法我们每一个epoch都可以获得一个检查点
- 我们同样在循环中可以使用model.save()或model.save_weights()
- 上面只有两次是因为我训练两次后就停掉了
2 使用检查点
我们使用tf.train.latest_checkpoint获取目录中最新的检查点,之后使用checkpoint.restore()调用已有的检查点
- 上面两个方法如果在指定文件夹中没有保存点也不会报错
我们在训练前使用模型测试一下,如果没读进来检查点效果是非常不好的,如果读进来效果就还可以
我们使用预测结果与标签比较一下,之后再计算正确率
可以看到我们第一个batch的正确率达到0.875,这样就说明我们的检查点是已经读进来的