Tensorflow2 二次训练和断点续训

Environment

  • Tensorflow2.0.0
  • python3.6

问题描述

每轮训练中以特定方式(固定频率、最高准确率或最低loss等)存储模型,停止训练后,基于已存储的模型进行二次训练。

以特定方式保存模型

callbacks = [
            tf.keras.callbacks.ModelCheckpoint(filepath=save_args['./saved_models/model_epoch{epoch}.h5'],
                                               # save_freq=save_args['2'], 
                                               # save_weights_only=True,
                                               #monitor='val_accuracy',
                                               #mode='max',
                                               #save_best_only=True
                                               ),
            tf.keras.callbacks.TensorBoard(log_dir="./logs/callback_test", update_freq='batch')]
  • save_freq=save_args['2']
    • 每2 epoches保存一次
  • monitor='val_accuracy',
    mode='max',
    save_best_only=True
    • 无论多少epoch,只保存val_accuracy最大的一次

加载模型保存好的模型

# 重新创建完全相同的模型,包括其权重和优化程序
self.model = tf.keras.models.load_model('./saved_models/callback_test/model_epoch1.h5')

# 显示网络结构
self.model.summary()

存储最后一轮epoch模型

#sava the model as pb/h5
#self.model.save(save_args['pb_save_path']+'/1120202001.h5')
#tf.saved_model.save(self.model, save_args['pb_save_path'])

完整代码

import os
import numpy as np
import tensorflow as tf
import random
import yaml
import warnings

warnings.filterwarnings('ignore')


class Trainer:

    def __init__(self, config_path, config_load_mode=
  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值