Attention-ocr是一种端到端的图像文字识别方法。其输入是含有文字的图片,输出是文字。整个出来过程是:
图像—-CNN特征提取—Encoder—Visual Attention—Decoder
以下是基于Github上的Attention-OCR进行安装测试
Ubuntu 16.04
GPU:GTX1080
Tensorflow 1.8.0
python 2.7.12
1. 安装Distance
wget http://www.cs.cmu.edu/~yuntiand/Distance-0.1.3.tar.gz
tar zxf Distance-0.1.3.tar.gz
cd distance; sudo python setup.py install
2. 下载Attention-OCR
git clone https://github.com/da03/Attention-OCR.git
3. 准备数据集
数据格式可以表示如下:
图像路径 标签
如文件train.txt中内容为:
1.jpg hello
2.jpg 2345
作者也提供一些数据
下载并解压到Attention-OCR根目录中
wget http://www.cs.cmu.edu/~yuntiand/sample.tgz
tar zxf sample.tgz
4. 训练
基于作者提供的数据进行训练
python src/launcher.py --phase=train --data-path=sample/sample.txt --data-base-dir=sample --log-path=log.txt --no-load-model
如果是自己的数据需要注意 –data-path –data-base-dir 的路径是否正确
正常情况会有以下内容
...
2016-06-08 20:47:22,335 root INFO Created model with fresh parameters.
2016-06-08 20:47:52,852 root INFO current_step: 0
2016-06-08 20:48:01,253 root INFO step_time: 8.400597, step perplexity: 38.998714
2016-06-08 20:48:01,385 root INFO current_step: 1
2016-06-08 20:48:07,166 root INFO step_time: 5.781749, step perplexity: 38.998445
2016-06-08 20:48:07,337 root INFO current_step: 2
2016-06-08 20:48:12,322 root INFO step_time: 4.984972, step perplexity: 39.006730
2016-06-08 20:48:12,347 root INFO current_step: 3
2016-06-08 20:48:16,821 root INFO step_time: 4.473902, step perplexity: 39.000267
2016-06-08 20:48:16,859 root INFO current_step: 4
2016-06-08 20:48:21,452 root INFO step_time: 4.593249, step perplexity: 39.009864
2016-06-08 20:48:21,530 root INFO current_step: 5
2016-06-08 20:48:25,878 root INFO step_time: 4.348195, step perplexity: 38.987707
2016-06-08 20:48:26,016 root INFO current_step: 6
2016-06-08 20:48:30,851 root INFO step_time: 4.835423, step perplexity: 39.022887
5 . 测试
训练结束之后会在根目录有一个train的文件夹,里面保存的是模型文件
python src/launcher.py --phase=test --visualize --data-path=test.txt --data-base-dir=data/img --log-path=log.txt --load-model --model-dir=train --output-dir=results
参数
--data-path 测试文件 test.txt
--data-base-dir 图片所在目录 如果test.txt在data中,data有一个图片文件夹img,test.txt 内容是:1.jpg hello --data-base-dir=data/img
--load-model 模型所在的文件夾,如在train中 --load-model=train
--result 生成的结果所在的文件夹
2016-06-08 22:36:31,638 root INFO Reading model parameters from model/translate.ckpt-47200
2016-06-08 22:36:40,529 root INFO Compare word based on edit distance.
2016-06-08 22:36:41,652 root INFO step_time: 1.119277, step perplexity: 1.056626
2016-06-08 22:36:41,660 root INFO 1.000000 out of 1 correct
2016-06-08 22:36:42,358 root INFO step_time: 0.696687, step perplexity: 2.003350
2016-06-08 22:36:42,363 root INFO 1.666667 out of 2 correct
2016-06-08 22:36:42,831 root INFO step_time: 0.466550, step perplexity: 1.501963
2016-06-08 22:36:42,835 root INFO 2.466667 out of 3 correct
2016-06-08 22:36:43,402 root INFO step_time: 0.562091, step perplexity: 1.269991
2016-06-08 22:36:43,418 root INFO 3.366667 out of 4 correct
2016-06-08 22:36:43,897 root INFO step_time: 0.477545, step perplexity: 1.072437
2016-06-08 22:36:43,905 root INFO 4.366667 out of 5 correct
2016-06-08 22:36:44,107 root INFO step_time: 0.195361, step perplexity: 2.071796
2016-06-08 22:36:44,127 root INFO 5.144444 out of 6 correct
训练过程遇到的错误
由于tensorflow版本不同导致的错误
seq2se.py 47行
linear = core_rnn_cell._linear
seq2seq_model.py 87
single_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(attn_num_hidden, forget_bias=0.0, state_is_tuple=False)
改成:
single_cell = tf.contrib.rnn.BasicLSTMCell(attn_num_hidden, forget_bias=0.0, state_is_tuple=False)
seq2seq_model.py 100 ,102
lstm_fw_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(num_hidden, forget_bias=0.0, state_is_tuple=False)
# Backward direction cell
lstm_bw_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(num_hidden, forget_bias=0.0, state_is_tuple=False)
改成:
lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(num_hidden, forget_bias=0.0, state_is_tuple=False)
# Backward direction cell
lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(num_hidden, forget_bias=0.0, state_is_tuple=False)
如果result 中没有correct说明没有正确结果,模型文件路径可能不对