python数字识别tensorflow_用TensorFlow实现0和1数字识别

def read_and_decode(filename_queue):

reader = tf.TFRecordReader()

_, serialized_example = reader.read(filename_queue)

features = tf.parse_single_example(

serialized_example,

features={

'height': tf.FixedLenFeature([], tf.int64),

'width': tf.FixedLenFeature([], tf.int64),

'channels': tf.FixedLenFeature([], tf.int64),

'image_data': tf.FixedLenFeature([], tf.string),

'label': tf.FixedLenFeature([], tf.int64),

})

image = tf.decode_raw(features['image_data'], tf.uint8)

image = tf.reshape(image, [100, 100, 3])

image = tf.cast(image, tf.float32) * (1. / 255) - 0.5

#image = tf.cast(image, tf.float32)

label = tf.cast(features['label'], tf.int32)

return image, label

def inputs(filename, batch_size):

with tf.name_scope('input'):

filename_queue = tf.train.string_input_producer(

[filename], num_epochs=2000)

image, label = read_and_decode(filename_queue)

images, labels = tf.train.shuffle_batch(

[image, label], batch_size=batch_size, num_threads=1,

capacity=4,

min_after_dequeue=2)

return images, labels

def train():

'''训练过程'''

global_step = tf.get_variable('global_step', [],

initializer=tf.constant_initializer(0),

trainable=False, dtype=tf.int32)

batch_size = FLAGS.batch_size

train_images, train_labels = inputs("./tfrecord_data/train.tfrecord", batch_size )

test_images, test_labels = inputs("./tfrecord_data/train.tfrecord", batch_size )

train_labels_one_hot = tf.one_hot(train_labels, 2, on_value=1.0, off_value=0.0)

test_labels_one_hot = tf.one_hot(test_labels, 2, on_value=1.0, off_value=0.0)

#因为任务比较简单,故意把学习率调小了,以拉长训练过程。

learning_rate = 0.000001

with tf.variable_scope("inference") as scope:

train_y_conv = inference(train_images)

scope.reuse_variables()

test_y_conv = inference(test_images)

cross_entropy = tf.reduce_mean(

tf.nn.softmax_cross_entropy_with_logits(labels=train_labels_one_hot, logits=train_y_conv))

optimizer = tf.train.GradientDescentOptimizer(learning_rate)

train_op = optimizer.minimize(cross_entropy, global_step=global_step)

train_correct_prediction = tf.equal(tf.argmax(train_y_conv, 1), tf.argmax(train_labels_one_hot, 1))

train_accuracy = tf.reduce_mean(tf.cast(train_correct_prediction, tf.float32))

test_correct_prediction = tf.equal(tf.argmax(test_y_conv, 1), tf.argmax(test_labels_one_hot, 1))

test_accuracy = tf.reduce_mean(tf.cast(test_correct_prediction, tf.float32))

init_op = tf.global_variables_initializer()

local_init_op = tf.local_variables_initializer()

saver = tf.train.Saver()

tf.summary.scalar('cross_entropy_loss', cross_entropy)

tf.summary.scalar('train_acc', train_accuracy)

summary_op = tf.summary.merge_all()

gpu_options = tf.GPUOptions(

per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)

config = tf.ConfigProto(gpu_options=gpu_options)

with tf.Session(config=config) as sess:

if FLAGS.reload_model == 1:

ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)

saver.restore(sess, ckpt.model_checkpoint_path)

save_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])

print("reload model from%s, save_step =%d" % (ckpt.model_checkpoint_path, save_step))

else:

print("Create model with fresh paramters.")

sess.run(init_op)

sess.run(local_init_op)

summary_writer = tf.summary.FileWriter(FLAGS.event_dir, sess.graph)

coord = tf.train.Coordinator()

threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:

while not coord.should_stop():

_, g_step = sess.run([train_op, global_step])

if g_step % 2 == 0:

summary_str = sess.run(summary_op)

summary_writer.add_summary(summary_str, g_step)

if g_step % 100 == 0:

train_accuracy_value, loss = sess.run([train_accuracy, cross_entropy])

print("step%dtraining_acc is%.2f, loss is%.4f" % (g_step, train_accuracy_value, loss))

if g_step % 1000 == 0:

test_accuracy_value = sess.run(test_accuracy)

print("step%dtest_acc is%.2f" % (g_step, test_accuracy_value))

if g_step % 2000 == 0:

#保存一次模型

print("save model to%s" % FLAGS.model_dir + "model.ckpt." + str(g_step) )

saver.save(sess, FLAGS.model_dir + "model.ckpt", global_step=global_step)

except tf.errors.OutOfRangeError:

pass

finally:

coord.request_stop()

coord.join(threads)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值