import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import transforms,models,datasets
import imageio
import time
import warnings
warnings.filterwarnings("ignore")
import random
import sys
import json
from PIL import Image
数据读取与预处理操作
data_dir = './flower_data/'
train_dir = data_dir+'/train'
valid_dir = data_dir+'/valid' #验证集
制作数据源
data_transforms = {
'train':
transforms.Compose([ #按顺序做接下来的操作
transforms.Resize([96,96]), #卷积神经网络要求数据格式相同
transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
transforms.CenterCrop(64),#从中间开始裁剪
transforms.RandomHorizontalFlip(p=0.5),#随机水平旋转,选择一个概率
transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1),#参数1为亮度参数2为对比度,参数3为饱和度,参数4为色相
transforms.RandomGrayscale(p=0.025),#概率转换为灰度率,3通道就是RGB
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])#均值,标准差
]),
'valid':
transforms.Compose([
transforms.Resize([64,64]),
transforms.ToTensor(),
transforms.Normalize([-.485,0.456,0.406],[0.229,0.224,0.225])
]),
}
构建图像数据加载和处理
batch_size = 128
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x),data_transforms[x])for x in ['train','valid']}
dataloaders={x:torch.utils.data.DataLoader(image_datasets[x],bath_size=natch_size,shuffle=True) for x in ['train','valid']}
dataste_sizes = {x:len(image_datasets[x])for x in['train','valid']}
calss_names = image_datasets['train'].classes
image_datasets
读取标签对应的实际名字
with open('cat_to_name.json','r')as f:
cat_to_name = json.load(f)
cat_to_name
加载models中提供的模型,并且直接用训练的好权重当作初始化参数
'model_name' = 'resnet'#可选的比较多['resent','alexnet','vgg''squeezenet','densenet']
#是否用人家训练好的特征来做
feature_extraxt =True#都用人家的特征,咱先不更新
#是否用GPU训练
train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
print('CUDA is not available. Training on CPU...')
else:
print('CUDA is avilable! Training on GPU...')
device = torch.device("cuda:0"if torch.cuda.is_avilable() else "cpu")
模型参数要不要更新
有时候用人家模型,就一直用了,更不更新咱们可以自己定
model_ft = model.resent18()#18层的能快点,条件好的也可以选152
model_ft