CRNN学习笔记

 

最近学习了CRNN网络,大体训练流程如下:

1、准备输入数据和标签,标签为稀疏矩阵

inputs = tf.placeholder(tf.float32, [batch_size, input_height, input_width, 1])

label = tf.sparse_placeholder(tf.int32, name='label')

seq_len = tf.placeholder(tf.int32, [None], name='seq_len') 

2、通过CNN网络提取特征

cnn_out = self._cnn(inputs)

3、通过2次双向RNN,得到神经网络输出结果

crnn_model = self._rnn(cnn_out, self._seq_len)

4、根据最终字符的类别得到最终的输出

logits = tf.reshape(crnn_model, [-1, 512])

W = tf.Variable(tf.truncated_normal([512, self._class_num], stddev=0.1), name="W")
b = tf.Variable(tf.constant(0., shape=[self._class_num]), name="b")

logits = tf.matmul(logits, W) + b

logits = tf.reshape(logits, [self._batch_size, -1, self._class_num])

        # 网络层输出
net_output = tf.transpose(logits, (1, 0, 2))

5、解析网络输出,其中decoded[0]是一个稀疏张量,类型和label一样

decoded, log_prob = tf.nn.ctc_greedy_decoder(net_output, self._seq_len)

6、损失函数loss

with tf.name_scope('loss'):
    loss = tf.nn.ctc_loss(self._label, self._net_output, self._seq_len)
    loss = tf.reduce_mean(loss)

7、优化器optimizer

with tf.name_scope('optimizer'):
    train_op = tf.train.AdamOptimizer(self._learning_rate).minimize(loss)

8、准确率accuracy

with tf.name_scope('accuracy'):
    accuracy = 1 - tf.reduce_mean(tf.edit_distance(tf.cast(self._decoded[0], tf.int32), self._label))
    accuracy_broad = tf.summary.scalar("accuracy", accuracy)

9、喂数据进行训练

feed_dict = {self._inputs: batch_data,self._label: batch_label, \
self._seq_len: [self._max_char_count] * self.batch_size}
sess.run(train_op, feed_dict=feed_dict)

 

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值