制作自己的TFRecords数据集——Tensorflow

前言

       最近一直在研究深度学习,主要是针对卷积神经网络(CNN),接触过的数据集也有了几个,最经典的就是MNIST, CIFAR10/100, NOTMNIST, CATS_VS_DOGS 这几种,由于这几种是在深度学习入门中最被广泛应用的,所以很多深度学习框架 Tensorflow、keras和pytorch都有针对这些数据集专用的数据导入的函数封装,但是一般情况下我们的数据集并不是这种很规范的形式,那么如何把自己的数据集转换成这些框架能够使用的数据形式至关重要,首先是针对最流行的Tensorflow。
        tensorflow官网给出了三种读取数据的方法:
		对于数据量较小而言,可能一般选择直接将数据加载进内存,然后再分batch输入网络进行训练。
		但是,如果数据量较大,这样的方法就不适用了,因为太耗内存,另外io运算较耗时,所以这时最好使用tensorflow提供的队列queue,也就是第二种方法 从文件读取数据。对于一些特定的读取,比如csv文件格式,官网有相关的描述,在这儿我介绍一种比较通用,高效的读取方法(官网介绍的少),即使用tensorflow内定标准格式——TFRecord.
		那下面就让我们了解一下什么是TFRecord:
		TFRecords其实是一种二进制文件,但是它能更好的利用内存,并且不需要单独的标签文件,总而言之,这样的文件格式好处多多,所以让我们用起来吧。这里注意:TFRecord会根据你输入的文件的类,自动给每一类打上同样的标签。
		下面以cifar10数据集为例

TFRecord数据的制作

cifar10数据集的下载

def download_and_uncompress_tarball(tarball_url, dataset_dir):
  """Downloads the `tarball_url` and uncompresses it locally.
  Args:
    tarball_url: The URL of a tarball file.
    dataset_dir: The directory where the temporary files are stored.
  """
  filename = tarball_url.split('/')[-1]
  filepath = os.path.join(dataset_dir, filename)

  def _progress(count, block_size, total_size):
    sys.stdout.write('\r>> Downloading %s %.1f%%' % (
        filename, float(count * block_size) / float(total_size) * 100.0))
    sys.stdout.flush()
  filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress)
  print()
  statinfo = os.stat(filepath)
  print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
  tarfile.open(filepath, 'r:gz').extractall(dataset_dir)
DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
DATA_DIR = 'data'
download_and_uncompress_tarball(DATA_URL, DATA_DIR)
	上述代码为在线下载cifar10数据集代码程序,当然我们也可以网站上直接下载cifar10数据集,为二进制格式数据集。下载好的数据集格式如下:

在这里插入图片描述

数据解析和转换

		cifar10数据集自带的解析函数,代码如下,故可使用自带的解析函数对数据集进行解析。
def unpickle(file):
    import pickle
    with open(file,"rb") as fo:
        dict=pickle.load(fo,encoding="bytes")
    return dict
	对上述数据集进行转换处理,转换成如下截图所示的数据格式。所有训练样本在train文件夹下,每个类别的训练数据又都在其类别文件夹下。

在这里插入图片描述
转换的代码如下:

folders="/home/zhouzhou/deep_learning/original_code/tf_read_write/TF_read_write/data_manager/data/cifar-10-batches-py"
trfiles=glob.glob(folders+"/data_batch*")
print(trfiles)
data  = []
labels = []
for file in trfiles:
    dt = unpickle(file)
    data += list(dt[b"data"])
    labels += dt[b"labels"]
imgs = np.reshape(data, [-1,3,32,32])
for i in range(imgs.shape[0]):
    im_data = imgs[i, ...]
    im_data = np.transpose(im_data, [1, 2, 0])
    im_data = cv2.cvtColor(im_data, cv2.COLOR_RGB2BGR)
    f = "{}/{}".format("data/image/train", classification[labels[i]])
    if not os.path.exists(f):
        os.mkdir(f)
    cv2.imwrite("{}/{}.jpg".format(f, str(i)), im_data)

TFRecord数据的打包制作

	如上面步骤完成了上述数据格式的数据准备工作,下面进行TFRecord数据格式制作。
idx=0
im_data =[]
im_labels =[]
for path in classification:
    folder ="data/image/test/"+path
    im_list =glob.glob(folder+"/*")
    im_label =[idx for i in range(im_list.__len__())]
    idx +=1
    im_data +=im_list
    im_labels +=im_label
tf_file ="data/train.tfrecord"
writer =tf.python_io.TFRecordWriter(tf_file)
index =[i for i in range(im_data.__len__())]
np.random.shuffle(index)
for i in range(im_data.__len__()):
    im_d =im_data[index[i]]
    im_l =im_labels[index[i]]
    data =cv2.imread(im_d)
    ex =tf.train.Example(
        features =tf.train.Features(
            feature ={
                "image":tf.train.Feature(
                    bytes_list =tf.train.BytesList(value=[data.tobytes()])),
                "label":tf.train.Feature(
                    int64_list =tf.train.Int64List(value=[im_l]))
            }
        )
    )
    writer.write(ex.SerializeToString())
writer.close()

小结

	如上步骤,完成了TFrecord数据集制作。

tf.train.Features函数详解

引用

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值