Tensorflow保存和恢复模型

在tensorflow中保存和恢复模型主要通过tf.train.Saver(),具体如下:

  • 保存模型
  1. saver = tf.train.Saver()获得一个文件句柄,将训练中的某一个快照状态保存到文件中去
  2. saver.save(sess, os.path.join(model_dir, ‘ckp-%05d’%(i+1))),将训练好的模型保存到文件中
  3. 参数1:sess,session会话,参数2:模型保存路径
  • 恢复模型
  1. saver.restore(sess, model_path)从文件中恢复模型
  2. 参数1:sess,session会话;参数2:model_path需要被恢复模型

示例代码:

with tf.Session() as sess:
    sess.run( init ) # 注意: 这一步必须要有!!

    # 打开一个writer,向writer中写数据
    train_writer = tf.summary.FileWriter(train_log_dir, sess.graph) # 参数2:显示计算图
    test_writer = tf.summary.FileWriter(test_log_dir)
    
    fixed_test_batch_data, fixed_test_batch_labels = test_data.next_batch(batch_size)

    if os.path.exists(model_path + '.index'):
        saver.restore(sess, model_path)
        print('model restored from %s' % model_path)
    else:
        print('model %s dose not exist' % model_path)

    # 开始训练
    for i in range( train_steps ):
        # 得到batch
        batch_data, batch_labels = train_data.next_batch( batch_size )
        
        eval_ops = [loss, accuracy, train_op]
        should_output_summary = ((i+1) % output_summary_every_steps == 0)

        if should_output_summary:
            eval_ops.append(merged_summary)

        # 获得 损失值, 准确率
        eval_val_results = sess.run( eval_ops, feed_dict={x:batch_data, y:batch_labels} )
        loss_val, acc_val = eval_val_results[0:2]

        if should_output_summary:
            train_summary_str = eval_val_results[-1]
            train_writer.add_summary(train_summary_str, i+1)
            test_summary_str = sess.run([merged_summary_test],
                                        feed_dict = {x: fixed_test_batch_data,y: fixed_test_batch_labels} )[0]
            test_writer.add_summary(test_summary_str, i+1)

        # 每 500 次 输出一条信息
        if ( i+1 ) % 500 == 0:
            print('[Train] Step: %d, loss: %4.5f, acc: %4.5f' % ( i+1, loss_val, acc_val ))
        # 每 5000 次 进行一次 测试
        if ( i+1 ) % 5000 == 0:
            # 获取数据集,但不随机
            test_data = CifarData( test_filename, False )
            all_test_acc_val = []
            for j in range( test_steps ):
                test_batch_data, test_batch_labels = test_data.next_batch( batch_size )
                test_acc_val = sess.run( [accuracy], feed_dict={ x:test_batch_data, y:test_batch_labels } )
                all_test_acc_val.append( test_acc_val )
            test_acc = np.mean( all_test_acc_val )

            print('[Test ] Step: %d, acc: %4.5f' % ( (i+1), test_acc ))

        if (i+1) % output_model_every_steps == 0:
            # saver 机制,保存最近的5个模型
            saver.save(sess, os.path.join(model_dir, 'ckp-%05d'%(i+1)))
            print('model saved to ckp-%05d' % (i+1))

参考:

Tensorflow保存和恢复模型

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值