python dataset用法_tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用

本文介绍了如何使用Python创建和使用tfrecord文件,包括将数据写入tfrecord,创建TFRecordDataset,以及如何从tfrecord文件解析数据进行训练。详细步骤包括数据预处理、创建TFRecordDataset、定义解析函数以及使用迭代器进行数据读取。
摘要由CSDN通过智能技术生成

1.创建tfrecord

tfrecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList、tf.train.Int64List、tf.train.FloatList写入tf.train.Feature,如下所示:

tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()])) #feature一般是多维数组,要先转为list

tf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.shape))) #tostring函数后feature的形状信息会丢失,把shape也写入

tf.train.Feature(float_list=tf.train.FloatList(value=[label]))

通过上述操作,以dict的形式把要写入的数据汇总,并构建tf.train.Features,然后构建tf.train.Example,如下:

def get_tfrecords_example(feature, label):

tfrecords_features = {}

feat_shape = feature.shape

tfrecords_features['feature'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))

tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))

tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))

return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))

把创建的tf.train.Example序列化下,便可通过tf.python_io.TFRecordWriter写入tfrecord文件,如下:

tfrecord_wrt = tf.python_io.TFRecordWriter('xxx.tfrecord') #创建tfrecord的writer,文件名为xxx

exmp = get_tfrecords_example(feats[inx], labels[inx]) #把数据写入Example

exmp_serial = exmp.SerializeToString() #Example序列化

tfrecord_wrt.write(exmp_serial) #写入tfrecord文件

tfrecord_wrt.close() #写完后关闭tfrecord的writer

代码汇总:

import tensorflow as tf

from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

mnist = read_data_sets("MNIST_data/", one_hot=True)

#把数据写入Example

def get_tfrecords_example(feature, label):

tfrecords_features = {}

feat_shape = feature.shape

tfrecords_features['feature'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))

tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))

tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))

return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))

#把所有数据写入tfrecord文件

def make_tfrecord(data, outf_nm='mnist-train'):

feats, labels = data

outf_nm += '.tfrecord'

tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)

ndatas = len(labels)

for inx in range(ndatas):

exmp = get_tfrecords_example(feats[inx], labels[inx])

exmp_serial = exmp.SerializeToString()

tfrecord_wrt.write(exmp_serial)

tfrecord_wrt.close()

import random

nDatas = len(mnist.train.labels)

inx_lst = range(nDatas)

random.shuffle(inx_lst)

random.shuffle(inx_lst)

ntrains = int(0.85*nDatas)

# make training set

data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \

[mnist.train.labels[i] for i in inx_lst[:ntra

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值