思想
- 采用数据并行传输到多卡中,优化过程中进行合并,主要是每张GPU梯度的合并
- 话不多说,上代码
from __future__ import absolute_import, division, print_function
import os
import sys
import csv
import cv2
import time
import random
import argparse
import importlib
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
devices = '0,1,2,3'
os.environ['CUDA_VISIBLE_DEVICES'] = devices
N_GPU = len(devices.split(','))
def get_loss(pred_angle,labels_batch):
loss = tf.reduce_mean(tf.square(pred_angle - labels_batch))
regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
loss = tf.add_n([loss] + regularization_losses,name='total_loss')
return loss
def average_gradients(tower_grads):
average_grads = []
for grad_and_vars in zip(*tower_grads):
grads = []
for g,_ in grad_and_vars:
if g is not None:
temp_g = tf.expand_dims(g,0)
grads.append(temp_g)
if grads:
grad = tf.concat(grads,axis = 0)
grad = tf.reduce_mean(grad,axis = 0)
v = grad_and_vars[0][1]
grad_and_vars = (grad,v)
average_grads.append(grad_and_vars)
return average_grads
def data_augment(image_name):
content = tf.read_file(image_name)
img= tf.image.decode_png(content,dtype=tf.uint16)
images = tf.cast(images, tf.float32)*(1.0/255)-0.5
images = tf.reshape(images, shape=[224,224,3])
return images
def read_decode_csv(filename_queue):
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
record_defaults = [["NULL"], [0.0],[0.0],[0.0]]
features, col2, col3, col4 = tf.decode_csv(value, record_defaults=record_defaults)
labels = tf.stack([col2, col3, col4])
labels = labels
return features,labels
def input(sub_dir,batch_size):
filename = os.path.join(sub_dir,'train_list.csv')
file_queue = tf.train.string_input_producer([filename])
image_op, label_op = read_decode_csv(file_queue)
images = data_augment(image_op)
image_names, batch_image, batch_label = tf.train.shuffle_batch([image_op, images, label_op],
batch_size=batch_size,
capacity = 1000 + 3 * batch_size,
min_after_dequeue = 256,
num_threads=4,
allow_smaller_final_batch=False
)
return image_names, batch_image, batch_label
def main(args):
print("training online stage")
steps_per_epoch = args.nums_samples // args.batch_size
network = importlib.import_module('inference.ShuffleNet_v1')
with tf.Graph().as_default(),tf.device('/cpu:0'):
image_names, image_batch, labels_batch = input(args.sub_dir, batch_size=args.batch_size)
global_step = tf.Variable(0, trainable=False, name="global_step",dtype=tf.float16)
learning_rate = tf.train.exponential_decay(args.learning_rate, global_step,
args.learning_rate_decay_epochs*steps_per_epoch, args.learning_rate_decay_factor, staircase=True)
opt = tf.train.RMSPropOptimizer(learning_rate,0.9,momentum=0.9,epsilon=1.0)
images_splits = tf.split(image_batch, num_or_size_splits=N_GPU, axis=0)
label_splits = tf.split(labels_batch, num_or_size_splits=N_GPU, axis=0)
reuse_variables = None
tower_loss_value = []
tower_grads_and_vars = []
for i in range(N_GPU):
with tf.device('/gpu:%d'%i):
with tf.name_scope('GPU_%d'%i) as scope:
with tf.variable_scope(tf.get_variable_scope(), reuse = reuse_variables):
outputs, _ = network.siamese_eval(images_splits[i], num_classes=3, is_training = True)
loss_value = get_loss(outputs,label_splits[i])
grads_and_vars = opt.compute_gradients(loss_value)
tower_loss_value.append(loss_value)
tower_grads_and_vars.append(grads_and_vars)
reuse_variables = True
total_loss = tf.reduce_mean(tower_loss_value,axis=0)
grads = average_gradients(tower_grads_and_vars)
apply_gradient_op=opt.apply_gradients(grads,global_step=global_step)
variable_averages = tf.train.ExponentialMovingAverage(
args.moving_average_decay, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables()+tf.moving_average_variables())
with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
train_op = tf.no_op(name='train')
saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=3)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement=True
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord, sess=sess)
if args.pretrained_model!='':
saver.restore(sess, args.pretrained_model)
if not os.path.exists(args.checkpoint_dir):
os.mkdir(args.checkpoint_dir)
checkpoint_file = os.path.join(args.checkpoint_dir, 'model.ckpt')
epoch =0
while epoch < args.nums_epoch:
step = 0
while step < steps_per_epoch:
began = time.time()
img_name, image, label = sess.run([image_names, image_batch, labels_batch])
_, step_global, loss, lr= sess.run([train_op,global_step, total_loss, learning_rate])
duration = time.time() - began
print('epoch %02d [%d/%d]\tTime %.3f\tloss %2.5f\t(lr %.5f)' %
(epoch, step_global, steps_per_epoch, duration, loss, lr))
step +=1
if step_global % 4000==0:
saver.save(sess, checkpoint_file, global_step=step_global)
epoch +=1
saver.save(sess, checkpoint_file, global_step=step_global)
coord.request_stop()
coord.join(threads)
sess.close()
def parse_arguments(argv):
parser = argparse.ArgumentParser()
parser.add_argument('--learning_rate', type=float, default = 0.1)
parser.add_argument('--learning_rate_decay_epochs',type=int, default = 5)
parser.add_argument('--moving_average_decay', type=float, default = 0.99)
parser.add_argument('--learning_rate_decay_factor', type=float, default = 0.98)
parser.add_argument('--batch_size', type=int, default = 400)
parser.add_argument('--pretrained_model', type=str, default = '')
parser.add_argument('--checkpoint_dir', type=str, default ='./model')
parser.add_argument('--sub_dir', type=str, default ='./')
parser.add_argument('--nums_samples', type=int, default = 300000)
parser.add_argument('--nums_epoch', type=int, default = 1000)
return parser.parse_args(argv)
if __name__ == '__main__':
main(parse_arguments(sys.argv[1:]))