RuntimeError: only batches of spatial targets supported (3D tensors) but got 4
https://discuss.pytorch.org/t/runtimeerror-only-batches-of-spatial-targets-supported-3d-tensors-but-got-targets-of-dimension-4/82098
解决方案:
target = torch.argmax(target, dim=1)
异常背景:
In my training function:
#get values from dataloader
X = normalize_zero_to_one(X) #input
y = normalize_zero_to_one(y) #target
images = Variable(torch.from_numpy(X)).to(device) # [batch, channel, H, W]
masks = Variable(torch.from_numpy(y)).to(device)
optim.zero_grad()
outputs = model(images)
loss = loss_new(outputs, masks) # (preds, target)
loss.backward()
optim.step() # Update weights
I know the the target (here masks) should be [batch_size, w, h]. However, it is currently [batch_size, channels, w, h].