Tensorflow深度学习之十八:多GPU并行

本篇文章参考《Tensorflow实战Google深度学习框架》一书

import os.path
import re
import time
import numpy as np
import tensorflow as tf
import cifar10
batch_size = 128
max_steps = 1000
num_gpus=1 # 具体gpu数量

def tower_loss(scope):
    images, labels = cifar10.distorted_inputs()
    logits = cifar10.inference(images)
    _ = cifar10.loss(logits, labels)
    losses = tf.get_collection('losses', scope)
    total_loss = tf.add_n(losses, name='total_loss')
    return total_loss

def average_gradients(tower_grads):
    average_grads = []
    for grad_and_vars in zip(*tower_grads):
        grads = []
        for g, _ in grad_and_vars:
            expanded_g = tf.expand_dims(g, 0)
            grads.append(expanded_g)

        grad = tf.concat(grads, 0)
        grad = tf.reduce_mean(grad, 0)
        v = grad_and_vars[0][1]
        grad_and_var = (grad, v)
        average_grads.append(grad_and_var)
    return  average_grads

def train():
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
        num_batches_per_epoch = cifar10.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / batch_size
        decay_steps = int(num_batches_per_epoch * cifar10.NUM_EPOCHS_PER_DECAY)
        lr = tf.train.exponential_decay(cifar10.INITIAL_LEARNING_RATE,
                                        global_step,
                                        decay_steps,
                                        cifar10.LEARNING_RATE_DECAY_FACTOR,
                                        staircase=True)
        opt = tf.train.GradientDescentOptimizer(lr)

        tower_grads = []
        for i in range(num_gpus):
            with tf.device('/gpu:%d' % i):
                with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope:
                    loss = tower_loss(scope)
                    tf.get_variable_scope().reuse_variables()
                    grads = opt.compute_gradients(loss)
                    tower_grads.append(grads)
        grads = average_gradients(tower_grads)
        apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

        saver = tf.train.Saver(tf.all_variables())
        init = tf.global_variables_initializer()
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        sess.run(init)
        tf.train.start_queue_runners(sess=sess)

        for step in range(max_steps):
            start_time = time.time()
            _, loss_value = sess.run([apply_gradient_op, loss])
            duration = time.time() - start_time

            if step % 10 == 0:
                num_examples_per_step = batch_size * num_gpus
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = duration / num_gpus

                format_str = ('step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)')
                print(format_str % (step, loss_value, examples_per_sec, sec_per_batch))

            if step % 1000 == 0 or (step + 1) == max_steps:
                saver.save(sess, 'cifar10_train/model.ckpt', global_step=step) # 需要在当前py文件下存在cifar10_train目录

cifar10.maybe_download_and_extract()
train()
D:\Python\Anaconda3\python.exe C:/Users/amax/Desktop/UNet/test.py
>> Downloading cifar-10-binary.tar.gz 100.0%
Successfully downloaded cifar-10-binary.tar.gz 170052171 bytes.
Filling queue with 20000 CIFAR images before starting to train. This will take a few minutes.

