【OCR技术系列之六】文本检测CTPN的代码实现

这几天一直在用Pytorch来复现文本检测领域的CTPN论文,本文章将从数据处理、训练标签生成、神经网络搭建、损失函数设计、训练主过程编写等这几个方面来一步一步复现CTPN。CTPN算法理论可以参考这里

训练数据处理

我们的训练选择天池ICPR2018和MSRA_TD500两个数据集,天池ICPR的数据集为网络图像,都是一些淘宝商家上传到淘宝的一些商品介绍图像,其标签方式参考了ICDAR2015的数据标签格式,即一个文本框用4个坐标来表示,即左上、右上、右下、左下四个坐标,共八个值,记作[x1 y1 x2 y2 x3 y3 x4 y4]

1093303-20181202175424984-936238952.png

天池ICPR2018数据集的风格如下,字体形态格式颜色多变,多嵌套于物体之中,识别难度大:

1093303-20181202181734519-193227309.png

MSRA_TD500使微软收集的一个文本检测和识别的一个数据集,里面的图像多是街景图,背景比较复杂,但文本位置比较明显,一目了然。因为MSRA_TD500的标签格式不一样,最后一个参数表示矩形框的旋转角度。

1093303-20181202175515910-1923330407.png

所以我们第一步就是将这两个数据集的标签格式统一,我的做法是将MSRA数据集格式改为ICDAR格式,方便后面的模型训练。因为MSRA_TD500采取的标签格式是[index difficulty_label x y w h angle],所以我们需要根据这个文本框的旋转角度来求得水平文本框旋转后的4个坐标位置。实现如下:


    
    
  1. """
  2. This file is to change MSRA_TD500 dataset format to ICDAR2015 dataset format.
  3. MSRA_TD500 format: [index difficulty_label x y w h angle]
  4. ICDAR2015 format: [left_top_x left_top_y right_top_X right_top_y right_bottom_x right_bottom_y left_bottom_x left_bottom_y]
  5. """
  6. import math
  7. import cv2
  8. import os
  9. # 求旋转后矩形的4个坐标
  10. def get_box_img(x, y, w, h, angle):
  11. # 矩形框中点(x0,y0)
  12. x0 = x + w/ 2
  13. y0 = y + h/ 2
  14. l = math.sqrt(pow(w/ 2, 2) + pow(h/ 2, 2)) # 即对角线的一半
  15. # angle小于0,逆时针转
  16. if angle < 0:
  17. a1 = -angle + math.atan(h / float(w)) # 旋转角度-对角线与底线所成的角度
  18. a2 = -angle - math.atan(h / float(w)) # 旋转角度+对角线与底线所成的角度
  19. pt1 = (x0 - l * math.cos(a2), y0 + l * math.sin(a2))
  20. pt2 = (x0 + l * math.cos(a1), y0 - l * math.sin(a1))
  21. pt3 = (x0 + l * math.cos(a2), y0 - l * math.sin(a2)) # x0+左下点旋转后在水平线上的投影, y0-左下点在垂直线上的投影,显然逆时针转时,左下点上一和左移了。
  22. pt4 = (x0 - l * math.cos(a1), y0 + l * math.sin(a1))
  23. else:
  24. a1 = angle + math.atan(h / float(w))
  25. a2 = angle - math.atan(h / float(w))
  26. pt1 = (x0 - l * math.cos(a1), y0 - l * math.sin(a1))
  27. pt2 = (x0 + l * math.cos(a2), y0 + l * math.sin(a2))
  28. pt3 = (x0 + l * math.cos(a1), y0 + l * math.sin(a1))
  29. pt4 = (x0 - l * math.cos(a2), y0 - l * math.sin(a2))
  30. return [pt1[ 0], pt1[ 1], pt2[ 0], pt2[ 1], pt3[ 0], pt3[ 1], pt4[ 0], pt4[ 1]]
  31. def read_file(path):
  32. result = []
  33. for line in open(path):
  34. info = []
  35. data = line.split( ' ')
  36. info.append(int(data[ 2]))
  37. info.append(int(data[ 3]))
  38. info.append(int(data[ 4]))
  39. info.append(int(data[ 5]))
  40. info.append(float(data[ 6]))
  41. info.append(data[ 0])
  42. result.append(info)
  43. return result
  44. if __name__ == '__main__':
  45. file_path = '/home/ljs/OCR_dataset/MSRA-TD500/test/'
  46. save_img_path = '../dataset/OCR_dataset/ctpn/test_im/'
  47. save_gt_path = '../dataset/OCR_dataset/ctpn/test_gt/'
  48. file_list = os.listdir(file_path)
  49. for f in file_list:
  50. if '.gt' in f:
  51. continue
  52. name = f[ 0: 8]
  53. txt_path = file_path + name + '.gt'
  54. im_path = file_path + f
  55. im = cv2.imread(im_path)
  56. coordinate = read_file(txt_path)
  57. # 仿照ICDAR格式,图片名字写做img_xx.jpg,对应的标签文件写做gt_img_xx.txt
  58. cv2.imwrite(save_img_path + name.lower() + '.jpg', im)
  59. save_gt = open(save_gt_path + 'gt_' + name.lower() + '.txt', 'w')
  60. for i in coordinate:
  61. box = get_box_img(i[ 0], i[ 1], i[ 2], i[ 3], i[ 4])
  62. box = [int(box[i]) for i in range(len(box))]
  63. box = [str(box[i]) for i in range(len(box))]
  64. save_gt.write( ','.join(box))
  65. save_gt.write( '\n')

经过格式处理后,我们两份数据集算是整理好了。当然我们还需要对整个数据集划分为训练集和测试集,我的文件组织习惯如下:train_im, test_im文件夹装的是训练和测试图像,train_gt和test_gt装的是训练和测试标签。

1093303-20181202175530276-1351005026.png

训练标签生成

