之前关注过文本检测,只了解到CTPN,现在开始看PSENet(Shape Robust Text Detection with Progressive Scale Expansion Network)
参考博客:https://mp.weixin.qq.com/s/-zMVO47AL1iKFmF16KsfOw
PSENet文本检测算法来自论文《Shape Robust Text Detection with Progressive Scale Expansion Network》,2018年7月发表于arxiv,已被CVPR 2019 接收。
解决一些使用python2与python3不同造成的函数或模块找不到问题,以及模型路径问题。
参考:tensorflow版PSENet训练自己的数据及测试进行OCR文本检测,Linux和Windows详细复现过程
解决windows上pse.cpp的编译问题。
这样就可以在windows上跑通 tensorflow-PSENet的测试了。
我整理好了一份放在百度云盘上。 提取码:f55e
关于训练,有时间再详细笔记。
关于训练和测试,由于我比较懒就写了一个脚本,每次想执行的时候双击一下就可以执行了。
测试脚本:
python eval.py --test_data_path=data\images\ --gpu_list=0 --checkpoint_path=model\ --output_dir=data\result
pause
训练脚本:
python train.py --gpu_list=0 --input_size=512 --batch_size_per_gpu=8 --checkpoint_path=./train_model/ --training_data_path=./data/icdar2015/
pause
源码解读:
复现的论文神经网络部分使用的是tensorflow,广度优先搜索部分使用C++实现。
先从train.py开始看,103行定义了损失函数,在tower_loss函数中,构建了模型。模型的输出seg_maps是一个6通道的tensor,对应了论文中segmentation result。在train.py中没有引用到pse,pse在训练的过程中没有用到。
预测的过程在eval.py中。
def detect(seg_maps, timer, image_w, image_h, min_area_thresh=10, seg_map_thresh=0.9, ratio = 1):
其中,min_area_thresh是一个连通分量中至少有10个像素,seg_map_thresh是在返回的seg_map中,0.9以下的被变成0,以上的变为1,以此将seg_map变成二值图。在论文中出现的kernel就是图中变成1的部分。在detect函数中,调用了pse,这部分是使用C++实现的。python中调用C++使用了pybind11。
编译pse
pse文件夹中包含了pse的实现。其中include目录是pybind11的开源代码。广度优先的过程都在pse.cpp中。
在pse/__init__py中队pse.cpp进行了编译。
pybind11中,规定PYBIND!!_MODULE作为一个接口,写在C++文件中,编译的时候会将函数与python中的函数绑定。
PYBIND11_MODULE(pse, m){
m.def("pse_cpp", &pse::pse, " re-implementation pse algorithm(cpp)", py::arg("label_map"), py::arg("Sn"), py::arg("c")=6);
}
第一个pse_cpp是python中绑定的函数名,第二个&pse::pse是在C++文件中待绑定的函数py::arg声明了参数以及默认值。在pse::pse中实现了一个广度优先搜索。
__init__.py中,对pse进行了一次封装。
label_num, label = cv2.connectedComponents(kernals[kernal_num - 1].astype(np.uint8), connectivity=4)
cv2.connectedComponents对模型求出来的最后一个kernel求了一次连通分量。label_num是图中连通分量的个数,label是带有标签的图,如果在连通分量里面,那个像素的值就是对应的连通分量编号,否则就是0。
接下来的for循环将小于10个像素的连通分量删除。
加载数据
在train.py中,数据生成使用的是:
data_generator = data_provider.get_batch(num_workers=FLAGS.num_readers,
input_size=FLAGS.input_size,
batch_size=FLAGS.batch_size_per_gpu * len(gpus))
get_batch中,类GeneratorEnqueuer使用的数据生成器是generator(**kargs),生成器的返回值有images, image_fns, seg_maps, training_masks,其中有用的是images, seg_maps, training_masks,image_fns是文件名,所以是没用的。对应的就是下面有data[0],data[2],data[3]没有data[1]。
ml, tl, _ = sess.run([model_loss, total_loss, train_op], feed_dict={input_images: data[0],
input_seg_maps: data[2],
input_training_masks: data[3]})
读取标注的函数是data_provider.py中的load_annotation,在第280行调用了这个函数,这个函数就是读取标记用的,如果要兼容别的数据集需要修改这个函数。返回值text_polys是一个三维数组
其中,每一层保存了一个多边形,由于数据集中只支持矩形,所以就只有四个点。第三个维度表示一张图片中存在多个文字块。text_tags是一个布尔型的数组。
text_polys, text_tags = load_annoataion(txt_fn)
在load_anotation之后调用check_and_validate_polys对text_polys和text_tags进行矫正。在这个函数中pyclipper.Area计算多边形内的面积,如果面积小于1,则舍去。使用pyclipper.Orientation使点的方向变成顺时针。
然后对图片进行随机放缩
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
标签中的最后数据仅用作是否是文本行的判断依据
if label == '*' or label == '###' or label == '?':
text_tags.append(True)
else:
text_tags.append(False)
crop_area会随机选择一块区域,如果有文字则为样本
否则为背景。
然后,generate_seg会将文字的矩形区域缩放成6种不同的大小,供金字塔结构使用
seg_map_per_image, training_mask = generate_seg((new_h, new_w), text_polys, text_tags,
image_list[i], scale_ratio)
在generage_seg函数中,调用了函数shrink_poly这个函数用来将ground_truth进行不同比例的缩小(论文的3.3节label generation)
# seg map
shrinked_polys = []
if poly_idx not in ignore_poly_mark:
shrinked_polys = shrink_poly(poly.copy(), scale_ratio[i])
模型实现:
在model.py中,首先建立金字塔特征:
feature_pyramid = build_feature_pyramid(end_points, weight_decay=weight_decay)
其中endpoints是resnet中几个特征图。
然后讲feature_pyramid进行concat,由于每一层的feature_pyramid的大小不一定一样,所以需要先进行缩放(unpool函数)
然后经过两个卷积层,得到seg_S_pred。
关于参数传递:
下面是取自 eval.py中的代码
tf.app.flags.DEFINE_string('test_data_path', None, '')
tf.app.flags.DEFINE_string('gpu_list', '0', '')
tf.app.flags.DEFINE_string('checkpoint_path', './', '')
tf.app.flags.DEFINE_string('output_dir', './results/', '')
tf.app.flags.DEFINE_bool('no_write_images', False, 'do not write images')
“DEFINE_xxx”函数带3个参数,分别是变量名称,默认值,用法描述
参考上面我自己写的测试脚本,可知,当有对应参数传递时,默认参数就会被覆盖。
我参考的是:https://github.com/liuheng92/tensorflow_PSENet
其中还有一些代码不太理解:
标签:
if label == '*' or label == '###' or label == '?':
text_tags.append(True)
else:
text_tags.append(False)
这里是说只有标签行以 “*” 或者“###”或者“?”为最后一个元素时才标记为True,那么对icdar2015标签数据中很多是以数字或者词组为标签数据的最后一个元素的,这个代码会将这些标记为False。我的认知是,有任何文本都应该标注为True,没有文本标注为False。当然上面这样写肯定是有道理的,毕竟我代码都没有看全,作者勿怪。
后来参考:https://github.com/whai362/PSENet/blob/master/dataset/icdar2015_loader.py
其中关于标签的设置比较符合我的认知。
if gt[-1][0] == '#':
tags.append(False)
else:
tags.append(True)
即标签行最后一个元素如果是以‘#’开头就认为该行没有文本,否则为文本行,即使这样也会有文本行以‘#’开头的,不过这种情况比较少见。
参考:PSENet源码阅读笔记