def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'channels': tf.FixedLenFeature([], tf.int64),
'image_data': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
})
image = tf.decode_raw(features['image_data'], tf.uint8)
image = tf.reshape(image, [100, 100, 3])
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
#image = tf.cast(image, tf.float32)
label = tf.cast(features['label'], tf.int32)
return image, label
def inputs(filename, batch_size):
with tf.name_scope('input'):
filename_queue = tf.train.string_input_producer(
[filename], num_epochs=2000)
image, label = read_and_decode(filename_queue)
images, labels = tf.train.shuffle_batch(
[image, label], batch_size=batch_size, num_threads=1,
capacity=4,
min_after_dequeue=2)
return images, labels
def train():
'''训练过程'''
global_step = tf.get_variable('global_step', [],
initializer=tf.constant_initializer(0),
trainable=False, dtype=tf.int32)
batch_size = FLAGS.batch_size
train_images, train_labels = inputs("./tfrecord_data/train.tfrecord", batch_size )
test_images, test_labels = inputs("./tfrecord_data/train.tfrecord", batch_size )
train_labels_one_hot = tf.one_hot(train_labels, 2, on_value=1.0, off_value=0.0)
test_labels_one_hot = tf.one_hot(test_labels, 2, on_value=1.0, off_value=0.0)
#因为任务比较简单,故意把学习率调小了,以拉长训练过程。
learning_rate = 0.000001
with tf.variable_scope("inference") as scope:
train_y_conv = inference(train_images)
scope.reuse_variables()
test_y_conv = inference(test_images)
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=train_labels_one_hot, logits=train_y_conv))
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
train_op = optimizer.minimize(cross_entropy, global_step=global_step)
train_correct_prediction = tf.equal(tf.argmax(train_y_conv, 1), tf.argmax(train_labels_one_hot, 1))
train_accuracy = tf.reduce_mean(tf.cast(train_correct_prediction, tf.float32))
test_correct_prediction = tf.equal(tf.argmax(test_y_conv, 1), tf.argmax(test_labels_one_hot, 1))
test_accuracy = tf.reduce_mean(tf.cast(test_correct_prediction, tf.float32))
init_op = tf.global_variables_initializer()
local_init_op = tf.local_variables_initializer()
saver = tf.train.Saver()
tf.summary.scalar('cross_entropy_loss', cross_entropy)
tf.summary.scalar('train_acc', train_accuracy)
summary_op = tf.summary.merge_all()
gpu_options = tf.GPUOptions(
per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
config = tf.ConfigProto(gpu_options=gpu_options)
with tf.Session(config=config) as sess:
if FLAGS.reload_model == 1:
ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)
saver.restore(sess, ckpt.model_checkpoint_path)
save_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
print("reload model from%s, save_step =%d" % (ckpt.model_checkpoint_path, save_step))
else:
print("Create model with fresh paramters.")
sess.run(init_op)
sess.run(local_init_op)
summary_writer = tf.summary.FileWriter(FLAGS.event_dir, sess.graph)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
while not coord.should_stop():
_, g_step = sess.run([train_op, global_step])
if g_step % 2 == 0:
summary_str = sess.run(summary_op)
summary_writer.add_summary(summary_str, g_step)
if g_step % 100 == 0:
train_accuracy_value, loss = sess.run([train_accuracy, cross_entropy])
print("step%dtraining_acc is%.2f, loss is%.4f" % (g_step, train_accuracy_value, loss))
if g_step % 1000 == 0:
test_accuracy_value = sess.run(test_accuracy)
print("step%dtest_acc is%.2f" % (g_step, test_accuracy_value))
if g_step % 2000 == 0:
#保存一次模型
print("save model to%s" % FLAGS.model_dir + "model.ckpt." + str(g_step) )
saver.save(sess, FLAGS.model_dir + "model.ckpt", global_step=global_step)
except tf.errors.OutOfRangeError:
pass
finally:
coord.request_stop()
coord.join(threads)