【OCR】【专题系列】五、基于Encoder-Decoder的文本识别

【OCR】【专题系列】五、基于Encoder-Decoder的文本识别

目录

一、论文阅读

二、代码实现

三、结果讨论


一、论文阅读

        OCR识别技术在流程上,可以分为:1.CNN抽取图像特征;2.RNN/BiLSTM组合上下文信息特征;3.对齐标签目标函数产生Loss训练整个网络。见下图所示:

    

         在上述流程中,步骤1可以通过ResNet/VGG等图像处理结构来对CNN网络结构进行替换;在步骤2中可通过RNN/Transformer/Bert等时序文本网络结构来替换BiLSTM的网络结构;然而文字识别区别一般的任务最重要的是步骤3目标函数的选择和实现。在我博客OCR系列代码中,讲解并实现了基于CTC的文本对齐方法《【OCR】基于RCNN-CTC的不定长文本识别》。接下来将讲解并实现基于Encoder-Attention-Decoder的方式对齐不定长文本识别任务。

        原文链接为:《Robust Scene Text Recognition with Automatic Rectification》

        文章讲述了通过编码阶段用于特征抽取,包括:ConvNet和BiLSTM共同实现特征抽取;在解码阶段通过加入Attention结构将编码阶段所有时序特征全部用于解码推理。网络结构如下图所示:

        可从上图看出,Encoder部分采用ConvNet+BiLSTM,Decoder部分仅采用Attention推理。

        特别的:CTC对齐采用的是定义max_length,然后通过CTC的原理,再对max_length维的特征缩减实现文本的对齐;Encoder-Attention-Decoder对齐方式是通过定义<EOS>标志符,若在训练时遇到<EOS>标识符就立即停止。

        根据上述原理,本文实现了基于Encoder-Attention-Decoder的代码如下,代码结构与之前OCR系列文章一致。

二、代码实现

        说明:在config类里面,train_list为训练集路径及对应标签描述的txt、eval_list为验证集路径及对应标签描述的txt、test_img_paths为存放待推理图像文件夹路径。

对应的train_list.txt的组织结构如下图,文件路径+空格+标签+"\n",eval_list格式保持一致。

         save_model_dir为模型保存的地址,test_encoder_path和test_decoder_path为推理时读取模型的地址。然后istrain和istest用于控制训练和推理。修改好对应参数即可训练和推理。

import os
import random
import numpy as np

from PIL import Image
import cv2

import torch
import torch.utils.data
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.autograd import Variable
import collections
import collections.abc

cudnn.benchmark = True

class configs():
    def __init__(self):
        #Data
        self.train_list = r'E:\code\OCR\crnn_seq2seq_ocr_pytorch-master\data\train_list.txt'
        self.eval_list = r'E:\code\OCR\crnn_seq2seq_ocr_pytorch-master\data\valid_list.txt'
        self.img_height = 32
        self.img_width = 280

        self.save_model_dir = 'seq_models'
        self.get_lexicon_dir = './lbl2id_map.txt'

        # self.lexicon = self.get_lexicon(lexicon_name=self.get_lexicon_dir)
        self.lexicon = "0123456789"
        self.all_chars = {v: k for k, v in enumerate(self.lexicon)}
        self.all_nums = {v: k for v, k in enumerate(self.lexicon)}
        self.class_num = len(self.lexicon)+2
        self.label_word_length = 4

        self.random_sample = True #是否数据随机

        self.teaching_forcing_prob = 0.5

        #train
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.batch_size = 64
        self.epoch = 31
        self.save_model_fre_epoch = 1
        self.hidden_size = 256 # 隐层数量
        self.learning_rate = 0.0001
        self.encoder = ''
        self.decoder = ''
        self.max_width = 71 #最长字长
        #test/infer
        self.test_img_paths = r'E:\code\OCR\new_ocr\captcha_datasets\test-data-1'
        self.test_encoder_path = r'E:\code\OCR\crnn_seq2seq_ocr_pytorch-master\model\encoder_30.pth'
        self.test_decoder_path = r'E:\code\OCR\crnn_seq2seq_ocr_pytorch-master\model\decoder_30.pth'


        self.istrain = False
        self.istest = True

    def get_lexicon(self,lexicon_name):
        '''
        #获取词表 lbl2id_map.txt',词表格式如下
        #0\t0\n
        #a\t1\n
        #...
        #z\t63\n
        :param lexicons_name:
        :return:
        '''
        lexicons = open(lexicon_name, 'r', encoding='utf-8').readlines()
        lexicons_str = ''.join(word[0].split('\t')[0] for word in lexicons)
        return lexicons_str

