可以使用 torchvision 库中的 models.vgg19() 方法来加载预训练好的 VGG19 模型。然后使用 torch.nn.Sequential 将 VGG19 模型的部分层替换成新的自定义层,训练得到一个新的模型。最后使用这个新模型来对 MINST 手写数字进行分类。
代码示例:
import torch
import torch.nn as nn
import torchvision.models as models
# 加载预训练模型
vgg19 = models.vgg19(pretrained=True)
# 获取最后一层的输出维度
num_ftrs = vgg19.classifier[6].in_features
# 定义新的分类层
classifier = nn.Sequential(
nn.Linear(num_ftrs, 4096),
nn.ReLU(),
nn.Linear(4096,