想问一下出现这个报错:x_shape[C_in] / group must equal to w_shape[C_in] = 1, but got 3
我使用的是自己的数据集,错误显示w_shape[C_in]必须是一维,但是自己的数据集是彩色的哇,不是三维吗?是我理解的有问题还是哪里需要修改的哇。
下面是自定义数据集函数,错误日志放在了后面
class GetDatasetGenerator:
def __init__(self, root_dir):
self.root_dir = root_dir
self.img_name = [i for i in os.listdir(os.path.join(root_dir, "train_image")) if i.endswith(".jpg")]
self.seg_name = [i for i in os.listdir(os.path.join(root_dir, "train_mask")) if i.endswith(".jpg")]
def __getitem__(self, index):
segment_name = self.seg_name[index]
img_name = self.img_name[index]
segment_path = os.path.join(self.root_dir,"train_mask" ,segment_name) #数据
image_path = os.path.join(self.root_dir,"train_image", img_name) #标签
image_img = Image.open(image_path)
segment_img = Image.open(segment_path)
image_np = np.array(image_img)
segment_np = np.array(segment_img)
return (image_np,segment_np)
def __len__(self):
return len(self.img_name)
然后网络结构使用的是gitee上面U_net中的网络:https://gitee.com/mindspore/course/tree/master/unet/src
【操作步骤&问题现象】
解答:
https://www.mindspore.cn/docs/api/zh-CN/master/api_python/nn/mindspore.nn.Conv2d.html#mindspore.nn.Conv2d
MindSpore的Conv2D接口默认是采用NCHW格式输入。
报错“x_shape[C_in] / group must equal to w_shape[C_in] = 1, but got 3” 的意思是:
x_shape[C_in] / group 必须要 == w_shape[C_in]
但是用户给的w_shape[C_in] 值是1
但是x_shape[C_in] / group 却==3
所以,问题应该是你输入的w_shape[C_in]不对了。这个w_shape[C_in]就是权重的channels维的大小,也就是你传的in_channels属性值。你检查一下是不是把nn.Conv2d初始化时的in_channels属性设置成1了。