1.数据准备
(1)ImageFolder
torchvision已经预先实现了常用的Dataset,包括前面使用过的CIFAR-10,以及ImageNet、COCO、MNIST、LSUN等数据集,可通过诸如torchvision.datasets.CIFAR10来调用。在这里介绍一个会经常使用到的Dataset——ImageFolder。ImageFolder假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,详见ImageFolder
(2)数据示例
或
按照类别划分目录
2.训练代码
'''
基于PyTorch的VGG16迁移学习
'''
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import numpy as np
from torchvision import models
import matplotlib.pyplot as plt
batch_size = 8
learning_rate = 0.0002
epoch = 10
train_transforms = transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip()