CRNN主要分为四步
1.特征提取
2.序列转换
3.执行LSTM获取序列输出
4.进行CTC转换
CRNN使用
以下代码有三个作用
1.特征提取
2.序列转换
3.执行LSTM获取序列输出
def inference(self, inputdata, name, reuse=False):
"""
Main routine to construct the network
:param inputdata:
:param name:
:param reuse:
:return:
"""
with tf.variable_scope(name_or_scope=name, reuse=reuse):
# centerlized data
inputdata = tf.divide(inputdata, 255.0)
#1.特征提取阶段
# first apply the cnn feature extraction stage
cnn_out = self._feature_sequence_extraction(
inputdata=inputdata, name='feature_extraction_module'
)
#2.第二步, batch*1*25*512 变成 batch * 25 * 512
# second apply the map to sequence stage
sequence = self._map_to_sequence(
inputdata=cnn_out, name='map_to_sequence_module'
)
#第三步,应用序列标签阶段
# third apply the sequence label stage
# net_out width, batch, n_classes
# raw_pred width, batch, 1
net_out, raw_pred = self._sequence_label(
inputdata=sequence, name='sequence_rnn_module'
)
return net_out
以下代码进行CTC解码
train_decoded, train_log_prob = tf.nn.ctc_beam_search_decoder(
train_inference_ret,
CFG.ARCH.SEQ_LENGTH * np.ones(CFG.TRAIN.BATCH_SIZE),
merge_repeated=False
)
下面我们详细看看各个部分是如何实现的
1.特征提取层
使用的VGG提取特征
输入是batchsize321003 batchsizehwc
输出是batchsize125512 batchsizehwc
def _feature_sequence_extraction(self, inputdata, name):
""" Implements section 2.1 of the paper: "Feature Sequence Extraction"
:param inputdata: eg. batch*32*100*3 NHWC format
:param name:
:return:
_conv_stage:conv + bn + relu + max_pool
"""
with tf.variable_scope(name_or_scope=name):
# batch*32*100*3
conv1 = self._conv_stage(
inputdata=inputdata, out_dims=64, name='conv1'
)
#batch*16*50*64
conv2 = self._conv_stage(
inputdata=conv1, out_dims=128, name='conv2'
)
# batch*8*25*128
conv3 = self.conv2d(
inputdata=conv2, out_channel=256, kernel_size=3, stride=1, use_bias=False, name='conv3'
)
# batch*8*25*256
bn3 = self.layerbn(
inputdata=conv3, is_training=self._is_training, name='bn3'
)
relu3 = self.relu(
inputdata=bn3, name='relu3'
)
conv4 = self.conv2d(
inputdata=relu3, out_channel=256, kernel_size=3, stride=1, use_bias=False, name='conv4'
)
# batch*8*25*256
bn4 = self.layerbn(
inputdata=conv4, is_training=self._is_training, name='bn4'
)
relu4 = self.relu(
inputdata=bn4, name='relu4')
max_pool4 = self.maxpooling(
inputdata=relu4, kernel_size=[2, 1], stride=[2, 1], padding='VALID', name='max_pool4'
)
# batch*4*25*256
conv5 = self.conv2d(
inputdata=max_pool4, out_channel=512, kernel_size=3, stride=1, use_bias=False, name='conv5'
)
# batch*4*25*512
bn5 = self.layerbn(
inputdata=conv5, is_training=self._is_training, name='bn5'
)
relu5 = self.relu(
inputdata=bn5, name='bn5'
)
conv6 = self.conv2d(
inputdata=relu5, out_channel=512, kernel_size=3, stride=1, use_bias=False, name='conv6'
)
# batch*4*25*512
bn6 = self.layerbn(
inputdata=conv6, is_training=self._is_training, name='bn6'
)
relu6 = self.relu(
inputdata=bn6, name='relu6'
)
max_pool6 = self.maxpooling(
inputdata=relu6, kernel_size=[2, 1], stride=[2, 1], name='max_pool6'
)
# batch*2*25*512
conv7 = self.conv2d(
inputdata=max_pool6, out_channel=512, kernel_size=2, stride=[2, 1], use_bias=False, name='conv7'
)
# batch*1*25*512
bn7 = self.layerbn(
inputdata=conv7, is_training=self._is_training, name='bn7'
)
relu7 = self.relu(
inputdata=bn7, name='bn7'
)
#return batch*1*25*512
return relu7
2.特征转换为序列
提取的特征LSTM不能直接使用,需要先进序列转换
输入batchsize125512
输出batchsize25*512
def _map_to_sequence(self, inputdata, name):
""" Implements the map to sequence part of the network.
This is used to convert the CNN feature map to the sequence used in the stacked LSTM layers later on.
Note that this determines the length of the sequences that the LSTM expects
:param inputdata:
:param name:
:return:
"""
with tf.variable_scope(name_or_scope=name):
shape = inputdata.get_shape().as_list()
# H必须是一,这是LSTM网络输入的要求
assert shape[1] == 1 # H of the feature map must equal to 1
ret = self.squeeze(inputdata=inputdata, axis=1, name='squeeze')
return ret
其中用到了 self.squeeze函数,我们看看它做了什么
def squeeze(inputdata, axis=None, name=None):
"""
:param inputdata:
:param axis:
:param name:
:return:
"""
return tf.squeeze(input=inputdata, axis=axis, name=name)
3.进行LSTM获取输出序列
输入batchsize125*512
输出width,batchn_classes
def _sequence_label(self, inputdata, name):
""" Implements the sequence label part of the network
:param inputdata:
:param name:
:return:
"""
with tf.variable_scope(name_or_scope=name):
# construct stack lstm rcnn layer
# forward lstm cell
fw_cell_list = [tf.nn.rnn_cell.LSTMCell(nh, forget_bias=1.0) for
nh in [self._hidden_nums] * self._layers_nums]
# Backward direction cells
bw_cell_list = [tf.nn.rnn_cell.LSTMCell(nh, forget_bias=1.0) for
nh in [self._hidden_nums] * self._layers_nums]
stack_lstm_layer, _, _ = rnn.stack_bidirectional_dynamic_rnn(
fw_cell_list, bw_cell_list, inputdata,
dtype=tf.float32
)
stack_lstm_layer = self.dropout(
inputdata=stack_lstm_layer,
keep_prob=0.5,
is_training=self._is_training,
name='sequence_drop_out'
)
[batch_s, _, hidden_nums] = inputdata.get_shape().as_list() # [batch, width, 2*n_hidden]
shape = tf.shape(stack_lstm_layer)
rnn_reshaped = tf.reshape(stack_lstm_layer, [shape[0] * shape[1], shape[2]])
w = tf.get_variable(
name='w',
shape=[hidden_nums, self._num_classes],
initializer=tf.truncated_normal_initializer(stddev=0.02),
trainable=True
)
# Doing the affine projection
logits = tf.matmul(rnn_reshaped, w, name='logits')
logits = tf.reshape(logits, [shape[0], shape[1], self._num_classes], name='logits_reshape')
raw_pred = tf.argmax(tf.nn.softmax(logits), axis=2, name='raw_prediction')
# Swap batch and batch axis
rnn_out = tf.transpose(logits, [1, 0, 2], name='transpose_time_major') # [width, batch, n_classes]
return rnn_out, raw_pred