pytorch 利用Dataset读取数据报错

1、报错点如下:

Traceback (most recent call last):
  File "read_data.py", line 100, in <module>
    for i , (image,seg) in enumerate(train_loader):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 819, in __next__
    return self._process_data(data)
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 846, in _process_data
    data.reraise()
  File "/usr/local/lib/python3.6/dist-packages/torch/_utils.py", line 369, in reraise
    raise self.exc_type(msg)
TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "read_data.py", line 91, in __getitem__
    ])(img)
  File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py", line 61, in __call__
    img = t(img)
  File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py", line 238, in __call__
    return F.center_crop(img, self.size)
  File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py", line 374, in center_crop
    w, h = img.size
TypeError: 'int' object is not iterable

原来我其实没有注意Datset与PIL下面的Image的关系:

    def __getitem__(self, index):
        img = cv2.imread(self.image_name[index],cv2.COLOR_BGR2RGB)
        #img = np.transpose(img,(2,1,0))
        img = cv2.resize(img,(self.size,self.size))
        seg = cv2.imread(self.image_seg[index],cv2.COLOR_BGR2RGB)
        seg = cv2.resize(seg,(self.size,self.size) )
        seg = convert_from_color_segmentation(seg)
        #seg = torch.from_numpy(seg)
        if self.transform is not  None:
           img = self.transform(img)
        return img , seg

报错中清晰提及这个问题,我突然反应过来,是自己的读取数据错误了:

应该为:

    def __getitem__(self, index):
        #img = cv2.imread(self.image_name[index],cv2.COLOR_BGR2RGB)
        img = Image.open(self.image_name[index])
        #img = np.transpose(img,(2,1,0))
        #img = cv2.resize(img,(self.size,self.size))
        seg = cv2.imread(self.image_seg[index],cv2.COLOR_BGR2RGB)
        seg = cv2.resize(seg,(self.size,self.size) )
        seg = convert_from_color_segmentation(seg)
        #seg = torch.from_numpy(seg)
        if self.transform is not  None:
           img = self.transform(img)
        return img , seg

测试打印数据,完美解决:

transform = transforms.Compose([transforms.Resize((300,300)),transforms.RandomCrop((224,224)),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
#transform = transforms.Compose([
#                                               transforms.CenterCrop((278,278)),transforms.Resize((224,224)),transforms.ToTensor()
#                                                                              ])
train_data = GetParasetData(size=224,train=True,transform=transform)
train_loader = DataLoader(train_data,batch_size=64,shuffle=True,num_workers=2)
for i , (image,seg) in enumerate(train_loader):
    print(image.shape,seg.shape)
    

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值