例子 cutmix 数据增强,利用pytorch读取,显示处理过后的图片,
# -*- coding: utf-8 -*-
# @Time : 18-3-15 下午6:43
# @Author : zhwzhong
# @File : model.py
# @Contact : zhwzhong.hit@gmail.com
# @Function:
from torchvision import transforms, datasets as ds
import torchvision as tv
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torch
transform = transforms.Compose(
[
#transforms.Resize(cfg.INPUT.SIZE_TRAIN),
#transforms.RandomHorizontalFlip(p=cfg.INPUT.PROB), #
#transforms.Pad(cfg.INPUT.PADDING),
#transforms.RandomCrop(cfg.INPUT.SIZE_TRAIN),
transforms.ToTensor()
]
)
train_set = tv.datasets.ImageFolder(root='/home/shiyy/nas/data/yidongface/yidong_recognition_img_800', transform=transform)
data_loader = DataLoader(dataset=train_set,batch_size=8,shuffle=True)
to_pil_image = transforms.ToPILImage()
def rand_bbox(size, lam):
W = size[2]
H = size[3]
cut_rat = np.sqrt(1. - lam)
cut_w = np.int(W * cut_rat)
cut_h = np.int(H * cut_rat)
# uniform
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
for input, target in data_loader:
# 方法1:Image.show()
# transforms.ToPILImage()中有一句
# npimage = np.transpose(pic.numpy(), (1, 2, 0))
# 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维
print(target)
r = np.random.rand(1) #0-1 之间的小数 array([0.33473484])
beta=1.0
cutmix_prob = 1
if beta > 0 and r < cutmix_prob:
# generate mixed sample
lam = np.random.beta(beta, beta)
rand_index = torch.randperm(input.size()[0]).cuda()
target_a = target
target_b = target[rand_index]
bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam)
input[:, :, bbx1:bbx2, bby1:bby2] = input[rand_index, :, bbx1:bbx2, bby1:bby2]
# adjust lambda to exactly match pixel ratio
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2]))
print(lam)
print(target_a)
print(target_b)
print(input.shape)
for i in range(len(target_b)):
image = to_pil_image(input[i])
image.show()
# image.save("1.jpg")
# # 方法2:plt.imshow(ndarray)
# image = image[0] # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
# image = image.numpy() # FloatTensor转为ndarray
# image = np.transpose(image, (1, 2, 0)) # 把channel那一维放到最后
#
# # 显示图片
# # plt.savefig("filename.png")
# plt.imshow(image)
# plt.show()
break
dataloader 后显示图片 https://blog.csdn.net/qq_34535410/article/details/79574327
cutmix https://blog.csdn.net/weixin_38715903/article/details/103999227
RandPatch 显示图片, 随机batch 显示的图片,复制一部分到一个列表中,随机黏贴到 其他图片中,
类似随机擦除,但是随机擦除类似于 全黑色,固定的颜色
from torchvision import transforms, datasets as ds
import torchvision as tv
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
import math
from collections import deque
from PIL import Image
class RandomPatch(object):
"""Random patch data augmentation.
There is a patch pool that stores randomly extracted pathces from person images.
For each input image, RandomPatch
1) extracts a random patch and stores the patch in the patch pool;
2) randomly selects a patch from the patch pool and pastes it on the
input (at random position) to simulate occlusion.
Reference:
- Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
- Zhou et al. Learning Generalisable Omni-Scale Representations
for Person Re-Identification. arXiv preprint, 2019.
min_sample_size 和 batch 有关系
batch 64 min_sample_size=60 61张图片原来的样子(复制一部分到 列表中,随机抽取黏贴到另外三张图片中), 3张处理后的图片
"""
def __init__(self, prob_happen=1, pool_capacity=50000, min_sample_size=5,
patch_min_area=0.01, patch_max_area=0.5, patch_min_ratio=0.1,
prob_rotate=0.5, prob_flip_leftright=0.5,
):
self.prob_happen = prob_happen
self.patch_min_area = patch_min_area
self.patch_max_area = patch_max_area
self.patch_min_ratio = patch_min_ratio
self.prob_rotate = prob_rotate
self.prob_flip_leftright = prob_flip_leftright
self.patchpool = deque(maxlen=pool_capacity)
self.min_sample_size = min_sample_size
def generate_wh(self, W, H):
area = W * H
for attempt in range(100):
target_area = random.uniform(self.patch_min_area, self.patch_max_area) * area
aspect_ratio = random.uniform(self.patch_min_ratio, 1. / self.patch_min_ratio)
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < W and h < H:
return w, h
return None, None
def transform_patch(self, patch):
if random.uniform(0, 1) > self.prob_flip_leftright:
patch = patch.transpose(Image.FLIP_LEFT_RIGHT)
if random.uniform(0, 1) > self.prob_rotate:
patch = patch.rotate(random.randint(-10, 10))
return patch
def __call__(self, img):
W, H = img.size # original image size
# collect new patch
w, h = self.generate_wh(W, H)
if w is not None and h is not None:
x1 = random.randint(0, W - w)
y1 = random.randint(0, H - h)
new_patch = img.crop((x1, y1, x1 + w, y1 + h)) #剪切一部分图片
self.patchpool.append(new_patch)
print("**************************")
if len(self.patchpool) < self.min_sample_size:
print(len(self.patchpool))
# print(np.self.patchpool)
print(self.min_sample_size)
return img
if random.uniform(0, 1) > self.prob_happen:
return img
# paste a randomly selected patch on a random position
patch = random.sample(self.patchpool, 1)[0]
patchW, patchH = patch.size
x1 = random.randint(0, W - patchW)
y1 = random.randint(0, H - patchH)
patch = self.transform_patch(patch)
img.paste(patch, (x1, y1))
return img
# ###数据增强显示, 图片
transform = transforms.Compose(
[
RandomPatch(),
transforms.ToTensor()
]
)
train_set = tv.datasets.ImageFolder(root='/home/shiyy/nas/data/yidongface/yidong_recognition_img_800', transform=transform)
data_loader = DataLoader(dataset=train_set,batch_size=8,shuffle=False,)
to_pil_image = transforms.ToPILImage()
def rand_bbox(size, lam):
W = size[2]
H = size[3]
cut_rat = np.sqrt(1. - lam)
cut_w = np.int(W * cut_rat)
cut_h = np.int(H * cut_rat)
# uniform
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
for input, target in data_loader:
for i in range(len(target)):
image = to_pil_image(input[i])
image.show()
# image.save("1.jpg")
# # 显示图片
plt.imshow(image)
plt.show()
break