Transforms的基本概念
transforms是torchvision下的一个模块,主要帮助用户方便的对图像数据进行处理
它要求数据是(C, H, W)
的三维数组,其中字母含义为:
C
: Channel, 图片的通道,例如R、G、BH,W
, Height, Weight,图片的宽高
使用PIL读取一张图片
在使用Transforms前,先读取一张图片,用于后续使用
from PIL import Image
image = Image.open("images/mary.jpg")
image
Transforms的常用方法
Transforms的常用方法有如下:
1.ToTensor()
: 将一个PIL Image
或一个numpy.ndarray
转为Tensor
trans = transforms.ToTensor()
img_data = trans(image)
img_data.shape
torch.Size([3, 225, 225])
输出[3, 255, 255]
表示有3个通道(R,G,B),每个通道有255x255个像素点
2.Normalize(mean, std, inplace=False)
: 将tensor归一化为均值为mean
,方差为std
的数据
# 将三个通道分别做归一化
# 第一个通道归一化为 均值为0,方差为1
# 第二个通道归一化为 均值为1,方差为2
# 第三个通道归一化为 均值为2,方差为3
img_data = transforms.Normalize(mean=(0, 1, 2), std=(1,2,3))(img_data)
img_data.shape
torch.Size([3, 225, 225])
Transforms的Compose方法
一张图片可能需要执行很多次Transforms方法,所以Transform提供了Compose方法,方便用户一次将其全部处理完毕
img_data = transforms.ToTensor()(image)
img_data = transforms.Normalize(mean=(0, 1, 2), std=(1,2,3))(img_data)
compose = transforms.Compose(
[ # 将要对图片做的处理,全部一次性写全
transforms.ToTensor(),
transforms.Normalize(mean=(0, 1, 2), std=(1,2,3))
]
)
compose(image).equal(img_data)
True
参考资料
transforms官方文档:https://pytorch.org/vision/stable/transforms.html