因为CTPN的核心思想也是基于Faster RCNN中的region proposal机制的,所以原始数据标签需要转化为
anchor标签。训练数据的标签的生成的代码是最难写,因为从一个完整的文本框标签转化为一个个小尺度文本框标签确实有点难度,而且这个anchor标签的生成方式也与Faster RCNN生成方式略有不同。下面讲一讲我的实现思路:

第一步我们需要将原先每张图的bbox标签转化为每个anchor标签。为了实现该功能,我们先将一张图划分为宽度为16的各个anchor。

  • 首先计算一张图可以分为多少个宽度为16的acnhor(比如一张图的宽度为w,那么水平anchor总数为w/16),再计算出我们的文本框标签中含有几个acnhor,最左和最右的anchor又是哪几个;
  • 计算文本框内anchor的高度和中心是多少:此时我们可以在一个全黑的mask中把文本框label画上去(白色),然后从上往下和从下往上找到第一个白色像素点的位置作为该anchor的上下边界;
  • 最后将每个anchor的位置(水平ID)、anchor中心y坐标、anchor高度存储并返回

    
    
  1. def generate_gt_anchor(img, box, anchor_width=16):
  2. """
  3. calsulate ground truth fine-scale box
  4. :param img: input image
  5. :param box: ground truth box (4 point)
  6. :param anchor_width:
  7. :return: tuple (position, h, cy)
  8. """
  9. if not isinstance(box[ 0], float):
  10. box = [float(box[i]) for i in range(len(box))]
  11. result = []
  12. # 求解一个bbox下,能分解为多少个16宽度的小anchor,并求出最左和最右的小achor的id
  13. left_anchor_num = int(math.floor(max(min(box[ 0], box[ 6]), 0) / anchor_width)) # the left side anchor of the text box, downwards
  14. right_anchor_num = int(math.ceil(min(max(box[ 2], box[ 4]), img.shape[ 1]) / anchor_width)) # the right side anchor of the text box, upwards
  15. # handle extreme case, the right side anchor may exceed the image width
  16. if right_anchor_num * 16 + 15 > img.shape[ 1]:
  17. right_anchor_num -= 1
  18. # combine the left-side and the right-side x_coordinate of a text anchor into one pair
  19. position_pair = [(i * anchor_width, (i + 1) * anchor_width - 1) for i in range(left_anchor_num, right_anchor_num)]
  20. # 计算每个gt anchor的真实位置,其实就是求解gt anchor的上边界和下边界
  21. y_top, y_bottom = cal_y_top_and_bottom(img, position_pair, box)
  22. # 最后将每个anchor的位置(水平ID)、anchor中心y坐标、anchor高度存储并返回
  23. for i in range(len(position_pair)):
  24. position = int(position_pair[i][ 0] / anchor_width) # the index of anchor box
  25. h = y_bottom[i] - y_top[i] + 1 # the height of anchor box
  26. cy = (float(y_bottom[i]) + float(y_top[i])) / 2.0 # the center point of anchor box
  27. result.append((position, cy, h))
  28. return result

计算anchor上下边界的方法:


    
    
  1. # cal the gt anchor box's bottom and top coordinate
  2. def cal_y_top_and_bottom(raw_img, position_pair, box):
  3. """
  4. :param raw_img:
  5. :param position_pair: for example:[(0, 15), (16, 31), ...]
  6. :param box: gt box (4 point)
  7. :return: top and bottom coordinates for y-axis
  8. """
  9. img = copy.deepcopy(raw_img)
  10. y_top = []
  11. y_bottom = []
  12. height = img.shape[ 0]
  13. # 设置图像mask,channel 0为全黑图
  14. for i in range(img.shape[ 0]):
  15. for j in range(img.shape[ 1]):
  16. img[i, j, 0] = 0
  17. top_flag = False
  18. bottom_flag = False
  19. # 根据bbox四点画出文本框,channel 0下文本框为白色
  20. img = other.draw_box_4pt(img, box, color=( 255, 0, 0))
  21. for k in range(len(position_pair)):
  22. # 从左到右遍历anchor gt,对每个anchor从上往下扫描像素,遇到白色像素点(255)就停下来,此时像素点坐标y就是该anchor gt的上边界
  23. # calc top y coordinate
  24. for y in range( 0, height -1):
  25. # loop each anchor, from left to right
  26. for x in range(position_pair[k][ 0], position_pair[k][ 1] + 1):
  27. if img[y, x, 0] == 255:
  28. y_top.append(y)
  29. top_flag = True
  30. break
  31. if top_flag is True:
  32. break
  33. # 从左到右遍历anchor gt,对每个anchor从下往上扫描像素,遇到白色像素点(255)就停下来,此时像素点坐标y就是该anchor gt的下边界
  34. # calc bottom y coordinate, pixel from down to top loop
  35. for y in range(height - 1, -1, -1):
  36. # loop each anchor, from left to right
  37. for x in range(position_pair[k][ 0], position_pair[k][ 1] + 1):
  38. if img[y, x, 0] == 255:
  39. y_bottom.append(y)
  40. bottom_flag = True
  41. break
  42. if bottom_flag is True:
  43. break
  44. top_flag = False
  45. bottom_flag = False
  46. return y_top, y_bottom

经过上面的标签处理,我们已经将原先的标准的文本框标签转化为一个一个小尺度anchor标签,以下是标签转化后的效果:

1093303-20181202181756341-612737932.png

以上标签可视化后看来anchor标签做得不错,但是这里需要提出的是,我发现这种anchor生成方法是不太精准的,比如一个文本框边缘像素刚好落在一个新的anchor上,那么我们就要为这个像素分配一个16像素的anchor,显然导致了文本框标签的不准确,引入了15像素的误差,这个是需要思考的。这个问题我们先不做处理,继续下面的工作。

