背景:几天前需要写个多GPU训练的算法模型,翻来覆去在tensorflow的官网上看到cifar-10的官方代码,花了三天时间去掉代码的冗余部分和改写成自己的风格。
代码共有6部分构成:
1、data_input.py 由于cifar-10官方训练集和验证集都是.bin格式文件,故使用同一份数据读入代码
2、network.py 搭建VGG19,返回带weight decay的变量loss和交叉熵之和作为总loss
3、train.py 在每个GPU中建立tower,并行训练
4、val.py 多线程进行模型验证
5、toTFRecords.py 由于使用多线程无法读入label,故将图像和图像名(作label)制作成TFRecords
6、 test_tfrecords.py 读TFRecords文件,计算模型输出成csv
1、data_input.py:坑点为多线程读入数据时,如果num_threads==1,则万事大吉,否则,由于不同线程读取数据快慢不同,读入num_threads个数据的时候会出现一定程度的顺序打乱
import os
import tensorflow as tf
class data_input(object):
def __init__(self, data_dir, batch_size, num_classes, is_training):
self.data_dir = data_dir
self.batch_size = batch_size
self.num_classes = num_classes
self.is_training = is_training
self.image_batch, self.label_batch = self.load_data()
#Input of this function is args, output is batch of images and batch of labels.
def load_data(self):
if self.is_training == True:
filenames = [os.path.join(self.data_dir, 'data_batch_%d.bin' % i) for i in range(1, 6)]
else:
filenames = [os.path.join(self.data_dir, 'test_batch.bin')]
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
filename_queue = tf.train.string_input_producer(filenames, shuffle = False)
image, label = self.load_one_sample(filename_queue)
image.set_shape([32, 32, 3])
label.set_shape([self.num_classes])
#if the "num_threads" is not 1, due to the speed difference of each thread, there will be a shuffle in every num_threads data.
if self.is_training == True:#Then data augmentation == True, shuffle == True.
image = self.data_augmentation(image)
image_batch, label_batch = tf.train.shuffle_batch([image, label], batch_size = self.batch_size, num_threads = 16, capacity = 20400, min_after_dequeue = 20000)
else:
image = tf.image.resize_image_with_crop_or_pad(image, 28, 28)
image_batch, label_batch = tf.train.batch([image, label], batch_size = self.batch_size, num_threads = 16, capacity = 20400)
return image_batch, tf.reshape(label_batch, [self.batch_size, self.num_classes])
#From filename queue to read image and label. Image occupies 32*32*3 bytes, label occupies 1 bytes.
def load_one_sample(self, filename_queue):
image_bytes = 32 * 32 * 3
label_bytes = 1
record_bytes = image_bytes + label_bytes
reader = tf.FixedLengthRecordReader(record_bytes = record_bytes)
key, value = reader.read(filename_queue)
record_bytes = tf.decode_raw(value, tf.uint8)
label = tf.cast(tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
label = tf.one_hot(label, self.num_classes)
image = tf.reshape(tf.strided_slice(record_bytes, [label_bytes], [label_bytes + image_bytes]), [3, 32, 32])
image = tf.transpose(image, [1, 2, 0])
image = tf.cast(image, tf.float32)
return image, tf.reshape(label, [self.num_classes])
def data_augmentation(self, image):
image = tf.random_crop(image, [28, 28, 3])
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, max_delta = 63)
image = tf.image.random_contrast(image, lower = 0.2, upper = 1.8)
image = tf.image.per_image_standardization(image)
return image
2、network.py:坑点为不要忘记在验证集和测试集上关闭dropout。。。。
import tensorflow as tf
class NETWORK(object):
def __init__(self, image_batch, label_batch, keep_prob, num_classes, is_training):
self.image_batch = image_batch
self.label_batch = label_batch
if is_training is True:
self.keep_prob = keep_prob
else:
self.keep_prob = 1
self.num_classes = num_classes
self.logits = self.inference()
self.losses = self.loss()
def inference(self):
conv1_1 = conv(self.image_batch, 3, 64, 1, 'SAME', 0.0004, 'conv1_1')
conv1_2 = conv(conv1_1, 3, 64, 2, 'SAME', None, 'conv1_2')
pool1 = max_pool(conv1_2, 3, 2, 'SAME', 'pool1')
conv2_1 = conv(pool1, 3, 128, 1, 'SAME', 0.0004, 'conv2_1')
conv2_2 = conv(conv2_1, 3, 128, 1, 'SAME', None, 'conv2_2')
pool2 = max_pool(conv2_2, 3, 2, 'SAME', 'pool2')
conv3_1 = conv(pool2, 3, 256, 1, 'SAME', 0.0004, 'conv3_1')
conv3_2 = conv(conv3_1, 3, 256, 1, 'SAME', None, 'conv3_2')
conv3_3 = conv(conv3_2, 3, 256, 1, 'SAME', None, 'conv3_3')
conv3_4 = conv(conv3_3, 3, 256, 1, 'SAME', None, 'conv3_4')
pool3 = max_pool(conv3_4, 3, 2, 'SAME', 'pool3')
conv4_1 = conv(pool3, 3, 512, 1, 'SAME', 0.0004, 'conv4_1')
conv4_2 = conv(conv4_1, 3, 512, 1, 'SAME', None, 'conv4_2')
conv4_3 = conv(conv4_2, 3, 512, 1, 'SAME', None, 'conv4_3')
conv4_4 = conv(conv4_3, 3,