PyTorch数据读取

9 篇文章 0 订阅
5 篇文章 1 订阅

torch.utils.data.DataLoader

torch.utils.data.DataLoader(torch.utils.data.dataset,batch_size,shuffle,num_workers,pin_memory)

关键是这两个类:
torch.utils.data.DataLoader
torch.utils.data.dataset

import torchvision.transforms as transforms

train_loader = torch.utils.data.DataLoader(
ImageList(root=opt.root_path, fileList=opt.train_list, 
transform=transforms.Compose([ 
transforms.ToTensor(),              #将读取的图片变为Tensor类型,很重要
])),
batch_size=opt.batch_size, shuffle=True,
num_workers=opt.workers, pin_memory=True)

写一个类作为数据读取器,继承torch.utils.data.dataset

#load_imglist.py
import torch.utils.data 

from PIL import Image
import os



def default_list_reader(fileList):
    imgList = []
    with open(fileList, 'r') as file:
        for line in file.readlines():
            imgPath, label = line.strip().split(' ')
            imgList.append((imgPath, int(label)))
    return imgList


class ImageList(torch.utils.data.Dataset):
    def __init__(self, root, fileList, transform=None):
        self.root      = root
        self.imgList   = default_list_reader(fileList)
        self.transform = transform


    def __getitem__(self, index):

        imgPath, target = self.imgList[index]

        print(imgPath)

        img_loc=os.path.join(self.root, imgPath)
        img = Image.open(img_loc).convert('L')  #默认读取彩色图象,这儿转化为RGB图像

        if self.transform is not None:
            img = self.transform(img)

        return img, target

    def __len__(self):
        return len(self.imgList)

使用

for i,(input,target) in enumerate(train_loader):
    print(i,target)
    print(input.shape)

输出的最后一个结果为

(1093, 
 928
[torch.LongTensor of size 1]
)
(1L, 1L, 64L, 64L)

输出的Tensor是4维,将图像自动加了一维。

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值