torch.utils.data.DataLoader
本节讲述collate_fn使用。
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None):
torch._C._log_api_usage_once("python.data_loader"))
collate_fn官网介绍:
merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.意思是对传入的dataset进行额外处理。
如下官网的一个collate模块,从中可以看出,对数据做各种转化。
default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int_classes):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, container_abcs.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, container_abcs.Sequence):
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))
如dataset为下生成,经过collate后,batch转化成data-target对应关系
#inputing = torch.tensor(np.array([test[i:i + 3] for i in range(10)]))
#target = torch.tensor(np.array([test[i:i + 1] for i in range(10)]))
#torch_dataset = Data.TensorDataset(inputing, target)
tuple(tensor[index] for tensor in self.tensors)#TensorDataset中的__getitem__
上一个具体例子:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
# a simple custom collate function, just to show the idea
def my_collate(batch):
data = [item[0] for item in batch]
target = [item[1] for item in batch]
target = torch.LongTensor(target)
return [data, target]
def show_image_batch(img_list, title=None):
num = len(img_list)
fig = plt.figure()
for i in range(num):
ax = fig.add_subplot(1, num, i+1)
ax.imshow(img_list[i].numpy().transpose([1,2,0]))
ax.set_title(title[i])
plt.show()
# do not do randomCrop to show that the custom collate_fn can handle images of different size
train_transforms = transforms.Compose([transforms.Scale(size = 224),
transforms.ToTensor(),
])
# change root to valid dir in your system, see ImageFolder documentation for more info
train_dataset = datasets.ImageFolder(root="/hd1/jdhao/toyset",
transform=train_transforms)
trainset = DataLoader(dataset=train_dataset,
batch_size=4,
shuffle=True,
collate_fn=my_collate, # use custom collate function here
pin_memory=True)
trainiter = iter(trainset)
imgs, labels = trainiter.next()
# print(type(imgs), type(labels))
show_image_batch(imgs, title=[train_dataset.classes[x] for x in labels])
参考文章:
- https://zhuanlan.zhihu.com/p/96083063
- https://www.cnblogs.com/zf-blog/p/11360557.html
- https://blog.csdn.net/weixin_42028364/article/details/81675021