pytorch-自定义图片数据集

步骤:

  1. 图片分类存储在不同文件夹下
    在这里插入图片描述
  2. 写一个类继承自torch.utils.data.Dataset并重写__len()__和__getitem()__方法
  3. 打标签
  4. 写一个把图片路径与标签以”,“分隔存入csv文件,若文件存在能加载数据出来的方法
  5. __getitem()__方法把把csv的路径对应的图片读出来,进行转换,return,便于用torch.utils.data.Dataloader加载
import torch
import os, glob
import random, csv
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image


class ADS_B(Dataset):
    def __init__(self, root, resize, mode):
        super(ADS_B, self).__init__()
        # 文件根路径
        self.root = root
        # 调整大小
        self.resize = resize

        # 打标签
        self.name2label = {}
        for name in sorted(os.listdir((os.path.join(root)))):
            if not os.path.isdir(os.path.join(root, name)):
                continue
            self.name2label[name] = len(self.name2label.keys())

        # 从csv文件加载文件路径与标签
        self.imagepaths, self.labels = self.load_csv('imagepaths.csv')

        # 按作用取数据集
        if mode == 'train':  # 60%
            self.imagepaths = self.imagepaths[:int(0.6 * len(self.imagepaths))]
            self.labels = self.labels[:int(0.6 * len(self.labels))]
        elif mode == 'val':  # 20%: 60% => 80%
            self.imagepaths = self.imagepaths[int(0.6 * len(self.imagepaths)):int(0.8 * len(self.imagepaths))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.imagepaths))]
        else:  # 20%: 80% => 100%
            self.imagepaths = self.imagepaths[int(0.8 * len(self.imagepaths)):]
            self.labels = self.labels[int(0.8 * len(self.labels)):]

    def load_csv(self, filename):
        # 不存在csv文件则创建
        if not os.path.exists(os.path.join(self.root, filename)):
            imagepaths = []
            for name in self.name2label.keys():
                # 把所有格式的图片路径全加入到list中
                imagepaths += glob.glob(os.path.join(self.root, name, '*.png'))
                imagepaths += glob.glob(os.path.join(self.root, name, '*.jpg'))

            # print(len(imagepaths), imagepaths)

            # 打乱
            random.shuffle(imagepaths)

            # 以imagepath, label的形式保存
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                writer = csv.writer(f)
                for imagepath in imagepaths:
                    # 取出类别名称,以得到label
                    name = imagepath.split(os.sep)[-2]
                    label = self.name2label[name]
                    writer.writerow([imagepath, label])
                print('write into csv file:', filename)

        # 存在则直接加载出来
        imagepaths, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                imagepath, label = row
                label = int(label)

                imagepaths.append(imagepath)
                labels.append(label)

        # 保证图片路径的个数与标签个数相等
        assert len(imagepaths) == len(labels)

        return imagepaths, labels

    # 数据集大小
    def __len__(self):
        return len(self.imagepaths)

    # 把像素值从[-1, 1恢复到原来的值,方便显示
    def denormalize(self, x_hat):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # x_hat = (x - mean) / std
        # x = x_hat * std + mean
        # x: [c, h, w]
        # mean: [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)

        x = x_hat * std + mean

        return x

    # 取图片路径与标签
    def __getitem__(self, idx):
        imagepath, label = self.imagepaths[idx], self.labels[idx]
        # 转换器
        tf = transforms.Compose([
            # 读取图片
            lambda x: Image.open(x).convert('RGB'),  # string path => image data
            # 更改大小
            transforms.Resize((self.resize, self.resize)),
            # 随机旋转15度
            # transforms.RandomRotation(15),
            # 中心裁剪
            # transforms.CenterCrop(self.resize),
            # 转tensor
            transforms.ToTensor(),
            # 把图片转到[-1,1]上,方便训练
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

        ])

        image = tf(imagepath)
        label = torch.tensor(label)

        return image, label


def main():

    import visdom
    import time
    viz = visdom.Visdom()

    db = ADS_B('ADS-B', 224, 'train')
    # 用迭代器迭代,调用了getitem方法
    x, y = next(iter(db))
    # sample: torch.Size([3, 224, 224]) torch.Size([]) tensor(7)
    print('sample:', x.shape, y.shape, y)
    # 显示i一张图片
    viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))

    # 使用torchvision.datasets.ImageFolder做数据集
    # tf = transforms.Compose([
    #     transforms.Resize((64, 64)),
    #     transforms.ToTensor()
    # ])
    # db = torchvision.datasets.ImageFolder(root='ADS-B', transform=tf)
    # print(db.class_to_idx)

    # 使用DataLoader加载数据集
    loader = DataLoader(db, batch_size=32, shuffle=True)
    # loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8) # 多线程取数据
    # len of loader: 75 = 400 * 10 * 0.6 / 32 = batch数量
    print('len of loader:', len(loader))
    for x, y in loader:
        viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
        viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))

        time.sleep(10)

if __name__ == '__main__':
    main()

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
在Pytorch中加载图片数据集一般有两种方法。第一种是使用torchvision.datasets中的ImageFolder来读取图片,然后用DataLoader来并行加载,适合图片分类问题,简单但不灵活。\[1\]您可以通过设置各种参数,例如批处理大小以及是否在每个epoch之后对数据打乱顺序,来自定义DataLoader。例如,可以使用以下代码创建一个DataLoader:dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)。\[2\]第二种方法是通过继承torch.utils.data.Dataset实现用户自定义读取数据集,然后用DataLoader来并行加载,这种方法更为灵活。您可以将分类图片的父目录作为路径传递给ImageFolder(),并传入transform来加载数据集。然后可以使用DataLoader加载数据,并构建网络训练。\[3\] #### 引用[.reference_title] - *1* [Pytorch加载图片数据集的两种方式](https://blog.csdn.net/weixin_43917574/article/details/114625616)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insertT0,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [Pytorch加载图像数据](https://blog.csdn.net/qq_28368377/article/details/105635898)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insertT0,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [pytorch加载自己的图片数据集的两种方法](https://blog.csdn.net/qq_53345829/article/details/124308515)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insertT0,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值