ResNet101+BiLSTM +CTC手写汉字识别(一)

文章讨论了在手写字体识别中,作者尝试使用ResNet101替代CRNN中的CNN部分,并连接双层BiLSTM进行改进。重点在于调整网络结构以优化字体识别性能,以及提到使用CTCLOSS函数时的输入和输出处理策略。
摘要由CSDN通过智能技术生成

在练习的时候发现CRNN对是写字体的识别不是太理想,一琢磨,好像CNN这部分可以换一个别的结构的,于是随便选了个ResNet101,于是全连接层连给拆了,然后接上了双层的BiLSTM

class BiLSTM(nn.Module):
    def __init__(self, nIn, nHidden, nOut):
        super(BiLSTM, self).__init__()

        self.LSTM = nn.LSTM(nIn, nHidden, bidirectional=True)
        self.Linear = nn.Linear(nHidden * 2, nOut)

    def forward(self, input):
        recurrent, _ = self.LSTM(input)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)

        output = self.Linear(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1)

        return output
    
class CRNN(nn.Module):
   #接入全连接层之前的resNet101最后的输出的是2048,所以这里用1024 
   #nClass = 预测的类别数(是预测类型数+1因为还有一个为空的占位),
   #maxLen = 这组训练数据的最大文字长度,
   #nHidden = 进入BiLSTM前的数据的维度数/2
   #这里直接用的resNet101的三通道输入,如果想要训练1通道的图,那记得改一                     #下resNet101的头

    def __init__(self,nClass,maxLen,nHidden = 1024) :
        super(CRNN,self).__init__()
        #不加载预训练模型
        #当然你也可以加载,只是我这里没加载罢了,pytorch帮你用kaiming法初始了权                    #重的,如果想要换别的,可以自行操作更改

        resnet101 = models.resnet101()
        #其实这个可以没有的,在大小不确定的时候你可以用来调大小,
        #但是别让b c h w中的 w 比(maxLen)*2+1)小就行了,

        #None的意思是和上面输入的h一样的意思,
        #当然你也可以两个都设个大小规范一下。反正别比变化后的大小小就行了

        resnet101.avgpool = nn.AdaptiveAvgPool2d(output_size=(None, (maxLen)*2+1))
        self.cnn = nn.Sequential(
           #[0:-2]意思是,从0开始一直取到结尾前两个,[0:-1]自然是前一个 
           #如果你不要avgpool那么这里就是*list(resnet101.children())[0:-2]

            *list(resnet101.children())[0:-1])
        self.rnn = nn.Sequential(
            BiLSTM(2048,nHidden,nHidden),
            BiLSTM(nHidden,nHidden,nClass)
            )
    def forward(self,input):
        cnn = self.cnn(input)
        # 因为此时h为1所以将维度压缩
        # 这样做的原因是,将字符特征压缩到同一直线

        cnn = cnn.squeeze(dim=2)
        # 调整各个维度的位置(B,C,W)->(W,B,C),对应lstm的输入(seq,batch,input_size)
        #因为上面压缩了h(因为H = 1所以可以压缩掉)所以可以直接调整位置

        cnn = cnn.permute(2,0,1)
        #得出双层BiLSTM预测结果
        #这里可以加一个log_softmax进行结果的概率预测,当然也可以加载后面的训练代码里

        return self.rnn(cnn)

被拆解后的ResNet101

    #上面的部分省略,太多了,直接看接上后的下面被改后的的样子
    (2): Bottleneck(
      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (8): AdaptiveAvgPool2d(output_size=(None, 23))
)

#这下面的全连接层被去掉了,AdaptiveAvgPool2d的大小也边了不是(1,1)了

接上BiLSTM后的样子 

#上省略,太多了,直接看接上后的下面被改后的的样子

 (2): Bottleneck(
        (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
    )
    (8): AdaptiveAvgPool2d(output_size=(None, 23))
  )
  (rnn): Sequential(
    (0): BiLSTM(
      (LSTM): LSTM(2048, 1024, bidirectional=True)
      (Linear): Linear(in_features=2048, out_features=1024, bias=True)
    )
    (1): BiLSTM(
      (LSTM): LSTM(1024, 1024, bidirectional=True)
      (Linear): Linear(in_features=2048, out_features=12, bias=True)
    )
  )
)

 这里的AdaptiveAvgPool2d留着是因为可以在进入BiLSTM之前改变形状,让训练的时候图片大小选择更灵活

然后需要注意的是,CTCLOSS损失函数中,官方推荐的input_lengths值是 maxLen*2+1(要考虑到每个预测值之间都有一个空位的情况,当然一般来说可能会有多个空位)而且预测值 (也就是进入BiLSTM前的W大小)的长度不能小于target_lengths(也就是这组数据的标签长度),不然CTCLOSS的loss会直接变成inf,当然如果你图片过小,导致进入BiLSTM的时候W不够,那你可以试试,把图片同比变大点,到了AdaptiveAvgPool2d的时候把output_size设成(1,None),这样可以保持宽不变的情况下可以让高变成1,然后正常进入BiLSTM 不用改别的东西了

  • 9
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值