java+整合handwrite_cnn handwrite使用原生的TensorFlow进行预测

100个汉字,放在data目录下。直接将下述文件和data存在同一个目录下运行即可。

关键参数:

run_mode = "train" 训练模型用,修改为validation 表示验证100张图片的预测精度,修改为inference表示预测 './data/00098/102544.png'这个图片手写识别结果,返回top3。

charset_size = 100 表示汉字数目。如果是全量数据,则为3755.

代码参考了:https://github.com/burness/tensorflow-101/blob/master/chinese_hand_write_rec/src/chinese_rec.py

其中加入:(1)图像随机左右旋转30度特性 (2)断点续传进行训练(3)为了达到更高精度,加入了一个卷积层,见https://github.com/AmemiyaYuko/HandwrittenChineseCharacterRecognition

import tensorflow as tf

import os

import random

import math

import tensorflow.contrib.slim as slim

import time

import logging

import numpy as np

import pickle

from PIL import Image

logger = logging.getLogger('Training a chinese write char recognition')

logger.setLevel(logging.INFO)

# formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

ch = logging.StreamHandler()

ch.setLevel(logging.INFO)

logger.addHandler(ch)

run_mode = "train"

charset_size = 100 # 3755

max_steps = 12002

save_steps = 2000

"""

# for online 3755 words training

checkpoint_dir = '/aiml/dfs/checkpoint/'

train_data_dir = '/aiml/data/train/'

test_data_dir = '/aiml/data/test/'

log_dir = '/aiml/dfs/'

"""

checkpoint_dir = './checkpoint2/'

train_data_dir = './data/'

test_data_dir = './data/'

log_dir = './'

tf.app.flags.DEFINE_string('mode', run_mode, 'Running mode. One of {"train", "valid", "test"}')

tf.app.flags.DEFINE_boolean('random_flip_up_down', True, "Whether to random flip up down")

tf.app.flags.DEFINE_boolean('random_brightness', True, "whether to adjust brightness")

tf.app.flags.DEFINE_boolean('random_contrast', True, "whether to random constrast")

tf.app.flags.DEFINE_integer('charset_size', charset_size, "Choose the first `charset_size` character to conduct our experiment.")

tf.app.flags.DEFINE_integer('image_size', 64, "Needs to provide same value as in training.")

tf.app.flags.DEFINE_boolean('gray', True, "whether to change the rbg to gray")

tf.app.flags.DEFINE_integer('max_steps', max_steps, 'the max training steps ')

tf.app.flags.DEFINE_integer('eval_steps', 50, "the step num to eval")

tf.app.flags.DEFINE_integer('save_steps', save_steps, "the steps to save")

tf.app.flags.DEFINE_string('checkpoint_dir', checkpoint_dir, 'the checkpoint dir')

tf.app.flags.DEFINE_string('train_data_dir', train_data_dir, 'the train dataset dir')

tf.app.flags.DEFINE_string('test_data_dir', test_data_dir, 'the test dataset dir')

tf.app.flags.DEFINE_string('log_dir', log_dir, 'the logging dir')

##############################

# resume training

tf.app.flags.DEFINE_boolean('restore', True, 'whether to restore from checkpoint')

##############################

tf.app.flags.DEFINE_boolean('epoch', 10, 'Number of epoches')

tf.app.flags.DEFINE_boolean('batch_size', 128, 'Validation batch size')

FLAGS = tf.app.flags.FLAGS

class DataIterator:

def __init__(self, data_dir):

# Set FLAGS.charset_size to a small value if available computation power is limited.

truncate_path = data_dir + ('%05d' % FLAGS.charset_size)

print(truncate_path)

self.image_names = []

for root, sub_folder, file_list in os.walk(data_dir):

if root < truncate_path:

self.image_names += [os.path.join(root, file_path) for file_path in file_list]

random.shuffle(self.image_names)

self.labels = [int(file_name[len(data_dir):].split(os.sep)[0]) for file_name in self.image_names]

@property

def size(self):

return len(self.labels)

@staticmethod

def data_augmentation(images):

if FLAGS.random_flip_up_down:

# images = tf.image.random_flip_up_down(images)

images = tf.contrib.image.rotate(images, random.randint(0, 30) * math.pi / 180, interpolation='BILINEAR')

if FLAGS.random_brightness:

images = tf.image.random_brightness(images, max_delta=0.3)

if FLAGS.random_contrast:

images = tf.image.random_contrast(images, 0.8, 1.2)

return images

def input_pipeline(self, batch_size, num_epochs=None, aug=False):

images_tensor = tf.convert_to_tensor(self.image_names, dtype=tf.string)

labels_tensor = tf.convert_to_tensor(self.labels, dtype=tf.int64)

