TF-slim download_and_convert_flowers.py代码解析

参考:
1、https://github.com/tensorflow/models/blob/master/research/slim/datasets/download_and_convert_flowers.py


1、提取类名

import os

filenames=r'C:\Users\Administrator\Desktop\test3\hymenoptera_data\train\ants\0013035.jpg'
class_name = os.path.basename(os.path.dirname(filenames)) # 提取文件所在的文件夹名,即分类名,用于后续做标签
print(class_name)
# ants

2、解析所有图片路径

def _get_filenames_and_classes(dataset_dir):
  """Returns a list of filenames and inferred class names.
  返回所有图片对应的路径以及文件名(对应每一类)
  Args:
    dataset_dir: A directory containing a set of subdirectories representing
      class names. Each subdirectory should contain PNG or JPG encoded images.
  Returns:
    A list of image file paths, relative to `dataset_dir` and the list of
    subdirectories, representing class names.
  """
  flower_root = os.path.join(dataset_dir, 'flower_photos')
  directories = []
  class_names = []
  for filename in os.listdir(flower_root):
    path = os.path.join(flower_root, filename)
    if os.path.isdir(path): # 如果是文件夹
      directories.append(path)
      class_names.append(filename) # 文件名对应类名 用于做标签

  photo_filenames = [] # 存储了所有图片的路径
  for directory in directories:
    for filename in os.listdir(directory):
      path = os.path.join(directory, filename) #图片完整路径
      photo_filenames.append(path)

  return photo_filenames, sorted(class_names)

3、图像转成tfr数据

def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
  """Converts the given filenames to a TFRecord dataset.  图像转成tfr数据
  Args:
    split_name: The name of the dataset, either 'train' or 'validation'.
    filenames: A list of absolute paths to png or jpg images. 存放所有图片的路径 list
    class_names_to_ids: A dictionary from class names (strings) to ids
      (integers).
    dataset_dir: The directory where the converted datasets are stored.
  """
  assert split_name in ['train', 'validation']

  num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS))) # 将总的文件分成_NUM_SHARDS份,每一份是num_per_shard个图片

  with tf.Graph().as_default(): # 新建图表
    image_reader = ImageReader()

    with tf.Session('') as sess:

      for shard_id in range(_NUM_SHARDS):
        output_filename = _get_dataset_filename(
            dataset_dir, split_name, shard_id)

        with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
          start_ndx = shard_id * num_per_shard
          end_ndx = min((shard_id+1) * num_per_shard, len(filenames)) # 每num_per_shard个图片生成一个tfr文件
          for i in range(start_ndx, end_ndx):
            sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
                i+1, len(filenames), shard_id)) # 输出进度条
            sys.stdout.flush()

            # Read the filename:
            image_data = tf.gfile.FastGFile(filenames[i], 'rb').read()
            height, width = image_reader.read_image_dims(sess, image_data)

            class_name = os.path.basename(os.path.dirname(filenames[i])) # 类名(文件夹名)
            class_id = class_names_to_ids[class_name] # 换成对应的类的id 如 0,1等

            example = dataset_utils.image_to_tfexample(
                image_data, b'jpg', height, width, class_id)
            tfrecord_writer.write(example.SerializeToString())

  sys.stdout.write('\n')
  sys.stdout.flush()

4、数据下载及解压

def download_and_uncompress_tarball(tarball_url, dataset_dir):
  """Downloads the `tarball_url` and uncompresses it locally.
  Args:
    tarball_url: The URL of a tarball file.
    dataset_dir: The directory where the temporary files are stored.
  """
  filename = tarball_url.split('/')[-1]
  filepath = os.path.join(dataset_dir, filename)

  def _progress(count, block_size, total_size):# 下载进度
    sys.stdout.write('\r>> Downloading %s %.1f%%' % (
        filename, float(count * block_size) / float(total_size) * 100.0))
    sys.stdout.flush()
  filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress)
  print()
  statinfo = os.stat(filepath)
  print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
  tarfile.open(filepath, 'r:gz').extractall(dataset_dir)

5、实现类名与id对应

class_names=['1','2','3','4','5']

class_names_to_ids = dict(zip(class_names, range(len(class_names)))) # 将类别名与id对应

print(class_names_to_ids)
# {'1': 0, '2': 1, '3': 2, '4': 3, '5': 4}

labels_to_class_names = dict(zip(range(len(class_names)), class_names)) # 让id与文件名(类别)对应起来
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值