例如加载MNIST数据集:
import torchvision.datasets
import numpy as np
import torch
from torch.utils.data import DataLoader
# preprocess transform
initial_process = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
# train set
train_dataset = torchvision.datasets.MNIST(root="./dataset",download=True,transform=torchvision.transforms.ToTensor(),
train=True)
# data set
test_dataset = torchvision.datasets.MNIST(root="./dataset",download=True,transform=initial_process,
train=False)
train_loader = DataLoader(
train_dataset, batch_size=1, shuffle=True
)
test_loader = DataLoader(
test_dataset, batch_size=1,shuffle=False
)
# 打印数据集长度
print("训练集长度为",train_dataset.__len__())
print("测试集长度为",test_dataset.__len__())
# 打印数据集标准差以及均值
print(train_dataset.data.dtype)
print(train_dataset.data.float().mean())
print(train_dataset.data.float().std())
#
for data,targets in train_loader:
print(data.dtype)
print(targets)
会发现,直接打印数据集data的dtype,发现为torch.uint8[0,255]
但是从dataloader中取出来之后,数据类型就变为了torch.float32[0,1]
所以估计DataLoader会调用dataset的transform对dataset进行变换