一个包含数据输入和预处理流程的使用数据集进行训练和测试的完整例子
import tensorflow as tf
train_files = tf.train.match_filenames_once("path/to/train-file-*")
test_files = tf.train.match_filenames_once("path/to/test-file-*")
# 定义parser方法从TFRecord中解析数据
def parser(record):
features = tf.parse_single_example(
record,
features = {
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'channels': tf.FixedLenFeature([], tf.int64)
}
)
# 从原始图像数据解析出像素矩阵, 并根据图像尺寸还原图像
decode_image = tf.decode_raw(features['image'], tf.uint8)
decode_image.set_shape([features['height'], features['width'], features['channels']])
label = features['l