# 7.2 TensorFlow笔记(基础篇): 生成TFRecords文件

167人阅读 评论(0)

## 前言

1. 把样本数据写入TFRecords二进制文件
2. 从队列中读取

TFRecords二进制文件,能够更好的利用内存,更方便的移动和复制,并且不需要单独的标记文件

## CODE

### 源码与解析

import tensorflow as tf
import os
import argparse
import sys

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

#1.0 生成TFRecords 文件
from tensorflow.contrib.learn.python.learn.datasets import mnist

FLAGS = None

# 编码函数如下:
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def convert_to(data_set, name):
"""Converts a dataset to tfrecords."""
images = data_set.images
labels = data_set.labels
num_examples = data_set.num_examples

if images.shape[0] != num_examples:
raise ValueError('Images size %d does not match label size %d.' %
(images.shape[0], num_examples))
rows = images.shape[1] # 28
cols = images.shape[2] # 28
depth = images.shape[3] # 1. 是黑白图像,所以是单通道

filename = os.path.join(FLAGS.directory, name + '.tfrecords')
print('Writing', filename)
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
image_raw = images[index].tostring()

# 写入协议缓存区,height,width,depth,label编码成int64类型,image_raw 编码成二进制
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'label': _int64_feature(int(labels[index])),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString()) # 序列化为字符串
writer.close()

def main(unused_argv):
# Get the data.
dtype=tf.uint8,
reshape=False,
validation_size=FLAGS.validation_size)

# Convert to Examples and write the result to TFRecords.
convert_to(data_sets.train, 'train')
convert_to(data_sets.validation, 'validation')
convert_to(data_sets.test, 'test')

if __name__ == '__main__':
parser = argparse.ArgumentParser()
'--directory',
type=str,
default='MNIST_data/',
)
'--validation_size',
type=int,
default=5000,
help="""\
Number of examples to separate from the training data for the validation
set.\
"""
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)


### 运行结果

#### 打印输出

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Writing MNIST_data/train.tfrecords
Writing MNIST_data/validation.tfrecords
Writing MNIST_data/test.tfrecords

## 相关

1. argparse是python用于解析命令行参数和选项的标准模块，用于代替已经过时的optparse模块。argparse模块的作用是用于解析命令行参数,详情请参见这里:python中的argparse模块:http://blog.csdn.net/fontthrone/article/details/76735591
2. 把样本数据写入TFRecords二进制文件 : http://blog.csdn.net/fontthrone/article/details/76727412
3. TensorFlow笔记(基础篇):加载数据之预加载数据与填充数据:http://blog.csdn.net/fontthrone/article/details/76727466
4. TensorFlow笔记(基础篇):加载数据之从队列中读取:http://blog.csdn.net/fontthrone/article/details/76728083
1
0

* 以上用户言论只代表其个人观点，不代表CSDN网站的观点或立场
个人资料
• 访问：223228次
• 积分：2535
• 等级：
• 排名：第14515名
• 原创：71篇
• 转载：17篇
• 译文：2篇
• 评论：20条
博客专栏
 Python大战人工智能 文章：13篇 阅读：3160
 剑指数据科学 文章：9篇 阅读：27909
 剑指汉语自然语言处理 文章：14篇 阅读：49307