0. 自定义数据集要自己写Dataset
类并根据需求写DataLoader
、iter
torch.utils.data
包括Dataset
(抽象类)和DataLoader
。
自定义数据集需要继承Dataset
类并实现__len__
(返回数据的大小)和__getitem__
(通过给定索引获取数据、标签或一个样本)。
__getitem__
一次只能获取一个样本,所以通过DataLoader
来定义一个新的迭代器实现批量读取
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
# torch.utils.data包括Dataset(抽象类)和DataLoader。
# 自定义数据集需要继承Dataset类并实现__len__(返回数据的大小)和__getitem__(通过给定索引获取数据、标签或一个样本)。
# __getitem__一次只能获取一个样本,所以通过DataLoader来定义一个新的迭代器实现批量读取
class myData(Dataset):
def __init__(self):
super(myData, self).__init__()
self.img_pth_lst = [r'D:\cat\cat01.jpg',r'D:\cat\cat02.jpg',r'D:\cat\cat03.jpg',r'D:\cat\cat04.jpg',r'D:\cat\cat05.jpg',r'D:\cat\cat06.jpg',
r'D:\dog\dog01.jpg',r'D:\dog\dog02.jpg',r'D:\dog\dog03.jpg',r'D:\dog\dog04.jpg',r'D:\dog\dog05.jpg',r'D:\dog\dog06.jpg',]
self.labels = [0,0,0,0,0,0,1,1,1,1,1,1]
def __len__(self):
return len(self.img_pth_lst)
def __getitem__(self, item):
img_pth = self.img_pth_lst[item]
label = torch.tensor(self.labels[item])
return img_pth, label
if __name__ == '__main__':
### 直接通过Dataset访问数据,一次只能获取一个数据
my_data = myData()
print(my_data.__len__())
print(my_data[2])
### 通过DataLoader可以获取批量数据
batch_size = 3
my_data_loader = DataLoader(my_data, batch_size=batch_size,shuffle=True, num_workers=2)
# 通过for-loop访问
for i, data in enumerate(my_data_loader):
print('batch_idx:', i)
img_pths, labels = data
print('image_paths:', img_pths)
print('labels:', labels)
print('*'*30)
# 通过构造迭代器访问
my_data_iter = iter(my_data_loader)
for i in range(int(my_data.__len__()/batch_size)):
img_pths, labels = next(my_data_iter)
print('batch_idx:', i)
img_pths, labels = data
print('image_paths:', img_pths)
print('labels:', labels)
1. torch.utils.data.Dataset
创建自己的Dataset类,然后通过索引或for-loop的方式访问、遍历。
import os
from PIL import Image
import numpy as np
from torchvision import transforms as T
from torch.utils import data
transform = T.Compose([T.Resize(224),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])
class DogCat(data.Dataset):
def __init__(self, root, transforms=None):
imgs = os.listdir(root)
self.imgs = [os.path.join(root, img) for img in imgs]
self.transforms = transforms
def __getitem__(self, idx):
img_path = self.imgs[idx]
label = 1 if 'dog' == img_path.split('\\')[-1].split('.')[0] else 0
data = Image.open(img_path)
if self.transforms:
data = self.transforms(data)
return data, label
def __len__(self):
return len(self.imgs)
dataset = DogCat(r'D:\03.Data\01.CatDog\image\train', transforms=transform)
ct = 0
img, label = dataset[0] # 通过索引的方式访问。
for img, label in dataset: # 通过for-loop的方式遍历。
ct += 1
print(ct, ':', img.size(), label)
print()
2. torchvision.datasets.ImageFolder
直接遍历文件夹,文件夹的最后一级目录是类别名。
from torchvision.datasets import ImageFolder
from torchvision import transforms as T
dataset = ImageFolder(r'D:\03.Data\01.CatDog\image\train')
print(type(dataset))
print(dataset.class_to_idx)
if 0:
for pth, label in dataset.imgs:
print(pth, label)
print('*'*30)
transform = T.Compose([T.RandomCrop(224),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])])
dataset = ImageFolder(r'D:\03.Data\01.CatDog\image\train', transform=transform)
print(dataset[0][0].size())
to_img = T.ToPILImage()
a = to_img(dataset[10][0]*0.5+0.5)
a.show()
3. torch.utils.data.DataLoader
将自定义的Dataset
类放到DataLoader
中然后可以通过创建迭代器iter
再通过next
遍历、访问。
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import os
from PIL import Image
import numpy as np
from torchvision import transforms as T
from torch.utils import data
transform = T.Compose([T.Resize(224),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])
class DogCat(data.Dataset):
def __init__(self, root, transforms=None):
imgs = os.listdir(root)
self.imgs = [os.path.join(root, img) for img in imgs]
self.transforms = transforms
def __getitem__(self, idx):
img_path = self.imgs[idx]
label = 1 if 'dog' == img_path.split('\\')[-1].split('.')[0] else 0
data = Image.open(img_path)
if self.transforms:
data = self.transforms(data)
return data, label
def __len__(self):
return len(self.imgs)
dataset = DogCat(r'D:\03.Data\01.CatDog\image\train\dog', transforms=transform)
if 0:
ct = 0
img, label = dataset[0]
for img, label in dataset:
ct += 1
print(ct, ':', img.size(), label)
print()
dataloader = DataLoader(dataset, batch_size=10, shuffle=True, sampler=None, num_workers=0,
pin_memory=False, drop_last=False)
# 使用方式一,使用next不断获取一个batch的数据
dataiter = iter(dataloader)
imgs, labels = next(dataiter)
print(imgs.size())
# 使用方式二,在for循环中不断获取一个batch的数据
for batch_data, batch_labels in dataloader:
print(batch_data.size(), batch_labels.size())
print()