DL with python(9)——TensorFlow实现神经网络模型的断点续训

本文涉及到的是中国大学慕课《人工智能实践:Tensorflow笔记》第四讲第四节的内容,通过tensorflow实现神经网络模型的断点续训。

断点续训

神经网络模型的断点续训指的是将训练好的模型保存下来,并在之后的运行中直接调用的训练方法。

保存模型使用以下代码,将模型参数保存到callbacks中,并将callbacks添加到模型训练的history中。

callbacks = tf.keras.callbacks.ModelCheckpoint( (
filepath= = 路径文件名,
save_weights_only= = True/False,
save_best_only= = True/False) )
history = model.fit( ( callbacks=[cp_callback]

读取模型使用以下代码,基本原理是生成.ckpt文件的同时会生成对应的索引表(.index文件),通过判断索引表是否存在来决定是否导入保存的模型。

checkpoint_save_path = "./checkpoint/mnist.ckpt"           # 给出模型保存的路径以及文件名
if os.path.exists(checkpoint_save_path + '.index'):        # 通过索引表判断保存的模型是否存在
    print('-------------load the model-----------------')  # 是,则打印"导入模型"
    model.load_weights(checkpoint_save_path)               # 导入模型

代码实现

从神经网络搭建的六步法来看,与DL with python(6)——Keras实现手写数字识别(全连接网络)中直接导入mnist数据的代码相比,断点续训的代码在第一步、第四步和第五步有所改动。

# 第一步,导入相关模块,os模块用于判断文件是否存在
import tensorflow as tf
import os
# 第二步,导入数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# 第三步,搭建网络结构
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])
# 第四步,配置训练方法
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])
# 导入保存的模型,第二次运行才可以进行的操作
checkpoint_save_path = "./checkpoint/mnist.ckpt"           # 给出模型保存的路径以及文件名
if os.path.exists(checkpoint_save_path + '.index'):        # 通过索引表判断保存文件是否存在
    print('-------------load the model-----------------')  # 是,则打印"导入模型"
    model.load_weights(checkpoint_save_path)               # 导入模型
# 保存模型,第一次运行执行这一步操作
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,  # 模型保存路径
                                                 save_weights_only=True,         # 只保留模型参数
                                                 save_best_only=True)            # 只保留最优结果

# 第五步,执行训练,依次为训练集样本,训练集标签,小批量大小32,训练轮次5,测试集,训练集循环1轮次进行一次测试
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback]) # 最后添加callbacks回调选项,将前面保存的模型参数赋给后面的模型
# 第六步,打印网络结构和参数统计
model.summary()

第一次运行代码后,会得到一个checkpoint文件夹,其中含有四个文件,含有模型的相关信息。
在这里插入图片描述第一次运行,模型的最终表现

60000/60000 [==============================] - 5s 77us/sample - loss: 0.0454 - sparse_categorical_accuracy: 0.9863 - val_loss: 0.0865 - val_sparse_categorical_accuracy: 0.9752

然后第二次运行,在第一次的基础上进行训练,最终表现如下,可以看到各方面指标都有了很大的进步。

60000/60000 [==============================] - 4s 69us/sample - loss: 0.0180 - sparse_categorical_accuracy: 0.9943 - val_loss: 0.0785 - val_sparse_categorical_accuracy: 0.9782
  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值