tensonflow实现单机多卡GPU

版权声明:未经本人许可,不得用于商业用途及传统媒体。转载请注明出处! https://blog.csdn.net/qikaihuting/article/details/83042049

思想

  • 采用数据并行传输到多卡中,优化过程中进行合并,主要是每张GPU梯度的合并
  • 话不多说,上代码
#coding=utf-8
from __future__ import absolute_import, division, print_function

import os
import sys
import csv
import cv2
import time
import random
import argparse

import importlib
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim

#os.environ["CUDA_VISIBLE_DEVICES"] = "3"
devices = '0,1,2,3'
os.environ['CUDA_VISIBLE_DEVICES'] = devices
N_GPU = len(devices.split(','))

def get_loss(pred_angle,labels_batch):
    # Calculate the total losses
    loss = tf.reduce_mean(tf.square(pred_angle - labels_batch))
    regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    loss = tf.add_n([loss] + regularization_losses,name='total_loss')
    return loss

def average_gradients(tower_grads):
    average_grads = []
    for grad_and_vars in zip(*tower_grads):
        grads = []
        for g,_ in grad_and_vars:
            if g is not None:
                temp_g = tf.expand_dims(g,0)
                grads.append(temp_g)
                
        if grads:
            grad = tf.concat(grads,axis = 0)
            grad = tf.reduce_mean(grad,axis = 0)
            
            v = grad_and_vars[0][1]
            grad_and_vars = (grad,v)
            average_grads.append(grad_and_vars)
    
    return average_grads
def data_augment(image_name):
    #decode and read
    content = tf.read_file(image_name)
    img= tf.image.decode_png(content,dtype=tf.uint16)
    images = tf.cast(images, tf.float32)*(1.0/255)-0.5
    images = tf.reshape(images, shape=[224,224,3]) 
    return images 

def read_decode_csv(filename_queue):
    reader = tf.TextLineReader()
    key, value = reader.read(filename_queue)
    # Default values, in case of empty columns. Specifies the type of the decoded result.
    record_defaults = [["NULL"], [0.0],[0.0],[0.0]]
    features, col2, col3, col4 = tf.decode_csv(value, record_defaults=record_defaults)
    
    labels = tf.stack([col2, col3, col4])
    labels = labels
    return features,labels

def input(sub_dir,batch_size):
    filename = os.path.join(sub_dir,'train_list.csv')
    file_queue = tf.train.string_input_producer([filename])
    image_op, label_op = read_decode_csv(file_queue)
    
    images = data_augment(image_op)
    # Shuffle the examples and collect them into batch_size batches.
    # (Internally uses a RandomShuffleQueue.) We run this in two threads to avoid being a bottleneck.
    image_names, batch_image, batch_label = tf.train.shuffle_batch([image_op, images, label_op],
                                        batch_size=batch_size,  
                                        capacity = 1000 + 3 * batch_size,
                                        min_after_dequeue = 256,
                                        num_threads=4,
                                        allow_smaller_final_batch=False
                                        ) 
    return image_names, batch_image, batch_label   
def main(args):
    print("training online stage")
    steps_per_epoch = args.nums_samples // args.batch_size
    network = importlib.import_module('inference.ShuffleNet_v1')
    with tf.Graph().as_default(),tf.device('/cpu:0'):
        # Get the sets of images and labels for training and test 
        image_names, image_batch, labels_batch = input(args.sub_dir, batch_size=args.batch_size)
        
        global_step = tf.Variable(0, trainable=False, name="global_step",dtype=tf.float16)
        
        #Build a Graph that computes predictions from the inference model.            
        learning_rate = tf.train.exponential_decay(args.learning_rate, global_step,
                        args.learning_rate_decay_epochs*steps_per_epoch, args.learning_rate_decay_factor, staircase=True)
        opt = tf.train.RMSPropOptimizer(learning_rate,0.9,momentum=0.9,epsilon=1.0)
        
        ########单机多卡
        #数据分块
        images_splits = tf.split(image_batch, num_or_size_splits=N_GPU, axis=0)
        label_splits = tf.split(labels_batch, num_or_size_splits=N_GPU, axis=0)
        reuse_variables = None
        tower_loss_value = []
        tower_grads_and_vars = []
        for i in range(N_GPU):
            with tf.device('/gpu:%d'%i):
                with tf.name_scope('GPU_%d'%i) as scope:
                    with tf.variable_scope(tf.get_variable_scope(), reuse = reuse_variables):
                        outputs, _ = network.siamese_eval(images_splits[i], num_classes=3, is_training = True)
                        loss_value = get_loss(outputs,label_splits[i])
                        grads_and_vars = opt.compute_gradients(loss_value)
                        tower_loss_value.append(loss_value)
                        tower_grads_and_vars.append(grads_and_vars)      
            reuse_variables = True
            
        #total_loss = tf.reduce_sum(tower_loss_value,axis=0)
        total_loss = tf.reduce_mean(tower_loss_value,axis=0)
        ####合并梯度
        grads = average_gradients(tower_grads_and_vars)
                    
        apply_gradient_op=opt.apply_gradients(grads,global_step=global_step)

        variable_averages = tf.train.ExponentialMovingAverage(
            args.moving_average_decay, global_step)
        variables_averages_op = variable_averages.apply(tf.trainable_variables()+tf.moving_average_variables())
        
        with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
            train_op = tf.no_op(name='train')
        
        saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=3)
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement=True
        #config.log_device_placement=True
        
        sess = tf.Session(config=config)
        
        # Initialize variables
        sess.run(tf.global_variables_initializer())
        
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord, sess=sess)
            
        if args.pretrained_model!='':
            #ckpt = tf.train.get_checkpoint_state(args.checkpoint_dir)
            saver.restore(sess, args.pretrained_model)

        if not os.path.exists(args.checkpoint_dir):
            os.mkdir(args.checkpoint_dir) 
        checkpoint_file = os.path.join(args.checkpoint_dir, 'model.ckpt')
        epoch =0
        while epoch < args.nums_epoch:
            step = 0
            while step < steps_per_epoch:
                began = time.time()
                img_name, image, label = sess.run([image_names, image_batch, labels_batch])
                
                _, step_global, loss, lr= sess.run([train_op,global_step, total_loss, learning_rate])
                                                         
                duration = time.time() - began
                print('epoch %02d [%d/%d]\tTime %.3f\tloss %2.5f\t(lr %.5f)' %
                        (epoch, step_global, steps_per_epoch, duration, loss, lr))
                step +=1
                if step_global % 4000==0:
                    saver.save(sess, checkpoint_file, global_step=step_global)
            epoch +=1
        saver.save(sess, checkpoint_file, global_step=step_global)
        
        coord.request_stop()
        coord.join(threads)
        sess.close() 


def parse_arguments(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument('--learning_rate', type=float, default = 0.1)
    parser.add_argument('--learning_rate_decay_epochs',type=int, default = 5)
    parser.add_argument('--moving_average_decay', type=float, default = 0.99)
    parser.add_argument('--learning_rate_decay_factor', type=float, default = 0.98)
    parser.add_argument('--batch_size', type=int, default = 400)
    parser.add_argument('--pretrained_model', type=str, default = '')
    parser.add_argument('--checkpoint_dir', type=str, default ='./model')
    parser.add_argument('--sub_dir', type=str, default ='./')
    parser.add_argument('--nums_samples', type=int, default = 300000)
    parser.add_argument('--nums_epoch', type=int, default = 1000)
    return parser.parse_args(argv)

if __name__ == '__main__':
    main(parse_arguments(sys.argv[1:]))
展开阅读全文

没有更多推荐了,返回首页