在用自己的数据集训练unet时,碰到了这样的问题。
RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size: : [1, 640, 959, 3]
损失函数 nn.CrossEntropyLoss()的输入应该是一个4维的张量(网络的输出)和一个三维的张量(target),而读取的数据集中的标签为RGB三通道的图片 [batch size,weight,height,RGB]。
需要将该四维张量的RGB图片输入转为单值的类别信息。
重新将标签制作为单值灰度图。
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
color2class_dict = {
0: [64.0, 140.0, 214.0, 255.0],
1: [2.0, 4.0, 244.0, 255.0],
2: [210.0, 21.0, 28.0, 255],
3: [9.0, 243.0, 25.0, 255.0]
} #自行设置类别对应的颜色字典
def get_keys1(value): #按字典中的颜色对应关系分类
p = 0
for k, v in color2class_dict.items():
if v == value:
p = k
break
return p
def get_keys2(value): #自行设置颜色范围
if value[0] > 150:
return 1
elif value[1] > 150 and value[2] < 50:
return 2
elif value[1] < 50 and value[2] > 150:
return 3
else:
return 0
def main(input_path, save_path, mode):
get_keys = get_keys1 if mode == 0 else get_keys2
img_list = os.listdir(input_path)
for image in img_list:
img_path = os.path.join(input_path, image)
save_path_img = os.path.join(save_path, image.split(".")[0]+"_mask.png")
img = plt.imread(img_path)*255.0
img_label = np.zeros((img.shape[0], img.shape[1]))
img_new_label = np.zeros((img.shape[0], img.shape[1]))
for i in range(img.shape[0]):
for j in range(img.shape[1]):
value = list(img[i, j])
img_label[i, j] = get_keys(value)
img_new_label[i, j] = img_label[i, j]
label0 = Image.fromarray(np.uint8(img_new_label))
cv2.imwrite(save_path_img, img_label)
print(image+" done")
input_path = ""
save_path = ""
mode = 0
main(input_path, save_path, mode)