数据增强 cutout改进imbalance

总结:

数据加载的时候会同时加载数据和数据增强方式。数据增强的时候会默认调用加载数据集时的getitem方法,去获取对应的数据和标签。

定义好普通的transformers

    transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

结合有Cutout的自定义Transformer

自定义的transformer可以接收两个参数:img,label。

传统的transformer只接受一个参数:img

class ConditionalTransform:
    def __init__(self, transform, num_per_cls_dict):
        self.transform = transform
        self.num_per_cls_dict = num_per_cls_dict
        self.n_holes_dict = {}

    def __call__(self, img, label):
        total_samples = sum(self.num_per_cls_dict.values())
        cls_num_list = list(self.num_per_cls_dict.values())
        per_cls_weights = 1.0 / np.array(cls_num_list)
        per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
        per_cls_weights = torch.FloatTensor(per_cls_weights).to(torch.device('cpu'))  # 假设我们在 CPU 上运行

        # 计算 n_holes 的数量
        n_holes = 1 + int(per_cls_weights[label] * 3)  # 确保 n_holes 在 1 到 4 之间
        n_holes = min(4, max(1, n_holes))

        # 保存每个类别的 n_holes 数量
        self.n_holes_dict[label] = n_holes

        img = self.transform(img)
        cutout_transform = Cutout(n_holes=n_holes, length=16)
        return cutout_transform(img)

cutout

class Cutout(object):
    """Randomly mask out one or more patches from an image.

    Args:
        n_holes (int): Number of patches to cut out of each image.
        length (int): The length (in pixels) of each square patch.
    """
    def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):
        """
        Args:
            img (Tensor): Tensor image of size (C, H, W).
        Returns:
            Tensor: Image with n_holes of dimension length x length cut out of it.
        """
        h = img.size(1)
        w = img.size(2)

        mask = np.ones((h, w), np.float32)

        for n in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)

            mask[y1: y2, x1: x2] = 0.

        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img = img * mask

        return img

数据加载

标准的CIFAR10

部分代码:

class CIFAR10(VisionDataset):

    def __init__(
            self,
            root: str,
            train: bool = True,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            download: bool = False,
    ) -> None:

        super(CIFAR10, self).__init__(root, transform=transform,
                                      target_transform=target_transform)

        self.train = train  # training set or test set


        self.data: Any = []
        self.targets = []

        # now load the picked numpy arrays
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(self.root, self.base_folder, file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.data.append(entry['data'])
                if 'labels' in entry:
                    self.targets.extend(entry['labels'])
                else:
                    self.targets.extend(entry['fine_labels'])

        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC

        self._load_meta()

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

__getitem__通常是会被自动调用的?

以下是一些常见的情况:

1. 索引操作

当你使用索引操作(如 dataset[index])访问数据集对象时,__getitem__ 方法会被自动调用。例如:

dataset = CIFAR10(root='data', train=True, download=True)
image, label = dataset[0]  # 这里会自动调用 dataset.__getitem__(0)

在上面的代码中,当你尝试访问 dataset[0] 时,__getitem__ 方法会被调用,返回第一个图像和标签。

2. DataLoader 一起使用

在深度学习中,__getitem__ 方法经常与 PyTorch 的 DataLoader 类一起使用。DataLoader 会在训练或测试过程中自动调用数据集的 __getitem__ 方法来获取数据。例如:

from torch.utils.data import DataLoader

dataset = CIFAR10(root='data', train=True, download=True)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

for images, labels in dataloader:
    # 这里会自动调用 dataset.__getitem__(index) 来获取数据
    # 进行训练或测试的相关操作
    pass

在这个例子中,DataLoader 会在每次迭代时调用 __getitem__ 方法来获取数据集中的样本。

3. 自定义数据集

在创建自定义数据集时,你可以通过实现 __getitem__ 方法来定义如何访问和处理数据。例如:

from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        # 这里定义如何获取和处理数据
        return self.data[index], self.labels[index]

data = ...  # 数据
labels = ...  # 标签
dataset = MyDataset(data, labels)

在这个自定义数据集中,__getitem__ 方法定义了如何根据索引访问数据和标签。

总结

__getitem__ 方法是一种魔法方法(magic method),在特定场景下会被自动调用,尤其是当你使用索引操作访问对象或与某些库(如 PyTorch 的 DataLoader)一起使用时。通过实现 __getitem__ 方法,你可以自定义对象的索引行为,从而更方便地处理和访问数据。

自定义CIFAR10

class CustomCIFAR10(datasets.CIFAR10):
    cls_num = 10

    def __init__(self, root, imb_type='exp', imb_factor=0.01, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super(CustomCIFAR10, self).__init__(root, train=train, transform=None, target_transform=target_transform, download=download)
        self.num_per_cls_dict = {}
        if imb_type is not None:
            img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor)
            self.gen_imbalanced_data(img_num_list)

        # 初始化条件变换,传入类别分布
        if transform is None:
            transform = transforms.Compose([
                transforms.ToTensor()
            ])

        # 初始化条件变换,传入类别分布
        self.transform = ConditionalTransform(transform, self.num_per_cls_dict)

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img, target)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

和官方的cifar10的区别在与,自定义的getitem方法里的transformer里多了一个参数target

官方的:

img = self.transform(img)

自制的:

img = self.transform(img, target)

如果不重写getitem方法,就会出现这样的报错:

  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torchvision/datasets/cifar.py", line 120, in __getitem__
    img = self.transform(img)
TypeError: __call__() missing 1 required positional argument: 'label'

  • 8
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值