本次将一个使用Pytorch的一个实战项目,记录流程:自定义数据集->数据加载->搭建神经网络->迁移学习->保存模型->加载模型->测试模型git
自定义数据集
参考个人上一篇博客:自定义数据集处理github
数据加载
默认小伙伴有对深度学习框架有必定的了解,这里就不作过多的说明了。
好吧,仍是简单的说一下吧:
咱们在作好了自定义数据集以后,其实数据的加载和MNSIT 、CIFAR-10 、CIFAR-100等数据集的都是类似的,过程以下所示:web
导入必要的包
import torch
from torch import optim, nn
import visdom
from torch.utils.data import DataLoader
加载数据
能够发现和MNIST 、CIFAR的加载基本上是同样的
train_db = Pokemon('pokeman', 224, mode='train')
val_db = Pokemon('pokeman', 224, mode='val')
test_db = Pokemon('pokeman', 224, mode='test')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,
num_workers=4)
val_loader = DataLoader(val_db