torchvision介绍
torchvision是pytorch的一个图形库,它服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。
torchvision的构成:
torchvision.datasets:一些加载数据的函数以及常用的数据集接口
torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
torchvision.transforms: 常用的图形变换,例如裁剪、旋转等;
torchvision.utils: 其他的一些有用的方法。
transforms.Compose()详解
torchvision.transforms.Compose()类:用于串联多个图片变换的操作
Compose()类会将transforms列表里面的tansform操作进行遍历
(ps:括号里面是一个列表,列表里面的元素是transform操作)
例如:
transforms.Compose([transforms.CenterCrop(10),transform.ToTensor()])
常见的transform操纵:
transforms.CenterCrop(size)
将给定的PIL.Image进行中心切割,得到给定的size。
transforms.RandomCrop(size, padding=0)
切割中心点的位置随机选取。
transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5)
进行数据的归一化处理,公式为:channel=(channel-mean)/std
transforms.ToTensor()
把数据处理成[0,1]
例如:
transform2 = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
]
)
transforms.Normalize将数据进行归一化:
channel=(channel-mean)/std(因为transforms.ToTensor()已经把数据处理成[0,1],那么(x-0.5)/0.5就是[-1.0, 1.0]),
这样一来,我们的数据中的每个值就变成了[-1,1]的数了