CRNN模型Python实现笔记三

一、函数讲解

1. numel()函数

在Pytorch中, numel函数是torch.Tensor类的一个方法,它可以返回张量中的元素总数。
例如:

import torch

x = torch.randn(2, 3)
print(x.numel())

这将输出6,因为x有2行3列,总共有6个元素。

值得注意的是,torch.Tensor.nelement()也和torch.Tensor.numel()做同样的事情,所以您可以使用任何一个。

二、疑难代码段理解

1. strLabelConverter

# copy from utils
class strLabelConverter(object):
    def __init__(self, alphabet, ignore_case=False):
        self._ignore_case = ignore_case
        if self._ignore_case:
            alphabet = alphabet.lower()
        self.alphabet = alphabet + '_'  # for `-1` index

        self.dict = {}
        for i, char in enumerate(alphabet):
            # NOTE: 0 is reserved for 'blank' required by wrap_ctc
            self.dict[char] = i + 1

    # print(self.dict)
    def encode(self, text):
        length = []
        result = []
        for item in text:
            item = item.decode('utf-8', 'strict')
            length.append(len(item))
            for char in item:
                if char not in self.dict.keys():
                    index = 0
                else:
                    index = self.dict[char]
                result.append(index)
        text = result
        return (torch.IntTensor(text), torch.IntTensor(length))

    def decode(self, t, length, raw=False):
        if length.numel() == 1:
            length = length[0]
            assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(),
                                                                                                         length)
            if raw:
                return ''.join([self.alphabet[i - 1] for i in t])
            else:
                char_list = []
                for i in range(length):
                    if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
                        char_list.append(self.alphabet[t[i] - 1])
                return ''.join(char_list)
        else:
            # batch mode
            assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(
                t.numel(), length.sum())
            texts = []
            index = 0
            for i in range(length.numel()):
                l = length[i]
                texts.append(
                    self.decode(
                        t[index:index + l], torch.IntTensor([l]), raw=raw))
                index += l
            return texts

这段代码是定义了一个 strLabelConverter 类,它主要用于文本编码和解码。

类中定义了三个函数:

  • init(self, alphabet, ignore_case=False):初始化函数,用于创建一个 strLabelConverter 实例。alphabet 是字符集,ignore_case=False 表示是否忽略大小写。

  • encode(self, text):将文本编码为整数序列。

  • decode(self, t, length, raw=False):将整数序列解码为文本。

这个类中使用了字典 self.dict 和字符集 self.alphabet 将文本编码为整数序列和将整数序列解码为文本。

其中, self 是类中所有函数的第一个参数,它代表的是类的实例本身,在类的函数中可以通过 self 引用类的其它成员。

(1) def encode(self, text):函数的作用

这段代码是在strLabelConverter类中encode函数中,主要用于将文本编码为整数序列。

item = item.decode('utf-8', 'strict')
Python中字符串默认是以unicode编码的,如果字符串是以其他编码格式存储的,那么就需要使用decode()函数进行解码。

decode()函数接受两个参数,第一个参数是解码的编码格式,第二个参数是非法字节序列的处理方式,'strict’表示如果遇到非法字节序列将会抛出一个UnicodeDecodeError异常。

第二个参数的值有’strict’,‘ignore’,‘replace’,‘backslashreplace’ 四种。 'strict' 表示遇到非法字节序列会抛出异常, 'ignore' 表示忽略非法字节序列,'replace' 表示替换非法字节序列,'backslashreplace' 表示反斜杠替换。

首先,定义了两个空的列表lengthresult。然后遍历text中的每一项,对每一项进行decode('utf-8', 'strict')操作,将其解码为utf-8格式。接着将每一项的长度添加到length列表中。

接下来,对每一项中的每一个字符进行遍历,如果字符不在self.dict.keys()中,将其索引赋值为0;否则,将其索引赋值为self.dict[char]。最后把索引添加到result列表中。

最后,将result赋值给text,并用torch.IntTensor()textlength转换为张量后返回。

总之,这个函数的作用是将文本中的每个字符转换为相应的索引,并返回索引和文本长度的张量。

(2) def decode(self, text):函数的作用

raw=False有什么用?

raw=False表示在解码时,只保留非重复字符。

