TensorFlow22: 手写汉字识别

NIST手写数字数据集通常做为深度学习的练习数据集,这个数据集恐怕早已经被大家玩坏了。本帖就介绍一个和MNIST类似,同时又适合国人练习的数据集-手写汉字数据集,然后训练一个简单的Deep Convolutional Network识别手写汉字。

识别手写汉字要把识别手写洋文难上很多。首先,英文字符的分类少,总共10+26*2;而中文总共50,000多汉字,常用的就有3000多。其次,汉字有书法,每个人书写风格多样。

手写汉字数据集: CASIA-HWDB

下载HWDB1.1数据集:

[python]  view plain  copy
  1. $ wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1trn_gnt.zip  
  2. # zip解压没得说, 之后还要解压alz压缩文件  
  3. $ wget http://www.nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.1tst_gnt.zip  

这个数据集由模式识别国家重点实验室共享,它还共享了其它几个数据库,先mark:

  • 行为分析数据库
  • 三维人脸数据库
  • 中文语言资源库
  • 步态数据库
  • 掌纹数据库
  • 虹膜库数据

手写汉字的样子:

[python]  view plain  copy
  1. import os  
  2. import numpy as np  
  3. import struct  
  4. import PIL.Image  
  5.    
  6. train_data_dir = "HWDB1.1trn_gnt"  
  7. test_data_dir = "HWDB1.1tst_gnt"  
  8.    
  9. # 读取图像和对应的汉字  
  10. def read_from_gnt_dir(gnt_dir=train_data_dir):  
  11.     def one_file(f):  
  12.         header_size = 10  
  13.         while True:  
  14.             header = np.fromfile(f, dtype='uint8', count=header_size)  
  15.             if not header.size: break  
  16.             sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24)  
  17.             tagcode = header[5] + (header[4]<<8)  
  18.             width = header[6] + (header[7]<<8)  
  19.             height = header[8] + (header[9]<<8)  
  20.             if header_size + width*height != sample_size:  
  21.                 break  
  22.             image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width))  
  23.             yield image, tagcode  
  24.    
  25.     for file_name in os.listdir(gnt_dir):  
  26.         if file_name.endswith('.gnt'):  
  27.             file_path = os.path.join(gnt_dir, file_name)  
  28.             with open(file_path, 'rb') as f:  
  29.                 for image, tagcode in one_file(f):  
  30.                     yield image, tagcode  
  31.    
  32. # 统计样本数  
  33. train_counter = 0  
  34. test_counter = 0  
  35. for image, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir):  
  36.     tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')  
  37.     """ 
  38.     # 提取点图像, 看看什么样 
  39.     if train_counter < 1000: 
  40.         im = PIL.Image.fromarray(image) 
  41.         im.convert('RGB').save('png/' + tagcode_unicode + str(train_counter) + '.png') 
  42.     """  
  43.     train_counter += 1  
  44. for image, tagcode in read_from_gnt_dir(gnt_dir=test_data_dir):  
  45.     tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')  
  46.     test_counter += 1  
  47.    
  48. # 样本数  
  49. print(train_counter, test_counter)  

TensorFlow练习: 手写汉字识别

由于时间和系统资源有限,我只使用数据集的一部分(只识别最常用的140个汉字)。

训练模型

