Resnet Tensorflow源码实现解析(1)

这次开始读一个基于tensorflow的resnet的完整实现的源码,整个项目一共包含了三个文件cifar_input.py、resnet_main.py、resnet_model.py。这次先针对cifar_input.py文件进行源码的解析,该文件主要是实现数据集的预处理和读入的功能,具体代码先贴在下面。

关于cifar数据集的介绍以及如何将二进制数据输出为jpg格式可以参考博客【1】
在代码中使用tf.RandomShuffleQueue具体的介绍可以参考博客
【2】
在代码中使用tf.FIFOQueue具体的介绍可以参考博客【3】
对于tensorflow中批量输入数据的实现可以参考博客【4】

import tensorflow as tf

def build_input(dataset, data_path, batch_size, mode):
  """Build CIFAR image and labels.

  Args:
    dataset(数据集): Either 'cifar10' or 'cifar100'.
    data_path(数据集路径): Filename for data.
    batch_size: Input batch size.
    mode(模式): Either 'train' or 'eval'.
  Returns:
    images(图片): Batches of images. [batch_size, image_size, image_size, 3]
    labels(类别标签): Batches of labels. [batch_size, num_classes]
  Raises:
    ValueError: when the specified dataset is not supported.
  """
  
  # 数据集参数
  # 由于数据是二进制文件读入的,数据集中标签和图片数据在读入后是存放在一个一维
  # 列表中的,所以label_offset表示了二进制文件中表示标签字节起始位置的偏移量
  # label_bytes代表了标签在二进制文件中所占的字节数。
  image_size = 32
  if dataset == 'cifar10':
    label_bytes = 1
    label_offset = 0
    num_classes = 10
  elif dataset == 'cifar100':
    label_bytes = 1
    label_offset = 1
    num_classes = 100
  else:
    raise ValueError('Not supported dataset %s', dataset)

  # 数据读取参数
  depth = 3
  # image_bytes代表了一张图片所占的字节数量
  image_bytes = image_size * image_size * depth
  # record_bytes代表了一个二进制文件内标签加上图片本身数据总体的字节数量
  record_bytes = label_bytes + label_offset + image_bytes

  # 获取文件名列表
  # data_path是一个路径下的文件名或者符合正则pattern的文件,data_file就是在
  # data_path中所匹配的所有文件所构成的列表
  data_files = tf.gfile.Glob(data_path)
  # 文件名列表生成器
  file_queue = tf.train.string_input_producer(data_files, shuffle=True)
  # 文件名列表里读取原始二进制数据
  reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
  _, value = reader.read(file_queue)

  # 将原始二进制数据转换成图片数据及类别标签
  record = tf.reshape(tf.decode_raw(value, tf.uint8), [record_bytes])
  label = tf.cast(tf.slice(record, [label_offset], [label_bytes]), tf.int32)
  # 将数据串 [depth * height * width] 转换成矩阵 [depth, height, width].
  depth_major = tf.reshape(tf.slice(record, [label_bytes], [image_bytes]),
                           [depth, image_size, image_size])
  # 转换维数:[depth, height, width]转成[height, width, depth].
  image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)

  if mode == 'train':
    # 增减图片尺寸
    image = tf.image.resize_image_with_crop_or_pad(
                        image, image_size+4, image_size+4)
    # 随机裁剪图片
    image = tf.random_crop(image, [image_size, image_size, 3])
    # 随机水平翻转图片
    image = tf.image.random_flip_left_right(image)
    # 逐图片做像素值中心化(减均值)
    image = tf.image.per_image_standardization(image)

    # 建立输入数据队列(随机洗牌)
    example_queue = tf.RandomShuffleQueue(
        # 队列容量
        capacity=16 * batch_size,
        # 队列数据的最小容许量
        min_after_dequeue=8 * batch_size,
        dtypes=[tf.float32, tf.int32],
        # 图片数据尺寸,标签尺寸
        shapes=[[image_size, image_size, depth], [1]])
    # 读线程的数量
    num_threads = 16
  else:
    # 获取测试图片,并做像素值中心化
    image = tf.image.resize_image_with_crop_or_pad(
                        image, image_size, image_size)
    image = tf.image.per_image_standardization(image)

    # 建立输入数据队列(先入先出队列)
    example_queue = tf.FIFOQueue(
        3 * batch_size,
        dtypes=[tf.float32, tf.int32],
        shapes=[[image_size, image_size, depth], [1]])
    # 读线程的数量
    num_threads = 1

  # 数据入队操作
  example_enqueue_op = example_queue.enqueue([image, label])
  # 队列执行器
  tf.train.add_queue_runner(tf.train.queue_runner.QueueRunner(
      example_queue, [example_enqueue_op] * num_threads))

  # 数据出队操作,从队列读取Batch数据
  images, labels = example_queue.dequeue_many(batch_size)
  # 将标签数据由稀疏格式转换成稠密格式
  # [ 2,       [[0,1,0,0,0]
  #   4,        [0,0,0,1,0]  
  #   3,   -->  [0,0,1,0,0]    
  #   5,        [0,0,0,0,1]
  #   1 ]       [1,0,0,0,0]]
  labels = tf.reshape(labels, [batch_size, 1])
  indices = tf.reshape(tf.range(0, batch_size, 1), [batch_size, 1])
  labels = tf.sparse_to_dense(
                  tf.concat(values=[indices, labels], axis=1),
                  [batch_size, num_classes], 1.0, 0.0)

  #检测数据维度
  assert len(images.get_shape()) == 4
  assert images.get_shape()[0] == batch_size
  assert images.get_shape()[-1] == 3
  assert len(labels.get_shape()) == 2
  assert labels.get_shape()[0] == batch_size
  assert labels.get_shape()[1] == num_classes

  # 添加图片总结
  tf.summary.image('images', images)
  return images, labels

