上一篇提到文字数据集的合成,现在我们手头上已经得到了3755个汉字(一级字库)的印刷体图像数据集,我们可以利用它们进行接下来的3755个汉字的识别系统的搭建。用深度学习做文字识别,用的网络当然是CNN,那具体使用哪个经典网络?VGG?RESNET?还是其他?我想了下,越深的网络训练得到的模型应该会更好,但是想到训练的难度以及以后线上部署时预测的速度,我觉得首先建立一个比较浅的网络(基于LeNet的改进)做基本的文字识别,然后再根据项目需求,再尝试其他的网络结构。这次任务所使用的深度学习框架是强大的Tensorflow。
网络搭建
第一步当然是搭建网络和计算图
其实文字识别就是一个多分类任务,比如这个3755文字识别就是3755个类别的分类任务。我们定义的网络非常简单,基本就是LeNet的改进版,值得注意的是我们加入了batch normalization。另外我们的损失函数选择sparse_softmax_cross_entropy_with_logits,优化器选择了Adam,学习率设为0.1
#network: conv2d->max_pool2d->conv2d->max_pool2d->conv2d->max_pool2d->conv2d->conv2d->max_pool2d->fully_connected->fully_connected
def build_graph(top_k): keep_prob = tf.placeholder(dtype=tf.float32, shape=[], name='keep_prob') images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1], name='image_batch') labels = tf.placeholder(dtype=tf.int64, shape=[None], name='label_batch') is_training = tf.placeholder(dtype=tf.bool, shape=[], name='train_flag') with tf.device('/gpu:5'): #给slim.conv2d和slim.fully_connected准备了默认参数:batch_norm with slim.arg_scope([slim.conv2d, slim.fully_connected], normalizer_fn=slim.batch_norm, normalizer_params={ 'is_training': is_training}): conv3_1 = slim.conv2d(images, 64, [3, 3], 1, padding='SAME', scope='conv3_1') max_pool_1 = slim.max_pool2d(conv3_1, [2, 2], [2, 2], padding='SAME', scope='pool1') conv3_2 = slim.conv2d(max_pool_1, 128, [3, 3], padding='SAME', scope='conv3_2') max_pool_2 = slim.max_pool2d(conv3_2, [2, 2], [2, 2], padding='SAME', scope='pool2') conv3_3 = slim.conv2d(max_pool_2, 256, [3, 3], padding='SAME', scope='conv3_3') max_pool_3 = slim.max_pool2d(conv3_3, [2, 2], [2, 2], padding='SAME', scope='pool3') conv3_4 = slim.conv2d(max_pool_3, 512, [3, 3], padding='SAME', scope='conv3_4') conv3_5 = slim.conv2d(conv3_4, 512, [3, 3], padding='SAME', scope='conv3_5') max_pool_4 = slim.max_pool2d(conv3_5, [2, 2], [2, 2], padding='SAME', scope='pool4') flatten = slim.flatten(max_pool_4) fc1 = slim.fully_connected(slim.dropout(flatten, keep_prob), 1024, activation_fn=tf.nn.relu, scope='fc1') logits = slim.fully_connected(slim.dropout(fc1, keep_prob), FLAGS.charset_size, activation_fn=None, scope='fc2') # 因为我们没有做热编码,所以使用sparse_softmax_cross_entropy_with_logits loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)) accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.float32)) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) if update_ops: updates = tf.group(*update_ops) loss = control_flow_ops.with_dependencies([updates], loss) global_step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False) optimizer = tf.train.AdamOptimizer(learning_rate=0.1) train_op = slim.learning.create_train_op(loss, optimizer, global_st