一、 Transform与dataset数据集的使用
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_set = torchvision.datasets.CIFAR10(root = "./dataset", train=True, transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root = "./dataset", train=True, transform=dataset_transform, download=True)
第三行代码实现的是数据集中的PIL型图片转换为tensor数据类型的图片;
在第四行与第五行代码中实现的是将从CIFAR10的数据集中获取数据,将下载的数据存放于dataset中,将数据图片转化为tensor型,否则报错,得到的是PIL型数据,而非numpy数据或者tensor数据。
writer = SummaryWriter("p10")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()
使用tensorboard来查看。
tensorboard --logdir=learningplan1/p10
可查看到:
二、DataLoader的使用
作用:从dataset数据集中取出数据,如何取出数据,这就是由dataloader中的参数所决定。
import torchvision
# 准备的测试数据集
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
test_data = torchvision.datasets.CIFAR10("./dataset", train=False,transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False, num_workers=0, drop_last=True)
# 测试数据集中第一张图片集target
img, target = test_data[0]
print(img.shape)
print(target)
writer = SummaryWriter("dataloader")
for epoch in range(2):
step = 0
for data in test_loader: # 取出每一个循环
imgs,targets = data
# print(imgs.shape)
# print(targets)
writer.add_images("epoch:{}".format(epoch), imgs, step)
step +=1
writer.close()
结果:
上述实现的就是,循环0与1后,使用shuffle为Flase后,打乱的两版数据集。