原创
torch.nn.CrossEntropyLoss的相关
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
<!--一个博主专栏付费入口结束-->
<link rel="stylesheet" href="https://csdnimg.cn/release/phoenix/template/css/ck_htmledit_views-d284373521.css">
<link rel="stylesheet" href="https://csdnimg.cn/release/phoenix/template/css/ck_htmledit_views-d284373521.css">
<div class="htmledit_views" id="content_views">
<p> 参数可以看下面</p>
-
class CrossEntropyLoss(_WeightedLoss):
-
-
def __init__(self, weight=None, size_average=True, ignore_index=-100, reduce=True):
-
pass
-
def forward(self, input, target):
-
pass
解释:网上的一些解释,该损失的计算公式:https://blog.csdn.net/tmk_01/article/details/80839810
这里面的公式的2个例子举的很好,该例子是判断每一个实例的分类,比如图片是属于哪一个类的。所以每一个实例的真实标签是一维的。做不一样的任务,实例的标签也是不一样的,有1维或者是多维的。举例:如果输入的是图像,任务是分类每个图像对应一个类别即数字是一维的,总共有class个类别,那么input=【n,c,h,w】,target=【n】是一维的,元素个数为n;如果是做语义分割任务是得到一张【h,w】的语义图,那么input=【n,c,h,w】,target=【n,h,w】。
import torch.nn as nn
import torch.nn.functional as F
#可以的格式
#首先要定义这个函数,也就是实例化才能使用
loss = nn.CrossEntropyLoss()(input,target)
#这个函数已经进行了定义为criterion
criterion = nn.CrossEntropyLoss()
loss = criterion(input,target)
#F函数里面有这个函数的相关定义,直接调用就可以
loss = F.cross_entropy(input,target)
#上述三个的输出结果是一样的,loss 类型为torch.Tensor是一个可求导的tensor; loss.item类型为float是python类型的常数值.
#不可以的格式
loss = nn.CrossEntropyLoss(input,target)
#得到的loss类型为torch.nn.modules.loss.CrossEntropyLoss,相当于干函数的实例化,而不是上面的tensor类型