为了把手写文本(如历史值班手册等)转化为电子文本,最近尝试在github上查找相关开源项目。功夫不负有心人,找到了基于TensorFlow2.0的中文手写字识别模型,其采用CASIA-HWDB数据集,设计了神经网络达到了相当不错的识别率。
项目地址如下:
https://github.com/jjcheer/ocrcn_tf2
按照操作流程,下载数据,转换格式,生成训练样本。
结果报错如下:
tensorflow.python.framework.errors_impl.InvalidArgumentError: {
{function_node __inference_Dataset_map_parse_example_25}} Input to reshape is a tensor with 4446 values, but the requested shape has 4096
看起来是进入模型的数据格式有问题啊,看了下./train_simple.py里面如下代码报错:
def train():
all_characters = load_characters()
num_classes = len(all_characters)
logging.info('all characters: {}'.format(num_classes))
train_dataset = load_ds()
train_dataset = train_dataset.shuffle(100).map(preprocess).batch(32).repeat()
应该是整个load_ds的问题。
观察./dataset/casia_hwdb.py中的parse_example函数存在原版和V2版本,而在load_ds上调用的是parse_example,而load_val_ds中调用的则是parse_example_v2。
def parse_example(record):
features = tf.io.parse_single_example(record,
features={