当然转化期间我们也遇到很多奇怪的问题,比如下图这种标签都已经超出图像范围的,我们必须做相应的特殊处理,比如限定标签横坐标的最大尺寸为图像宽度。


    
    
  1. left_anchor_num = int(math.floor(max(min(box[ 0], box[ 6]), 0) / anchor_width)) # the left side anchor of the text box, downwards
  2. right_anchor_num = int(math.ceil(min(max(box[ 2], box[ 4]), img.shape[ 1]) / anchor_width)) # the right side anchor of the text box, upwards

1093303-20181202175613774-256051821.png

CTPN网络结构

因为CTPN用到了CNN+双向LSTM的网络结构,所以我们分步实现CTPN架构。

1093303-20181202175625213-1331358551.png

CNN部分CTPN采取了VGG16进行底层特征提取。


    
    
  1. class VGG_16(nn.Module):
  2. "" "
  3. VGG-16 without pooling layer before fc layer
  4. " ""
  5. def __init__(self):
  6. super(VGG_16, self).__init_ _()
  7. self.convolution1_1 = nn.Conv2d( 3, 64, 3, padding= 1)
  8. self.convolution1_2 = nn.Conv2d( 64, 64, 3, padding= 1)
  9. self.pooling1 = nn.MaxPool2d( 2, stride= 2)
  10. self.convolution2_1 = nn.Conv2d( 64, 128, 3, padding= 1)
  11. self.convolution2_2 = nn.Conv2d( 128, 128, 3, padding= 1)
  12. self.pooling2 = nn.MaxPool2d( 2, stride= 2)
  13. self.convolution3_1 = nn.Conv2d( 128, 256, 3, padding= 1)
  14. self.convolution3_2 = nn.Conv2d( 256, 256, 3, padding= 1)
  15. self.convolution3_3 = nn.Conv2d( 256, 256, 3, padding= 1)
  16. self.pooling3 = nn.MaxPool2d( 2, stride= 2)
  17. self.convolution4_1 = nn.Conv2d( 256, 512, 3, padding= 1)
  18. self.convolution4_2 = nn.Conv2d( 512, 512, 3, padding= 1)
  19. self.convolution4_3 = nn.Conv2d( 512, 512, 3, padding= 1)
  20. self.pooling4 = nn.MaxPool2d( 2, stride= 2)
  21. self.convolution5_1 = nn.Conv2d( 512, 512, 3, padding= 1)
  22. self.convolution5_2 = nn.Conv2d( 512, 512, 3, padding= 1)
  23. self.convolution5_3 = nn.Conv2d( 512, 512, 3, padding= 1)
  24. def forward(self, x):
  25. x = F.relu( self.convolution1_1(x), inplace=True)
  26. x = F.relu( self.convolution1_2(x), inplace=True)
  27. x = self.pooling1(x)
  28. x = F.relu( self.convolution2_1(x), inplace=True)
  29. x = F.relu( self.convolution2_2(x), inplace=True)
  30. x = self.pooling2(x)
  31. x = F.relu( self.convolution3_1(x), inplace=True)
  32. x = F.relu( self.convolution3_2(x), inplace=True)
  33. x = F.relu( self.convolution3_3(x), inplace=True)
  34. x = self.pooling3(x)
  35. x = F.relu( self.convolution4_1(x), inplace=True)
  36. x = F.relu( self.convolution4_2(x), inplace=True)
  37. x = F.relu( self.convolution4_3(x), inplace=True)
  38. x = self.pooling4(x)
  39. x = F.relu( self.convolution5_1(x), inplace=True)
  40. x = F.relu( self.convolution5_2(x), inplace=True)
  41. x = F.relu( self.convolution5_3(x), inplace=True)
  42. return x

再实现双向LSTM,增强关联序列的信息学习。


    
    
  1. class BLSTM(nn.Module):
  2. def __init__(self, channel, hidden_unit, bidirectional=True):
  3. """
  4. :param channel: lstm input channel num
  5. :param hidden_unit: lstm hidden unit
  6. :param bidirectional:
  7. """
  8. super(BLSTM, self).__init__()
  9. self.lstm = nn.LSTM(channel, hidden_unit, bidirectional=bidirectional)
  10. def forward(self, x):
  11. """
  12. WARNING: The batch size of x must be 1.
  13. """
  14. x = x.transpose( 1, 3)
  15. recurrent, _ = self.lstm(x[ 0])
  16. recurrent = recurrent[np.newaxis, :, :, :]
  17. recurrent = recurrent.transpose( 1, 3)
  18. return recurrent

这里实现多一层中间层,用于连接CNN和LSTM。将VGG最后一层卷积层输出的feature map转化为向量形式,用于接下来的LSTM训练。


    
    
  1. class Im2col(nn.Module):
  2. def __init__(self, kernel_size, stride, padding):
  3. super(Im2col, self).__init_ _()
  4. self.kernel_size = kernel_size
  5. self.stride = stride
  6. self.padding = padding
  7. def forward(self, x):
  8. height = x.shape[ 2]
  9. x = F.unfold(x, self.kernel_size, padding= self.padding, stride= self.stride)
  10. x = x.reshape((x.shape[ 0], x.shape[ 1], height, - 1))
  11. return x

