dataset只是去告诉程序我们的数据集在哪个位置,dataloader是一个加载器,可以把数据加载到神经网络中。可以类比成扑克牌,把手想象成一个神经网络。
dataloader每次从dataset当中取数据,取多少、怎么取用dataloader中的参数去控制。
打开Pytorch的官方文档中的dataloader部分。
可以看到,dataloader出现在torch.utils.data工具包下。
dataloader中的参数比较多,但其中只有dataset没有默认值。
-
dataset (Dataset) – dataset from which to load the data.
-
batch_size (int, optional) – how many samples per batch to load (default:
1
). -
shuffle (bool, optional) – set to
True
to have the data reshuffled at every epoch (default:False
). -
sampler (Sampler or Iterable_,_ optional) – defines the strategy to draw samples from the dataset. Can be any
Iterable
with__len__
implemented. If specified,shuffle
must not be specified. -
batch_sampler (Sampler or Iterable_,_ optional) – like
sampler
, but returns a batch of indices at a time. Mutually exclusive withbatch_size
,shuffle
,sampler
, anddrop_last
. -
num_workers (int, optional) – how many subprocesses to use for data loading.
0
means that the data will be loaded in the main process. (default:0
) -
drop_last (bool, optional) – set to
True
to drop the last incomplete batch, if the dataset size is not divisible by the batch size. IfFalse
and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default:False
)
batch_size表示每次取几个数据。
shuffle表示是否打乱数据,若为true,则两次筛数据的顺序一样 。
num_workers表示进程数,若为0则只才用主进程进行加载,但num_workers>0在windows下可能会产生错误。
drop_last表示当取数据除不尽数据总数时数据是舍去还是不舍去。
先进行数据的准备,把batch_size设为4,表示一次性抓取4个数据。
import torchvision
from torch.utils.data import DataLoader
# 准备的测试集
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor());
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
查看CIFAR10数据集
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args: index (int): Index
Returns: tuple: (image, target) where target is index of the target class. """
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
可以发现getitem返回数据的方式为img, target,则我们用同样的方式去接收。
# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)
输出:
torch.Size([3, 32, 32])
3
batch_size=4,就是取test_data[0],test_data[1],test_data[2],test_data[3]并将他们打包。注意,默认情况下batch_size是从数据集中随机抓取数据。
可以看到,dataloader打包时是将img和target分别打包
for data in test_loader:
imgs, targets = data
print(imgs.shape)
print(targets)
输出:
torch.Size([4, 3, 32, 32])
tensor([1, 6, 7, 1])
torch.Size([4, 3, 32, 32])
tensor([9, 4, 8, 1])
...
[4,3,32,32]表示有4张图片,3通道,图片大小32×32.
[1,6,7,1]表示这4张图片的类别。
用tensorboard演示一下
writer = SummaryWriter("dataloader")
step = 0
for data in test_loader:
imgs, targets = data
# print(imgs.shape)
# print(targets) writer.add_images("test_data", imgs, step)
step = step + 1
writer.close()
进入tensorboard查看,这里要是跳步的话,可以用如下命令
tensorboard --samples_per_plugin images=500 --logdir="dataloader"