参考的是 torchvision - PyTorch中文文档
相比于各种博客的良莠不齐的质量(也可能是我太菜看不懂),PyTorch中文文档 要友好的多。
torchvision
库:torchvision
是独立于 PyTorch 的关于图像操作的一些方便工具库,包含了目前流行的数据集,模型结构和常用的图片转换工具。
其常用的包:
vision.datasets
:
几个常用视觉数据集,可以下载和加载。包括以下的数据集:
MNIST
COCO
(用于图像标注和目标检测)(Captioning and Detection)LSUN Classification
ImageFolder
Imagenet-12
CIFAR10 and CIFAR100
STL10
vision.models
:
包含的模型结构有:
AlexNet
,VGG
,ResNet
SqueezeNet
Densenet
可以使用随机初始化的权重来创建这些模型。
import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
squeezenet = models.squeezenet1_0()
densenet = models.densenet_161()
对于ResNet variants
和AlexNet
,也可以使用预训练(pre-trained
)的模型。通过加上pretrained=True
来调用。
注:预训练是在ImageNet
上进行的。
# pretrained=True 来使用预训练的模型
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
附:可调用的全部模型
torchvision.models.alexnet(pretrained=False, ** kwargs)
torchvision.models.resnet18(pretrained=False, ** kwargs)
torchvision.models.resnet34(pretrained=False, ** kwargs)
torchvision.models.resnet50(pretrained=False, ** kwargs)
torchvision.models.resnet101(pretrained=False, ** kwargs)
torchvision.models.resnet152(pretrained=False, ** kwargs)
torchvision.models.vgg11(pretrained=False, ** kwargs)
torchvision.models.vgg11_bn(** kwargs) # with batch normalization
torchvision.models.vgg13(pretrained=False, ** kwargs)
torchvision.models.vgg13_bn(** kwargs) # with batch normalization
torchvision.models.vgg16(pretrained=False, ** kwargs)
torchvision.models.vgg16_bn(** kwargs) # with batch normalization
torchvision.models.vgg19(pretrained=False, ** kwargs)
torchvision.models.vgg19_bn(** kwargs) # with batch normalization
vision.transforms
:
对PIL.Image进行变换
包括一些常用的图像操作。
class torchvision.transforms.Compose(transforms)
:
将多个transform
组合起来使用。
transforms
: 由transform
构成的列表。
示例:
transforms.Compose([
transforms.CenterCrop(10),
transforms.ToTensor(),
])
class torchvision.transforms.CenterCrop(size)
:
将给定的PIL.Image
进行中心切割,得到给定的size
。
size
可以是tuple
,(target_height, target_width)
。
size
也可以是一个Integer
,在这种情况下,切出来的图片的形状是正方形。
class torchvision.transforms.Pad(padding, fill=0)
:
将给定的PIL.Image
的所有边用给定的pad value
填充。 padding
:要填充多少像素。
fill
:用什么值填。
示例:
from torchvision import transforms
from PIL import Image
padding_img = transforms.Pad(padding=10, fill=0)
img = Image.open('test.jpg')
print(type(img))
print(img.size)
padded_img=padding(img)
print(type(padded_img))
print(padded_img.size)
<class 'PIL.PngImagePlugin.PngImageFile'>
(10, 10)
<class 'PIL.Image.Image'>
(30, 30) #由于上下左右都要填充10个像素,所以填充后的size是(30,30)
class torchvision.transforms.RandomCrop(size, padding=0)
:
切割中心点的位置随机选取。size
可以是tuple
也可以是Integer
。
class torchvision.transforms.RandomHorizontalFlip
:
随机水平翻转给定的PIL.Image
,概率为0.5
。即:一半的概率翻转,一半的概率不翻转。
class torchvision.transforms.RandomSizedCrop(size, interpolation=2)
:
先将给定的PIL.Image
随机切,然后再resize
成给定的size
大小。
对Tensor进行变换
class torchvision.transforms.Normalize(mean, std)
:
给定均值:(R,G,B)
、方差:(R,G,B)
,将会把Tensor
正则化。
即:Normalized_image
=(image
-mean
)/std
。
class torchvision.transforms.ToTensor
:
把一个 取值范围是[0,255]
的PIL.Image
或者是 shape
为(H,W,C)
的numpy.ndarray
,转换成shape
为[C,H,W]
,取值范围是[0,1.0]
的torch.FloadTensor
。
简单来说,就是PIL.Image
、ndarray
—> tensor
data = np.random.randint(0, 255, size=300)
img = data.reshape(10,10,3)
print(img.shape)
img_tensor = transforms.ToTensor()(img) # 转换成tensor
print(img_tensor)
class torchvision.transforms.ToPILImage
:
将shape
为(C,H,W)
的Tensor
或shape
为(H,W,C)
的numpy.ndarray
转换成PIL.Image
,值不变。
通用变换 class torchvision.transforms.Lambda(lambd)
使用lambd
作为转换器。