机器学习实战笔记9 花卉图像识别

import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import transforms,models,datasets
import imageio
import time
import warnings
warnings.filterwarnings("ignore")
import random
import sys
import json
from PIL import Image

数据读取与预处理操作

data_dir = './flower_data/'
train_dir = data_dir+'/train'
valid_dir = data_dir+'/valid' #验证集

制作数据源

data_transforms = {
    'train':
        transforms.Compose([  #按顺序做接下来的操作
        transforms.Resize([96,96]), #卷积神经网络要求数据格式相同
        transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
        transforms.CenterCrop(64),#从中间开始裁剪
        transforms.RandomHorizontalFlip(p=0.5),#随机水平旋转,选择一个概率
        transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
        transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1),#参数1为亮度参数2为对比度,参数3为饱和度,参数4为色相
        transforms.RandomGrayscale(p=0.025),#概率转换为灰度率,3通道就是RGB
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])#均值,标准差
        ]),
    'valid':
        transforms.Compose([
        transforms.Resize([64,64]),
        transforms.ToTensor(),
        transforms.Normalize([-.485,0.456,0.406],[0.229,0.224,0.225])
        ]),
}

构建图像数据加载和处理

batch_size = 128

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x),data_transforms[x])for x in ['train','valid']}
dataloaders={x:torch.utils.data.DataLoader(image_datasets[x],bath_size=natch_size,shuffle=True) for x in ['train','valid']}
dataste_sizes = {x:len(image_datasets[x])for x in['train','valid']}
calss_names = image_datasets['train'].classes
                                                  
                                                                                                

image_datasets

读取标签对应的实际名字

with open('cat_to_name.json','r')as f:
    cat_to_name = json.load(f)

cat_to_name

加载models中提供的模型,并且直接用训练的好权重当作初始化参数

'model_name' = 'resnet'#可选的比较多['resent','alexnet','vgg''squeezenet','densenet']
#是否用人家训练好的特征来做
feature_extraxt =True#都用人家的特征,咱先不更新
#是否用GPU训练
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available. Training on CPU...')
else:
    print('CUDA is avilable! Training on GPU...')
device = torch.device("cuda:0"if torch.cuda.is_avilable() else "cpu")

模型参数要不要更新

有时候用人家模型,就一直用了,更不更新咱们可以自己定
model_ft = model.resent18()#18层的能快点,条件好的也可以选152
model_ft

  • 4
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值