PyTorch读取自己的数据集

方法一:用ImageFloder类

图片结构如下所示: 子文件夹的名字为label

transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
   # T.RandomCrop(c.cropsize),
    T.Resize([c.imageSize, c.imageSize]), 
    T.ToTensor()
])

train_dataset = ImageFolder(
            c.TRAIN_PATH,
            transform)

# Training data loader
trainloader = DataLoader(
    train_dataset,
    batch_size=c.batch_size,
    shuffle=True,
    pin_memory=True,
    num_workers=8,
    drop_last=True
)

trainloader返回的是一个列表[image,label]

 for i_batch, datalist in enumerate(datasets.trainloader):
           
            #print(np.array(datalist).shape) #datalist是列表
           
            data, labels = datalist
            data = data.to(device)
            #print(data.shape) #[batchsize,3,128,128]

方法二:自定义Dataset类

用glob.glob读取图片(不能读子文件夹)

图片结构为:

class Hinet_Dataset(Dataset):
    def __init__(self, transforms_=None, mode="train"):

        self.transform = transforms_
        self.mode = mode
        if mode == 'train':
            # train
            self.files = natsorted(sorted(glob.glob(c.TRAIN_PATH + "/*." + c.format_train, recursive=True)))
        else:
            # test
            self.files = sorted(glob.glob(c.VAL_PATH + "/*." + c.format_val,recursive=True))
        print("Total training examples:", len(self.files))

    def __getitem__(self, index):
        try:
            image = Image.open(self.files[index])
            image = to_rgb(image)
            item = self.transform(image)
            return item  #返回的形式为一张图片

        except:
            return self.__getitem__(index + 1)

    def __len__(self):
        if self.mode == 'shuffle':
            return max(len(self.files_cover), len(self.files_secret))

        else:
            return len(self.files)

 _getitem_(self,index)函数定义放回的形式  即trainloader加载的形式,这里返回是只有图片没有标签

train_datset = Hinet_Dataset(transforms_=transform, mode="train")

trainloader = DataLoader(
    train_dataset,
    batch_size=c.batch_size,
    shuffle=True,
    pin_memory=True,
    num_workers=8,
    drop_last=True
)

使用时:

      for i_batch, data in enumerate(datasets.trainloader):
           
            #print(data.shape) #[batchsize,3,128,128] 
            data = data.to(device)
            cover = data[data.shape[0] // 2:] #data里的前一半
            secret = data[:data.shape[0] // 2]#data里的后一半

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值