pytorch进阶-图像预处理transform和数据读取Dataloader

东阳的学习记录,坚持就是胜利!

Dataloader和transform的关系

dataloader是一个多进程迭代器

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False)[source]
Data loader. Combines a dataset and a sampler, and providessingle- or multi-process iterators over the dataset.

建议自己debug一遍,在这里不详细写了:

  • transform:在__getitem__()中调用
    def __getitem__(self, index):
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')     # 0~255

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label
  • collate_fn:collate_fn的作用是将生成的data列表转化成(B, C, H, W)的四维张量模式。
    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)

在这里插入图片描述

transform

例子:

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),

    # 1 Pad
    transforms.Pad(padding=32, fill=(255, 0, 0), padding_mode='constant'),
    transforms.Pad(padding=(8, 64), fill=(255, 0, 0), padding_mode='constant'),
    transforms.Pad(padding=(8, 16, 32, 64), fill=(255, 0, 0), padding_mode='constant'),
    transforms.Pad(padding=(8, 16, 32, 64), fill=(255, 0, 0), padding_mode='symmetric'),
 
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std)
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

transform的常用方法

  • 数据中心化
  • 数据标准化
  • 缩放
  • 裁剪
  • 旋转
  • 翻转
  • 填充
  • 噪声添加
  • 灰度变换
  • 线性变换
  • 仿射变换(对形状进行变换)
  • 亮度、饱和度及对比度变换

怎么选择适合的数据增强/处理方法

原则: 让 训练集 与 测试集 更接近(而不是瞎增强)

  • 空间位置:平移、旋转
  • 色彩:灰度图,色 彩 抖动
  • 形状:仿射变换
  • 上下文场景:遮挡 , 填充

Dataset类

  • getitem():接收一个索引,返 回 一个 样 本
  • collate_fn():collate_fn的作用是将生成的data列表转化成(B, C, H, W)的四维张量模式。

Dataloader类

  • dataset: D ataset 类 , 决定数据从 哪 读取
    及如何读取
  • batchsize : 批大小
  • num _works: 是否多进程读取数 据
  • shuffle: 每个epoch是否乱序
  • drop_last:当 样 本数 不 能被 batchsize整除时,是否舍弃最后一批数据

在Dataset中加入数据增强的方法

例子:num_rep:表示数据集扩张倍数

    def __len__(self):
        return len(self.annotations) * self.num_rep

    def __getitem__(self, item) -> (Study, Any, (torch.Tensor, torch.Tensor)):
        item = item % len(self.annotations)
        key, (v_annotation, d_annotation) = self.annotations[item]
        return self.studies[key[0]], key, v_annotation, d_annotation
不要在getitem中写太多的逻辑

可以使用一个静态方法,避免多次重复调用

@staticmethod
    def get_img_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))

                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = rmb_label[sub_dir]
                    data_info.append((path_img, int(label)))

        return data_info
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

东阳z

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值