一.DataLoader配合数据集的用法
二.若是下载别人的数据集,用以下代码直接加载即可
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
test_data=torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor())
test_loader=DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=False)
writer=SummaryWriter("dataloader")
step=0
for data in test_loader:
imgs,targets=data
writer.add_images("test_data",imgs,step)
step+=1
writer.close()
三.若是自己构建数据集,用第一天学的继承Dateset父类并重写构造方法和__len__和__getitem()__两个魔法方法,详情可看
自己构建数据集的方法
四.自己构建的数据集导入
import torchvision
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import os
class MyDate(Dataset):
def __init__(self,root_path,label_path):
self.root_path=root_path
self.label_path=label_path
self.path=os.path.join(self.root_path,self.label_path)
self.img_path=os.listdir(self.path)
def __getitem__(self, item):
img_name=self.img_path[item]
img_truepath=os.path.join(self.root_path,self.label_path,img_name)
img=Image.open(img_truepath)
img = trans_resize(img)
img=trans_tensor(img)
label=self.label_path
return img,label
def __len__(self):
return len(self.img_path)
root="hymenoptera_data/hymenoptera_data/train"
label="ants"
trans_tensor=torchvision.transforms.ToTensor()
trans_resize=torchvision.transforms.Resize((50,50))
ants_dataset=MyDate(root,label)
writer=SummaryWriter("logs")
data_loader=DataLoader(dataset=ants_dataset,batch_size=64,shuffle=True,num_workers=0)
i=0
for data in data_loader:
writer.add_images("test",data[0],i)
i+=1