最后将以上三部分拼接成一个完整的CTPN网络:底层使用VGG16做特征提取->lstm序列信息学习->output每个anchor分数,h, y, side_refinement


    
    
  1. class CTPN(nn.Module):
  2. def __init__(self):
  3. super(CTPN, self).__init_ _()
  4. self.cnn = nn.Sequential()
  5. self.cnn.add_module( 'VGG_16', VGG_16())
  6. self.rnn = nn.Sequential()
  7. self.rnn.add_module( 'im2col', Net.Im2col(( 3, 3), ( 1, 1), ( 1, 1)))
  8. self.rnn.add_module( 'blstm', BLSTM( 3 * 3 * 512, 128))
  9. self.FC = nn.Conv2d( 256, 512, 1)
  10. self.vertical_coordinate = nn.Conv2d( 512, 2 * 10, 1) # 最终输出2K个参数(k=10),10表示anchor的尺寸个数,2个参数分别表示anchor的h和dy
  11. self.score = nn.Conv2d( 512, 2 * 10, 1) # 最终输出是2K个分数(k=10),2表示有无字符,10表示anchor的尺寸个数
  12. self.side_refinement = nn.Conv2d( 512, 10, 1) # 最终输出1K个参数(k=10),该参数表示该anchor的水平偏移,用于精修文本框水平边缘精度,,10表示anchor的尺寸个数
  13. def forward(self, x, val=False):
  14. x = self.cnn(x)
  15. x = self.rnn(x)
  16. x = self.FC(x)
  17. x = F.relu(x, inplace=True)
  18. vertical_pred = self.vertical_coordinate(x)
  19. score = self.score(x)
  20. if val:
  21. score = score.reshape((score.shape[ 0], 10, 2, score.shape[ 2], score.shape[ 3]))
  22. score = score.squeeze( 0)
  23. score = score.transpose( 1, 2)
  24. score = score.transpose( 2, 3)
  25. score = score.reshape((- 1, 2))
  26. #score = F.softmax(score, dim=1)
  27. score = score.reshape(( 10, vertical_pred.shape[ 2], - 1, 2))
  28. vertical_pred = vertical_pred.reshape((vertical_pred.shape[ 0], 10, 2, vertical_pred.shape[ 2], vertical_pred.shape[ 3]))
  29. side_refinement = self.side_refinement(x)
  30. return vertical_pred, score, side_refinement

损失函数设计

CTPN的LOSS分为三部分:

  • h,y的regression loss,用的是SmoothL1Loss;
  • score的classification loss,用的是CrossEntropyLoss;
  • side refinement loss,用的是用的是SmoothL1Loss。

1093303-20181202175646079-972094450.png

先定义好一些固定参数


    
    
  1. class CTPN_Loss(nn.Module):
  2. def __init__(self, using_cuda=False):
  3. super(CTPN_Loss, self).__init_ _()
  4. self.Ns = 128
  5. self.ratio = 0. 5
  6. self.lambda1 = 1.0
  7. self.lambda2 = 1.0
  8. self.Ls_cls = nn.CrossEntropyLoss()
  9. self.Lv_reg = nn.SmoothL1Loss()
  10. self.Lo_reg = nn.SmoothL1Loss()
  11. self.using_cuda = using_cuda

首先设计classification loss


    
    
  1. cls_loss = 0.0
  2. if self.using_cuda:
  3. for p in positive_batch:
  4. cls_loss += self. Ls_cls(score[ 0, p[ 2] * 2: ((p[ 2] + 1) * 2), p[ 1], p[ 0]].unsqueeze( 0),
  5. torch. LongTensor([ 1]).cuda())
  6. for n in negative_batch:
  7. cls_loss += self. Ls_cls(score[ 0, n[ 2] * 2: ((n[ 2] + 1) * 2), n[ 1], n[ 0]].unsqueeze( 0),
  8. torch. LongTensor([ 0]).cuda())
  9. else:
  10. for p in positive_batch:
  11. cls_loss += self. Ls_cls(score[ 0, p[ 2] * 2: ((p[ 2] + 1) * 2), p[ 1], p[ 0]].unsqueeze( 0),
  12. torch. LongTensor([ 1]))
  13. for n in negative_batch:
  14. cls_loss += self. Ls_cls(score[ 0, n[ 2] * 2: ((n[ 2] + 1) * 2), n[ 1], n[ 0]].unsqueeze( 0),
  15. torch. LongTensor([ 0]))
  16. cls_loss = cls_loss / self. Ns

然后是vertical coordinate regression loss,反映的是y和h的偏差


    
    
  1. # calculate vertical coordinate regression loss
  2. v_reg_loss = 0.0
  3. Nv = len(vertical_reg)
  4. if self.using_cuda:
  5. for v in vertical_reg:
  6. v_reg_loss += self.Lv_reg(vertical_pred[ 0, v[ 2] * 2: ((v[ 2] + 1) * 2), v[ 1], v[ 0]].unsqueeze( 0),
  7. torch.FloatTensor([v[ 3], v[ 4]]).unsqueeze( 0).cuda())
  8. else:
  9. for v in vertical_reg:
  10. v_reg_loss += self.Lv_reg(vertical_pred[ 0, v[ 2] * 2: ((v[ 2] + 1) * 2), v[ 1], v[ 0]].unsqueeze( 0),
  11. torch.FloatTensor([v[ 3], v[ 4]]).unsqueeze( 0))
  12. v_reg_loss = v_reg_loss / float(Nv)

最后计算side refinement regression loss,用于修正边缘精度


    
    
  1. # calculate side refinement regression loss
  2. o_reg_loss = 0.0
  3. No = len(side_refinement_reg)
  4. if self.using_cuda:
  5. for s in side_refinement_reg:
  6. o_reg_loss += self.Lo_reg(side_refinement[ 0, s[ 2]: s[ 2] + 1, s[ 1], s[ 0]].unsqueeze( 0),
  7. torch.FloatTensor([s[ 3]]).unsqueeze( 0).cuda())
  8. else:
  9. for s in side_refinement_reg:
  10. o_reg_loss += self.Lo_reg(side_refinement[ 0, s[ 2]: s[ 2] + 1, s[ 1], s[ 0]].unsqueeze( 0),
  11. torch.FloatTensor([s[ 3]]).unsqueeze( 0))
  12. o_reg_loss = o_reg_loss / float(No)

当然最后还有个total loss,汇总整个训练过程中的loss

loss = cls_loss + v_reg_loss * self.lambda1 + o_reg_loss * self.lambda2
    
    

训练过程设计

