import torch
import torchvision
import torch.nn as nn
import numpy as np
import torchvision.transforms as transforms
# 下载并加载预训练模型 ResNet-18.
resnet = torchvision.models.resnet18(pretrained=True)
# 保持该模型的其它层参数,仅使用最后一层进行 finetuning.
for param in resnet.parameters():
param.requires_grad = False
resnet.fc = nn.Linear(resnet.fc.in_features, 100) # 设定输出为100.
# 前向传递.
images = torch.randn(64, 3, 224, 224)
outputs = resnet(images)
print (outputs.size()) # 输出 (64, 100)
# 保存和加载整个模型.
torch.save(resnet, 'model.ckpt')
model = torch.load('model.ckpt')
# 保存和加载整个模型的参数. (数据量较少,推荐这种方式).
torch.save(resnet.state_dict(), 'params.ckpt')
resnet.load_state_dict(torch.load('params.ckpt'))