pytorch学习笔记8

#pip install torchvision


import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import transforms,models,datesets
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

#数据读取与预处理操作
data_dir=
train_dir=data_dir+
valid_dir=data_dir+
#数据增强
data_transform={
    'train':transforms.compose([transforms.randomrotation(45),#随机旋转,-45到45度之间随机选
    transforms.centercrop(224),#从中心开始裁剪
    transforms.randomhorizontalflip(p=0.5),#随机水平反转,以一个概率
    transforms.randomverticalflip(p=0.5)#随机垂直翻转
    transforms.colorjitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1),#参数1:亮度;参数2:对比度;参数3:饱和度;参数4:色相
    transforms.randomgrayscale(p=0.025),#概率转换成灰度图,3通道就是R=G=B
    transforms.totensor(),
    transforms.normalize([0.485,0.456,0.406],[0.229,0.224,0.225])#均值,方差
    ]),
    'valid':transforms.compose([transforms.resize(256),
    transforms.centercrop(224),
    transforms.totensor(),
    transforms.normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]),
    }

    batch_size=8
    image_datesets={x:datasets.imagefolder(os.path.join(data_dir,x),data_transforms[x]) for x in ['train','vaild']}
    dataloader={x:torch.utils.data.DataLoader(image_datasets[x],batch_size=batch_size,shuffle=true) for x in ['train','valid']}//划分batch
    detaset_sizes={x:len(image_datasets[x]) for x in ['train','valid']}
    class_names=image_datasets['train'].classes
    #读取标签对应的名字
    with open('cat_to_name.json','r') as f:
        cat_to_name=json.load(f)
    #展示数据
    def im_convert(tensor):
        image=tensor.to('cpu').clone().detach()//tensor.clone():返回tensor的拷贝,返回的新tensor和原来的tensor具有同样的大小和数据类型;tensor.detach():返回一个新的tensor,新的tensor和原来的tensor共享数据内存
        image=image.numpy().squeeze()//作用:从数组的形状中删除单维度条目,即把shape中为1的维度去掉;场景:在机器学习和深度学习中,通常算法的结果是可以表示向量的数组(即包含两对或以上的方括号形式[[]]),如果直接利用这个数组进行画图可能显示界面为空。我们可以利用squeeze()函数将表示向量的数组转换为秩为1的数组,这样利用matplotlib库函数画图时,就可以正常的显示结果了。
        image=image.tanspose(1,2,0)//主要是Pytorch中使用的数据格式与plt.imshow()函数的格式不一致,Pytorch中为[Channels, H, W],plt.imshow()中则是[H, W, Channels]
        image=image*np.array((0.229,0.224,0.225))+np.array((0.485,0.456,0.406))
        image=image.clip(0,1)//image.clip函数:下界,区间的最小值,a中比a_min小的数都会强制变成a_min;上界,区间的最大值,a中比a_max大的数都会强制变成a_max
        return image;

    fig=plt.figure(figsize=(20,12))//生成一个图框
    columns=4
    rows=2

    dataiter=iter(dataloaders['valid'])//iter() 函数用来生成迭代器
    inputs,classes=dataiter.next()//指针指向下一条记录

    for idx in range (columns*rows)
        ax=fig.add_subplot(rows,columns,idx+1,xticks=[],yticks=[])//第一个参数表示行数,第二个参数表示列数,第三个参数表示你正在绘制图的位置。
        ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
        plt.imshow(im_convert(inputs[idx]))//将数字从一个度量系统转换到另一个度量系统中的函数。
    plt.show()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值