tensorflow 实现RBFSoftmax

参考论文:

RBF_softmax:Learning Deep Representative Prototypes with Radial Basis Function Softmax

 


"""
RBF_softmax:Learning Deep Representative Prototypes with Radial Basis Function Softmax
交叉熵是深度学习中非常常用的一种损失,通过交叉熵学到的特征表示会有比较大的类内的多样性。因为传统的softmax损失优化的是类内和类间的差异的最大化,也就是类内和类间的距离(logits)的差别的最大化,没有办法得到表示类别的向量表示来对类内距离进行正则化。之前的方法都是想办法增加类内的内聚性,而忽视了不同的类别之间的关系。本文提出了Radial Basis Function(RBF)距离来代替原来的softmax中的內积,这样可以自适应的给类内和类间距离施加正则化,可以得到更好的表示类别的向量,从而提高性能。

github源码:https://github.com/2han9x1a0release/RBF-Softmax
主要参考实现:https://github.com/2han9x1a0release/RBF-Softmax/blob/master/pycls/losses/rbflogit.py

具体内容可以参考原始论文:
中文讲解:https://blog.csdn.net/u011984148/article/details/108688071

"""

# 以下是使用tensorflow重现RBF softmax
"""
以下是使用tensorflow重现RBF softmax

"""


import tensorflow as tf


class RBFSoftmax(tf.layers.Layer):

    def __init__(self, feature_dim, num_classes, scale, gamma):
        super(RBFSoftmax, self).__init__()

        self.feature_dim = feature_dim
        self.num_classes = num_classes
        self.scale = scale
        self.gamma = gamma

    def build(self, input_shape):
        # 代表着每个类都有一个类中心的向量,用来计算RBF score (也可以添加偏置bias)
        self.weight = tf.Variable(tf.truncated_normal(shape=(self.num_classes, self.feature_dims), stddev=0.02))
        # self.bias = tf.Variable([0] * self.num_classes)
        self.built = True

    def call(self, inputs, training=None):
        """ 计算RBF logits

        :param inputs: Tensor, shape:(batch, feature_dim)
        :return: Tensor, shape:(batch, num_classes)

        可以先经过若干层的dense层,再进行计算RBFSoftmax

        """

        diff = tf.expand_dims(self.weight, axis=0) - tf.expand_dims(inputs, axis=1)
        diff = tf.multiply(diff, diff)
        metric = tf.reduce_sum(diff, axis=-1)  # shape: (batch, num_classes)
        kernel_metric = tf.exp(-1.0 * metric / self.gamma)
        logits = self.scale * kernel_metric
        return logits


"""
使用样例demo: 

rbflogit = RBFSoftmax(...)
logits = rbflogit(inputs)
true_one_hot_labels = ...
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=true_one_hot_labels, logits=logits)
loss = tf.reduce_mean(losses)
"""

 

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值