第一步:数据的预处理
1.训练集和测试集数据的路径
#指定数据的路径
data_dir='/Users/macbook/Desktop/flower_data/'
train_dir=data_dir+'/train'
valid_dir=data_dir+'/valid'
2.训练集的图像增强Image Augmentation
Image Augmentation:只需要在训练集中进行,测试集不需要进行数据增强操作
transforms:用来进行图像增强操作
data_tansforms={
#数据增强Data Augmentation 训练集需要进行数据增强
'train':transforms.Compose(
[transforms.RandomRotation(45),#随机旋转,表示图像在-45~45度之间随机旋转
#对于比较大的图像,如1024*1024的图像,应该先要经过resize操作,变成256*256大小,再从中心开始裁剪成224*224
transforms.CenterCrop(224),#从中心开始裁剪,裁成224*224尺寸的图像
transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转,50%可能性执行水平翻转
transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转,50%可能性执行垂直翻转
#参数1是亮度,参数2是对比度,参数3是饱和度,参数4是色相
transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1),
#2.5%的可能性把当前图像转换成灰度图
transforms.RandomGrayscale(p=0.025),
transforms.ToTensor(),#把数据转换成Tensor格式
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])#均值和标准差,标准化操作
]),
#验证集不需要进行数据增强:图片大小和标准化操作
'valid':transforms.Compose(
[
transforms.Resize(256),#验证集图像的大小并不确定,所以要将图像缩放成256*256
transforms.CenterCrop(224),#从图片的中心裁剪成224*224大小
transforms.ToTensor(),#将数据转换成Tensor格式
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])#均值和标准差,标准化操作
]
),
}
第二步:数据的读取
1.定义超参数
#定义超参数
batch_size=8
2.训练集和测试集的读取
#训练集和测试集的数据读取:image_datasets
image_datasets={
x:datasets.ImageFolder(os.path.join(data_dir,x),data_tansforms[x])
for x in ['train','valid']
}
3.构建batch数据,DataLoader用来迭代取数据
#构建batch数据,DataLoader用来迭代取数据:dataLoaders
dataloaders={
x:data.DataLoader(image_datasets[x],batch_size=batch_size,shuffle=True)
for x in ['train','valid']
}
4.参数:训练集和测试集的大小,花卉种类的名称
#训练集和测试集的大小
dataset_sizes={
x:len(image_datasets[x])