文章目录
一、数据加载和处理
torch.util.data
模块为我们提供了一种十分方便的数据读取机制,即使用Dataset
类与DataLoader
类的组合,来得到数据迭代器。在训练或预测时,数据迭代器能够输出每一批次所需的数据,并且对数据进行相应的预处理与数据增强操作:Dataset
承担了你的自定义的数据(可以是任何一种格式)与标准 PyTorch 张量之间的转换任务DataLoader
可以在后台生成子进程来从Dataset
中加载数据,使数据准备就绪并在循环可以使用后立即等待训练
- 可以通过
torch.nn.DataParallel
和torch.distributed
来使用多个 GPU
二、torch.utils
数据读取流程如下
1、torch.utils.data
a、torch.utils.data.Dataset
- 定义了数据集的内容,它相当于一个类似
列表
的数据结构,具有确定的长度,能够用索引
获取数据集中的元素 - 解决了 从哪读数据(data_dir) 和 怎么读数据及数据增强?(getitem&transform) 的问题
- Dataset 的抽象类, 所有自定义的 Dataset 都需要继承它,并实现
__init__
初始化方法、__len__
和__getitem__
成员方法__init__
初始化方法:指定 数据所在路径和对应标签 成员变量(以dir
或txt
形式给出)、数据增强成员变量(transform)
及生成数据路径和标签组成的列表(data_info
),尽量统一使用txt_path
的方式,因为数据量一多直接读取路径还是挺费时的__getitem__
成员方法:接收一个索引,返回一个样本对(img_data, img_label)
,可在此方法中实现数据增强(transform)
。可以使用索引的方式获取某条数据(魔法方法__getitem__
可以让对象实现迭代功能(不需要实现__iter__ 和 __next__
迭代器协议了),这样就可以使用for…in…
来迭代该对象了)__len__
成员方法:获取数据集的长度,使用 len 方法读取初始化生成的数据列表data_info
的长度,主要用于后面dataloader
的sampler
去获取随机的index
。可以使用len(train_data)
获取数据集的大小
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from PIL import Image
# 数据预处理设置:事先从训练集中算出每个通道的均值和标准差
norm_mean = [0.4948052, 0.48568845, 0.44682974]
norm_std = [0.24580306, 0.24236229, 0.2603115]
train_transform = transforms.Compose([
transforms.Resize(32), # 将图像最短边缩小至 32,宽高比例不变
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std) # 减去每个通道的均值,除以每个通道的标准差
])
valid_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std) # 减去每个通道的均值,除以每个通道的标准差
])
# 自定义数据集(可以单独放在一个包中)
# 构建完实例对象可以查看数据集大小 len(train_data)、进行索引 train_data[i]
class KedaClsDataset(Dataset):
# 1、数据从哪读取?
def __init__(self, txt_path, transform=None):
super().__init__() # py3 super 简化
self.transform = transform
self.data_info = [] # data_info 存储所有图片路径和标签,在 DataLoader 中通过 index 读取样本
with open(txt_path, 'r', encoding='utf-8') as f:
for each_line in f:
img_path, img_label = each_line.strip().split()
self.data_info.append((img_path, int(img_label)))
# 2、怎么读数据?做哪些数据增强?
def __getitem__(self, index):
img_path, img_label = self.data_info[index] # 根据下标取数据集中的元素
img_rgb = kd_read_image(img_path) # pil image hwc rgb
if self.transform is not None:
img_rgb = self.transform(img_rgb) # 在这里做 transform,转为 tensor 等
return img_rgb, img_label
# 3、确定数据集的大小
def __len__(self):
return len(self.data_info)
# 构建 MyDataset 实例对象,构建完可以查看数据集大小 len(train_data)、进行索引 train_data[i]
train_data = MyDataset(txt_path=train_txt_path, transform=train_transform)
valid_data = MyDataset(txt_path=valid_txt_path, transform=valid_transform)
# 创建数据集的 filelist
import os
import codecs
import numpy as np
def get_all_labels(label_dir):
all_label = []
for folder_dir, sub_folder_dir, file_names in os.walk(label_dir):
for name in file_names:
all_label.append(os.path.join(folder_dir, name))
return all_label
def create_file_list(img_dir, list_save_path):
"""
Args:
img_dir: where the train/val data stored
list_save_path: where the train/val list stored
Returns:
None
"""
img_list = []
cnt = 0
wrong_cnt = 0
label_img_list = get_all_labels(img_dir)
for img_path in label_img_list:
try:
img_label = img_path.split('/')[-2]
if img_label == 'ants':
img_list.append(img_path + ' 0\n') # for lmdb use img_path[1:] remove /
else:
img_list.append(img_path + ' 1\n') # for lmdb use img_path[1:] remove /
cnt += 1
except Exception as e:
print("Error img is {}".format(img_path))
print("Error reason is {}".format(e))
wrong_cnt += 1
continue
np.random.shuffle(img_list)
with codecs.open(list_save_path, 'w', 'utf-8') as f1:
f1.writelines(img_list)
print('Successfully create {} num img to list'.format(cnt))
print('Failed create {} num img to list'.format(wrong_cnt))
if __name__ == "__main__":
train_dir = "/data/hymenoptera_data/train"
val_dir = "/data/hymenoptera_data/val"
train_list_path = os.path.join(train_dir.split('train')[0], 'train_list.txt')
val_list_path = os.path.join(val_dir.split('val')[0], 'val_list.txt')
create_file_list(train_dir, train_list_path)
create_file_list(val_dir, val_list_path)
b、torch.utils.data.DataLoader
- 在实际项目中,如果数据量很大,考虑到内存有限、I/O 速度等问题,在训练过程中不可能一次性的将所有数据全部加载到内存中,也不能只用一个进程去加载,所以就需要多进程、迭代加载,而
DataLoader
就是基于这些需要被设计出来的:- 它定义了按
batch
加载数据集的方法,它是一个实现了__iter__
方法的可迭代对象,每次迭代输出一个batch
的数据 - 它能够控制
batch
的大小、batch
中元素的采样方法,以及将batch
结果整理成模型所需输入形式的方法,并且能够使用多进程读取数据
- 它定义了按
- 解决了 读哪些数据,怎么采样?(sampler 输出的 index) 和 怎么组成 batch?(collate_fn) 的问题
- 如何自定义
collate_fn
? - 如何自定义
customized_sampler
?
- 如何自定义
sampler 与 batch_sampler 的区别
:当采用auto_collation
时,采用batch_sampler
,它在 sampler 的基础上封装了一个 batch 抽取的功能,一次 yield 一个 batch 的 index,而样本采样的顺序取决于RandomSampler
和SequentialSample
if shuffle:
sampler = RandomSampler(dataset, generator=generator)
else:
sampler = SequentialSampler(dataset)
# BatchSampler
def __iter__(self) -> Iterator[List[int]]:
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
# 当for循环结束,且batch的数量又不满足batchsize时,则进入以下代码, 其实就是 drop_last 的逻辑代码
if len(batch) > 0 and not self.drop_last:
yield batch
# 通常情况下使用默认的 collate_fn 即可,需要自定义的场景使用下面的函数进行更改
def batch_fn(self, batch_samples):
# zip(*)函数:使用 * 进行解包,把数组中的每个对应元素提取出来,组成一个新的元组的可迭代对象
# imgs = (img1, img2, ..., imgN); labels = (label1, label2, ..., labelN)
imgs, labels = zip(*batch_samples)
# imgs 增加一个 batch 维度,256,3,224,224; labels 转换为 tensor 类型,256
return torch.stack(imgs, dim=0), torch.tensor(labels, dtype=torch.long)
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
# 重要参数解析
- dataset(Dataset 类):需要加载的数据集对象,决定数据从哪读取数据及如何读取
- batch_size(int, optional):每个 batch 加载多少个样本
- shuffe(bool, optional):设置为 True 时会在每个 epoch 重新打乱数据
- num_workers(int, optional):用多少个子进程加载数据,0 表示数据将在主进程中加载,大于 1 时使用多进程读取数据
- prefetch_factor(int, optional):Number of batches loaded in advance by each worker
- drop_last(bool, optional):当样本数不能被 batch_size 整除时,是否舍弃最后一批数据
- collate_fn(callable, optional):用这个函数来打包 batch(merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.)
- sampler(Sampler or Iterable, optional):defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, :attr:`shuffle` must not be specified.
- batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but returns a batch of indices at a time. Mutually exclusive with :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
- pin_memory:
- 锁页内存,当计算机的内存充足的时候可设置为 True,此时生成的 Tensor 是属于内存中的锁页内存(显卡中均为锁页内存),
这样将内存的 Tensor 转到 GPU 显存就会更快一些。
- 主机中的内存,有锁页和不锁页两种存在方式,锁页内存存放的内容在任何情况下都不会与主机的虚拟内存进行交换
(注:虚拟内存就是硬盘),而不锁页内存在主机内存不足时,数据会存放在虚拟内存中
class DataLoader(object):
def __init__(self, dataset, batch_size, collate_fn, shuffle = True, drop_last = False):
self.dataset = dataset
self.sampler =torch.utils.data.RandomSampler if shuffle else \
torch.utils.data.SequentialSampler
self.batch_sampler = torch.utils.data.BatchSampler
self.sample_iter = self.batch_sampler(
self.sampler(range(len(dataset))),
batch_size = batch_size, drop_last = drop_last)
def __next__(self):
# 1、读哪些数据? sampler 输出的 indices
indices = next(self.sample_iter)
# 2、 如何组成 batch ? imgs = torch.stack([X[1], X[5]]), labels = torch.stack([Y[1], Y[5]])
batch = self.collate_fn([self.dataset[i] for i in indices])
return batch
# 构建 DataLoder 实例对象
train_loader = DataLoader(dataset=train_data, batch_size=train_bs, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=valid_bs)
# 输出一个 batch 数据和标签
# 首先,通过内置函数 iter 将可迭代对象列表转换为迭代器对象
# 然后,通过内置函数 next 依次获取迭代器的下一个元素
# print('One batch tensor data: ', next(iter(train_loader)))
for epoch in range(max_epoch):
for i, data in enumerate(train_loader):
images, labels = data # 获取训练集的批图像和标签
# validate the model
if (epoch + 1) % val_interval == 0:
with torch.no_grad():
for j, data in enumerate(valid_loader):
images, labels = data # 获取验证集的批图像和标签
# 多 GPU :训练迭代器改用分布式采样器,同时 shuffle 设置为 false
self.train_loader = torch.utils.data.DataLoader(
dataset=self.train_data,
batch_size=self.cfg.DATA.TRAIN_BATCH_SIZE,
shuffle=False, # ddp 中 shuflle 改为 false
num_workers=self.cfg.DATA.NUM_WORKERS,
prefetch_factor=self.cfg.DATA.PREFETCH_FACTOR,
pin_memory=self.cfg.DATA.PIN_MEMORY,
sampler=torch.utils.data.distributed.DistributedSampler(self.train_data) # NOTE:改成分布式采样,调试的时候可以注释掉
# collate_fn=self.batch_fn,
)
self.test_loader = torch.utils.data.DataLoader(
dataset=self.test_data,
batch_size=self.cfg.DATA.TEST_BATCH_SIZE,
num_workers=self.cfg.DATA.NUM_WORKERS,
prefetch_factor=self.cfg.DATA.PREFETCH_FACTOR,
pin_memory=self.cfg.DATA.PIN_MEMORY,
# collate_fn=self.batch_fn,
)
c、collate_fn 和 customized_sampler
- 不均衡样本采样策略:https://github.com/ufoym/imbalanced-dataset-sampler
- WeightedRandomSampler: https://tingsongyu.github.io/PyTorch-Tutorial-2nd/chapter-3/3.3-dataset-useful-api.html
2、torch.utils.model_zoo
# 在给定 URL 上加载 Torch 序列化对象
torch.utils.model_zoo.load_url(url, model_dir=None, map_location=None, progress=True, check_hash=False)
# 重要参数解析
- url(string):要下载对象的 URL
- model_dir(string,optional):保存对象的目录,默认值为 ~/.cache/torch/hub/checkpoints,可以使用 $TORCH_MODEL_ZOO 环境变量来覆盖默认目录
# 加载预训练模型
state_dict = torch.utils.model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
# 1.4 版本,放在了 hub
state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
三、torchvision
torchvision
库就是常用数据集 + 常见网络模型 +常用图像处理方法
的集合
1、torchvision.datasets
- 提供常用的数据集加载,设计上都是继承
torch.utils.data.Dataset
- 主要包括:
MNIST、CIFAR10/100、VOC、ImageNet、COCO
等 - 基本都有两个参数:
transform
和target_transform
分别对输入(img
)和目标(label
)做变换# 数据集路径(root)中如果不存在数据,则会进行下载,train 参数改为 False 则会下载测试集 mnist_train = torchvision.datasets.MNIST(root='path/to/imagenet_root/', train=True, transform=None, target_transform=None, download=False) # 送入 DataLoader 产生 batch 数据(shuffled) data_loader_train = torch.utils.data.DataLoader(mnist_train, batch_size=4, shuffle=True, num_workers=args.nThreads)
2、torchvision.models
- 提供深度学习中各种经典网络的
网络结构及预训练好的模型
- 主要包括:
ResNet
系列、DenseNet
系列、MobileNet
系列、ShuffleNet
系列等
import torchvision.models as models
from torch import nn
# 加载网络结构和预训练的模型,若不存在则会下载(默认保存在 ~/.cache/torch/hub/checkpoints 下)
res18 = models.resnet18(pretrained=True) # 从 torch 官网下载预训练模型
# 从本地加载预训练模型
res18 = models.resnet18(pretrained=False) # 默认为 False,不进行模型下载
pretrained_model = os.path.join(BASEDIR, "resnet18-5c106cde.pth")
res18.load_state_dict(torch.load(pretrained_model))
# 修改最后的全连接层为 10 分类问题(默认 ImageNet 为 1000 分类)
num_ftrs = res18.fc.in_features
res18.fc = nn.Linear(num_ftrs, 10)
# 若想要从 10 类预训练好的模型进行 finetune,需要进行如下操作:
res18 = models.resnet18(pretrained=False) # 默认为 False,不进行模型下载
num_ftrs = res18.fc.in_features
res18.fc = nn.Linear(num_ftrs, 10)
pretrained_model = os.path.join(BASEDIR, "resnet18-ft.pth")
res18.load_state_dict(torch.load(pretrained_model))
# 将模型放到指定设备上运行
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
res18.to(device)
3、torchvision.transforms
- 默认操作对象是 PIL 的
Image(RGB/HWC)
或者 Tensor; 如果需要进行多个变换功能,可以利用transforms.Compose
将多个变换整合起来- 目前版本的 Torchvision 对各种图像变换操作已经基本同时支持
PIL Image
和Tensor
类型了,因此只针对 Tensor 的变换操作很少,只有 4 个,分别是LinearTransformation
(线性变换)、Normalize
(标准化)、RandomErasing
(随机擦除)、ConvertImageDtype
(格式转换)- 神经网络模型接收的数据类型是
Tensor
,而不是 PIL 对象,可以用transforms.ToTensor() # 内部做了维度变换 HWC-->CHW 和 归一化:img.permute((2, 0, 1)).float().div(255)
对其进行转换;而反之,将Tensor
或Numpy.ndarray
格式的数据转化为 PIL.Image 格式,则用transforms.ToPILImage(mode=None)
类
-
像素内容变换
from PIL import Image img = Image.open(img_path) # 首先创建匿名实例化对象,然后使用 __call__ 功能使其变成可调用对象 img = transforms.ColorJitter(brightness=1)(img) # 类实例化后,可以直接像函数一样调用,因为其内部实现了 __call__ 方法 # 1. 亮度、对比度、饱和度、色度调整,Transforms on PIL Image(RGB 顺序) torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0) - brightness/contrast/saturation:float or tuple of float, should be non negative numbers - 当为 a 时,从[max(0, 1-a), 1+a]中随机选择 - 当为(a, b)时,从 [a, b]中随机选择 - hue:色相参数 - 当为 a 时,从 [-a, a] 中随机选择,注: 0<= a <= 0.5 - 当为(a, b)时,从 [a, b] 中随机选择,注:-0.5 <= a <= b <= 0.5 torchvision.transforms.functional.adjust_gamma(img, gamma, gain=1) # gamma 小于 1 提升亮度 torchvision.transforms.functional.adjust_brightness(img, brightness_factor) # 大于 1 提升亮度 torchvision.transforms.functional.adjust_contrast(img, contrast_factor) # 大于 1 提升对比度 torchvision.transforms.functional.adjust_saturation(img, saturation_factor) # 大于 1 提升饱和度 torchvision.transforms.functional.adjust_hue(img, hue_factor) # [-0.5, 0.5] # 2. 高斯模糊、转为灰度图、补零、擦除一部分(用 0 填充),Transforms on PIL Image(RGB 顺序) torchvision.transforms.GaussianBlur(kernel_size, sigma=(0.1, 2.0)) torchvision.transforms.Grayscale(num_output_channels=1) torchvision.transforms.Pad(padding, fill=0, padding_mode='constant') # 用法跟 RandomCrop 中的 pad 一致 # 对图像进行随机遮挡,scale:遮挡区域的面积,ratio:遮挡区域的长宽比,value:遮挡区域的值 torchvision.transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False) torchvision.transforms.functional.to_grayscale(img, num_output_channels=1) torchvision.transforms.functional.pad(img, padding, fill=0, padding_mode='constant') torchvision.transforms.functional.erase(img, i, j, h, w, v, inplace=False) # Randomly selects a rectangle region in an image and erases its pixels. # 3. 标准化处理,需在 ToTensor 后执行,Transforms on torch.*Tensor torchvision.transforms.Normalize(mean, std, inplace=False) # RGB 三通道的均值和标准差(需在训练集上统计) torchvision.transforms.functional.normalize(tensor, mean, std, inplace=False)
-
空间几何变换
# 1. 裁剪: Transforms on PIL Image(RGB 顺序) or Tensor # 1.1 随机裁剪 torchvision.transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant') - size:int(size) or sequence(h,w), output is (size, size) or(h, w) - padding:int or sequence, - 当为 a 时,上下左右均填充 a 个像素; - 当为(a, b)时,左右填充 a 个像素,上下填充 b 个像素 - 当为(a, b, c, d)时,左,上,右,下分别填充 a, b, c, d 个像素 - pad_if_needed:boolean,若图像小于设定 size,则进行填充 - padding_mode:填充模式,有 4 种模式 - 1、constant:像素值由 fill 值设定 - 2、edge:像素值由图像边缘像素决定 - 3、reflect:镜像填充,最后一个像素不镜像,eg:[1,2,3,4] → [3,2,1,2,3,4,3,2] - 4、symmetric:镜像填充,最后一个像素镜像,eg:[1,2,3,4] → [2,1,1,2,3,4,4,3] - fill:填充模式为 constant 时,设置填充的像素值 # 1.2 随机面积大小和长宽比裁剪原始图片,然后将图片 resize 到设定好的 size torchvision.transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(3./4., 4./3.), interpolation=Image.BILINEAR) - size:int(size) or sequence(h,w), output is (size, size) or(h, w) - scale:随机裁剪面积比例, 默认(0.08, 1) - ratio:随机长宽比,默认(3/4, 4/3) - interpolation:插值模式,有 3 种, PIL.Image.NEAREST/BILINEAR/BICUBIC # 1.3 中心裁剪出尺寸为 size 的图片 torchvision.transforms.CenterCrop(size) # 1.4 在图像的上下左右以及中心裁剪出尺寸为 size 的 5 张图片 torchvision.transforms.FiveCrop(size) # 1.5 在原图像及其水平镜像图像的上下左右以及中心裁剪出尺寸为 size 的 10 张图片 torchvision.transforms.TenCrop(size, vertical_flip=False) # horizontal flipping is used by default # Functional Transforms on PIL Image(RGB 顺序) torchvision.transforms.functional.crop(img, top, left, height, width) torchvision.transforms.functional.center_crop(img, output_size) torchvision.transforms.functional.five_crop(img, size) torchvision.transforms.functional.ten_crop(img, size, vertical_flip=False) torchvision.transforms.functional.resized_crop(img, top, left, height, width, size, interpolation=2) # eg:Five Crop >>> transform = Compose([ >>> FiveCrop(size), # this is a list of PIL Images >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor >>> ]) >>> #In your test loop you can do the following: >>> input, target = batch # input is a 5d tensor, target is 2d >>> bs, ncrops, c, h, w = input.size() >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops # 2. 翻转、旋转、缩放:Transforms on PIL Image(RGB 顺序) # 2.1 旋转 torchvision.transforms.RandomRotation(degrees, expand=False, resample=False, center=None, fill=0) - degrees:sequence or float or int - 当为 float or int 时(eg: a),旋转角度范围为 (-a, +a) - 当为 sequence 时(eg: (min, max)),,旋转角度范围为 (min, max) - expand:是否扩大图片以保持原图信息 - center:旋转中心设置,默认为图像中心 - resample:重采样方法 - fill:默认为 0 进行填充 # 2.2 缩放 torchvision.transforms.Resize(size, interpolation=Image.BILINEAR) - size:int(size) or sequence(h,w), output is (size, size) or(h, w) - 若为 int(eg:size):将图像最短边缩小至 size,宽高比例不变,若 h>w,则输出大小为 (size*h/w, size). - 若为 sequence(eg:(h,w)):将图像将缩放至 (h, w) - interpolation:插值模式,有 3 种, PIL.Image.NEAREST/BILINEAR/BICUBIC # 2.3 翻转 torchvision.transforms.RandomHorizontalFlip(p=0.5) # 依概率左右翻转 torchvision.transforms.RandomVerticalFlip(p=0.5) # 依概率上下翻转 torchvision.transforms.functional.hflip(img) torchvision.transforms.functional.vflip(img) torchvision.transforms.functional.rotate(img, angle, resample=False, expand=False, center=None, fill=0) torchvision.transforms.functional.resize(img, size, interpolation=2) # 3. 仿射变换和透视变换 torchvision.transforms.RandomAffine(degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0) - degrees:sequence or float or int - 当为 float or int 时(eg: a),旋转角度范围为 (-a, +a) - 当为 sequence 时(eg: (min, max)),,旋转角度范围为 (min, max) - translate:tuple,宽和高平移比例设置,eg: (w_ratio, h_ratio), - 图像在宽维度平移范围为 (-img_width * w_ratio, img_width * w_ratio) - 图像在高维度平移范围为 (-img_height * h_ratio, img_height * h_ratio) - scale:tuple,面积缩放比例设置,eg:(a, b) - shear:sequence or float or int,错切角度参数设置 - 一个值时:沿 x 轴错切,范围取 (-shear, +shear) - 二个值时:沿 x 轴错切,范围取 (shear[0], shear[1]) - 四个值时:沿 x 轴错切,范围取 (shear[0], shear[1]);沿 y 轴错切,范围取 (shear[2], shear[3]) - resample:重采样方法 - fillcolor:默认为 0 进行填充 torchvision.transforms.RandomPerspective(distortion_scale=0.5, p=0.5, interpolation=3) torchvision.transforms.functional.affine(img, angle, translate, scale, shear, resample=0, fillcolor=None) torchvision.transforms.functional.perspective(img, startpoints, endpoints, interpolation=3)
-
数据类型转换、随机操作及组合操作
1. PIL Image 和 Tensor 的相互转换
# Convert a tensor(CHW) or an ndarray(HWC) to PIL Image
torchvision.transforms.ToPILImage(mode=None)
# Converts a PIL Image or numpy.ndarray (HWC, dtype=uint8) in the range [0, 255] to a torch.FloatTensor
# of shape (CHW) in the range [0.0, 1.0];具体内部操作示例:img.permute((2, 0, 1)).float().div(255)
torchvision.transforms.ToTensor()
torchvision.transforms.functional.to_pil_image(pic, mode=None)
torchvision.transforms.functional.to_tensor(pic)
2. 随机对 transforms 操作,使数据增强更加灵活
torchvision.transforms.RandomApply([transforms1,transforms2, ...], p=0.5) # 依据概率执行一组 transforms 操作
torchvision.transforms.RandomChoice([transforms1,transforms2, ...]) # 从一系列transforms 方法中随机挑选一个
torchvision.transforms.RandomOrder([transforms1,transforms2, ...]) # 对一组 transformers 随机打乱顺序后执行
torchvision.transforms.Lambda(lambd) # Apply a user-defined lambda as a transform(eg:TenCrop后处理函数)
3. 将多个 transform 操作组合起来使用
torchvision.transforms.Compose(transforms) # transforms:list of Transform objects
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
transforms.RandomErasing(),
])
# 4. 使用示例
import numpy as np
from PIL import Image
from utils import kd_read_image_pil
import matplotlib.pyplot as plt
from torchvision import transforms
if __name__ == "__main__":
# 0、读取图片,并转换为 BGR HWC
img_path = "./imgs/test.jpg"
pil_img = kd_read_image_pil(img_path) # pil rgb hwc
pil_img = np.array(pil_img)[..., ::-1] # np bgr hwc
pil_img = Image.fromarray(pil_img) # pil bgr hwc
# 1、ColorJitter
T_ColorJitter = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5)
img_ColorJitter = T_ColorJitter(pil_img)
# 2、Pad
T_Pad = transforms.Pad(padding=15, fill=0, padding_mode='constant')
# T_Pad = transforms.Pad(padding=(10, 0, 5, 0), fill=0, padding_mode='constant') # 左,上,右,下
img_pad = T_Pad(pil_img)
# 3、RandomRotation
T_RandomRotation = transforms.RandomRotation(degrees=(-8, -8), expand=True)
img_rotated = T_RandomRotation(pil_img)
# 5、RandomPerspective
T_Perspective = transforms.RandomPerspective(distortion_scale=0.3, p=1.0)
img_perspective = T_Perspective(pil_img)
# 6、GaussianBlur
T_GaussianBlur = transforms.GaussianBlur(kernel_size=7)
img_blur = T_GaussianBlur(pil_img)
# 7、Resize
T_Resize = transforms.Resize(size=(48, 128))
img_resized = T_Resize(pil_img)
# 8、组合使用
transform = transforms.Compose([
transforms.RandomApply([
transforms.RandomChoice([
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
transforms.RandomPerspective(distortion_scale=0.2, p=1.0),
transforms.RandomRotation(degrees=(-8, -8), expand=True),
transforms.GaussianBlur(kernel_size=7),
transforms.Pad(padding=(10, 0, 10, 0), fill=0, padding_mode='constant'),
transforms.Pad(padding=(0, 10, 0, 10), fill=0, padding_mode='constant'),
transforms.Pad(padding=(10, 10, 10, 10), fill=0, padding_mode='constant'),
])], p=0.5), # 概率为0.5随机应用上述转换操作中的一个
transforms.Resize(size=(48, 128)),
# transforms.ToTensor(),
# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
img_transformed = transform(pil_img)
# 9、画图显示
plt.subplot(211).imshow(pil_img)
plt.title("raw img")
plt.subplot(212).imshow(img_transformed)
plt.title("img_transformed")
plt.show()
- 自定义一个 transforms 方法
# 只需实现 __init__ 和 __call__ 即可
class AddPepperNoise(object):
"""增加椒盐噪声
Args:
snr (float): Signal Noise Rate
p (float): 概率值,依概率执行该操作
"""
def __init__(self, snr, p=0.9):
assert isinstance(snr, float) or (isinstance(p, float))
self.snr = snr
self.p = p
def __call__(self, img):
"""
Args:
img (PIL Image): PIL Image
Returns:
PIL Image: PIL image.
"""
if random.uniform(0, 1) < self.p:
img_ = np.array(img).copy()
h, w, c = img_.shape
signal_pct = self.snr
noise_pct = (1 - self.snr)
mask = np.random.choice((0, 1, 2), size=(h, w, 1), p=[signal_pct, noise_pct/2., noise_pct/2.])
mask = np.repeat(mask, c, axis=2)
img_[mask == 1] = 255 # 盐噪声
img_[mask == 2] = 0 # 椒噪声
return Image.fromarray(img_.astype('uint8')).convert('RGB')
else:
return img
# 无参数据增强实现
class ToTensor(object):
def __call__(self, cvimage, boxes=None, labels=None):
return torch.from_numpy(cvimage.astype(np.float32)).permute(2, 0, 1), boxes, labels
# 自定义 transformer 方法的使用
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
AddPepperNoise(0.9, p=0.5),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
- 注意:调用默认的
Compose
只适合transforms
方法只有一个参数的情形;若像SSD
增强那种需要额外传入boxes
和labels
参数的情形,则需要重新定义 Compose 类(或者将多个参数封装成一个字典
,则不重新定义 Compose 类,此时只能用自定义数据增强方式)
class Compose(object):
"""Composes several augmentations together.
Args:
transforms (List[Transform]): list of transforms to compose.
Example:
>>> augmentations.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img, boxes=None, labels=None):
for t in self.transforms:
img, boxes, labels = t(img, boxes, labels)
return img, boxes, labels
- transform Normalize 和 ToTensor 逆变换
def transform_invert(img_, transform_train):
"""
将data 进行反transfrom操作
:param img_: tensor
:param transform_train: torchvision.transforms
:return: PIL image
"""
if 'Normalize' in str(transform_train):
norm_transform = list(filter(lambda x: isinstance(x, transforms.Normalize), transform_train.transforms))
mean = torch.tensor(norm_transform[0].mean, dtype=img_.dtype, device=img_.device)
std = torch.tensor(norm_transform[0].std, dtype=img_.dtype, device=img_.device)
img_.mul_(std[:, None, None]).add_(mean[:, None, None])
img_ = img_.transpose(0, 2).transpose(0, 1) # C*H*W --> H*W*C
img_ = np.array(img_) * 255
if img_.shape[2] == 3:
img_ = Image.fromarray(img_.astype('uint8')).convert('RGB')
elif img_.shape[2] == 1:
img_ = Image.fromarray(img_.astype('uint8').squeeze())
else:
raise Exception("Invalid img shape, expected 1 or 3 in axis 2, but got {}!".format(img_.shape[2]))
return img_
for epoch in range(MAX_EPOCH):
for i, data in enumerate(train_loader):
inputs, labels = data # N C H W
img_tensor = inputs[0, ...] # C H W
img = transform_invert(img_tensor, train_transform) # HWC RGB
plt.imshow(img)
plt.show()
plt.pause(0.5)
plt.close()
4、torchvision.utils
- 提供两个常用的函数
make_grid
和save_img
make_grid
将多张图片拼接在一个网格中;save_img
将Tensor
保存成图片torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0) # 主要参数解析: - tensor (Tensor or list):4D mini-batch Tensor(B x C x H x W),或者 a list of image(形状相同) - nrow (python:int, optional):每行几张图片. 拼成一个(B/nrow, nrow)的网格图片 - padding (python:int, optional):补零的数量 torchvision.utils.save_image(tensor, fp, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0, format=None) # 主要参数解析: tensor (Tensor or list):默认 为一个Tensor,如果给定的是 mini-batch tensor,那就用 make-grid 做成网格图 normalize=True ,会将图片的像素值归一化处理 fp:保存的文件名(带后缀) eg: torchvision.utils.save_image(img, 'a.png')
四、参考资料
1、https://pytorch.org/docs/stable/torchvision/index.html
2、Pytorch数据读取(Dataset, DataLoader, DataLoaderIter)
3、4个例子让你的pytorch数据增强过程不随机
4、图像分类:数据增强(Pytorch版)
5、目标检测:数据增强(Numpy+Pytorch)