torchvision中的transforms,很常用,用于图片的变换。本节通过处理单张图片介绍transforms的基本使用,下节将transforms与数据集结合。
from torchvision import transforms
1. transforms结构及用法
可以用ctrl查看使用方法。一个python文件transforms.py,工具箱,其中包含许多class方法。
(1)compose
结合不同的transforms together,可以将多个不同的transforms操作整合成一个块操作
(2)ToTensor(最常用)
把“PIL Image” 或 “numpy.array” 转换成tensor
(3)ToPILImage,Normalize正则化,Resize尺寸变化
2. tensor数据类型
通过transforms.ToTensor解决两个问题:
1.transforms如何使用
from torchvision import transforms
import cv2
img_path = '....jpg'
img = cv2.imread(img_path)
#1.transforms如何使用
#查看如何使用ToTensor。需要传入一个pic,返回tensor类型图片
tensor_trans = transforms.ToTensor() #创建具体的工具(也是一种实例化?)原本工具箱相当于模板
tensor_img = tensor_trans(img)
2.tensor数据类型特点,为何需要该数据类型:神经网络中必须使用到的数据类型
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import cv2
img_path = '....jpg'
img = cv2.imread(img_path)
writer = SummaryWriter("logs")
tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img)
writer.add_image("tensor_img", tensor_img)
writer.close()
#运行后,在终端运行tensorboard --logdir=logs
单纯生成一个tensor类型:
x = torch.tensor(1.0)
3. 常见transforms的使用
需要关注的几点,主要是格式方面:
常用的几种功能函数在前有所提及,具体使用方法在实例化后,用ctrl进去看需要参数即可。
归一化Normalize
公式:
input[channel] = ( input[channel] - mean[channel] ) / std[channel]
实例化时需要输入mean和std(平均值和标准差,三个通道需要写[0, 0, 0]这样)
作用:限制输出可以在一个范围内。如输入均值和标准差为0.5,则有:
4. Resize()的使用
输入要求数据为PIL Image格式。resize((512, 512))为转换为(512, 512)尺寸
#之后的img指的是由PIL Image方法读取的Image数据
trans_totensor = transforms.ToTensor()
trans_resize = transforms.Resize((512, 512))
# img PIL -> resize -> img_resize PIL
img_resize = trans_resize(img)
# img_resize PIL -> totensor -> img_resize tensor 覆盖
img_resize = trans_totensor(img_resize)
print(img_resize)
5. Compose()的使用
常规第一种用法不提,用来合并不同的transforms操作。
Compose()中的参数需要是一个列表。python中列表表示形式为[数据1, 数据2, …],在Compose中,数据需要transforms类型,即Compose([transforms参数1, transforms参数2, …])
trans_resize_2 = transforms.Resize(512) #等比例缩放,输入一个值
# PIL -> PIL -> tensor 顺序不能变,否则先变成tensor类型会出错
trans_compose = transforms.Compose([trans_resize_2, trans_totensor])
img_resize_2 = trans_compose(img) #需要Image数据
writer.add_image(img_resize_2)
6. 随机裁剪RandomCrop()的使用
trans_random = transforms.RandomCrop(512)
trans_compose_2 = transforms.Compose([trans_random, trans_totensor])
#例如说裁剪10个
for i in range(10):
img_crop = trans_compose_2(img)
writer.add_image("RandomCrop", img_crop, i)
writer.close()
#指定高和宽
trans_random = transforms.RandomCrop(500, 1000)
trans_compose_2 = transforms.Compose([trans_random, trans_totensor])
for i in range(10):
img_crop = trans_compose_2(img)
writer.add_image("RandomCrop", img_crop, i)
writer.close()
7. 总结
- 关注输入、输出类型;多看官方文档。
- 不知道返回值的时候:
- print(type())
- debug