torchvision学习——datasets、models(加载数据、调用模型)

系列文章目录

torchvision是pytorch的一个图形库,它包含了torchvision.datasets、torchvision.models、torchvision.transforms、torchvision.utils四部分。
1、torchvision.datasets: 一些数据集。
2、torchvision.models: 常见卷积网络模型。
3、torchvision.transforms: 数据预处理、图片变换等操作。
4、torchvision.utils: 其他函数。


前言

最近在学习pytorch,总结一下。


1.torchvision.datasets

datasets这个包有很多数据集,比如MINIST、COCO、CIFAR10 and CIFAR100、LSUN 、Classification、ImageFolder、Imagenet-12、STL10。torchvision.datasets中的数据集封装都是torch.utils.data.Dataset子类,它们都实现了__getitem__ 和 __len__方法,都可以用DataLoader进行数据加载。

  torchvision.datasets.MNIST(root,train = True,transform = None,target_transform = None,download = False

参数介绍:
root:数据集的根目录
train:如果为True,训练集,否则是测试集
download:如果为true,根目录没有数据集就会自动在这个目录下载。
transform:数据集预处理,比如归一化当图形转换类的操作
target_transform:接收目标并对其进行转换的函数/转换。

MNIST数据集示例

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
# 数据预处理
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081))])
# 训练集
train_dataset = datasets.MNIST(root='../data/mnist',train=True, download=True, transform=transform)
# 测试集
test_dataset=datasets.MNIST(root='../data/mnist',train=False, download=True, transform=transform)
# 数据集加载器 
# train_loader=DataLoader(train_dataset,shuffle=True,batch_size=batch_size)
# test_loader=DataLoader(test_dataset,shuffle=False,batch_size=batch_size)

2.DataLoader

DataLoader数据加载器,PyTorch数据读取重要接口,用PyTorch架构来训练模型基本都会用到该接口,把数据分块送入模型进行训练。

from torch.utils.data import DataLoader

train_loader= DataLoader(dataset = train_dataset,  # 数据加载
                        batch_size = 4,    # 送入多少张图片
                        shuffle = True,    #对原有数据排序是否打乱
                        num_workers = 0,   #是否进行多进程加载数据设置
                        drop_last = False) #最后的数据组不成一个batch_size 是否丢弃

参数:
dataset:数据加载
batch_size :送入多少张图片
shuffle :是否打乱数据
sampler :指定数据加载中使用的索引/键的序列
batch_sampler = None,#和sampler类似
num_workers :是否进行多进程加载数据设置
collat​​e_fn = None,#是否合并样本列表以形成一小批Tensor
pin_memory :数据加载器会在返回之前将Tensors复制到CUDA固定内存
drop_last :最后的数据组不成一个batch_size 是否丢弃

3.torchvision.models

models包含以下模型:
AlexNet
VGG
ResNet
SqueezeNet
DenseNet
Inception v3
GoogLeNet
ShuffleNet v2
MobileNetV2
MobileNetV3
ResNeXt
Wide ResNet
MNASNet
EfficientNet
RegNet

导入模型

import torchvision.models as models
 
#alexnet = models.alexnet(pretrained=True)  # 加载预训练权重
alexnet = models.alexnet()   # AlexNet   不加载
vgg16 = models.vgg16()       # VGG16
resnet18 = models.resnet18() # ResetNet模型
print(vgg16 )   # 打印模型

改模型默认下载目录

import os
os.environ['TORCH_HOME']='E:/Data/torch-model'

修改模型

import torchvision.models as models
vgg16 = models.vgg16()       # VGG16
# 在classifier层添加add_linear
vgg16.classifier.add_module("add_linear",Linear(1000,10))
# 在classifier层修改add_linear参数
vgg16_false.classifier[6]=Linear(4096,10)

保存模型

path = "D:/code/text/model1.pth"
    #torch.save(model,path)
    torch.save(model.state_dict(),path)   # 保存模型

加载模型

解决pytorch加载模型报错TypeError: ‘collections.OrderedDict‘ object is not callable
# 错误原因:之前保存网络时用的方法是torch.save(model, ‘Nei.pkl’),这样保存下来的Net.pkl是一个状态字典,而不是模型本身,也就是说Net.pkl中保存的只是网络的参数,而没有网络结构。

    model = torchvision.models.vgg16(pretrained=False)
    model.load_state_dict(torch.load('D:/code/text/model1.pth')) # 导入网络的参数

4.torchvision.utils

拼接图片
组成图像的网络,将多张图片组合成一张图片

torchvision.utils.make_grid(images)

参数:
tensor:4D张量,形状为(B x C x H x W),图像列表
nrow:每行的图片数量,默认值为8
padding:相邻图像之间的间隔。默认值为2
normalize:如果为True,则把图像的像素值通过range指定的最大值和最小值归一化到0-1。默认为False
range:元组,用于指定最大值和最小值。默认使用图像像素的最大最小值。
sacle_each:如果为True,就单独对每张图像进行normalize;如果是False,统一对所有图像进行normalize。默认为Flase
pad_value:float,上述padding会使得图像之间留出空隙,默认为0

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

train_dataset = datasets.MNIST(root='../data/mnist',train=True,transform=data_tf,download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=10,shuffle=False)

images, labels = next(iter(train_loader))   # batch_size 高 长 宽
# 组成图像的网络,将多张图片组合成一张图片
img = torchvision.utils.make_grid(images)
img = img.numpy().transpose(1, 2, 0)

def cv_show(name, img):  # 长宽高
    cv2.imshow(name, img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

cv_show('image', img)

在这里插入图片描述

保存图片

torchvision.utils.save_image(img, imgPath)

总结

未完待续,,,

  • 4
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Chaoy6565

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值