[Tensorflow] TFRecord介绍和教程

官方tutorials Using TFRecords and tf.Example

我的理解:
在图像转化为tfrecord时,是将图像转换成一系列二进制记录( a sequence of binary records binary records),用{“string”: value}来存储高,宽,channel以及label。
Protocol buffers 是一种跨平台、语言的搞笑的序列化数据结构。

tf.Example message (or protobuf)用来存储.tfrecord的编码信息。利用它我们可以将这个信息还原。

很好的教程:

TensorFlow TFRecord数据集的生成与显示

暂时没有解决的问题(可能会出现图像size不匹配)

code

利用TF-slim中flower的download_and_coverting.py,改成适合自己数据集的代码。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import os,glob
import random
import sys

import tensorflow as tf

import dataset_utils

_NUM_SHARDS = 5

class ImageReader(object):
  """Helper class that provides TensorFlow image coding utilities."""

  def __init__(self):
    # Initializes function that decodes RGB JPEG data.
    self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
    self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)

  def read_image_dims(self, sess, image_data):
    image = self.decode_jpeg(sess, image_data)
    return image.shape[0], image.shape[1]

  def decode_jpeg(self, sess, image_data):
    image = sess.run(self._decode_jpeg,
                     feed_dict={self._decode_jpeg_data: image_data})
    assert len(image.shape) == 3
    assert image.shape[2] == 3
    return image

def _get_filenames_and_classes(dataset_dir):
  """Returns a list of filenames and inferred class names.

  Args:
    dataset_dir: A directory containing a set of subdirectories representing
      class names. Each subdirectory should contain PNG or JPG encoded images.

  Returns:
    A list of image file paths, relative to `dataset_dir` and the list of
    subdirectories, representing class names.
  """
  root = os.path.join(dataset_dir)
  directories = []
  class_names = []
  for filename in os.listdir(root):
    path = os.path.join(root, filename)
    if os.path.isdir(path):
      directories.append(path)
      class_names.append(filename)

  photo_filenames = []
  for directory in directories:
    for filename in os.listdir(directory):
      path = os.path.join(directory, filename)
      photo_filenames.append(path)

  return photo_filenames, sorted(class_names)

def _get_dataset_filename(dataset_dir, split_name, shard_id):
  output_filename = 'nwfs_%s_%05d-of-%05d.tfrecord' % (
      split_name, shard_id, _NUM_SHARDS)
  return os.path.join(dataset_dir, output_filename)

def _convert_dataset(split_name, data_dir_class, class_names_to_ids, dataset_dir):
  """Converts the given filenames to a TFRecord dataset.

  Args:
    split_name: The name of the dataset, either 'train' or 'validation'.
    filenames: A list of absolute paths to png or jpg images.
    class_names_to_ids: A dictionary from class names (strings) to ids
      (integers).
    dataset_dir: The directory where the converted datasets are stored.
  """
  assert split_name in ['train', 'validation']

  filenames = []
  for file in os.listdir(data_dir_class):
        filenames.append(data_dir_class+'/'+file)
        #print(data_dir_class+'/'+file)
  num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS)))
  #print(filenames)
  with tf.Graph().as_default():
    image_reader = ImageReader()

    with tf.Session('') as sess:

      for shard_id in range(_NUM_SHARDS):
        output_filename = _get_dataset_filename(
            dataset_dir, split_name, shard_id)

        with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
          start_ndx = shard_id * num_per_shard
          end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
          for i in range(start_ndx, end_ndx):
            #print(filenames[i].split('.')[-1])
            try:
                sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
                    i+1, len(filenames), shard_id))

                sys.stdout.flush()

                # Read the filename:
                
                image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
                height, width = image_reader.read_image_dims(sess, image_data)

                class_name = os.path.basename(os.path.dirname(filenames[i]))
                class_id = class_names_to_ids[class_name]
                #print("class  name:",class_name,"class id:",class_id)
                example = dataset_utils.image_to_tfexample(
                    image_data, b'jpg', height, width, class_id)
                tfrecord_writer.write(example.SerializeToString())
                
            except Exception:
                pass
                continue
            
            

  sys.stdout.write('\n')
  sys.stdout.flush()
    
def _dataset_exists(dataset_dir):
  for split_name in ['train', 'validation']:
    for shard_id in range(_NUM_SHARDS):
      output_filename = _get_dataset_filename(
          dataset_dir, split_name, shard_id)
      if not tf.gfile.Exists(output_filename):
        return False
  return True

def run(dataset_dir):
  """Runs the download and conversion operation.

  Args:
    dataset_dir: The dataset directory where the dataset is stored.
  """
  if not tf.gfile.Exists(dataset_dir):
    tf.gfile.MakeDirs(dataset_dir)

  if _dataset_exists(dataset_dir):
    print('Dataset files already exist. Exiting without re-creating them.')
    return

  #dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir)
  
  photo_filenames, class_names = _get_filenames_and_classes(dataset_dir)
  #class_names_to_ids = dict(zip(class_names, range(len(class_names))))
  class_names_to_ids = {"neg":0,"pos":1}
    
  # Divide into train and test:
#   random.seed(_RANDOM_SEED)
#   random.shuffle(photo_filenames)
#   training_filenames = photo_filenames[_NUM_VALIDATION:]
#   validation_filenames = photo_filenames[:_NUM_VALIDATION]

  # First, convert the training and validation sets.

    
  _convert_dataset('train', dataset_dir+'/train/pos', class_names_to_ids,
                   dataset_dir + '/tfrecord')
  _convert_dataset('train', dataset_dir+'/train/neg', class_names_to_ids,
                   dataset_dir + '/tfrecord')
  _convert_dataset('validation', dataset_dir+'/validation/pos', class_names_to_ids,
                   dataset_dir + '/tfrecord')
  _convert_dataset('train', dataset_dir+'/validation/neg', class_names_to_ids,
                   dataset_dir + '/tfrecord')
  # Finally, write the labels file:
  labels_to_class_names = dict(zip(range(len(class_names)), class_names))
  dataset_utils.write_label_file(labels_to_class_names, dataset_dir)

  print('\nFinished converting the dataset!')
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值