tensorflow中checkpoint断点生成,保存,下载参数续训。

四段代码实现参数保存、下载、续训练

1.设置ckpt文件保存路径

其中保存有模型参数(特定文件类型


checkpoint_save_path = "checkpoint/mnist.ckpt"

2.判断是否有索引,如果保存了参数模型,会有相应索引文件

if os.path.exists(checkpoint_save_path + '.index'):
    print("---------------loading the model---------------")
    model.load_weights(checkpoint_save_path)

3.callback定义模型参数保存相关内容

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 # save_best_only=True,
                                                 save_weights_only=True
                                                 )

4.在训练中加入callback保存模型参数,创建checkpoint文件

model.fit(x_train, y_train, epochs=5, callbacks=[cp_callback])

5.源码

训练fashion_mnist全连接神经网络

import os
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt


# # callback类
# class myCallback(tf.keras.callbacks.Callback):
#     def on_epoch_end(self, epoch, logs=None):
# 
#         if logs.get('loss') < 0.4:
#             print("\n Loss is low so cancelling training!")
#             self.model.stop_training = True
# 
# 
# # 实例化类
# callback = myCallback()

(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
# print(x_train[0].shape)
x_train = x_train / 255.0
y_train = y_train / 255.0
# plt.imshow(x_train[0])
# print(x_train[0])
# print(y_train[0])
# plt.show()
model = keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])

# 断点续训
# 寻找路径,没有则生成
checkpoint_save_path = "checkpoint/mnist.ckpt"
# 判断是否存在索引
if os.path.exists(checkpoint_save_path + '.index'):
    print("---------------loading the model---------------")
    # 从路径中下载模型
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 # save_best_only=True,
                                                 save_weights_only=True
                                                 )

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
model.fit(x_train, y_train, epochs=5, callbacks=[cp_callback])


  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

plus_left

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值