cfg = configs()


#数据
class TextLineDataset(torch.utils.data.Dataset):

    def __init__(self, text_line_file=None, transform=None, target_transform=None):
        self.text_line_file = text_line_file
        with open(text_line_file) as fp:
            self.lines = fp.readlines()
            self.nSamples = len(self.lines)

        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return self.nSamples

    def __getitem__(self, index):
        assert index <= len(self), 'index range error'

        line_splits = self.lines[index].strip().split()
        img_path = line_splits[0]
        try:
            if 'train' in self.text_line_file:
                img = Image.open(img_path).convert('RGB')
            else:
                img = Image.open(img_path).convert('RGB')
        except IOError:
            print('Corrupted image for %d' % index)
            return self[index + 1]

        if self.transform is not None:
            img = self.transform(img)

        label = line_splits[1]

        if self.target_transform is not None:
            label = self.target_transform(label)

        return (img, label)

class ResizeNormalize(object):

    def __init__(self, img_width, img_height):
        self.img_width = img_width
        self.img_height = img_height
        self.toTensor = T.ToTensor()

    def __call__(self, img):
        img = np.array(img)
        h, w, c = img.shape
        height = self.img_height
        width = int(w * height / h)
        if width >= self.img_width:
            img = cv2.resize(img, (self.img_width, self.img_height))
        else:
            img = cv2.resize(img, (width, height))
            img_pad = np.zeros((self.img_height, self.img_width, c), dtype=img.dtype)
            img_pad[:height, :width, :] = img
            img = img_pad
        img = Image.fromarray(img)
        img = self.toTensor(img)
        img.sub_(0.5).div_(0.5)
        return img

class RandomSequentialSampler(torch.utils.data.sampler.Sampler):

    def __init__(self, data_source, batch_size):
        self.num_samples = len(data_source)
        self.batch_size = batch_size

    def __iter__(self):
        n_batches = len(self) // self.batch_size
        tail = len(self) % self.batch_size
        index = torch.LongTensor(len(self)).fill_(0)
        for i in range(n_batches):
            random_start = random.randint(0, len(self) - self.batch_size)
            batch_index = random_start + torch.arange(0, self.batch_size)
            index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index
        # deal with tail
        if tail:
            random_start = random.randint(0, len(self) - self.batch_size)
            tail_index = random_start + torch.arange(0, tail)
            index[(i + 1) * self.batch_size:] = tail_index

        return iter(index)

    def __len__(self):
        return self.num_samples

class AlignCollate(object):

    def __init__(self, img_height=32, img_width=100):
        self.img_height = img_height
        self.img_width = img_width
        self.transform = ResizeNormalize(img_width=self.img_width, img_height=self.img_height)

    def __call__(self, batch):
        images, labels = zip(*batch)

        images = [self.transform(image) for image in images]
        images = torch.cat([t.unsqueeze(0) for t in images], 0)

        return images, labels

def load_data(v, data):
    with torch.no_grad():
        v.resize_(data.size()).copy_(data)

