一、简介
今天在处理二分类问题时,使用CrossEntropyLoss()作为损失函数,出现了Target 2 is out of bounds.这样的报错。
二、原因
原因是因为数据集中的标签是1和2,不是0和1。一般会认为处理二分类问题,模型的输出就设置为2。但是由于CrossEntropyLoss()函数内置了softmax()函数将标签转换为独热编码的形式。如数据集中的标签是1和2,则会被转换成(0, 0, 1)或(0, 1, 0),但是模型的输出只有2,如(0.4, 0.6),两个输入的形状不一样,所以报错。
三、解决办法
第一个解决办法:将标签转换为0和1.
label_pipeline = lambda x: 0 if x == 1 else 1
第二个解决办法:模型输出设置为3.
self.fc = nn.Linear(embed_dim, 3)