pytorch数据集加载的两种方式

数据加载的两种方式:

1. 直接采用pytorch官方定义的方法: torchvision.datasets

包含一些常见的数据集
数据集无需转化为图片,直接将下载的压缩包放在root=opt.dataset_dir下即可
缺点是自己无法自定义划分数据集了

import torchvision.datasets
import torchvision.transforms as transforms
#一些转化方法
transform=transforms.Compose([transforms.Grayscale(1),
transforms.Resize(opt.picSize),
transforms.ToTensor()])
dataset=torchvision.datasets.CIFAR10(root=opt.dataset_dir,
#split=opt.data_split,
train=True, #采用训练集
transform=transform,
download=True)#是否从互联网下载
dataloader=DataLoader(dataset,
batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.n_cpu)

官方文档都有解释:https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/

2. 自己重写dataset方法进行加载
importtorch.utils.dataasdataf
fromtorch.utils.dataimportDataset
 
Class dataset(Dataset):
def __init__(self,input_dir):
super(dataset,self).__init__()
path_list=os.listdir(input_dir)
path_list.sort(key=lambdax:int(x.split('.')[0])) #将图片进行排序
self.input_filenames=[os.path.join(input_dir,x)forxinpath_list]
 
def__getitem__(self,index):
input=cv2.imread(self.input_filenames[index],0)
input=cv2.resize(input,(32,32))
input=ToTensor()(input)
#input=input.float()
returninput
 
def__len__(self):
returnlen(self.input_filenames)
 
test_img=dataset('./test/')#调用上面写的dataset,path写图片文件路径
dataloader=dataf.DataLoader(test_img,batch_size=1,shuffle=True)

读取dataloader: for i ,(img,target) in dataloader
其中img的形状:(batch_size, channel, size, size)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值