strLabelConverter类中decode函数中,如果raw参数为False,那么对于t中的每一个整数序列,如果不为0并且不是重复的,那么将其对应的字符添加到char_list中,最后将char_list拼接成字符串并返回。

这样的做法的意图是去除重复的字符,在实际应用中可能用于去除预测结果中的重复字符,如果raw=True则会返回所有字符。

(3) 关于函数定义中self的疑惑

python在类中定义函数为什么参数列表里面都有self?

在 Python 中,类中定义函数的参数列表里面加入 self 是为了让函数能够访问类的其它成员,也是一种约定俗成的用法。

self 是类中所有函数的第一个参数,它代表的是类的实例本身,在类的函数中可以通过 self 引用类的其它成员。

当一个类的实例调用一个函数时,系统会自动传入这个实例作为第一个参数,这样函数就能访问类的其它成员。

简而言之,self 参数起到类和函数之间连接的作用,让函数能够访问类。




三、附录crnn_recognizer.py

import torch.nn as nn
# import torchvision.models as models
import torch, os
from PIL import Image
import cv2
import torchvision.transforms as transforms
from torch.autograd import Variable
import numpy as np
import random
from crnn import CRNN
import config

# copy from mydataset
class resizeNormalize(object):
    def __init__(self, size, interpolation=Image.LANCZOS, is_test=True):
        self.size = size
        self.interpolation = interpolation
        self.toTensor = transforms.ToTensor()
        self.is_test = is_test

    def __call__(self, img):
        w, h = self.size
        w0 = img.size[0]
        h0 = img.size[1]
        if w <= (w0 / h0 * h):
            img = img.resize(self.size, self.interpolation)
            img = self.toTensor(img)
            img.sub_(0.5).div_(0.5)
        else:
            w_real = int(w0 / h0 * h)
            img = img.resize((w_real, h), self.interpolation)
            img = self.toTensor(img)
            img.sub_(0.5).div_(0.5)
            tmp = torch.zeros([img.shape[0], h, w])
            start = random.randint(0, w - w_real - 1)
            if self.is_test:
                start = 0
            tmp[:, :, start:start + w_real] = img
            img = tmp
        return img

# copy from utils
class strLabelConverter(object):
    def __init__(self, alphabet, ignore_case=False):
        self._ignore_case = ignore_case
        if self._ignore_case:
            alphabet = alphabet.lower()
        self.alphabet = alphabet + '_'  # for `-1` index

        self.dict = {}
        for i, char in enumerate(alphabet):
            # NOTE: 0 is reserved for 'blank' required by wrap_ctc
            self.dict[char] = i + 1

    # print(self.dict)
    def encode(self, text):
        length = []
        result = []
        for item in text:
            item = item.decode('utf-8', 'strict')
            length.append(len(item))
            for char in item:
                if char not in self.dict.keys():
                    index = 0
                else:
                    index = self.dict[char]
                result.append(index)
        text = result
        return (torch.IntTensor(text), torch.IntTensor(length))

    def decode(self, t, length, raw=False):
        if length.numel() == 1:
            length = length[0]
            assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(),
                                                                                                         length)
            if raw:
                return ''.join([self.alphabet[i - 1] for i in t])
            else:
                char_list = []
                for i in range(length):
                    if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
                        char_list.append(self.alphabet[t[i] - 1])
                return ''.join(char_list)
        else:
            # batch mode
            assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(
                t.numel(), length.sum())
            texts = []
            index = 0
            for i in range(length.numel()):
                l = length[i]
                texts.append(
                    self.decode(
                        t[index:index + l], torch.IntTensor([l]), raw=raw))
                index += l
            return texts

