mindspore报错:x_shape[C_in] / group must equal to w_shape[C_in] = 1, but got 3

想问一下出现这个报错:x_shape[C_in] / group must equal to w_shape[C_in] = 1, but got 3

我使用的是自己的数据集,错误显示w_shape[C_in]必须是一维,但是自己的数据集是彩色的哇,不是三维吗?是我理解的有问题还是哪里需要修改的哇。

下面是自定义数据集函数,错误日志放在了后面

class GetDatasetGenerator:
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.img_name = [i for i in os.listdir(os.path.join(root_dir, "train_image")) if i.endswith(".jpg")]
        self.seg_name = [i for i in os.listdir(os.path.join(root_dir, "train_mask")) if i.endswith(".jpg")]
    def __getitem__(self, index):
        segment_name = self.seg_name[index]
        img_name = self.img_name[index]
        segment_path = os.path.join(self.root_dir,"train_mask" ,segment_name)  #数据
        image_path = os.path.join(self.root_dir,"train_image", img_name)         #标签
        image_img = Image.open(image_path)
        segment_img = Image.open(segment_path)

        image_np = np.array(image_img)
        segment_np = np.array(segment_img)
        return (image_np,segment_np)

    def __len__(self):
        return len(self.img_name)

然后网络结构使用的是gitee上面U_net中的网络:https://gitee.com/mindspore/course/tree/master/unet/src   

【操作步骤&问题现象】

解答:

https://www.mindspore.cn/docs/api/zh-CN/master/api_python/nn/mindspore.nn.Conv2d.html#mindspore.nn.Conv2d

MindSpore的Conv2D接口默认是采用NCHW格式输入。

报错“x_shape[C_in] / group must equal to w_shape[C_in] = 1, but got 3” 的意思是:

x_shape[C_in] / group 必须要 ==  w_shape[C_in]

但是用户给的w_shape[C_in] 值是1
但是x_shape[C_in] / group 却==3

所以,问题应该是你输入的w_shape[C_in]不对了。这个w_shape[C_in]就是权重的channels维的大小,也就是你传的in_channels属性值。你检查一下是不是把nn.Conv2d初始化时的in_channels属性设置成1了。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值