12_TFRecord数据打包

4.TensorFlow中的数据读取方式

4.1直接将数据加载到内存中

在处理小规模数据,比如mnist的时候可以将数据直接存入内存。但是对于较大的数据集,这是很麻烦的。因此好的方法就是将数据存放在磁盘上,在需要的时候再加载进内存,比如训练时每次加载一个batchsize的数据。
但是这样的方法也是存在问题的,如下图,因为计算和加载数据是不同步的,若计算设备完成 计算时数据还未完成加载,这个时候就会造成计算训练处的延迟,导致资源浪费。
在这里插入图片描述
因此更好地办法是预取数据并且使用独立的线程进行加载和训练。

4.2TFRecord

TensorFlow提供了TFRecord:高效的TensorFlow文件格式来统一存储数据,TFRecord就是一个简单的包含序列化输入数据的二进制文件,序列化是基于协议缓冲区的,这样可以快速地实现数据的复制,移动,读取和存储。
如下图,TensorFlow数据读取机制(数据->文件名队列->内存队列->计算设备读取数据):
在这里插入图片描述
首先要理解EPOCH的概念,一个epoch就是把数据集里的所有图片都学习了一遍,比如一个数据集有100张图片,训练时batchsize为10,那么一个epoch就是把这100张图片都计算了一遍,而一个epoch就包含了10(总量100除以每批的数据量10)个batch的数据。文件名队列就构成了数据的队列,设定shuffle参数可以让数据乱序。
比如原来有A,B,C三张图片,若设定训练3个epoch,shuffle=false,则最终文件名队列中的结果就是[A,B,C,A,B,C,A,B,C]。
比如原来有A,B,C三张图片,若设定训练3个epoch,shuffle=true,则最终文件名队列中的结果可能是[A,C,B,B,C,A,C,A,B]。每个epoch内的ABC顺序不一定。

4.3数据操作实践cifar_10数据集

4.3.1下载并解析cifar10

首先下载cifar10数据集,由于官网下载可能很慢,我将数据集放在了百度网盘里(提取码o5yu):[Download cifar10]

我的目录结构:

.
├── cifar10_bin_to_jpg.py
├── cifar10_jpg_to_tfrecord.py
├── test_data
│   └── data_batch_6
└── train_data
    ├── batches.meta
    ├── cifar10_jpgs
    ├── data_batch_1
    ├── data_batch_2
    ├── data_batch_3
    ├── data_batch_4
    └── data_batch_5

首先解析下载下来的二进制文件,将里面的数据解析并存放到每个数据到train_data/cifar10_jpgs目录中对应label的目录下。
cifar10_bin_to_jpg.py:

import os
import pickle
import numpy as np
import cv2

classes= ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

data  = []
labels = []
for i in range(1,6):
	file_name = "train_data/data_batch_"+str(i)
	result = unpickle(file_name)
	data += list(result[b"data"])
	labels += list(result[b"labels"])
	print(file_name+" loaded.")
	
imgs = np.reshape(data, [-1, 3, 32, 32])

for i in range(imgs.shape[0]):
    im_data = imgs[i, ...]
    im_data = np.transpose(im_data, [1, 2, 0])
    im_data = cv2.cvtColor(im_data, cv2.COLOR_RGB2BGR)
	
    f = "{}/{}".format("train_data/cifar10_jpgs", classes[labels[i]])

    if not os.path.exists(f):
        os.mkdir(f)

    cv2.imwrite("{}/{}.jpg".format(f, str(i)), im_data)

print("All Done.")

完成后可以在train_data/cifar10_jpgs目录下看到十个类别的数据里面的图片数据。
在这里插入图片描述

4.3.将解析后的cifar10打包成tfrecord文件。

涉及到的tf api:
数据读取:tf.train.string_input_producer
数据解析:tf.TFRecordReader tf.parse_single_example
数据写入:tf.python_io.TFRecordWriter

数据写入:
writer = tf.python_io.TFRecordWriter()
example = tf.train.example() Feature{图像数据,图像label等。。}
//序列化后写入
writer.write(example.SerializeToString())
writer.close()

import tensorflow as tf
import glob
import cv2
import numpy as np
classes = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

idx = 0
im_data = []
im_labels = []
for class_ in classes:
	path = "train_data/cifar10_jpgs/" + class_
	im_list = glob.glob(path + "/*")
	im_label = [idx for i in  range(im_list.__len__())]
	idx += 1
	im_data += im_list
	im_labels += im_label

#实例化一个TFRecordWriter对象
tfrecord_file = "train_data/train_data.tfrecord"
writer = tf.python_io.TFRecordWriter(tfrecord_file)

index = [i for i in range(im_data.__len__())]
np.random.shuffle(index)

print("strat make")
print("total:",im_data.__len__())

for i in range(im_data.__len__()):
	if i%5000 is 0 :
		print(i)
	im_d = im_data[index[i]]
	im_l = im_labels[index[i]]
	data = cv2.imread(im_d)
	#tf.train.Example是一个存放数据的结构,一个example包含一个Features对象
	ex = tf.train.Example(
		features = tf.train.Features(
            feature = {
                "image":tf.train.Feature(
                    bytes_list=tf.train.BytesList(
                        value=[data.tobytes()])),
                "label": tf.train.Feature(
                    int64_list=tf.train.Int64List(
                        value=[im_l])),
            }
        )
    )
	writer.write(ex.SerializeToString())

writer.close()
print("make tfrecord_file done.")
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值