SOS_TOKEN = 0  # special token for start of sentence
EOS_TOKEN = 1  # special token for end of sentence
class ConvertBetweenStringAndLabel(object):
    """Convert between str and label.

    NOTE:
        Insert `EOS` to the alphabet for attention.

    Args:
        alphabet (str): set of the possible characters.
        ignore_case (bool, default=True): whether or not to ignore all of the case.
    """

    def __init__(self, alphabet):
        self.alphabet = alphabet

        self.dict = {}
        self.dict['SOS_TOKEN'] = SOS_TOKEN
        self.dict['EOS_TOKEN'] = EOS_TOKEN
        for i, item in enumerate(self.alphabet):
            self.dict[item] = i + 2

    def encode(self, text):
        """
        Args:
            text (str or list of str): texts to convert.

        Returns:
            torch.IntTensor targets:max_length × batch_size
        """
        if isinstance(text, str):
            text = [self.dict[item] if item in self.dict else 2 for item in text]
        elif isinstance(text, collections.abc.Iterable):
            text = [self.encode(s) for s in text]
            max_length = max([len(x) for x in text])
            nb = len(text)
            targets = torch.ones(nb, max_length + 2) * 2
            for i in range(nb):
                targets[i][0] = 0
                targets[i][1:len(text[i]) + 1] = text[i]
                targets[i][len(text[i]) + 1] = 1
            text = targets.transpose(0, 1).contiguous()
            text = text.long()
        return torch.LongTensor(text)

    def decode(self, t):
        """Decode encoded texts back into strs.

        Args:
            torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
            torch.IntTensor [n]: length of each text.

        Raises:
            AssertionError: when the texts and its length does not match.

        Returns:
            text (str or list of str): texts to convert.
        """

        texts = list(self.dict.keys())[list(self.dict.values()).index(t)]
        return texts

converter = ConvertBetweenStringAndLabel(cfg.lexicon)

#模型

class CNN(nn.Module):
    def __init__(self, channel_size):
        super(CNN, self).__init__()
        self.cnn = nn.Sequential(
                      nn.Conv2d(channel_size, 64, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d(2, 2),
                      nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d(2, 2),
                      nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True),
                      nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d((2,2), (2,1), (0,1)),
                      nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),
                      nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d((2,2), (2,1), (0,1)),
                      nn.Conv2d(512, 512, 2, 1, 0), nn.BatchNorm2d(512), nn.ReLU(True))
    def forward(self, input):
        # [n, channel_size, 32, 280] -> [n, 512, 1, 71]
        conv = self.cnn(input)
        return conv

class BidirectionalLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(BidirectionalLSTM, self).__init__()
        self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True)
        self.embedding = nn.Linear(hidden_size * 2, output_size)
    def forward(self, input):
        recurrent, _ = self.rnn(input)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)
        output = self.embedding(t_rec)  # [T * b, output_size]
        output = output.view(T, b, -1)
        return output

class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=71):
        super(AttnDecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p
        self.max_length = max_length
        self.embedding = nn.Embedding(self.output_size, self.hidden_size)
        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.dropout = nn.Dropout(self.dropout_p)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)
    def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input)
        embedded = self.dropout(embedded)
        attn_weights = F.softmax(self.attn(torch.cat((embedded, hidden[0]), 1)), dim=1)
        attn_applied = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs.permute(1, 0, 2))
        output = torch.cat((embedded, attn_applied.squeeze(1)), 1)
        output = self.attn_combine(output).unsqueeze(0)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = F.log_softmax(self.out(output[0]), dim=1)
        return output, hidden, attn_weights
    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=cfg.device)

