解决keras模型保存问题(避免系统崩溃,模型训练无效),并且可以解决训练越来越慢的问题

修改之前

首先贴出来训练部分的代码:

    def train(self, train_generator, validation_generator, pre_model_path=None):
        '''
        :param train_generator: 训练集
        :param validation_generator: 测试集
        :param pre_model_path: 预训练模型,在之前模型上继续训练,目前仅支持h5模型
        '''

        # 在已有模型基础上继续训练
        if pre_model_path:
            self.model = load_model(pre_model_path)

        # 配置模型
        with open(pjoin(TXT_DIR, 'message.txt'), 'r') as f:
            _, TRAIN_SIZE, VAL_SIZE, _ = list(map(int, f.readline().split(',')))
        STEP_PER_EPOCH = TRAIN_SIZE // BATCH_SIZE + 1
        VALIDATION_STEPS = VAL_SIZE // BATCH_SIZE + 1
        optimizer = optimizers.RMSprop(lr=LEARNING_RATE)
        self.model.compile(loss='mse', optimizer=optimizer, metrics=['mae'])
        # 训练
        self.history = self.model.fit(train_generator, steps_per_epoch=STEP_PER_EPOCH, epochs=EPOCH,
                                      validation_data=validation_generator, validation_steps=VALIDATION_STEPS)
        # 保存
        self.save()
        return self.model

在fit之后调用了模型保存:

    def save(self):
        if os.path.exists(MODEL_SAVE_DIR):
            shutil.rmtree(MODEL_SAVE_DIR)
        os.mkdir(MODEL_SAVE_DIR)

        # 保存h5模型
        h5_path = pjoin(MODEL_SAVE_DIR, VERSION + '.h5')
        self.model.save(h5_path)
        print('成功保存h5模型:%s' % h5_path)

        # 保存pb模型
        # 定义输入输出
        model_signature = predict_signature_def(inputs={INPUT_KEY: self.model.input},
                                                outputs={OUTPUT_KEY: self.model.output})
        with tf.keras.backend.get_session() as sess:
            pb_path = pjoin(MODEL_SAVE_DIR, VERSION + '_pb')
            try:
                legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
                builder = tf.saved_model.builder.SavedModelBuilder(pb_path)
                builder.add_meta_graph_and_variables(sess, [MODEL_TAG],
                                                     clear_devices=True,
                                                     signature_def_map={SIGNATURE_DEF_KEY: model_signature},
                                                     legacy_init_op=legacy_init_op)
                builder.save()
                print('成功保存PB模型:%s' % pb_path)
            except Exception as e:
                print("Fail to export saved model, exception: {}".format(e))

这样训练过程中只能在fit之后才能对模型进行保存,也就是所有的epoch执行完之后才能保存我们训练好的参数,暂时还没有找到其他解决方法,就说一下我的解决办法。


修改之后

1、首先是在调用train函数的外围建立一个大循环,可直接写死:while True:

然后每次训练完epoch回合之后,清理一下内存,使用backend.clear_session(),不然的话会训练越来越慢,直到电脑卡死

这一步主要解决训练越来越慢的问题,相当于是重新建立sess

想要不中断训练代码,还要退出训练对模型进行保存,这些操作就必须在一个大的session中完成,需要将其中的session抽离出来,不然不能循环调用,因为执行完一次session,系统会自动关掉之前建立的session,所以我们现在外围建立一个session,作为参数传入函数中,按理来说sess也可以建立在while True外围,但是这样所有的操作都在一个sess里面完成,会耗费很多内存,尤其是长时间训练,所以将sess放在了while True里面,这样的话每次调用都会释放内存,然后继续训练,会解决耗费内存的问题,具体操作如下 :

import tensorflow.keras.backend as backend
while True:
    net.train(train_generator, val_generator)
    backend.clear_session()

2、然后是修改epoch参数,也就是迭代的步数,epoch的值的大小设置为想要保存的步数,也就是epoch迭代次数之后保存一次模型,然后在此模型的基础上继续进行训练,以此类推,每次训练好的模型覆盖上一次的模型:

将保存的模型的.h5文件路径加载至train中,以便于调用上次训练好的模型,加入判断模型是否存在的语句:os.path.exists(model_path)来判断是否存在训练好的模型,若是有那么就载入,在现有模型的基础上进行训练。

h5_path = pjoin(MODEL_SAVE_DIR, VERSION + '.h5'),代码如下:

    def train(self, train_generator, validation_generator, sess, pre_model_path=None):
        '''
        :param train_generator: 训练集
        :param validation_generator: 测试集
        :param pre_model_path: 预训练模型,在之前模型上继续训练,目前仅支持h5模型
        '''

        # **********************修改的部分*************************
        pre_model_path = pjoin(MODEL_SAVE_DIR, VERSION + '.h5')
        # ********************************************************

        # 在已有模型基础上继续训练
        if os.path.exists(pre_model_path):
            self.model = load_model(pre_model_path)

        # 配置模型
        with open(pjoin(TXT_DIR, 'message.txt'), 'r') as f:
            _, TRAIN_SIZE, VAL_SIZE, _ = list(map(int, f.readline().split(',')))
        STEP_PER_EPOCH = TRAIN_SIZE // BATCH_SIZE + 1
        VALIDATION_STEPS = VAL_SIZE // BATCH_SIZE + 1
        optimizer = optimizers.RMSprop(lr=LEARNING_RATE)
        self.model.compile(loss='mse', optimizer=optimizer, metrics=['mae'])
        # 训练
        self.history = self.model.fit(train_generator, steps_per_epoch=STEP_PER_EPOCH, epochs=EPOCH,
                                      validation_data=validation_generator, validation_steps=VALIDATION_STEPS)
        # 保存
        self.save(sess)
        return self.model

3、接下里是修改save函数,save中就不需要重新建立session了,直接调用外围的session即可可以直接把转pb注释掉,我们只保存.h5模型就行,到最后训练好之后,将pb保存放开跑一次就行

    def save(self, sess):
        if os.path.exists(MODEL_SAVE_DIR):
            shutil.rmtree(MODEL_SAVE_DIR)
        os.mkdir(MODEL_SAVE_DIR)

        # 保存h5模型
        h5_path = pjoin(MODEL_SAVE_DIR, VERSION + '.h5')
        self.model.save(h5_path)
        print('成功保存h5模型:%s' % h5_path)

        # 保存pb模型
        # 定义输入输出
        #model_signature = predict_signature_def(inputs={INPUT_KEY: self.model.input},
                                                outputs={OUTPUT_KEY: self.model.output})
        
        #pb_path = pjoin(MODEL_SAVE_DIR, VERSION + '_pb')
       # try:
           # legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
           # builder = tf.saved_model.builder.SavedModelBuilder(pb_path)
           # builder.add_meta_graph_and_variables(sess, [MODEL_TAG],
                                                     clear_devices=True,
                                                     signature_def_map={SIGNATURE_DEF_KEY: model_signature},
                                                     legacy_init_op=legacy_init_op)
           # builder.save()
           # print('成功保存PB模型:%s' % pb_path)
       # except Exception as e:
           # print("Fail to export saved model, exception: {}".format(e))

 

  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值