CTPN文本检测与tensorflow实现

 

1. 引言

    近年来,随着人工智能的发展,文本检测在很多任务中都是一项基本任务,比如广告牌中文字识别、智能驾驶路牌的检测、身份证识别、快递地址识别等。这些任务中首先的一项就是文本检测,即检测出文本在图像中的位置,这样才能便于后续的文本识别。

    文本检测的任务类似目标检测,但是又要比目标检测难,因为目标检测只需对物体进行分类,当只看到物体的半个区域时,其实就基本可以识别出物体是什么类别,但是对于文本检测来说,当只看到局部区域时,可能只是看到一个单词中的个别字母或者一个汉字中的偏旁部首,因此,很难识别出对应的文本是什么,因此,对于文本检测来说,它的精度要求比目标检测要高得多。

    文本检测目前主要分为两种,一种是OCR(Optical Character Recognition, 光学字符识别),一种是STR(Scene Text Recognition,场景文本识别),前者主要是用于对一些扫描文档的文本检测,要求背景必须比较干净,并且识别的精度比较低,速度也比较慢,而后者则是对一些自然场景的文本检测,其难度更大,一方面是背景信息及其复杂,另一方面是文字的角度、字体、颜色等千变万化,因此,面临的挑战也更大。

    这几年来,随着卷积神经网络的发展,场景文本识别得到了很大的发展,其中,最经典是16年Zhi Tian等人提出来的CTPN(Connectionist Text Proposal Network)模型,该模型极大的简化了检测的流程,也使文本检测的效果、速度、稳健性得到了质的提升。因此,本文将对这个模型进行介绍,并用tensorflow进行实现。

2. CTPN模型介绍

2.1 CTPN的模型结构

    CTPN模型主要包括三个部分,分别是卷积层、Bi-LSTM层、全连接层,其结构如下图所示。

    首先,在卷积层部分,CTPN选取的是16年比较火的VGG16模型进行特征提取,作者选择了VGG16模型中conv5的feature maps作为图像最后的特征,假设此时feature maps的尺寸为H \times W \times C

    接着,由于文本之间存在序列关系,因此,作者引入了递归神经网络,采用的是一层Bi-LSTM层,作者发现引入了递归神经网络对文本检测的效果有一个很大的提升,如下图所示,第一行是不采用递归神经网络的效果,第二行是采用了Bi-LSTM后的效果。具体的做法是采用一个3 \times 3的滑动窗口,提取feature maps上每个点周围3 \times 3的区域作为该点的特征向量表示,此时,图像的尺度变为H \times W \times 9C,然后将每一行作为序列的长度,高度作为batch_size,传入一个128维的Bi-LSTM,得到Bi-LSTM层的输出为W \times H \times 256

    最后,将Bi-LSTM的输出接入全连接层,在这一部分,作者引入了anchor的机制,即对每一个点用k个anchor进行预测,每个anchor就是一个盒子,其高度由[273,390,...,11]逐渐递减,每次除以0.7,总共有10个。作者采用的是三个全连接层分支。

     第一个分支是对k个auchor的纵坐标,每一个anchor的纵坐标有两个,分别是盒子的y坐标中心和高度,因此,总共的维度是2k。具体表示如下:

                                                                    \begin{array} { l l } { v _ { c } = \left( c _ { y } - c _ { y } ^ { a } \right) / h ^ { a } , } & { v _ { h } = \log \left( h / h ^ { a } \right) } \\ { v _ { c } ^ { * } = \left( c _ { y } ^ { * } - c _ { y } ^ { a } \right) / h ^ { a } , } & { v _ { h } ^ { * } = \log \left( h ^ { * } / h ^ { a } \right) } \end{array}

其中,\mathbf { v } = \left\{ v _ { c } , v _ { h } \right\}\mathbf { v } ^ { * } = \left\{ v _ { c } ^ { * } , v _ { h } ^ { * } \right\}分别是预测的坐标和真实的坐标,c _ { y } ^ { a }h ^ { a }分别是一个anchor的y坐标中心和高度,c _ { y }h分别是预测出来的y坐标中心和高度,c^{*} _ { y }h^{*}分别是真实的y坐标中心和高度。

    第二个分支是对k个anchor的score进行预测,即该盒子是否包含文本的概率,每个盒子对应一个二分类,因此,总共是2k个维度。当score>0.7时,认为该auchor包含文本。

    第三个分支是对k个anchor的横坐标进行预测,可以在每个anchor盒子的左侧或右侧的横坐标,因此,k个盒子对应的维度就是k。之所以需要这个分支主要是为了修正anchor的水平位置。坐标的具体表示如下:

                                                               o = \left( x _ { s i d e } - c _ { x } ^ { a } \right) / w ^ { a } , \quad o ^ { * } = \left( x _ { s i d e } ^ { * } - c _ { x } ^ { a } \right) / w ^ { a }

其中,x _ { s i d e }是预测出来的距离anchor水平坐标(左或右坐标)最近的坐标,x _ { s i d e } ^ { * }是真实的x坐标,c_ { x } ^ { a }是anchor的x坐标中心,w ^ { a }是anchor的宽度,也就是16。

2.2  side-refinement

    当模型预测完成后,由于是对feature-map上每个点进行预测,因此,其预测出来只是对应一个anchor,其形式就如上图所示,因此,还需要将这些anchor进行连接,将其连接成一个完整的文本线。因此,作者又引入了一个叫做side-refinement的算法,用于文本线构造,其思想大致如下:

    首先,对于第i个text proposal,记为B _ { i },寻找其配对的邻居B _ { j },记配对后关系为B _ { j } - > B _ { i },其中,要求B _ { j }满足以下条件:

  • B _ { j }距离B _ { i }的长度在50个像素点以内,并且其score最大

  • B _ { j }B _ { i }在垂直方向的重合度必须大于0.7

     接着,对B _ { j }进行反向寻找近邻,当其近邻也刚好是B _ { i }时,则B _ { i }B _ { j }的文本线构建成功。

下图是用了refinement和不用refinement的对比,红色线是用了refinement的效果,黄色线是不用refinement的效果。

2.3 模型的损失函数

     由于全连接层对应的是三个分支,因此,其损失也对应着有三个,作者将这三个损失进行了加权,其形式如下:

                           L \left( \mathbf { s } _ { i } , \mathbf { v } _ { j } , \mathbf { o } _ { k } \right) = \frac { 1 } { N _ { s } } \sum _ { i } L _ { s } ^ { c l } \left( \mathbf { s } _ { i } , \mathbf { s } _ { i } ^ { * } \right) + \frac { \lambda _ { 1 } } { N _ { v } } \sum _ { j } L _ { v } ^ { r e } \left( \mathbf { v } _ { j } , \mathbf { v } _ { j } ^ { * } \right) + \frac { \lambda _ { 2 } } { N _ { o } } \sum _ { k } L _ { o } ^ { r e } \left( \mathbf { o } _ { k } , \mathbf { o } _ { k } ^ { * } \right) 

其中,L _ { s } ^ { c l } , L _ { v } ^ { r e } \text { and } l _ { o } ^ { r e }分别为score、坐标、side-refinement对应的损失函数,其中,L _ { s } ^ { c l }采用的是softmax loss,L _ { v } ^ { r e } \text { and } l _ { o } ^ { r e }采用的是regression loss。\mathbf { s } _ { i } , \mathbf { s } _ { i } ^ { * }分别是预测的score和真实的score,\mathbf { v } _ { j } , \mathbf { v } _ { j } ^ { * }分别对应预测的垂直方向坐标和真实的垂直方向坐标,\mathbf { o } _ { k } , \mathbf { o } _ { k } ^ { * }分别对应预测的side-refinement和真实的side-refinement,N _ { s } ,N _ { v } \text { and } N _ { o }分别对应anchor的个数,这里anchor的个数是不一样的,N _ { s }考虑的是score>0.7的anchor,N _ { v }考虑的是\mathbf { s } _ { i } ^ { * }=1或者与真实anchor的重合度大于0.5的anchor,N _ { o }考虑的是分布在距离真实文本线左右两侧32个像素范围内的anchor。\lambda _ { 1 }, \lambda_{2}分别对应着L_1惩罚参数,采用经验值1和2。

