pytorch-Dataset类的基本写法
1、Dataset类简介
是pytorch中用来说明数据集的类,主要说明文件路径、数据集的特性、观察数据、格式转换(转为tensor)等
2、Dataset类简介
是pytorch中用来导入数据集的类,主要负责分割数据(分成训练轮数epoch, 小批量batch)、拼接数据等
3、Dataset说明
(1) MyData 继承了Dataset类,要重写三个虚函数:__ init __ , __ getitem__, __ len __
(2) self 相当于 this , self.x, x是本类的全局变量
(3) 路径字符串连接要用os.path.join(root_dir, sub_dir)
4、Dataset代码
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self, root_dir, label):
self.root_dir = root_dir
self.label = label
self.path = os.path.join(self.root_dir,self.label)
self.img_names = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_names[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
return img, self.label
def len(self):
return len(self.img_path)
if __name__ == '__main__':
root_dir = "data/hymenoptera_data/train"
ants_label = "ants"
ants_dataset = MyData(root_dir, ants_label)
5、DataLoader代码
import torchvision
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
trans = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
if __name__ == '__main__':
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=trans, download=True)
test_loader = DataLoader(dataset=test_set, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
writer = SummaryWriter("dataloader")
step = 0
for epoch in range(2):
step = 0
for data in test_loader:
img, target = data
writer.add_images("Epoch: {}".format(epoch), img, step)
step = step + 1
writer.close()