Pytorch学习笔记8——迁移学习与autoencoder
上一回我们使用自定义数据集完成了训练,但由于自定义数据集数量较少,准确度较低,因此我们可以考虑迁移学习。利用同domain的数据来增强学习效果。
红色绿色为原数据,黑色为我们之前训练效果不好的数据,我们将黑色数据加入原数据中是原模型微调,达到迁移学习的效果(起点更高)。
实现的方法很简单,在训练程序中,将主程序进行变更:
#model=ResNet18(5).to(device)
trained_model=resnet18(pretrained=True)
model=nn.Sequential(*list(trained_model.children())[:-1], #[b,512,1,1])
Flatten(),#[b,512,1,1])=>[b,512]
nn.Linear(512,5)
).to(device)
x=torch.randn(2,3,224,224).to(device)
print(model(x).shape)
使用pytorch的预训练模型即可。