step 0, loss = 4.67 (3.9 examples/sec; 33.204 sec/batch)
step 10, loss = 4.60 (335.4 examples/sec; 0.382 sec/batch)
step 20, loss = 4.35 (451.1 examples/sec; 0.284 sec/batch)
step 30, loss = 4.28 (245.9 examples/sec; 0.520 sec/batch)
step 40, loss = 4.35 (630.1 examples/sec; 0.203 sec/batch)
step 50, loss = 4.30 (655.4 examples/sec; 0.195 sec/batch)
step 60, loss = 4.39 (409.6 examples/sec; 0.313 sec/batch)
step 70, loss = 4.19 (340.4 examples/sec; 0.376 sec/batch)
step 80, loss = 4.09 (543.5 examples/sec; 0.235 sec/batch)
step 90, loss = 4.45 (512.0 examples/sec; 0.250 sec/batch)
step 100, loss = 4.08 (716.9 examples/sec; 0.179 sec/batch)
step 110, loss = 4.09 (459.7 examples/sec; 0.278 sec/batch)
step 120, loss = 4.05 (744.7 examples/sec; 0.172 sec/batch)
step 130, loss = 4.09 (621.2 examples/sec; 0.206 sec/batch)
step 140, loss = 3.98 (512.0 examples/sec; 0.250 sec/batch)
step 150, loss = 3.96 (420.5 examples/sec; 0.304 sec/batch)
step 160, loss = 3.92 (585.1 examples/sec; 0.219 sec/batch)
step 170, loss = 3.82 (701.5 examples/sec; 0.182 sec/batch)
step 180, loss = 3.74 (585.1 examples/sec; 0.219 sec/batch)
step 190, loss = 3.96 (251.9 examples/sec; 0.508 sec/batch)
step 200, loss = 3.74 (561.4 examples/sec; 0.228 sec/batch)
step 210, loss = 3.91 (440.4 examples/sec; 0.291 sec/batch)
step 220, loss = 3.74 (627.5 examples/sec; 0.204 sec/batch)
step 230, loss = 3.94 (608.2 examples/sec; 0.210 sec/batch)
step 240, loss = 3.74 (650.8 examples/sec; 0.197 sec/batch)
step 250, loss = 3.65 (658.4 examples/sec; 0.194 sec/batch)
step 260, loss = 3.69 (716.3 examples/sec; 0.179 sec/batch)
step 270, loss = 3.73 (768.2 examples/sec; 0.167 sec/batch)
step 280, loss = 3.52 (384.9 examples/sec; 0.333 sec/batch)
step 290, loss = 3.69 (273.5 examples/sec; 0.468 sec/batch)
step 300, loss = 3.57 (189.2 examples/sec; 0.677 sec/batch)
step 310, loss = 3.54 (656.7 examples/sec; 0.195 sec/batch)
step 320, loss = 3.76 (323.2 examples/sec; 0.396 sec/batch)
step 330, loss = 3.46 (682.7 examples/sec; 0.188 sec/batch)
step 340, loss = 3.45 (803.8 examples/sec; 0.159 sec/batch)
step 350, loss = 3.48 (808.5 examples/sec; 0.158 sec/batch)
step 360, loss = 3.29 (429.3 examples/sec; 0.298 sec/batch)
step 370, loss = 3.46 (682.7 examples/sec; 0.188 sec/batch)
step 380, loss = 3.28 (791.7 examples/sec; 0.162 sec/batch)
step 390, loss = 3.38 (355.7 examples/sec; 0.360 sec/batch)
step 400, loss = 3.53 (716.6 examples/sec; 0.179 sec/batch)
step 410, loss = 3.18 (497.8 examples/sec; 0.257 sec/batch)
step 420, loss = 3.26 (630.1 examples/sec; 0.203 sec/batch)
step 430, loss = 3.23 (714.0 examples/sec; 0.179 sec/batch)
step 440, loss = 3.24 (148.4 examples/sec; 0.862 sec/batch)
step 450, loss = 3.23 (400.4 examples/sec; 0.320 sec/batch)
step 460, loss = 3.19 (552.3 examples/sec; 0.232 sec/batch)
step 470, loss = 3.26 (608.2 examples/sec; 0.210 sec/batch)
step 480, loss = 3.25 (228.6 examples/sec; 0.560 sec/batch)
step 490, loss = 3.22 (615.7 examples/sec; 0.208 sec/batch)
step 500, loss = 3.44 (720.1 examples/sec; 0.178 sec/batch)
step 510, loss = 3.06 (682.7 examples/sec; 0.188 sec/batch)
step 520, loss = 3.27 (453.8 examples/sec; 0.282 sec/batch)
step 530, loss = 3.09 (682.7 examples/sec; 0.188 sec/batch)
step 540, loss = 3.02 (674.3 examples/sec; 0.190 sec/batch)
step 550, loss = 3.17 (591.7 examples/sec; 0.216 sec/batch)
step 560, loss = 3.11 (570.4 examples/sec; 0.224 sec/batch)
step 570, loss = 3.17 (271.3 examples/sec; 0.472 sec/batch)
step 580, loss = 3.13 (536.0 examples/sec; 0.239 sec/batch)
step 590, loss = 2.98 (732.6 examples/sec; 0.175 sec/batch)
step 600, loss = 3.03 (819.2 examples/sec; 0.156 sec/batch)
step 610, loss = 3.01 (764.0 examples/sec; 0.168 sec/batch)
step 620, loss = 3.08 (752.6 examples/sec; 0.170 sec/batch)
step 630, loss = 3.09 (512.0 examples/sec; 0.250 sec/batch)
step 640, loss = 3.09 (481.9 examples/sec; 0.266 sec/batch)
step 650, loss = 2.96 (498.8 examples/sec; 0.257 sec/batch)
step 660, loss = 2.71 (658.9 examples/sec; 0.194 sec/batch)
step 670, loss = 2.85 (186.4 examples/sec; 0.687 sec/batch)
step 680, loss = 2.77 (557.5 examples/sec; 0.230 sec/batch)
step 690, loss = 3.06 (326.0 examples/sec; 0.393 sec/batch)
step 700, loss = 2.86 (715.7 examples/sec; 0.179 sec/batch)
step 710, loss = 2.77 (442.5 examples/sec; 0.289 sec/batch)
step 720, loss = 2.86 (496.5 examples/sec; 0.258 sec/batch)
step 730, loss = 2.85 (254.7 examples/sec; 0.503 sec/batch)
step 740, loss = 2.83 (388.6 examples/sec; 0.329 sec/batch)
step 750, loss = 2.73 (294.3 examples/sec; 0.435 sec/batch)
step 760, loss = 2.77 (481.9 examples/sec; 0.266 sec/batch)
step 770, loss = 2.59 (245.2 examples/sec; 0.522 sec/batch)
step 780, loss = 2.75 (555.0 examples/sec; 0.231 sec/batch)
step 790, loss = 2.73 (691.7 examples/sec; 0.185 sec/batch)
step 800, loss = 2.89 (706.2 examples/sec; 0.181 sec/batch)
step 810, loss = 2.76 (434.3 examples/sec; 0.295 sec/batch)
step 820, loss = 2.80 (274.2 examples/sec; 0.467 sec/batch)
step 830, loss = 2.93 (268.9 examples/sec; 0.476 sec/batch)
step 840, loss = 2.72 (379.6 examples/sec; 0.337 sec/batch)
step 850, loss = 2.53 (608.1 examples/sec; 0.210 sec/batch)
step 860, loss = 2.58 (715.8 examples/sec; 0.179 sec/batch)
step 870, loss = 2.53 (703.4 examples/sec; 0.182 sec/batch)
step 880, loss = 2.58 (277.7 examples/sec; 0.461 sec/batch)
step 890, loss = 2.77 (471.5 examples/sec; 0.272 sec/batch)
step 900, loss = 2.55 (744.7 examples/sec; 0.172 sec/batch)
step 910, loss = 2.60 (274.1 examples/sec; 0.467 sec/batch)
step 920, loss = 2.51 (630.1 examples/sec; 0.203 sec/batch)
step 930, loss = 2.53 (538.1 examples/sec; 0.238 sec/batch)
step 940, loss = 2.54 (630.1 examples/sec; 0.203 sec/batch)
step 950, loss = 2.45 (398.2 examples/sec; 0.321 sec/batch)
step 960, loss = 2.43 (283.2 examples/sec; 0.452 sec/batch)
step 970, loss = 2.60 (350.2 examples/sec; 0.365 sec/batch)
step 980, loss = 2.41 (566.6 examples/sec; 0.226 sec/batch)
step 990, loss = 2.55 (296.2 examples/sec; 0.432 sec/batch)
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值