训练:优化器我们选择SGD,learning rate我们设置了两个,前N个epoch使用较大的lr,后面的epoch使用较小的lr以更好地收敛。训练过程我们定义了4个loss,分别是total_cls_loss,total_v_reg_loss, total_o_reg_loss, total_loss(前面三个loss相加)。


    
    
  1. net = Net.CTPN() # 获取网络结构
  2. for name, value in net.named_parameters():
  3. if name in no_grad:
  4. value.requires_grad = False
  5. else:
  6. value.requires_grad = True
  7. # for name, value in net.named_parameters():
  8. # print('name: {0}, grad: {1}'.format(name, value.requires_grad))
  9. net.load_state_dict(torch.load( './lib/vgg16.model'))
  10. # net.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
  11. lib.utils.init_weight(net)
  12. if using_cuda:
  13. net.cuda()
  14. net.train()
  15. print(net)
  16. criterion = Loss.CTPN_Loss(using_cuda=using_cuda) # 获取loss
  17. train_im_list, train_gt_list, val_im_list, val_gt_list = create_train_val() # 获取训练、测试数据
  18. total_iter = len(train_im_list)
  19. print( "total training image num is %s" % len(train_im_list))
  20. print( "total val image num is %s" % len(val_im_list))
  21. train_loss_list = []
  22. test_loss_list = []
  23. # 开始迭代训练
  24. for i in range(epoch):
  25. if i >= change_epoch:
  26. lr = lr_behind
  27. else:
  28. lr = lr_front
  29. optimizer = optim.SGD(net.parameters(), lr=lr, momentum= 0.9, weight_decay= 0.0005)
  30. #optimizer = optim.Adam(net.parameters(), lr=lr)
  31. iteration = 1
  32. total_loss = 0
  33. total_cls_loss = 0
  34. total_v_reg_loss = 0
  35. total_o_reg_loss = 0
  36. start_time = time.time()
  37. random.shuffle(train_im_list) # 打乱训练集
  38. # print(random_im_list)
  39. for im in train_im_list:
  40. root, file_name = os.path.split(im)
  41. root, _ = os.path.split(root)
  42. name, _ = os.path.splitext(file_name)
  43. gt_name = 'gt_' + name + '.txt'
  44. gt_path = os.path.join(root, "train_gt", gt_name)
  45. if not os.path.exists(gt_path):
  46. print( 'Ground truth file of image {0} not exists.'.format(im))
  47. continue
  48. gt_txt = lib.dataset_handler.read_gt_file(gt_path) # 读取对应的标签
  49. #print("processing image %s" % os.path.join(img_root1, im))
  50. img = cv2.imread(im)
  51. if img is None:
  52. iteration += 1
  53. continue
  54. img, gt_txt = lib.dataset_handler.scale_img(img, gt_txt) # 图像和标签做归一化
  55. tensor_img = img[np.newaxis, :, :, :]
  56. tensor_img = tensor_img.transpose(( 0, 3, 1, 2))
  57. if using_cuda:
  58. tensor_img = torch.FloatTensor(tensor_img).cuda()
  59. else:
  60. tensor_img = torch.FloatTensor(tensor_img)
  61. vertical_pred, score, side_refinement = net(tensor_img) # 正向计算,获取预测结果
  62. del tensor_img
  63. # transform bbox gt to anchor gt for training
  64. positive = []
  65. negative = []
  66. vertical_reg = []
  67. side_refinement_reg = []
  68. visual_img = copy.deepcopy(img) # 该图用于可视化标签
  69. try:
  70. # loop all bbox in one image
  71. for box in gt_txt:
  72. # generate anchors from one bbox
  73. gt_anchor, visual_img = lib.generate_gt_anchor.generate_gt_anchor(img, box, draw_img_gt=visual_img) # 获取图像的anchor标签
  74. positive1, negative1, vertical_reg1, side_refinement_reg1 = lib.tag_anchor.tag_anchor(gt_anchor, score, box) # 计算预测值反映在anchor层面的数据
  75. positive += positive1
  76. negative += negative1
  77. vertical_reg += vertical_reg1
  78. side_refinement_reg += side_refinement_reg1
  79. except:
  80. print( "warning: img %s raise error!" % im)
  81. iteration += 1
  82. continue
  83. if len(vertical_reg) == 0 or len(positive) == 0 or len(side_refinement_reg) == 0:
  84. iteration += 1
  85. continue
  86. cv2.imwrite(os.path.join(DRAW_PREFIX, file_name), visual_img)
  87. optimizer.zero_grad()
  88. # 计算误差
  89. loss, cls_loss, v_reg_loss, o_reg_loss = criterion(score, vertical_pred, side_refinement, positive,
  90. negative, vertical_reg, side_refinement_reg)
  91. # 反向传播
  92. loss.backward()
  93. optimizer.step()
  94. iteration += 1
  95. # save gpu memory by transferring loss to float
  96. total_loss += float(loss)
  97. total_cls_loss += float(cls_loss)
  98. total_v_reg_loss += float(v_reg_loss)
  99. total_o_reg_loss += float(o_reg_loss)
  100. if iteration % display_iter == 0:
  101. end_time = time.time()
  102. total_time = end_time - start_time
  103. print( 'Epoch: {2}/{3}, Iteration: {0}/{1}, loss: {4}, cls_loss: {5}, v_reg_loss: {6}, o_reg_loss: {7}, {8}'.
  104. format(iteration, total_iter, i, epoch, total_loss / display_iter, total_cls_loss / display_iter,
  105. total_v_reg_loss / display_iter, total_o_reg_loss / display_iter, im))
  106. logger.info( 'Epoch: {2}/{3}, Iteration: {0}/{1}'.format(iteration, total_iter, i, epoch))
  107. logger.info( 'loss: {0}'.format(total_loss / display_iter))
  108. logger.info( 'classification loss: {0}'.format(total_cls_loss / display_iter))
  109. logger.info( 'vertical regression loss: {0}'.format(total_v_reg_loss / display_iter))
  110. logger.info( 'side-refinement regression loss: {0}'.format(total_o_reg_loss / display_iter))
  111. train_loss_list.append(total_loss)
  112. total_loss = 0
  113. total_cls_loss = 0
  114. total_v_reg_loss = 0
  115. total_o_reg_loss = 0
  116. start_time = time.time()
  117. # 定期验证模型性能
  118. if iteration % val_iter == 0:
  119. net.eval()
  120. logger.info( 'Start evaluate at {0} epoch {1} iteration.'.format(i, iteration))
  121. val_loss = evaluate.val(net, criterion, val_batch_size, using_cuda, logger, val_im_list)
  122. logger.info( 'End evaluate.')
  123. net.train()
  124. start_time = time.time()
  125. test_loss_list.append(val_loss)
  126. # 定期存储模型
  127. if iteration % save_iter == 0:
  128. print( 'Model saved at ./model/ctpn-{0}-{1}.model'.format(i, iteration))
  129. torch.save(net.state_dict(), './model/ctpn-msra_ali-{0}-{1}.model'.format(i, iteration))
  130. print( 'Model saved at ./model/ctpn-{0}-end.model'.format(i))
  131. torch.save(net.state_dict(), './model/ctpn-msra_ali-{0}-end.model'.format(i))
  132. # 画出loss的变化图
  133. draw_loss_plot(train_loss_list, test_loss_list)

