5实现断点续播_tensorflow 2.x系列03 模型断点续训,自动停训

bb0b63460aeee4e74453d7ba01531724.png

tensorflow 2.x系列03 模型断点续训,自动停训

本期文章是一个系列课程,本文是这个系列的第3篇复习笔记

(1)Build and train neural network models using TensorFlow 2.x
(2)Image classification
(3)Natural language processing(NLP)
(4)Time series, sequences and predictions

断点续训

断点续传主要是模型序列化,然后重新加载模型继续训练,这个对于实际生产应用非常有帮助,可以实现在线实时训练.而不用丢失之前的训练进度. 主要通过model.laod_weights函数来完成模型的加载,tensorflow的模型文件格式一般是ckpt文件

读取模型

checkponit_save_path="./checkponit/fashion.ckpt"
if os.path.exists(checkponit_save_path+".index"):
    print("------------load the model -------------")
    model.load_weights(checkponit_save_path)

保存模型

tf.keras.callbacks.ModelCheckpoint(
    filepath="路径文件名",
    save_weights_only=True/False,
    save_best_only=True/False
)

训练时调用

cp_callback=tf.keras.callbacks.ModelCheckpoint(filepath=checkponit_save_path
    save_weights_only=True,
    save_best_only=True
)
history=model.fit(x,y,batch_size=32,epochs=5,validation_data=(x_test,y_test),validation_freq=1,callbacks=[cp_callback])

自动停止训练

主要是实现一个callback的子类来获取实际中的参数来控制stop_training变量达到停止训练目的.

class MyCallback(tf.keras.callbacks.Callback):

    def on_epoch_end(self,epoch,logs={}):
        if (logs.get("loss")<0.25):
            print("n loss is low so cancel train")
            self.model.stop_training=True

代码实战

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)
(train_x,train_y),(test_x,test_y)=tf.keras.datasets.fashion_mnist.load_data()

model=tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28,28)))
model.add(tf.keras.layers.Dense(128,activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(10,activation=tf.nn.softmax))
model.compile(optimizer="adam",loss=tf.keras.losses.sparse_categorical_crossentropy,metrics=["accuracy"])

checkponit_save_path="/tmp/checkponit/fashion.ckpt"         
if os.path.exists(checkponit_save_path+".index"):
    print("------------load the model -------------")
    model.load_weights(checkponit_save_path)

class MyCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self,epoch,logs={}):
        if logs.get("loss")<0.3:
            print("n loss is low so cancel train")
            self.model.stop_training=True

model_save_callback=tf.keras.callbacks.ModelCheckpoint(filepath=checkponit_save_path,
    monitor = 'val_loss',
    save_weights_only=True,
    save_best_only=True
)

auto_stop_callback=MyCallback()

history=model.fit(train_x/255,train_y,batch_size=32,epochs=50,
validation_data=(test_x,test_y),validation_freq=1,
callbacks=[auto_stop_callback,model_save_callback])
model.evaluate(test_x/255,test_y)
t=np.array(test_x[0]/255).reshape(1,28,28)
print(np.argmax(model.predict(t)))
plt.imshow(test_x[0])
2.3.0
------------load the model -------------
Epoch 1/50
1853/1875 [============================>.] - ETA: 0s - loss: 0.2806 - accuracy: 0.8959
 loss is low so cancel train
1875/1875 [==============================] - 4s 2ms/step - loss: 0.2801 - accuracy: 0.8960 - val_loss: 54.3266 - val_accuracy: 0.8685
313/313 [==============================] - 0s 1ms/step - loss: 0.3418 - accuracy: 0.8796
WARNING:tensorflow:7 out of the last 7 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f76704549d8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
9

f593b139391f4fa6147d339e4a8d5bee.png

总结

通过实现不同的callback,可以对模型训练期间进行精确控制. 包括模型的断点续训和自动停止训练

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值