# recognize api
class PytorchOcr():
    def __init__(self, model_path):
        alphabet_unicode = config.alphabet_v2
        self.alphabet = ''.join([chr(uni) for uni in alphabet_unicode])
        # print(len(self.alphabet))
        self.nclass = len(self.alphabet) + 1
        self.model = CRNN(config.imgH, 1, self.nclass, 256)
        self.cuda = False
        if torch.cuda.is_available():
            self.cuda = True
            self.model.cuda()
            self.model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path).items()})
        else:
            # self.model = nn.DataParallel(self.model)
            self.model.load_state_dict(torch.load(model_path, map_location='cpu'))
        self.model.eval()
        self.converter = strLabelConverter(self.alphabet)

    def recognize(self, img):
        h,w = img.shape[:2]
        if len(img.shape) == 3:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        image = Image.fromarray(img)
        transformer = resizeNormalize((int(w/h*32), 32))
        image = transformer(image)
        image = image.view(1, *image.size())
        image = Variable(image)

        if self.cuda:
            image = image.cuda()

        preds = self.model(image)

        _, preds = preds.max(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)

        preds_size = Variable(torch.IntTensor([preds.size(0)]))
        txt = self.converter.decode(preds.data, preds_size.data, raw=False)

        return txt


if __name__ == '__main__':
    model_path = './crnn_models/CRNN-1008.pth'
    recognizer = PytorchOcr(model_path)
    img_name = 't1.jpg'
    img = cv2.imread(img_name)
    h, w = img.shape[:2]
    res = recognizer.recognize(img)
    print(res)


  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
CRNN是一种深度学习模型,结合了卷积神经网络(CNN)和循环神经网络(RNN),用于图像文本识别。下面是一个使用Keras实现CRNN的示例代码: ```python from keras.models import Sequential, Model from keras.layers import Input, Dense, Dropout, Activation, Flatten, Reshape, Permute from keras.layers.convolutional import Conv2D, MaxPooling2D from keras.layers.recurrent import GRU from keras.layers.wrappers import TimeDistributed from keras.layers.normalization import BatchNormalization # 定义模型输入形状和参数 input_shape = (32, 280, 1) num_classes = 10 max_text_len = 32 # 定义CNN层 def cnn_layers(input_shape): # 定义模型 model = Sequential() # 第一层卷积 model.add(Conv2D(64, (3, 3), activation='relu', padding='same', name='conv1', input_shape=input_shape)) model.add(MaxPooling2D(pool_size=(2, 2), name='pool1')) # 第二层卷积 model.add(Conv2D(128, (3, 3), activation='relu', padding='same', name='conv2')) model.add(MaxPooling2D(pool_size=(2, 2), name='pool2')) # 第层卷积 model.add(Conv2D(256, (3, 3), activation='relu', padding='same', name='conv3')) # 第四层卷积 model.add(Conv2D(256, (3, 3), activation='relu', padding='same', name='conv4')) model.add(MaxPooling2D(pool_size=(1, 2), name='pool3')) # 第五层卷积 model.add(Conv2D(512, (3, 3), activation='relu', padding='same', name='conv5')) model.add(BatchNormalization(name='batchnorm1')) # 第六层卷积 model.add(Conv2D(512, (3, 3), activation='relu', padding='same', name='conv6')) model.add(BatchNormalization(name='batchnorm2')) model.add(MaxPooling2D(pool_size=(1, 2), name='pool4')) # 输出CNN结果 model.add(Permute((2, 1, 3), name='permute')) model.add(TimeDistributed(Flatten(), name='timedistrib')) return model # 定义CRNN模型 def create_crnn(input_shape, num_classes, max_text_len): # 定义CNN层 cnn = cnn_layers(input_shape) # 定义RNNrnn = Sequential() rnn.add(GRU(256, return_sequences=True, name='gru1')) rnn.add(GRU(256, return_sequences=True, name='gru2')) rnn.add(Dropout(0.25, name='dropout')) # 定义最终输出层 input_data = Input(name='the_input', shape=input_shape, dtype='float32') inner = cnn(input_data) inner = rnn(inner) y_pred = Dense(num_classes, activation='softmax', name='dense')(inner) model = Model(inputs=input_data, outputs=y_pred) # 定义模型输出形状 model.output_length = lambda x: cnn.output_shape[1] return model # 创建CRNN模型 crnn_model = create_crnn(input_shape, num_classes, max_text_len) ``` 以上代码中,`cnn_layers`函数定义了CNN层,`create_crnn`函数定义了CRNN模型,包括CNN层、RNN层和最终输出层。`input_shape`参数指定了输入图像的形状,`num_classes`参数指定了输出类别数,`max_text_len`参数指定了输出文本的最大长度。最后,使用`create_crnn`函数创建CRNN模型并保存在`crnn_model`变量中。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

No_one-_-2022

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值