训练East文本检测模型若干感悟

East的模型结构

总的来说,真是简单。但是简单,却又好用,为旷视点赞。

分别用PVANet特征提取,采用UNet结构做特征融合,最后直接输出置信度,坐标位置,和角度。

虽然看起来简单,是不是训练起来也一样简单呢????

 

不然,其实置信度是最难train的。说说我的经历。

本来大家都习惯用vgg来提取特征,效果不错,偏偏又人自命不凡,比如我自己写来一个四不像的模型,大家看看。

最终结果勉勉强强,却还是能检测处文字的。

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class YYBlock(nn.Module):
    def __init__(self, in_ch, extend,blockType):
        super(YYBlock, self).__init__()
        self.blockType = blockType
        self.groups = 1
        self.extend = extend
        self.branch9x9 = nn.Sequential(
            nn.Conv2d(in_ch, 2 * extend, 9, 1, padding=0, dilation=2, bias=True),
            nn.BatchNorm2d(2 * extend),
            nn.ReLU(inplace=True),
            nn.Conv2d(2 * extend, 2 * extend, 9, 1, padding=0, dilation=2, bias=True),
            nn.BatchNorm2d(2 * extend),
            nn.ReLU(inplace=True)
        )

        self.branch7x7 = nn.Sequential(
            nn.Conv2d(in_ch, 4 * extend, 7, 1, padding=0, dilation=2, bias=True),
            nn.BatchNorm2d(4 * extend),
            nn.ReLU(inplace=True),
            nn.Conv2d(4 * extend, 4 * extend, 7, 1, padding=0, dilation=2, bias=True),
            nn.BatchNorm2d(4 * extend),
            nn.ReLU(inplace=True)
        )

        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_ch, 6 * extend, 5, 1, padding=0,groups=self.groups, dilation=1, bias=True),
            nn.BatchNorm2d(6 * extend),
            nn.ReLU(inplace=True),
            nn.Conv2d(6 * extend, 6 * extend, 5, 1, padding=0, groups=self.groups, dilation=1, bias=True),
            nn.BatchNorm2d(6 * extend),
            nn.ReLU(inplace=True)
        )

        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_ch, 8 * extend, 3, 1, padding=0,groups=self.groups, dilation=1, bias=True),
            nn.BatchNorm2d(8 * extend),
            nn.ReLU(inplace=True),
            nn.Conv2d(8 * extend, 8 * extend, 3, 1, padding=0, groups=self.groups, dilation=1, bias=True),
            nn.BatchNorm2d(8 * extend),
            nn.ReLU(inplace=True)
        )

        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_ch, 10 * extend, 1, 1, padding=0,groups=self.groups, dilation=1, bias=True),
            nn.BatchNorm2d(10 * extend),
            nn.ReLU(inplace=True)
        )

        self.avg = nn.AvgPool2d(kernel_size=2,stride=2)
    def forward(self, x):
        residual = x

        b, c, h, w = x.size()
        conv5x5 = self.branch5x5(x)
        conv3x3 = self.branch3x3(x)
        conv1x1 = self.branch1x1(x)

        if self.blockType == 'L':
            conv9x9 = self.branch9x9(x)
            conv7x7 = self.branch7x7(x)
        elif self.blockType == 'M':
            conv7x7 = self.branch7x7(x)

        feature5x5 = F.interpolate(conv5x5, (h, w), None, 'bilinear', True)
        feature3x3 = F.interpolate(conv3x3, (h, w), None, 'bilinear', True)
        feature1x1 = F.interpolate(conv1x1, (h, w), None, 'bilinear', True)

        if self.blockType == 'L':
            feature9x9 = F.interpolate(conv9x9, (h, w), None, 'bilinear', True)
            feature7x7 = F.interpolate(conv7x7, (h, w), None, 'bilinear', True)
            feature_cat = torch.cat([feature9x9, feature7x7, feature5x5, feature3x3, feature1x1, residual], dim=1)
        elif self.blockType == 'M':
            feature7x7 = F.interpolate(conv7x7, (h, w), None, 'bilinear', True)
            feature_cat = torch.cat([feature7x7, feature5x5, feature3x3, feature1x1, residual], dim=1)
        else:
            feature_cat = torch.cat([feature5x5, feature3x3, feature1x1, residual], dim=1)

        out = self.avg(feature_cat)
        return out

cfg = [{'type':'L','extend':1},{'type':'M','extend':4},
       {'type':'S','extend':4},{'type':'S','extend':8},
       {'type':'S','extend':8}]

def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for config in cfg:
        blockType = config['type']
        extend = config['extend']
        layers += [YYBlock(in_channels, extend=extend, blockType=blockType)]
        if blockType == 'L':
            in_channels = 30 * extend + in_channels
        elif blockType == 'M':
            in_channels = 28 * extend + in_channels
        else:
            in_channels = 24 * extend + in_channels
    return nn.Sequential(*layers)

class YYModel(nn.Module):
    def __init__(self,cfg):
        super(YYModel, self).__init__()
        self.features = make_layers(cfg=cfg)
        self.blocks = []

        for m in self.modules():
            if isinstance(m,nn.Conv2d):
                nn.init.kaiming_normal_(m.weight,mode='fan_out',nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias,0)
            elif isinstance(m,nn.BatchNorm2d):
                nn.init.constant_(m.weight,1)
                nn.init.constant_(m.bias,0)

    def forward(self,x):
        out = []
        for feature in self.features:
            x = feature(x)
            out.append(x)
            # print("feature.shape:{0}".format(x.shape))
        return out[1:]

以上为我自以为是的作品,idea来自语义分割中deeplab-v3的ASPP结构,谷歌mobile net,谷歌BlazeFace

结果画虎不成,反类犬。没有我想的那么高效。

原因有几点:

         首先:我从BlazeFace得到一个信息,大卷积核使用较少的filter个数在浅层效果好过小卷积核使用较多的filter个数。本来是没错,错在我使用来过大的卷积核,大部分层都使用了大卷积核,然后导致参数量上来了,然后缩减了模型层次。导致deep learning已经不deep了,没办法深挖到语义信息。

        然后:小的kernel根本不需要用双线性差值,3x3一个padding就可以了,1x1当然padding都不需要,分支过多,然后filter个数不足,特别是大核fileter个数更少,导致单层融合多视野能力实际比较差。大核的权重完全被小核覆盖住了(小核filter个数多),导致学到的东西很容易丢失。          

        这个模型不是只有缺点,优点也有,学习的比vgg快,因为层次少,25个eporch就能学出比较好的效果。但是潜力不足,后续优化空间少。

        后期我换回了vgg16,只是把所有层次filter减半了。降低模型的参数量,提高模型泛化能力。

        训练也有一些trick:我发现刚开始train,基本学不到什么东西,可能跟数据集有关,跟显卡也有关。数据集只有1000张图片,很多图片没有文字,导致的后果是总样本数少,正样本数更少,样本特别不平衡。我开始使用的batchsize是8,lr =0.1,基本学不到啥,还总是nan.

       分析了一下nan,很可能是梯度爆炸了,果断降低了lr=0.001,batchsize=8肯定太小了,增大到16,可视化之后,发现置信度特别难train,果断使用dice loss来应对类别不平衡,并且加大20倍置信度损失的权重。效果开始出来了,完美。

 

 

 

 

 

 

 

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值