类的初始化套用参数的代码例子

查看下面的一段代码内容

from keras.layers import *
import tensorflow.keras as keras
import tensorflow.keras.backend as K
#from tensorflow.backend import keras,K
import tensorflow as tf
class Loss(Layer):
    """特殊的层,用来定义复杂loss
    """
    def __init__(self, output_axis=None, **kwargs):
        super(Loss, self).__init__(**kwargs)
        self.output_axis = output_axis

    def call(self, inputs, mask=None):
        loss = self.compute_loss(inputs, mask)
        self.add_loss(loss, inputs=inputs)
        if self.output_axis is None:
            return inputs
        elif isinstance(self.output_axis, list):
            return [inputs[i] for i in self.output_axis]
        else:
            return inputs[self.output_axis]

    def compute_loss(self, inputs, mask=None):
        raise NotImplementedError

    def compute_output_shape(self, input_shape):
        if self.output_axis is None:
            return input_shape
        elif isinstance(self.output_axis, list):
            return [input_shape[i] for i in self.output_axis]
        else:
            return input_shape[self.output_axis]

    def compute_mask(self, inputs, mask):
        if mask is not None:
            if self.output_axis is None:
                return mask
            elif isinstance(self.output_axis, list):
                return [mask[i] for i in self.output_axis]
            else:
                return mask[self.output_axis]

    def get_config(self):
        config = {
            'output_axis': self.output_axis,
        }
        base_config = super(Loss, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

class CrossEntropy(Loss):
    """交叉熵作为loss,并mask掉输入部分
    """
    def compute_loss(self, inputs, mask=None):
        y_true, y_mask, y_pred = inputs
        y_true = y_true[:, 1:]  # 目标token_ids
        y_mask = y_mask[:, 1:]  # segment_ids,刚好指示了要预测的部分
        y_pred = y_pred[:, :-1]  # 预测序列,错开一位
        loss = K.sparse_categorical_crossentropy(y_true, y_pred)
        loss = K.sum(loss * y_mask) / K.sum(y_mask)
        return loss

data1 = tf.ones([1,2],tf.float32)
data2 = tf.ones([1,2,5],tf.float32)
result = [data1,data1,data2]
output = CrossEntropy(2)(result)

这里关键讲解的语句内容为

output = CrossEntropy(2)(result)

这里面2所代表的含义为CrossEntropy的父类Loss初始化的参数内容

class Loss(Layer):
    """特殊的层,用来定义复杂loss
    """
    def __init__(self, output_axis=None, **kwargs):
        super(Loss, self).__init__(**kwargs)
        self.output_axis = output_axis

将相应的output_axis赋值为2,接着会因为Loss类传入相应的输入参数results进而调用对应的call函数

def call(self, inputs, mask=None):
    loss = self.compute_loss(inputs, mask)
    self.add_loss(loss, inputs=inputs)
    if self.output_axis is None:
        return inputs
    elif isinstance(self.output_axis, list):
        return [inputs[i] for i in self.output_axis]
    else:
        return inputs[self.output_axis]

这里的inputs放入的为相应的result内容,所以第一个(2)为传入的初始化参数,第二个results为call调用时传入的result的数组内容

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值