解决tensorflow的tfrecords文件在训练时所占内存过大的问题

tfrecords文件在训练时所占内存过大的原因

项目中需要把tiff格式的图像先制作成tensorflow的tfrecords格式,然后在训练时读取tfrecords文件为tf.data.Dataset 进行训练。而在读取tfrecords文件生成Dataset时一般都需要设置shuffle的大小,目的是在训练过程中,每个step都随机从Dataset中随机的获取一个batch的数据进行训练,能较好的防止过拟合的问题。

dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)

首先需要说一下shuffle的原理,shuffle操作后,在训练开始前,会从dataset的第一个元素开始(一个样本或者一个batch)依次取buffer_size个元素放入了缓存中,然后训练时随机从缓存的元素中取一个用于第一个step的训练,并取dataset的下一个元素放入缓存替换到取出来训练的元素的位置。之后每个的训练step重复以上的操作。
更为具体的对dataset的操作性能的讲解可以参考这篇博客:TensorFlow 高性能数据输入管道设计指南
所以可以想象一下,当我们的buffer_size设置的比较小时,这个时候就不能实现完全的随机性抽取训练样本,这样会影响训练的效果。比如我们有一个二分类的任务对猫和狗进行分类,如果buffer_size设置较小,那么前面的epoch的训练模型只能看到猫的图像,看不到狗的图像,影响训练效果。
这时往往需要设置buffer_size=训练元素的数目 来实现均匀的随机化。而当我们的训练集很大的时候,设置buffer_size=训练元素的数目会将整个训练集放入缓存中,会占用很大的内存,甚至会内存溢出。

实现均匀随机化又要节省大量内存的解决方案

我们可以使用下面的方法实现训练样本的均匀随机化:
1.首先将训练样本随机的制作成很多个小的tfrecords文件
2.使用tensorflow的interleave函数读取多个tfrecordds文件前将tfrecordds文件先随机排列
3.设置shuffle的buffer_size=小的tfrecords文件的元素数目
由于buffer_size设置的较小,所以训练占的内存就很小,这样又实现了均匀随机化,减缓了过拟合的问题。

具体的实现代码

1)将tiff格式的图像随机的制作成很多个小的tfrecords文件

import os
import tensorflow as tf
from PIL import Image
import glob
from skimage import io
import numpy as np
import time
import random

def img2TFRecord(img_dir,label_dir,tfrecords_dir,num_examples_per_tfrecords_txt_file,num_examples_per_tfrecords):
    """
    将图像和对应的标签制作成TFRecord文件,之后使用tf.data.TFRecordDataset读取TFRecord文件制作成dataset
    :param img_dir: 图像的文件夹
    :param label_dir: 标签的文件夹
    :param tfrecords_dir: 制作成的tfrecords文件夹
    :param num_examples_per_tfrecords_txt_file: 每个tfrecords文件包含的样本数的txt文件
    :param num_examples_per_tfrecords: 每个tfrecords文件包含的样本数
    :return:
    """
    t0 = time.time()
    img_filenames = glob.glob(img_dir+'/*.tif')

    # 对列表进行随机打乱
    random.shuffle(img_filenames)

    print(img_filenames[:10])

    # 把样本数目写入在一个txt中,之后需要用
    with open(num_examples_per_tfrecords_txt_file,
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值