报错详情:
RuntimeError: only batches of spatial targets supported (3D tensors)
but got targets of size: : 【8, 480, 480, 3】
原因:
损失函数 nn.CrossEntropyLoss()的输入应该是一个4维的张量(网络的输出)和一个三维的张量(target),而数据集的标签为RGB三通道的图片 [batch size,weight,height,RGB]。需要将该四维张量的RGB图片输入转为单值的类别信息,重新将标签制作为单值灰度图。
通俗来说就是:
SegmentationClass文件夹中的mask标签,即png图片是24位的,应该为8位的索引图。
解决方法:
将24位的png改为8位的索引图(8位索引图也是彩色的),python代码如下:
from PIL import Image
import cv2
import os
image_path = r"F:\1ship_height\COCO\Boats\VOC\SegmentationClass" # 要转换的图片所在路径
save_path = r"F:\1ship_height\COCO\Boats\VOC\SegmentationClass_8" # 保存路径
for image_name in os.listdir(image_path):
image = os.path.join(image_path, image_name)
img = Image.open(image)
img = img.convert('P')
save_name = os.path.join(save_path, image_name)
img.save(save_name)