一、什么是Transforms?
transforms 是 PyTorch 中 torchvision 库的一个模块,用于进行图像数据预处理和增强。它提供了一系列常用的数据预处理函数,如图像裁剪、缩放、翻转、归一化等。
transforms 模块中的函数通常用于创建数据预处理管道,以将输入数据转换为模型所需的格式。例如,在训练神经网络时,通常会对输入图像进行随机翻转、随机裁剪等增强操作,以扩展数据集并提高模型的泛化能力。
以下是 transforms 模块中一些常用的函数:
Compose(transforms):将多个数据预处理函数组合成一个管道。
ToTensor():将 PIL 图像或 Numpy 数组转换为 PyTorch 张量。
Normalize(mean, std):对张量进行归一化,减去均值并除以标准差。
Resize(size):将图像缩放到指定大小。
CenterCrop(size):对图像进行中心裁剪。
RandomCrop(size):对图像进行随机裁剪。
RandomHorizontalFlip():以一定的概率对图像进行随机水平翻转。
RandomRotation(degrees):对图像进行随机旋转。
使用 transforms 模块中的函数可以方便地将输入数据进行预处理和增强,以提高模型的表现。
二、将PIL图像和numpy数组转化为Tensor类型
import cv2
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from PIL import Image
#通过transforms.Totensor去解决两个问题
#1、transforms该如何使用
#2、为什么我们需要Tensor数据类型
img_path="hymenoptera_data/hymenoptera_data/train/bees/95238259_98470c5b10.jpg"
#ALT+ENTER可以瞬间调库
PIL_image=Image.open(img_path)
writer=SummaryWriter("logs")
cv2_img=cv2.imread(img_path)
#创造实例化tensor生产工具
tensor_trans=transforms.ToTensor()
#进行实际生产,将PIL类型转为Tensor类型
tensor_img=tensor_trans(PIL_image)
tensor_img2=tensor_trans(cv2_img)
writer.add_image("Tensor_img",tensor_img,0)
writer.add_image("Tensor_img1",tensor_img2,0)
writer.close()
三、结果展示
四、常见类型图片打开
五、transforms里的Normalize
#将每个通道的像素值减去 0.5,然后除以 0.5,以将像素值归一化到范围 [-1, 1]。
trans_norm=transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
norm_img=trans_norm(tensor_img)
六、transforms里的Resize
trans_resize=transforms.Resize((512,512))
resize_img=trans_resize(PIL_image)
resize_img=trans_totensor(resize_img)
不过如今可以通过trans_resize.forward(这里直接传入tensor对象)
七、transforms里的compose
通过一个列表可以接收多个预操作,不用再一项项的输入,注意是个列表
trens_size=transforms.Resize((512,1024))
trens_normal=transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
trens_tensor=transforms.ToTensor()
trens_compose=transforms.Compose([trens_size,trens_tensor,trens_normal])
重点还是要学会在结构中看文档!!!