深度学习实战(六) 多机多卡分布式训练cifar10完整实现

准备工作:

数据集下载地址:

http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz

 

实现部分(附详细注释):

首先获取用于训练的小批量数据,由于获取过程中需要对图像进行处理,避免阻塞训练进程,我们开启16个线程来从队列获取批量图像。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os 
import tensorflow as tf
import matplotlib.pyplot as plt
 
# %matplotlib inline

IMAGE_SIZE = 32

NUM_CLASSES = 10
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000 #训练集的样本总数
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000 #验证集的样本总数

cifar_label_bytes = 1  # 2 for CIFAR-100 第一个字节为label
cifar_height = 32
cifar_width = 32
cifar_depth = 3 #通道数

#生产批量输入
def generate_batch_inputs(eval_data, shuffle, data_dir, batch_size):
    """
    参数:
    eval_data: bool值,指定训练或者验证.
    shuffle: bool值,是否将数据顺序打乱.
    data_dir: CIFAR-10数据集所在目录.
    batch_size: 批量大小.

    返回值:
    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
    labels: Labels. 1D tensor of [batch_size] size.
    """
 
    if not eval_data:
        filepath = os.path.join(data_dir, 'data_batch_*') 
        num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
    else:
        filepath = os.path.join(data_dir, 'test_batch*')
        num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL

    files = tf.train.match_filenames_once(filepath)

    with tf.name_scope('input'):
        # tf.train.string_input_producer会使用初始化时提供的文件列表创建一个输入队列,
        # 创建好的输入队列可以作为文件读取函数的参数.
        # shuffle参数为True时,文件在加入队列之前会被打乱顺序
        # tf.train.string_input_producer生成的输入队列可以同时被多个文件读取线程操作,
        # 而且输入队列会将队列中的文件均匀地分配给不同的线程,不会出现有些文件被处理过多次而有些文件还没有被处理过的情况
        # 当一个输入队列中的所有文件都被处理完后,它会将初始化时提供的文件类表中的文件全部重新加入队列,
        # 通过num_epochs参数来限制加载初始化文件列表的最大轮数。当所有文件都已经被使用了设定的轮数后,
        # 如果继续尝试读取新的文件,输入队列会报OutOfRange的错误。这里我们取None不做限制
        filename_queue = tf.train.string_input_producer(files, shuffle=False, num_epochs=None)
    
        # 从文件队列读取样本
        image_bytes = cifar_height * cifar_width * cifar_depth
        #每条数据的长度
        record_bytes = cifar_label_bytes + image_bytes
        # 读取固定长度的一条数据
        reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
        key, value = reader.read(filename_queue)

        # 格式转换
        record_bytes = tf.decode_raw(value, tf.uint8)

        # 第一个字节为分类标签
        label = tf.cast(
            tf.strided_slice(record_bytes, [0], [cifar_label_bytes]), tf.int32)

        # 标签字节后面的字节表示图片信息 
        # reshape from [depth * height * width] to [depth, height, width].
        depth_major = tf.reshape(
            tf.strided_slice(record_bytes, [cifar_label_bytes],
                           [cifar_label_bytes + image_bytes]),
            [cifar_depth, cifar_height, cifar_width])
        # Convert from [depth, height, width] to [height, width, depth].
        uint8image = tf.transpose(depth_major, [1, 2, 0])
    
        reshaped_image = tf.cast(uint8image, tf.float32)
        #plt.imshow(reshaped_image)
        '''
        if not eval_data:
            # 数据增强用于训练
            # 随机的对图片进行一些处理,原来的一张图片在多次epoch中就会生成多张不同的图片,这样就增加了样本数量
            #由于数据增强会耗费大量的CPU时间,因此我们用16个线程来处理

            # Randomly crop a [IMAGE_SIZE, IMAGE_SIZE] section of the image.
            resized_image = tf.random_crop(reshaped_image, [IMAGE_SIZE, IMAGE_SIZE, 3])

            # Randomly flip the image horizontally.
            resized_image = tf.image.random_flip_left_right(resized_image)

            # Because these operations are not commutative, consider randomizing
            # the order their operation.
            # NOTE: since per_image_standardization zeros the mean and makes
            # the stddev unit, this likely has no effect see tensorflow#1458.
            resized_image = tf.image.random_brightness(resized_image,
                                                     max_delta=63)
            resized_image = tf.image.random_contrast(resized_image,
                                                   lower=0.2, upper=1.8)

        else:
            # 裁剪中间部分用于验证
            resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image,
                                                               IMAGE_SIZE, IMAGE_SIZE)
        
        # 减去均值并除以像素的方差
        float_image = tf.image.per_image_standardization(resized_image)
        '''
        
        # 这里我们不对图片进行任何处理,得到更大的图像,以便后面训练得到更快的收敛和更好的精度
        float_image 
  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值