Pytorch(四)——数据加载及预处理

1. 使用torch.utils.data.Dataset进行数据读取

  • 通过继承该类进行数据读取
    文件路径为:
    在这里插入图片描述
import torch
from torch.utils.data import Dataset,DataLoader
import os
import csv
import glob
import random
from PIL import Image
from torchvision import transforms
import visdom
from torchvision.datasets import ImageFolder

class AnimalData(Dataset):
    def __init__(self,root,resize = [28,28],mode="train"):
        super(AnimalData,self).__init__()
        self.root = root
        self.resize = resize # [h,w]

        # 依据子文件夹名字获取各个类别的标签
        self.class2label = {}
        for name in sorted(os.listdir(os.path.join(self.root))):
            if not os.path.isdir(os.path.join(self.root,name)):
                continue
            self.class2label[name] = len(self.class2label.keys())
        print(self.class2label)

        # 从csv文件中加载数据的存储路径及其标签
        images,labels = self.load_csv("animal.csv")
        # 根据任务需求,返回数据
        if mode == "train":
            self.images = images[:int(0.6*len(images))]
            self.labels = labels[:int(0.6*len(images))]
        elif mode == "val":
            self.images = images[int(0.6 * len(images)):int(0.8 * len(images))]
            self.labels = labels[int(0.6 * len(images)):int(0.8 * len(images))]
        elif mode == "test":
            self.images = images[int(0.8 * len(images)):]
            self.labels = labels[int(0.8 * len(images)):]

    def load_csv(self,file_name):

        if not os.path.exists(file_name):
            images = []
            for name in self.class2label.keys():
                # glob.glob()方法可以匹配该路径下的文件,返回完整路径
                images += glob.glob(os.path.join(self.root,name,"*.png"))
                images += glob.glob(os.path.join(self.root,name,".jpg"))

            # 打乱数据顺序
            random.shuffle(images)

            # 写入csv文件,便于下次读取
            with open(file_name,"w",encoding="utf-8",newline="") as f:
                writer = csv.writer(f)
                for path in images:
                    name = path.split(os.sep)[1]
                    label = self.class2label[name]
                    writer.writerow([path,label])

        # 通过csv加载数据
        with open(file_name,"r",encoding="utf-8") as f:
            reader = csv.reader(f)
            images = []
            labels = []
            for line in reader:
                images.append(line[0])
                labels.append(int(line[1]))
        return images,labels

    # 重写该方法,返回数据大小
    def __len__(self):
        return len(self.images)

    # 反标准化,便于可视化
    def de_normalize(self,x_hat):
        mean = torch.tensor([0.485, 0.456, 0.406]).unsqueeze(1).unsqueeze(1)
        std = torch.tensor([0.229, 0.224, 0.225]).unsqueeze(1).unsqueeze(1)
        x = x_hat *std + mean
        return x

    # 重写该方法,返回Tensor格式的数据及标签
    def __getitem__(self,idx):
        label = torch.tensor(self.labels[idx])
        tf = transforms.Compose([
             lambda x: Image.open(x).convert("RGB"), # 读取图片
             transforms.Resize([int(self.resize[0]*1.25),int(self.resize[1]*1.25)]),
             transforms.RandomRotation(15), # 数据增强
             transforms.CenterCrop(self.resize), # 中心化裁剪
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
         ])
        image = tf(self.images[idx])

        return image,label

if __name__ == '__main__':
    resize = [128,100]
    db = AnimalData(root="animal",resize=resize)
{'cat': 0, 'dog': 1, 'rabbit': 2}

2. 使用torch.utils.data.DataLoader进行数据加载

if __name__ == '__main__':

    resize = [128,100]
    db = AnimalData(root="animal",resize=resize)

    it_db = iter(db)
    vis = visdom.Visdom()
    image,label = next(it_db)
    vis.image(db.de_normalize(image),win="iter_image",opts=dict(title="iter_image"))

    # 使用数据加载器,设定batch
    loader = DataLoader(dataset=db,batch_size=16,shuffle=True,num_workers=8) # num_workers参数为多线程读取数据
    for x,y in loader:
        vis.images(db.de_normalize(x),win="batch_imags",nrow=4,opts=dict(title="batch"))

在这里插入图片描述

3. 使用torchvision.datasets.ImageFolder进行快速读取数据

  # ImageFolder 可以一步实现上述过程
    tf = transforms.Compose([

        transforms.Resize([int(resize[0] * 1.25), int(resize[1] * 1.25)]),
        transforms.RandomRotation(15),  # 数据增强
        transforms.CenterCrop(resize),  # 中心化裁剪
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    db = ImageFolder(root = "animal",
                     transform=tf)

by CyrusMay 2022 06 30

一生要有多少的辗转
才能走到幸福的彼岸
才能 活得此生无恨无憾
平凡却不平淡
——————五月天(青空未来)——————

  • 7
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值