【pytorch学习】transforms与数据集的使用

1 dataset

新建一个名为test_dataset_transforms.py的文件

先下载数据集,dataset下载地址:Datasets — Torchvision 0.19 documentation

import torchvision

train_data = torchvision.datasets.CIFAR10('./dataset/cifar10', train=True, download=True)
test_data = torchvision.datasets.CIFAR10('./dataset/cifar10', train=False, download=True)
print(train_data[0])

print(test_data.classes)# 打印出test_data的类别
img , target = test_data[0]# 获取第一张图片的数据和标签
print(img)# 打印出图片数据
print(test_data.classes[target])# 打印出图片标签对应的类别
img.show()# 展示图片

由运行结果可以知道, 当前图片是PIL类型,现需要将其转为tensor类型,使用torch.ToTensor()

import torchvision

data_transform = torchvision.transforms.ToTensor()#将PIL.Image转换为torch.FloatTensor
train_data = torchvision.datasets.CIFAR10('./dataset/cifar10', train=True, download=True, transform=data_transform)
test_data = torchvision.datasets.CIFAR10('./dataset/cifar10', train=False, download=True, transform=data_transform)
print(train_data[0])

 展示图片

import torchvision
from torch.utils.tensorboard import SummaryWriter

data_transform = torchvision.transforms.ToTensor()#将PIL.Image转换为torch.FloatTensor
train_data = torchvision.datasets.CIFAR10('./dataset/cifar10', train=True, download=True, transform=data_transform)
test_data = torchvision.datasets.CIFAR10('./dataset/cifar10', train=False, download=True, transform=data_transform)


writer = SummaryWriter("dataset_transforms_logs")

for i in range(10):
    img, target = train_data[i]
    writer.add_image("train_data", img, i)

writer.close()

 结果如下:

 2 dataloader

dataloader相关介绍:torch.utils.data — PyTorch 2.4 documentation

 其中,

  • batch_size (intoptional) – 每批次加载多少个样本 (default: 1).

  • shuffle (booloptional) – 设置为 True 代表每次加载样本时重新洗牌数据 (default: False).

  • drop_last (booloptional) – 如果数据集大小不能被批次大小整除,当设置为 True 时会删除最后一个不完整的批次。如果为 False 则保留最后一个批次。 (default: False)

 实例:

import torch
import torchvision
from torch.utils.data import dataloader, DataLoader
from torch.utils.tensorboard import SummaryWriter

data_transform = torchvision.transforms.ToTensor()#将PIL.Image转换为torch.FloatTensor
train_data = torchvision.datasets.CIFAR10('./dataset/cifar10', train=True, download=True, transform=data_transform)
test_data = torchvision.datasets.CIFAR10('./dataset/cifar10', train=False, download=True, transform=data_transform)

test_loader = DataLoader(test_data, batch_size=64, shuffle=True)
# test_loader = DataLoader(test_data, batch_size=64, shuffle=True, drop_last=True)

writer = SummaryWriter("dataloader_logs")

step = 0
for data in test_loader:
    imgs, targets = data
    writer.add_images("drop_last_false", imgs, step)# 注意,此时是images,非image
    step += 1

writer.close()

结果如图

可以看出,当 batch_size=64时,每次取出64张图片, shuffle=True时,每次取的图片都打乱了顺序,当drop_last=false时,最后剩下的16张图片也被留下来了。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

神通广大白居易

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值