TensorFlow学习(十一):保存TFRecord文件

更新:

2018.3.26 对于每个例子添加了详细的解释,方便理解.

做过kaggle竞赛的应该很熟悉.csv文件了,.csv文件非常方便,但是通常读取的时候,是一次性读取到内存里面的.要是内存小的话,就要想其他的办法了,那就变得很麻烦了.
或者有时候,从硬盘上面直接读取图片啊什么的,因为图片的文件格式,存放位置各种各样等等一些因素,要是想在训练阶段直接这么使用的话,就更加麻烦了.所以,对于数据进行统一的管理是很有必要的.TFRecord就是对于输入数据做统一管理的格式.加上一些多线程的处理方式,使得在训练期间对于数据管理把控的效率和舒适度都好于暴力的方法.
小的任务什么方法差别不大,但是对于大的任务,使用统一格式管理的好处就非常显著了.因此,TFRecord的使用方法很有必要熟悉.
这节并不准备将TFRcord文件的读取,只讲怎么保存为TFRecord文件,读取还涉及到其他的操作,所以之后会和其他的操作一起讲.
本文的顺序是先讲保存TFRecord文件的时候常见的API,然后再举例子在实际中怎么使用这些API.

一.重要API

Ⅰ tf.python_io.TFRecordWriter

把记录写入到TFRecords文件的类.

__init__(path,options=None)

作用:创建一个TFRecordWriter对象,这个对象就负责写记录到指定的文件中去了.
参数:
path: TFRecords 文件路径
options: (可选) TFRecordOptions对象

close()

作用:关闭对象.

write(record)

作用:把字符串形式的记录写到文件中去.
参数:
record: 字符串,待写入的记录

Ⅱ.tf.train.Example

这个类是非常重要的,TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的.在这里,不会非常详细的讲这个类,但是会给出常见的使用方法和一些重要函数的解释.其他的细节可以参考文档.
class tf.train.Example

属性:

features Magic attribute generated for “features” proto field.

函数:

__init__(**kwargs)

这个函数是初始化函数,会生成一个Example对象,一般我们使用的时候,是传入一个tf.train.Features对象进去.

SerializeToString()

作用:把example序列化为一个字符串,因为在写入到TFRcorde的时候,write方法的参数是字符串的.

Ⅲ.tf.train.Features

class tf.train.Features

属性:

feature

函数:

__init__(**kwargs)
作用:初始化Features对象,一般我们是传入一个字典,字典的键是一个字符串,表示名字,字典的值是一个tf.train.Feature对象.

Ⅳ.tf.train.Feature

class tf.train.Feature

属性:

bytes_list
float_list
int64_list

函数:
__init__(**kwargs)

作用:构造一个Feature对象,一般使用的时候,传入 tf.train.Int64List, tf.train.BytesList, tf.train.FloatList对象.

Ⅴ.tf.train.Int64List, tf.train.BytesList, tf.train.FloatList

使用的时候,一般传入一个具体的值,比如学习任务中的标签就可以传进value=tf.train.Int64List,而图片就可以先转为字符串的格式之后,传入value=tf.train.BytesList中.

说明:

以上的函数的API都可以对照着例子的代码来熟悉,看在例子中是怎么使用的这些对象.

二.几个例子

这里直接写两个非常常见的例子(有多常见你看一下就知道了),来体会一下TFRecord可能的用法,在这里可以并不用知道程序运行每一行的具体含义,但是要大概知道是怎么回事.更加详细的对这两个例子的讲解会在后面的内容.这里的目的就是先感受一下.
在这个例子中使用的.csv文件就是kaggle竞赛里面MNIST手写体识别的train.csv文件,可以在官网上面下载实验一下.

Ⅰ.csv文件转化为TFRecord格式

import tensorflow as tf
import numpy as np
import pandas as pd

#---------------------loading data from .csv file------------------------------#
#load .csv file
train_frame=pd.read_csv(filepath_or_buffer="../Mnist/train.csv")
#print(train_frame.head())
train_labels_frame=train_frame.pop(item="label")
#print(train_labels_frame.shape)

train_values=train_frame.values
train_size=train_values.shape[0]
train_labels_values=train_labels_frame.values


#print(train_values[0].shape)
#print(train_values[0].dtype)
#print(train_labels_values[0])

#------------------------------create TFRecord file----------------------------#
writer=tf.python_io.TFRecordWriter(path="train.tfrecords")
for i in range(train_size):
    image_raw=train_values[i].tostring()
    example=tf.train.Example(
        features=tf.train.Features(
            feature={
                "image_raw":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),
                "label":tf.train.Feature(int64_list=tf.train.Int64List(value=[train_labels_values[i]]))
            }
        )
    )
    writer.write(record=example.SerializeToString())