class Encoder(nn.Module):
    def __init__(self, channel_size, hidden_size):
        super(Encoder, self).__init__()
        self.cnn = CNN(channel_size)
        self.rnn = nn.Sequential(
            BidirectionalLSTM(512, hidden_size, hidden_size),
            BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
    def forward(self, input):
        # conv features
        conv = self.cnn(input)
        b, c, h, w = conv.size()
        assert h == 1, "the height of conv must be 1"
        # rnn feature
        conv = conv.squeeze(2)        # [b, c, 1, w] -> [b, c, w]
        conv = conv.permute(2, 0, 1)  # [b, c, w] -> [w, b, c]
        output = self.rnn(conv)
        return output

class Decoder(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=71):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.decoder = AttnDecoderRNN(hidden_size, output_size, dropout_p, max_length)
    def forward(self, input, hidden, encoder_outputs):
        return self.decoder(input, hidden, encoder_outputs)
    def initHidden(self, batch_size):
        result = Variable(torch.zeros(1, batch_size, self.hidden_size))
        return result


#utils 功能函数
#模型初始化
def weights_init(model):
    # Official init from torch repo.
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            nn.init.constant_(m.bias, 0)
#loss取平均
class Averager(object):
    """Compute average for `torch.Variable` and `torch.Tensor`. """

    def __init__(self):
        self.reset()

    def add(self, v):
        if isinstance(v, Variable):
            count = v.data.numel()
            v = v.data.sum()
        elif isinstance(v, torch.Tensor):
            count = v.numel()
            v = v.sum()

        self.n_count += count
        self.sum += v

    def reset(self):
        self.n_count = 0
        self.sum = 0

    def val(self):
        res = 0
        if self.n_count != 0:
            res = self.sum / float(self.n_count)
        return res

class ocr():
    def train(self):
        # create train dataset
        train_dataset = TextLineDataset(text_line_file=cfg.train_list, transform=None)
        sampler = RandomSequentialSampler(train_dataset, cfg.batch_size)
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=cfg.batch_size, shuffle=False, sampler=sampler, num_workers=4,
            collate_fn=AlignCollate(img_height=cfg.img_height, img_width=cfg.img_width))

        # create test dataset
        test_dataset = TextLineDataset(text_line_file=cfg.eval_list,
                                               transform=ResizeNormalize(img_width=cfg.img_width,
                                                                                 img_height=cfg.img_height))
        test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, batch_size=1,
                                                  num_workers=4)

        # create crnn/seq2seq/attention network
        encoder = Encoder(channel_size=3, hidden_size=cfg.hidden_size)
        # for prediction of an indefinite long sequence
        decoder = Decoder(hidden_size=cfg.hidden_size, output_size=cfg.class_num, dropout_p=0.1,
                               max_length=cfg.max_width)
        encoder.apply(weights_init)
        decoder.apply(weights_init)

        # create input tensor
        image = torch.FloatTensor(cfg.batch_size, 3, cfg.img_height, cfg.img_width)
        text = torch.LongTensor(cfg.batch_size)

        criterion = torch.nn.NLLLoss()

        encoder.to(cfg.device)
        decoder.to(cfg.device)
        image = image.to(cfg.device)
        text = text.to(cfg.device)
        criterion = criterion.to(cfg.device)

        # optimizer
        encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=cfg.learning_rate, betas=(0.5, 0.999))
        decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=cfg.learning_rate, betas=(0.5, 0.999))

        # loss averager
        loss_avg = Averager()

        for epoch in range(cfg.epoch):
            train_iter = iter(train_loader)

            for i in range(len(train_loader)):
                cpu_images, cpu_texts = train_iter.next()
                batch_size = cpu_images.size(0)

                for encoder_param, decoder_param in zip(encoder.parameters(), decoder.parameters()):
                    encoder_param.requires_grad = True
                    decoder_param.requires_grad = True
                encoder.train()
                decoder.train()

                target_variable = converter.encode(cpu_texts)
                load_data(image, cpu_images)

                # CNN + BiLSTM
                encoder_outputs = encoder(image)
                target_variable = target_variable.cuda()
                # start decoder for SOS_TOKEN
                decoder_input = target_variable[SOS_TOKEN].cuda()
                decoder_hidden = decoder.initHidden(batch_size).cuda()

                loss = 0.0
                teach_forcing = True if random.random() > cfg.teaching_forcing_prob else False
                if teach_forcing:
                    for di in range(1, target_variable.shape[0]):
                        decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden,
                                                                                    encoder_outputs)
                        loss += criterion(decoder_output, target_variable[di])
                        decoder_input = target_variable[di]
                else:
                    for di in range(1, target_variable.shape[0]):
                        decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden,
                                                                                    encoder_outputs)
                        loss += criterion(decoder_output, target_variable[di])
                        topv, topi = decoder_output.data.topk(1)
                        ni = topi.squeeze()
                        decoder_input = ni
                encoder.zero_grad()
                decoder.zero_grad()
                loss.backward()
                encoder_optimizer.step()
                decoder_optimizer.step()

                loss_avg.add(loss)

                if i % 10 == 0:
                    print(
                        '[Epoch {0}/{1}] [Batch {2}/{3}] Loss: {4}'.format(epoch, cfg.epoch, i, len(train_loader),
                                                                           loss_avg.val()))
                    loss_avg.reset()

            # save checkpoint
            torch.save(encoder.state_dict(), '{0}/encoder_{1}.pth'.format(cfg.save_model_dir, epoch))
            torch.save(decoder.state_dict(), '{0}/decoder_{1}.pth'.format(cfg.save_model_dir, epoch))

    def infer(self):

        encoder_name = cfg.test_encoder_path
        decoder_name = cfg.test_decoder_path
        correct = 0
        transformer = ResizeNormalize(img_width=cfg.img_width, img_height=cfg.img_height)
        for test_img_paths in os.listdir(cfg.test_img_paths):
            test_img_path = os.path.join(cfg.test_img_paths, test_img_paths)

            # image = Image.open(cfg.img_path).convert('RGB')
            image = Image.open(test_img_path).convert('RGB')
            image = transformer(image)
            image = image.to(cfg.device)
            image = image.view(1, *image.size())
            image = torch.autograd.Variable(image)

            encoder = Encoder(3, cfg.hidden_size)
            # no dropout during inference
            decoder = Decoder(cfg.hidden_size, cfg.class_num, dropout_p=0.0, max_length=cfg.max_width)

            encoder = encoder.to(cfg.device)
            decoder = decoder.to(cfg.device)

            # encoder.load_state_dict(torch.load(cfg.encoder, map_location=map_location))
            encoder.load_state_dict(torch.load(encoder_name, map_location='cuda'))
            # print('loading pretrained encoder models from {}.'.format(encoder_name))
            # decoder.load_state_dict(torch.load(cfg.decoder, map_location=map_location))
            decoder.load_state_dict(torch.load(decoder_name, map_location='cuda'))
            # print('loading pretrained decoder models from {}.'.format(decoder_name))

            encoder.eval()
            decoder.eval()

            encoder_out = encoder(image)

            max_length = 20
            decoder_input = torch.zeros(1).long()
            decoder_hidden = decoder.initHidden(1)
            decoder_input = decoder_input.to(cfg.device)
            decoder_hidden = decoder_hidden.to(cfg.device)

            words, prob = self.seq2seq_decode(encoder_out, decoder, decoder_input, decoder_hidden, max_length)
            # print('predict_string: {} => predict_probility: {}'.format(words, prob))

            if words == test_img_paths.replace('.png', '').split('_')[1]:
                correct += 1
        print("model" + '\t' + "|| acc: " + str(correct / len(os.listdir(cfg.test_img_paths))) + '\n')

    #解码推理
    def seq2seq_decode(self,encoder_out, decoder, decoder_input, decoder_hidden, max_length):
        decoded_words = []
        prob = 1.0
        for di in range(max_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_out)
            probs = torch.exp(decoder_output)
            _, topi = decoder_output.data.topk(1)
            ni = topi.squeeze(1)
            decoder_input = ni
            prob *= probs[:, ni]
            if ni == EOS_TOKEN:
                break
            else:
                decoded_words.append(converter.decode(ni))

        words = ''.join(decoded_words)
        prob = prob.item()

        return words, prob

if __name__ == '__main__':
    myocr = ocr()
    if cfg.istrain == True:
        myocr.train()
    if cfg.istest == True:
        myocr.infer()

三、结果讨论

        本文对上述代码训练30个epoch测试效果,采用captcha_datasets数据集作为实验数据集,训练集:验证集:测试集=25000:10000:10000。图片内容主要是数字验证码。在本次实验中采用30次迭代测试模型效果,train-nll_loss、test-acc效果如下表所示。

epochlossval/test-acc
18.2338523860
54.4181427960.12
100.3344426750.94
150.3156355920.97
200.0585773250.99
250.0744021760.98
300.0598752460.97

部分识别效果图展示:

        基于Encoder-Decoder结构在长文本类型效果因加入了注意力机制效果会较好,但是推理数据较CTC方案慢较多。因此也是一个速度与精度的balance。欢迎大家留言讨论,共同学习。

  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BoostingIsm

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值