3. CTPN的tensorflow实现

    本文利用tensorflow对CTPN进行了复现,参考的代码主要是LI Mingfan的,其原始链接如下:

    本文将其代码进行了规整,并做了一些改动,并将模型迁移到ICDAR的比赛数据集上进行训练,该数据集总共有7200张,里面含有英语、韩语、日语、中文等多国语言,本文在训练时主要是对英语和中文进行检测。ICDAR数据集的下载链接如下:

    在模型的结构部分,主要改动的是卷积层部分,引入了ResNet的结构,并且横坐标也采用的是2k的维度,另外,损失函数用到是平方损失函数,并且引入了focal-loss的思想。由于篇幅的原因,这里不对模型的结构代码进行具体解读,里面已经备注的比较详细,具体代码如下:

 
  1. # -*- coding: utf-8 -*-

  2. import os

  3. import random

  4. import time

  5. import numpy as np

  6. import tensorflow as tf

  7. from PIL import Image

  8. from tensorflow.python.framework import graph_util

  9. from tensorflow.python.training.moving_averages import assign_moving_average

  10. import data_loader

  11.  
  12.  
  13. class ModelDetect:

  14. def __init__(self,

  15. model_detect_dir,

  16. model_detect_pb_file,

  17. LEARNING_RATE_BASE,

  18. TRAINING_STEPS,

  19. VALID_FREQ,

  20. LOSS_FREQ,

  21. KEEP_NEAR,

  22. KEEP_FREQ,

  23. anchor_heights,

  24. MOMENTUM,

  25. dir_results_valid,

  26. threshold,

  27. model_detect_name,

  28. rnn_size,

  29. fc_size,

  30. keep_prob):

  31. self.model_detect_dir = model_detect_dir

  32. self.model_detect_pb_file = model_detect_pb_file

  33. self.pb_file = os.path.join(model_detect_dir, model_detect_pb_file)

  34. self.sess_config = tf.ConfigProto()

  35. self.is_train = False

  36. self.graph = None

  37. self.sess = None

  38. self.learning_rate_base = LEARNING_RATE_BASE

  39. self.train_steps = TRAINING_STEPS

  40. self.valid_freq = VALID_FREQ

  41. self.loss_freq = LOSS_FREQ

  42. self.keep_near = KEEP_NEAR

  43. self.keep_freq = KEEP_FREQ

  44. self.anchor_heights = anchor_heights

  45. self.MOMENTUM = MOMENTUM

  46. self.dir_results_valid = dir_results_valid

  47. self.threshold = threshold

  48. self.model_detect_name = model_detect_name

  49. self.rnn_size = rnn_size

  50. self.fc_size = fc_size

  51. self.keep_prob = keep_prob

  52.  
  53. def prepare_for_prediction(self, pb_file_path=None):

  54. """

  55. 加载计算图

  56. :param pb_file_path: pb文件

  57. :return:

  58. """

  59. if pb_file_path == None:

  60. pb_file_path = self.pb_file

  61.  
  62. if not os.path.exists(pb_file_path):

  63. print('ERROR: %s NOT exists, when load_pb_for_predict()' % pb_file_path)

  64. return -1

  65.  
  66. self.graph = tf.Graph()

  67.  
  68. # 从pb文件导入计算图

  69. with self.graph.as_default():

  70. with open(pb_file_path, "rb") as f:

  71. graph_def = tf.GraphDef()

  72. graph_def.ParseFromString(f.read())

  73. tf.import_graph_def(graph_def, name="")

  74.  
  75. self.x = self.graph.get_tensor_by_name('x-input:0')

  76. self.w = self.graph.get_tensor_by_name('w-input:0')

  77.  
  78. self.rnn_cls = self.graph.get_tensor_by_name('rnn_cls:0')

  79. self.rnn_ver = self.graph.get_tensor_by_name('rnn_ver:0')

  80. self.rnn_hor = self.graph.get_tensor_by_name('rnn_hor:0')

  81.  
  82. print('graph loaded for prediction')

  83. self.sess = tf.Session(graph=self.graph, config=self.sess_config)

  84.  
  85. def predict(self, img_file, out_dir=None):

  86. """

  87.  
  88. :param img_file: 图像路径. [str]

  89. :param out_dir: 输出保存路径. [str]

  90. :return:

  91. """

  92. # 加载图像

  93. img = Image.open(img_file)

  94.  
  95. # 图片预处理

  96. # img_data = data_loader.mean_gray(img_data)

  97. # img_data = data_loader.two_value_binary(img_data)

  98. # img_data = data_loader.convert2rgb(img_data)

  99.  
  100. # 对图像进行放缩

  101. img_size = img.size # (width, height)

  102. im_size_min = np.min(img_size[0:2])

  103. im_size_max = np.max(img_size[0:2])

  104. im_scale = float(600) / float(im_size_min)

  105. if np.round(im_scale * im_size_max) > 800:

  106. im_scale = float(800) / float(im_size_max)

  107. width = int(img_size[0] * im_scale)

  108. height = int(img_size[1] * im_scale)

  109. img = img.resize((width, height), Image.ANTIALIAS)

  110. # re_im = cv2.resize(img, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR)

  111.  
  112.  
  113. # 对图像进行标准化

  114. img_data = np.array(img, dtype=np.float32) / 255

  115. try:

  116. img_data = [img_data[:, :, 0:3]] # rgba

  117. except:

  118. img_data = [img_data[:, :, 0:2]] # rgb

  119. w_arr = np.array([width], dtype=np.int32)

  120.  
  121. # 开始预测

  122. with self.graph.as_default():

  123. feed_dict = {self.x: img_data, self.w: w_arr}

  124. r_cls, r_ver, r_hor = self.sess.run([self.rnn_cls, self.rnn_ver, self.rnn_hor], feed_dict)

  125. text_bbox, conf_bbox = data_loader.trans_results(r_cls, r_ver, r_hor, \

  126. self.anchor_heights, self.threshold)

  127. # refinement

  128. conn_bbox = data_loader.do_nms_and_connection(text_bbox, conf_bbox)

  129.  
  130. if out_dir == None:

  131. return conn_bbox, text_bbox, conf_bbox

  132.  
  133. if not os.path.exists(out_dir):

  134. os.mkdir(out_dir)

  135.  
  136. # 绘制anchor文本线

  137. filename = os.path.basename(img_file)

  138. basename, _ = os.path.splitext(filename)

  139. file_target = os.path.join(out_dir, 'predicted_' + basename + '.png')

  140. img_target = Image.fromarray(np.uint8(img_data[0] * 255)) # .convert('RGB')

  141. img_target.save(file_target)

  142. data_loader.draw_text_boxes(file_target, text_bbox)

  143.  
  144. # 绘制多个anchor连接后的文本线

  145. file_target = os.path.join(out_dir, 'connected_' + basename + '.png')

  146. img_target = Image.fromarray(np.uint8(img_data[0] * 255)) # .convert('RGB')

  147. img_target.save(file_target)

  148. data_loader.draw_text_boxes(file_target, conn_bbox)

  149.  
  150. return conn_bbox, text_bbox, conf_bbox

  151.  
  152. def create_graph_all(self, training):

  153. """

  154. 创建计算图

  155. :param training: 参数是否可训练. [boolean]

  156. :return:

  157. """

  158. self.is_train = training

  159. self.graph = tf.Graph()

  160.  
  161. with self.graph.as_default():

  162. # 初始化变量

  163. self.x = tf.placeholder(tf.float32, (1, None, None, 3), name='x-input')

  164. self.w = tf.placeholder(tf.int32, (1,), name='w-input') # width

  165. self.t_cls = tf.placeholder(tf.float32, (None, None, None), name='c-input')

  166. self.t_ver = tf.placeholder(tf.float32, (None, None, None), name='v-input')

  167. self.t_hor = tf.placeholder(tf.float32, (None, None, None), name='h-input')

  168.  
  169. # 卷积层,结合resnet结构

  170. self.conv_feat, self.seq_len = self.conv_feat_layers(self.x, self.w, self.is_train)

  171.  
  172. # BI_LSTM + 全连接层

  173. self.rnn_cls, self.rnn_ver, self.rnn_hor = self.rnn_detect_layers(self.conv_feat,

  174. self.seq_len,

  175. len(self.anchor_heights))

  176.  
  177. # 模型的损失函数

  178. self.loss = self.detect_loss(self.rnn_cls,

  179. self.rnn_ver,

  180. self.rnn_hor,

  181. self.t_cls,

  182. self.t_ver,

  183. self.t_hor)

  184.  
  185. # 设置优化函数

  186. self.global_step = tf.train.get_or_create_global_step()

  187. self.learning_rate = tf.get_variable("learning_rate", shape=[], dtype=tf.float32, trainable=False)

  188. optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=self.MOMENTUM)

  189. grads_applying = optimizer.compute_gradients(self.loss)

  190. self.train_op = optimizer.apply_gradients(grads_applying, global_step=self.global_step)

  191.  
  192. if self.is_train:

  193. print('graph defined for training')

  194. else:

  195. print('graph defined for validation')

  196.  
  197. def train_and_valid(self, data_train, data_valid):

  198. """

  199. 训练模型

  200. :param data_train: 训练集图像路径列表. [list]

  201. :param data_valid: 测试集图像路径列表. [list]

  202. :return:

  203. """

  204. # 创建模型存储路径

  205. if not os.path.exists(self.model_detect_dir):

  206. os.mkdir(self.model_detect_dir)

  207.  
  208. # 构建计算图

  209. self.create_graph_all(training=True)

  210.  
  211. # 加载和训练模型

  212. with self.graph.as_default():

  213. saver = tf.train.Saver()

  214. with tf.Session(config=self.sess_config) as sess:

  215. # 初始化变量

  216. tf.global_variables_initializer().run()

  217. sess.run(tf.assign(self.learning_rate, tf.constant(self.learning_rate_base, dtype=tf.float32)))

  218.  
  219. # 加载模型

  220. ckpt = tf.train.get_checkpoint_state(self.model_detect_dir)

  221.  
  222. # 加载模型

  223. if ckpt and ckpt.model_checkpoint_path:

  224. saver.restore(sess, ckpt.model_checkpoint_path)

  225.  
  226. # 开始训练

  227. print('begin to train ...')

  228. start_time = time.time()

  229. begin_time = start_time

  230. step = sess.run(self.global_step)

  231. train_step_half = int(self.train_steps * 0.5)

  232. train_step_quar = int(self.train_steps * 0.75)

  233.  
  234. while step < self.train_steps:

  235. # 训练次数达到总的迭代次数的一半时,将学习率设置为原来的0.1,

  236. # 当训练次数达到总的迭代次数的3/4时,将学习率设置为原来的0.01

  237. if step == train_step_half:

  238. sess.run(tf.assign(self.learning_rate, tf.constant(self.learning_rate_base / 10, dtype=tf.float32)))

  239. if step == train_step_quar:

  240. sess.run(tf.assign(self.learning_rate, tf.constant(self.learning_rate_base / 100, dtype=tf.float32)))

  241.  
  242. # 保存和验证模型

  243. if (step + 1) % self.valid_freq == 0:

  244. # 保存模型

  245. print('save model to ckpt ...')

  246. saver.save(sess, os.path.join(self.model_detect_dir, self.model_detect_name),

  247. global_step=step)

  248.  
  249. # 验证模型

  250. print('validating ...')

  251. model_v = ModelDetect(self.model_detect_dir,

  252. self.model_detect_pb_file,

  253. self.learning_rate_base,

  254. self.train_steps,

  255. self.valid_freq,

  256. self.loss_freq,

  257. self.keep_near,

  258. self.keep_freq,

  259. self.anchor_heights,

  260. self.MOMENTUM,

  261. self.dir_results_valid,

  262. self.threshold,

  263. self.model_detect_name,

  264. self.rnn_size,

  265. self.fc_size,

  266. 1.0)

  267. model_v.validate(data_valid, step)

  268.  
  269. # 从训练集中随机抽选一张照片

  270. img_file = random.choice(data_train)

  271. if not os.path.exists(img_file):

  272. print('image_file: %s NOT exist' % img_file)

  273. continue

  274.  
  275. # 获取该图像的文本线文档路径

  276. txt_file = data_loader.get_target_txt_file(img_file)

  277. if not os.path.exists(txt_file):

  278. print('label_file: %s NOT exist' % txt_file)

  279. continue

  280.  
  281. # 加载图像,并获取对应的真实标签

  282. img_data, feat_size, target_cls, target_ver, target_hor = \

  283. data_loader.get_image_and_targets(img_file, txt_file, self.anchor_heights)

  284.  
  285. # 开始训练

  286. img_size = img_data[0].shape # height, width, channel

  287. w_arr = np.array([img_size[1]], dtype=np.int32)

  288.  
  289. feed_dict = {self.x: img_data,

  290. self.w: w_arr,

  291. self.t_cls: target_cls,

  292. self.t_ver: target_ver,

  293. self.t_hor: target_hor}

  294.  
  295. _, loss_value, step, lr = sess.run([self.train_op, self.loss, self.global_step, self.learning_rate],

  296. feed_dict)

  297.  
  298. if step % self.loss_freq == 0:

  299. curr_time = time.time()

  300. print('step: %d, loss: %g, lr: %g, sect_time: %.1f, total_time: %.1f, %s' %

  301. (step, loss_value, lr,

  302. curr_time - begin_time,

  303. curr_time - start_time,

  304. os.path.basename(img_file)))

  305. begin_time = curr_time

  306.  
  307. def validate(self, data_valid, step):

  308. """

  309. 模型验证函数

  310. :param data_valid: 验证集图像路径列表. [list]

  311. :param step: 当前迭代的次数. [int]

  312. :return:

  313. """

  314. # 判断验证集路径是否存在

  315. if not os.path.exists(self.dir_results_valid):

  316. os.mkdir(self.dir_results_valid)

  317.  
  318. # 初始化计算图

  319. self.create_graph_all(training=False)

  320.  
  321. with self.graph.as_default():

  322. saver = tf.train.Saver()

  323. with tf.Session(config=self.sess_config) as sess:

  324. # 初始化全局变量

  325. tf.global_variables_initializer().run()

  326.  
  327. # 加载模型

  328. ckpt = tf.train.get_checkpoint_state(self.model_detect_dir)

  329. if ckpt and ckpt.model_checkpoint_path:

  330. saver.restore(sess, ckpt.model_checkpoint_path)

  331.  
  332. # 将变量转化为常数,并保存到pb文件

  333. constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def,

  334. output_node_names=['rnn_cls', 'rnn_ver', 'rnn_hor'])

  335. with tf.gfile.FastGFile(self.pb_file, mode='wb') as f:

  336. f.write(constant_graph.SerializeToString())

  337.  
  338. # 开始预测

  339. NumImages = len(data_valid)

  340. curr = 0

  341. for img_file in data_valid:

  342. print(img_file)

  343. # 获取当前图像的文本线txt文档的存储路径

  344. txt_file = data_loader.get_target_txt_file(img_file)

  345.  
  346. # 获取当前图像的像素矩阵、feature map维度以及三个分支的标签

  347. img_data, feat_size, target_cls, target_ver, target_hor = \

  348. data_loader.get_image_and_targets(img_file, txt_file, self.anchor_heights)

  349.  
  350. # 当前图像的尺寸

  351. img_size = img_data[0].shape # height, width, channel

  352. w_arr = np.array([img_size[1]], dtype=np.int32)

  353.  
  354. feed_dict = {self.x: img_data,

  355. self.w: w_arr,

  356. self.t_cls: target_cls,

  357. self.t_ver: target_ver,

  358. self.t_hor: target_hor}

  359.  
  360. # 获取预测到的标签和损失值

  361. r_cls, r_ver, r_hor, loss_value = sess.run([self.rnn_cls, self.rnn_ver, self.rnn_hor, self.loss], feed_dict)

  362.  
  363. curr += 1

  364. print('curr: %d / %d, loss: %f' % (curr, NumImages, loss_value))

  365.  
  366. # 将相对坐标转化为原始图像的绝对坐标,获取预测到的文本线坐标和分数

  367. text_bbox, conf_bbox = data_loader.trans_results(r_cls,

  368. r_ver,

  369. r_hor,

  370. self.anchor_heights,

  371. self.threshold)

  372.  
  373. # 在图像上绘制文本线,并保存

  374. filename = os.path.basename(img_file)

  375. file_target = os.path.join(self.dir_results_valid, str(step) + '_predicted_' + filename)

  376. img_target = Image.fromarray(np.uint8(img_data[0] * 255)) # .convert('RGB')

  377. img_target.save(file_target)

  378. data_loader.draw_text_boxes(file_target, text_bbox)

  379.  
  380. # 移除之前验证的文件

  381. id_remove = step - self.valid_freq * self.keep_near

  382. if id_remove % self.keep_freq:

  383. file_temp = os.path.join(self.dir_results_valid, str(id_remove) + '_predicted_' + filename)

  384. if os.path.exists(file_temp): os.remove(file_temp)

  385.  
  386. print('validation finished')

  387.  
  388. def norm_layer(self, x, train, eps=1e-05, decay=0.9, affine=True, name=None):

  389. """

  390. 批标准化

  391. :param x:输入. [tensor]

  392. :param train: 是否可训练. [boolean]

  393. :param eps:

  394. :param decay:

  395. :param affine:

  396. :param name:

  397. :return:

  398. """

  399. with tf.variable_scope(name, default_name='batch_norm'):

  400. params_shape = [x.shape[-1]]

  401. batch_dims = list(range(0, len(x.shape) - 1))

  402. moving_mean = tf.get_variable('mean', params_shape,

  403. initializer=tf.zeros_initializer(),

  404. trainable=False)

  405. moving_variance = tf.get_variable('variance', params_shape,

  406. initializer=tf.ones_initializer(),

  407. trainable=False)

  408.  
  409. def mean_var_with_update():

  410. # 计算均值和方差

  411. batch_mean, batch_variance = tf.nn.moments(x, batch_dims, name='moments')

  412. # 更新moving_mean和moving_variance

  413. with tf.control_dependencies([assign_moving_average(moving_mean, batch_mean, decay),

  414. assign_moving_average(moving_variance, batch_variance, decay)]):

  415. return tf.identity(batch_mean), tf.identity(batch_variance)

  416.  
  417. if train:

  418. mean, variance = mean_var_with_update()

  419. else:

  420. mean, variance = moving_mean, moving_variance

  421.  
  422. if affine:

  423. beta = tf.get_variable('beta', params_shape,

  424. initializer=tf.zeros_initializer(),

  425. trainable=True)

  426. gamma = tf.get_variable('gamma', params_shape,

  427. initializer=tf.ones_initializer(),

  428. trainable=True)

  429. x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, eps)

  430. else:

  431. x = tf.nn.batch_normalization(x, mean, variance, None, None, eps)

  432.  
  433. return x

  434.  
  435. def conv_layer(self, inputs, params, training):

  436. """

  437. 定义卷积层,带有batch_normalization,relu

  438. :param inputs: 输入数据维度为 4-D tensor: [batch_size, width, height, channels]

  439. or [batch_size, height, width, channels]

  440. :param params: 卷积层参数,[filters, kernel_size, strides, padding, batch_norm, relu, name]. [list]

  441. :param training: 参数是否可以训练. [boolean]

  442. :return:

  443. """

  444. kernel_initializer = tf.contrib.layers.variance_scaling_initializer()

  445. bias_initializer = tf.constant_initializer(value=0.0)

  446. gamma_initializer = tf.random_normal_initializer(1, 0.02)

  447.  
  448. # conv

  449. outputs = tf.layers.conv2d(inputs, params[0], params[1], strides=params[2],

  450. padding=params[3],

  451. kernel_initializer=kernel_initializer,

  452. bias_initializer=bias_initializer,

  453. name=params[6])

  454.  
  455. # batch_norm

  456. if params[4]:

  457. outputs = self.norm_layer(outputs, training, name=params[6] + '/batch_norm')

  458. # outputs = tf.layers.batch_normalization(inputs,

  459. # axis=3,

  460. # epsilon=1e-5,

  461. # momentum=0.1,

  462. # training=training,

  463. # gamma_initializer=gamma_initializer,

  464. # name=params[6] + '/batch_norm')

  465.  
  466. # relu

  467. if params[5]:

  468. outputs = tf.nn.relu(outputs, name=params[6] + '/relu')

  469.  
  470. return outputs

  471.  
  472. def block_resnet_others(self, inputs, layer_params, relu, training, name):

  473. """

  474. 定义ResNet_block

  475. :param inputs: 输入. [tensor]

  476. :param layer_params: 卷积层参数. [list]

  477. :param relu: 是否使用relu激活函数. [boolean]

  478. :param training: 参数是否可以训练. [boolean]

  479. :param name: layer name. [str]

  480. :return:

  481. """

  482. with tf.variable_scope(name):

  483. short_cut = tf.identity(inputs)

  484.  
  485. for item in layer_params:

  486. inputs = self.conv_layer(inputs, item, training)

  487.  
  488. outputs = tf.add(inputs, short_cut, name='add')

  489. if relu:

  490. outputs = tf.nn.relu(outputs, 'last_relu')

  491. return outputs

  492.  
  493. def conv_feat_layers(self, inputs, width, training):

  494. """

  495. cptn结构中的卷积层部分,用来提取feature_map.

  496. :param inputs: 输入的图像. [placeholder]

  497. :param width: 图像宽度. [placeholder]

  498. :param training:是否可训练. [boolean]

  499. :return:

  500. """

  501. # 卷积层各层的参数信息

  502. layer_params = [[64, (3, 3), (1, 1), 'same', True, True, 'conv1'],

  503. [128, (3, 3), (1, 1), 'same', True, True, 'conv2'],

  504. [128, (2, 2), (2, 2), 'valid', True, True, 'pool1'],

  505. [128, (3, 3), (1, 1), 'same', True, True, 'conv3'],

  506. [256, (3, 3), (1, 1), 'same', True, True, 'conv4'],

  507. [256, (2, 2), (2, 2), 'valid', True, True, 'pool2'],

  508. [256, (3, 3), (1, 1), 'same', True, True, 'conv5'],

  509. [512, (3, 3), (1, 1), 'same', True, True, 'conv6'],

  510. [512, (3, 2), (3, 2), 'valid', True, True, 'pool3'],

  511. [512, (3, 1), (1, 1), 'valid', True, True, 'conv_feat']]

  512.  
  513. resnet_params = [[[128, 3, (1, 1), 'same', True, True, 'conv1'],

  514. [128, 3, (1, 1), 'same', True, False, 'conv2']],

  515. [[256, 3, (1, 1), 'same', True, True, 'conv1'],

  516. [256, 3, (1, 1), 'same', True, False, 'conv2']],

  517. [[512, 3, (1, 1), 'same', True, True, 'conv1'],

  518. [512, 3, (1, 1), 'same', True, False, 'conv2']]]

  519.  
  520. # 构建卷积层

  521. with tf.variable_scope("conv_comm"):

  522. inputs = self.conv_layer(inputs, layer_params[0], training)

  523. inputs = self.conv_layer(inputs, layer_params[1], training)

  524. inputs = tf.pad(inputs, [[0, 0], [0, 1], [0, 1], [0, 0]], name='padd1')

  525. inputs = tf.layers.max_pooling2d(inputs, (2, 2), (2, 2), 'valid', 'channels_last', 'pool1')

  526.  
  527. inputs = self.block_resnet_others(inputs, resnet_params[0], True, training, 'res1')

  528.  
  529. inputs = self.conv_layer(inputs, layer_params[3], training)

  530. inputs = self.conv_layer(inputs, layer_params[4], training)

  531. inputs = tf.pad(inputs, [[0, 0], [0, 1], [0, 1], [0, 0]], name='padd2')

  532. inputs = tf.layers.max_pooling2d(inputs, (2, 2), (2, 2), 'valid', 'channels_last', 'pool2')

  533.  
  534. inputs = self.block_resnet_others(inputs, resnet_params[1], True, training, 'res2')

  535.  
  536. inputs = self.conv_layer(inputs, layer_params[6], training)

  537. inputs = self.conv_layer(inputs, layer_params[7], training)

  538. inputs = tf.pad(inputs, [[0, 0], [0, 0], [0, 1], [0, 0]], name='padd3')

  539. inputs = tf.layers.max_pooling2d(inputs, (3, 2), (3, 2), 'valid', 'channels_last', 'pool3')

  540.  
  541. inputs = self.block_resnet_others(inputs, resnet_params[2], True, training, 'res3')

  542.  
  543. conv_feat = self.conv_layer(inputs, layer_params[9], training)

  544. feat_size = tf.shape(conv_feat)

  545.  
  546. # 计算每个feature_map每一行的序列长度,每一行即一个序列

  547. two = tf.constant(2, dtype=tf.float32, name='two')

  548. w = tf.cast(width, tf.float32)

  549. for i in range(3):

  550. w = tf.div(w, two)

  551. w = tf.ceil(w)

  552.  
  553. # 复制height倍,并转化为向量

  554. w = tf.cast(w, tf.int32)

  555. w = tf.tile(w, [feat_size[1]])

  556. sequence_length = tf.reshape(w, [-1], name='seq_len') # [batch,height]

  557.  
  558. return conv_feat, sequence_length

  559.  
  560. def rnn_detect_layers(self, conv_feat, sequence_length, num_anchors):

  561. """

  562. Bi_LSTM + 全连接层.

  563. :param conv_feat: 卷积层提取到的feature map. [tensor]

  564. :param sequence_length: 每一行序列的长度列表,向量长度为conv_feat的高. [tensor]

  565. :param num_anchors: anchor的个数

  566. :return:

  567. """

  568. # 将feature map进行降维,因为batch_size设置为1,所以这里直接去掉batch那一维

  569. conv_feat = tf.squeeze(conv_feat, axis=0)

  570. conv_feat = tf.transpose(conv_feat, [1, 0, 2])

  571.  
  572. # Bi_LSTM层

  573. en_lstm1 = tf.contrib.rnn.LSTMCell(self.rnn_size)

  574. en_lstm1 = tf.contrib.rnn.DropoutWrapper(en_lstm1, output_keep_prob=self.keep_prob)

  575. en_lstm2 = tf.contrib.rnn.LSTMCell(self.rnn_size)

  576. en_lstm2 = tf.contrib.rnn.DropoutWrapper(en_lstm2, output_keep_prob=self.keep_prob)

  577. # encoder_cell_fw = tf.contrib.rnn.MultiRNNCell([en_lstm1])

  578. # encoder_cell_bw = tf.contrib.rnn.MultiRNNCell([en_lstm2])

  579. bi_encoder_outputs, _ = tf.nn.bidirectional_dynamic_rnn(en_lstm1,

  580. en_lstm2,

  581. conv_feat,

  582. sequence_length=sequence_length,

  583. time_major=True,

  584. dtype=tf.float32) # 2 * batch_size * seq_len * hidden_dim

  585. conv_feat = tf.concat(bi_encoder_outputs, 2)

  586.  
  587. # 全连接层

  588. weight_initializer = tf.contrib.layers.variance_scaling_initializer()

  589. bias_initializer = tf.constant_initializer(value=0.0)

  590.  
  591. rnn_feat = tf.layers.dense(conv_feat, self.fc_size,

  592. activation=tf.nn.relu,

  593. kernel_initializer=weight_initializer,

  594. bias_initializer=bias_initializer,

  595. name='rnn_feat')

  596.  
  597. # 输出层,总共三个分支

  598. rnn_cls = tf.layers.dense(rnn_feat, num_anchors * 2,

  599. activation=tf.nn.sigmoid,

  600. kernel_initializer=weight_initializer,

  601. bias_initializer=bias_initializer,

  602. name='text_cls')

  603.  
  604. rnn_ver = tf.layers.dense(rnn_feat, num_anchors * 2,

  605. activation=tf.nn.tanh,

  606. kernel_initializer=weight_initializer,

  607. bias_initializer=bias_initializer,

  608. name='text_ver')

  609.  
  610. rnn_hor = tf.layers.dense(rnn_feat, num_anchors * 2,

  611. activation=tf.nn.tanh,

  612. kernel_initializer=weight_initializer,

  613. bias_initializer=bias_initializer,

  614. name='text_hor')

  615.  
  616. rnn_cls = tf.transpose(rnn_cls, perm=[1, 0, 2], name='rnn_cls')

  617. rnn_ver = tf.transpose(rnn_ver, perm=[1, 0, 2], name='rnn_ver')

  618. rnn_hor = tf.transpose(rnn_hor, perm=[1, 0, 2], name='rnn_hor')

  619.  
  620. return rnn_cls, rnn_ver, rnn_hor

  621.  
  622. def detect_loss(self, rnn_cls, rnn_ver, rnn_hor, target_cls, target_ver, target_hor):

  623. """

  624. 模型损失函数.

  625. :param rnn_cls:预测得到的cls,即分类概率.

  626. :param rnn_ver:预测得到的ver,anchor的y坐标中心.

  627. :param rnn_hor:预测得到的hor,anchor的x坐标.

  628. :param target_cls:真实的cls

  629. :param target_ver:真实的ver

  630. :param target_hor:真实的hor

  631. :return:

  632. """

  633. # 计算正例和负例对应的rnn_cls

  634. rnn_cls_posi = rnn_cls * target_cls

  635. rnn_cls_neg = rnn_cls - rnn_cls_posi

  636.  
  637. # 计算类别的平方损失

  638. pow_posi = tf.square(rnn_cls_posi - target_cls)

  639. pow_neg = tf.square(rnn_cls_neg)

  640.  
  641. # 对损失进行加权

  642. mod_posi = tf.pow(pow_posi / 0.24, 5) # 0.3, 0.2, 0.5,0.4

  643. mod_neg = tf.pow(pow_neg / 0.24, 5) # 0.7, 0.6,

  644. mod_con = tf.pow(0.25 / 0.2, 5)

  645.  
  646. # 统计正例和负例的个数

  647. num_posi = tf.reduce_sum(target_cls) / 2 + 1

  648. num_neg = tf.reduce_sum(target_cls + 1) / 2 - num_posi * 2 + 1

  649.  
  650. # 计算正例和负例的损失值

  651. loss_cls_posi = tf.reduce_sum(pow_posi * mod_posi) / 2

  652. loss_cls_neg = tf.reduce_sum(pow_neg * mod_neg) / 2

  653.  
  654. # 将正例和负例的损失分别计算平均值,最终加和,

  655. # 因为同一张图像会出现较多负例,所以这样要比两者加和后再计算平均好一点

  656. loss_cls = loss_cls_posi / num_posi + loss_cls_neg / num_neg

  657. print('loss_cls:%s' % str(loss_cls))

  658.  
  659. # 计算正例的rnn_ver和rnn_hor

  660. rnn_ver_posi = rnn_ver * target_cls

  661. rnn_hor_posi = rnn_hor * target_cls

  662.  
  663. # 计算负例的rnn_ver和rnn_hor

  664. rnn_ver_neg = rnn_ver - rnn_ver_posi

  665. rnn_hor_neg = rnn_hor - rnn_hor_posi

  666.  
  667. # 计算正例的ver和hor平方损失

  668. pow_ver_posi = tf.square(rnn_ver_posi - target_ver)

  669. pow_hor_posi = tf.square(rnn_hor_posi - target_hor)

  670.  
  671. # 计算负例的ver和hor的平方损失

  672. pow_ver_neg = tf.square(rnn_ver_neg)

  673. pow_hor_neg = tf.square(rnn_hor_neg)

  674.  
  675. # 对正例的平方损失进行加权并计算平均,这里有点类似focal loss的思想

  676. loss_ver_posi = tf.reduce_sum(pow_ver_posi * mod_con) / num_posi

  677. loss_hor_posi = tf.reduce_sum(pow_hor_posi * mod_con) / num_posi

  678.  
  679. # 对负例的平方损失进行加权并计算平均

  680. loss_ver_neg = tf.reduce_sum(pow_ver_neg * mod_neg) / num_neg

  681. loss_hor_neg = tf.reduce_sum(pow_hor_neg * mod_neg) / num_neg

  682.  
  683. # 对正负例的ver和hor损失进行加总

  684. loss_ver = loss_ver_posi + loss_ver_neg

  685. loss_hor = loss_hor_posi + loss_hor_neg

  686.  
  687. loss = tf.add(loss_cls, loss_ver + 2 * loss_hor, name='loss')

  688.  
  689. return loss

     下面是一些数据预处理函数,代码如下:

 
  1. import os

  2. from PIL import Image, ImageDraw

  3. import numpy as np

  4. from math import ceil, floor

  5. from operator import itemgetter

  6.  
  7.  
  8. def getFilesInDirect(path, str_dot_ext):

  9. """

  10. 获取背景图像路径列表

  11. :param path: 背景图像存储路径. [str]

  12. :param str_dot_ext: 背景图像的存储格式. [str]

  13. :return:

  14. """

  15. file_list = []

  16. for file in os.listdir(path):

  17. file_path = os.path.join(path, file)

  18. if os.path.splitext(file_path)[1] == str_dot_ext:

  19. file_list.append(file_path)

  20. return file_list

  21.  
  22.  
  23. def get_files_with_ext(path, str_ext):

  24. """

  25. 获取具有str_ext结尾的文件路径列表

  26. :param path: 文件目录. [str]

  27. :param str_ext: 文件格式,如.png. [str]

  28. :return:

  29. """

  30. file_list = []

  31. for file in os.listdir(path):

  32. file_path = os.path.join(path, file)

  33. if file_path.endswith(str_ext):

  34. file_list.append(file_path)

  35. return file_list

  36.  
  37.  
  38. def get_target_txt_file(img_file):

  39. """

  40. 加载对应图像的文本线文档对应的路径

  41. :param img_file: 图像路径. [str]

  42. :return:

  43. """

  44. # 或者文本线存储目录

  45. pre_dir = os.path.abspath(os.path.dirname(img_file) + os.path.sep + "..")

  46. txt_dir = os.path.join(pre_dir, 'contents')

  47.  
  48. # 加载对应图像的文本线文档

  49. filename = os.path.basename(img_file)

  50. arr_split = os.path.splitext(filename)

  51. filename = arr_split[0] + '.txt'

  52. txt_file = os.path.join(txt_dir, filename)

  53. return txt_file

  54.  
  55.  
  56. def get_list_contents(content_file):

  57. """

  58. 获取文本线文档中的坐标和标签,并转化为列表,类似[[[1,2,3,4],'hello']]

  59. :param content_file: 文本线文档路径. [str]

  60. :return:

  61. """

  62. contents = []

  63. if not os.path.exists(content_file):

  64. return contents

  65.  
  66. with open(content_file, 'r', encoding='utf-8') as fp:

  67. lines = fp.readlines()

  68.  
  69. for line in lines:

  70. arr_str = line.split('|')

  71. item = list(map(lambda x: int(x), arr_str[0].split(',')))

  72. contents.append([item, arr_str[1]])

  73. return contents

  74.  
  75.  
  76. def get_image_and_targets(img_file, txt_file, anchor_heights):

  77. """

  78. 加载图像,并获取图像卷积后的尺寸和真实标签

  79. :param img_file: 图像的路径. [str]

  80. :param txt_file: 图像对应的文本线文档路径. [str]

  81. :param anchor_heights: anchor高度列表. [list]

  82. :return:

  83. """

  84. # 加载图像

  85. img = Image.open(img_file)

  86. img_data = np.array(img, dtype=np.float32) / 255 # [height, width, channel]

  87.  
  88. # 获取图像的rgb通道

  89. try:

  90. img_data = img_data[:, :, 0:3]

  91. except:

  92. img_data = img_data[:, :, 0:2]

  93.  
  94. # 获取各个文本线的坐标和标注,并转化为列表

  95. txt_list = get_list_contents(txt_file)

  96.  
  97. # targets

  98. img_size = img_data.shape # height, width, channel

  99.  
  100. # 计算卷积后feature map的高和宽

  101. height_feat = floor(ceil(ceil(img_size[0] / 2.0) / 2.0) / 3.0) - 2

  102. width_feat = ceil(ceil(ceil(img_size[1] / 2.0) / 2.0) / 2.0)

  103.  
  104. # 初始化三个分支的目标值

  105. num_anchors = len(anchor_heights)

  106. target_cls = np.zeros((height_feat, width_feat, 2 * num_anchors))

  107. target_ver = np.zeros((height_feat, width_feat, 2 * num_anchors))

  108. target_hor = np.zeros((height_feat, width_feat, 2 * num_anchors))

  109.  
  110. # 计算feature map上每个点的对应的真实标签

  111. ash = 12 # anchor stride - height

  112. asw = 8 # anchor stride - width

  113. hc_start = 18

  114. wc_start = 4

  115.  
  116. for h in range(height_feat):

  117. hc = hc_start + ash * h # anchor height center

  118. for w in range(width_feat):

  119. cls, ver, hor = calculate_targets_at([hc, wc_start + asw * w], txt_list, anchor_heights)

  120. target_cls[h, w] = cls

  121. target_ver[h, w] = ver

  122. target_hor[h, w] = hor

  123.  
  124. return [img_data], [height_feat, width_feat], target_cls, target_ver, target_hor

  125.  
  126.  
  127. def calculate_targets_at(anchor_center, txt_list, anchor_heights):

  128. """

  129. 计算当前anchor的真实标签

  130. :param anchor_center: anchor的中心,[height_center,width_center]. [list]

  131. :param txt_list: 文本线列表. [list]

  132. :param anchor_heights: anchor高度列表. [list]

  133. :return:

  134. """

  135. # anchor宽度和anchor高度、宽度步伐

  136. anchor_width = 8

  137. ash = 12 # anchor stride - height

  138. asw = 8 # anchor stride - width

  139.  
  140. # anchor中心

  141. hc = anchor_center[0]

  142. wc = anchor_center[1]

  143.  
  144. # 初始化maxIoU和anchor_posi

  145. maxIoU = 0

  146. anchor_posi = 0

  147. text_bbox = []

  148.  
  149. # 检测当前anchor是否包含文本,若存在,选择IoU最大的作为正例

  150. for item in txt_list:

  151. # 当前文本线的四个坐标

  152. bbox = item[0]

  153.  
  154. flag = 0

  155. # 如果当前的anchor宽度中心刚好落在文本线内,则标记为1

  156. # 如果当前的文本线落在anchor宽度中心~anchor宽度中心+8范围内,并且比较靠近anchor宽度中心,则标记为1

  157. # 如果当前的文本线落在anchor宽度中心-8~anchor宽度中心范围内,并且比较靠近anchor宽度中心,则标记为1

  158. if bbox[0] < wc and wc <= bbox[2]:

  159. flag = 1

  160. elif wc < bbox[0] and bbox[2] < wc + asw:

  161. if bbox[0] - wc < wc + asw - bbox[2]:

  162. flag = 1

  163. elif wc - asw < bbox[0] and bbox[2] < wc:

  164. if bbox[2] - wc <= wc - asw - bbox[0]:

  165. flag = 1

  166.  
  167. if flag == 0: continue

  168.  
  169. # 文本线高度中心

  170. bcenter = (bbox[1] + bbox[3]) / 2.0

  171.  
  172. # anchor的中心不能距离真实中心太远

  173. d0 = abs(hc - bcenter)

  174. dm = abs(hc - ash - bcenter)

  175. dp = abs(hc + ash - bcenter)

  176.  
  177. if d0 < ash and d0 <= dm and d0 < dp:

  178. pass

  179. else:

  180. continue

  181.  
  182. # 当检测到文本时,计算各个anchor的IoU,选择其中最大的作为正例

  183. posi = 0

  184.  
  185. for ah in anchor_heights:

  186. hah = ah // 2 # half_ah

  187.  
  188. IoU = 1.0 * (min(hc + hah, bbox[3]) - max(hc - hah, bbox[1])) \

  189. / (max(hc + hah, bbox[3]) - min(hc - hah, bbox[1]))

  190.  
  191. if IoU > maxIoU:

  192. maxIoU = IoU

  193. anchor_posi = posi

  194. text_bbox = bbox

  195.  
  196. posi += 1

  197. break

  198.  
  199. # 当检测不到文本时,三个分支的标签都用0表示

  200. if maxIoU <= 0: #

  201. num_anchors = len(anchor_heights)

  202. cls = [0, 0] * num_anchors

  203. ver = [0, 0] * num_anchors

  204. hor = [0, 0] * num_anchors

  205. return cls, ver, hor

  206.  
  207. # 检测出包含文本时,则最大IoU对应的anchor作为正例,其他作为负例

  208. cls = []

  209. ver = []

  210. hor = []

  211. for idx, ah in enumerate(anchor_heights):

  212. if not idx == anchor_posi:

  213. cls.extend([0, 0])

  214. ver.extend([0, 0])

  215. hor.extend([0, 0])

  216. continue

  217. cls.extend([1, 1])

  218.  
  219. half_ah = ah // 2

  220. half_aw = anchor_width // 2

  221.  
  222. # 计算anchor的绝对坐标

  223. anchor_bbox = [wc - half_aw, hc - half_ah, wc + half_aw, hc + half_ah]

  224.  
  225. # 计算相对坐标,对anchor坐标进行修正

  226. ratio_bbox = [0, 0, 0, 0]

  227. ratio = (text_bbox[0] - anchor_bbox[0]) / anchor_width

  228. if abs(ratio) < 1:

  229. ratio_bbox[0] = ratio

  230.  
  231. ratio = (text_bbox[2] - anchor_bbox[2]) / anchor_width

  232. if abs(ratio) < 1:

  233. ratio_bbox[2] = ratio

  234.  
  235. ratio_bbox[1] = (text_bbox[1] - anchor_bbox[1]) / ah

  236. ratio_bbox[3] = (text_bbox[3] - anchor_bbox[3]) / ah

  237.  
  238. ver.extend([ratio_bbox[1], ratio_bbox[3]])

  239. hor.extend([ratio_bbox[0], ratio_bbox[2]])

  240.  
  241. return cls, ver, hor

  242.  
  243.  
  244. def trans_results(r_cls, r_ver, r_hor, anchor_heights, threshold):

  245. """

  246. 将相对坐标转化为原始图像的绝对坐标,获取预测到的文本线坐标和分数

  247. :param r_cls: cls标签

  248. :param r_ver: ver标签

  249. :param r_hor: hor标签

  250. :param anchor_heights: anchor高度列表. [list]

  251. :param threshold: 分类阈值. [float]

  252. :return:

  253. """

  254. anchor_width = 8

  255. ash = 12 # anchor stride - height

  256. asw = 8 # anchor stride - width

  257. hc_start = 18

  258. wc_start = 4

  259. aw = anchor_width

  260. list_bbox = []

  261. list_conf = []

  262. feat_shape = r_cls.shape

  263.  
  264. for h in range(feat_shape[0]):

  265. for w in range(feat_shape[1]):

  266. if max(r_cls[h, w, :]) < threshold:

  267. continue

  268.  
  269. # 获取概率最大的anchor

  270. anchor_posi = np.argmax(r_cls[h, w, :]) # in r_cls

  271. anchor_id = anchor_posi // 2 # in anchor_heights

  272.  
  273. # 计算anchor的坐标

  274. ah = anchor_heights[anchor_id] #

  275. anchor_posi = anchor_id * 2 # for retrieve in r_ver, r_hor

  276.  
  277. hc = hc_start + ash * h # anchor center

  278. wc = wc_start + asw * w # anchor center

  279.  
  280. half_ah = ah // 2

  281. half_aw = aw // 2

  282.  
  283. anchor_bbox = [wc - half_aw, hc - half_ah, wc + half_aw, hc + half_ah]

  284.  
  285. # 计算预测到的文本线的坐标

  286. text_bbox = [0, 0, 0, 0]

  287. text_bbox[0] = anchor_bbox[0] + aw * r_hor[h, w, anchor_posi]

  288. text_bbox[1] = anchor_bbox[1] + ah * r_ver[h, w, anchor_posi]

  289. text_bbox[2] = anchor_bbox[2] + aw * r_hor[h, w, anchor_posi + 1]

  290. text_bbox[3] = anchor_bbox[3] + ah * r_ver[h, w, anchor_posi + 1]

  291.  
  292. list_bbox.append(text_bbox)

  293. list_conf.append(max(r_cls[h, w, :]))

  294.  
  295. return list_bbox, list_conf

  296.  
  297.  
  298. def draw_text_boxes(img_file, text_bbox):

  299. """

  300. 对图像绘制文本线

  301. :param img_file: 图像对应的路径. [str]

  302. :param text_bbox: 文本线坐标. [list]

  303. :return:

  304. """

  305. img_draw = Image.open(img_file)

  306. draw = ImageDraw.Draw(img_draw)

  307. for item in text_bbox:

  308. xs = item[0]

  309. ys = item[1]

  310. xe = item[2]

  311. ye = item[3]

  312. line_width = 1 # round(text_size/10.0)

  313. draw.line([(xs, ys), (xs, ye), (xe, ye), (xe, ys), (xs, ys)],

  314. width=line_width, fill=(255, 0, 0))

  315.  
  316. img_draw.save(img_file)

  317.  
  318.  
  319. def do_nms_and_connection(list_bbox, list_conf):

  320. """将anchor连接为文本框

  321. :param list_bbox: anchor list,每个anchor包含左上右下四个坐标.[list]

  322. :param list_conf: anchor概率list,存放每个anchor为前景的概率,同list_bbox对应.[list]

  323. :return: 返回连接anchor后的文本框conn_bboxlist,每个文本框包含左上右下的四个坐标,[list]

  324. """

  325. # #设置anchor连接的最大距离,两个anchor距离大于50,则处理为两个文本框,反之则连接两个文本框

  326. # max_margin = 50

  327. # len_list_box = len(list_bbox)

  328. # conn_bbox = []

  329. # head = tail = 0

  330. # for i in range(1, len_list_box):

  331. # distance_i_j = abs(list_bbox[i][0] - list_bbox[i - 1][0])

  332. # overlap_i_j = overlap(list_bbox[i][1], list_bbox[i][3], list_bbox[i - 1][1], list_bbox[i - 1][3])

  333. # if distance_i_j < max_margin and overlap_i_j > 0.7:

  334. # tail = i

  335. # if i == len_list_box - 1:

  336. # this_test_box = [list_bbox[head][0], list_bbox[head][1], list_bbox[tail][2], list_bbox[tail][3]]

  337. # conn_bbox.append(this_test_box)

  338. # head = tail = i

  339. # else:

  340. # this_test_box = [list_bbox[head][0], list_bbox[head][1], list_bbox[tail][2], list_bbox[tail][3]]

  341. # conn_bbox.append(this_test_box)

  342. # head = tail = i

  343.  
  344. # 获取每个anchor的近邻,判断条件是两个anchor之间的距离必须小于50个像素点,并且在垂直方向的重合度大于0.4

  345. neighbor_list = []

  346. for i in range(len(list_bbox) - 1):

  347. this_neighbor_list = [i]

  348. for j in range(i + 1, len(list_bbox)):

  349. distance_i_j = abs(list_bbox[i][2] - list_bbox[j][0])

  350. overlap_i_j = overlap(list_bbox[i][1], list_bbox[i][3], list_bbox[j][1], list_bbox[j][3])

  351. if distance_i_j < 50 and overlap_i_j > 0.4:

  352. this_neighbor_list.append(j)

  353. neighbor_list.append(this_neighbor_list)

  354.  
  355. # 对每个近邻列表进行合并,一旦两个列表之间有共同的元素,则将他们并在一起

  356. conn_bbox = []

  357. while len(neighbor_list) > 0:

  358. this_conn_bbox = set(neighbor_list[0])

  359. filter_list = [0]

  360. for i in range(1, len(neighbor_list)):

  361. if len(this_conn_bbox & set(neighbor_list[i])) > 0:

  362. this_conn_bbox = this_conn_bbox | set(neighbor_list[i])

  363. filter_list.append(i)

  364. min_x = min([list_bbox[i][0] for i in list(this_conn_bbox)])

  365. min_y = np.mean([list_bbox[i][1] for i in list(this_conn_bbox)])

  366. max_x = max([list_bbox[i][2] for i in list(this_conn_bbox)])

  367. max_y = np.mean([list_bbox[i][3] for i in list(this_conn_bbox)])

  368.  
  369. conn_bbox.append([min_x, min_y, max_x, max_y])

  370. neighbor_list = [neighbor_list[i] for i in range(len(neighbor_list)) if i not in filter_list]

  371.  
  372. return conn_bbox

  373.  
  374.  
  375. def overlap(h_up1, h_dw1, h_up2, h_dw2):

  376. """

  377. 计算垂直重合度

  378. :param h_up1:

  379. :param h_dw1:

  380. :param h_up2:

  381. :param h_dw2:

  382. :return:

  383. """

  384. overlap_value = (min(h_dw1, h_dw2) - max(h_up1, h_up2)) \

  385. / (max(h_dw1, h_dw2) - min(h_up1, h_up2))

  386. return overlap_value

  387.  
  388.  
  389. def mean_gray(img):

  390. """图像灰度处理,均值法(多个通道的均值)

  391. :param img: img为通过cv2.imread()读入的图片

  392. :return: 均值法灰度化的图片数组

  393. """

  394. row, col, channel = img.shape

  395. img_gray = np.zeros(shape=(row, col))

  396. for r in range(row):

  397. for l in range(col):

  398. img_gray[r, l] = img[r, l, :].mean()

  399.  
  400. return img_gray

  401.  
  402.  
  403. def two_value_binary(img_gray, threshold=100, reverse=False):

  404. """

  405. 二值法数据增强.

  406. :param img_gray: 灰度化后的图片数组.

  407. :param threshold: 二值化阈值, 大于阈值设为255, 小于阈值设为0.

  408. :param reverse:是否将前景和背景反转,默认False.[boolean]

  409. :return:

  410. """

  411. threshold /= 255

  412. img_binary = np.zeros_like(img_gray)

  413. row, col = img_binary.shape

  414. for i in range(row):

  415. for j in range(col):

  416. if img_gray[i, j] >= threshold:

  417. img_binary[i, j] = 1

  418. if reverse:

  419. img_binary[i, j] = 1 - img_binary[i, j]

  420. return img_binary

  421.  
  422.  
  423. def convert2rgb(img_binary):

  424. """将二值化后图片改为三通道

  425. :param img_binary: 二值化后的图片,维度:二维.[numpy.ndarray]

  426. :return:

  427. """

  428. rows, cols = img_binary.shape

  429. img_binary_rgb = np.zeros((rows, cols, 3))

  430. for i in range(rows):

  431. for j in range(cols):

  432. img_binary_rgb[i, j, 0:3] = np.tile(img_binary[i, j], 3)

  433. return img_binary_rgb

    在ICDAR数据集上迭代250000次后,达到的效果如下:

 

4.CTPN的优缺点总结

    首先讲一下CTPN的优点吧,大致可以总结为以下几个方面:

  1. 将文本检测任务转化为一系列细比例尺的文本提取,并提出了一种anchor回归机制,可以同时预测垂直位置和提取出的文本是否是文本的分数。
  2. 用RNN将从CNN提取出来的文本进行连接,可以获取到文本行的上下文信息,使得文本检测更加可靠。
  3. 可以兼容多尺度或多语言的文本检测,并且模型的pipeline比较简洁。
  4. 对图像的检测速度快。

    但是CTPN也存在一些缺点,比如对于一些旋转的文本行,其检测效果还是一般,并且文本线的构造也是局限在矩形,当文本出现倾斜时,文本线的构造就显得不够优雅,但是总而言之,该模型还是对文本检测带来了巨大的影响。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值