数据类型只支持long类型
import torch
import torch.nn as nn
loss = nn.CrossEntropyLoss()
# input, NxC=2x3
input = torch.randn(2, 3, requires_grad=True)
# target, N
target = torch.empty(2, dtype=torch.long).random_(3)
output = loss(input, target)
output.backward()
多维度会报错:
1only batches of spatial targets supported (non-empty 3D tensors) but got targets of size: : [1, 3,
例如:
criterion = nn.CrossEntropyLoss()
loss = criterion(logit, target.long())
其中,logit: torch.Size([4, 9, 256, 256]); target: [4,1, 256, 256]
就会出现1only batches of spatial targets supported (non-empty 3D tensors) but got targets of size的错误
错误原因:
Pytorch的CrossEntropy官方文档截图如下:
显然,其中N=4, C=31, d1=d2=256, 故target中的1是多余的,是报错的原因。