python中seq函数作用_Seq2Seq模型和损失函数(以keras为单位)

在损失函数中使用K.eval或{}不是一个好主意。关于张量的所有想法是,它们有一个由tensorflow/keras管理的内部连接,通过它可以计算梯度和其他东西。在

使用eval并处理numpy值将破坏此连接并破坏模型。使用eval只能查看结果,而不是创建函数。在

使用ifs将不起作用,因为张量值不可用。但是有一些keras函数,如K.switch、K.greater、K.less等,都列在backend documentation中。在

您可以使用这些函数重新创建函数。在

但老实说,我认为你应该改用“掩蔽”或“班级加权”。在

掩蔽(溶液1)

如果您使用的是嵌入层,您可以有意地为“结束后无任何内容”保留零值。在

然后,您可以在嵌入层中使用mask_zero=True,并具有如下所示的输入:[2, #start token

3,

123,

1548, #end token

0, #nothing, value to be masked

0,

0,

0,

0,

0]

另一个选择是不使用“结束标记”,而是使用“零”。在

类别权重(解决方案2)

因为这很可能是因为您在期望的输出中拥有比其他任何东西都多的结束标记,所以您可以减少结束标记的相关性。在

计算输出中出现的每个类,并计算结束标记的比率。例如:计算所有其他类出现的平均值

计算结束标记的出现次数

ratio = other_classes_mean / end_token_occurences

然后在fit方法中,使用:

^{pr2}$

易于实现:class_weight = {i:1. for i in range(totalTokens)}

class_weight[1548] = ratio

model.fit(...,...,....., class_weight = class_weight,...)

(在这种情况下,请确保将0作为可能的类,或将索引移动1)

类似的损失函数(解决方案3)

请注意,y_pred永远不会“等于”y_true。在y_pred是可变的、连续的、可微的

y_true是精确且恒定的

为了进行比较,您应该使用“argmax”,它非常类似(如果不完全是)类索引。在def mean_absolute_error(y_true, y_pred):

#for comparing, let's take exact values

y_true_max = K.argmax(y_true)

y_pred_max = K.argmax(y_pred)

#compare with a proper tensor function

equal_mask = K.equal(y_true_max,y_pred_max)

is_start = K.equal(y_true_max, self.startTokenAsIndex)

is_end = K.equal(y_true_max, self.endTokenAsIndex)

#cast to float for multiplying and summing

equal_mask = K.cast(equal_mask, K.floatx())

is_start = K.cast(is_start, K.floatx())

is_end = K.cast(is_end, K.floatx())

#these are tensors with 0 (false) and 1 (true) as float

#entire condition as you wanted

condition = (is_start + is_end) * equal_mask

# sum = or ||| multiply = and

# we don't have to worry about the sum resulting in 2

# because you will never have startToken == endToken

#reverse condition:

condition = 1 - condition

#result

return condition * K.mean(K.abs(y_pred - y_true), axis=-1)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值