tf2自定义损失函数测试

main.py

import tensorflow as tf
from custom_loss import focal_loss


mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0


model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=focal_loss(), # 自定义损失函数
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)
print()
model.evaluate(x_test,  y_test, verbose=2)

custom_loss

from tensorflow.keras.losses import Loss
from tensorflow.keras.losses import binary_crossentropy
import tensorflow as tf


class focal_loss(Loss):
    def __init__(self, alpha=0.25, gamma=2,**kwargs):
        super(focal_loss,self).__init__(**kwargs)
        self.gamma = gamma
        self.alpha = alpha
    
    def call(self,y_true,y_pred):
        # y_true转成和y_pred一样的shape
        y_true = tf.squeeze(tf.one_hot(y_true,depth=10))
        BCE = binary_crossentropy(y_true,y_pred)
        pt = tf.math.exp(-BCE)

        F_loss = self.alpha * (1-pt)**self.gamma * BCE
        loss = tf.reduce_mean(F_loss)
        return loss


class MeanSquaredError(tf.keras.losses.Loss):
    def call(self, y_true, y_pred):
        return tf.reduce_mean(tf.square(y_pred - y_true))


Epoch 1/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.0011 - accuracy: 0.8727
Epoch 2/5
1875/1875 [==============================] - 3s 2ms/step - loss: 3.9561e-04 - accuracy: 0.9392  
Epoch 3/5
1875/1875 [==============================] - 3s 2ms/step - loss: 2.8197e-04 - accuracy: 0.9538  
Epoch 4/5
1875/1875 [==============================] - 3s 2ms/step - loss: 2.2211e-04 - accuracy: 0.9626  
Epoch 5/5
1875/1875 [==============================] - 3s 2ms/step - loss: 1.8357e-04 - accuracy: 0.9680  

313/313 - 0s - loss: 1.8853e-04 - accuracy: 0.9706 - 414ms/epoch - 1ms/step

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值