使用CrossEntropyLoss损失函数时出现 RuntimeError: multi-target not supported at,以及动手计算CrossEntropyLoss

应用CrossEntropyLoss时,注意

Input: (N, C) where C = number of classes

Target: (N)

不要写程序把target独热化。这是我遇到时候的解决办法。

另外可能有用的:https://blog.csdn.net/ccbrid/article/details/79844005

https://blog.csdn.net/geter_CS/article/details/84857220

 

看了CrossEntropyLoss的官方文档:

>>> loss = nn.CrossEntropyLoss(weight=torch.tensor([1.,1,1,1,1])
>>> indata = torch.randn(3, 5, requires_grad=True) #size:(batch,nclasses)==>(3,5) 
tensor([[-0.6566,  0.5376,  0.4272,  0.0929, -0.9580],
        [-0.0633, -0.2284,  0.8822,  1.3089,  2.4364],
        [-1.0384, -1.4110, -0.2139, -2.3812, -0.3994]], requires_grad=True)

#target是一个(num_sample=3)的张量,并没有独热化。
>>> target = torch.empty(3, dtype=torch.long).random_(5) 
tensor([2, 4, 0])

>>> output = loss(indata, target)
tensor(1.1884, grad_fn=<NllLossBackward>)

>>> output.backward()


#Compute Loss
>>> indata_softmax = nn.Softmax(1)(indata)  #计算sofmax,也可以用nn.LogSoftmax(),这样就免去
#了下一步计算log

tensor([[0.0989, 0.3264, 0.2923, 0.2092, 0.0732],
        [0.0487, 0.0413, 0.1253, 0.1920, 0.5928],
        [0.1633, 0.1125, 0.3723, 0.0426, 0.3093]], grad_fn=<SoftmaxBackward>)

>>> pre_target = indata_softmax.argmax(dim=1) #if predicting
tensor([1, 4, 2])

>>> pre_target = torch.log(indata_softmax)  #计算log
tensor([[-2.3138, -1.1196, -1.2300, -1.5643, -2.6152],
        [-3.0226, -3.1877, -2.0771, -1.6504, -0.5229],
        [-1.8124, -2.1850, -0.9879, -3.1552, -1.1734]])

>>> -pre_target[range(target.shape[0]), target].mean() #计算loss:method 1
>>> -(pre_target [0][2] + pre_target [1][4] + pre_target [2][0])/3  #计算loss:method 2, 算#出log后的张量,乘以target独热后的张量(可以称之为mask),然后求和并反正负号,得出loss。
#target的独热张量为:
tensor([[0, 0, 1, 0, 0],
        [0, 0, 0, 0, 1],
        [1, 0, 0, 0, 0]])
tensor(1.1884)

 

计算loss公式:

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值