Tensorflow不仅支持经典损失函数,还可以优化任意的自定义损失函数。下面以预测产品销量为例。
在预测产品销量时,如果预测多了,商家损失的是生产商品的成本,如果预测少了,损失的是商品的利润。前面所讲到的均方误差损失函数不能很好的最大化销售利润。下面的公式给出了一个当预测多于真实值和预测少于真实值时有不同损失系数的损失函数:
其中yi为一个batch中第i个数据的正确答案,yi‘为神经网络得到的预测值,a和b为常量,通过以下代码实现这个损失函数:
loss=tf.reduce_sum(tf.where(tf.greater(v1,v2),(v1-v2)*a,(v2-v1)*b))
上面代码用到了tf.greater和tf.where来实现选择操作。tf.greater的输入是两个张量,此函数会比较这两个输入张量中每一个元素的大小,并返回比较结果。当输入的张量维度不一样时,Tensorflow会进行类似NumPy广播操作的处理。tf.where函数有三个参数。第一个为选择条件根据,当选择条件为True时,tf.where函数会选择第二个参数的值,否则会使用第三个参数中的值。tf.where函数判断和选择都是在元素级别进行,以下代码展示了tf.where函数和tf.greater函数的用法。
import tensorflow as tf
v1=tf.constant([1.0,2.0,3.0,4.0])
v2=tf.const