加载预训练模型: PyTorch提供了许多经过预训练的模型,可以通过torchvision库直接加载。例如,可以使用以下代码加载预训练的ResNet模型
import torch
import torchvision.models as models
model = models.resnet50(pretrained=True)
可以根据自己的需求修改已加载的模型的结构。例如,可以修改ResNet模型的全连接层,其中num_classes是新的输出类别数量:
model.fc = torch.nn.Linear(2048, num_classes)
保存模型: 使用torch.save()函数可以将模型保存到文件中。例如,可以使用以下代码保存修改后的模型:
torch.save(model.state_dict(), 'model.pth')
使用torch.load()函数加载已保存的模型参数。例如,可以使用以下代码加载已保存的模型:
model = models.resnet50()
model.load_state_dict(torch.load('model.pth'))
model.eval()