- 如何把transform和数据集结合在一起?
- 数据集如何下载、如何使用
在该链接中,进行数据集的下载。
比如要下载CIFAR10
数据集,网页中有详细的参数说明以及调用方式。
root (string) – Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.
决定数据集下载在哪里-路径
train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.
设置为True,则下载训练集,设置为False,下载测试集
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
是否将PIL数据类型转换为transform数据类型比如随即裁剪等
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
对target进行数据变换
download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
是否联网下载
下载若慢,在控制台或源码中找到下载链接,用迅雷下载。
可以使用索引来看数据是什么。也可以通过控制台,了解数据集中的属性,从而使用print
进行查看。
train_set = torchvision.datasets.CIFAR10(root = "./dataset",train=True , download=True)
test_set = torchvision.datasets.CIFAR10(root = "./dataset", train=False , download=True)
print(f"test_set = {test_set}")
print(F"test_set[0] = {test_set[0]}")
print(f"test_set.classes = {test_set.classes}")
可以看到运行结果,测试集一共有10个类别。
img,target = test_set[0]
print(f"img = {img}")#可以看到PIL 类型
print(F"target = {target}")#目标分类是3
print(f"test_set.classes[target]")#
img.show()
print(test_set[0])
运行结果👇
要在pytorch中继续使用,需要转换为tensor数据类型。过程如下
dataset_tensor = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
#还可以加torchvision.transforms.RandomCrop(512)
])
#接着在数据集中指定
train_set = torchvision.datasets.CIFAR10(root="./dataset",train=True,transform = dataset_tensor , download=True)
#查看
print(test_set[0])
writer = SummaryWriter("tb_blogs")
for i in range(10):
img , target = test_set[i]
writer.add_image("test_set", img,i)
writer.close()
#之后打开tensorboard查看