writer.close()

Ⅱ.硬盘上图片转成TFRecord格式

这里使用的例子还是来自于kaggle,是在kaggle中的CIFAR-10识别比赛中的数据集.
CIFAR-10 - Object Recognition in Images
数据集主要是两个压缩包,train数据集和test数据集,解压之后就是两个文件夹,里面放了一大推图片.
这里写图片描述
要是不用TFRecord的话,直接处理是比较低效的.那么来看看用TFRecord怎么处理.

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import tensorflow as tf
import pandas as pd
import os


#get the amount of files in folder
def sizeOfFolder(folder_path):
    fileNameList = os.listdir(path=folder_path)
    size = 0
    for fileName in fileNameList:
        if (os.path.isfile(path=os.path.join(folder_path, fileName))):
            size += 1
    return size


#path(path of folder)
#if isTrain=True,labels can't be None
def pics_to_TFRecord(folder_path,labels=None,isTrain=False):
    size=sizeOfFolder(folder_path=folder_path)

    #train set
    if isTrain:
        if labels is None:
            print("labels can't be None!!!")
            return None
        if labels.shape[0]!=size:
            print("something wrong with shape!!!")
            return None
        writer=tf.python_io.TFRecordWriter("../data/TFRecords/train.tfrecords")
        for i in range(1,size+1):
            print("----------processing the ",i,"\'th image----------")
            filename=folder_path+str(i)+".png"
            img=mpimg.imread(fname=filename)
            width=img.shape[0]
            print(width)
            #trans to string
            img_raw=img.tostring()

            example=tf.train.Example(
                features=tf.train.Features(
                    feature={
                        "img_raw":tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
                        "label":tf.train.Feature(int64_list=tf.train.Int64List(value=[labels[i-1]])),
                        "width":tf.train.Feature(int64_list=tf.train.Int64List(value=[width]))
                    }
                )

            )
            writer.write(record=example.SerializeToString())
        writer.close()

    #test set
    else:
        writer = tf.python_io.TFRecordWriter("../data/TFRecords/test.tfrecords")
        for i in range(1, size + 1):
            print("----------processing the ", i, "\'th image----------")
            filename = folder_path + str(i) + ".png"
            img = mpimg.imread(fname=filename)
            width = img.shape[0]
            print(width)
            # trans to string
            img_raw = img.tostring()

            example = tf.train.Example(
                features=tf.train.Features(
                    feature={
                        "img_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
                        "width": tf.train.Feature(int64_list=tf.train.Int64List(value=[width]))
                    }
                )

            )
            writer.write(record=example.SerializeToString())
        writer.close()

train_labels_frame=pd.read_csv("../data/trainLabels.csv")
train_labels_frame_dummy=pd.get_dummies(data=train_labels_frame)
#print(train_labels_frame_dummy)
train_labels_frame_dummy.pop(item="id")
#print(train_labels_frame_dummy)
train_labels_values_dummy=train_labels_frame_dummy.values
#print(train_labels_values_dummy)
train_labels_values=np.argmax(train_labels_values_dummy,axis=1)
#print(train_labels_values)


#write train record
pics_to_TFRecord(folder_path="../data/train/",labels=train_labels_values,isTrain=True)

#write test record
pics_to_TFRecord(folder_path="../data/test/")

上面的代码运行之后,就会在设置的文件夹下面得到两个.tfrecords文件.
其中训练集的有500多M,测试集达到了惊人的3G多.但是为了后面训练的方便,这是值得的.

Ⅲ.存为多个TFRecord文件

通过前面两个方法,我们知道可以把你想要的文件或者记录通过或多或少的方法转为TFRecord格式.
那么数据量很大的时候,你会发现,单个TFRecord文件是非常非常大的,这对于硬盘是不小的负担,所以,可以通过存储多个TFRecord文件来解决问题.

import tensorflow as tf

num_files=3
num_instance=100

for i in range(num_files):
    print("write ",i," file")
    fileName=("test.tfrecords-%.5d-of-%.5d" % (i,num_files))
    writer=tf.python_io.TFRecordWriter(path=fileName)

    for j in range(num_instance):
        print("write ",j," record")
        example=tf.train.Example(
            features=tf.train.Features(
                    feature={
                        "i":tf.train.Feature(int64_list=tf.train.Int64List(value=[i])),
                        "j":tf.train.Feature(int64_list=tf.train.Int64List(value=[j]))
                    }
            )

        )
        writer.write(record=example.SerializeToString())
    writer.close()
  • 19
    点赞
  • 58
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值