CRNN代码解析

CRNN主要分为四步
1.特征提取
2.序列转换
3.执行LSTM获取序列输出
4.进行CTC转换

CRNN使用

以下代码有三个作用
1.特征提取
2.序列转换
3.执行LSTM获取序列输出

    def inference(self, inputdata, name, reuse=False):
        """
        Main routine to construct the network
        :param inputdata:
        :param name:
        :param reuse:
        :return:
        """
        with tf.variable_scope(name_or_scope=name, reuse=reuse):
            # centerlized data
            inputdata = tf.divide(inputdata, 255.0)
            #1.特征提取阶段
            # first apply the cnn feature extraction stage
            cnn_out = self._feature_sequence_extraction(
                inputdata=inputdata, name='feature_extraction_module'
            )
            #2.第二步,  batch*1*25*512  变成 batch * 25 * 512
            # second apply the map to sequence stage
            sequence = self._map_to_sequence(
                inputdata=cnn_out, name='map_to_sequence_module'
            )
            #第三步,应用序列标签阶段
            # third apply the sequence label stage
            # net_out width, batch, n_classes
            # raw_pred   width, batch, 1
            net_out, raw_pred = self._sequence_label(
                inputdata=sequence, name='sequence_rnn_module'
            )

        return net_out

以下代码进行CTC解码

        train_decoded, train_log_prob = tf.nn.ctc_beam_search_decoder(
            train_inference_ret,
            CFG.ARCH.SEQ_LENGTH * np.ones(CFG.TRAIN.BATCH_SIZE),
            merge_repeated=False
        )

下面我们详细看看各个部分是如何实现的

1.特征提取层

使用的VGG提取特征
输入是batchsize321003 batchsizehwc
输出是batchsize125512 batchsizehwc

 def _feature_sequence_extraction(self, inputdata, name):
        """ Implements section 2.1 of the paper: "Feature Sequence Extraction"

        :param inputdata: eg. batch*32*100*3 NHWC format
        :param name:
        :return:
        _conv_stage:conv + bn + relu +  max_pool
        """
        with tf.variable_scope(name_or_scope=name):
            # batch*32*100*3
            conv1 = self._conv_stage(
                inputdata=inputdata, out_dims=64, name='conv1'
            )
            #batch*16*50*64
            conv2 = self._conv_stage(
                inputdata=conv1, out_dims=128, name='conv2'
            )
            # batch*8*25*128
            conv3 = self.conv2d(
                inputdata=conv2, out_channel=256, kernel_size=3, stride=1, use_bias=False, name='conv3'
            )
            # batch*8*25*256
            bn3 = self.layerbn(
                inputdata=conv3, is_training=self._is_training, name='bn3'
            )
            relu3 = self.relu(
                inputdata=bn3, name='relu3'
            )

            conv4 = self.conv2d(
                inputdata=relu3, out_channel=256, kernel_size=3, stride=1, use_bias=False, name='conv4'
            )
            # batch*8*25*256
            bn4 = self.layerbn(
                inputdata=conv4, is_training=self._is_training, name='bn4'
            )
            relu4 = self.relu(
                inputdata=bn4, name='relu4')


            max_pool4 = self.maxpooling(
                inputdata=relu4, kernel_size=[2, 1], stride=[2, 1], padding='VALID', name='max_pool4'
            )
            # batch*4*25*256
            conv5 = self.conv2d(
                inputdata=max_pool4, out_channel=512, kernel_size=3, stride=1, use_bias=False, name='conv5'
            )
            # batch*4*25*512
            bn5 = self.layerbn(
                inputdata=conv5, is_training=self._is_training, name='bn5'
            )
            relu5 = self.relu(
                inputdata=bn5, name='bn5'
            )
            conv6 = self.conv2d(
                inputdata=relu5, out_channel=512, kernel_size=3, stride=1, use_bias=False, name='conv6'
            )

            # batch*4*25*512
            bn6 = self.layerbn(
                inputdata=conv6, is_training=self._is_training, name='bn6'
            )

            relu6 = self.relu(
                inputdata=bn6, name='relu6'
            )

            max_pool6 = self.maxpooling(
                inputdata=relu6, kernel_size=[2, 1], stride=[2, 1], name='max_pool6'
            )
            # batch*2*25*512
            conv7 = self.conv2d(
                inputdata=max_pool6, out_channel=512, kernel_size=2, stride=[2, 1], use_bias=False, name='conv7'
            )
            # batch*1*25*512
            bn7 = self.layerbn(
                inputdata=conv7, is_training=self._is_training, name='bn7'
            )
            relu7 = self.relu(
                inputdata=bn7, name='bn7'
            )
        #return  batch*1*25*512
        return relu7

