最近,想使用谷歌的Attention OCR做中文文本识别,项目github地址:https://github.com/A-bone1/Attention-ocr-Chinese-Version,中文介绍可参考CSDN博客:https://blog.csdn.net/qq_40003316/article/details/80062023。
研究后发现该模型的训练数据需要提供FSNS格式的训练数据,而官方也没有给出相关的文档,只给了一个stackoverflow的链接https://stackoverflow.com/a/44461910/743658,可是说的也不清楚。所以自己参考网上的一些办法,写了一个生成FSNS格式tfrecord的小代码。github地址为:https://github.com/A-bone1/FSNS-tfrecord-generate。
FSNS的具体格式在这篇论文有说:https://arxiv.org/pdf/1702.03970.pdf
但是,我们只需关心表四即可:
image/format表示图片的格式,是‘png’ ,如果你生的tfrecord是使用jpg格式,可改成‘raw’
image/encoded 表示图片的具体内容,占用一个string,以‘png’的格式编码
iamge/class表示图片真实的类别id,是37个int64数据,每一个int64对应一个字符编码,具体的映射方式在charset_size=134.txt文件中,要生成自己的数据需要自己创建类似的字典,如我自己创建的包含5400个中文的dic.txt。
image/unpadded_class 表示图片在没有被填充之前真实的id。
image/width:表示图片的像素的宽度
image/orig_width:表示图片在没有填充之前像素的宽度
image/height:表示图片的像素的高度,在tensorflow代码中,这一部分并没有写入代码,因为图片高度固定为150
image/test:占用一个string,是使用UTF-8编码的真实的字符形式的标记
下面直接上代码:(上传的代码是将jpg图片直接存储为tfrecord,速度较快,如果读者想生成png编码的tfrecord,可以参考我的github。
from random import shuffle
import numpy as np
import glob
import tensorflow as tf
import cv2
import sys
import os
import PIL.Image as Image
def encode_utf8_string(text, length, dic, null_char_id=5462):
char_ids_padded = [null_char_id]*length
char_ids_unpadded = [null_char_id]*len(text)
for i in range(len(text)):
hash_id = dic[text[i]]
char_ids_padded[i] = hash_id
char_ids_unpadded[i] = hash_id
return char_ids_padded, char_ids_unpadded
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
dict={}
with open('dic.txt', encoding="utf") as dict_file:
for line in dict_file:
(key, value) = line.strip().split('\t')
dict[value] = int(key)
print((dict))
image_path = 'data/*/*.jpg'
addrs_image = glob.glob(image_path)
label_path = 'data/*/*.txt'
addrs_label = glob.glob(label_path)
print(len(addrs_image))
print(len(addrs_label))
tfrecord_writer = tf.python_io.TFRecordWriter("tfexample_train")
for j in range(0,int(len(addrs_image))):
# 这是写入操作可视化处理
print('Train data: {}/{}'.format(j,int(len(addrs_image))))
sys.stdout.flush()
img = Image.open(addrs_image[j])
img = img.resize((600, 150), Image.ANTIALIAS)
np_data = np.array(img)
image_data = img.tobytes()
for text in open(addrs_label[j], encoding="utf"):
char_ids_padded, char_ids_unpadded = encode_utf8_string(
text=text,
dic=dict,
length=37,
null_char_id=5462)
example = tf.train.Example(features=tf.train.Features(
feature={
'image/encoded': _bytes_feature(image_data),
'image/format': _bytes_feature(b"raw"),
'image/width': _int64_feature([np_data.shape[1]]),
'image/orig_width': _int64_feature([np_data.shape[1]]),
'image/class': _int64_feature(char_ids_padded),
'image/unpadded_class': _int64_feature(char_ids_unpadded),
'image/text': _bytes_feature(bytes(text, 'utf-8')),
# 'height': _int64_feature([crop_data.shape[0]]),
}
))
tfrecord_writer.write(example.SerializeToString())
tfrecord_writer.close()
sys.stdout.flush()
原文地址: