部分代码如下
for data in train_loader:
imgs, targets = data
# print("标签为", targets.shape)
imgs.byte()
output = unet_l2(imgs)
loss = criterion(output, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
报错如下:
RuntimeError: expected scalar type Byte but found Float
解决方法
output = unet_l2(imgs.to(torch.float32))