2.特征转换为序列

提取的特征LSTM不能直接使用,需要先进序列转换
输入batchsize125512
输出batchsize
25*512

 def _map_to_sequence(self, inputdata, name):
        """ Implements the map to sequence part of the network.

        This is used to convert the CNN feature map to the sequence used in the stacked LSTM layers later on.
        Note that this determines the length of the sequences that the LSTM expects
        :param inputdata:
        :param name:
        :return:
        """
        with tf.variable_scope(name_or_scope=name):

            shape = inputdata.get_shape().as_list()
            # H必须是一,这是LSTM网络输入的要求
            assert shape[1] == 1  # H of the feature map must equal to 1
            
            ret = self.squeeze(inputdata=inputdata, axis=1, name='squeeze')

        return ret

其中用到了 self.squeeze函数,我们看看它做了什么

    def squeeze(inputdata, axis=None, name=None):
        """
        :param inputdata:
        :param axis:
        :param name:
        :return:
        """
        return tf.squeeze(input=inputdata, axis=axis, name=name)

3.进行LSTM获取输出序列

输入batchsize125*512
输出width,batchn_classes

    def _sequence_label(self, inputdata, name):
        """ Implements the sequence label part of the network

        :param inputdata:
        :param name:
        :return:
        """
        with tf.variable_scope(name_or_scope=name):
            # construct stack lstm rcnn layer
            # forward lstm cell
            fw_cell_list = [tf.nn.rnn_cell.LSTMCell(nh, forget_bias=1.0) for
                            nh in [self._hidden_nums] * self._layers_nums]
            # Backward direction cells
            bw_cell_list = [tf.nn.rnn_cell.LSTMCell(nh, forget_bias=1.0) for
                            nh in [self._hidden_nums] * self._layers_nums]

            stack_lstm_layer, _, _ = rnn.stack_bidirectional_dynamic_rnn(
                fw_cell_list, bw_cell_list, inputdata,
                dtype=tf.float32
            )
            stack_lstm_layer = self.dropout(
                inputdata=stack_lstm_layer,
                keep_prob=0.5,
                is_training=self._is_training,
                name='sequence_drop_out'
            )

            [batch_s, _, hidden_nums] = inputdata.get_shape().as_list()  # [batch, width, 2*n_hidden]

            shape = tf.shape(stack_lstm_layer)
            rnn_reshaped = tf.reshape(stack_lstm_layer, [shape[0] * shape[1], shape[2]])

            w = tf.get_variable(
                name='w',
                shape=[hidden_nums, self._num_classes],
                initializer=tf.truncated_normal_initializer(stddev=0.02),
                trainable=True
            )

            # Doing the affine projection
            logits = tf.matmul(rnn_reshaped, w, name='logits')

            logits = tf.reshape(logits, [shape[0], shape[1], self._num_classes], name='logits_reshape')

            raw_pred = tf.argmax(tf.nn.softmax(logits), axis=2, name='raw_prediction')

            # Swap batch and batch axis
            rnn_out = tf.transpose(logits, [1, 0, 2], name='transpose_time_major')  # [width, batch, n_classes]

        return rnn_out, raw_pred
