PyTorch学习笔记-Transform

1. Transform的概念与基本用法

transforms 在计算机视觉工具包 torchvision 下,包含了很多种对图像数据进行变换的类,这些都是在我们进行图像数据读入步骤中必不可少的。

transforms 主要使用的类为:transforms.ToTensor,该类能够将 PIL Image 或者 ndarray 转换为 tensor,并且归一化至[0-1]。注意归一化至[0-1]是直接除以255,若自己的 ndarray 数据尺度有变化,则需要自行修改。

为什么需要 tensor 数据类型?因为它包装了反向传播神经网络所需要的一些基础的参数,因此在神经网络中需要将图片类型转换为 tensor 类型进行训练。

例如:

from PIL import Image
from torchvision import transforms
import cv2

img_path = 'dataset/hymenoptera_data/train/ants_image/0013035.jpg'
img_PIL = Image.open(img_path)  # <class 'PIL.JpegImagePlugin.JpegImageFile'>

tensor_trans = transforms.ToTensor()  # 创建 ToTensor 的实例对象
img_tensor1 = tensor_trans(img_PIL)  # 将 PIL Image 转换成 tensor
print(type(img_tensor1))  # <class 'torch.Tensor'>

img_cv = cv2.imread(img_path)  # <class 'numpy.ndarray'>
img_tensor2 = tensor_trans(img_cv)  # 将 OpenCV Image 转换成 tensor
print(type(img_tensor2))

2. Transform的常用类

  • transforms.Compose:Compose 能够将多种变换组合在一起。例如下面的代码可以先将 PIL Image 中心裁切,然后再转换成 tensor:
img_path = 'dataset/hymenoptera_data/train/ants_image/0013035.jpg'
img_PIL = Image.open(img_path)

trans = transforms.Compose([
    transforms.CenterCrop(100),
    transforms.ToTensor()
])

img_trans = trans(img_PIL)
  • transforms.CenterCrop:需要传入参数 size,表示以 (size, size) 的大小从中心裁剪,参数也可以为 (height, width)。例如:
img_PIL.show()

trans_centercrop = transforms.CenterCrop((100, 150))
img_centercrop = trans_centercrop(img_PIL)
img_centercrop.show()
  • transforms.RandomCrop:需要传入参数 size,表示以 (size, size) 的大小随机裁剪,参数也可以为 (height, width)
  • transforms.Normalize(mean, std):对数据按通道进行标准化,即先减均值 mean,再除以标准差 std,注意是 HWC 格式,处理公式为:output[channel] = (input[channel] - mean[channel]) / std[channel],例如:
trans_tensor = transforms.ToTensor()
img_tensor = trans_tensor(img_PIL)

# 如果 input 的范围是[0, 1],那么用该参数归一化后的范围就变为[-1, 1]
trans_norm = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
img_norm = trans_norm(img_tensor)
print(img_norm)
  • transforms.Resize:需要传入参数 (height, width) interpolation,表示重置图像的分辨率为 (h, w),也可以传入一个整数 size,这样会将较短的那条边缩放至 size,另一条边按原图大小等比例缩放。interpolation 为插值方法选择,默认为 PIL.Image.BILINEAR,例如:
trans_tensor = transforms.ToTensor()
img_tensor = trans_tensor(img_PIL)

print(img_tensor.size())  # torch.Size([3, 512, 768]),tensor 图像使用 size() 获取大小,PIL 图像使用 size

trans_resize = transforms.Resize((256, 300))
img_resize = trans_resize(img_tensor)
print(img_resize.size())  # torch.Size([3, 256, 300]),修改比例

trans_resize = transforms.Resize(30)
img_resize = trans_resize(img_tensor)
print(img_resize.size())  # torch.Size([3, 30, 45]),与原图等比例
  • transforms.ToPILImage::将 tensor 或者 ndarray 的数据转换为 PIL Image 类型数据,参数 mode 默认为 None,表示1通道, mode=3 表示3通道,默认转换为 RGB,4通道默认转换为 RGBA。
  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

柃歌

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值