PyTorch数据处理

0. 自定义数据集要自己写Dataset类并根据需求写DataLoaderiter

torch.utils.data包括Dataset(抽象类)和DataLoader
自定义数据集需要继承Dataset类并实现__len__(返回数据的大小)和__getitem__(通过给定索引获取数据、标签或一个样本)。
__getitem__一次只能获取一个样本,所以通过DataLoader来定义一个新的迭代器实现批量读取

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

# torch.utils.data包括Dataset(抽象类)和DataLoader。
# 自定义数据集需要继承Dataset类并实现__len__(返回数据的大小)和__getitem__(通过给定索引获取数据、标签或一个样本)。
# __getitem__一次只能获取一个样本,所以通过DataLoader来定义一个新的迭代器实现批量读取
class myData(Dataset):
    def __init__(self):
        super(myData, self).__init__()
        self.img_pth_lst = [r'D:\cat\cat01.jpg',r'D:\cat\cat02.jpg',r'D:\cat\cat03.jpg',r'D:\cat\cat04.jpg',r'D:\cat\cat05.jpg',r'D:\cat\cat06.jpg',
                            r'D:\dog\dog01.jpg',r'D:\dog\dog02.jpg',r'D:\dog\dog03.jpg',r'D:\dog\dog04.jpg',r'D:\dog\dog05.jpg',r'D:\dog\dog06.jpg',]
        self.labels = [0,0,0,0,0,0,1,1,1,1,1,1]

    def __len__(self):
        return len(self.img_pth_lst)

    def __getitem__(self, item):
        img_pth = self.img_pth_lst[item]
        label = torch.tensor(self.labels[item])
        return img_pth, label

if __name__ == '__main__':
    ### 直接通过Dataset访问数据,一次只能获取一个数据
    my_data = myData()
    print(my_data.__len__())
    print(my_data[2])

    ### 通过DataLoader可以获取批量数据
    batch_size = 3
    my_data_loader = DataLoader(my_data, batch_size=batch_size,shuffle=True, num_workers=2)
    # 通过for-loop访问
    for i, data in enumerate(my_data_loader):
        print('batch_idx:', i)
        img_pths, labels = data
        print('image_paths:', img_pths)
        print('labels:', labels)
    print('*'*30)
    # 通过构造迭代器访问
    my_data_iter = iter(my_data_loader)
    for i in range(int(my_data.__len__()/batch_size)):
        img_pths, labels = next(my_data_iter)
        print('batch_idx:', i)
        img_pths, labels = data
        print('image_paths:', img_pths)
        print('labels:', labels)

1. torch.utils.data.Dataset

创建自己的Dataset类,然后通过索引或for-loop的方式访问、遍历。

import os
from PIL import Image
import numpy as np
from torchvision import transforms as T
from torch.utils import data

transform = T.Compose([T.Resize(224),
                       T.CenterCrop(224),
                       T.ToTensor(),
                       T.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])

class DogCat(data.Dataset):
    def __init__(self, root, transforms=None):
        imgs = os.listdir(root)
        self.imgs = [os.path.join(root, img) for img in imgs]
        self.transforms = transforms

    def __getitem__(self, idx):
        img_path = self.imgs[idx]
        label = 1 if 'dog' == img_path.split('\\')[-1].split('.')[0] else 0
        data = Image.open(img_path)
        if self.transforms:
            data = self.transforms(data)
        return data, label

    def __len__(self):
        return len(self.imgs)

dataset = DogCat(r'D:\03.Data\01.CatDog\image\train', transforms=transform)

ct = 0
img, label = dataset[0]      # 通过索引的方式访问。

for img, label in dataset:   # 通过for-loop的方式遍历。
    ct += 1
    print(ct, ':', img.size(), label)

print()

2. torchvision.datasets.ImageFolder

直接遍历文件夹,文件夹的最后一级目录是类别名。

from torchvision.datasets import ImageFolder
from torchvision import transforms as T

dataset = ImageFolder(r'D:\03.Data\01.CatDog\image\train')
print(type(dataset))
print(dataset.class_to_idx)

if 0:
    for pth, label in dataset.imgs:
        print(pth, label)

print('*'*30)

transform = T.Compose([T.RandomCrop(224),
                       T.RandomHorizontalFlip(),
                       T.ToTensor(),
                       T.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])])
dataset = ImageFolder(r'D:\03.Data\01.CatDog\image\train', transform=transform)
print(dataset[0][0].size())

to_img = T.ToPILImage()
a = to_img(dataset[10][0]*0.5+0.5)
a.show()

3. torch.utils.data.DataLoader

将自定义的Dataset类放到DataLoader中然后可以通过创建迭代器iter再通过next遍历、访问。

from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

import os
from PIL import Image
import numpy as np
from torchvision import transforms as T
from torch.utils import data

transform = T.Compose([T.Resize(224),
                       T.CenterCrop(224),
                       T.ToTensor(),
                       T.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])

class DogCat(data.Dataset):
    def __init__(self, root, transforms=None):
        imgs = os.listdir(root)
        self.imgs = [os.path.join(root, img) for img in imgs]
        self.transforms = transforms

    def __getitem__(self, idx):
        img_path = self.imgs[idx]
        label = 1 if 'dog' == img_path.split('\\')[-1].split('.')[0] else 0
        data = Image.open(img_path)
        if self.transforms:
            data = self.transforms(data)
        return data, label

    def __len__(self):
        return len(self.imgs)

dataset = DogCat(r'D:\03.Data\01.CatDog\image\train\dog', transforms=transform)
if 0:
    ct = 0
    img, label = dataset[0]
    for img, label in dataset:
        ct += 1
        print(ct, ':', img.size(), label)
print()

dataloader = DataLoader(dataset, batch_size=10, shuffle=True, sampler=None, num_workers=0,
                        pin_memory=False, drop_last=False)

# 使用方式一,使用next不断获取一个batch的数据
dataiter = iter(dataloader)
imgs, labels = next(dataiter)

print(imgs.size())

# 使用方式二,在for循环中不断获取一个batch的数据
for batch_data, batch_labels in dataloader:
    print(batch_data.size(), batch_labels.size())

print()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值