pytorch-自定义图片数据集

步骤:

  1. 图片分类存储在不同文件夹下
    在这里插入图片描述
  2. 写一个类继承自torch.utils.data.Dataset并重写__len()__和__getitem()__方法
  3. 打标签
  4. 写一个把图片路径与标签以”,“分隔存入csv文件,若文件存在能加载数据出来的方法
  5. __getitem()__方法把把csv的路径对应的图片读出来,进行转换,return,便于用torch.utils.data.Dataloader加载
import torch
import os, glob
import random, csv
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image


class ADS_B(Dataset):
    def __init__(self, root, resize, mode):
        super(ADS_B, self).__init__()
        # 文件根路径
        self.root = root
        # 调整大小
        self.resize = resize

        # 打标签
        self.name2label = {}
        for name in sorted(os.listdir((os.path.join(root)))):
            if not os.path.isdir(os.path.join(root, name)):
                continue
            self.name2label[name] = len(self.name2label.keys())

        # 从csv文件加载文件路径与标签
        self.imagepaths, self.labels = self.load_csv('imagepaths.csv')

        # 按作用取数据集
        if mode == 'train':  # 60%
            self.imagepaths = self.imagepaths[:int(0.6 * len(self.imagepaths))]
            self.labels = self.labels[:int(0.6 * len(self.labels))]
        elif mode == 'val':  # 20%: 60% => 80%
            self.imagepaths = self.imagepaths[int(0.6 * len(self.imagepaths)):int(0.8 * len(self.imagepaths))]
            self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.imagepaths))]
        else:  # 20%: 80% => 100%
            self.imagepaths = self.imagepaths[int(0.8 * len(self.imagepaths)):]
            self.labels = self.labels[int(0.8 * len(self.labels)):]

    def load_csv(self, filename):
        # 不存在csv文件则创建
        if not os.path.exists(os.path.join(self.root, filename)):
            imagepaths = []
            for name in self.name2label.keys():
                # 把所有格式的图片路径全加入到list中
                imagepaths += glob.glob(os.path.join(self.root, name, '*.png'))
                imagepaths += glob.glob(os.path.join(self.root, name, '*.jpg'))

            # print(len(imagepaths), imagepaths)

            # 打乱
            random.shuffle(imagepaths)

            # 以imagepath, label的形式保存
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                writer = csv.writer(f)
                for imagepath in imagepaths:
                    # 取出类别名称,以得到label
                    name = imagepath.split(os.sep)[-2]
                    label = self.name2label[name]
                    writer.writerow([imagepath, label])
                print('write into csv file:', filename)

        # 存在则直接加载出来
        imagepaths, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                imagepath, label = row
                label = int(label)

                imagepaths.append(imagepath)
                labels.append(label)

        # 保证图片路径的个数与标签个数相等
        assert len(imagepaths) == len(labels)

        return imagepaths, labels

    # 数据集大小
    def __len__(self):
        return len(self.imagepaths)

    # 把像素值从[-1, 1恢复到原来的值,方便显示
    def denormalize(self, x_hat):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # x_hat = (x - mean) / std
        # x = x_hat * std + mean
        # x: [c, h, w]
        # mean: [3] => [3, 1, 1]
        mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
        std = torch.tensor(std).unsqueeze(1).unsqueeze(1)

        x = x_hat * std + mean

        return x

    # 取图片路径与标签
    def __getitem__(self, idx):
        imagepath, label = self.imagepaths[idx], self.labels[idx]
        # 转换器
        tf = transforms.Compose([
            # 读取图片
            lambda x: Image.open(x).convert('RGB'),  # string path => image data
            # 更改大小
            transforms.Resize((self.resize, self.resize)),
            # 随机旋转15度
            # transforms.RandomRotation(15),
            # 中心裁剪
            # transforms.CenterCrop(self.resize),
            # 转tensor
            transforms.ToTensor(),
            # 把图片转到[-1,1]上,方便训练
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

        ])

        image = tf(imagepath)
        label = torch.tensor(label)

        return image, label


def main():

    import visdom
    import time
    viz = visdom.Visdom()

    db = ADS_B('ADS-B', 224, 'train')
    # 用迭代器迭代,调用了getitem方法
    x, y = next(iter(db))
    # sample: torch.Size([3, 224, 224]) torch.Size([]) tensor(7)
    print('sample:', x.shape, y.shape, y)
    # 显示i一张图片
    viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))

    # 使用torchvision.datasets.ImageFolder做数据集
    # tf = transforms.Compose([
    #     transforms.Resize((64, 64)),
    #     transforms.ToTensor()
    # ])
    # db = torchvision.datasets.ImageFolder(root='ADS-B', transform=tf)
    # print(db.class_to_idx)

    # 使用DataLoader加载数据集
    loader = DataLoader(db, batch_size=32, shuffle=True)
    # loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8) # 多线程取数据
    # len of loader: 75 = 400 * 10 * 0.6 / 32 = batch数量
    print('len of loader:', len(loader))
    for x, y in loader:
        viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
        viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))

        time.sleep(10)

if __name__ == '__main__':
    main()

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值