(1)二进制文件的读入操作

从源码中可以发现在读入二进制文件的原始信息时调用了tf.gfile.Glob、tf.train.string_input_producer、tf.FixedLengthRecordReader、reader.read等API。下面就针对这几个函数进行详细的介绍:

tf.gfile.Glob()

函数功能:查找匹配pattern的文件并以列表的形式返回,filename可以是一个具体的文件名,也可以是包含通配符的正则表达式。因此该函数返回的是所有满足匹配规则的文件名所构成的列表。
样例: 在样例代码中我首先设计了一个匹配规则data_path,之后调用tf.gfile.Glob(data_path)来获取文件列表。

data_path = r"D:\SourceInsightData\ResNet_TF\test*.txt"
data_files = tf.gfile.Glob(data_path)
print(data_files)

最终在控制台的输出结果为:

['D:\\SourceInsightData\\ResNet_TF\\test.txt', 'D:\\SourceInsightData\\ResNet_TF\\test1.txt']

tf.train.string_input_producer()

函数功能: 输出字符串到一个输入管道队列。
主要参数定义:
1、string_tensor: 1-D字符串Tensor
2、num_epochs=None: 一个整数(可选)。如果指定,string_input_producer在产生OutOfRange错误之前从string_tensor中产生num_epochs次字符串。如果未指定,则可以无限次循环遍历字符串。需要注意的是:当指定该参数为某一整数时,需要在调用tf.run之前初始化局部变量sess.run(tf.local_variables_initializer()),否则会报错。
3、shuffle=True: 当设置为True时文件队列的存放顺序会被随机打乱。
4、seed=None: 一个整数(可选)。如果shuffle==True,则使用种子。
5、capacity=32: 一个整数。设置队列容量。

tf.FixedLengthRecordReader()

函数功能: 创建一个可从文件名列表中读取内容的对象
在源码中可以发现该函数包含一个参数record_bytes,这就表示reader对象在每次调用read函数时都只能从相应的文件列表内读取固定大小的字节数量。

reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
 _, value = reader.read(file_queue)

样例: 在样例中每次从文件名队列内的文件读取7个字节的数据

data_path = r"D:\SourceInsightData\ResNet_TF\test*.txt"
data_files = tf.gfile.Glob(data_path)
file_queue = tf.train.string_input_producer(data_files, num_epochs=3, shuffle=True)
reader = tf.FixedLengthRecordReader(record_bytes=7)
key, value = reader.read(file_queue)

sess = tf.InteractiveSession()
init = tf.local_variables_initializer()
tf.train.start_queue_runners(sess=sess)
sess.run(init)

print(sess.run([key, value]))
print(sess.run([key, value]))

最终在控制台的输出结果为:

[b'D:\\SourceInsightData\\ResNet_TF\\test.txt:0', b'1234567']
[b'D:\\SourceInsightData\\ResNet_TF\\test1.txt:0', b'abcdefg']

(2)数据增强和预处理

在源码中调用了tf.image.resize_image_with_crop_or_pad、tf.random_crop、tf.image.random_flip_left_right三个API对原始数据进行了数据增强,调用了tf.image.per_image_standardization对图像数据进行了去中心话的操作。

tf.image.resize_image_with_crop_or_pad()

函数功能: 对图片进行裁剪和填充,当原图片尺寸大于目标图片
时,自动截取原图片居中位置。
从源码中可以知道函数包含三个参数,Image表示输入的图像数据,后两个参数分别表示所需要裁剪的最终图像的大小。

image = tf.image.resize_image_with_crop_or_pad(image, 
							image_size+4, image_size+4)

tf.random_crop()

函数功能: 随机裁剪图片。
从源码中可以发现该函数包含两个参数,image表示输入的图像数据,后一个参数表示随机裁剪的图片的shape。

image = tf.random_crop(image, [image_size, image_size, 3])

tf.image.random_flip_left_right()

函数功能: 随机水平翻转图片。

tf.image.per_image_standardization()

函数功能: 逐图片做像素值中心化。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值