[python]  view plain  copy
  1. import os  
  2. import numpy as np  
  3. import struct  
  4. import PIL.Image  
  5.    
  6. train_data_dir = "HWDB1.1trn_gnt"  
  7. test_data_dir = "HWDB1.1tst_gnt"  
  8.    
  9. # 读取图像和对应的汉字  
  10. def read_from_gnt_dir(gnt_dir=train_data_dir):  
  11.     def one_file(f):  
  12.         header_size = 10  
  13.         while True:  
  14.             header = np.fromfile(f, dtype='uint8', count=header_size)  
  15.             if not header.size: break  
  16.             sample_size = header[0] + (header[1]<<8) + (header[2]<<16) + (header[3]<<24)  
  17.             tagcode = header[5] + (header[4]<<8)  
  18.             width = header[6] + (header[7]<<8)  
  19.             height = header[8] + (header[9]<<8)  
  20.             if header_size + width*height != sample_size:  
  21.                 break  
  22.             image = np.fromfile(f, dtype='uint8', count=width*height).reshape((height, width))  
  23.             yield image, tagcode  
  24.    
  25.     for file_name in os.listdir(gnt_dir):  
  26.         if file_name.endswith('.gnt'):  
  27.             file_path = os.path.join(gnt_dir, file_name)  
  28.             with open(file_path, 'rb') as f:  
  29.                 for image, tagcode in one_file(f):  
  30.                     yield image, tagcode  
  31.    
  32. import scipy.misc  
  33. from sklearn.utils import shuffle  
  34. import tensorflow as tf  
  35.    
  36. # 我取常用的前140个汉字进行测试  
  37. char_set = "的一是了我不人在他有这个上们来到时大地为子中你说生国年着就那和要她出也得里后自以会家可下而过天去能对小多然于心学么之都好看起发当没成只如事把还用第样道想作种开美总从无情己面最女但现前些所同日手又行意动方期它头经长儿回位分爱老因很给名法间斯知世什两次使身者被高已亲其进此话常与活正感"  
  38.    
  39. def resize_and_normalize_image(img):  
  40.     # 补方  
  41.     pad_size = abs(img.shape[0]-img.shape[1]) // 2  
  42.     if img.shape[0] < img.shape[1]:  
  43.         pad_dims = ((pad_size, pad_size), (00))  
  44.     else:  
  45.         pad_dims = ((00), (pad_size, pad_size))  
  46.     img = np.lib.pad(img, pad_dims, mode='constant', constant_values=255)  
  47.     # 缩放  
  48.     img = scipy.misc.imresize(img, (64 - 4*264 - 4*2))  
  49.     img = np.lib.pad(img, ((44), (44)), mode='constant', constant_values=255)  
  50.     assert img.shape == (6464)  
  51.    
  52.     img = img.flatten()  
  53.     # 像素值范围-1到1  
  54.     img = (img - 128) / 128  
  55.     return img  
  56.    
  57. # one hot  
  58. def convert_to_one_hot(char):  
  59.     vector = np.zeros(len(char_set))  
  60.     vector[char_set.index(char)] = 1  
  61.     return vector  
  62.    
  63. # 由于数据量不大, 可一次全部加载到RAM  
  64. train_data_x = []  
  65. train_data_y = []  
  66.    
  67. for image, tagcode in read_from_gnt_dir(gnt_dir=train_data_dir):  
  68.     tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')  
  69.     if tagcode_unicode in char_set:  
  70.         train_data_x.append(resize_and_normalize_image(image))  
  71.         train_data_y.append(convert_to_one_hot(tagcode_unicode))  
  72.    
  73. # shuffle样本  
  74. train_data_x, train_data_y = shuffle(train_data_x, train_data_y, random_state=0)  
  75.    
  76. batch_size = 128  
  77. num_batch = len(train_data_x) // batch_size  
  78.    
  79. text_data_x = []  
  80. text_data_y = []  
  81. for image, tagcode in read_from_gnt_dir(gnt_dir=test_data_dir):  
  82.     tagcode_unicode = struct.pack('>H', tagcode).decode('gb2312')  
  83.     if tagcode_unicode in char_set:  
  84.         text_data_x.append(resize_and_normalize_image(image))  
  85.         text_data_y.append(convert_to_one_hot(tagcode_unicode))  
  86. # shuffle样本  
  87. text_data_x, text_data_y = shuffle(text_data_x, text_data_y, random_state=0)  
  88.    
  89.    
  90. X = tf.placeholder(tf.float32, [None64*64])  
  91. Y = tf.placeholder(tf.float32, [None140])  
  92. keep_prob = tf.placeholder(tf.float32)  
  93.    
  94. def chinese_hand_write_cnn():  
  95.     x = tf.reshape(X, shape=[-164641])  
  96.     # 3 conv layers  
  97.     w_c1 = tf.Variable(tf.random_normal([33132], stddev=0.01))  
  98.     b_c1 = tf.Variable(tf.zeros([32]))  
  99.     conv1 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x, w_c1, strides=[1111], padding='SAME'), b_c1))  
  100.     conv1 = tf.nn.max_pool(conv1, ksize=[1221], strides=[1221], padding='SAME')  
  101.       
  102.     w_c2 = tf.Variable(tf.random_normal([333264], stddev=0.01))  
  103.     b_c2 = tf.Variable(tf.zeros([64]))  
  104.     conv2 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv1, w_c2, strides=[1111], padding='SAME'), b_c2))  
  105.     conv2 = tf.nn.max_pool(conv2, ksize=[1221], strides=[1221], padding='SAME')  
  106.       
  107.     """ 
  108.     # 训练开始之后我就去睡觉了, 早晨起来一看, 白跑了, 准确率不足10%; 把网络变量改少了再来一发 
  109.     w_c3 = tf.Variable(tf.random_normal([3, 3, 64, 128], stddev=0.01)) 
  110.     b_c3 = tf.Variable(tf.zeros([128])) 
  111.     conv3 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv2, w_c3, strides=[1, 1, 1, 1], padding='SAME'), b_c3)) 
  112.     conv3 = tf.nn.max_pool(conv3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 
  113.     conv3 = tf.nn.dropout(conv3, keep_prob) 
  114.     """  
  115.    
  116.     # fully connect layer  
  117.     w_d = tf.Variable(tf.random_normal([8*32*641024], stddev=0.01))  
  118.     b_d = tf.Variable(tf.zeros([1024]))  
  119.     dense = tf.reshape(conv2, [-1, w_d.get_shape().as_list()[0]])  
  120.     dense = tf.nn.relu(tf.add(tf.matmul(dense, w_d), b_d))  
  121.     dense = tf.nn.dropout(dense, keep_prob)  
  122.    
  123.     w_out = tf.Variable(tf.random_normal([1024140], stddev=0.01))  
  124.     b_out = tf.Variable(tf.zeros([140]))  
  125.     out = tf.add(tf.matmul(dense, w_out), b_out)  
  126.    
  127.     return out  
  128.    
  129. def train_hand_write_cnn():  
  130.     output = chinese_hand_write_cnn()  
  131.    
  132.     loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(output, Y))  
  133.     optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)  
  134.    
  135.     accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(output, 1), tf.argmax(Y, 1)), tf.float32))  
  136.    
  137.     # TensorBoard  
  138.     tf.scalar_summary("loss", loss)  
  139.     tf.scalar_summary("accuracy", accuracy)  
  140.     merged_summary_op = tf.merge_all_summaries()  
  141.    
  142.     saver = tf.train.Saver()  
  143.     with tf.Session() as sess:  
  144.         sess.run(tf.global_variables_initializer())  
  145.    
  146.         # 命令行执行 tensorboard --logdir=./log  打开浏览器访问http://0.0.0.0:6006  
  147.         summary_writer = tf.train.SummaryWriter('./log', graph=tf.get_default_graph())  
  148.    
  149.         for e in range(50):  
  150.             for i in range(num_batch):  
  151.                 batch_x = train_data_x[i*batch_size : (i+1)*batch_size]  
  152.                 batch_y = train_data_y[i*batch_size : (i+1)*batch_size]  
  153.                 _, loss_, summary = sess.run([optimizer, loss, merged_summary_op], feed_dict={X: batch_x, Y: batch_y, keep_prob: 0.5})  
  154.                 # 每次迭代都保存日志  
  155.                 summary_writer.add_summary(summary, e*num_batch+i)  
  156.                 print(e*num_batch+i, loss_)  
  157.    
  158.                 if e*num_batch+i % 100 == 0:  
  159.                     # 计算准确率  
  160.                     acc = accuracy.eval({X: text_data_x[:500], Y: text_data_y[:500], keep_prob: 1.})  
  161.                     #acc = sess.run(accuracy, feed_dict={X: text_data_x[:500], Y: text_data_y[:500], keep_prob: 1.})  
  162.                     print(e*num_batch+i, acc)  
  163.    
  164. train_hand_write_cnn()  

Computation Graph:

TensorFlow练习22: 手写汉字识别

loss:

TensorFlow练习22: 手写汉字识别

准确率:

TensorFlow练习22: 手写汉字识别


  • 0
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值