tensorflow、多GPU、多线程训练VGG19来做cifar-10分类

本文介绍了如何使用TensorFlow进行多GPU并行训练VGG19模型,以解决CIFAR-10分类任务。内容包括数据输入、网络构建、训练及验证过程中的关键点,特别提到了多线程数据读取的顺序问题、验证集上关闭dropout的重要性以及避免使用tf.nn.in_top_k()函数的陷阱。
摘要由CSDN通过智能技术生成

背景:几天前需要写个多GPU训练的算法模型,翻来覆去在tensorflow的官网上看到cifar-10的官方代码,花了三天时间去掉代码的冗余部分和改写成自己的风格。

代码共有6部分构成:

1、data_input.py 由于cifar-10官方训练集和验证集都是.bin格式文件,故使用同一份数据读入代码

2、network.py 搭建VGG19,返回带weight decay的变量loss和交叉熵之和作为总loss

3、train.py 在每个GPU中建立tower,并行训练

4、val.py 多线程进行模型验证


5、toTFRecords.py 由于使用多线程无法读入label,故将图像和图像名(作label)制作成TFRecords

6、 test_tfrecords.py 读TFRecords文件,计算模型输出成csv


1、data_input.py:坑点为多线程读入数据时,如果num_threads==1,则万事大吉,否则,由于不同线程读取数据快慢不同,读入num_threads个数据的时候会出现一定程度的顺序打乱

import os
import tensorflow as tf

class data_input(object):
  def __init__(self, data_dir, batch_size, num_classes, is_training):
    self.data_dir = data_dir
    self.batch_size = batch_size
    self.num_classes = num_classes
    self.is_training = is_training
    
    self.image_batch, self.label_batch = self.load_data()
    
  #Input of this function is args, output is batch of images and batch of labels.
  def load_data(self):
    
    if self.is_training == True:
      filenames = [os.path.join(self.data_dir, 'data_batch_%d.bin' % i) for i in range(1, 6)]
    else:
      filenames = [os.path.join(self.data_dir, 'test_batch.bin')]
      
    for f in filenames:
      if not tf.gfile.Exists(f):
        raise ValueError('Failed to find file: ' + f)
    
    filename_queue = tf.train.string_input_producer(filenames, shuffle = False)
    image, label = self.load_one_sample(filename_queue)
    
    image.set_shape([32, 32, 3])
    label.set_shape([self.num_classes])
    
    #if the "num_threads" is not 1, due to the speed difference of each thread, there will be a shuffle in every num_threads data.
    if self.is_training == True:#Then data augmentation == True, shuffle == True.
      image = self.data_augmentation(image)
      image_batch, label_batch = tf.train.shuffle_batch([image, label], batch_size = self.batch_size, num_threads = 16, capacity = 20400, min_after_dequeue = 20000)
    else:
      image = tf.image.resize_image_with_crop_or_pad(image, 28, 28)
      image_batch, label_batch = tf.train.batch([image, label], batch_size = self.batch_size, num_threads = 16, capacity = 20400)
    
    return image_batch, tf.reshape(label_batch, [self.batch_size, self.num_classes])

  #From filename queue to read image and label. Image occupies 32*32*3 bytes, label occupies 1 bytes.
  def load_one_sample(self, filename_queue):
    
    image_bytes = 32 * 32 * 3
    label_bytes = 1
    
    record_bytes = image_bytes + label_bytes
    reader = tf.FixedLengthRecordReader(record_bytes = record_bytes)
    
    key, value = reader.read(filename_queue)
    record_bytes = tf.decode_raw(value, tf.uint8)
    
    label = tf.cast(tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
    label = tf.one_hot(label, self.num_classes)
    
    image = tf.reshape(tf.strided_slice(record_bytes, [label_bytes], [label_bytes + image_bytes]), [3, 32, 32])
    image = tf.transpose(image, [1, 2, 0])
    image = tf.cast(image, tf.float32)
    
    return image, tf.reshape(label, [self.num_classes])
    
  
  def data_augmentation(self, image):
    
    image = tf.random_crop(image, [28, 28, 3])
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, max_delta = 63)
    image = tf.image.random_contrast(image, lower = 0.2, upper = 1.8)
    image = tf.image.per_image_standardization(image)
    
    return image


2、network.py:坑点为不要忘记在验证集和测试集上关闭dropout。。。。

import tensorflow as tf

class NETWORK(object):
  def __init__(self, image_batch, label_batch, keep_prob, num_classes, is_training):
    self.image_batch = image_batch
    self.label_batch = label_batch
    
    if is_training is True:
      self.keep_prob = keep_prob
    else:
      self.keep_prob = 1
      
    self.num_classes = num_classes
    
    self.logits = self.inference()
    self.losses = self.loss()
  
  
  def inference(self):
    conv1_1 = conv(self.image_batch, 3, 64, 1, 'SAME', 0.0004, 'conv1_1')
    conv1_2 = conv(conv1_1, 3, 64, 2, 'SAME', None, 'conv1_2')
    pool1 = max_pool(conv1_2, 3, 2, 'SAME', 'pool1')
    
    conv2_1 = conv(pool1, 3, 128, 1, 'SAME', 0.0004, 'conv2_1')
    conv2_2 = conv(conv2_1, 3, 128, 1, 'SAME', None, 'conv2_2')
    pool2 = max_pool(conv2_2, 3, 2, 'SAME', 'pool2')

    conv3_1 = conv(pool2, 3, 256, 1, 'SAME', 0.0004, 'conv3_1')
    conv3_2 = conv(conv3_1, 3, 256, 1, 'SAME', None, 'conv3_2')
    conv3_3 = conv(conv3_2, 3, 256, 1, 'SAME', None, 'conv3_3')
    conv3_4 = conv(conv3_3, 3, 256, 1, 'SAME', None, 'conv3_4')
    pool3 = max_pool(conv3_4, 3, 2, 'SAME', 'pool3') 
    
    conv4_1 = conv(pool3, 3, 512, 1, 'SAME', 0.0004, 'conv4_1')
    conv4_2 = conv(conv4_1, 3, 512, 1, 'SAME', None, 'conv4_2')
    conv4_3 = conv(conv4_2, 3, 512, 1, 'SAME', None, 'conv4_3')
    conv4_4 = conv(conv4_3, 3, 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值