input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor], num_epochs=num_epochs)

labels = input_queue[1]

images_content = tf.read_file(input_queue[0])

images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32)

if aug:

images = self.data_augmentation(images)

new_size = tf.constant([FLAGS.image_size, FLAGS.image_size], dtype=tf.int32)

images = tf.image.resize_images(images, new_size)

image_batch, label_batch = tf.train.shuffle_batch([images, labels], batch_size=batch_size, capacity=50000,

min_after_dequeue=10000)

return image_batch, label_batch

def build_graph(top_k):

# with tf.device('/cpu:0'):

keep_prob = tf.placeholder(dtype=tf.float32, shape=[], name='keep_prob')

images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1], name='image_batch')

labels = tf.placeholder(dtype=tf.int64, shape=[None], name='label_batch')

conv_1 = slim.conv2d(images, 64, [3, 3], 1, padding='SAME', scope='conv1')

max_pool_1 = slim.max_pool2d(conv_1, [2, 2], [2, 2], padding='SAME')

conv_2 = slim.conv2d(max_pool_1, 128, [3, 3], padding='SAME', scope='conv2')

max_pool_2 = slim.max_pool2d(conv_2, [2, 2], [2, 2], padding='SAME')

conv_3 = slim.conv2d(max_pool_2, 256, [3, 3], padding='SAME', scope='conv3')

max_pool_3 = slim.max_pool2d(conv_3, [2, 2], [2, 2], padding='SAME')

conv_4 = slim.conv2d(max_pool_3, 512, [3, 3], [2, 2], scope="conv4", padding="SAME")

max_pool_4 = slim.max_pool2d(conv_4, [2, 2], [2, 2], padding="SAME")

flatten = slim.flatten(max_pool_4)

fc1 = slim.fully_connected(slim.dropout(flatten, keep_prob), 1024, activation_fn=tf.nn.tanh, scope='fc1')

logits = slim.fully_connected(slim.dropout(fc1, keep_prob), FLAGS.charset_size, activation_fn=None, scope='fc2')

# logits = slim.fully_connected(flatten, FLAGS.charset_size, activation_fn=None, reuse=reuse, scope='fc')

loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))

accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.float32))

global_step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False)

rate = tf.train.exponential_decay(2e-4, global_step, decay_steps=2000, decay_rate=0.97, staircase=True)

train_op = tf.train.AdamOptimizer(learning_rate=rate).minimize(loss, global_step=global_step)

probabilities = tf.nn.softmax(logits)

tf.summary.scalar('loss', loss)

tf.summary.scalar('accuracy', accuracy)

merged_summary_op = tf.summary.merge_all()

predicted_val_top_k, predicted_index_top_k = tf.nn.top_k(probabilities, k=top_k)

accuracy_in_top_k = tf.reduce_mean(tf.cast(tf.nn.in_top_k(probabilities, labels, top_k), tf.float32))

return {'images': images,

'labels': labels,

'keep_prob': keep_prob,

'top_k': top_k,

'global_step': global_step,

'train_op': train_op,

'loss': loss,

'accuracy': accuracy,

'accuracy_top_k': accuracy_in_top_k,

'merged_summary_op': merged_summary_op,

'predicted_distribution': probabilities,

'predicted_index_top_k': predicted_index_top_k,

'predicted_val_top_k': predicted_val_top_k}

def train():

print('Begin training')

train_feeder = DataIterator(FLAGS.train_data_dir)

test_feeder = DataIterator(FLAGS.test_data_dir)

with tf.Session() as sess:

train_images, train_labels = train_feeder.input_pipeline(batch_size=FLAGS.batch_size, aug=True)

test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size)

graph = build_graph(top_k=1)

sess.run(tf.global_variables_initializer())

coord = tf.train.Coordinator()

threads = tf.train.start_queue_runners(sess=sess, coord=coord)

saver = tf.train.Saver()

train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)

test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/val')

start_step = 0

if FLAGS.restore:

ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)

if ckpt:

saver.restore(sess, ckpt)

print("restore from the checkpoint {0}".format(ckpt))

start_step += int(ckpt.split('-')[-1])

logger.info(':::Training Start:::')

try:

while not coord.should_stop():

start_time = time.time()

train_images_batch, train_labels_batch = sess.run([train_images, train_labels])

feed_dict = {graph['images']: train_images_batch,

graph['labels']: train_labels_batch,

graph['keep_prob']: 0.8}

_, loss_val, train_summary, step = sess.run(

[graph['train_op'], graph['loss'], graph['merged_summary_op'], graph['global_step']],

feed_dict=feed_dict)

train_writer.add_summary(train_summary, step)

