查看下面的一段代码内容
from keras.layers import *
import tensorflow.keras as keras
import tensorflow.keras.backend as K
#from tensorflow.backend import keras,K
import tensorflow as tf
class Loss(Layer):
"""特殊的层,用来定义复杂loss
"""
def __init__(self, output_axis=None, **kwargs):
super(Loss, self).__init__(**kwargs)
self.output_axis = output_axis
def call(self, inputs, mask=None):
loss = self.compute_loss(inputs, mask)
self.add_loss(loss, inputs=inputs)
if self.output_axis is None:
return inputs
elif isinstance(self.output_axis, list):
return [inputs[i] for i in self.output_axis]
else:
return inputs[self.output_axis]
def compute_loss(self, inputs, mask=None):
raise NotImplementedError
def compute_output_shape(self, input_shape):
if self.output_axis is None:
return input_shape
elif isinstance(self.output_axis, list):
return [input_shape[i] for i in self.output_axis]
else:
return input_shape[self.output_axis]
def compute_mask(self, inputs, mask):
if mask is not None:
if self.output_axis is None:
return mask
elif isinstance(self.output_axis, list):
return [mask[i] for i in self.output_axis]
else:
return mask[self.output_axis]
def get_config(self):
config = {
'output_axis': self.output_axis,
}
base_config = super(Loss, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class CrossEntropy(Loss):
"""交叉熵作为loss,并mask掉输入部分
"""
def compute_loss(self, inputs, mask=None):
y_true, y_mask, y_pred = inputs
y_true = y_true[:, 1:] # 目标token_ids
y_mask = y_mask[:, 1:] # segment_ids,刚好指示了要预测的部分
y_pred = y_pred[:, :-1] # 预测序列,错开一位
loss = K.sparse_categorical_crossentropy(y_true, y_pred)
loss = K.sum(loss * y_mask) / K.sum(y_mask)
return loss
data1 = tf.ones([1,2],tf.float32)
data2 = tf.ones([1,2,5],tf.float32)
result = [data1,data1,data2]
output = CrossEntropy(2)(result)
这里关键讲解的语句内容为
output = CrossEntropy(2)(result)
这里面2所代表的含义为CrossEntropy的父类Loss初始化的参数内容
class Loss(Layer):
"""特殊的层,用来定义复杂loss
"""
def __init__(self, output_axis=None, **kwargs):
super(Loss, self).__init__(**kwargs)
self.output_axis = output_axis
将相应的output_axis赋值为2,接着会因为Loss类传入相应的输入参数results进而调用对应的call函数
def call(self, inputs, mask=None):
loss = self.compute_loss(inputs, mask)
self.add_loss(loss, inputs=inputs)
if self.output_axis is None:
return inputs
elif isinstance(self.output_axis, list):
return [inputs[i] for i in self.output_axis]
else:
return inputs[self.output_axis]
这里的inputs放入的为相应的result内容,所以第一个(2)为传入的初始化参数,第二个results为call调用时传入的result的数组内容