该问题可以参考https://blog.csdn.net/u014484783/article/details/88849971中的描述,无论是CTC_decode还是CTC_loss,每次运行都会创建节点,避免的方法是将其封装到model中,这样就固定了计算节点。
测试方法: 在初始化节点后(注意是在运行fit/predict至少一次后,因为这些方法也会更改计算图状态),运行K.get_session().graph.finalize()
锁定节点,此时如果图节点变了会报错并提示出错代码。
from keras import backend as K
from keras.layers import Lambda,Input
from keras import Model
from tensorflow.python.ops import ctc_ops as ctc
import tensorflow as tf
from keras.layers import Layer
class CTC_Batch_Cost():
'''
用于计算CTC loss
'''
def ctc_lambda_func(self,args):
"""Runs CTC loss algorithm on each batch element.
# Arguments
y_true: tensor `(samples, max_string_length)` 真实标签
y_pred: tensor `(samples, time_steps, num_categories)` 预测前未经过softmax的向量
input_length: tensor `(samples, 1)` 每一个y_pred的长度
label_length: tensor `(samples, 1)` 每一个y_true的长度
# Returns
Tensor with shape (samples,1) 包含了每一个样本的ctc loss
"""
y_true, y_pred, input_length, label_length = args
# y_pred = y_pred[:, :, :]
# y_pred = y_pred[:, 2:, :]
return self.ctc_batch_cost(y_true, y_pred, input_length, label_length)
def __call__(self, args):
'''
ctc_decode 每次创建会生成一个节点,这里参考了https://blog.csdn.net/u014484783/article/details/88849971
将ctc封装成模型,是否会解决这个问题还没有测试过这种方法是否还会出现创建节点的问题
'''
y_true = Input(shape=(None,))
y_pred = Input(shape=(None,None))
input_length