点击上方“算法数据侠”,选择“星标”公众号
第一时间获取最新推文与资源分享小侠客们好呀,我是oubahe。今天迎来“好技巧”专题第4文,在Tensorflow2.0发布以来,学术界和工业界就引起了广泛关注。谷歌及其团队对之前的Keras库做出了大量的Tensorflow专属优化与改动。但是与初代Tensorflow相比,实现复杂的多输入参数损失函数变得复杂晦涩了。本文将介绍Tensorflow2.0中tf.keras自定义实现复杂的损失函数,使得我们更容易在不同的模型框架中重复使用,构建符合实际场景需求的损失函数。来吧,展示~
01
—
问题描述
Tensorflow2.0中tf.keras可以直接利用API快速实现一些简单的损失函数,例如model.compile(loss="mse")。但是任何的简单方法都是有代价的,这个内置方法定义的损失函数有且只能有y_true和y_pred两个参数:
def simple_loss(y_true, y_pred): pass
那么一向以简洁易懂著称的Keras如何自定义复杂的损失函数,还不影响Keras漂亮的训练进度条呢?