pytorch 数据模块读图和显示数据增强效果cutmix randpatch

例子 cutmix 数据增强,利用pytorch读取,显示处理过后的图片,

# -*- coding: utf-8 -*-
# @Time    : 18-3-15 下午6:43
# @Author  : zhwzhong
# @File    : model.py
# @Contact : zhwzhong.hit@gmail.com
# @Function:
from torchvision import transforms, datasets as ds
import torchvision as tv
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torch

transform = transforms.Compose(
    [
       #transforms.Resize(cfg.INPUT.SIZE_TRAIN),
       #transforms.RandomHorizontalFlip(p=cfg.INPUT.PROB), #
       #transforms.Pad(cfg.INPUT.PADDING),
       #transforms.RandomCrop(cfg.INPUT.SIZE_TRAIN),    
        transforms.ToTensor()
    ]
)
train_set = tv.datasets.ImageFolder(root='/home/shiyy/nas/data/yidongface/yidong_recognition_img_800', transform=transform)
data_loader = DataLoader(dataset=train_set,batch_size=8,shuffle=True)

to_pil_image = transforms.ToPILImage()
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

for input, target in data_loader:
    # 方法1:Image.show()
    # transforms.ToPILImage()中有一句
    # npimage = np.transpose(pic.numpy(), (1, 2, 0))
    # 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维
    print(target)
    r = np.random.rand(1)  #0-1 之间的小数 array([0.33473484])
    beta=1.0
    cutmix_prob = 1
    if beta > 0 and r < cutmix_prob:
        # generate mixed sample
        lam = np.random.beta(beta, beta)
        rand_index = torch.randperm(input.size()[0]).cuda()
        target_a = target
        target_b = target[rand_index]
        bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam)
        input[:, :, bbx1:bbx2, bby1:bby2] = input[rand_index, :, bbx1:bbx2, bby1:bby2]
        # adjust lambda to exactly match pixel ratio
        lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2]))
        
    print(lam)
    print(target_a)
    print(target_b)
    print(input.shape)
    for i in range(len(target_b)):
        image = to_pil_image(input[i])
        image.show()
    # image.save("1.jpg")

    # # 方法2:plt.imshow(ndarray)
    # image = image[0]  # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
    # image = image.numpy()  # FloatTensor转为ndarray
    # image = np.transpose(image, (1, 2, 0))  # 把channel那一维放到最后
    #
    # # 显示图片
    # # plt.savefig("filename.png")
    # plt.imshow(image)
    # plt.show()
    break

dataloader 后显示图片 https://blog.csdn.net/qq_34535410/article/details/79574327
cutmix https://blog.csdn.net/weixin_38715903/article/details/103999227

RandPatch 显示图片, 随机batch 显示的图片,复制一部分到一个列表中,随机黏贴到 其他图片中,
类似随机擦除,但是随机擦除类似于 全黑色,固定的颜色

