夯实基础系列:CRNN核心代码

引言
  • CRNN是经典的文本识别算法,这里主要用来夯实基础,掌握CRNN基本原理以及PyTorch实现。
基本原理
  • CRNN采取的架构是CNN+RNN+CTC,
    • CNN:使用深度CNN,对输入图像提取特征,得到特征图
    • RNN:使用双向RNN对特征序列进行预测,对序列中每个特征向量进行学习,并输出预测标签
    • CTC:使用CTC损失,把从循环层获取的一系列标签分布转换为最终的标签序列
核心代码实现(可直接复制运行)
  • torch.nn.CTCLoss()的输入必须要经过logsoftmax函数的
  • 以下代码包括:训练和损失,推理和解码4个部分
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import string

import torch
from torch import nn
import torch.nn.functional as F


class CRNN(nn.Module):
    def __init__(self, img_height, input_channel, n_class, hidden_size):
        super().__init__()

        if img_height % 16 != 0:
            raise ValueError('img_height has to be a multiple of 16')

        kernel_size = [3, 3, 3, 3, 3, 3, 2]
        padding_size = [1, 1, 1, 1, 1, 1, 0]
        stride = [1, 1, 1, 1, 1, 1, 1]
        channel = [64, 128, 256, 256, 512, 512, 512]

        def conv_relu(i, batchNormalization=False):
            in_channels = input_channel if i == 0 else channel[i - 1]
            out_channels = channel[i]
            cnn.add_module(f'conv{i}',
                           nn.Conv2d(in_channels, out_channels,
                                     kernel_size[i],
                                     stride[i],
                                     padding_size[i]))

            if batchNormalization:
                cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(out_channels))
            cnn.add_module(f'relu{i}', nn.ReLU(True))

        # x: 1 x 32 x 320
        cnn = nn.Sequential()
        conv_relu(0)
        cnn.add_module('pooling0', nn.MaxPool2d(2, 2))  # 64x16x160

        conv_relu(1)
        cnn.add_module('pooling1', nn.MaxPool2d(2, 2))  # 128x8x80

        conv_relu(2, True)
        conv_relu(3)
        cnn.add_module('pooling2',
                       nn.MaxPool2d(kernel_size=(2, 2),
                                    stride=(2, 1),
                                    padding=(0, 1)))  # 256x4x81

        conv_relu(4, True)
        conv_relu(5)
        cnn.add_module('pooling3',
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 512x2x82
        conv_relu(6, True)  # 512x1x81

        self.cnn = cnn
        self.rnn = nn.Sequential(
            BidirectionalLSTM(512, hidden_size, hidden_size),
            BidirectionalLSTM(hidden_size, hidden_size, n_class)
        )

    def forward(self, x):
        cnn_feature = self.cnn(x)

        # 1 x 512 x 1 x 81
        h = cnn_feature.size()[2]
        if h != 1:
            raise ValueError("the height of cnn_feature must be 1")

        cnn_feature = cnn_feature.squeeze(2)

        # 81: 序列长度 1: batch size, 512: 每个特征的维度
        cnn_feature = cnn_feature.permute(2, 0, 1)

        output = self.rnn(cnn_feature)
        # [81, 1, num_classes]
        x = F.log_softmax(x, dim=2)
        return output


class BidirectionalLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, out_feature):
        super().__init__()
        self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True)
        self.embedding = nn.Linear(hidden_size * 2, out_feature)

    def forward(self, x):
        # x: [81, 1, 512] → [sequence_length, batch_size, input_size]
        recurrent, _ = self.rnn(x)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)

        output = self.embedding(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1)
        return output


def decode(preds, preds_length):
    length = preds_length[0]
    char_list = []
    for i in range(length):
        # 第一个索引是blank
        if preds[i] != 0 and (not (i > 0 and preds[i - 1] == preds[i])):
            char_list.append(alphabet[preds[i] - 1])
    return ''.join(char_list)


if __name__ == '__main__':
    alphabet = ['blank'] + list(string.ascii_lowercase)
    num_classes = len(alphabet)  # 27
    
    img = torch.randn((1, 1, 32, 320))

    ctc_loss = nn.CTCLoss()

    crnn = CRNN(32, 1, num_classes, 256)

    # 推理
    preds = crnn(img)

    # 推理:解码得到文字内容
    # 获得每一个seq对应的num_classes类中最大的那一类的索引
    _, infer_preds = preds.max(2)  # preds out: [81, 1]
    infer_preds = infer_preds.transpose(1, 0).contiguous().view(-1)  # out: [81]
    preds_len = torch.IntTensor([infer_preds.shape[0]])
    text = decode(infer_preds, preds_len)
    print(text)

    # 训练:计算loss
    min_seq_length = 10
    max_seq_length = 30
    batch_size = img.shape[0]
    time_step = preds.shape[0]
    input_length = torch.IntTensor([time_step] * batch_size)
    target = torch.randint(low=1, high=num_classes,
                           size=(batch_size, max_seq_length),
                           dtype=torch.long)
    target_length = torch.randint(low=min_seq_length,
                                  high=max_seq_length,
                                  size=(batch_size,), dtype=torch.long)

    # preds shape: [81, 1, num_classes]
    # target shape: [1, 30]
    # input_length: [1]
    # target_length: [1]
    loss = ctc_loss(preds, target, input_length, target_length)

    print(preds.shape)
参考博客
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值