torchvision是pytorch的一个图形库,主要用来构建计算机视觉模型。torchvision.transforms主要包括一些常用的图形变换,它由以下四个部分组成:
- torchvision.datasets:一些加载数据的函数及常用的数据集接口
- torchvision.models:包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等
- torchvision.transforms:常用的图片变换,例如裁剪、旋转等
- torchvision.utils:其他的一些有用的方法
1 datasets
2 models
3 transforms
transforms.Compose()
这个类的主要作用是串联多个图片变换的操作
类定义如下:
CLASS torchvision.transforms.Compose(transforms)
参数:
- transforms:这是Transform类对象的list,包含要组合在一起的transforms
使用示例:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.PILToTensor(),
>>> transforms.ConvertImageDtype(torch.float),
>>> ])
事实上,Compose()
类会将transforms列表里面的transform操作进行遍历,实现的代码比较简单:
## 这里对源码进行了部分截取。
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img