from torchvision import transforms, datasets as ds
import torchvision as tv
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
import math
from collections import deque
from PIL import Image
class RandomPatch(object):
    """Random patch data augmentation.
    There is a patch pool that stores randomly extracted pathces from person images.
    For each input image, RandomPatch
        1) extracts a random patch and stores the patch in the patch pool;
        2) randomly selects a patch from the patch pool and pastes it on the
           input (at random position) to simulate occlusion.
    Reference:
        - Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
        - Zhou et al. Learning Generalisable Omni-Scale Representations
          for Person Re-Identification. arXiv preprint, 2019.

          min_sample_size   和 batch 有关系
          batch 64  min_sample_size=60  61张图片原来的样子(复制一部分到 列表中,随机抽取黏贴到另外三张图片中), 3张处理后的图片
    """

    def __init__(self, prob_happen=1, pool_capacity=50000, min_sample_size=5,
                 patch_min_area=0.01, patch_max_area=0.5, patch_min_ratio=0.1,
                 prob_rotate=0.5, prob_flip_leftright=0.5,
                 ):

        self.prob_happen = prob_happen

        self.patch_min_area = patch_min_area
        self.patch_max_area = patch_max_area
        self.patch_min_ratio = patch_min_ratio

        self.prob_rotate = prob_rotate
        self.prob_flip_leftright = prob_flip_leftright

        self.patchpool = deque(maxlen=pool_capacity)

        self.min_sample_size = min_sample_size

    def generate_wh(self, W, H):
        area = W * H
        for attempt in range(100):
            target_area = random.uniform(self.patch_min_area, self.patch_max_area) * area
            aspect_ratio = random.uniform(self.patch_min_ratio, 1. / self.patch_min_ratio)
            h = int(round(math.sqrt(target_area * aspect_ratio)))
            w = int(round(math.sqrt(target_area / aspect_ratio)))
            if w < W and h < H:
                return w, h
        return None, None

    def transform_patch(self, patch):
        if random.uniform(0, 1) > self.prob_flip_leftright:
            patch = patch.transpose(Image.FLIP_LEFT_RIGHT)
        if random.uniform(0, 1) > self.prob_rotate:
            patch = patch.rotate(random.randint(-10, 10))
        return patch

    def __call__(self, img):
        W, H = img.size  # original image size

        # collect new patch
        w, h = self.generate_wh(W, H)
        if w is not None and h is not None:
            x1 = random.randint(0, W - w)
            y1 = random.randint(0, H - h)
            new_patch = img.crop((x1, y1, x1 + w, y1 + h))  #剪切一部分图片
            self.patchpool.append(new_patch)
        print("**************************")
        if len(self.patchpool) < self.min_sample_size:
            print(len(self.patchpool))
            # print(np.self.patchpool)
            print(self.min_sample_size)
            return img

        if random.uniform(0, 1) > self.prob_happen:
            return img

        # paste a randomly selected patch on a random position
        patch = random.sample(self.patchpool, 1)[0]
        patchW, patchH = patch.size
        x1 = random.randint(0, W - patchW)
        y1 = random.randint(0, H - patchH)
        patch = self.transform_patch(patch)
        img.paste(patch, (x1, y1))

        return img


# ###数据增强显示, 图片
transform = transforms.Compose(
    [
        RandomPatch(),
        transforms.ToTensor()
    ]
)
train_set = tv.datasets.ImageFolder(root='/home/shiyy/nas/data/yidongface/yidong_recognition_img_800', transform=transform)
data_loader = DataLoader(dataset=train_set,batch_size=8,shuffle=False,)

to_pil_image = transforms.ToPILImage()
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

for input, target in data_loader:
    for i in range(len(target)):
        image = to_pil_image(input[i])
        image.show()
    # image.save("1.jpg")
    # # 显示图片

    plt.imshow(image)
    plt.show()
    break
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
CutMix是一种数据增强技术,可以在训练神经网络时,将两个不同的像混合在一起,生成一个新的像。这种技术可以增加模型的鲁棒性和泛化能力。 以下是使用PyTorch实现CutMix数据增强的代码: ```python import torch import numpy as np import random def cutmix_data(x, y, alpha=1.0): lam = np.random.beta(alpha, alpha) batch_size = x.size()[0] index = torch.randperm(batch_size) y_a, y_b = y, y[index] bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam) x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2] lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2])) return x, y_a, y_b, lam def rand_bbox(size, lam): W = size[2] H = size[3] cut_rat = np.sqrt(1. - lam) cut_w = np.int(W * cut_rat) cut_h = np.int(H * cut_rat) # uniform cx = np.random.randint(W) cy = np.random.randint(H) bbx1 = np.clip(cx - cut_w // 2, 0, W) bby1 = np.clip(cy - cut_h // 2, 0, H) bbx2 = np.clip(cx + cut_w // 2, 0, W) bby2 = np.clip(cy + cut_h // 2, 0, H) return bbx1, bby1, bbx2, bby2 ``` 正则交叉熵损失函数是一种可以减少标签噪声对模型训练的影响的损失函数。以下是使用PyTorch实现正则交叉熵损失函数的代码: ```python import torch.nn.functional as F def reg_cross_entropy_loss(input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', reg_lambda=0.1): logp = F.log_softmax(input, dim=1) loss = F.nll_loss(logp, target, weight, size_average, ignore_index, reduce, reduction) reg_loss = torch.mean(torch.sum(torch.square(torch.exp(logp)), dim=1)) return loss + reg_lambda * reg_loss ``` 在调用此函数时,您可以指定reg_lambda参数来控制正则化的程度。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值