tensorflow2.0损失函数总结和损失函数自定义

当我们尝试训练神经网络的时候,不可避免地要接触到损失函数,损失函数计算真实值和预测值的误差。tensorflow2.0已经给我们封装好的具备很多用途的损失函数,我们可以只用两行代码就可以直接使用,简直方便地不要不要的。

我先说如何使用,再说有哪些可以供我们挑选使用

如何使用看下面代码,分析过程在代码的注释里面,注意看代码注释,注意看代码注释,注意看代码注释。

from tensorflow.keras import losses

# 假设y_true是真实值, y_pred是网络预测值
import tensorflow as tf

y_true = tf.Variable(initial_value=tf.random.normal(shape=(32, 10)))
y_pred = tf.Variable(initial_value=tf.random.normal(shape=(32, 10)))

# 实例化一个损失对象
loss_object = losses.CategoricalCrossentropy()  # 这个类里面的参数很值的研究,一般都是默认即可
"""
类似于这种CategoricalCrossentropy类,tensorflow2.0给我提供了几种呢?
答案是有好多好多,具体分析接着看博客。你也可以按住ctrl健点选“losses”,进入源码里面看看
"""

# 通过该损失对象计算损失
losses_hjx = loss_object(y_true=y_true, y_pred=y_pred)  # losses_hjx为tf.Tensor(15.427775, shape=(), dtype=float32)


pass

我觉得,就算我现在把所有常用的损失函数都告诉你了,我相信你也是一刷而过,丝毫没有觉得有用的感觉,反而感觉压力很大。所以我就不把这些内置的损失函数逐一告诉你们了。换言之,我们完全可以不用别人写好的东西呀,我们想要什么就自己来自定义什么呗,难道不是很快乐吗?

所以

接着我要告诉你们如何自定义损失函数,当然啦,tensorflow2.0确实已经给我们做好了太多东西了,你可以直接使用他们的内置函数。想知道还有哪些内置函数的童鞋,评论区call我,我发给你10G资料研究研究啧啧啧。

如何自定义损失函数,代码分析在注释里面,注意看代码注释,注意看代码注释,注意看代码注释

from tensorflow.keras import losses

# 假设y_true是真实值, y_pred是网络预测值
import tensorflow as tf

y_true = tf.Variable(initial_value=tf.random.normal(shape=(32, 10)))
y_pred = tf.Variable(initial_value=tf.random.normal(shape=(32, 10)))


class FocalLoss(losses.Loss):  # 继承Loss类

    # 重写初始化方法,其实就是定义一些自己损失逻辑可能使用到的参数,格式如下
    def __init__(self, gamma=2.0, alpha=0.25, **kwargs):
        super(FocalLoss, self).__init__(**kwargs)
        self.gamma = gamma
        self.alpha = alpha

    # call函数是重点,重写了损失函数的运算逻辑,这也是一个损失函数的本质了,下面损失逻辑是我随便写的
    def call(self, y_true, y_pred):
        pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
        pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
        loss = pt_0 + pt_1
        return loss


# 使用方法跟内置损失函数的使用方法一样,看下面
# 实例化一个损失对象
loss_object = FocalLoss(name='focalloss')  # 这个类里面的参数必须要传递一个参数name,name的值可以自定义

# 通过该损失对象计算损失
losses_hjx = loss_object(y_true=y_true, y_pred=y_pred)  # losses_hjx为tf.Tensor(15.427775, shape=(), dtype=float32)

pass

有一些童鞋可能聪明一点,现在是不是在想,为啥我不直接以一个函数的形式来实现这个损失函数对吧。如果你想到这个问题,说明你太聪明了。

对的

为啥我们不使用自己写的普通形式的函数来定义损失函数呢。

原因就是

如果你这样做,你无法通过tensorflow2.0其它定义的优化器和回调函数来使用这个损失函数。大白话就是,tensorflow2.0希望作者按照他们之前规定的规则来做,这样能最大性能发挥tensorflow框架的性能。当然,你想自己通过手写函数的形式使用低阶tensorflow的api实现也是可以的,就是费力费时间而已。

好啦,本篇文章就到此结束了,恭喜你又懂得了一点新知识哦,爱你么么哒

  • 1
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值