RuntimeError: 1only batches of spatial targets supported (non-empty 3D tensors) but got targets of s

在用自己的数据集训练unet时,碰到了这样的问题。

RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size: : [1, 640, 959, 3]

损失函数 nn.CrossEntropyLoss()的输入应该是一个4维的张量(网络的输出)和一个三维的张量(target),而读取的数据集中的标签为RGB三通道的图片 [batch size,weight,height,RGB]。

需要将该四维张量的RGB图片输入转为单值的类别信息。

重新将标签制作为单值灰度图。

import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2


color2class_dict = {   
        0: [64.0, 140.0, 214.0, 255.0],
        1: [2.0, 4.0, 244.0, 255.0],
        2: [210.0, 21.0, 28.0, 255],
        3: [9.0, 243.0, 25.0, 255.0]

    }   #自行设置类别对应的颜色字典
def get_keys1(value):   #按字典中的颜色对应关系分类
    p = 0
    for k, v in color2class_dict.items():
        if v == value:
            p = k
            break
    return p
def get_keys2(value):   #自行设置颜色范围
    if value[0] > 150:
        return 1
    elif value[1] > 150 and value[2] < 50:
        return 2
    elif value[1] < 50 and value[2] > 150:
        return 3
    else:
        return 0

def main(input_path, save_path, mode):
    get_keys = get_keys1 if mode == 0 else get_keys2
    img_list = os.listdir(input_path)
    for image in img_list:
        img_path = os.path.join(input_path, image)
        save_path_img = os.path.join(save_path, image.split(".")[0]+"_mask.png")
        img = plt.imread(img_path)*255.0
        img_label = np.zeros((img.shape[0], img.shape[1]))
        img_new_label = np.zeros((img.shape[0], img.shape[1]))
        for i in range(img.shape[0]):
            for j in range(img.shape[1]):
                value = list(img[i, j])
                img_label[i, j] = get_keys(value)
                img_new_label[i, j] = img_label[i, j]
        label0 = Image.fromarray(np.uint8(img_new_label))
        cv2.imwrite(save_path_img, img_label)
        print(image+" done")


input_path = ""
save_path = ""
mode = 0
main(input_path, save_path, mode)

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值