在 pytorch 中,交叉熵损失函数会报以下的错:
该错误的原因是目标的类别数与预测的类别数不匹配。
只将 CrossEntropyLoss损失函数拿出来,看一下输入和输出。以下代码可以正常的输出loss的值:
import torch
criterion = torch.nn.CrossEntropyLoss()
x = torch.rand(10, 10) # x.shape[10,10]
y =</
在 pytorch 中,交叉熵损失函数会报以下的错:
该错误的原因是目标的类别数与预测的类别数不匹配。
只将 CrossEntropyLoss损失函数拿出来,看一下输入和输出。以下代码可以正常的输出loss的值:
import torch
criterion = torch.nn.CrossEntropyLoss()
x = torch.rand(10, 10) # x.shape[10,10]
y =</