代码来源《深度学习:卷积神经网络从入门到精通》,使用oxflower—17数据集
# train.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import os.path
import time
import numpy as np
from six.moves import xrange
import tensorflow as tf
import data_loader
import arch
import sys
import argparse
def loss(logits, labels): # 定义损失函数
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits,
name='cross_entropy_per_example')
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
tf.summary.scalar('Cross Entropy Loss', cross_entropy_mean) # 数据的汇总和记录
return cross_entropy_mean
def average_gradients(tower_grads): # 定义平均梯度函数
average_grads = []
for grad_and_vars in zip(*tower_grads): # zip函数可接受任意多个序列为参数,返回tuple列表
grads = []
for g, _ in grad_and_vars:
expanded_g = tf.expand_dims(g, 0) # 扩展维度
grads.append(expanded_g)
grad = tf.concat(axis=0, values=grads)
grad = tf.reduce_mean(grad, 0)
v = grad_and_vars[0][1]
grad_and_var = (grad, v)
average_grads.append(grad_and_var)
return average_grads
def train(args): # 定义训练过程
with tf.device('/cpu:0'):
images, labels = data_loader.read_inputs(True, args)
epoch_number = tf.get_variable('epoch_number', [], dtype=tf.int32,
initializer=tf.constant_initializer(0), trainable=False)
lr = tf.train.piecewise_constant(epoch_number, [19, 30, 44, 53],
[0.01, 0.005, 0.001, 0.0005, 0.0001], name='LearningRate')
wd = tf.train.piecewise_constant(epoch_number, [30], [0.0005, 0.0],
name='WeightDecay')
opt = tf.train.MomentumOptimizer(lr, 0.9) # 使用动量优化方法
tower_grads = []
with tf.variable_scope(tf.get_variable_scope()):
for i in xrange(args.num_gpus):
with tf.device('/gpu:%d' % i):
with tf.name_scope('Tower_%d' % i) as scope:
logits = arch.get_model(images, wd, True, args)
top1acc = tf.reduce_mean(tf.cast(tf.nn.in_top_k(logits,
labels, 1), tf.float32)) # top-1准确率
top5acc = tf.reduce_mean(tf.cast(tf.nn.in_top_k(logits,
labels, 5), tf.float32)) # top-5准确率
cross_entropy_mean = loss(logits, labels)
regularization_losses = tf.get_collection(tf.
GraphKeys.REGULARIZATION_LOSSES)
reg_loss = tf.add_n(regularization_losses)
# 对应位置元素相加
tf.summary.scalar('Regularization Loss', reg_loss)
# 对reg_loss标量汇总和记录
total_loss = tf.add(cross_entropy_mean, reg_loss)
tf.summary.scalar('Total Loss', total_loss)
# 对total_loss标量汇总和记录
tf.summary.scalar('Top-1 Accuracy', top1acc)
# 对top1acc标量汇总和记录
tf.summary.scalar('Top-5 Accuracy', top5acc)
# 对top5acc标量汇总和记录
tf.get_variable_scope().reuse_variables()
# 表示允许重用当前scope下所有变量
summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
batchnorm_updates = tf.get_collection(tf.GraphKeys.
UPDATE_OPS, scope)
grads = opt.compute_gradients(total_loss)
# 按批计算数据的梯度
tower_grads.append(grads)
grads = average_gradients(tower_grads)
summaries.append(tf.summary.scalar('learning_rate', lr))
summaries.append(tf.summary.scalar('weight_decay', wd))
apply_gradient_op = opt.apply_gradients(grads) # 更新模型的权值参数
batchnorm_updates_op = tf.group(*batchnorm_updates) # 更新BN层的参数
train_op = tf.group(apply_gradient_op, batchnorm_updates_op)
saver = tf.train.Saver(tf.global_variables(), max_to_keep=args.num_epochs)
summary_op = tf.summary.merge_all()
init = tf.global_variables_initializer()
if args.log_debug_info:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
else:
run_options = None
run_metadata = None
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
log_device_placement=args.log_device_placement))
if args.retrain_from is not None:
saver.restore(sess, args.retrain_from)
else:
sess.run(init)
tf.train.start_queue_runners(sess=sess) # 启动输入管道的线程
summary_writer = tf.summary.FileWriter(args.log_dir, sess.graph)
start_epoch = sess.run(epoch_number + 1)
for epoch in range(start_epoch, start_epoch + args.num_epochs):
sess.run(epoch_number.assign(epoch))
for step in range(args.num_batches):
start_time = time.time()
_, loss_value, top1_accuracy, top5_accuracy = sess.run([train_op,
cross_entropy_mean,
top1acc, top5acc],
options=run_options,
run_metadata=run_metadata)
duration = time.time() - start_time
assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
if step % 10 == 0:
num_examples_per_step = args.chunked_batch_size * args.num_gpus
examples_per_sec = num_examples_per_step / duration
sec_per_batch = duration / args.num_gpus
format_str = (
'%s: epoch %d, step %d, loss = %.2f, Top-1 = %.2f Top-5 = %.2f (%.1f examples/sec; %.3f sec/batch)')
print(format_str % (datetime.now(), epoch, step, loss_value,
top1_accuracy, top5_accuracy,
examples_per_sec, sec_per_batch))
sys.stdout.flush() # 等到程序执行完毕在屏幕上一次性输出结果
if step % 100 == 0:
summary_str = sess.run(summary_op)
summary_writer.add_summary(summary_str, args.num_batches * epoch + step)
# 写入文件
if args.log_debug_info:
summary_writer.add_run_metadata(run_metadata, 'e