1 dataset
新建一个名为test_dataset_transforms.py的文件
先下载数据集,dataset下载地址:Datasets — Torchvision 0.19 documentation
import torchvision
train_data = torchvision.datasets.CIFAR10('./dataset/cifar10', train=True, download=True)
test_data = torchvision.datasets.CIFAR10('./dataset/cifar10', train=False, download=True)
print(train_data[0])
print(test_data.classes)# 打印出test_data的类别
img , target = test_data[0]# 获取第一张图片的数据和标签
print(img)# 打印出图片数据
print(test_data.classes[target])# 打印出图片标签对应的类别
img.show()# 展示图片
由运行结果可以知道, 当前图片是PIL类型,现需要将其转为tensor类型,使用torch.ToTensor()
import torchvision
data_transform = torchvision.transforms.ToTensor()#将PIL.Image转换为torch.FloatTensor
train_data = torchvision.datasets.CIFAR10('./dataset/cifar10', train=True, download=True, transform=data_transform)
test_data = torchvision.datasets.CIFAR10('./dataset/cifar10', train=False, download=True, transform=data_transform)
print(train_data[0])
展示图片
import torchvision
from torch.utils.tensorboard import SummaryWriter
data_transform = torchvision.transforms.ToTensor()#将PIL.Image转换为torch.FloatTensor
train_data = torchvision.datasets.CIFAR10('./dataset/cifar10', train=True, download=True, transform=data_transform)
test_data = torchvision.datasets.CIFAR10('./dataset/cifar10', train=False, download=True, transform=data_transform)
writer = SummaryWriter("dataset_transforms_logs")
for i in range(10):
img, target = train_data[i]
writer.add_image("train_data", img, i)
writer.close()
结果如下:
2 dataloader
dataloader相关介绍:torch.utils.data — PyTorch 2.4 documentation
其中,
-
batch_size (int, optional) – 每批次加载多少个样本 (default:
1
). -
shuffle (bool, optional) – 设置为 True 代表每次加载样本时重新洗牌数据 (default:
False
).
-
drop_last (bool, optional) – 如果数据集大小不能被批次大小整除,当设置为 True 时会删除最后一个不完整的批次。如果为 False 则保留最后一个批次。 (default:
False
)
实例:
import torch
import torchvision
from torch.utils.data import dataloader, DataLoader
from torch.utils.tensorboard import SummaryWriter
data_transform = torchvision.transforms.ToTensor()#将PIL.Image转换为torch.FloatTensor
train_data = torchvision.datasets.CIFAR10('./dataset/cifar10', train=True, download=True, transform=data_transform)
test_data = torchvision.datasets.CIFAR10('./dataset/cifar10', train=False, download=True, transform=data_transform)
test_loader = DataLoader(test_data, batch_size=64, shuffle=True)
# test_loader = DataLoader(test_data, batch_size=64, shuffle=True, drop_last=True)
writer = SummaryWriter("dataloader_logs")
step = 0
for data in test_loader:
imgs, targets = data
writer.add_images("drop_last_false", imgs, step)# 注意,此时是images,非image
step += 1
writer.close()
结果如图
可以看出,当 batch_size=64时,每次取出64张图片, shuffle=True时,每次取的图片都打乱了顺序,当drop_last=false时,最后剩下的16张图片也被留下来了。