torch.nn.CrossEntropyLoss用法(原理, nlp, cv例子)

前言

早上想花一个小时参照网上其他教程,修改模型结构,写一个手写识别数字的出来,结果卡在了这个上面,loss一直降不下来,然后我就去查看了一下CrossEntropyLoss的用法,毕竟分类问题一般都用这个。



代码

原理层面

引入一个库:

import torch



假如是一个四分类任务,batch为2(只是为了显示简单,举个例子罢了)

logists = torch.randn(2, 4, requires_grad=True)
print(logists)


其实根据这个模型预测出来就是, 第一个样本预测的类别是1, 第二个样本预测的类别是2。
这里我们假设模型足够好,都预测对了,那么其实target就是ground_truth。

target = logists.argmax(dim=-1)


通过查看官方文档 CrossEntropyLoss–PyTorch 1.10.1 document, 可以知道loss有两种算法。
target可one-hot也可以不one-hot。


定义损失函数:

crition = torch.nn.CrossEntropyLoss()



先来看个target_1d版的loss:

crition(logists, target)



再来看个target one-hot版的:
注意: 该版本在我的macbook python3.7.8, torch1.10.2的版本上没有问题, 但是在我的windows python3.7.6 torch1.9.1就出问题了!!! 因此稳妥起见还是直接用target比较好
先把target转为one

t_onehot = torch.nn.functional.one_hot(target, num_classes=4)

如何是one_hot, 要求target也是浮点类型的,所以t_onehot再调用float()转为浮点类型。

crition(logists, t_onehot.float())

最后发现两种方法其实算出来的loss都是0.5601


另外插一嘴,crossEntropyLoss也可以通过nll_loss实现(如果你去看torch.nn.crossEntropyLoss的源码就会发现官方就是使用torch.nn.functional.nll_loss实现的,只不过模型输出的logists值要先经过log_softmax



nlp

这里来个序列标注的例子。模型输出是 (batch_size, seq_len, hidden_dim)。
这里我演示用两种方法,方法一会更加简洁,但可读性不如方法二,看个人喜欢。

# 序列标注,把每个token分成一类
import torch


cel = torch.nn.CrossEntropyLoss()
# 也可以用以下的方式,结果一样
# cel = torch.nn.functional.cross_entropy

batch_size, seq_len, hidden_dim = 4, 28, 128

# output logits
x = torch.randn(batch_size, seq_len, hidden_dim)
# ground truth
gt = torch.ones(batch_size, seq_len).long()

print('method 1:')
print(cel(x.permute(0, 2, 1), gt))
print()

print('method 2:')
print(cel(x.view(-1, hidden_dim), gt.view(-1)))
可以看到两种方法一样。




cv

图像分割例子,模型输出是 (batch_size, channel, height, width), 有多少个类别就有多少个channel, 通常医疗上的语义分割是2分类,因此输出channel为2。
这里我演示用两种方法,方法一会更加简洁,但可读性不如方法二,看个人喜欢。

# 语义分割
import torch


cel = torch.nn.CrossEntropyLoss()
# 也可以用以下的方式,结果一样
# cel = torch.nn.functional.cross_entropy

# batch_size, input_channel, ouput_channel, height, width
# input_channel: 3 代表彩色图, output_channel: 2 代表语义分割二分类
b, ic, oc, h, w = 4, 3, 2, 28, 28

# output logits, predict: x.argmax(dim=1)
x = torch.randn(b, oc, h, w)
# ground truth
gt = torch.ones(b, 1, h, w).long()

# 这里可以把 h * w 这种图像的二维数据展平成一维数据,就和序列标注一样了

print('method 1:')
print(cel(x.view(b, oc, -1), gt.squeeze().view(b, -1)))
print()

print('method 2:')
print(cel(x.permute(0, 2, 3, 1).reshape(-1, oc), gt.squeeze().reshape(-1)))
  • 8
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 11
    评论
这段代码导入了一些常用的Python库和模块,以及一些特定的工具和函数。让我逐一解释它们的作用: - `import numpy as np`: 导入NumPy库并将其命名为`np`,用于进行数值计算和数组操作。 - `import random`: 导入Python的随机数模块,用于生成随机数和进行随机抽样。 - `import math`: 导入Python的数学模块,提供了一些数学函数和常量。 - `import os`: 导入Python的操作系统模块,用于进行文件和目录操作。 - `import scipy.io`: 导入SciPy库中的io模块,用于读取和写入各种数据文件。 - `import matplotlib.pyplot as plt`: 导入Matplotlib库中的pyplot模块,并将其命名为`plt`,用于绘制数据可视化图形。 - `import torch`: 导入PyTorch深度学习库。 - `import torch.nn as nn`: 导入PyTorch中的神经网络模块,用于定义和构建神经网络模型。 - `import torch.nn.functional as F`: 导入PyTorch中的函数式接口模块,提供了一些常用的非线性函数和损失函数。 - `import torchvision`: 导入PyTorch中的计算机视觉库,用于处理图像和视频数据。 - `import transformers`: 导入Hugging Face的Transformers库,用于自然语言处理任务和预训练模型。 - `%matplotlib inline`: 这是一个Jupyter Notebook的魔术命令,用于在Notebook中内联显示Matplotlib绘图的结果。 通过导入这些库和模块,代码可以使用它们提供的功能来进行数据处理、数学计算、文件操作、绘图、深度学习模型构建和自然语言处理等任务。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值