unet网络自定义训练代码_OCR模型训练

本文介绍了基于深度学习的文本检测算法,包括unet网络的自定义训练过程,以及OCR文本识别中常用的CRNN结构。详细阐述了训练数据的准备、网络设计、损失函数和训练策略,提供了具体的网络结构图和训练流程。
摘要由CSDN通过智能技术生成
d6fef1042e9b7524991fa8de96cbf49e.png

点击上方“蓝字”关注我们

作者 | 李虎

编辑 | 张婵

OCR 从流程上包括两步: 文本检测文本识别,即将图片输入到文本检测算法中得到一个个的文本框,将每个文本框分别送入到文本识别算法中得到识别结果。

1. 基于深度学习的文本检测算法大致分为两类:基于候选框回归的算法基于分割的算法。

 
  • 基于候选框回归的文本检测,是源于目标检测算法,然后结合文本框的特点改造而成的,包括 CTPN、EAST 和 Seglink 算法等。CTPN 是基于 faster RCNN 改进的算法,在 CNN 后加入 RNN 网络,主要思想是把文本行切分成小的细长矩形进行检测再拼接起来;SegLink 算法的检测思路与 CTPN 类似,也是先检测文本行的小块然后拼起来,但网络结构上采取了 SSD 的思路,在多个特征图尺度上进行文本检测,然后将多尺度的结果融合起来,另外输出中加入了角度信息的回归;EAST 算法,它是直接回归的整个文本行的坐标,而不是细长矩形拼接,网络结构上利用了 Unet 的上采样结构来提取特征,融入了浅层和深层的信息,并且在输出层回归了角度信息,可以检测斜框。

  • 基于分割的文本检测,其基本思路是通过分割网络进行像素级别的语义分割,再基于分割的结果构建文本行,包括 PixelLink、Psenet 和 Craft 算法等。PixelLink 算法,网络结构上采用 FCN 提取特征,直接通过实例分割结果中提取文本位置,输出的特征图包括像素分类特征图和像素 link 特征图。Psenet 算法,网络结构上采用 FPN 特征金字塔提取特征,对每个分割区域预测出多个分割结果,然后提出一种新颖的渐进扩展算法,将多个分割的结果进行融合。Craft 算法,网络结构上采用 UNet 的结构,输出的特征图包括 Region score 特征图和像素 Affinity score 特征图,另外特征图中使用了高斯函数,将预测像素点分类的问题转成了像素点的回归问题,能更好的适应文字没有严格包围边界的特点。

2. 基于深度学习的文本识别算法则相对较为统一,一般都采用CNN+RNN+CTC 的结构,俗称 CRNN 结构,因为这种结构的识别效果很好,且泛化性好,工业上大多都用的这种结构,然后在该框架上做一些改进,如更换 CNN 主干网络,缩减卷积层以提高速度缩减空间,或者改进 RNN 加入 Attention 结构等。

  本文主要介绍了我们在生产上使用的文本检测和文本识别算法。算法的训练流程一般包括以下步骤:

1. 准备训练数据,有的是需要标注(如文本检测中),有的主要是造数据(如文本识别中);

2. 定义算法网络,这里主要是明确输入和输出;

3. 准备好 batch 数据集,这里主要是处理输入的图片和标签数据,标签数据结构与第 2 步中的网络输出对应,例如 craft 要进行高斯函数计算等,而文本识别中则无需处理,直接将造好的数据输入即可;

4. 定义 loss,优化器和学习率等参数;

5. 训练,这里主要是定义每批次数据训练的操作策略,如保存策略,日志策略,测试策略等。

OCR文本检测

我们在文本定位中采用的是 Craft 算法,它是一种基于分割的算法,无需进行大量候选框的回归,也无需进行 NMS 后处理,因此极大提升了速度,并且它是字符级别的文本检测器,定位的是字符,对于尺寸缩放不敏感,无需多尺度训练和预测来解决尺度方差问题,最后其泛化性能也能达到 SOTA 的水平。

1. 训练数据标注

该方法是基于分割的,背景文字是指的原本就在票据上的文字,如“姓名”、“出生年月”等文字,前景文字是指的待识别的文字,也就是用户后填进去的内容。标注步骤就是将这些文字框出来,标上相应的类别。我们采用自己开发的标注工具,这里也可以使用开源的 labelme 工具,生成的标注文件如下所示,第一行是图片所在路径,从第二行开始就是坐标框信息,最后一位是类别。

f50b3204b5de39fafdc2cc96f9ccd0bd.png

2. 网络设计

下图是网络结构图,整体采用了 Unet 的主结构,主干网络用的 vgg16,输入图片首先经过 vgg16 后,接 UNet 的上采样结构,其作用是使得深层和浅层的特征图进行拼接作为输出。然后再接一系列的卷积操作,充分提取特征。最后输出的特征图包括 Region score 特征图和像素 Affinity score 特征图。

c4b32412809e40d0cd94bfb85c569a53.png

网络的代码如下所示:
  1. class CRAFT(nn.Module):

  2. def __init__(self, pretrained=False, freeze=False, phase='test'):

  3. super(CRAFT, self).__init__()

  4. """ Base network """

  5. self.basenet = vgg16_bn(pretrained, freeze)

  6. """ 固定部分参数,用于迁移学习"""

  7. if phase == 'train':

  8. for p in self.parameters():

  9. p.requires_grad=False

  10. """ U network """

  11. self.upconv1 = double_conv(1024, 512, 256)

  12. self.upconv2 = double_conv(512, 256, 128)

  13. self.upconv3 = double_conv(256, 128, 64)

  14. self.upconv4 = double_conv(128, 64, 32)

  15. num_class = 2

  16. self.conv_cls = nn.Sequential(

  17. nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),

  18. nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),

  19. nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True),

  20. nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True),

  21. nn.Conv2d(16, num_class, kernel_size=1),

  22. )

  23. init_weights(self.upconv1.modules())

  24. init_weights(self.upconv2.modules())

  25. init_weights(self.upconv3.modules())

  26. init_weights(self.upconv4.modules())

  27. init_weights(self.conv_cls.modules())

  28. def forward(self, x):

  29. """ Base network """

  30. sources = self.basenet(x)

  31. """ U network """

  32. y = torch.cat([sources[0], sources[1]], dim=1)

  33. y = self.upconv1(y)

  34. y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)

  35. y = torch.cat([y, sources[2]], dim=1)

  36. y = self.upconv2(y)

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值