理解Pytorch里面nn.CrossEntropyLoss的含义

理解Pytorch里面nn.CrossEntropyLoss的含义

  • 先说nn.CrossEntropyLoss的参数,如果神经网络的输出output是一个(batch_size, num_class, h, w)的tensor(其中,num_class代表分类问题的类别数,h为图像高度,w为图像宽度),则nn.CrossEntropyLoss需要的label形状为(batch_size, h, w),对每一个batch而言,label中的数据代表每一个像素所属的类别,如果是一个二分类问题,则label中的数值只能是0或者1,如果是三分类问题,则label中的数值可以是0,1,2,以此类推。
  • 交叉熵就是用来衡量两个分布之间的相似性的,因此也可以判断神经网络实际的输出与期望的输出的接近程度。假设有两个分布 p ( x ) , q ( x ) p(x),q(x) p(x),q(x),则两者的交叉熵为
    C E H = − ∑ x ∈ χ p ( x ) l o g ( q ( x ) ) CEH = -\sum_{x\in\chi}p(x)log(q(x)) CEH=xχp(x)log(q(x))
  • 在分类问题中,给定label和样本之后,该样本只能属于一个种类,假设该样本属于种类 k k k,则 p ( x = k ) = 1 , p ( x ≠ k ) = 0 p(x=k)=1, p(x\neq k)=0 p(x=k)=1,p(x=k)=0,因此该样本的输出和label的交叉熵可以化简为
    C E H = − l o g ( q ( x = k ) ) CEH = -log(q(x=k)) CEH=log(q(x=k))
  • 神经网络的输出一般为与类别数相等的向量,为了将向量转换为概率分布,即 q ( x = k ) q(x=k) q(x=k)的形式,必须使用softmax函数对神经网络的输出进行转换,而加入softmax函数后的交叉熵函数形式如下,该式即为nn.CrossEntropyLoss的公式
    l o s s ( x , k ) = − l o g ( e x p ( x [ k ] ) ∑ j e x p ( x [ j ] ) ) loss(x,k)=-log\left(\dfrac{exp(x[k])}{\sum_{j}exp(x[j])}\right) loss(x,k)=log(jexp(x[j])exp(x[k]))
import torch
import numpy as np
import torch.nn as nn
import math

a = torch.randn((4,3,8,8 ))
b = np.random.randint(0,3,(4,8,8))
b = torch.from_numpy(b)
loss_fn = nn.CrossEntropyLoss()
b = b.long()
loss = loss_fn(a, b)
loss
# tensor(1.3822)
#验证softmax2d就是对每一个N维度沿着C维度做softmax
m = nn.Softmax2d()
output = m(a)
#验证softmax2d就是对每一个N维度沿着C维度做softmax
a01 = math.exp(a[0,0,0,0])
a02 = math.exp(a[0,1,0,0])
aa = a01 + a02
print(a01/aa)
print(a02/aa)
print(output[0,0,0,0])
print(output[0,1,0,0])
loss = 0
for batch in range(4):
    for i in range(8):
        for j in range(8):
            if b[batch, i, j] == 1:
                loss = loss - math.log(output[batch, 1, i, j])
            if b[batch, i, j] == 0:
                loss = loss - math.log(output[batch, 0, i, j])
            if b[batch, i, j] == 2:
                loss = loss - math.log(output[batch, 2, i, j])
print(loss/64/4)  #将总的loss对总样本数取平均值,样本数为图像中像素数量8*8再*batch_size即为8*8*4
# 1.3822217100148755
  • 上述结果能够看出,手动计算的loss等于loss_fn计算得到的loss
  • 4
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值