模型保存
1.只保存和加载模型参数:
##保存
torch.save(the_model.state_dict(), PATH)
##导入
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
2.保存和加载整个模型:
##保存
torch.save(the_model, PATH)
##导入
the_model = torch.load(PATH)
torchvision介绍
1.torchvision.datasets
- MNIST
- COCO(用于图像标注和目标检测)(Captioning and Detection)
- LSUN Classification
- ImageFolder
- Imagenet-12
- CIFAR10 and CIFAR100
- STL10
2.torchvision.models
- AlexNet
- VGG
- ResNet
- SqueezeNet
- DenseNet
import torchvision.models as models
#pretrained=True就可以使用预训练的模型
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
3. torchvision.transforms
对PIL.Image进行变换
class torchvision.transforms.CenterCrop(size)
将给定的PIL.Image进行中心切割,得到给定的size,size可以是tuple,(target_height, target_width)。size也可以是一个Integer,在这种情况下,切出来的图片的形状是正方形。
class torchvision.transforms.RandomCrop(size, padding=0)
切割中心点的位置随机选取。size可以是tuple也可以是Integer。
class torchvision.transforms.RandomHorizontalFlip
随机水平翻转给定的PIL.Image,概率为0.5。即:一半的概率翻转,一半的概率不翻转。
class torchvision.transforms.RandomSizedCrop(size, interpolation=2)
先将给定的PIL.Image随机切,然后再resize成给定的size大小。
class torchvision.transforms.Pad(padding, fill=0)
将给定的PIL.Image的所有边用给定的pad value填充。 padding:要填充多少像素 fill:用什么值填充。
对Tensor进行变换
class torchvision.transforms.Normalize(mean, std)
给定均值:(R,G,B) 方差:(R,G,B),将会把Tensor正则化。即:Normalized_image=(image-mean)/std。
class tor