CTC_loss和CTC_decode的模型封装代码避免节点不断增加

为解决CTC_loss和CTC_decode每次运行导致计算图节点不断增多的问题,可以将它们封装到Keras或TensorFlow模型中。如此一来,计算节点数量将保持固定。一种测试方法是初始化模型并至少运行一次fit或predict,之后尝试锁定节点,如果节点变化会引发错误并显示错误代码。
摘要由CSDN通过智能技术生成

该问题可以参考https://blog.csdn.net/u014484783/article/details/88849971中的描述,无论是CTC_decode还是CTC_loss,每次运行都会创建节点,避免的方法是将其封装到model中,这样就固定了计算节点。

测试方法: 在初始化节点后(注意是在运行fit/predict至少一次后,因为这些方法也会更改计算图状态),运行K.get_session().graph.finalize()锁定节点,此时如果图节点变了会报错并提示出错代码。

from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer
class CTC_Batch_Cost():
    '''
    用于计算CTC loss
    '''
    def ctc_lambda_func(self,args):
        """Runs CTC loss algorithm on each batch element.

        # Arguments
            y_true: tensor `(samples, max_string_length)` 真实标签
            y_pred: tensor `(samples, time_steps, num_categories)` 预测前未经过softmax的向量
            input_length: tensor `(samples, 1)` 每一个y_pred的长度
            label_length: tensor `(samples, 1)` 每一个y_true的长度

            # Returns
                Tensor with shape (samples,1) 包含了每一个样本的ctc loss
            """
        y_true, y_pred, input_length, label_length = args

        # y_pred = y_pred[:, :, :]
        # y_pred = y_pred[:, 2:, :]
        return self.ctc_batch_cost(y_true, y_pred, input_length, label_length)

    def __call__(self, args):
        '''
        ctc_decode 每次创建会生成一个节点,这里参考了https://blog.csdn.net/u014484783/article/details/88849971
        将ctc封装成模型,是否会解决这个问题还没有测试过这种方法是否还会出现创建节点的问题
        '''
        y_true = Input(shape=(None,))
        y_pred = Input(shape=(None,None))
        input_length 
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值