Pytorch 学习第【4】天

1. torchvision库

torchvision包含了目前流行的数据集,模型结构和常用的图片转换工具。
结合前边的内容学习:
Pytorch 学习第【1】天
Pytorch 学习第【2】天
Pytorch 学习第【3】天

1.1. datasets数据集

torchvision.datasets中包含了以下数据集
图像数据集文本数据集音频数据集

1.1.1 加载数据集

例子:

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

#训练集
training_data = datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=ToTensor()
)
#测试集
test_data = datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=ToTensor()
)

【PS:若是下载太慢,可以复制给出的网址到迅雷下载,之后放在指定的目录下,再运行代码回自行解压。显示Files already downloaded and verified就ok了,查看CIFAR10数据集

datasets参数解析

def __init__(
            self,
            root: str,
            train: bool = True,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            download: bool = False,
    )

其中
root:是存储训练/测试数据的路径,
train:指定训练或测试数据集,
download=True(如果无法在root路径上获得,从互联网下载数据)
transform和target_transform并指定特征和标签转换

1.1.2 查看数据集

导入的库:

from torchvision.transforms import *
import matplotlib.pyplot as plt
from torchvision import *

随机显示部分图片并标以对应的标签:

# 将tensor 格式转化为PIL格式
to_pil_image = transforms.ToPILImage()
# 标签对应的含义
labels_map = {
    0: "airplane",
    1: "automobile",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck",
}
#图片显示大小
figure = plt.figure(figsize=(8, 8))
#这里显示是三行三列
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    # 随机选择一张图片
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    # 将tensor类型转成PIL类型
    img = to_pil_image(img)
    #准备画图
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    #不显示坐标轴
    plt.axis("off")
    plt.imshow(img,)
plt.show()

【PS:需要注意的导入数据的时候是tensor类型的,在显示图片需要将tensor类型转化成PIL类型。】
当然也可以通过tensorBoard进行查看

writer = SummaryWriter("p9")
for i in range(9):
    img,label = training_data[i]
    writer.add_image("train_set",img,i)
writer.close()

【PS:命令忘记了可以看Pytorch 学习第【3】天】

1.1.3 输出结果

在这里插入图片描述
在tensorboard上:
在这里插入图片描述

1.2 DataLoader数据加载器

数据加载程序。结合一个数据集和一个采样器,并在给定数据集上提供一个可迭代对象。DataLoader支持地图样式和可迭代样式的数据集,支持单进程或多进程加载、自定义加载顺序、可选的自动批处理(排序)和内存固定。
参数:

CLASStorch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False)

其中比较常设置的参数

dataset (dataset):用于加载数据的数据集。
Batch_size (int,可选):每批加载多少个样本(默认值:1)。
shuffle (bool,可选):设置为True,在每个epoch重新洗牌数据(:False)。
sampler (sampler或Iterable,可选):定义了从数据集中抽取样本的策略。
Num_workers (int,可选):用于数据加载的子进程数。0表示数据将加载到主进程中。(默认值:0)
drop_last (bool,可选):设置为True表示删除最后一个不完整的批处理,如果数据集大小不能被批处理大小整除。如果为False且数据集的大小不能被批量大小整除,则最后一批数据将更小。(默认值:False)

1.2.1 tensorboard上进行查看

代码:

# 从训练集中,batch_size=64每批加载64个样本,在每个epoch重新洗牌数据,
# 加载到主进程中,保留不能整除的部分
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True,num_workers=0,drop_last=False)
writer =SummaryWriter("dataloader")
i=0
for data in test_dataloader:
    imgs, targets = data
    print(imgs.shape)
    print(targets)
    writer.add_images("test_dataloader",imgs,i)
    i+=1
writer.close()

结果:在这里插入图片描述
【PS:因为drop_last = false,将保留不能除尽的部分在tensorboard的显示效果,同时这里使用的是add_images。若drop_last = true将删除最后一个不完整的批处理,可自行尝试】
在这里插入图片描述

本文作者:九重!
本文链接:https://blog.csdn.net/weixin_43798572/article/details/124237002
关于博主:评论和私信会在第一时间回复。或者直接私信我。
声援博主:如果您觉得文章对您有帮助,可以点击文章右下角【点赞】【收藏】一下。您的鼓励是博主的最大动力!

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

九重!

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值