end_time = time.time()

logger.info("the step {0} takes {1} loss {2}".format(step, end_time - start_time, loss_val))

if step > FLAGS.max_steps:

break

if step % FLAGS.eval_steps == 1:

test_images_batch, test_labels_batch = sess.run([test_images, test_labels])

feed_dict = {graph['images']: test_images_batch,

graph['labels']: test_labels_batch,

graph['keep_prob']: 1.0}

accuracy_test, test_summary = sess.run(

[graph['accuracy'], graph['merged_summary_op']],

feed_dict=feed_dict)

test_writer.add_summary(test_summary, step)

logger.info('===============Eval a batch=======================')

logger.info('the step {0} test accuracy: {1}'

.format(step, accuracy_test))

logger.info('===============Eval a batch=======================')

if step % FLAGS.save_steps == 1:

logger.info('Save the ckpt of {0}'.format(step))

saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'),

global_step=graph['global_step'])

except tf.errors.OutOfRangeError:

logger.info('==================Train Finished================')

saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'my-model'), global_step=graph['global_step'])

finally:

coord.request_stop()

coord.join(threads)

def validation():

print('validation')

test_feeder = DataIterator(FLAGS.test_data_dir)

final_predict_val = []

final_predict_index = []

groundtruth = []

with tf.Session() as sess:

test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size, num_epochs=1)

graph = build_graph(top_k=3)

sess.run(tf.global_variables_initializer())

sess.run(tf.local_variables_initializer()) # initialize test_feeder's inside state

coord = tf.train.Coordinator()

threads = tf.train.start_queue_runners(sess=sess, coord=coord)

saver = tf.train.Saver()

ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)

if ckpt:

saver.restore(sess, ckpt)

print("restore from the checkpoint {0}".format(ckpt))

print(':::Start validation:::')

try:

i = 0

acc_top_1, acc_top_k = 0.0, 0.0

while not coord.should_stop():

i += 1

start_time = time.time()

test_images_batch, test_labels_batch = sess.run([test_images, test_labels])

feed_dict = {graph['images']: test_images_batch,

graph['labels']: test_labels_batch,

graph['keep_prob']: 1.0}

batch_labels, probs, indices, acc_1, acc_k = sess.run([graph['labels'],

graph['predicted_val_top_k'],

graph['predicted_index_top_k'],

graph['accuracy'],

graph['accuracy_top_k']], feed_dict=feed_dict)

final_predict_val += probs.tolist()

final_predict_index += indices.tolist()

groundtruth += batch_labels.tolist()

acc_top_1 += acc_1

acc_top_k += acc_k

end_time = time.time()

logger.info("the batch {0} takes {1} seconds, accuracy = {2}(top_1) {3}(top_k)"

.format(i, end_time - start_time, acc_1, acc_k))

except tf.errors.OutOfRangeError:

logger.info('==================Validation Finished================')

acc_top_1 = acc_top_1 * FLAGS.batch_size / test_feeder.size

acc_top_k = acc_top_k * FLAGS.batch_size / test_feeder.size

logger.info('top 1 accuracy {0} top k accuracy {1}'.format(acc_top_1, acc_top_k))

finally:

coord.request_stop()

coord.join(threads)

return {'prob': final_predict_val, 'indices': final_predict_index, 'groundtruth': groundtruth}

def inference(image):

print('inference')

temp_image = Image.open(image).convert('L')

temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size), Image.ANTIALIAS)

temp_image = np.asarray(temp_image) / 255.0

temp_image = temp_image.reshape([-1, 64, 64, 1])

with tf.Session() as sess:

logger.info('========start inference============')

# images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1])

# Pass a shadow label 0. This label will not affect the computation graph.

graph = build_graph(top_k=3)

saver = tf.train.Saver()

ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)

if ckpt:

saver.restore(sess, ckpt)

predict_val, predict_index = sess.run([graph['predicted_val_top_k'], graph['predicted_index_top_k']],

feed_dict={graph['images']: temp_image, graph['keep_prob']: 1.0})

return predict_val, predict_index

def main(_):

print(FLAGS.mode)

if FLAGS.mode == "train":

train()

elif FLAGS.mode == 'validation':

dct = validation()

result_file = 'result.dict'

logger.info('Write result into {0}'.format(result_file))

with open(result_file, 'wb') as f:

pickle.dump(dct, f)

logger.info('Write file ends')

elif FLAGS.mode == 'inference':

image_path = './data/00098/102544.png'

final_predict_val, final_predict_index = inference(image_path)

logger.info('the result info label {0} predict index {1} predict_val {2}'.format(190, final_predict_index,

final_predict_val))

if __name__ == "__main__":

tf.app.run()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值