简介:
这是PyTorch中的一个图像转换模块,任何图像增强或者归一化操作都需要用到此模块
1.torchvision.transforms.Compose(transforms)
作用:
将几个transforms组合到一起
参数:
由Transform
objects构成的列表,可以是:
a).torchvision.transforms的类对象,
b).也可以是自定义的类对象
2.torchvision.transforms.ToTensor
作用:
将PIL读取的图像或者np.array (H * W * C) 格式的数据([0, 255])转换成torch.FloatTensor类型的 (C * H * W) 在([0.0, 1.0])范围内的数据.
黄色线前提:
a).PIL image属于(L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
b).np.array数据类型为dtype=np.uint8
其他情况下不会缩放到[0.0, 1.0]的范围
返回: Tensor
3.torchvision.transforms.Normalize(mean, std, inplace=False)
作用:
(标准化/归一化)一个tensor 类型的数据,归一化的均值和方差分别是:mean和std. 要给定每个通道的均值和方差(M1, M2, M3), (S1, S2, S3).
例如: input[channel] = (input[channel] - mean[channel]) / std[channel]
注意:默认情况下转换操作不是in_palce操作,一般配合**torchvision.transforms.ToTensor()**一起使用
返回: Tensor