缩放图像具有一定规则:首先要保证文本框label的最短边也要等于600。我们通过scale = float(shortest_side)/float(min(height, width))来求得图像的缩放系数,对原始图像进行缩放。同时我们也要对我们的label也要根据该缩放系数进行缩放。


    
    
  1. def scale_img(img, gt, shortest_side=600):
  2. height = img.shape[ 0]
  3. width = img.shape[ 1]
  4. scale = float(shortest_side)/ float(min(height, width))
  5. img = cv2.resize(img, ( 0, 0), fx=scale, fy=scale)
  6. if img.shape[ 0] < img.shape[ 1] and img.shape[ 0] != 600:
  7. img = cv2.resize(img, ( 600, img.shape[ 1]))
  8. elif img.shape[ 0] > img.shape[ 1] and img.shape[ 1] != 600:
  9. img = cv2.resize(img, (img.shape[ 0], 600))
  10. elif img.shape[ 0] != 600:
  11. img = cv2.resize(img, ( 600, 600))
  12. h_scale = float(img.shape[ 0])/ float(height)
  13. w_scale = float(img.shape[ 1])/ float(width)
  14. scale_gt = []
  15. for box in gt:
  16. scale_box = []
  17. for i in range(len(box)):
  18. # x坐标
  19. if i % 2 == 0:
  20. scale_box.append( int( int(box[i]) * w_scale))
  21. # y坐标
  22. else:
  23. scale_box.append( int( int(box[i]) * h_scale))
  24. scale_gt.append(scale_box)
  25. return img, scale_gt