OCR技术是一种能够将图像中的文本内容转化为可编辑文本的技术,其中ctpn和crnn是OCR技术中的两个重要组成部分。 ctpn(Connectionist Text Proposal Network)是一种基于深度学习的文本检测算法,其主要任务是检测图像中的文本行和单个字符,并将其转换为一组矩形边界框。这些边界框可以用于后续的文本识别操作。 crnn(Convolutional Recurrent Neural Network)是一种基于深度学习的文本识别算法,其主要任务是根据文本检测阶段生成的文本行或单个字符图像,识别其中的文本内容。crnn算法通常由卷积神经网络(CNN)和循环神经网络(RNN)两个部分组成,其中CNN用于提取图像特征,RNN用于对特征序列进行建模。 以下是一个基于ctpn和crnn的OCR代码实现示例(Python): ```python import cv2 import numpy as np import tensorflow as tf # 加载ctpn模型 ctpn_model = cv2.dnn.readNet('ctpn.pb') # 加载crnn模型 crnn_model = tf.keras.models.load_model('crnn.h5') # 定义字符集 charset = '0123456789abcdefghijklmnopqrstuvwxyz' # 定义字符到索引的映射表 char_to_index = {char: index for index, char in enumerate(charset)} # 定义CTPN参数 ctpn_params = { 'model': 'ctpn', 'scale': 600, 'max_scale': 1200, 'text_proposals': 2000, 'min_size': 16, 'line_min_score': 0.9, 'text_proposal_min_score': 0.7, 'text_proposal_nms_threshold': 0.3, 'min_num_proposals': 2, 'max_num_proposals': 10 } # 定义CRNN参数 crnn_params = { 'model': 'crnn', 'img_w': 100, 'img_h': 32, 'num_classes': len(charset), 'rnn_units': 128, 'rnn_dropout': 0.25, 'rnn_recurrent_dropout': 0.25, 'rnn_activation': 'relu', 'rnn_type': 'lstm', 'rnn_direction': 'bidirectional', 'rnn_merge_mode': 'concat', 'cnn_filters': 32, 'cnn_kernel_size': (3, 3), 'cnn_activation': 'relu', 'cnn_pool_size': (2, 2) } # 定义文本检测函数 def detect_text(image): # 将图像转换为灰度图 gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # 缩放图像 scale = ctpn_params['scale'] max_scale = ctpn_params['max_scale'] if np.max(gray) > 1: gray = gray / 255 rows, cols = gray.shape if rows > max_scale: scale = max_scale / rows gray = cv2.resize(gray, (0, 0), fx=scale, fy=scale) rows, cols = gray.shape elif rows < scale: scale = scale / rows gray = cv2.resize(gray, (0, 0), fx=scale, fy=scale) rows, cols = gray.shape # 文本检测 ctpn_model.setInput(cv2.dnn.blobFromImage(gray)) output = ctpn_model.forward() boxes = [] for i in range(output.shape[2]): score = output[0, 0, i, 2] if score > ctpn_params['text_proposal_min_score']: x1 = int(output[0, 0, i, 3] * cols / scale) y1 = int(output[0, 0, i, 4] * rows / scale) x2 = int(output[0, 0, i, 5] * cols / scale) y2 = int(output[0, 0, i, 6] * rows / scale) boxes.append([x1, y1, x2, y2]) # 合并重叠的文本框 boxes = cv2.dnn.NMSBoxes(boxes, output[:, :, :, 2], ctpn_params['text_proposal_min_score'], ctpn_params['text_proposal_nms_threshold']) # 提取文本行图像 lines = [] for i in boxes: i = i[0] x1, y1, x2, y2 = boxes[i] line = gray[y1:y2, x1:x2] lines.append(line) return lines # 定义文本识别函数 def recognize_text(image): # 缩放图像 img_w, img_h = crnn_params['img_w'], crnn_params['img_h'] image = cv2.resize(image, (img_w, img_h)) # 归一化图像 if np.max(image) > 1: image = image / 255 # 转换图像格式 image = image.transpose([1, 0, 2]) image = np.expand_dims(image, axis=0) # 预测文本 y_pred = crnn_model.predict(image) y_pred = np.argmax(y_pred, axis=2)[0] # 将预测结果转换为文本 text = '' for i in y_pred: if i != len(charset) - 1 and (not (len(text) > 0 and text[-1] == charset[i])): text += charset[i] return text # 读取图像 image = cv2.imread('test.png') # 检测文本行 lines = detect_text(image) # 识别文本 texts = [] for line in lines: text = recognize_text(line) texts.append(text) # 输出识别结果 print(texts) ``` 上述代码实现了一个基于ctpn和crnn的OCR系统,其中ctpn用于检测文本行,crnn用于识别文本内容。在使用代码时,需要将ctpn和crnn的模型文件替换为自己训练的模型文件,并根据实际情况调整参数。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值