RuntimeError: output with shape [1,100, 100] doesn’t match the broadcast shape [3, 100, 100]
报错溯源: 建立我自己的DataSet时报错。这是因为transformer.Normalize里面的mean和std只赋了两个值[0.5,0.52], [0.2, 0.23],它只是对一张图片的两个channel标准化。所以对于RGB图片,需要赋三个值,如下图所示
解决方法:
shape不匹配,是因为输入图片为灰度图,一个通道,需要的图片是RGB图,所以要将灰度图转化为RGB图片
im = Image.open(im_path).convert('RGB') # 图片
class RSDataset(Dataset):
def __init__(self, txt_path, width=100, height=100, transform=None, test=False):
self.ims, self.labels = read_txt(txt_path)
self.width = width
self.height = height
self.transform = transform
self.test = test
def __getitem__(self, index): # 通过下标来索引图片
im_path = self.ims[index]
label = self.labels[index] # 标签
im_path = os.path.join(config.data_root, im_path)
im = Image.open(im_path).convert('RGB') # 图片
# im = Image.open(im_path) # 图片
# Image._show(im)
# im = im.resize((self.width, self.height))
if self.transform is not None:
im = self.transform(im) # 做transform, 将图像转化为tensor等操作
return im, label
def __len__(self):
return len(self.ims)