验证集评估:


    
    
  1. def val(net, criterion, batch_num, using_cuda, logger):
  2. img_root = '../dataset/OCR_dataset/ctpn/test_im'
  3. gt_root = '../dataset/OCR_dataset/ctpn/test_gt'
  4. img_list = os.listdir(img_root)
  5. total_loss = 0
  6. total_cls_loss = 0
  7. total_v_reg_loss = 0
  8. total_o_reg_loss = 0
  9. start_time = time.time()
  10. for im in random.sample(img_list, batch_num):
  11. name, _ = os.path.splitext(im)
  12. gt_name = 'gt_' + name + '.txt'
  13. gt_path = os.path.join(gt_root, gt_name)
  14. if not os.path.exists(gt_path):
  15. print( 'Ground truth file of image {0} not exists.'.format(im))
  16. continue
  17. gt_txt = Dataset.port.read_gt_file(gt_path, have_BOM= True)
  18. img = cv2.imread(os.path.join(img_root, im))
  19. img, gt_txt = Dataset.scale_img(img, gt_txt)
  20. tensor_img = img[np.newaxis, :, :, :]
  21. tensor_img = tensor_img.transpose(( 0, 3, 1, 2))
  22. if using_cuda:
  23. tensor_img = torch.FloatTensor(tensor_img).cuda()
  24. else:
  25. tensor_img = torch.FloatTensor(tensor_img)
  26. vertical_pred, score, side_refinement = net(tensor_img)
  27. del tensor_img
  28. positive = []
  29. negative = []
  30. vertical_reg = []
  31. side_refinement_reg = []
  32. for box in gt_txt:
  33. gt_anchor = Dataset.generate_gt_anchor(img, box)
  34. positive1, negative1, vertical_reg1, side_refinement_reg1 = Net.tag_anchor(gt_anchor, score, box)
  35. positive += positive1
  36. negative += negative1
  37. vertical_reg += vertical_reg1
  38. side_refinement_reg += side_refinement_reg1
  39. if len(vertical_reg) == 0 or len(positive) == 0 or len(side_refinement_reg) == 0:
  40. batch_num -= 1
  41. continue
  42. loss, cls_loss, v_reg_loss, o_reg_loss = criterion(score, vertical_pred, side_refinement, positive,
  43. negative, vertical_reg, side_refinement_reg)
  44. total_loss += loss
  45. total_cls_loss += cls_loss
  46. total_v_reg_loss += v_reg_loss
  47. total_o_reg_loss += o_reg_loss
  48. end_time = time.time()
  49. total_time = end_time - start_time
  50. print( '#################### Start evaluate ####################')
  51. print( 'loss: {0}'.format(total_loss / float(batch_num)))
  52. logger.info( 'Evaluate loss: {0}'.format(total_loss / float(batch_num)))
  53. print( 'classification loss: {0}'.format(total_cls_loss / float(batch_num)))
  54. logger.info( 'Evaluate vertical regression loss: {0}'.format(total_v_reg_loss / float(batch_num)))
  55. print( 'vertical regression loss: {0}'.format(total_v_reg_loss / float(batch_num)))
  56. logger.info( 'Evaluate side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num)))
  57. print( 'side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num)))
  58. logger.info( 'Evaluate side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num)))
  59. print( '{1} iterations for {0} seconds.'.format(total_time, batch_num))
  60. print( '##################### Evaluate end #####################')
  61. print( '\n')

训练过程:

1093303-20181202175714392-2085371767.png

训练效果与预测效果

测试效果:输入一张图片,给出最后的检测结果


    
    
  1. def infer_one(im_name, net):
  2. im = cv2.imread(im_name)
  3. im = lib.dataset_handler.scale_img_only(im) # 归一化图像
  4. img = copy.deepcopy(im)
  5. img = img.transpose( 2, 0, 1)
  6. img = img[np.newaxis, :, :, :]
  7. img = torch.Tensor(img)
  8. v, score, side = net(img, val= True) # 送入网络预测
  9. result = []
  10. # 根据分数获取有文字的anchor
  11. for i in range(score.shape[ 0]):
  12. for j in range(score.shape[ 1]):
  13. for k in range(score.shape[ 2]):
  14. if score[i, j, k, 1] > THRESH_HOLD:
  15. result.append((j, k, i, float(score[i, j, k, 1].detach().numpy())))
  16. # nms过滤
  17. for_nms = []
  18. for box in result:
  19. pt = lib.utils.trans_to_2pt(box[ 1], box[ 0] * 16 + 7.5, anchor_height[box[ 2]])
  20. for_nms.append([pt[ 0], pt[ 1], pt[ 2], pt[ 3], box[ 3], box[ 0], box[ 1], box[ 2]])
  21. for_nms = np.array(for_nms, dtype=np.float32)
  22. nms_result = lib.nms.cpu_nms(for_nms, NMS_THRESH)
  23. out_nms = []
  24. for i in nms_result:
  25. out_nms.append(for_nms[i, 0: 8])
  26. # 确定哪几个anchors是属于一组的
  27. connect = get_successions(v, out_nms)
  28. # 将一组anchors合并成一条文本线
  29. texts = get_text_lines(connect, im.shape)
  30. for box in texts:
  31. box = np.array(box)
  32. print(box)
  33. lib.draw_image.draw_ploy_4pt(im, box[ 0: 8])
  34. _, basename = os.path.split(im_name)
  35. cv2.imwrite( './infer_'+basename, im)

推断时提到了get_successions用于获取一个预测文本行里的所有anchors,换句话说,我们得到的很多预测有字符的anchor,但是我们怎么知道哪些acnhors可以组成一个文本线呢?所以我们需要实现一个anchor合并算法,这也是CTPN代码实现中最为困难的一步。

CTPN论文提到,文本线构造法如下:文本行构建很简单,通过将那些text/no-text score > 0.7的连续的text proposals相连接即可。文本行的构建如下。

  • 首先,为一个proposal Bi定义一个邻居(Bj):Bj−>Bi,其中:
  1. Bj在水平距离上离Bi最近
  2. 该距离小于50 pixels
  • 它们的垂直重叠(vertical overlap) > 0.7

一看理论很简单,但是一到自己实现就困难重重了。真是应了那句“纸上得来终觉浅,绝知此事要躬行”啊!get_successions传入的参数是v代表每个预测anchor的h和y信息,anchors代表每个anchors的四个顶点坐标信息。


    
    
  1. def get_successions(v, anchors=[]):
  2. texts = []
  3. for i, anchor in enumerate(anchors):
  4. neighbours = [] # 记录每组的anchors
  5. neighbours.append(i)
  6. center_x1 = (anchor[ 2] + anchor[ 0]) / 2
  7. h1 = get_anchor_h(anchor, v) # 获取该anchor的高度
  8. # find i's neighbour
  9. # 遍历余下的anchors,找出邻居
  10. for j in range(i + 1, len(anchors)):
  11. center_x2 = (anchors[j][ 2] + anchors[j][ 0]) / 2 # 中心点X坐标
  12. h2 = get_anchor_h(anchors[j], v)
  13. # 如果这两个Anchor间的距离小于50,而且他们的它们的垂直重叠(vertical overlap)大于一定阈值,那就是邻居
  14. if abs(center_x1 - center_x2) < NEIGHBOURS_MIN_DIST and \
  15. meet_v_iou(max(anchor[ 1], anchors[j][ 1]), min(anchor[ 3], anchors[j][ 3]), h1, h2): # less than 50 pixel between each anchor
  16. neighbours.append(j)
  17. if len(neighbours) != 0:
  18. texts.append(neighbours)
  19. # 通过上面的步骤,我们已经把每一个anchor的邻居都找到并加入了对应的集合中了,现在我们
  20. # 通过一个循环来不断将每个小组合并
  21. need_merge = True
  22. while need_merge:
  23. need_merge = False
  24. # ok, we combine again.
  25. for i, line in enumerate(texts):
  26. if len(line) == 0:
  27. continue
  28. for index in line:
  29. for j in range(i+ 1, len(texts)):
  30. if index in texts[j]:
  31. texts[i] += texts[j]
  32. texts[i] = list( set(texts[i]))
  33. texts[j] = []
  34. need_merge = True
  35. result = []
  36. #print(texts)
  37. for text in texts:
  38. if len( text) < 2:
  39. continue
  40. local = []
  41. for j in text:
  42. local.append(anchors[j])
  43. result.append(local)
  44. return result

当我们得到一个文本框的anchors组合后,接下来要做的就是将组内的anchors串联成一个文本框。get_text_lines函数做的就是这个功能。


    
    
  1. def get_text_lines(text_proposals, im_size, scores= 0):
  2. "" "
  3. text_proposals:boxes
  4. " ""
  5. text_lines = np.zeros((len(text_proposals), 8), np.float32)
  6. for index, tp_indices in enumerate(text_proposals):
  7. text_line_boxes = np.array(tp_indices) # 每个文本行的全部小框
  8. #print(text_line_boxes)
  9. #print(type(text_line_boxes))
  10. #print(text_line_boxes.shape)
  11. X = (text_line_boxes[:, 0] + text_line_boxes[:, 2]) / 2 # 求每一个小框的中心x,y坐标
  12. Y = (text_line_boxes[:, 1] + text_line_boxes[:, 3]) / 2
  13. #print(X)
  14. #print(Y)
  15. z1 = np.polyfit(X, Y, 1) # 多项式拟合,根据之前求的中心店拟合一条直线(最小二乘)
  16. x 0 = np.min(text_line_boxes[:, 0]) # 文本行x坐标最小值
  17. x1 = np.max(text_line_boxes[:, 2]) # 文本行x坐标最大值
  18. offset = (text_line_boxes[ 0, 2] - text_line_boxes[ 0, 0]) * 0. 5 # 小框宽度的一半
  19. # 以全部小框的左上角这个点去拟合一条直线,然后计算一下文本行x坐标的极左极右对应的y坐标
  20. lt_y, rt_y = fit_y(text_line_boxes[:, 0], text_line_boxes[:, 1], x 0 + offset, x1 - offset)
  21. # 以全部小框的左下角这个点去拟合一条直线,然后计算一下文本行x坐标的极左极右对应的y坐标
  22. lb_y, rb_y = fit_y(text_line_boxes[:, 0], text_line_boxes[:, 3], x 0 + offset, x1 - offset)
  23. #score = scores[list(tp_indices)].sum() / float(len(tp_indices)) # 求全部小框得分的均值作为文本行的均值
  24. text_lines[ index, 0] = x 0
  25. text_lines[ index, 1] = min(lt_y, rt_y) # 文本行上端 线段 的y坐标的小值
  26. text_lines[ index, 2] = x1
  27. text_lines[ index, 3] = max(lb_y, rb_y) # 文本行下端 线段 的y坐标的大值
  28. text_lines[ index, 4] = scores # 文本行得分
  29. text_lines[ index, 5] = z1[ 0] # 根据中心点拟合的直线的k,b
  30. text_lines[ index, 6] = z1[ 1]
  31. height = np.mean((text_line_boxes[:, 3] - text_line_boxes[:, 1])) # 小框平均高度
  32. text_lines[ index, 7] = height + 2.5
  33. text_recs = np.zeros((len(text_lines), 9), np.float32)
  34. index = 0
  35. for line in text_lines:
  36. b1 = line[ 6] - line[ 7] / 2 # 根据高度和文本行中心线,求取文本行上下两条线的b值
  37. b2 = line[ 6] + line[ 7] / 2
  38. x1 = line[ 0]
  39. y1 = line[ 5] * line[ 0] + b1 # 左上
  40. x2 = line[ 2]
  41. y2 = line[ 5] * line[ 2] + b1 # 右上
  42. x3 = line[ 0]
  43. y3 = line[ 5] * line[ 0] + b2 # 左下
  44. x4 = line[ 2]
  45. y4 = line[ 5] * line[ 2] + b2 # 右下
  46. disX = x2 - x1
  47. disY = y2 - y1
  48. width = np.sqrt(disX * disX + disY * disY) # 文本行宽度
  49. fTmp 0 = y3 - y1 # 文本行高度
  50. fTmp1 = fTmp 0 * disY / width
  51. x = np.fabs(fTmp1 * disX / width) # 做补偿
  52. y = np.fabs(fTmp1 * disY / width)
  53. if line[ 5] < 0:
  54. x1 -= x
  55. y1 += y
  56. x4 += x
  57. y4 -= y
  58. else:
  59. x2 += x
  60. y2 += y
  61. x3 -= x
  62. y3 -= y
  63. # clock-wise order
  64. text_recs[ index, 0] = x1
  65. text_recs[ index, 1] = y1
  66. text_recs[ index, 2] = x2
  67. text_recs[ index, 3] = y2
  68. text_recs[ index, 4] = x4
  69. text_recs[ index, 5] = y4
  70. text_recs[ index, 6] = x3
  71. text_recs[ index, 7] = y3
  72. text_recs[ index, 8] = line[ 4]
  73. index = index + 1
  74. text_recs = clip_boxes(text_recs, im_size)
  75. return text_recs

检测效果和总结

首先看一下训练出来的模型的文字检测效果,为了便于观察,我把anchor和最终合并好的文本框一并画出:

1093303-20181202181822551-1270702271.png

下面再看看一些比较好的文字检测效果吧:

1093303-20181202181837840-1395668682.png

在实现过程中的一些总结和想法:

  1. CTPN对于带旋转角度的文本的检测效果不好,其实这是CTPN的算法特点决定的:一个个固定宽度的四边形是很难合并出一个准确的文本框,比如一些anchors很难组成一组,即使组成一组了也很难精确恢复成完整的精确的文本矩形框(推断阶段的缺点)。当然啦,对于水平排布的文本检测,个人认为这个算法思路还是很奏效的。
  2. CTPN中的side-refinement其实作用不大,如果我们检测出来的文本是直接拿出识别,这个side-refinement优化的几个像素差别其实可以忽略;
  3. CTPN的中间步骤有点多:从anchor标签的生成到中间计算loss再到最后推断的文本线生成步骤,都会引入一定的误差,这个缺点也是EAST论文中所提出的。训练的步骤越简洁,中间过程越少,精度更有保障。
  4. CTPN的算法得出的效果可以看出,准确率低但召回率高。这种基于16像素的anchor识别感觉对于一些大的非文字图标(比如路标)误判率相当高,这是源于其anchor的宽度实在太小了,尽管使用了lstm关联周围anchor,但是我还是认为有点“一叶障目”的感觉。所以CTPN对于过大或过小的文字检测效果不会太好。
  5. EAST是个比较老的算法了(2016年),其思路在当年还是很创新的,但是也有很多弊端。现在提出的新方法已经基本解决了这些不足之处,比如EAST,PixelNet都是一些很优秀的新算法。

CTPN的完整实现可以参考我的Github

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值