pytorch中的CrossEntropyLoss

这里主要探讨torch.nn.CrossEntropyLoss函数的用法。

使用方法如下:

# 首先定义该类
loss = torch.nn.CrossEntropyLoss()
#然后传参进去
loss(target, label)

第一个参数的维度为m1 * m2,第二个参数维度为m1。我们在做多分类问题的时候,target应该为我们网络生成的值,而label则是非one-hot类型的值。

用手写体数字识别简单举个例子:

for (trainX,trainY) in trainLoader:
# forward
        out = net(trainX) # batch_size * num_classes
        loss = loss_fn(out, trainY)
# backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

另外,trainY必须为Long类型,如果为非类型则会报错。RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'trainY'
CrossEntropyLoss还会自动对out作用softmax。

  • 5
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值