使用PyTorch中Dataset和Dataloader遇到的问题

今天在使用PyTorch中Dataset遇到了一个问题。先看代码

class psDataset(Dataset):
    def __init__(self, x, y, transforms = None):
        super(Dataset, self).__init__()
        self.x = x
        self.y = y
        if transforms == None:
            self.transforms = Compose([Resize((224, 224)), ToTensor()])
        else:
            self.transforms = transforms
        
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        img = Image.open(self.x[idx])
        img = self.transforms(img)       
        return img, torch.tensor([[self.y[idx]]])

结果运行时报错:RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 3 and 1 in dimension 1 at /opt/conda/conda-bld/pytorch_1522182087074/work/torch/lib/TH/generic/THTensorMath.c:2897

Google了一下发现是这样的:读入的图片有些是灰度图(1个通道),绝大多数是RGB图片(3通道),也有些是带透明度的(4通道)。这导致在读入后最后一个维度(通道数)不一致(可能是1、3或者4)。Dataloader在制作batch data时,tensor的shape必须一样,就报了这个错误。解决的方法是:img = img.convert(“RGB”)。完整代码如下:

class psDataset(Dataset):
    def __init__(self, x, y, transforms = None):
        super(Dataset, self).__init__()
        self.x = x
        self.y = y
        if transforms == None:
            self.transforms = Compose([Resize((224, 224)), ToTensor()])
        else:
            self.transforms = transforms
        
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        img = Image.open(self.x[idx])
        img = img.convert("RGB")
        img = self.transforms(img)       
        return img, torch.tensor([[self.y[idx]]])
  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值