TensorFlow笔记:数据集导出

2018/5/17更新
有问题欢迎联系邮箱 lvjc2010@qq.com

公布一下自己的一段读入数据的操作。
我在做端到端的任务,所以实际上我的image和label都是array。
准备工作:需要按照下文方法先生成tfrecord文件以便tf.data.Dataset操作。

def read_and_decode(serialized_example):
    '''read and decode tfrecord file, generate (image, label) batches
    Args:
        tfrecords_file: the directory of tfrecord file
        batch_size: number of images in each batch
    Returns:
        image: 4D tensor - [batch_size, width, height, channel]
        label: 1D tensor - [batch_size]
    '''
    # make an input queue from the tfrecord file
    img_features = tf.parse_single_example(
        serialized_example,
        features={
            'label_raw': tf.FixedLenFeature([], tf.string),
            'image_raw': tf.FixedLenFeature([], tf.string),
        })
    image = tf.decode_raw(img_features['image_raw'], tf.uint8)
    image = tf.reshape(image, [480, 640, 3])
    image = tf.div(tf.to_float(image), 255.0)
    label = tf.decode_raw(img_features['label_raw'], tf.uint8)
    label = tf.reshape(label, [480, 640, 1])
    label = tf.div(tf.to_float(label), 255.0)
    return image, label

#根据自己需要来修改下面路径
TRAIN_TFREC = 'data/.../train.tfrecords'
VAL_TFREC = 'data/.../test.tfrecords'

filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(read_and_decode,num_parallel_calls=4)
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(1)
#dataset = dataset.repeat()
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
init_train = iterator.make_initializer(dataset)
images, labels = iterator.get_next()

#这里是GPU的基本操作
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
config = tf.ConfigProto(allow_soft_placement=True,gpu_options=gpu_options)
sess=tf.train.MonitoredTrainingSession(config=config)
# Compute for 2 epochs.
for epoch in range(2):
  sess.run(init_train, feed_dict={filenames: [TRAIN_TFREC]})
  step=0
  while True:
    step +=1
    try:
      label=sess.run(labels)
      if step%1000==0:
        print(label[0][225][90:100])
        print(epoch,step)
    except tf.errors.OutOfRangeError:
      break

更新至Tensorflow 1.4

I. 读输入数据

1. 如果数据库大小可以全部被内存读入 使用最简单的Numpy arrays格式:

1). 将npy文件转换成tf.Tensor
2). 使用Dataset.from_tensor_slices()
示例:

# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
features = data["features"]
labels = data["labels"]
# Assume that each row of features corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]
dataset = tf.data.Dataset.from_tensor_slices((features, labels))

请注意,上面的代码片段会将TensorFlow graph中的featureslabels数组作为

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值